Compare commits

..

3 Commits

Author SHA1 Message Date
Paul Masurel
d7973892a2 extra commit 2019-12-27 22:53:04 +09:00
Paul Masurel
cd7484c035 Added ReadOnlyDirectory and implemented Bundle Directory 2019-12-27 12:05:39 +09:00
Paul Masurel
7ed6bc8718 Added serialize to bundle in the RAMDirectory. 2019-12-26 10:06:52 +09:00
135 changed files with 3341 additions and 4388 deletions

View File

@@ -1,38 +1,3 @@
Tantivy 0.13.0
======================
- Bugfix in `FuzzyTermQuery` not matching terms by prefix when it should (@Peachball)
- Relaxed constraints on the custom/tweak score functions. At the segment level, they can be mut, and they are not required to be Sync + Send.
- `MMapDirectory::open` does not return a `Result` anymore.
- Change in the DocSet and Scorer API. (@fulmicoton).
A freshly created DocSet point directly to their first doc. A sentinel value called TERMINATED marks the end of a DocSet.
`.advance()` returns the new DocId. `Scorer::skip(target)` has been replaced by `Scorer::seek(target)` and returns the resulting DocId.
As a result, iterating through DocSet now looks as follows
```rust
let mut doc = docset.doc();
while doc != TERMINATED {
// ...
doc = docset.advance();
}
```
The change made it possible to greatly simplify a lot of the docset's code.
- Misc internal optimization and introduction of the `Scorer::for_each_pruning` function. (@fulmicoton)
- Added an offset option to the Top(.*)Collectors. (@robyoung)
Tantivy 0.12.0
======================
- Removing static dispatch in tokenizers for simplicity. (#762)
- Added backward iteration for `TermDictionary` stream. (@halvorboe)
- Fixed a performance issue when searching for the posting lists of a missing term (@audunhalland)
- Added a configurable maximum number of docs (10M by default) for a segment to be considered for merge (@hntd187, landed by @halvorboe #713)
- Important Bugfix #777, causing tantivy to retain memory mapping. (diagnosed by @poljar)
- Added support for field boosting. (#547, @fulmicoton)
## How to update?
Crates relying on custom tokenizer, or registering tokenizer in the manager will require some
minor changes. Check https://github.com/tantivy-search/tantivy/blob/master/examples/custom_tokenizer.rs
to check for some code sample.
Tantivy 0.11.3 Tantivy 0.11.3
======================= =======================
- Fixed DateTime as a fast field (#735) - Fixed DateTime as a fast field (#735)

View File

@@ -1,11 +1,11 @@
[package] [package]
name = "tantivy" name = "tantivy"
version = "0.12.0" version = "0.11.3"
authors = ["Paul Masurel <paul.masurel@gmail.com>"] authors = ["Paul Masurel <paul.masurel@gmail.com>"]
license = "MIT" license = "MIT"
categories = ["database-implementations", "data-structures"] categories = ["database-implementations", "data-structures"]
description = """Search engine library""" description = """Search engine library"""
documentation = "https://docs.rs/tantivy/" documentation = "https://tantivy-search.github.io/tantivy/tantivy/index.html"
homepage = "https://github.com/tantivy-search/tantivy" homepage = "https://github.com/tantivy-search/tantivy"
repository = "https://github.com/tantivy-search/tantivy" repository = "https://github.com/tantivy-search/tantivy"
readme = "README.md" readme = "README.md"
@@ -13,23 +13,25 @@ keywords = ["search", "information", "retrieval"]
edition = "2018" edition = "2018"
[dependencies] [dependencies]
base64 = "0.12.0" base64 = "0.11.0"
byteorder = "1.0" byteorder = "1.0"
crc32fast = "1.2.0" crc32fast = "1.2.0"
once_cell = "1.0" once_cell = "1.0"
regex ={version = "1.3.0", default-features = false, features = ["std"]} regex ={version = "1.3.0", default-features = false, features = ["std"]}
tantivy-fst = "0.3" tantivy-fst = "0.1"
memmap = {version = "0.7", optional=true} memmap = {version = "0.7", optional=true}
lz4 = {version="1.20", optional=true} lz4 = {version="1.20", optional=true}
snap = "1" snap = {version="0.2"}
atomicwrites = {version="0.2.2", optional=true} atomicwrites = {version="0.2.2", optional=true}
tempfile = "3.0" tempfile = "3.0"
log = "0.4" log = "0.4"
serde = {version="1.0", features=["derive"]} serde = "1.0"
serde_derive = "1.0"
serde_json = "1.0" serde_json = "1.0"
num_cpus = "1.2" num_cpus = "1.2"
fs2={version="0.4", optional=true} fs2={version="0.4", optional=true}
levenshtein_automata = "0.2" itertools = "0.8"
levenshtein_automata = {version="0.1", features=["fst_automaton"]}
notify = {version="4", optional=true} notify = {version="4", optional=true}
uuid = { version = "0.8", features = ["v4", "serde"] } uuid = { version = "0.8", features = ["v4", "serde"] }
crossbeam = "0.7" crossbeam = "0.7"
@@ -38,14 +40,14 @@ owning_ref = "0.4"
stable_deref_trait = "1.0.0" stable_deref_trait = "1.0.0"
rust-stemmers = "1.2" rust-stemmers = "1.2"
downcast-rs = { version="1.0" } downcast-rs = { version="1.0" }
tantivy-query-grammar = { version="0.13", path="./query-grammar" } tantivy-query-grammar = { version="0.11", path="./query-grammar" }
bitpacking = {version="0.8", default-features = false, features=["bitpacker4x"]} bitpacking = {version="0.8", default-features = false, features=["bitpacker4x"]}
census = "0.4" census = "0.4"
fnv = "1.0.6" fnv = "1.0.6"
owned-read = "0.4" owned-read = "0.4"
failure = "0.1" failure = "0.1"
htmlescape = "0.3.1" htmlescape = "0.3.1"
fail = "0.4" fail = "0.3"
murmurhash32 = "0.2" murmurhash32 = "0.2"
chrono = "0.4" chrono = "0.4"
smallvec = "1.0" smallvec = "1.0"
@@ -58,9 +60,10 @@ winapi = "0.3"
rand = "0.7" rand = "0.7"
maplit = "1" maplit = "1"
matches = "0.1.8" matches = "0.1.8"
time = "0.1.42"
[dev-dependencies.fail] [dev-dependencies.fail]
version = "0.4" version = "0.3"
features = ["failpoints"] features = ["failpoints"]
[profile.release] [profile.release]

View File

@@ -31,20 +31,16 @@ Tantivy is, in fact, strongly inspired by Lucene's design.
# Benchmark # Benchmark
Tantivy is typically faster than Lucene, but the results depend on
the nature of the queries in your workload.
The following [benchmark](https://tantivy-search.github.io/bench/) break downs The following [benchmark](https://tantivy-search.github.io/bench/) break downs
performance for different type of queries / collection. performance for different type of queries / collection.
In general, Tantivy tends to be
- slower than Lucene on union with a Top-K due to Block-WAND optimization.
- faster than Lucene on intersection and phrase queries.
Your mileage WILL vary depending on the nature of queries and their load.
# Features # Features
- Full-text search - Full-text search
- Configurable tokenizer (stemming available for 17 Latin languages with third party support for Chinese ([tantivy-jieba](https://crates.io/crates/tantivy-jieba) and [cang-jie](https://crates.io/crates/cang-jie)), Japanese ([lindera](https://github.com/lindera-morphology/lindera-tantivy) and [tantivy-tokenizer-tiny-segmente](https://crates.io/crates/tantivy-tokenizer-tiny-segmenter)) and Korean ([lindera](https://github.com/lindera-morphology/lindera-tantivy) + [lindera-ko-dic-builder](https://github.com/lindera-morphology/lindera-ko-dic-builder)) - Configurable tokenizer (stemming available for 17 Latin languages with third party support for Chinese ([tantivy-jieba](https://crates.io/crates/tantivy-jieba) and [cang-jie](https://crates.io/crates/cang-jie)) and [Japanese](https://crates.io/crates/tantivy-tokenizer-tiny-segmenter))
- Fast (check out the :racehorse: :sparkles: [benchmark](https://tantivy-search.github.io/bench/) :sparkles: :racehorse:) - Fast (check out the :racehorse: :sparkles: [benchmark](https://tantivy-search.github.io/bench/) :sparkles: :racehorse:)
- Tiny startup time (<10ms), perfect for command line tools - Tiny startup time (<10ms), perfect for command line tools
- BM25 scoring (the same as Lucene) - BM25 scoring (the same as Lucene)
@@ -63,17 +59,18 @@ Your mileage WILL vary depending on the nature of queries and their load.
- Configurable indexing (optional term frequency and position indexing) - Configurable indexing (optional term frequency and position indexing)
- Cheesy logo with a horse - Cheesy logo with a horse
## Non-features # Non-features
- Distributed search is out of the scope of Tantivy. That being said, Tantivy is a - Distributed search is out of the scope of Tantivy. That being said, Tantivy is a
library upon which one could build a distributed search. Serializable/mergeable collector state for instance, library upon which one could build a distributed search. Serializable/mergeable collector state for instance,
are within the scope of Tantivy. are within the scope of Tantivy.
# Supported OS and compiler
# Getting started
Tantivy works on stable Rust (>= 1.27) and supports Linux, MacOS, and Windows. Tantivy works on stable Rust (>= 1.27) and supports Linux, MacOS, and Windows.
# Getting started
- [Tantivy's simple search example](https://tantivy-search.github.io/examples/basic_search.html) - [Tantivy's simple search example](https://tantivy-search.github.io/examples/basic_search.html)
- [tantivy-cli and its tutorial](https://github.com/tantivy-search/tantivy-cli) - `tantivy-cli` is an actual command line interface that makes it easy for you to create a search engine, - [tantivy-cli and its tutorial](https://github.com/tantivy-search/tantivy-cli) - `tantivy-cli` is an actual command line interface that makes it easy for you to create a search engine,
index documents, and search via the CLI or a small server with a REST API. index documents, and search via the CLI or a small server with a REST API.

View File

@@ -18,5 +18,5 @@ install:
build: false build: false
test_script: test_script:
- REM SET RUST_LOG=tantivy,test & cargo test --all --verbose --no-default-features --features mmap - REM SET RUST_LOG=tantivy,test & cargo test --verbose --no-default-features --features mmap
- REM SET RUST_BACKTRACE=1 & cargo build --examples - REM SET RUST_BACKTRACE=1 & cargo build --examples

View File

@@ -1,98 +0,0 @@
use std::collections::HashSet;
use tantivy::collector::TopDocs;
use tantivy::doc;
use tantivy::query::BooleanQuery;
use tantivy::schema::*;
use tantivy::{DocId, Index, Score, SegmentReader};
fn main() -> tantivy::Result<()> {
let mut schema_builder = Schema::builder();
let title = schema_builder.add_text_field("title", STORED);
let ingredient = schema_builder.add_facet_field("ingredient");
let schema = schema_builder.build();
let index = Index::create_in_ram(schema.clone());
let mut index_writer = index.writer(30_000_000)?;
index_writer.add_document(doc!(
title => "Fried egg",
ingredient => Facet::from("/ingredient/egg"),
ingredient => Facet::from("/ingredient/oil"),
));
index_writer.add_document(doc!(
title => "Scrambled egg",
ingredient => Facet::from("/ingredient/egg"),
ingredient => Facet::from("/ingredient/butter"),
ingredient => Facet::from("/ingredient/milk"),
ingredient => Facet::from("/ingredient/salt"),
));
index_writer.add_document(doc!(
title => "Egg rolls",
ingredient => Facet::from("/ingredient/egg"),
ingredient => Facet::from("/ingredient/garlic"),
ingredient => Facet::from("/ingredient/salt"),
ingredient => Facet::from("/ingredient/oil"),
ingredient => Facet::from("/ingredient/tortilla-wrap"),
ingredient => Facet::from("/ingredient/mushroom"),
));
index_writer.commit()?;
let reader = index.reader()?;
let searcher = reader.searcher();
{
let facets = vec![
Facet::from("/ingredient/egg"),
Facet::from("/ingredient/oil"),
Facet::from("/ingredient/garlic"),
Facet::from("/ingredient/mushroom"),
];
let query = BooleanQuery::new_multiterms_query(
facets
.iter()
.map(|key| Term::from_facet(ingredient, &key))
.collect(),
);
let top_docs_by_custom_score =
TopDocs::with_limit(2).tweak_score(move |segment_reader: &SegmentReader| {
let mut ingredient_reader = segment_reader.facet_reader(ingredient).unwrap();
let facet_dict = ingredient_reader.facet_dict();
let query_ords: HashSet<u64> = facets
.iter()
.filter_map(|key| facet_dict.term_ord(key.encoded_str()))
.collect();
let mut facet_ords_buffer: Vec<u64> = Vec::with_capacity(20);
move |doc: DocId, original_score: Score| {
ingredient_reader.facet_ords(doc, &mut facet_ords_buffer);
let missing_ingredients = facet_ords_buffer
.iter()
.filter(|ord| !query_ords.contains(ord))
.count();
let tweak = 1.0 / 4_f32.powi(missing_ingredients as i32);
original_score * tweak
}
});
let top_docs = searcher.search(&query, &top_docs_by_custom_score)?;
let titles: Vec<String> = top_docs
.iter()
.map(|(_, doc_id)| {
searcher
.doc(*doc_id)
.unwrap()
.get_first(title)
.unwrap()
.text()
.unwrap()
.to_owned()
})
.collect();
assert_eq!(titles, vec!["Fried egg", "Egg rolls"]);
}
Ok(())
}

View File

@@ -10,7 +10,7 @@
// --- // ---
// Importing tantivy... // Importing tantivy...
use tantivy::schema::*; use tantivy::schema::*;
use tantivy::{doc, DocSet, Index, Postings, TERMINATED}; use tantivy::{doc, DocId, DocSet, Index, Postings};
fn main() -> tantivy::Result<()> { fn main() -> tantivy::Result<()> {
// We first create a schema for the sake of the // We first create a schema for the sake of the
@@ -62,11 +62,12 @@ fn main() -> tantivy::Result<()> {
{ {
// this buffer will be used to request for positions // this buffer will be used to request for positions
let mut positions: Vec<u32> = Vec::with_capacity(100); let mut positions: Vec<u32> = Vec::with_capacity(100);
let mut doc_id = segment_postings.doc(); while segment_postings.advance() {
while doc_id != TERMINATED { // the number of time the term appears in the document.
let doc_id: DocId = segment_postings.doc(); //< do not try to access this before calling advance once.
// This MAY contains deleted documents as well. // This MAY contains deleted documents as well.
if segment_reader.is_deleted(doc_id) { if segment_reader.is_deleted(doc_id) {
doc_id = segment_postings.advance();
continue; continue;
} }
@@ -85,7 +86,6 @@ fn main() -> tantivy::Result<()> {
// Doc 2: TermFreq 1: [0] // Doc 2: TermFreq 1: [0]
// ``` // ```
println!("Doc {}: TermFreq {}: {:?}", doc_id, term_freq, positions); println!("Doc {}: TermFreq {}: {:?}", doc_id, term_freq, positions);
doc_id = segment_postings.advance();
} }
} }
} }

View File

@@ -9,10 +9,11 @@
// - import tokenized text straight from json, // - import tokenized text straight from json,
// - perform a search on documents with pre-tokenized text // - perform a search on documents with pre-tokenized text
use tantivy::tokenizer::{PreTokenizedString, SimpleTokenizer, Token, TokenStream, Tokenizer};
use tantivy::collector::{Count, TopDocs}; use tantivy::collector::{Count, TopDocs};
use tantivy::query::TermQuery; use tantivy::query::TermQuery;
use tantivy::schema::*; use tantivy::schema::*;
use tantivy::tokenizer::{PreTokenizedString, SimpleTokenizer, Token, Tokenizer};
use tantivy::{doc, Index, ReloadPolicy}; use tantivy::{doc, Index, ReloadPolicy};
use tempfile::TempDir; use tempfile::TempDir;

View File

@@ -50,7 +50,7 @@ fn main() -> tantivy::Result<()> {
// This tokenizer lowers all of the text (to help with stop word matching) // This tokenizer lowers all of the text (to help with stop word matching)
// then removes all instances of `the` and `and` from the corpus // then removes all instances of `the` and `and` from the corpus
let tokenizer = TextAnalyzer::from(SimpleTokenizer) let tokenizer = SimpleTokenizer
.filter(LowerCaser) .filter(LowerCaser)
.filter(StopWordFilter::remove(vec![ .filter(StopWordFilter::remove(vec![
"the".to_string(), "the".to_string(),

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "tantivy-query-grammar" name = "tantivy-query-grammar"
version = "0.13.0" version = "0.11.0"
authors = ["Paul Masurel <paul.masurel@gmail.com>"] authors = ["Paul Masurel <paul.masurel@gmail.com>"]
license = "MIT" license = "MIT"
categories = ["database-implementations", "data-structures"] categories = ["database-implementations", "data-structures"]
@@ -13,4 +13,4 @@ keywords = ["search", "information", "retrieval"]
edition = "2018" edition = "2018"
[dependencies] [dependencies]
combine = {version="4", default-features=false, features=[] } combine = ">=3.6.0,<4.0.0"

View File

@@ -1,3 +1,5 @@
#![recursion_limit = "100"]
mod occur; mod occur;
mod query_grammar; mod query_grammar;
mod user_input_ast; mod user_input_ast;

View File

@@ -1,209 +1,171 @@
use super::user_input_ast::{UserInputAST, UserInputBound, UserInputLeaf, UserInputLiteral}; use super::user_input_ast::*;
use crate::Occur; use crate::Occur;
use combine::error::StringStreamError; use combine::char::*;
use combine::parser::char::{char, digit, letter, space, spaces, string}; use combine::error::StreamError;
use combine::parser::Parser; use combine::stream::StreamErrorFor;
use combine::{ use combine::*;
attempt, choice, eof, many, many1, one_of, optional, parser, satisfy, skip_many1, value,
};
fn field<'a>() -> impl Parser<&'a str, Output = String> { parser! {
( fn field[I]()(I) -> String
letter(), where [I: Stream<Item = char>] {
many(satisfy(|c: char| c.is_alphanumeric() || c == '_')), (
) letter(),
.skip(char(':')) many(satisfy(|c: char| c.is_alphanumeric() || c == '_')),
).skip(char(':')).map(|(s1, s2): (char, String)| format!("{}{}", s1, s2))
}
}
parser! {
fn word[I]()(I) -> String
where [I: Stream<Item = char>] {
(
satisfy(|c: char| !c.is_whitespace() && !['-', '`', ':', '{', '}', '"', '[', ']', '(',')'].contains(&c) ),
many(satisfy(|c: char| !c.is_whitespace() && ![':', '{', '}', '"', '[', ']', '(',')'].contains(&c)))
)
.map(|(s1, s2): (char, String)| format!("{}{}", s1, s2)) .map(|(s1, s2): (char, String)| format!("{}{}", s1, s2))
.and_then(|s: String|
match s.as_str() {
"OR" => Err(StreamErrorFor::<I>::unexpected_static_message("OR")),
"AND" => Err(StreamErrorFor::<I>::unexpected_static_message("AND")),
"NOT" => Err(StreamErrorFor::<I>::unexpected_static_message("NOT")),
_ => Ok(s)
})
}
} }
fn word<'a>() -> impl Parser<&'a str, Output = String> { parser! {
( fn literal[I]()(I) -> UserInputLeaf
satisfy(|c: char| { where [I: Stream<Item = char>]
!c.is_whitespace() {
&& !['-', '^', '`', ':', '{', '}', '"', '[', ']', '(', ')'].contains(&c) let term_val = || {
}), let phrase = char('"').with(many1(satisfy(|c| c != '"'))).skip(char('"'));
many(satisfy(|c: char| { phrase.or(word())
!c.is_whitespace() && ![':', '^', '{', '}', '"', '[', ']', '(', ')'].contains(&c) };
})), let term_val_with_field = negative_number().or(term_val());
) let term_query =
.map(|(s1, s2): (char, String)| format!("{}{}", s1, s2)) (field(), term_val_with_field)
.and_then(|s: String| match s.as_str() { .map(|(field_name, phrase)| UserInputLiteral {
"OR" | "AND " | "NOT" => Err(StringStreamError::UnexpectedParse), field_name: Some(field_name),
_ => Ok(s), phrase,
});
let term_default_field = term_val().map(|phrase| UserInputLiteral {
field_name: None,
phrase,
});
attempt(term_query)
.or(term_default_field)
.map(UserInputLeaf::from)
}
}
parser! {
fn negative_number[I]()(I) -> String
where [I: Stream<Item = char>]
{
(char('-'), many1(satisfy(char::is_numeric)),
optional((char('.'), many1(satisfy(char::is_numeric)))))
.map(|(s1, s2, s3): (char, String, Option<(char, String)>)| {
if let Some(('.', s3)) = s3 {
format!("{}{}.{}", s1, s2, s3)
} else {
format!("{}{}", s1, s2)
}
})
}
}
parser! {
fn spaces1[I]()(I) -> ()
where [I: Stream<Item = char>] {
skip_many1(space())
}
}
parser! {
/// Function that parses a range out of a Stream
/// Supports ranges like:
/// [5 TO 10], {5 TO 10}, [* TO 10], [10 TO *], {10 TO *], >5, <=10
/// [a TO *], [a TO c], [abc TO bcd}
fn range[I]()(I) -> UserInputLeaf
where [I: Stream<Item = char>] {
let range_term_val = || {
word().or(negative_number()).or(char('*').with(value("*".to_string())))
};
// check for unbounded range in the form of <5, <=10, >5, >=5
let elastic_unbounded_range = (choice([attempt(string(">=")),
attempt(string("<=")),
attempt(string("<")),
attempt(string(">"))])
.skip(spaces()),
range_term_val()).
map(|(comparison_sign, bound): (&str, String)|
match comparison_sign {
">=" => (UserInputBound::Inclusive(bound), UserInputBound::Unbounded),
"<=" => (UserInputBound::Unbounded, UserInputBound::Inclusive(bound)),
"<" => (UserInputBound::Unbounded, UserInputBound::Exclusive(bound)),
">" => (UserInputBound::Exclusive(bound), UserInputBound::Unbounded),
// default case
_ => (UserInputBound::Unbounded, UserInputBound::Unbounded)
});
let lower_bound = (one_of("{[".chars()), range_term_val())
.map(|(boundary_char, lower_bound): (char, String)|
if lower_bound == "*" {
UserInputBound::Unbounded
} else if boundary_char == '{' {
UserInputBound::Exclusive(lower_bound)
} else {
UserInputBound::Inclusive(lower_bound)
});
let upper_bound = (range_term_val(), one_of("}]".chars()))
.map(|(higher_bound, boundary_char): (String, char)|
if higher_bound == "*" {
UserInputBound::Unbounded
} else if boundary_char == '}' {
UserInputBound::Exclusive(higher_bound)
} else {
UserInputBound::Inclusive(higher_bound)
});
// return only lower and upper
let lower_to_upper = (lower_bound.
skip((spaces(),
string("TO"),
spaces())),
upper_bound);
(optional(field()).skip(spaces()),
// try elastic first, if it matches, the range is unbounded
attempt(elastic_unbounded_range).or(lower_to_upper))
.map(|(field, (lower, upper))|
// Construct the leaf from extracted field (optional)
// and bounds
UserInputLeaf::Range {
field,
lower,
upper
}) })
} }
fn term_val<'a>() -> impl Parser<&'a str, Output = String> {
let phrase = char('"').with(many1(satisfy(|c| c != '"'))).skip(char('"'));
phrase.or(word())
}
fn term_query<'a>() -> impl Parser<&'a str, Output = UserInputLiteral> {
let term_val_with_field = negative_number().or(term_val());
(field(), term_val_with_field).map(|(field_name, phrase)| UserInputLiteral {
field_name: Some(field_name),
phrase,
})
}
fn literal<'a>() -> impl Parser<&'a str, Output = UserInputLeaf> {
let term_default_field = term_val().map(|phrase| UserInputLiteral {
field_name: None,
phrase,
});
attempt(term_query())
.or(term_default_field)
.map(UserInputLeaf::from)
}
fn negative_number<'a>() -> impl Parser<&'a str, Output = String> {
(
char('-'),
many1(digit()),
optional((char('.'), many1(digit()))),
)
.map(|(s1, s2, s3): (char, String, Option<(char, String)>)| {
if let Some(('.', s3)) = s3 {
format!("{}{}.{}", s1, s2, s3)
} else {
format!("{}{}", s1, s2)
}
})
}
fn spaces1<'a>() -> impl Parser<&'a str, Output = ()> {
skip_many1(space())
}
/// Function that parses a range out of a Stream
/// Supports ranges like:
/// [5 TO 10], {5 TO 10}, [* TO 10], [10 TO *], {10 TO *], >5, <=10
/// [a TO *], [a TO c], [abc TO bcd}
fn range<'a>() -> impl Parser<&'a str, Output = UserInputLeaf> {
let range_term_val = || {
word()
.or(negative_number())
.or(char('*').with(value("*".to_string())))
};
// check for unbounded range in the form of <5, <=10, >5, >=5
let elastic_unbounded_range = (
choice([
attempt(string(">=")),
attempt(string("<=")),
attempt(string("<")),
attempt(string(">")),
])
.skip(spaces()),
range_term_val(),
)
.map(
|(comparison_sign, bound): (&str, String)| match comparison_sign {
">=" => (UserInputBound::Inclusive(bound), UserInputBound::Unbounded),
"<=" => (UserInputBound::Unbounded, UserInputBound::Inclusive(bound)),
"<" => (UserInputBound::Unbounded, UserInputBound::Exclusive(bound)),
">" => (UserInputBound::Exclusive(bound), UserInputBound::Unbounded),
// default case
_ => (UserInputBound::Unbounded, UserInputBound::Unbounded),
},
);
let lower_bound = (one_of("{[".chars()), range_term_val()).map(
|(boundary_char, lower_bound): (char, String)| {
if lower_bound == "*" {
UserInputBound::Unbounded
} else if boundary_char == '{' {
UserInputBound::Exclusive(lower_bound)
} else {
UserInputBound::Inclusive(lower_bound)
}
},
);
let upper_bound = (range_term_val(), one_of("}]".chars())).map(
|(higher_bound, boundary_char): (String, char)| {
if higher_bound == "*" {
UserInputBound::Unbounded
} else if boundary_char == '}' {
UserInputBound::Exclusive(higher_bound)
} else {
UserInputBound::Inclusive(higher_bound)
}
},
);
// return only lower and upper
let lower_to_upper = (
lower_bound.skip((spaces(), string("TO"), spaces())),
upper_bound,
);
(
optional(field()).skip(spaces()),
// try elastic first, if it matches, the range is unbounded
attempt(elastic_unbounded_range).or(lower_to_upper),
)
.map(|(field, (lower, upper))|
// Construct the leaf from extracted field (optional)
// and bounds
UserInputLeaf::Range {
field,
lower,
upper
})
} }
fn negate(expr: UserInputAST) -> UserInputAST { fn negate(expr: UserInputAST) -> UserInputAST {
expr.unary(Occur::MustNot) expr.unary(Occur::MustNot)
} }
fn leaf<'a>() -> impl Parser<&'a str, Output = UserInputAST> { fn must(expr: UserInputAST) -> UserInputAST {
parser(|input| { expr.unary(Occur::Must)
char('(')
.with(ast())
.skip(char(')'))
.or(char('*').map(|_| UserInputAST::from(UserInputLeaf::All)))
.or(attempt(
string("NOT").skip(spaces1()).with(leaf()).map(negate),
))
.or(attempt(range().map(UserInputAST::from)))
.or(literal().map(UserInputAST::from))
.parse_stream(input)
.into_result()
})
} }
fn occur_symbol<'a>() -> impl Parser<&'a str, Output = Occur> { parser! {
char('-') fn leaf[I]()(I) -> UserInputAST
.map(|_| Occur::MustNot) where [I: Stream<Item = char>] {
.or(char('+').map(|_| Occur::Must)) char('-').with(leaf()).map(negate)
} .or(char('+').with(leaf()).map(must))
.or(char('(').with(ast()).skip(char(')')))
fn occur_leaf<'a>() -> impl Parser<&'a str, Output = (Option<Occur>, UserInputAST)> { .or(char('*').map(|_| UserInputAST::from(UserInputLeaf::All)))
(optional(occur_symbol()), boosted_leaf()) .or(attempt(string("NOT").skip(spaces1()).with(leaf()).map(negate)))
} .or(attempt(range().map(UserInputAST::from)))
.or(literal().map(UserInputAST::from))
fn positive_float_number<'a>() -> impl Parser<&'a str, Output = f32> { }
(many1(digit()), optional((char('.'), many1(digit())))).map(
|(int_part, decimal_part_opt): (String, Option<(char, String)>)| {
let mut float_str = int_part;
if let Some((chr, decimal_str)) = decimal_part_opt {
float_str.push(chr);
float_str.push_str(&decimal_str);
}
float_str.parse::<f32>().unwrap()
},
)
}
fn boost<'a>() -> impl Parser<&'a str, Output = f32> {
(char('^'), positive_float_number()).map(|(_, boost)| boost)
}
fn boosted_leaf<'a>() -> impl Parser<&'a str, Output = UserInputAST> {
(leaf(), optional(boost())).map(|(leaf, boost_opt)| match boost_opt {
Some(boost) if (boost - 1.0).abs() > std::f32::EPSILON => {
UserInputAST::Boost(Box::new(leaf), boost)
}
_ => leaf,
})
} }
#[derive(Clone, Copy)] #[derive(Clone, Copy)]
@@ -212,10 +174,13 @@ enum BinaryOperand {
And, And,
} }
fn binary_operand<'a>() -> impl Parser<&'a str, Output = BinaryOperand> { parser! {
string("AND") fn binary_operand[I]()(I) -> BinaryOperand
.with(value(BinaryOperand::And)) where [I: Stream<Item = char>]
.or(string("OR").with(value(BinaryOperand::Or))) {
string("AND").with(value(BinaryOperand::And))
.or(string("OR").with(value(BinaryOperand::Or)))
}
} }
fn aggregate_binary_expressions( fn aggregate_binary_expressions(
@@ -243,81 +208,37 @@ fn aggregate_binary_expressions(
} }
} }
fn operand_leaf<'a>() -> impl Parser<&'a str, Output = (BinaryOperand, UserInputAST)> { parser! {
( pub fn ast[I]()(I) -> UserInputAST
binary_operand().skip(spaces()), where [I: Stream<Item = char>]
boosted_leaf().skip(spaces()), {
) let operand_leaf = (binary_operand().skip(spaces()), leaf().skip(spaces()));
} let boolean_expr = (leaf().skip(spaces().silent()), many1(operand_leaf)).map(
|(left, right)| aggregate_binary_expressions(left,right));
pub fn ast<'a>() -> impl Parser<&'a str, Output = UserInputAST> { let whitespace_separated_leaves = many1(leaf().skip(spaces().silent()))
let boolean_expr = (boosted_leaf().skip(spaces()), many1(operand_leaf())) .map(|subqueries: Vec<UserInputAST>|
.map(|(left, right)| aggregate_binary_expressions(left, right));
let whitespace_separated_leaves = many1(occur_leaf().skip(spaces().silent())).map(
|subqueries: Vec<(Option<Occur>, UserInputAST)>| {
if subqueries.len() == 1 { if subqueries.len() == 1 {
let (occur_opt, ast) = subqueries.into_iter().next().unwrap(); subqueries.into_iter().next().unwrap()
match occur_opt.unwrap_or(Occur::Should) {
Occur::Must | Occur::Should => ast,
Occur::MustNot => UserInputAST::Clause(vec![(Some(Occur::MustNot), ast)]),
}
} else { } else {
UserInputAST::Clause(subqueries.into_iter().collect()) UserInputAST::Clause(subqueries.into_iter().collect())
} });
}, let expr = attempt(boolean_expr).or(whitespace_separated_leaves);
); spaces().with(expr).skip(spaces())
let expr = attempt(boolean_expr).or(whitespace_separated_leaves); }
spaces().with(expr).skip(spaces())
} }
pub fn parse_to_ast<'a>() -> impl Parser<&'a str, Output = UserInputAST> { parser! {
spaces() pub fn parse_to_ast[I]()(I) -> UserInputAST
.with(optional(ast()).skip(eof())) where [I: Stream<Item = char>]
.map(|opt_ast| opt_ast.unwrap_or_else(UserInputAST::empty_query)) {
spaces().with(optional(ast()).skip(eof())).map(|opt_ast| opt_ast.unwrap_or_else(UserInputAST::empty_query))
}
} }
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use super::*; use super::*;
use combine::parser::Parser;
pub fn nearly_equals(a: f32, b: f32) -> bool {
(a - b).abs() < 0.0005 * (a + b).abs()
}
fn assert_nearly_equals(expected: f32, val: f32) {
assert!(
nearly_equals(val, expected),
"Got {}, expected {}.",
val,
expected
);
}
#[test]
fn test_occur_symbol() {
assert_eq!(super::occur_symbol().parse("-"), Ok((Occur::MustNot, "")));
assert_eq!(super::occur_symbol().parse("+"), Ok((Occur::Must, "")));
}
#[test]
fn test_positive_float_number() {
fn valid_parse(float_str: &str, expected_val: f32, expected_remaining: &str) {
let (val, remaining) = positive_float_number().parse(float_str).unwrap();
assert_eq!(remaining, expected_remaining);
assert_nearly_equals(val, expected_val);
}
fn error_parse(float_str: &str) {
assert!(positive_float_number().parse(float_str).is_err());
}
valid_parse("1.0", 1.0f32, "");
valid_parse("1", 1.0f32, "");
valid_parse("0.234234 aaa", 0.234234f32, " aaa");
error_parse(".3332");
error_parse("1.");
error_parse("-1.");
}
fn test_parse_query_to_ast_helper(query: &str, expected: &str) { fn test_parse_query_to_ast_helper(query: &str, expected: &str) {
let query = parse_to_ast().parse(query).unwrap().0; let query = parse_to_ast().parse(query).unwrap().0;
@@ -348,24 +269,15 @@ mod test {
"Err(UnexpectedParse)" "Err(UnexpectedParse)"
); );
test_parse_query_to_ast_helper("NOTa", "\"NOTa\""); test_parse_query_to_ast_helper("NOTa", "\"NOTa\"");
test_parse_query_to_ast_helper("NOT a", "(-\"a\")"); test_parse_query_to_ast_helper("NOT a", "-(\"a\")");
}
#[test]
fn test_boosting() {
assert!(parse_to_ast().parse("a^2^3").is_err());
assert!(parse_to_ast().parse("a^2^").is_err());
test_parse_query_to_ast_helper("a^3", "(\"a\")^3");
test_parse_query_to_ast_helper("a^3 b^2", "(*(\"a\")^3 *(\"b\")^2)");
test_parse_query_to_ast_helper("a^1", "\"a\"");
} }
#[test] #[test]
fn test_parse_query_to_ast_binary_op() { fn test_parse_query_to_ast_binary_op() {
test_parse_query_to_ast_helper("a AND b", "(+\"a\" +\"b\")"); test_parse_query_to_ast_helper("a AND b", "(+(\"a\") +(\"b\"))");
test_parse_query_to_ast_helper("a OR b", "(?\"a\" ?\"b\")"); test_parse_query_to_ast_helper("a OR b", "(?(\"a\") ?(\"b\"))");
test_parse_query_to_ast_helper("a OR b AND c", "(?\"a\" ?(+\"b\" +\"c\"))"); test_parse_query_to_ast_helper("a OR b AND c", "(?(\"a\") ?((+(\"b\") +(\"c\"))))");
test_parse_query_to_ast_helper("a AND b AND c", "(+\"a\" +\"b\" +\"c\")"); test_parse_query_to_ast_helper("a AND b AND c", "(+(\"a\") +(\"b\") +(\"c\"))");
assert_eq!( assert_eq!(
format!("{:?}", parse_to_ast().parse("a OR b aaa")), format!("{:?}", parse_to_ast().parse("a OR b aaa")),
"Err(UnexpectedParse)" "Err(UnexpectedParse)"
@@ -403,13 +315,6 @@ mod test {
test_parse_query_to_ast_helper("weight: <= 70.5", "weight:{\"*\" TO \"70.5\"]"); test_parse_query_to_ast_helper("weight: <= 70.5", "weight:{\"*\" TO \"70.5\"]");
} }
#[test]
fn test_occur_leaf() {
let ((occur, ast), _) = super::occur_leaf().parse("+abc").unwrap();
assert_eq!(occur, Some(Occur::Must));
assert_eq!(format!("{:?}", ast), "\"abc\"");
}
#[test] #[test]
fn test_range_parser() { fn test_range_parser() {
// testing the range() parser separately // testing the range() parser separately
@@ -438,67 +343,32 @@ mod test {
fn test_parse_query_to_triming_spaces() { fn test_parse_query_to_triming_spaces() {
test_parse_query_to_ast_helper(" abc", "\"abc\""); test_parse_query_to_ast_helper(" abc", "\"abc\"");
test_parse_query_to_ast_helper("abc ", "\"abc\""); test_parse_query_to_ast_helper("abc ", "\"abc\"");
test_parse_query_to_ast_helper("( a OR abc)", "(?\"a\" ?\"abc\")"); test_parse_query_to_ast_helper("( a OR abc)", "(?(\"a\") ?(\"abc\"))");
test_parse_query_to_ast_helper("(a OR abc)", "(?\"a\" ?\"abc\")"); test_parse_query_to_ast_helper("(a OR abc)", "(?(\"a\") ?(\"abc\"))");
test_parse_query_to_ast_helper("(a OR abc)", "(?\"a\" ?\"abc\")"); test_parse_query_to_ast_helper("(a OR abc)", "(?(\"a\") ?(\"abc\"))");
test_parse_query_to_ast_helper("a OR abc ", "(?\"a\" ?\"abc\")"); test_parse_query_to_ast_helper("a OR abc ", "(?(\"a\") ?(\"abc\"))");
test_parse_query_to_ast_helper("(a OR abc )", "(?\"a\" ?\"abc\")"); test_parse_query_to_ast_helper("(a OR abc )", "(?(\"a\") ?(\"abc\"))");
test_parse_query_to_ast_helper("(a OR abc) ", "(?\"a\" ?\"abc\")"); test_parse_query_to_ast_helper("(a OR abc) ", "(?(\"a\") ?(\"abc\"))");
} }
#[test] #[test]
fn test_parse_query_single_term() { fn test_parse_query_to_ast() {
test_parse_query_to_ast_helper("abc", "\"abc\""); test_parse_query_to_ast_helper("abc", "\"abc\"");
} test_parse_query_to_ast_helper("a b", "(\"a\" \"b\")");
test_parse_query_to_ast_helper("+(a b)", "+((\"a\" \"b\"))");
#[test] test_parse_query_to_ast_helper("+d", "+(\"d\")");
fn test_parse_query_default_clause() { test_parse_query_to_ast_helper("+(a b) +d", "(+((\"a\" \"b\")) +(\"d\"))");
test_parse_query_to_ast_helper("a b", "(*\"a\" *\"b\")"); test_parse_query_to_ast_helper("(+a +b) d", "((+(\"a\") +(\"b\")) \"d\")");
} test_parse_query_to_ast_helper("(+a)", "+(\"a\")");
test_parse_query_to_ast_helper("(+a +b)", "(+(\"a\") +(\"b\"))");
#[test]
fn test_parse_query_must_default_clause() {
test_parse_query_to_ast_helper("+(a b)", "(*\"a\" *\"b\")");
}
#[test]
fn test_parse_query_must_single_term() {
test_parse_query_to_ast_helper("+d", "\"d\"");
}
#[test]
fn test_single_term_with_field() {
test_parse_query_to_ast_helper("abc:toto", "abc:\"toto\""); test_parse_query_to_ast_helper("abc:toto", "abc:\"toto\"");
}
#[test]
fn test_single_term_with_float() {
test_parse_query_to_ast_helper("abc:1.1", "abc:\"1.1\""); test_parse_query_to_ast_helper("abc:1.1", "abc:\"1.1\"");
} test_parse_query_to_ast_helper("+abc:toto", "+(abc:\"toto\")");
test_parse_query_to_ast_helper("(+abc:toto -titi)", "(+(abc:\"toto\") -(\"titi\"))");
#[test] test_parse_query_to_ast_helper("-abc:toto", "-(abc:\"toto\")");
fn test_must_clause() { test_parse_query_to_ast_helper("abc:a b", "(abc:\"a\" \"b\")");
test_parse_query_to_ast_helper("(+a +b)", "(+\"a\" +\"b\")");
}
#[test]
fn test_parse_test_query_plus_a_b_plus_d() {
test_parse_query_to_ast_helper("+(a b) +d", "(+(*\"a\" *\"b\") +\"d\")");
}
#[test]
fn test_parse_test_query_other() {
test_parse_query_to_ast_helper("(+a +b) d", "(*(+\"a\" +\"b\") *\"d\")");
test_parse_query_to_ast_helper("+abc:toto", "abc:\"toto\"");
test_parse_query_to_ast_helper("(+abc:toto -titi)", "(+abc:\"toto\" -\"titi\")");
test_parse_query_to_ast_helper("-abc:toto", "(-abc:\"toto\")");
test_parse_query_to_ast_helper("abc:a b", "(*abc:\"a\" *\"b\")");
test_parse_query_to_ast_helper("abc:\"a b\"", "abc:\"a b\""); test_parse_query_to_ast_helper("abc:\"a b\"", "abc:\"a b\"");
test_parse_query_to_ast_helper("foo:[1 TO 5]", "foo:[\"1\" TO \"5\"]"); test_parse_query_to_ast_helper("foo:[1 TO 5]", "foo:[\"1\" TO \"5\"]");
}
#[test]
fn test_parse_query_with_range() {
test_parse_query_to_ast_helper("[1 TO 5]", "[\"1\" TO \"5\"]"); test_parse_query_to_ast_helper("[1 TO 5]", "[\"1\" TO \"5\"]");
test_parse_query_to_ast_helper("foo:{a TO z}", "foo:{\"a\" TO \"z\"}"); test_parse_query_to_ast_helper("foo:{a TO z}", "foo:{\"a\" TO \"z\"}");
test_parse_query_to_ast_helper("foo:[1 TO toto}", "foo:[\"1\" TO \"toto\"}"); test_parse_query_to_ast_helper("foo:[1 TO toto}", "foo:[\"1\" TO \"toto\"}");

View File

@@ -85,14 +85,14 @@ impl UserInputBound {
} }
pub enum UserInputAST { pub enum UserInputAST {
Clause(Vec<(Option<Occur>, UserInputAST)>), Clause(Vec<UserInputAST>),
Unary(Occur, Box<UserInputAST>),
Leaf(Box<UserInputLeaf>), Leaf(Box<UserInputLeaf>),
Boost(Box<UserInputAST>, f32),
} }
impl UserInputAST { impl UserInputAST {
pub fn unary(self, occur: Occur) -> UserInputAST { pub fn unary(self, occur: Occur) -> UserInputAST {
UserInputAST::Clause(vec![(Some(occur), self)]) UserInputAST::Unary(occur, Box::new(self))
} }
fn compose(occur: Occur, asts: Vec<UserInputAST>) -> UserInputAST { fn compose(occur: Occur, asts: Vec<UserInputAST>) -> UserInputAST {
@@ -103,7 +103,7 @@ impl UserInputAST {
} else { } else {
UserInputAST::Clause( UserInputAST::Clause(
asts.into_iter() asts.into_iter()
.map(|ast: UserInputAST| (Some(occur), ast)) .map(|ast: UserInputAST| ast.unary(occur))
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
) )
} }
@@ -134,38 +134,26 @@ impl From<UserInputLeaf> for UserInputAST {
} }
} }
fn print_occur_ast(
occur_opt: Option<Occur>,
ast: &UserInputAST,
formatter: &mut fmt::Formatter,
) -> fmt::Result {
if let Some(occur) = occur_opt {
write!(formatter, "{}{:?}", occur, ast)?;
} else {
write!(formatter, "*{:?}", ast)?;
}
Ok(())
}
impl fmt::Debug for UserInputAST { impl fmt::Debug for UserInputAST {
fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
match *self { match *self {
UserInputAST::Clause(ref subqueries) => { UserInputAST::Clause(ref subqueries) => {
if subqueries.is_empty() { if subqueries.is_empty() {
write!(formatter, "<emptyclause>")?; write!(formatter, "<emptyclause>")?;
} else { } else {
write!(formatter, "(")?; write!(formatter, "(")?;
print_occur_ast(subqueries[0].0, &subqueries[0].1, formatter)?; write!(formatter, "{:?}", &subqueries[0])?;
for subquery in &subqueries[1..] { for subquery in &subqueries[1..] {
write!(formatter, " ")?; write!(formatter, " {:?}", subquery)?;
print_occur_ast(subquery.0, &subquery.1, formatter)?;
} }
write!(formatter, ")")?; write!(formatter, ")")?;
} }
Ok(()) Ok(())
} }
UserInputAST::Unary(ref occur, ref subquery) => {
write!(formatter, "{}({:?})", occur, subquery)
}
UserInputAST::Leaf(ref subquery) => write!(formatter, "{:?}", subquery), UserInputAST::Leaf(ref subquery) => write!(formatter, "{:?}", subquery),
UserInputAST::Boost(ref leaf, boost) => write!(formatter, "({:?})^{}", leaf, boost),
} }
} }
} }

View File

@@ -1,6 +1,7 @@
use super::Collector; use super::Collector;
use crate::collector::SegmentCollector; use crate::collector::SegmentCollector;
use crate::DocId; use crate::DocId;
use crate::Result;
use crate::Score; use crate::Score;
use crate::SegmentLocalId; use crate::SegmentLocalId;
use crate::SegmentReader; use crate::SegmentReader;
@@ -43,11 +44,7 @@ impl Collector for Count {
type Child = SegmentCountCollector; type Child = SegmentCountCollector;
fn for_segment( fn for_segment(&self, _: SegmentLocalId, _: &SegmentReader) -> Result<SegmentCountCollector> {
&self,
_: SegmentLocalId,
_: &SegmentReader,
) -> crate::Result<SegmentCountCollector> {
Ok(SegmentCountCollector::default()) Ok(SegmentCountCollector::default())
} }
@@ -55,7 +52,7 @@ impl Collector for Count {
false false
} }
fn merge_fruits(&self, segment_counts: Vec<usize>) -> crate::Result<usize> { fn merge_fruits(&self, segment_counts: Vec<usize>) -> Result<usize> {
Ok(segment_counts.into_iter().sum()) Ok(segment_counts.into_iter().sum())
} }
} }

View File

@@ -1,5 +1,6 @@
use crate::collector::top_collector::{TopCollector, TopSegmentCollector}; use crate::collector::top_collector::{TopCollector, TopSegmentCollector};
use crate::collector::{Collector, SegmentCollector}; use crate::collector::{Collector, SegmentCollector};
use crate::Result;
use crate::{DocAddress, DocId, Score, SegmentReader}; use crate::{DocAddress, DocId, Score, SegmentReader};
pub(crate) struct CustomScoreTopCollector<TCustomScorer, TScore = Score> { pub(crate) struct CustomScoreTopCollector<TCustomScorer, TScore = Score> {
@@ -11,13 +12,13 @@ impl<TCustomScorer, TScore> CustomScoreTopCollector<TCustomScorer, TScore>
where where
TScore: Clone + PartialOrd, TScore: Clone + PartialOrd,
{ {
pub(crate) fn new( pub fn new(
custom_scorer: TCustomScorer, custom_scorer: TCustomScorer,
collector: TopCollector<TScore>, limit: usize,
) -> CustomScoreTopCollector<TCustomScorer, TScore> { ) -> CustomScoreTopCollector<TCustomScorer, TScore> {
CustomScoreTopCollector { CustomScoreTopCollector {
custom_scorer, custom_scorer,
collector, collector: TopCollector::with_limit(limit),
} }
} }
} }
@@ -28,7 +29,7 @@ where
/// It is the segment local version of the [`CustomScorer`](./trait.CustomScorer.html). /// It is the segment local version of the [`CustomScorer`](./trait.CustomScorer.html).
pub trait CustomSegmentScorer<TScore>: 'static { pub trait CustomSegmentScorer<TScore>: 'static {
/// Computes the score of a specific `doc`. /// Computes the score of a specific `doc`.
fn score(&mut self, doc: DocId) -> TScore; fn score(&self, doc: DocId) -> TScore;
} }
/// `CustomScorer` makes it possible to define any kind of score. /// `CustomScorer` makes it possible to define any kind of score.
@@ -41,7 +42,7 @@ pub trait CustomScorer<TScore>: Sync {
type Child: CustomSegmentScorer<TScore>; type Child: CustomSegmentScorer<TScore>;
/// Builds a child scorer for a specific segment. The child scorer is associated to /// Builds a child scorer for a specific segment. The child scorer is associated to
/// a specific segment. /// a specific segment.
fn segment_scorer(&self, segment_reader: &SegmentReader) -> crate::Result<Self::Child>; fn segment_scorer(&self, segment_reader: &SegmentReader) -> Result<Self::Child>;
} }
impl<TCustomScorer, TScore> Collector for CustomScoreTopCollector<TCustomScorer, TScore> impl<TCustomScorer, TScore> Collector for CustomScoreTopCollector<TCustomScorer, TScore>
@@ -57,7 +58,7 @@ where
&self, &self,
segment_local_id: u32, segment_local_id: u32,
segment_reader: &SegmentReader, segment_reader: &SegmentReader,
) -> crate::Result<Self::Child> { ) -> Result<Self::Child> {
let segment_scorer = self.custom_scorer.segment_scorer(segment_reader)?; let segment_scorer = self.custom_scorer.segment_scorer(segment_reader)?;
let segment_collector = self let segment_collector = self
.collector .collector
@@ -72,7 +73,7 @@ where
false false
} }
fn merge_fruits(&self, segment_fruits: Vec<Self::Fruit>) -> crate::Result<Self::Fruit> { fn merge_fruits(&self, segment_fruits: Vec<Self::Fruit>) -> Result<Self::Fruit> {
self.collector.merge_fruits(segment_fruits) self.collector.merge_fruits(segment_fruits)
} }
} }
@@ -110,16 +111,16 @@ where
{ {
type Child = T; type Child = T;
fn segment_scorer(&self, segment_reader: &SegmentReader) -> crate::Result<Self::Child> { fn segment_scorer(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
Ok((self)(segment_reader)) Ok((self)(segment_reader))
} }
} }
impl<F, TScore> CustomSegmentScorer<TScore> for F impl<F, TScore> CustomSegmentScorer<TScore> for F
where where
F: 'static + FnMut(DocId) -> TScore, F: 'static + Sync + Send + Fn(DocId) -> TScore,
{ {
fn score(&mut self, doc: DocId) -> TScore { fn score(&self, doc: DocId) -> TScore {
(self)(doc) (self)(doc)
} }
} }

View File

@@ -1,9 +1,11 @@
use crate::collector::Collector; use crate::collector::Collector;
use crate::collector::SegmentCollector; use crate::collector::SegmentCollector;
use crate::docset::SkipResult;
use crate::fastfield::FacetReader; use crate::fastfield::FacetReader;
use crate::schema::Facet; use crate::schema::Facet;
use crate::schema::Field; use crate::schema::Field;
use crate::DocId; use crate::DocId;
use crate::Result;
use crate::Score; use crate::Score;
use crate::SegmentLocalId; use crate::SegmentLocalId;
use crate::SegmentReader; use crate::SegmentReader;
@@ -82,9 +84,9 @@ fn facet_depth(facet_bytes: &[u8]) -> usize {
/// use tantivy::collector::FacetCollector; /// use tantivy::collector::FacetCollector;
/// use tantivy::query::AllQuery; /// use tantivy::query::AllQuery;
/// use tantivy::schema::{Facet, Schema, TEXT}; /// use tantivy::schema::{Facet, Schema, TEXT};
/// use tantivy::{doc, Index}; /// use tantivy::{doc, Index, Result};
/// ///
/// fn example() -> tantivy::Result<()> { /// fn example() -> Result<()> {
/// let mut schema_builder = Schema::builder(); /// let mut schema_builder = Schema::builder();
/// ///
/// // Facet have their own specific type. /// // Facet have their own specific type.
@@ -187,11 +189,6 @@ pub struct FacetSegmentCollector {
collapse_facet_ords: Vec<u64>, collapse_facet_ords: Vec<u64>,
} }
enum SkipResult {
Found,
NotFound,
}
fn skip<'a, I: Iterator<Item = &'a Facet>>( fn skip<'a, I: Iterator<Item = &'a Facet>>(
target: &[u8], target: &[u8],
collapse_it: &mut Peekable<I>, collapse_it: &mut Peekable<I>,
@@ -201,14 +198,14 @@ fn skip<'a, I: Iterator<Item = &'a Facet>>(
Some(facet_bytes) => match facet_bytes.encoded_str().as_bytes().cmp(target) { Some(facet_bytes) => match facet_bytes.encoded_str().as_bytes().cmp(target) {
Ordering::Less => {} Ordering::Less => {}
Ordering::Greater => { Ordering::Greater => {
return SkipResult::NotFound; return SkipResult::OverStep;
} }
Ordering::Equal => { Ordering::Equal => {
return SkipResult::Found; return SkipResult::Reached;
} }
}, },
None => { None => {
return SkipResult::NotFound; return SkipResult::End;
} }
} }
collapse_it.next(); collapse_it.next();
@@ -265,7 +262,7 @@ impl Collector for FacetCollector {
&self, &self,
_: SegmentLocalId, _: SegmentLocalId,
reader: &SegmentReader, reader: &SegmentReader,
) -> crate::Result<FacetSegmentCollector> { ) -> Result<FacetSegmentCollector> {
let field_name = reader.schema().get_field_name(self.field); let field_name = reader.schema().get_field_name(self.field);
let facet_reader = reader.facet_reader(self.field).ok_or_else(|| { let facet_reader = reader.facet_reader(self.field).ok_or_else(|| {
TantivyError::SchemaError(format!("Field {:?} is not a facet field.", field_name)) TantivyError::SchemaError(format!("Field {:?} is not a facet field.", field_name))
@@ -285,7 +282,7 @@ impl Collector for FacetCollector {
// is positionned on a term that has not been processed yet. // is positionned on a term that has not been processed yet.
let skip_result = skip(facet_streamer.key(), &mut collapse_facet_it); let skip_result = skip(facet_streamer.key(), &mut collapse_facet_it);
match skip_result { match skip_result {
SkipResult::Found => { SkipResult::Reached => {
// we reach a facet we decided to collapse. // we reach a facet we decided to collapse.
let collapse_depth = facet_depth(facet_streamer.key()); let collapse_depth = facet_depth(facet_streamer.key());
let mut collapsed_id = 0; let mut collapsed_id = 0;
@@ -305,7 +302,7 @@ impl Collector for FacetCollector {
} }
break; break;
} }
SkipResult::NotFound => { SkipResult::End | SkipResult::OverStep => {
collapse_mapping.push(0); collapse_mapping.push(0);
if !facet_streamer.advance() { if !facet_streamer.advance() {
break; break;
@@ -331,7 +328,7 @@ impl Collector for FacetCollector {
false false
} }
fn merge_fruits(&self, segments_facet_counts: Vec<FacetCounts>) -> crate::Result<FacetCounts> { fn merge_fruits(&self, segments_facet_counts: Vec<FacetCounts>) -> Result<FacetCounts> {
let mut facet_counts: BTreeMap<Facet, u64> = BTreeMap::new(); let mut facet_counts: BTreeMap<Facet, u64> = BTreeMap::new();
for segment_facet_counts in segments_facet_counts { for segment_facet_counts in segments_facet_counts {
for (facet, count) in segment_facet_counts.facet_counts { for (facet, count) in segment_facet_counts.facet_counts {

View File

@@ -85,6 +85,7 @@ See the `custom_collector` example.
*/ */
use crate::DocId; use crate::DocId;
use crate::Result;
use crate::Score; use crate::Score;
use crate::SegmentLocalId; use crate::SegmentLocalId;
use crate::SegmentReader; use crate::SegmentReader;
@@ -109,7 +110,6 @@ pub use self::tweak_score_top_collector::{ScoreSegmentTweaker, ScoreTweaker};
mod facet_collector; mod facet_collector;
pub use self::facet_collector::FacetCollector; pub use self::facet_collector::FacetCollector;
use crate::query::Weight;
/// `Fruit` is the type for the result of our collection. /// `Fruit` is the type for the result of our collection.
/// e.g. `usize` for the `Count` collector. /// e.g. `usize` for the `Count` collector.
@@ -147,37 +147,14 @@ pub trait Collector: Sync {
&self, &self,
segment_local_id: SegmentLocalId, segment_local_id: SegmentLocalId,
segment: &SegmentReader, segment: &SegmentReader,
) -> crate::Result<Self::Child>; ) -> Result<Self::Child>;
/// Returns true iff the collector requires to compute scores for documents. /// Returns true iff the collector requires to compute scores for documents.
fn requires_scoring(&self) -> bool; fn requires_scoring(&self) -> bool;
/// Combines the fruit associated to the collection of each segments /// Combines the fruit associated to the collection of each segments
/// into one fruit. /// into one fruit.
fn merge_fruits(&self, segment_fruits: Vec<Self::Fruit>) -> crate::Result<Self::Fruit>; fn merge_fruits(&self, segment_fruits: Vec<Self::Fruit>) -> Result<Self::Fruit>;
/// Created a segment collector and
fn collect_segment(
&self,
weight: &dyn Weight,
segment_ord: u32,
reader: &SegmentReader,
) -> crate::Result<<Self::Child as SegmentCollector>::Fruit> {
let mut segment_collector = self.for_segment(segment_ord as u32, reader)?;
if let Some(delete_bitset) = reader.delete_bitset() {
weight.for_each(reader, &mut |doc, score| {
if delete_bitset.is_alive(doc) {
segment_collector.collect(doc, score);
}
})?;
} else {
weight.for_each(reader, &mut |doc, score| {
segment_collector.collect(doc, score);
})?;
}
Ok(segment_collector.harvest())
}
} }
/// The `SegmentCollector` is the trait in charge of defining the /// The `SegmentCollector` is the trait in charge of defining the
@@ -208,11 +185,7 @@ where
type Fruit = (Left::Fruit, Right::Fruit); type Fruit = (Left::Fruit, Right::Fruit);
type Child = (Left::Child, Right::Child); type Child = (Left::Child, Right::Child);
fn for_segment( fn for_segment(&self, segment_local_id: u32, segment: &SegmentReader) -> Result<Self::Child> {
&self,
segment_local_id: u32,
segment: &SegmentReader,
) -> crate::Result<Self::Child> {
let left = self.0.for_segment(segment_local_id, segment)?; let left = self.0.for_segment(segment_local_id, segment)?;
let right = self.1.for_segment(segment_local_id, segment)?; let right = self.1.for_segment(segment_local_id, segment)?;
Ok((left, right)) Ok((left, right))
@@ -225,7 +198,7 @@ where
fn merge_fruits( fn merge_fruits(
&self, &self,
children: Vec<(Left::Fruit, Right::Fruit)>, children: Vec<(Left::Fruit, Right::Fruit)>,
) -> crate::Result<(Left::Fruit, Right::Fruit)> { ) -> Result<(Left::Fruit, Right::Fruit)> {
let mut left_fruits = vec![]; let mut left_fruits = vec![];
let mut right_fruits = vec![]; let mut right_fruits = vec![];
for (left_fruit, right_fruit) in children { for (left_fruit, right_fruit) in children {
@@ -267,11 +240,7 @@ where
type Fruit = (One::Fruit, Two::Fruit, Three::Fruit); type Fruit = (One::Fruit, Two::Fruit, Three::Fruit);
type Child = (One::Child, Two::Child, Three::Child); type Child = (One::Child, Two::Child, Three::Child);
fn for_segment( fn for_segment(&self, segment_local_id: u32, segment: &SegmentReader) -> Result<Self::Child> {
&self,
segment_local_id: u32,
segment: &SegmentReader,
) -> crate::Result<Self::Child> {
let one = self.0.for_segment(segment_local_id, segment)?; let one = self.0.for_segment(segment_local_id, segment)?;
let two = self.1.for_segment(segment_local_id, segment)?; let two = self.1.for_segment(segment_local_id, segment)?;
let three = self.2.for_segment(segment_local_id, segment)?; let three = self.2.for_segment(segment_local_id, segment)?;
@@ -282,7 +251,7 @@ where
self.0.requires_scoring() || self.1.requires_scoring() || self.2.requires_scoring() self.0.requires_scoring() || self.1.requires_scoring() || self.2.requires_scoring()
} }
fn merge_fruits(&self, children: Vec<Self::Fruit>) -> crate::Result<Self::Fruit> { fn merge_fruits(&self, children: Vec<Self::Fruit>) -> Result<Self::Fruit> {
let mut one_fruits = vec![]; let mut one_fruits = vec![];
let mut two_fruits = vec![]; let mut two_fruits = vec![];
let mut three_fruits = vec![]; let mut three_fruits = vec![];
@@ -330,11 +299,7 @@ where
type Fruit = (One::Fruit, Two::Fruit, Three::Fruit, Four::Fruit); type Fruit = (One::Fruit, Two::Fruit, Three::Fruit, Four::Fruit);
type Child = (One::Child, Two::Child, Three::Child, Four::Child); type Child = (One::Child, Two::Child, Three::Child, Four::Child);
fn for_segment( fn for_segment(&self, segment_local_id: u32, segment: &SegmentReader) -> Result<Self::Child> {
&self,
segment_local_id: u32,
segment: &SegmentReader,
) -> crate::Result<Self::Child> {
let one = self.0.for_segment(segment_local_id, segment)?; let one = self.0.for_segment(segment_local_id, segment)?;
let two = self.1.for_segment(segment_local_id, segment)?; let two = self.1.for_segment(segment_local_id, segment)?;
let three = self.2.for_segment(segment_local_id, segment)?; let three = self.2.for_segment(segment_local_id, segment)?;
@@ -349,7 +314,7 @@ where
|| self.3.requires_scoring() || self.3.requires_scoring()
} }
fn merge_fruits(&self, children: Vec<Self::Fruit>) -> crate::Result<Self::Fruit> { fn merge_fruits(&self, children: Vec<Self::Fruit>) -> Result<Self::Fruit> {
let mut one_fruits = vec![]; let mut one_fruits = vec![];
let mut two_fruits = vec![]; let mut two_fruits = vec![];
let mut three_fruits = vec![]; let mut three_fruits = vec![];

View File

@@ -2,6 +2,7 @@ use super::Collector;
use super::SegmentCollector; use super::SegmentCollector;
use crate::collector::Fruit; use crate::collector::Fruit;
use crate::DocId; use crate::DocId;
use crate::Result;
use crate::Score; use crate::Score;
use crate::SegmentLocalId; use crate::SegmentLocalId;
use crate::SegmentReader; use crate::SegmentReader;
@@ -23,7 +24,7 @@ impl<TCollector: Collector> Collector for CollectorWrapper<TCollector> {
&self, &self,
segment_local_id: u32, segment_local_id: u32,
reader: &SegmentReader, reader: &SegmentReader,
) -> crate::Result<Box<dyn BoxableSegmentCollector>> { ) -> Result<Box<dyn BoxableSegmentCollector>> {
let child = self.0.for_segment(segment_local_id, reader)?; let child = self.0.for_segment(segment_local_id, reader)?;
Ok(Box::new(SegmentCollectorWrapper(child))) Ok(Box::new(SegmentCollectorWrapper(child)))
} }
@@ -32,10 +33,7 @@ impl<TCollector: Collector> Collector for CollectorWrapper<TCollector> {
self.0.requires_scoring() self.0.requires_scoring()
} }
fn merge_fruits( fn merge_fruits(&self, children: Vec<<Self as Collector>::Fruit>) -> Result<Box<dyn Fruit>> {
&self,
children: Vec<<Self as Collector>::Fruit>,
) -> crate::Result<Box<dyn Fruit>> {
let typed_fruit: Vec<TCollector::Fruit> = children let typed_fruit: Vec<TCollector::Fruit> = children
.into_iter() .into_iter()
.map(|untyped_fruit| { .map(|untyped_fruit| {
@@ -46,7 +44,7 @@ impl<TCollector: Collector> Collector for CollectorWrapper<TCollector> {
TantivyError::InvalidArgument("Failed to cast child fruit.".to_string()) TantivyError::InvalidArgument("Failed to cast child fruit.".to_string())
}) })
}) })
.collect::<crate::Result<_>>()?; .collect::<Result<_>>()?;
let merged_fruit = self.0.merge_fruits(typed_fruit)?; let merged_fruit = self.0.merge_fruits(typed_fruit)?;
Ok(Box::new(merged_fruit)) Ok(Box::new(merged_fruit))
} }
@@ -177,12 +175,12 @@ impl<'a> Collector for MultiCollector<'a> {
&self, &self,
segment_local_id: SegmentLocalId, segment_local_id: SegmentLocalId,
segment: &SegmentReader, segment: &SegmentReader,
) -> crate::Result<MultiCollectorChild> { ) -> Result<MultiCollectorChild> {
let children = self let children = self
.collector_wrappers .collector_wrappers
.iter() .iter()
.map(|collector_wrapper| collector_wrapper.for_segment(segment_local_id, segment)) .map(|collector_wrapper| collector_wrapper.for_segment(segment_local_id, segment))
.collect::<crate::Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
Ok(MultiCollectorChild { children }) Ok(MultiCollectorChild { children })
} }
@@ -193,7 +191,7 @@ impl<'a> Collector for MultiCollector<'a> {
.any(Collector::requires_scoring) .any(Collector::requires_scoring)
} }
fn merge_fruits(&self, segments_multifruits: Vec<MultiFruit>) -> crate::Result<MultiFruit> { fn merge_fruits(&self, segments_multifruits: Vec<MultiFruit>) -> Result<MultiFruit> {
let mut segment_fruits_list: Vec<Vec<Box<dyn Fruit>>> = (0..self.collector_wrappers.len()) let mut segment_fruits_list: Vec<Vec<Box<dyn Fruit>>> = (0..self.collector_wrappers.len())
.map(|_| Vec::with_capacity(segments_multifruits.len())) .map(|_| Vec::with_capacity(segments_multifruits.len()))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
@@ -211,7 +209,7 @@ impl<'a> Collector for MultiCollector<'a> {
.map(|(child_collector, segment_fruits)| { .map(|(child_collector, segment_fruits)| {
Ok(Some(child_collector.merge_fruits(segment_fruits)?)) Ok(Some(child_collector.merge_fruits(segment_fruits)?))
}) })
.collect::<crate::Result<_>>()?; .collect::<Result<_>>()?;
Ok(MultiFruit { sub_fruits }) Ok(MultiFruit { sub_fruits })
} }
} }

View File

@@ -55,7 +55,7 @@ impl Collector for TestCollector {
&self, &self,
segment_id: SegmentLocalId, segment_id: SegmentLocalId,
_reader: &SegmentReader, _reader: &SegmentReader,
) -> crate::Result<TestSegmentCollector> { ) -> Result<TestSegmentCollector> {
Ok(TestSegmentCollector { Ok(TestSegmentCollector {
segment_id, segment_id,
fruit: TestFruit::default(), fruit: TestFruit::default(),
@@ -66,7 +66,7 @@ impl Collector for TestCollector {
self.compute_score self.compute_score
} }
fn merge_fruits(&self, mut children: Vec<TestFruit>) -> crate::Result<TestFruit> { fn merge_fruits(&self, mut children: Vec<TestFruit>) -> Result<TestFruit> {
children.sort_by_key(|fruit| { children.sort_by_key(|fruit| {
if fruit.docs().is_empty() { if fruit.docs().is_empty() {
0 0
@@ -124,7 +124,7 @@ impl Collector for FastFieldTestCollector {
&self, &self,
_: SegmentLocalId, _: SegmentLocalId,
segment_reader: &SegmentReader, segment_reader: &SegmentReader,
) -> crate::Result<FastFieldSegmentCollector> { ) -> Result<FastFieldSegmentCollector> {
let reader = segment_reader let reader = segment_reader
.fast_fields() .fast_fields()
.u64(self.field) .u64(self.field)
@@ -139,7 +139,7 @@ impl Collector for FastFieldTestCollector {
false false
} }
fn merge_fruits(&self, children: Vec<Vec<u64>>) -> crate::Result<Vec<u64>> { fn merge_fruits(&self, children: Vec<Vec<u64>>) -> Result<Vec<u64>> {
Ok(children.into_iter().flat_map(|v| v.into_iter()).collect()) Ok(children.into_iter().flat_map(|v| v.into_iter()).collect())
} }
} }
@@ -184,7 +184,7 @@ impl Collector for BytesFastFieldTestCollector {
&self, &self,
_segment_local_id: u32, _segment_local_id: u32,
segment_reader: &SegmentReader, segment_reader: &SegmentReader,
) -> crate::Result<BytesFastFieldSegmentCollector> { ) -> Result<BytesFastFieldSegmentCollector> {
Ok(BytesFastFieldSegmentCollector { Ok(BytesFastFieldSegmentCollector {
vals: Vec::new(), vals: Vec::new(),
reader: segment_reader reader: segment_reader
@@ -198,7 +198,7 @@ impl Collector for BytesFastFieldTestCollector {
false false
} }
fn merge_fruits(&self, children: Vec<Vec<u8>>) -> crate::Result<Vec<u8>> { fn merge_fruits(&self, children: Vec<Vec<u8>>) -> Result<Vec<u8>> {
Ok(children.into_iter().flat_map(|c| c.into_iter()).collect()) Ok(children.into_iter().flat_map(|c| c.into_iter()).collect())
} }
} }

View File

@@ -1,5 +1,6 @@
use crate::DocAddress; use crate::DocAddress;
use crate::DocId; use crate::DocId;
use crate::Result;
use crate::SegmentLocalId; use crate::SegmentLocalId;
use crate::SegmentReader; use crate::SegmentReader;
use serde::export::PhantomData; use serde::export::PhantomData;
@@ -18,9 +19,9 @@ use std::collections::BinaryHeap;
/// Two elements are equal if their feature is equal, and regardless of whether `doc` /// Two elements are equal if their feature is equal, and regardless of whether `doc`
/// is equal. This should be perfectly fine for this usage, but let's make sure this /// is equal. This should be perfectly fine for this usage, but let's make sure this
/// struct is never public. /// struct is never public.
pub(crate) struct ComparableDoc<T, D> { struct ComparableDoc<T, D> {
pub feature: T, feature: T,
pub doc: D, doc: D,
} }
impl<T: PartialOrd, D: PartialOrd> PartialOrd for ComparableDoc<T, D> { impl<T: PartialOrd, D: PartialOrd> PartialOrd for ComparableDoc<T, D> {
@@ -56,8 +57,7 @@ impl<T: PartialOrd, D: PartialOrd> PartialEq for ComparableDoc<T, D> {
impl<T: PartialOrd, D: PartialOrd> Eq for ComparableDoc<T, D> {} impl<T: PartialOrd, D: PartialOrd> Eq for ComparableDoc<T, D> {}
pub(crate) struct TopCollector<T> { pub(crate) struct TopCollector<T> {
pub limit: usize, limit: usize,
pub offset: usize,
_marker: PhantomData<T>, _marker: PhantomData<T>,
} }
@@ -73,33 +73,27 @@ where
if limit < 1 { if limit < 1 {
panic!("Limit must be strictly greater than 0."); panic!("Limit must be strictly greater than 0.");
} }
Self { TopCollector {
limit, limit,
offset: 0,
_marker: PhantomData, _marker: PhantomData,
} }
} }
/// Skip the first "offset" documents when collecting. pub fn limit(&self) -> usize {
/// self.limit
/// This is equivalent to `OFFSET` in MySQL or PostgreSQL and `start` in
/// Lucene's TopDocsCollector.
pub fn and_offset(mut self, offset: usize) -> TopCollector<T> {
self.offset = offset;
self
} }
pub fn merge_fruits( pub fn merge_fruits(
&self, &self,
children: Vec<Vec<(T, DocAddress)>>, children: Vec<Vec<(T, DocAddress)>>,
) -> crate::Result<Vec<(T, DocAddress)>> { ) -> Result<Vec<(T, DocAddress)>> {
if self.limit == 0 { if self.limit == 0 {
return Ok(Vec::new()); return Ok(Vec::new());
} }
let mut top_collector = BinaryHeap::new(); let mut top_collector = BinaryHeap::new();
for child_fruit in children { for child_fruit in children {
for (feature, doc) in child_fruit { for (feature, doc) in child_fruit {
if top_collector.len() < (self.limit + self.offset) { if top_collector.len() < self.limit {
top_collector.push(ComparableDoc { feature, doc }); top_collector.push(ComparableDoc { feature, doc });
} else if let Some(mut head) = top_collector.peek_mut() { } else if let Some(mut head) = top_collector.peek_mut() {
if head.feature < feature { if head.feature < feature {
@@ -111,7 +105,6 @@ where
Ok(top_collector Ok(top_collector
.into_sorted_vec() .into_sorted_vec()
.into_iter() .into_iter()
.skip(self.offset)
.map(|cdoc| (cdoc.feature, cdoc.doc)) .map(|cdoc| (cdoc.feature, cdoc.doc))
.collect()) .collect())
} }
@@ -120,24 +113,8 @@ where
&self, &self,
segment_id: SegmentLocalId, segment_id: SegmentLocalId,
_: &SegmentReader, _: &SegmentReader,
) -> crate::Result<TopSegmentCollector<F>> { ) -> Result<TopSegmentCollector<F>> {
Ok(TopSegmentCollector::new( Ok(TopSegmentCollector::new(segment_id, self.limit))
segment_id,
self.limit + self.offset,
))
}
/// Create a new TopCollector with the same limit and offset.
///
/// Ideally we would use Into but the blanket implementation seems to cause the Scorer traits
/// to fail.
#[doc(hidden)]
pub(crate) fn into_tscore<TScore: PartialOrd + Clone>(self) -> TopCollector<TScore> {
TopCollector {
limit: self.limit,
offset: self.offset,
_marker: PhantomData,
}
} }
} }
@@ -211,7 +188,7 @@ impl<T: PartialOrd + Clone> TopSegmentCollector<T> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::{TopCollector, TopSegmentCollector}; use super::TopSegmentCollector;
use crate::DocAddress; use crate::DocAddress;
#[test] #[test]
@@ -272,48 +249,6 @@ mod tests {
top_collector_limit_3.harvest()[..2].to_vec(), top_collector_limit_3.harvest()[..2].to_vec(),
); );
} }
#[test]
fn test_top_collector_with_limit_and_offset() {
let collector = TopCollector::with_limit(2).and_offset(1);
let results = collector
.merge_fruits(vec![vec![
(0.9, DocAddress(0, 1)),
(0.8, DocAddress(0, 2)),
(0.7, DocAddress(0, 3)),
(0.6, DocAddress(0, 4)),
(0.5, DocAddress(0, 5)),
]])
.unwrap();
assert_eq!(
results,
vec![(0.8, DocAddress(0, 2)), (0.7, DocAddress(0, 3)),]
);
}
#[test]
fn test_top_collector_with_limit_larger_than_set_and_offset() {
let collector = TopCollector::with_limit(2).and_offset(1);
let results = collector
.merge_fruits(vec![vec![(0.9, DocAddress(0, 1)), (0.8, DocAddress(0, 2))]])
.unwrap();
assert_eq!(results, vec![(0.8, DocAddress(0, 2)),]);
}
#[test]
fn test_top_collector_with_limit_and_offset_larger_than_set() {
let collector = TopCollector::with_limit(2).and_offset(20);
let results = collector
.merge_fruits(vec![vec![(0.9, DocAddress(0, 1)), (0.8, DocAddress(0, 2))]])
.unwrap();
assert_eq!(results, vec![]);
}
} }
#[cfg(all(test, feature = "unstable"))] #[cfg(all(test, feature = "unstable"))]

View File

@@ -1,20 +1,18 @@
use super::Collector; use super::Collector;
use crate::collector::custom_score_top_collector::CustomScoreTopCollector; use crate::collector::custom_score_top_collector::CustomScoreTopCollector;
use crate::collector::top_collector::TopCollector;
use crate::collector::top_collector::TopSegmentCollector; use crate::collector::top_collector::TopSegmentCollector;
use crate::collector::top_collector::{ComparableDoc, TopCollector};
use crate::collector::tweak_score_top_collector::TweakedScoreTopCollector; use crate::collector::tweak_score_top_collector::TweakedScoreTopCollector;
use crate::collector::{ use crate::collector::{
CustomScorer, CustomSegmentScorer, ScoreSegmentTweaker, ScoreTweaker, SegmentCollector, CustomScorer, CustomSegmentScorer, ScoreSegmentTweaker, ScoreTweaker, SegmentCollector,
}; };
use crate::fastfield::FastFieldReader;
use crate::query::Weight;
use crate::schema::Field; use crate::schema::Field;
use crate::DocAddress; use crate::DocAddress;
use crate::DocId; use crate::DocId;
use crate::Result;
use crate::Score; use crate::Score;
use crate::SegmentLocalId; use crate::SegmentLocalId;
use crate::SegmentReader; use crate::SegmentReader;
use std::collections::BinaryHeap;
use std::fmt; use std::fmt;
/// The `TopDocs` collector keeps track of the top `K` documents /// The `TopDocs` collector keeps track of the top `K` documents
@@ -59,42 +57,7 @@ pub struct TopDocs(TopCollector<Score>);
impl fmt::Debug for TopDocs { impl fmt::Debug for TopDocs {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!( write!(f, "TopDocs({})", self.0.limit())
f,
"TopDocs(limit={}, offset={})",
self.0.limit, self.0.offset
)
}
}
struct ScorerByFastFieldReader {
ff_reader: FastFieldReader<u64>,
}
impl CustomSegmentScorer<u64> for ScorerByFastFieldReader {
fn score(&mut self, doc: DocId) -> u64 {
self.ff_reader.get_u64(u64::from(doc))
}
}
struct ScorerByField {
field: Field,
}
impl CustomScorer<u64> for ScorerByField {
type Child = ScorerByFastFieldReader;
fn segment_scorer(&self, segment_reader: &SegmentReader) -> crate::Result<Self::Child> {
let ff_reader = segment_reader
.fast_fields()
.u64(self.field)
.ok_or_else(|| {
crate::TantivyError::SchemaError(format!(
"Field requested ({:?}) is not a i64/u64 fast field.",
self.field
))
})?;
Ok(ScorerByFastFieldReader { ff_reader })
} }
} }
@@ -107,50 +70,11 @@ impl TopDocs {
TopDocs(TopCollector::with_limit(limit)) TopDocs(TopCollector::with_limit(limit))
} }
/// Skip the first "offset" documents when collecting.
///
/// This is equivalent to `OFFSET` in MySQL or PostgreSQL and `start` in
/// Lucene's TopDocsCollector.
///
/// ```rust
/// use tantivy::collector::TopDocs;
/// use tantivy::query::QueryParser;
/// use tantivy::schema::{Schema, TEXT};
/// use tantivy::{doc, DocAddress, Index};
///
/// let mut schema_builder = Schema::builder();
/// let title = schema_builder.add_text_field("title", TEXT);
/// let schema = schema_builder.build();
/// let index = Index::create_in_ram(schema);
///
/// let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap();
/// index_writer.add_document(doc!(title => "The Name of the Wind"));
/// index_writer.add_document(doc!(title => "The Diary of Muadib"));
/// index_writer.add_document(doc!(title => "A Dairy Cow"));
/// index_writer.add_document(doc!(title => "The Diary of a Young Girl"));
/// index_writer.add_document(doc!(title => "The Diary of Lena Mukhina"));
/// assert!(index_writer.commit().is_ok());
///
/// let reader = index.reader().unwrap();
/// let searcher = reader.searcher();
///
/// let query_parser = QueryParser::for_index(&index, vec![title]);
/// let query = query_parser.parse_query("diary").unwrap();
/// let top_docs = searcher.search(&query, &TopDocs::with_limit(2).and_offset(1)).unwrap();
///
/// assert_eq!(top_docs.len(), 2);
/// assert_eq!(&top_docs[0], &(0.5204813, DocAddress(0, 4)));
/// assert_eq!(&top_docs[1], &(0.4793185, DocAddress(0, 3)));
/// ```
pub fn and_offset(self, offset: usize) -> TopDocs {
TopDocs(self.0.and_offset(offset))
}
/// Set top-K to rank documents by a given fast field. /// Set top-K to rank documents by a given fast field.
/// ///
/// ```rust /// ```rust
/// # use tantivy::schema::{Schema, FAST, TEXT}; /// # use tantivy::schema::{Schema, FAST, TEXT};
/// # use tantivy::{doc, Index, DocAddress}; /// # use tantivy::{doc, Index, Result, DocAddress};
/// # use tantivy::query::{Query, QueryParser}; /// # use tantivy::query::{Query, QueryParser};
/// use tantivy::Searcher; /// use tantivy::Searcher;
/// use tantivy::collector::TopDocs; /// use tantivy::collector::TopDocs;
@@ -187,7 +111,7 @@ impl TopDocs {
/// fn docs_sorted_by_rating(searcher: &Searcher, /// fn docs_sorted_by_rating(searcher: &Searcher,
/// query: &dyn Query, /// query: &dyn Query,
/// sort_by_field: Field) /// sort_by_field: Field)
/// -> tantivy::Result<Vec<(u64, DocAddress)>> { /// -> Result<Vec<(u64, DocAddress)>> {
/// ///
/// // This is where we build our topdocs collector /// // This is where we build our topdocs collector
/// // /// //
@@ -219,7 +143,14 @@ impl TopDocs {
self, self,
field: Field, field: Field,
) -> impl Collector<Fruit = Vec<(u64, DocAddress)>> { ) -> impl Collector<Fruit = Vec<(u64, DocAddress)>> {
self.custom_score(ScorerByField { field }) self.custom_score(move |segment_reader: &SegmentReader| {
let ff_reader = segment_reader
.fast_fields()
.u64(field)
.expect("Field requested is not a i64/u64 fast field.");
//TODO error message missmatch actual behavior for i64
move |doc: DocId| ff_reader.get(doc)
})
} }
/// Ranks the documents using a custom score. /// Ranks the documents using a custom score.
@@ -326,7 +257,7 @@ impl TopDocs {
TScoreSegmentTweaker: ScoreSegmentTweaker<TScore> + 'static, TScoreSegmentTweaker: ScoreSegmentTweaker<TScore> + 'static,
TScoreTweaker: ScoreTweaker<TScore, Child = TScoreSegmentTweaker>, TScoreTweaker: ScoreTweaker<TScore, Child = TScoreSegmentTweaker>,
{ {
TweakedScoreTopCollector::new(score_tweaker, self.0.into_tscore()) TweakedScoreTopCollector::new(score_tweaker, self.0.limit())
} }
/// Ranks the documents using a custom score. /// Ranks the documents using a custom score.
@@ -440,7 +371,7 @@ impl TopDocs {
TCustomSegmentScorer: CustomSegmentScorer<TScore> + 'static, TCustomSegmentScorer: CustomSegmentScorer<TScore> + 'static,
TCustomScorer: CustomScorer<TScore, Child = TCustomSegmentScorer>, TCustomScorer: CustomScorer<TScore, Child = TCustomSegmentScorer>,
{ {
CustomScoreTopCollector::new(custom_score, self.0.into_tscore()) CustomScoreTopCollector::new(custom_score, self.0.limit())
} }
} }
@@ -453,7 +384,7 @@ impl Collector for TopDocs {
&self, &self,
segment_local_id: SegmentLocalId, segment_local_id: SegmentLocalId,
reader: &SegmentReader, reader: &SegmentReader,
) -> crate::Result<Self::Child> { ) -> Result<Self::Child> {
let collector = self.0.for_segment(segment_local_id, reader)?; let collector = self.0.for_segment(segment_local_id, reader)?;
Ok(TopScoreSegmentCollector(collector)) Ok(TopScoreSegmentCollector(collector))
} }
@@ -462,70 +393,9 @@ impl Collector for TopDocs {
true true
} }
fn merge_fruits( fn merge_fruits(&self, child_fruits: Vec<Vec<(Score, DocAddress)>>) -> Result<Self::Fruit> {
&self,
child_fruits: Vec<Vec<(Score, DocAddress)>>,
) -> crate::Result<Self::Fruit> {
self.0.merge_fruits(child_fruits) self.0.merge_fruits(child_fruits)
} }
fn collect_segment(
&self,
weight: &dyn Weight,
segment_ord: u32,
reader: &SegmentReader,
) -> crate::Result<<Self::Child as SegmentCollector>::Fruit> {
let heap_len = self.0.limit + self.0.offset;
let mut heap: BinaryHeap<ComparableDoc<Score, DocId>> = BinaryHeap::with_capacity(heap_len);
if let Some(delete_bitset) = reader.delete_bitset() {
let mut threshold = f32::MIN;
weight.for_each_pruning(threshold, reader, &mut |doc, score| {
if delete_bitset.is_deleted(doc) {
return threshold;
}
let heap_item = ComparableDoc {
feature: score,
doc,
};
if heap.len() < heap_len {
heap.push(heap_item);
if heap.len() == heap_len {
threshold = heap.peek().map(|el| el.feature).unwrap_or(f32::MIN);
}
return threshold;
}
*heap.peek_mut().unwrap() = heap_item;
threshold = heap.peek().map(|el| el.feature).unwrap_or(std::f32::MIN);
threshold
})?;
} else {
weight.for_each_pruning(f32::MIN, reader, &mut |doc, score| {
let heap_item = ComparableDoc {
feature: score,
doc,
};
if heap.len() < heap_len {
heap.push(heap_item);
// TODO the threshold is suboptimal for heap.len == heap_len
if heap.len() == heap_len {
return heap.peek().map(|el| el.feature).unwrap_or(f32::MIN);
} else {
return f32::MIN;
}
}
*heap.peek_mut().unwrap() = heap_item;
heap.peek().map(|el| el.feature).unwrap_or(std::f32::MIN)
})?;
}
let fruit = heap
.into_sorted_vec()
.into_iter()
.map(|cid| (cid.feature, DocAddress(segment_ord, cid.doc)))
.collect();
Ok(fruit)
}
} }
/// Segment Collector associated to `TopDocs`. /// Segment Collector associated to `TopDocs`.
@@ -535,7 +405,7 @@ impl SegmentCollector for TopScoreSegmentCollector {
type Fruit = Vec<(Score, DocAddress)>; type Fruit = Vec<(Score, DocAddress)>;
fn collect(&mut self, doc: DocId, score: Score) { fn collect(&mut self, doc: DocId, score: Score) {
self.0.collect(doc, score); self.0.collect(doc, score)
} }
fn harvest(self) -> Vec<(Score, DocAddress)> { fn harvest(self) -> Vec<(Score, DocAddress)> {
@@ -549,10 +419,11 @@ mod tests {
use crate::collector::Collector; use crate::collector::Collector;
use crate::query::{AllQuery, Query, QueryParser}; use crate::query::{AllQuery, Query, QueryParser};
use crate::schema::{Field, Schema, FAST, STORED, TEXT}; use crate::schema::{Field, Schema, FAST, STORED, TEXT};
use crate::DocAddress;
use crate::Index; use crate::Index;
use crate::IndexWriter; use crate::IndexWriter;
use crate::Score; use crate::Score;
use crate::{DocAddress, DocId, SegmentReader}; use itertools::Itertools;
fn make_index() -> Index { fn make_index() -> Index {
let mut schema_builder = Schema::builder(); let mut schema_builder = Schema::builder();
@@ -592,21 +463,6 @@ mod tests {
); );
} }
#[test]
fn test_top_collector_not_at_capacity_with_offset() {
let index = make_index();
let field = index.schema().get_field("text").unwrap();
let query_parser = QueryParser::for_index(&index, vec![field]);
let text_query = query_parser.parse_query("droopy tax").unwrap();
let score_docs: Vec<(Score, DocAddress)> = index
.reader()
.unwrap()
.searcher()
.search(&text_query, &TopDocs::with_limit(4).and_offset(2))
.unwrap();
assert_eq!(score_docs, vec![(0.48527452, DocAddress(0, 0))]);
}
#[test] #[test]
fn test_top_collector_at_capacity() { fn test_top_collector_at_capacity() {
let index = make_index(); let index = make_index();
@@ -628,27 +484,6 @@ mod tests {
); );
} }
#[test]
fn test_top_collector_at_capacity_with_offset() {
let index = make_index();
let field = index.schema().get_field("text").unwrap();
let query_parser = QueryParser::for_index(&index, vec![field]);
let text_query = query_parser.parse_query("droopy tax").unwrap();
let score_docs: Vec<(Score, DocAddress)> = index
.reader()
.unwrap()
.searcher()
.search(&text_query, &TopDocs::with_limit(2).and_offset(1))
.unwrap();
assert_eq!(
score_docs,
vec![
(0.5376842, DocAddress(0u32, 2)),
(0.48527452, DocAddress(0, 0))
]
);
}
#[test] #[test]
fn test_top_collector_stable_sorting() { fn test_top_collector_stable_sorting() {
let index = make_index(); let index = make_index();
@@ -662,8 +497,8 @@ mod tests {
// precondition for the test to be meaningful: we did get documents // precondition for the test to be meaningful: we did get documents
// with the same score // with the same score
assert!(page_1.iter().all(|result| result.0 == page_1[0].0)); assert!(page_1.iter().map(|result| result.0).all_equal());
assert!(page_2.iter().all(|result| result.0 == page_2[0].0)); assert!(page_2.iter().map(|result| result.0).all_equal());
// sanity check since we're relying on make_index() // sanity check since we're relying on make_index()
assert_eq!(page_1.len(), 2); assert_eq!(page_1.len(), 2);
@@ -737,6 +572,7 @@ mod tests {
} }
#[test] #[test]
#[should_panic(expected = "Field requested is not a i64/u64 fast field")]
fn test_field_not_fast_field() { fn test_field_not_fast_field() {
let mut schema_builder = Schema::builder(); let mut schema_builder = Schema::builder();
let title = schema_builder.add_text_field(TITLE, TEXT); let title = schema_builder.add_text_field(TITLE, TEXT);
@@ -751,59 +587,7 @@ mod tests {
let searcher = index.reader().unwrap().searcher(); let searcher = index.reader().unwrap().searcher();
let segment = searcher.segment_reader(0); let segment = searcher.segment_reader(0);
let top_collector = TopDocs::with_limit(4).order_by_u64_field(size); let top_collector = TopDocs::with_limit(4).order_by_u64_field(size);
let err = top_collector.for_segment(0, segment); assert!(top_collector.for_segment(0, segment).is_ok());
if let Err(crate::TantivyError::SchemaError(msg)) = err {
assert_eq!(
msg,
"Field requested (Field(1)) is not a i64/u64 fast field."
);
} else {
assert!(false);
}
}
#[test]
fn test_tweak_score_top_collector_with_offset() {
let index = make_index();
let field = index.schema().get_field("text").unwrap();
let query_parser = QueryParser::for_index(&index, vec![field]);
let text_query = query_parser.parse_query("droopy tax").unwrap();
let collector = TopDocs::with_limit(2).and_offset(1).tweak_score(
move |_segment_reader: &SegmentReader| move |doc: DocId, _original_score: Score| doc,
);
let score_docs: Vec<(u32, DocAddress)> = index
.reader()
.unwrap()
.searcher()
.search(&text_query, &collector)
.unwrap();
assert_eq!(
score_docs,
vec![(1, DocAddress(0, 1)), (0, DocAddress(0, 0)),]
);
}
#[test]
fn test_custom_score_top_collector_with_offset() {
let index = make_index();
let field = index.schema().get_field("text").unwrap();
let query_parser = QueryParser::for_index(&index, vec![field]);
let text_query = query_parser.parse_query("droopy tax").unwrap();
let collector = TopDocs::with_limit(2)
.and_offset(1)
.custom_score(move |_segment_reader: &SegmentReader| move |doc: DocId| doc);
let score_docs: Vec<(u32, DocAddress)> = index
.reader()
.unwrap()
.searcher()
.search(&text_query, &collector)
.unwrap();
assert_eq!(
score_docs,
vec![(1, DocAddress(0, 1)), (0, DocAddress(0, 0)),]
);
} }
fn index( fn index(

View File

@@ -14,11 +14,11 @@ where
{ {
pub fn new( pub fn new(
score_tweaker: TScoreTweaker, score_tweaker: TScoreTweaker,
collector: TopCollector<TScore>, limit: usize,
) -> TweakedScoreTopCollector<TScoreTweaker, TScore> { ) -> TweakedScoreTopCollector<TScoreTweaker, TScore> {
TweakedScoreTopCollector { TweakedScoreTopCollector {
score_tweaker, score_tweaker,
collector, collector: TopCollector::with_limit(limit),
} }
} }
} }
@@ -29,7 +29,7 @@ where
/// It is the segment local version of the [`ScoreTweaker`](./trait.ScoreTweaker.html). /// It is the segment local version of the [`ScoreTweaker`](./trait.ScoreTweaker.html).
pub trait ScoreSegmentTweaker<TScore>: 'static { pub trait ScoreSegmentTweaker<TScore>: 'static {
/// Tweak the given `score` for the document `doc`. /// Tweak the given `score` for the document `doc`.
fn score(&mut self, doc: DocId, score: Score) -> TScore; fn score(&self, doc: DocId, score: Score) -> TScore;
} }
/// `ScoreTweaker` makes it possible to tweak the score /// `ScoreTweaker` makes it possible to tweak the score
@@ -121,9 +121,9 @@ where
impl<F, TScore> ScoreSegmentTweaker<TScore> for F impl<F, TScore> ScoreSegmentTweaker<TScore> for F
where where
F: 'static + FnMut(DocId, Score) -> TScore, F: 'static + Sync + Send + Fn(DocId, Score) -> TScore,
{ {
fn score(&mut self, doc: DocId, score: Score) -> TScore { fn score(&self, doc: DocId, score: Score) -> TScore {
(self)(doc, score) (self)(doc, score)
} }
} }

View File

@@ -33,10 +33,6 @@ impl TinySet {
TinySet(0u64) TinySet(0u64)
} }
pub fn clear(&mut self) {
self.0 = 0u64;
}
/// Returns the complement of the set in `[0, 64[`. /// Returns the complement of the set in `[0, 64[`.
fn complement(self) -> TinySet { fn complement(self) -> TinySet {
TinySet(!self.0) TinySet(!self.0)
@@ -47,11 +43,6 @@ impl TinySet {
!self.intersect(TinySet::singleton(el)).is_empty() !self.intersect(TinySet::singleton(el)).is_empty()
} }
/// Returns the number of elements in the TinySet.
pub fn len(self) -> u32 {
self.0.count_ones()
}
/// Returns the intersection of `self` and `other` /// Returns the intersection of `self` and `other`
pub fn intersect(self, other: TinySet) -> TinySet { pub fn intersect(self, other: TinySet) -> TinySet {
TinySet(self.0 & other.0) TinySet(self.0 & other.0)
@@ -118,12 +109,22 @@ impl TinySet {
pub fn range_greater_or_equal(from_included: u32) -> TinySet { pub fn range_greater_or_equal(from_included: u32) -> TinySet {
TinySet::range_lower(from_included).complement() TinySet::range_lower(from_included).complement()
} }
pub fn clear(&mut self) {
self.0 = 0u64;
}
pub fn len(self) -> u32 {
self.0.count_ones()
}
} }
#[derive(Clone)] #[derive(Clone)]
pub struct BitSet { pub struct BitSet {
tinysets: Box<[TinySet]>, tinysets: Box<[TinySet]>,
len: usize, len: usize, //< Technically it should be u32, but we
// count multiple inserts.
// `usize` guards us from overflow.
max_value: u32, max_value: u32,
} }
@@ -203,7 +204,7 @@ mod tests {
use super::BitSet; use super::BitSet;
use super::TinySet; use super::TinySet;
use crate::docset::{DocSet, TERMINATED}; use crate::docset::DocSet;
use crate::query::BitSetDocSet; use crate::query::BitSetDocSet;
use crate::tests; use crate::tests;
use crate::tests::generate_nonunique_unsorted; use crate::tests::generate_nonunique_unsorted;
@@ -277,13 +278,11 @@ mod tests {
} }
assert_eq!(btreeset.len(), bitset.len()); assert_eq!(btreeset.len(), bitset.len());
let mut bitset_docset = BitSetDocSet::from(bitset); let mut bitset_docset = BitSetDocSet::from(bitset);
let mut remaining = true;
for el in btreeset.into_iter() { for el in btreeset.into_iter() {
assert!(remaining); bitset_docset.advance();
assert_eq!(bitset_docset.doc(), el); assert_eq!(bitset_docset.doc(), el);
remaining = bitset_docset.advance() != TERMINATED;
} }
assert!(!remaining); assert!(!bitset_docset.advance());
} }
#[test] #[test]

View File

@@ -186,7 +186,7 @@ mod test {
use super::{CompositeFile, CompositeWrite}; use super::{CompositeFile, CompositeWrite};
use crate::common::BinarySerializable; use crate::common::BinarySerializable;
use crate::common::VInt; use crate::common::VInt;
use crate::directory::{Directory, RAMDirectory}; use crate::directory::{Directory, RAMDirectory, ReadOnlyDirectory};
use crate::schema::Field; use crate::schema::Field;
use std::io::Write; use std::io::Write;
use std::path::Path; use std::path::Path;

View File

@@ -18,19 +18,6 @@ pub use byteorder::LittleEndian as Endianness;
/// We do not allow segments with more than /// We do not allow segments with more than
pub const MAX_DOC_LIMIT: u32 = 1 << 31; pub const MAX_DOC_LIMIT: u32 = 1 << 31;
pub fn minmax<I, T>(mut vals: I) -> Option<(T, T)>
where
I: Iterator<Item = T>,
T: Copy + Ord,
{
if let Some(first_el) = vals.next() {
return Some(vals.fold((first_el, first_el), |(min_val, max_val), el| {
(min_val.min(el), max_val.max(el))
}));
}
None
}
/// Computes the number of bits that will be used for bitpacking. /// Computes the number of bits that will be used for bitpacking.
/// ///
/// In general the target is the minimum number of bits /// In general the target is the minimum number of bits
@@ -147,7 +134,6 @@ pub fn u64_to_f64(val: u64) -> f64 {
#[cfg(test)] #[cfg(test)]
pub(crate) mod test { pub(crate) mod test {
pub use super::minmax;
pub use super::serialize::test::fixed_size_test; pub use super::serialize::test::fixed_size_test;
use super::{compute_num_bits, f64_to_u64, i64_to_u64, u64_to_f64, u64_to_i64}; use super::{compute_num_bits, f64_to_u64, i64_to_u64, u64_to_f64, u64_to_i64};
use std::f64; use std::f64;
@@ -213,21 +199,4 @@ pub(crate) mod test {
assert!(((super::MAX_DOC_LIMIT - 1) as i32) >= 0); assert!(((super::MAX_DOC_LIMIT - 1) as i32) >= 0);
assert!((super::MAX_DOC_LIMIT as i32) < 0); assert!((super::MAX_DOC_LIMIT as i32) < 0);
} }
#[test]
fn test_minmax_empty() {
let vals: Vec<u32> = vec![];
assert_eq!(minmax(vals.into_iter()), None);
}
#[test]
fn test_minmax_one() {
assert_eq!(minmax(vec![1].into_iter()), Some((1, 1)));
}
#[test]
fn test_minmax_two() {
assert_eq!(minmax(vec![1, 2].into_iter()), Some((1, 2)));
assert_eq!(minmax(vec![2, 1].into_iter()), Some((1, 2)));
}
} }

View File

@@ -1,3 +1,4 @@
use crate::Result;
use crossbeam::channel; use crossbeam::channel;
use rayon::{ThreadPool, ThreadPoolBuilder}; use rayon::{ThreadPool, ThreadPoolBuilder};
@@ -9,9 +10,7 @@ use rayon::{ThreadPool, ThreadPoolBuilder};
/// API of a dependency, knowing it might conflict with a different version /// API of a dependency, knowing it might conflict with a different version
/// used by the client. Second, we may stop using rayon in the future. /// used by the client. Second, we may stop using rayon in the future.
pub enum Executor { pub enum Executor {
/// Single thread variant of an Executor
SingleThread, SingleThread,
/// Thread pool variant of an Executor
ThreadPool(ThreadPool), ThreadPool(ThreadPool),
} }
@@ -21,8 +20,8 @@ impl Executor {
Executor::SingleThread Executor::SingleThread
} }
/// Creates an Executor that dispatches the tasks in a thread pool. // Creates an Executor that dispatches the tasks in a thread pool.
pub fn multi_thread(num_threads: usize, prefix: &'static str) -> crate::Result<Executor> { pub fn multi_thread(num_threads: usize, prefix: &'static str) -> Result<Executor> {
let pool = ThreadPoolBuilder::new() let pool = ThreadPoolBuilder::new()
.num_threads(num_threads) .num_threads(num_threads)
.thread_name(move |num| format!("{}{}", prefix, num)) .thread_name(move |num| format!("{}{}", prefix, num))
@@ -30,22 +29,22 @@ impl Executor {
Ok(Executor::ThreadPool(pool)) Ok(Executor::ThreadPool(pool))
} }
/// Perform a map in the thread pool. // Perform a map in the thread pool.
/// //
/// Regardless of the executor (`SingleThread` or `ThreadPool`), panics in the task // Regardless of the executor (`SingleThread` or `ThreadPool`), panics in the task
/// will propagate to the caller. // will propagate to the caller.
pub fn map< pub fn map<
A: Send, A: Send,
R: Send, R: Send,
AIterator: Iterator<Item = A>, AIterator: Iterator<Item = A>,
F: Sized + Sync + Fn(A) -> crate::Result<R>, F: Sized + Sync + Fn(A) -> Result<R>,
>( >(
&self, &self,
f: F, f: F,
args: AIterator, args: AIterator,
) -> crate::Result<Vec<R>> { ) -> Result<Vec<R>> {
match self { match self {
Executor::SingleThread => args.map(f).collect::<crate::Result<_>>(), Executor::SingleThread => args.map(f).collect::<Result<_>>(),
Executor::ThreadPool(pool) => { Executor::ThreadPool(pool) => {
let args_with_indices: Vec<(usize, A)> = args.enumerate().collect(); let args_with_indices: Vec<(usize, A)> = args.enumerate().collect();
let num_fruits = args_with_indices.len(); let num_fruits = args_with_indices.len();

View File

@@ -1,3 +1,4 @@
use super::segment::create_segment;
use super::segment::Segment; use super::segment::Segment;
use crate::core::Executor; use crate::core::Executor;
use crate::core::IndexMeta; use crate::core::IndexMeta;
@@ -19,21 +20,19 @@ use crate::reader::IndexReaderBuilder;
use crate::schema::Field; use crate::schema::Field;
use crate::schema::FieldType; use crate::schema::FieldType;
use crate::schema::Schema; use crate::schema::Schema;
use crate::tokenizer::{TextAnalyzer, TokenizerManager}; use crate::tokenizer::BoxedTokenizer;
use crate::tokenizer::TokenizerManager;
use crate::IndexWriter; use crate::IndexWriter;
use crate::Result;
use num_cpus;
use std::borrow::BorrowMut; use std::borrow::BorrowMut;
use std::collections::HashSet; use std::collections::HashSet;
use std::fmt; use std::fmt;
#[cfg(feature = "mmap")] #[cfg(feature = "mmap")]
use std::path::Path; use std::path::{Path, PathBuf};
use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
fn load_metas( fn load_metas(directory: &dyn Directory, inventory: &SegmentMetaInventory) -> Result<IndexMeta> {
directory: &dyn Directory,
inventory: &SegmentMetaInventory,
) -> crate::Result<IndexMeta> {
let meta_data = directory.atomic_read(&META_FILEPATH)?; let meta_data = directory.atomic_read(&META_FILEPATH)?;
let meta_string = String::from_utf8_lossy(&meta_data); let meta_string = String::from_utf8_lossy(&meta_data);
IndexMeta::deserialize(&meta_string, &inventory) IndexMeta::deserialize(&meta_string, &inventory)
@@ -74,14 +73,14 @@ impl Index {
/// Replace the default single thread search executor pool /// Replace the default single thread search executor pool
/// by a thread pool with a given number of threads. /// by a thread pool with a given number of threads.
pub fn set_multithread_executor(&mut self, num_threads: usize) -> crate::Result<()> { pub fn set_multithread_executor(&mut self, num_threads: usize) -> Result<()> {
self.executor = Arc::new(Executor::multi_thread(num_threads, "thrd-tantivy-search-")?); self.executor = Arc::new(Executor::multi_thread(num_threads, "thrd-tantivy-search-")?);
Ok(()) Ok(())
} }
/// Replace the default single thread search executor pool /// Replace the default single thread search executor pool
/// by a thread pool with a given number of threads. /// by a thread pool with a given number of threads.
pub fn set_default_multithread_executor(&mut self) -> crate::Result<()> { pub fn set_default_multithread_executor(&mut self) -> Result<()> {
let default_num_threads = num_cpus::get(); let default_num_threads = num_cpus::get();
self.set_multithread_executor(default_num_threads) self.set_multithread_executor(default_num_threads)
} }
@@ -100,10 +99,7 @@ impl Index {
/// ///
/// If a previous index was in this directory, then its meta file will be destroyed. /// If a previous index was in this directory, then its meta file will be destroyed.
#[cfg(feature = "mmap")] #[cfg(feature = "mmap")]
pub fn create_in_dir<P: AsRef<Path>>( pub fn create_in_dir<P: AsRef<Path>>(directory_path: P, schema: Schema) -> Result<Index> {
directory_path: P,
schema: Schema,
) -> crate::Result<Index> {
let mmap_directory = MmapDirectory::open(directory_path)?; let mmap_directory = MmapDirectory::open(directory_path)?;
if Index::exists(&mmap_directory) { if Index::exists(&mmap_directory) {
return Err(TantivyError::IndexAlreadyExists); return Err(TantivyError::IndexAlreadyExists);
@@ -112,7 +108,7 @@ impl Index {
} }
/// Opens or creates a new index in the provided directory /// Opens or creates a new index in the provided directory
pub fn open_or_create<Dir: Directory>(dir: Dir, schema: Schema) -> crate::Result<Index> { pub fn open_or_create<Dir: Directory>(dir: Dir, schema: Schema) -> Result<Index> {
if !Index::exists(&dir) { if !Index::exists(&dir) {
return Index::create(dir, schema); return Index::create(dir, schema);
} }
@@ -135,13 +131,13 @@ impl Index {
/// The temp directory is only used for testing the `MmapDirectory`. /// The temp directory is only used for testing the `MmapDirectory`.
/// For other unit tests, prefer the `RAMDirectory`, see: `create_in_ram`. /// For other unit tests, prefer the `RAMDirectory`, see: `create_in_ram`.
#[cfg(feature = "mmap")] #[cfg(feature = "mmap")]
pub fn create_from_tempdir(schema: Schema) -> crate::Result<Index> { pub fn create_from_tempdir(schema: Schema) -> Result<Index> {
let mmap_directory = MmapDirectory::create_from_tempdir()?; let mmap_directory = MmapDirectory::create_from_tempdir()?;
Index::create(mmap_directory, schema) Index::create(mmap_directory, schema)
} }
/// Creates a new index given an implementation of the trait `Directory` /// Creates a new index given an implementation of the trait `Directory`
pub fn create<Dir: Directory>(dir: Dir, schema: Schema) -> crate::Result<Index> { pub fn create<Dir: Directory>(dir: Dir, schema: Schema) -> Result<Index> {
let directory = ManagedDirectory::wrap(dir)?; let directory = ManagedDirectory::wrap(dir)?;
Index::from_directory(directory, schema) Index::from_directory(directory, schema)
} }
@@ -149,7 +145,7 @@ impl Index {
/// Create a new index from a directory. /// Create a new index from a directory.
/// ///
/// This will overwrite existing meta.json /// This will overwrite existing meta.json
fn from_directory(mut directory: ManagedDirectory, schema: Schema) -> crate::Result<Index> { fn from_directory(mut directory: ManagedDirectory, schema: Schema) -> Result<Index> {
save_new_metas(schema.clone(), directory.borrow_mut())?; save_new_metas(schema.clone(), directory.borrow_mut())?;
let metas = IndexMeta::with_schema(schema); let metas = IndexMeta::with_schema(schema);
Index::create_from_metas(directory, &metas, SegmentMetaInventory::default()) Index::create_from_metas(directory, &metas, SegmentMetaInventory::default())
@@ -160,7 +156,7 @@ impl Index {
directory: ManagedDirectory, directory: ManagedDirectory,
metas: &IndexMeta, metas: &IndexMeta,
inventory: SegmentMetaInventory, inventory: SegmentMetaInventory,
) -> crate::Result<Index> { ) -> Result<Index> {
let schema = metas.schema.clone(); let schema = metas.schema.clone();
Ok(Index { Ok(Index {
directory, directory,
@@ -177,11 +173,11 @@ impl Index {
} }
/// Helper to access the tokenizer associated to a specific field. /// Helper to access the tokenizer associated to a specific field.
pub fn tokenizer_for_field(&self, field: Field) -> crate::Result<TextAnalyzer> { pub fn tokenizer_for_field(&self, field: Field) -> Result<BoxedTokenizer> {
let field_entry = self.schema.get_field_entry(field); let field_entry = self.schema.get_field_entry(field);
let field_type = field_entry.field_type(); let field_type = field_entry.field_type();
let tokenizer_manager: &TokenizerManager = self.tokenizers(); let tokenizer_manager: &TokenizerManager = self.tokenizers();
let tokenizer_name_opt: Option<TextAnalyzer> = match field_type { let tokenizer_name_opt: Option<BoxedTokenizer> = match field_type {
FieldType::Str(text_options) => text_options FieldType::Str(text_options) => text_options
.get_indexing_options() .get_indexing_options()
.map(|text_indexing_options| text_indexing_options.tokenizer().to_string()) .map(|text_indexing_options| text_indexing_options.tokenizer().to_string())
@@ -200,7 +196,7 @@ impl Index {
/// Create a default `IndexReader` for the given index. /// Create a default `IndexReader` for the given index.
/// ///
/// See [`Index.reader_builder()`](#method.reader_builder). /// See [`Index.reader_builder()`](#method.reader_builder).
pub fn reader(&self) -> crate::Result<IndexReader> { pub fn reader(&self) -> Result<IndexReader> {
self.reader_builder().try_into() self.reader_builder().try_into()
} }
@@ -215,7 +211,7 @@ impl Index {
/// Opens a new directory from an index path. /// Opens a new directory from an index path.
#[cfg(feature = "mmap")] #[cfg(feature = "mmap")]
pub fn open_in_dir<P: AsRef<Path>>(directory_path: P) -> crate::Result<Index> { pub fn open_in_dir<P: AsRef<Path>>(directory_path: P) -> Result<Index> {
let mmap_directory = MmapDirectory::open(directory_path)?; let mmap_directory = MmapDirectory::open(directory_path)?;
Index::open(mmap_directory) Index::open(mmap_directory)
} }
@@ -239,7 +235,7 @@ impl Index {
} }
/// Open the index using the provided directory /// Open the index using the provided directory
pub fn open<D: Directory>(directory: D) -> crate::Result<Index> { pub fn open<D: Directory>(directory: D) -> Result<Index> {
let directory = ManagedDirectory::wrap(directory)?; let directory = ManagedDirectory::wrap(directory)?;
let inventory = SegmentMetaInventory::default(); let inventory = SegmentMetaInventory::default();
let metas = load_metas(&directory, &inventory)?; let metas = load_metas(&directory, &inventory)?;
@@ -247,7 +243,7 @@ impl Index {
} }
/// Reads the index meta file from the directory. /// Reads the index meta file from the directory.
pub fn load_metas(&self) -> crate::Result<IndexMeta> { pub fn load_metas(&self) -> Result<IndexMeta> {
load_metas(self.directory(), &self.inventory) load_metas(self.directory(), &self.inventory)
} }
@@ -275,7 +271,7 @@ impl Index {
&self, &self,
num_threads: usize, num_threads: usize,
overall_heap_size_in_bytes: usize, overall_heap_size_in_bytes: usize,
) -> crate::Result<IndexWriter> { ) -> Result<IndexWriter> {
let directory_lock = self let directory_lock = self
.directory .directory
.acquire_lock(&INDEX_WRITER_LOCK) .acquire_lock(&INDEX_WRITER_LOCK)
@@ -310,7 +306,7 @@ impl Index {
/// If the lockfile already exists, returns `Error::FileAlreadyExists`. /// If the lockfile already exists, returns `Error::FileAlreadyExists`.
/// # Panics /// # Panics
/// If the heap size per thread is too small, panics. /// If the heap size per thread is too small, panics.
pub fn writer(&self, overall_heap_size_in_bytes: usize) -> crate::Result<IndexWriter> { pub fn writer(&self, overall_heap_size_in_bytes: usize) -> Result<IndexWriter> {
let mut num_threads = num_cpus::get(); let mut num_threads = num_cpus::get();
let heap_size_in_bytes_per_thread = overall_heap_size_in_bytes / num_threads; let heap_size_in_bytes_per_thread = overall_heap_size_in_bytes / num_threads;
if heap_size_in_bytes_per_thread < HEAP_SIZE_MIN { if heap_size_in_bytes_per_thread < HEAP_SIZE_MIN {
@@ -327,7 +323,7 @@ impl Index {
} }
/// Returns the list of segments that are searchable /// Returns the list of segments that are searchable
pub fn searchable_segments(&self) -> crate::Result<Vec<Segment>> { pub fn searchable_segments(&self) -> Result<Vec<Segment>> {
Ok(self Ok(self
.searchable_segment_metas()? .searchable_segment_metas()?
.into_iter() .into_iter()
@@ -337,12 +333,12 @@ impl Index {
#[doc(hidden)] #[doc(hidden)]
pub fn segment(&self, segment_meta: SegmentMeta) -> Segment { pub fn segment(&self, segment_meta: SegmentMeta) -> Segment {
Segment::for_index(self.clone(), segment_meta) create_segment(self.clone(), segment_meta)
} }
/// Creates a new segment. /// Creates a new segment.
pub fn new_segment(&self) -> Segment { pub fn new_segment(&self) -> Segment {
let segment_meta = self let mut segment_meta = self
.inventory .inventory
.new_segment_meta(SegmentId::generate_random(), 0); .new_segment_meta(SegmentId::generate_random(), 0);
self.segment(segment_meta) self.segment(segment_meta)
@@ -360,12 +356,12 @@ impl Index {
/// Reads the meta.json and returns the list of /// Reads the meta.json and returns the list of
/// `SegmentMeta` from the last commit. /// `SegmentMeta` from the last commit.
pub fn searchable_segment_metas(&self) -> crate::Result<Vec<SegmentMeta>> { pub fn searchable_segment_metas(&self) -> Result<Vec<SegmentMeta>> {
Ok(self.load_metas()?.segments) Ok(self.load_metas()?.segments)
} }
/// Returns the list of segment ids that are searchable. /// Returns the list of segment ids that are searchable.
pub fn searchable_segment_ids(&self) -> crate::Result<Vec<SegmentId>> { pub fn searchable_segment_ids(&self) -> Result<Vec<SegmentId>> {
Ok(self Ok(self
.searchable_segment_metas()? .searchable_segment_metas()?
.iter() .iter()
@@ -374,7 +370,7 @@ impl Index {
} }
/// Returns the set of corrupted files /// Returns the set of corrupted files
pub fn validate_checksum(&self) -> crate::Result<HashSet<PathBuf>> { pub fn validate_checksum(&self) -> Result<HashSet<PathBuf>> {
self.directory.list_damaged().map_err(Into::into) self.directory.list_damaged().map_err(Into::into)
} }
} }

View File

@@ -3,7 +3,8 @@ use crate::core::SegmentId;
use crate::schema::Schema; use crate::schema::Schema;
use crate::Opstamp; use crate::Opstamp;
use census::{Inventory, TrackedObject}; use census::{Inventory, TrackedObject};
use serde::{Deserialize, Serialize}; use serde;
use serde_json;
use std::collections::HashSet; use std::collections::HashSet;
use std::fmt; use std::fmt;
use std::path::PathBuf; use std::path::PathBuf;
@@ -34,6 +35,7 @@ impl SegmentMetaInventory {
segment_id, segment_id,
max_doc, max_doc,
deletes: None, deletes: None,
bundled: false,
}; };
SegmentMeta::from(self.inventory.track(inner)) SegmentMeta::from(self.inventory.track(inner))
} }
@@ -80,6 +82,19 @@ impl SegmentMeta {
self.tracked.segment_id self.tracked.segment_id
} }
pub fn with_bundled(self) -> SegmentMeta {
SegmentMeta::from(self.tracked.map(|inner| InnerSegmentMeta {
segment_id: inner.segment_id,
max_doc: inner.max_doc,
deletes: inner.deletes.clone(),
bundled: true,
}))
}
pub fn is_bundled(&self) -> bool {
self.tracked.bundled
}
/// Returns the number of deleted documents. /// Returns the number of deleted documents.
pub fn num_deleted_docs(&self) -> u32 { pub fn num_deleted_docs(&self) -> u32 {
self.tracked self.tracked
@@ -106,8 +121,12 @@ impl SegmentMeta {
/// It just joins the segment id with the extension /// It just joins the segment id with the extension
/// associated to a segment component. /// associated to a segment component.
pub fn relative_path(&self, component: SegmentComponent) -> PathBuf { pub fn relative_path(&self, component: SegmentComponent) -> PathBuf {
let mut path = self.id().uuid_string(); let suffix = self.suffix(component);
path.push_str(&*match component { self.relative_path_from_suffix(&suffix)
}
fn suffix(&self, component: SegmentComponent) -> String {
match component {
SegmentComponent::POSTINGS => ".idx".to_string(), SegmentComponent::POSTINGS => ".idx".to_string(),
SegmentComponent::POSITIONS => ".pos".to_string(), SegmentComponent::POSITIONS => ".pos".to_string(),
SegmentComponent::POSITIONSSKIP => ".posidx".to_string(), SegmentComponent::POSITIONSSKIP => ".posidx".to_string(),
@@ -116,7 +135,17 @@ impl SegmentMeta {
SegmentComponent::FASTFIELDS => ".fast".to_string(), SegmentComponent::FASTFIELDS => ".fast".to_string(),
SegmentComponent::FIELDNORMS => ".fieldnorm".to_string(), SegmentComponent::FIELDNORMS => ".fieldnorm".to_string(),
SegmentComponent::DELETE => format!(".{}.del", self.delete_opstamp().unwrap_or(0)), SegmentComponent::DELETE => format!(".{}.del", self.delete_opstamp().unwrap_or(0)),
}); }
}
/// Returns the relative path of a component of our segment.
///
/// It just joins the segment id with the extension
/// associated to a segment component.
pub fn relative_path_from_suffix(&self, suffix: &str) -> PathBuf {
let mut path = self.id().uuid_string();
path.push_str(".");
path.push_str(&suffix);
PathBuf::from(path) PathBuf::from(path)
} }
@@ -160,6 +189,7 @@ impl SegmentMeta {
segment_id: inner_meta.segment_id, segment_id: inner_meta.segment_id,
max_doc, max_doc,
deletes: None, deletes: None,
bundled: inner_meta.bundled,
}); });
SegmentMeta { tracked } SegmentMeta { tracked }
} }
@@ -174,6 +204,7 @@ impl SegmentMeta {
segment_id: inner_meta.segment_id, segment_id: inner_meta.segment_id,
max_doc: inner_meta.max_doc, max_doc: inner_meta.max_doc,
deletes: Some(delete_meta), deletes: Some(delete_meta),
bundled: inner_meta.bundled,
}); });
SegmentMeta { tracked } SegmentMeta { tracked }
} }
@@ -184,6 +215,7 @@ struct InnerSegmentMeta {
segment_id: SegmentId, segment_id: SegmentId,
max_doc: u32, max_doc: u32,
deletes: Option<DeleteMeta>, deletes: Option<DeleteMeta>,
bundled: bool,
} }
impl InnerSegmentMeta { impl InnerSegmentMeta {

View File

@@ -7,6 +7,7 @@ use crate::schema::FieldType;
use crate::schema::IndexRecordOption; use crate::schema::IndexRecordOption;
use crate::schema::Term; use crate::schema::Term;
use crate::termdict::TermDictionary; use crate::termdict::TermDictionary;
use owned_read::OwnedRead;
/// The inverted index reader is in charge of accessing /// The inverted index reader is in charge of accessing
/// the inverted index associated to a specific field. /// the inverted index associated to a specific field.
@@ -59,7 +60,7 @@ impl InvertedIndexReader {
.get_index_record_option() .get_index_record_option()
.unwrap_or(IndexRecordOption::Basic); .unwrap_or(IndexRecordOption::Basic);
InvertedIndexReader { InvertedIndexReader {
termdict: TermDictionary::empty(), termdict: TermDictionary::empty(&field_type),
postings_source: ReadOnlySource::empty(), postings_source: ReadOnlySource::empty(),
positions_source: ReadOnlySource::empty(), positions_source: ReadOnlySource::empty(),
positions_idx_source: ReadOnlySource::empty(), positions_idx_source: ReadOnlySource::empty(),
@@ -96,7 +97,8 @@ impl InvertedIndexReader {
let offset = term_info.postings_offset as usize; let offset = term_info.postings_offset as usize;
let end_source = self.postings_source.len(); let end_source = self.postings_source.len();
let postings_slice = self.postings_source.slice(offset, end_source); let postings_slice = self.postings_source.slice(offset, end_source);
block_postings.reset(term_info.doc_freq, postings_slice); let postings_reader = OwnedRead::new(postings_slice);
block_postings.reset(term_info.doc_freq, postings_reader);
} }
/// Returns a block postings given a `Term`. /// Returns a block postings given a `Term`.
@@ -125,7 +127,7 @@ impl InvertedIndexReader {
let postings_data = self.postings_source.slice_from(offset); let postings_data = self.postings_source.slice_from(offset);
BlockSegmentPostings::from_data( BlockSegmentPostings::from_data(
term_info.doc_freq, term_info.doc_freq,
postings_data, OwnedRead::new(postings_data),
self.record_option, self.record_option,
requested_option, requested_option,
) )

View File

@@ -1,8 +1,11 @@
use crate::collector::Collector; use crate::collector::Collector;
use crate::collector::SegmentCollector;
use crate::core::Executor; use crate::core::Executor;
use crate::core::InvertedIndexReader; use crate::core::InvertedIndexReader;
use crate::core::SegmentReader; use crate::core::SegmentReader;
use crate::query::Query; use crate::query::Query;
use crate::query::Scorer;
use crate::query::Weight;
use crate::schema::Document; use crate::schema::Document;
use crate::schema::Schema; use crate::schema::Schema;
use crate::schema::{Field, Term}; use crate::schema::{Field, Term};
@@ -11,9 +14,30 @@ use crate::store::StoreReader;
use crate::termdict::TermMerger; use crate::termdict::TermMerger;
use crate::DocAddress; use crate::DocAddress;
use crate::Index; use crate::Index;
use crate::Result;
use std::fmt; use std::fmt;
use std::sync::Arc; use std::sync::Arc;
fn collect_segment<C: Collector>(
collector: &C,
weight: &dyn Weight,
segment_ord: u32,
segment_reader: &SegmentReader,
) -> Result<C::Fruit> {
let mut scorer = weight.scorer(segment_reader)?;
let mut segment_collector = collector.for_segment(segment_ord as u32, segment_reader)?;
if let Some(delete_bitset) = segment_reader.delete_bitset() {
scorer.for_each(&mut |doc, score| {
if delete_bitset.is_alive(doc) {
segment_collector.collect(doc, score);
}
});
} else {
scorer.for_each(&mut |doc, score| segment_collector.collect(doc, score));
}
Ok(segment_collector.harvest())
}
/// Holds a list of `SegmentReader`s ready for search. /// Holds a list of `SegmentReader`s ready for search.
/// ///
/// It guarantees that the `Segment` will not be removed before /// It guarantees that the `Segment` will not be removed before
@@ -54,7 +78,7 @@ impl Searcher {
/// ///
/// The searcher uses the segment ordinal to route the /// The searcher uses the segment ordinal to route the
/// the request to the right `Segment`. /// the request to the right `Segment`.
pub fn doc(&self, doc_address: DocAddress) -> crate::Result<Document> { pub fn doc(&self, doc_address: DocAddress) -> Result<Document> {
let DocAddress(segment_local_id, doc_id) = doc_address; let DocAddress(segment_local_id, doc_id) = doc_address;
let store_reader = &self.store_readers[segment_local_id as usize]; let store_reader = &self.store_readers[segment_local_id as usize];
store_reader.get(doc_id) store_reader.get(doc_id)
@@ -108,11 +132,7 @@ impl Searcher {
/// ///
/// Finally, the Collector merges each of the child collectors into itself for result usability /// Finally, the Collector merges each of the child collectors into itself for result usability
/// by the caller. /// by the caller.
pub fn search<C: Collector>( pub fn search<C: Collector>(&self, query: &dyn Query, collector: &C) -> Result<C::Fruit> {
&self,
query: &dyn Query,
collector: &C,
) -> crate::Result<C::Fruit> {
let executor = self.index.search_executor(); let executor = self.index.search_executor();
self.search_with_executor(query, collector, executor) self.search_with_executor(query, collector, executor)
} }
@@ -134,13 +154,18 @@ impl Searcher {
query: &dyn Query, query: &dyn Query,
collector: &C, collector: &C,
executor: &Executor, executor: &Executor,
) -> crate::Result<C::Fruit> { ) -> Result<C::Fruit> {
let scoring_enabled = collector.requires_scoring(); let scoring_enabled = collector.requires_scoring();
let weight = query.weight(self, scoring_enabled)?; let weight = query.weight(self, scoring_enabled)?;
let segment_readers = self.segment_readers(); let segment_readers = self.segment_readers();
let fruits = executor.map( let fruits = executor.map(
|(segment_ord, segment_reader)| { |(segment_ord, segment_reader)| {
collector.collect_segment(weight.as_ref(), segment_ord as u32, segment_reader) collect_segment(
collector,
weight.as_ref(),
segment_ord as u32,
segment_reader,
)
}, },
segment_readers.iter().enumerate(), segment_readers.iter().enumerate(),
)?; )?;

View File

@@ -4,7 +4,7 @@ use crate::core::SegmentId;
use crate::core::SegmentMeta; use crate::core::SegmentMeta;
use crate::directory::error::{OpenReadError, OpenWriteError}; use crate::directory::error::{OpenReadError, OpenWriteError};
use crate::directory::Directory; use crate::directory::Directory;
use crate::directory::{ReadOnlySource, WritePtr}; use crate::directory::{ReadOnlyDirectory, ReadOnlySource, WritePtr};
use crate::indexer::segment_serializer::SegmentSerializer; use crate::indexer::segment_serializer::SegmentSerializer;
use crate::schema::Schema; use crate::schema::Schema;
use crate::Opstamp; use crate::Opstamp;
@@ -24,12 +24,15 @@ impl fmt::Debug for Segment {
} }
} }
impl Segment { /// Creates a new segment given an `Index` and a `SegmentId`
/// Creates a new segment given an `Index` and a `SegmentId` ///
pub(crate) fn for_index(index: Index, meta: SegmentMeta) -> Segment { /// The function is here to make it private outside `tantivy`.
Segment { index, meta } /// #[doc(hidden)]
} pub fn create_segment(index: Index, meta: SegmentMeta) -> Segment {
Segment { index, meta }
}
impl Segment {
/// Returns the index the segment belongs to. /// Returns the index the segment belongs to.
pub fn index(&self) -> &Index { pub fn index(&self) -> &Index {
&self.index &self.index
@@ -87,8 +90,21 @@ impl Segment {
/// Open one of the component file for *regular* write. /// Open one of the component file for *regular* write.
pub fn open_write(&mut self, component: SegmentComponent) -> Result<WritePtr, OpenWriteError> { pub fn open_write(&mut self, component: SegmentComponent) -> Result<WritePtr, OpenWriteError> {
let path = self.relative_path(component); let path = self.relative_path(component);
let write = self.index.directory_mut().open_write(&path)?; self.index.directory_mut().open_write(&path)
Ok(write) }
pub fn open_bundle_writer(&mut self) -> Result<WritePtr, OpenWriteError> {
let path = self.meta.relative_path_from_suffix("bundle");
self.index.directory_mut().open_write(&path)
}
pub(crate) fn open_write_in_directory(
&mut self,
component: SegmentComponent,
directory: &mut dyn Directory,
) -> Result<WritePtr, OpenWriteError> {
let path = self.relative_path(component);
directory.open_write(&path)
} }
} }

View File

@@ -4,7 +4,6 @@ use uuid::Uuid;
#[cfg(test)] #[cfg(test)]
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use std::error::Error; use std::error::Error;
use std::str::FromStr; use std::str::FromStr;
#[cfg(test)] #[cfg(test)]

View File

@@ -16,6 +16,7 @@ use crate::space_usage::SegmentSpaceUsage;
use crate::store::StoreReader; use crate::store::StoreReader;
use crate::termdict::TermDictionary; use crate::termdict::TermDictionary;
use crate::DocId; use crate::DocId;
use crate::Result;
use fail::fail_point; use fail::fail_point;
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt; use std::fmt;
@@ -144,7 +145,7 @@ impl SegmentReader {
} }
/// Open a new segment for reading. /// Open a new segment for reading.
pub fn open(segment: &Segment) -> crate::Result<SegmentReader> { pub fn open(segment: &Segment) -> Result<SegmentReader> {
let termdict_source = segment.open_read(SegmentComponent::TERMS)?; let termdict_source = segment.open_read(SegmentComponent::TERMS)?;
let termdict_composite = CompositeFile::open(&termdict_source)?; let termdict_composite = CompositeFile::open(&termdict_source)?;

View File

@@ -0,0 +1,97 @@
use crate::directory::directory::ReadOnlyDirectory;
use crate::directory::error::OpenReadError;
use crate::directory::ReadOnlySource;
use crate::error::DataCorruption;
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::Arc;
#[derive(Clone)]
struct BundleDirectory {
source_map: Arc<HashMap<PathBuf, ReadOnlySource>>,
}
impl BundleDirectory {
pub fn from_source(source: ReadOnlySource) -> Result<BundleDirectory, DataCorruption> {
let mut index_offset_buf = [0u8; 8];
let (body_idx, footer_offset) = source.split_from_end(8);
index_offset_buf.copy_from_slice(footer_offset.as_slice());
let offset = u64::from_le_bytes(index_offset_buf);
let (body_source, idx_source) = body_idx.split(offset as usize);
let idx: HashMap<PathBuf, (u64, u64)> = serde_json::from_slice(idx_source.as_slice())
.map_err(|err| {
let msg = format!("Failed to read index from bundle. {:?}", err);
DataCorruption::comment_only(msg)
})?;
let source_map: HashMap<PathBuf, ReadOnlySource> = idx
.into_iter()
.map(|(path, (start, stop))| {
let source = body_source.slice(start as usize, stop as usize);
(path, source)
})
.collect();
Ok(BundleDirectory {
source_map: Arc::new(source_map),
})
}
}
impl ReadOnlyDirectory for BundleDirectory {
fn open_read(&self, path: &Path) -> Result<ReadOnlySource, OpenReadError> {
self.source_map
.get(path)
.cloned()
.ok_or_else(|| OpenReadError::FileDoesNotExist(path.to_path_buf()))
}
fn exists(&self, path: &Path) -> bool {
self.source_map.contains_key(path)
}
fn atomic_read(&self, path: &Path) -> Result<Vec<u8>, OpenReadError> {
let source = self
.source_map
.get(path)
.ok_or_else(|| OpenReadError::FileDoesNotExist(path.to_path_buf()))?;
Ok(source.as_slice().to_vec())
}
}
#[cfg(test)]
mod tests {
use super::BundleDirectory;
use crate::directory::{RAMDirectory, ReadOnlyDirectory, TerminatingWrite};
use crate::Directory;
use std::io::Write;
use std::path::Path;
#[test]
fn test_bundle_directory() {
let mut ram_directory = RAMDirectory::default();
let test_path_atomic = Path::new("testpath_atomic");
let test_path_wrt = Path::new("testpath_wrt");
assert!(ram_directory
.atomic_write(test_path_atomic, b"titi")
.is_ok());
{
let mut test_wrt = ram_directory.open_write(test_path_wrt).unwrap();
assert!(test_wrt.write_all(b"toto").is_ok());
assert!(test_wrt.terminate().is_ok());
}
let mut dest_directory = RAMDirectory::default();
let bundle_path = Path::new("bundle");
let mut wrt = dest_directory.open_write(bundle_path).unwrap();
assert!(ram_directory.serialize_bundle(&mut wrt).is_ok());
assert!(wrt.terminate().is_ok());
let source = dest_directory.open_read(bundle_path).unwrap();
let bundle_directory = BundleDirectory::from_source(source).unwrap();
assert_eq!(
&bundle_directory.atomic_read(test_path_atomic).unwrap()[..],
b"titi"
);
assert_eq!(
&bundle_directory.open_read(test_path_wrt).unwrap()[..],
b"toto"
);
}
}

View File

@@ -100,17 +100,7 @@ fn retry_policy(is_blocking: bool) -> RetryPolicy {
} }
} }
/// Write-once read many (WORM) abstraction for where pub trait ReadOnlyDirectory {
/// tantivy's data should be stored.
///
/// There are currently two implementations of `Directory`
///
/// - The [`MMapDirectory`](struct.MmapDirectory.html), this
/// should be your default choice.
/// - The [`RAMDirectory`](struct.RAMDirectory.html), which
/// should be used mostly for tests.
///
pub trait Directory: DirectoryClone + fmt::Debug + Send + Sync + 'static {
/// Opens a virtual file for read. /// Opens a virtual file for read.
/// ///
/// Once a virtual file is open, its data may not /// Once a virtual file is open, its data may not
@@ -122,6 +112,31 @@ pub trait Directory: DirectoryClone + fmt::Debug + Send + Sync + 'static {
/// You should only use this to read files create with [Directory::open_write]. /// You should only use this to read files create with [Directory::open_write].
fn open_read(&self, path: &Path) -> result::Result<ReadOnlySource, OpenReadError>; fn open_read(&self, path: &Path) -> result::Result<ReadOnlySource, OpenReadError>;
/// Returns true iff the file exists
fn exists(&self, path: &Path) -> bool;
/// Reads the full content file that has been written using
/// atomic_write.
///
/// This should only be used for small files.
///
/// You should only use this to read files create with [Directory::atomic_write].
fn atomic_read(&self, path: &Path) -> Result<Vec<u8>, OpenReadError>;
}
/// Write-once read many (WORM) abstraction for where
/// tantivy's data should be stored.
///
/// There are currently two implementations of `Directory`
///
/// - The [`MMapDirectory`](struct.MmapDirectory.html), this
/// should be your default choice.
/// - The [`RAMDirectory`](struct.RAMDirectory.html), which
/// should be used mostly for tests.
///
pub trait Directory:
DirectoryClone + ReadOnlyDirectory + fmt::Debug + Send + Sync + 'static
{
/// Removes a file /// Removes a file
/// ///
/// Removing a file will not affect an eventual /// Removing a file will not affect an eventual
@@ -131,9 +146,6 @@ pub trait Directory: DirectoryClone + fmt::Debug + Send + Sync + 'static {
/// `DeleteError::DoesNotExist`. /// `DeleteError::DoesNotExist`.
fn delete(&self, path: &Path) -> result::Result<(), DeleteError>; fn delete(&self, path: &Path) -> result::Result<(), DeleteError>;
/// Returns true iff the file exists
fn exists(&self, path: &Path) -> bool;
/// Opens a writer for the *virtual file* associated with /// Opens a writer for the *virtual file* associated with
/// a Path. /// a Path.
/// ///
@@ -155,14 +167,6 @@ pub trait Directory: DirectoryClone + fmt::Debug + Send + Sync + 'static {
/// The file may not previously exist. /// The file may not previously exist.
fn open_write(&mut self, path: &Path) -> Result<WritePtr, OpenWriteError>; fn open_write(&mut self, path: &Path) -> Result<WritePtr, OpenWriteError>;
/// Reads the full content file that has been written using
/// atomic_write.
///
/// This should only be used for small files.
///
/// You should only use this to read files create with [Directory::atomic_write].
fn atomic_read(&self, path: &Path) -> Result<Vec<u8>, OpenReadError>;
/// Atomically replace the content of a file with data. /// Atomically replace the content of a file with data.
/// ///
/// This calls ensure that reads can never *observe* /// This calls ensure that reads can never *observe*

View File

@@ -8,8 +8,6 @@ use crc32fast::Hasher;
use std::io; use std::io;
use std::io::Write; use std::io::Write;
const FOOTER_MAX_LEN: usize = 10_000;
type CrcHashU32 = u32; type CrcHashU32 = u32;
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
@@ -145,23 +143,12 @@ impl BinarySerializable for VersionedFooter {
} }
} }
BinarySerializable::serialize(&VInt(buf.len() as u64), writer)?; BinarySerializable::serialize(&VInt(buf.len() as u64), writer)?;
assert!(buf.len() <= FOOTER_MAX_LEN);
writer.write_all(&buf[..])?; writer.write_all(&buf[..])?;
Ok(()) Ok(())
} }
fn deserialize<R: io::Read>(reader: &mut R) -> io::Result<Self> { fn deserialize<R: io::Read>(reader: &mut R) -> io::Result<Self> {
let len = VInt::deserialize(reader)?.0 as usize; let len = VInt::deserialize(reader)?.0 as usize;
if len > FOOTER_MAX_LEN {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"Footer seems invalid as it suggests a footer len of {}. File is corrupted, \
or the index was created with a different & old version of tantivy.",
len
),
));
}
let mut buf = vec![0u8; len]; let mut buf = vec![0u8; len];
reader.read_exact(&mut buf[..])?; reader.read_exact(&mut buf[..])?;
let mut cursor = &buf[..]; let mut cursor = &buf[..];
@@ -234,12 +221,11 @@ mod tests {
use super::CrcHashU32; use super::CrcHashU32;
use super::FooterProxy; use super::FooterProxy;
use crate::common::{BinarySerializable, VInt}; use crate::common::BinarySerializable;
use crate::directory::footer::{Footer, VersionedFooter}; use crate::directory::footer::{Footer, VersionedFooter};
use crate::directory::TerminatingWrite; use crate::directory::TerminatingWrite;
use byteorder::{ByteOrder, LittleEndian}; use byteorder::{ByteOrder, LittleEndian};
use regex::Regex; use regex::Regex;
use std::io;
#[test] #[test]
fn test_versioned_footer() { fn test_versioned_footer() {
@@ -350,20 +336,4 @@ mod tests {
let res = footer.is_compatible(); let res = footer.is_compatible();
assert!(res.is_err()); assert!(res.is_err());
} }
#[test]
fn test_deserialize_too_large_footer() {
let mut buf = vec![];
assert!(FooterProxy::new(&mut buf).terminate().is_ok());
let mut long_len_buf = [0u8; 10];
let num_bytes = VInt(super::FOOTER_MAX_LEN as u64 + 1u64).serialize_into(&mut long_len_buf);
buf[0..num_bytes].copy_from_slice(&long_len_buf[..num_bytes]);
let err = Footer::deserialize(&mut &buf[..]).unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
assert_eq!(
err.to_string(),
"Footer seems invalid as it suggests a footer len of 10001. File is corrupted, \
or the index was created with a different & old version of tantivy."
);
}
} }

View File

@@ -10,7 +10,9 @@ use crate::directory::{WatchCallback, WatchHandle};
use crate::error::DataCorruption; use crate::error::DataCorruption;
use crate::Directory; use crate::Directory;
use crate::directory::directory::ReadOnlyDirectory;
use crc32fast::Hasher; use crc32fast::Hasher;
use serde_json;
use std::collections::HashSet; use std::collections::HashSet;
use std::io; use std::io;
use std::io::Write; use std::io::Write;
@@ -149,7 +151,7 @@ impl ManagedDirectory {
} }
Err(err) => { Err(err) => {
error!("Failed to acquire lock for GC"); error!("Failed to acquire lock for GC");
return Err(crate::TantivyError::from(err)); return Err(crate::Error::from(err));
} }
} }
} }
@@ -263,14 +265,6 @@ impl ManagedDirectory {
} }
impl Directory for ManagedDirectory { impl Directory for ManagedDirectory {
fn open_read(&self, path: &Path) -> result::Result<ReadOnlySource, OpenReadError> {
let read_only_source = self.directory.open_read(path)?;
let (footer, reader) = Footer::extract_footer(read_only_source)
.map_err(|err| IOError::with_path(path.to_path_buf(), err))?;
footer.is_compatible()?;
Ok(reader)
}
fn open_write(&mut self, path: &Path) -> result::Result<WritePtr, OpenWriteError> { fn open_write(&mut self, path: &Path) -> result::Result<WritePtr, OpenWriteError> {
self.register_file_as_managed(path) self.register_file_as_managed(path)
.map_err(|e| IOError::with_path(path.to_owned(), e))?; .map_err(|e| IOError::with_path(path.to_owned(), e))?;
@@ -288,18 +282,10 @@ impl Directory for ManagedDirectory {
self.directory.atomic_write(path, data) self.directory.atomic_write(path, data)
} }
fn atomic_read(&self, path: &Path) -> result::Result<Vec<u8>, OpenReadError> {
self.directory.atomic_read(path)
}
fn delete(&self, path: &Path) -> result::Result<(), DeleteError> { fn delete(&self, path: &Path) -> result::Result<(), DeleteError> {
self.directory.delete(path) self.directory.delete(path)
} }
fn exists(&self, path: &Path) -> bool {
self.directory.exists(path)
}
fn acquire_lock(&self, lock: &Lock) -> result::Result<DirectoryLock, LockError> { fn acquire_lock(&self, lock: &Lock) -> result::Result<DirectoryLock, LockError> {
self.directory.acquire_lock(lock) self.directory.acquire_lock(lock)
} }
@@ -309,6 +295,24 @@ impl Directory for ManagedDirectory {
} }
} }
impl ReadOnlyDirectory for ManagedDirectory {
fn open_read(&self, path: &Path) -> result::Result<ReadOnlySource, OpenReadError> {
let read_only_source = self.directory.open_read(path)?;
let (footer, reader) = Footer::extract_footer(read_only_source)
.map_err(|err| IOError::with_path(path.to_path_buf(), err))?;
footer.is_compatible()?;
Ok(reader)
}
fn exists(&self, path: &Path) -> bool {
self.directory.exists(path)
}
fn atomic_read(&self, path: &Path) -> result::Result<Vec<u8>, OpenReadError> {
self.directory.atomic_read(path)
}
}
impl Clone for ManagedDirectory { impl Clone for ManagedDirectory {
fn clone(&self) -> ManagedDirectory { fn clone(&self) -> ManagedDirectory {
ManagedDirectory { ManagedDirectory {
@@ -322,7 +326,9 @@ impl Clone for ManagedDirectory {
#[cfg(test)] #[cfg(test)]
mod tests_mmap_specific { mod tests_mmap_specific {
use crate::directory::{Directory, ManagedDirectory, MmapDirectory, TerminatingWrite}; use crate::directory::{
Directory, ManagedDirectory, MmapDirectory, ReadOnlyDirectory, TerminatingWrite,
};
use std::collections::HashSet; use std::collections::HashSet;
use std::fs::OpenOptions; use std::fs::OpenOptions;
use std::io::Write; use std::io::Write;

View File

@@ -1,4 +1,12 @@
use fs2;
use notify;
use self::fs2::FileExt;
use self::notify::RawEvent;
use self::notify::RecursiveMode;
use self::notify::Watcher;
use crate::core::META_FILEPATH; use crate::core::META_FILEPATH;
use crate::directory::directory::ReadOnlyDirectory;
use crate::directory::error::LockError; use crate::directory::error::LockError;
use crate::directory::error::{ use crate::directory::error::{
DeleteError, IOError, OpenDirectoryError, OpenReadError, OpenWriteError, DeleteError, IOError, OpenDirectoryError, OpenReadError, OpenWriteError,
@@ -13,12 +21,8 @@ use crate::directory::WatchCallback;
use crate::directory::WatchCallbackList; use crate::directory::WatchCallbackList;
use crate::directory::WatchHandle; use crate::directory::WatchHandle;
use crate::directory::{TerminatingWrite, WritePtr}; use crate::directory::{TerminatingWrite, WritePtr};
use fs2::FileExt; use atomicwrites;
use memmap::Mmap; use memmap::Mmap;
use notify::RawEvent;
use notify::RecursiveMode;
use notify::Watcher;
use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use std::convert::From; use std::convert::From;
use std::fmt; use std::fmt;
@@ -220,13 +224,17 @@ struct MmapDirectoryInner {
} }
impl MmapDirectoryInner { impl MmapDirectoryInner {
fn new(root_path: PathBuf, temp_directory: Option<TempDir>) -> MmapDirectoryInner { fn new(
MmapDirectoryInner { root_path: PathBuf,
temp_directory: Option<TempDir>,
) -> Result<MmapDirectoryInner, OpenDirectoryError> {
let mmap_directory_inner = MmapDirectoryInner {
root_path, root_path,
mmap_cache: Default::default(), mmap_cache: Default::default(),
_temp_directory: temp_directory, _temp_directory: temp_directory,
watcher: RwLock::new(None), watcher: RwLock::new(None),
} };
Ok(mmap_directory_inner)
} }
fn watch(&self, watch_callback: WatchCallback) -> crate::Result<WatchHandle> { fn watch(&self, watch_callback: WatchCallback) -> crate::Result<WatchHandle> {
@@ -260,11 +268,14 @@ impl fmt::Debug for MmapDirectory {
} }
impl MmapDirectory { impl MmapDirectory {
fn new(root_path: PathBuf, temp_directory: Option<TempDir>) -> MmapDirectory { fn new(
let inner = MmapDirectoryInner::new(root_path, temp_directory); root_path: PathBuf,
MmapDirectory { temp_directory: Option<TempDir>,
) -> Result<MmapDirectory, OpenDirectoryError> {
let inner = MmapDirectoryInner::new(root_path, temp_directory)?;
Ok(MmapDirectory {
inner: Arc::new(inner), inner: Arc::new(inner),
} })
} }
/// Creates a new MmapDirectory in a temporary directory. /// Creates a new MmapDirectory in a temporary directory.
@@ -274,7 +285,7 @@ impl MmapDirectory {
pub fn create_from_tempdir() -> Result<MmapDirectory, OpenDirectoryError> { pub fn create_from_tempdir() -> Result<MmapDirectory, OpenDirectoryError> {
let tempdir = TempDir::new().map_err(OpenDirectoryError::IoError)?; let tempdir = TempDir::new().map_err(OpenDirectoryError::IoError)?;
let tempdir_path = PathBuf::from(tempdir.path()); let tempdir_path = PathBuf::from(tempdir.path());
Ok(MmapDirectory::new(tempdir_path, Some(tempdir))) MmapDirectory::new(tempdir_path, Some(tempdir))
} }
/// Opens a MmapDirectory in a directory. /// Opens a MmapDirectory in a directory.
@@ -292,7 +303,7 @@ impl MmapDirectory {
directory_path, directory_path,
))) )))
} else { } else {
Ok(MmapDirectory::new(PathBuf::from(directory_path), None)) Ok(MmapDirectory::new(PathBuf::from(directory_path), None)?)
} }
} }
@@ -397,24 +408,6 @@ impl TerminatingWrite for SafeFileWriter {
} }
impl Directory for MmapDirectory { impl Directory for MmapDirectory {
fn open_read(&self, path: &Path) -> result::Result<ReadOnlySource, OpenReadError> {
debug!("Open Read {:?}", path);
let full_path = self.resolve_path(path);
let mut mmap_cache = self.inner.mmap_cache.write().map_err(|_| {
let msg = format!(
"Failed to acquired write lock \
on mmap cache while reading {:?}",
path
);
IOError::with_path(path.to_owned(), make_io_err(msg))
})?;
Ok(mmap_cache
.get_mmap(&full_path)?
.map(ReadOnlySource::from)
.unwrap_or_else(ReadOnlySource::empty))
}
/// Any entry associated to the path in the mmap will be /// Any entry associated to the path in the mmap will be
/// removed before the file is deleted. /// removed before the file is deleted.
fn delete(&self, path: &Path) -> result::Result<(), DeleteError> { fn delete(&self, path: &Path) -> result::Result<(), DeleteError> {
@@ -433,11 +426,6 @@ impl Directory for MmapDirectory {
} }
} }
fn exists(&self, path: &Path) -> bool {
let full_path = self.resolve_path(path);
full_path.exists()
}
fn open_write(&mut self, path: &Path) -> Result<WritePtr, OpenWriteError> { fn open_write(&mut self, path: &Path) -> Result<WritePtr, OpenWriteError> {
debug!("Open Write {:?}", path); debug!("Open Write {:?}", path);
let full_path = self.resolve_path(path); let full_path = self.resolve_path(path);
@@ -468,25 +456,6 @@ impl Directory for MmapDirectory {
Ok(BufWriter::new(Box::new(writer))) Ok(BufWriter::new(Box::new(writer)))
} }
fn atomic_read(&self, path: &Path) -> Result<Vec<u8>, OpenReadError> {
let full_path = self.resolve_path(path);
let mut buffer = Vec::new();
match File::open(&full_path) {
Ok(mut file) => {
file.read_to_end(&mut buffer)
.map_err(|e| IOError::with_path(path.to_owned(), e))?;
Ok(buffer)
}
Err(e) => {
if e.kind() == io::ErrorKind::NotFound {
Err(OpenReadError::FileDoesNotExist(path.to_owned()))
} else {
Err(IOError::with_path(path.to_owned(), e).into())
}
}
}
}
fn atomic_write(&mut self, path: &Path, data: &[u8]) -> io::Result<()> { fn atomic_write(&mut self, path: &Path, data: &[u8]) -> io::Result<()> {
debug!("Atomic Write {:?}", path); debug!("Atomic Write {:?}", path);
let full_path = self.resolve_path(path); let full_path = self.resolve_path(path);
@@ -520,6 +489,50 @@ impl Directory for MmapDirectory {
} }
} }
impl ReadOnlyDirectory for MmapDirectory {
fn open_read(&self, path: &Path) -> result::Result<ReadOnlySource, OpenReadError> {
debug!("Open Read {:?}", path);
let full_path = self.resolve_path(path);
let mut mmap_cache = self.inner.mmap_cache.write().map_err(|_| {
let msg = format!(
"Failed to acquired write lock \
on mmap cache while reading {:?}",
path
);
IOError::with_path(path.to_owned(), make_io_err(msg))
})?;
Ok(mmap_cache
.get_mmap(&full_path)?
.map(ReadOnlySource::from)
.unwrap_or_else(ReadOnlySource::empty))
}
fn exists(&self, path: &Path) -> bool {
let full_path = self.resolve_path(path);
full_path.exists()
}
fn atomic_read(&self, path: &Path) -> Result<Vec<u8>, OpenReadError> {
let full_path = self.resolve_path(path);
let mut buffer = Vec::new();
match File::open(&full_path) {
Ok(mut file) => {
file.read_to_end(&mut buffer)
.map_err(|e| IOError::with_path(path.to_owned(), e))?;
Ok(buffer)
}
Err(e) => {
if e.kind() == io::ErrorKind::NotFound {
Err(OpenReadError::FileDoesNotExist(path.to_owned()))
} else {
Err(IOError::with_path(path.to_owned(), e).into())
}
}
}
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {

View File

@@ -7,6 +7,7 @@ WORM directory abstraction.
#[cfg(feature = "mmap")] #[cfg(feature = "mmap")]
mod mmap_directory; mod mmap_directory;
mod bundle_directory;
mod directory; mod directory;
mod directory_lock; mod directory_lock;
mod footer; mod footer;
@@ -19,7 +20,7 @@ mod watch_event_router;
pub mod error; pub mod error;
pub use self::directory::DirectoryLock; pub use self::directory::DirectoryLock;
pub use self::directory::{Directory, DirectoryClone}; pub use self::directory::{Directory, DirectoryClone, ReadOnlyDirectory};
pub use self::directory_lock::{Lock, INDEX_WRITER_LOCK, META_LOCK}; pub use self::directory_lock::{Lock, INDEX_WRITER_LOCK, META_LOCK};
pub use self::ram_directory::RAMDirectory; pub use self::ram_directory::RAMDirectory;
pub use self::read_only_source::ReadOnlySource; pub use self::read_only_source::ReadOnlySource;

View File

@@ -1,4 +1,6 @@
use crate::common::CountingWriter;
use crate::core::META_FILEPATH; use crate::core::META_FILEPATH;
use crate::directory::directory::ReadOnlyDirectory;
use crate::directory::error::{DeleteError, OpenReadError, OpenWriteError}; use crate::directory::error::{DeleteError, OpenReadError, OpenWriteError};
use crate::directory::AntiCallToken; use crate::directory::AntiCallToken;
use crate::directory::WatchCallbackList; use crate::directory::WatchCallbackList;
@@ -115,6 +117,22 @@ impl InnerDirectory {
fn total_mem_usage(&self) -> usize { fn total_mem_usage(&self) -> usize {
self.fs.values().map(|f| f.len()).sum() self.fs.values().map(|f| f.len()).sum()
} }
fn serialize_bundle(&self, wrt: &mut WritePtr) -> io::Result<()> {
let mut counting_writer = CountingWriter::wrap(wrt);
let mut file_index: HashMap<PathBuf, (u64, u64)> = HashMap::default();
for (path, source) in &self.fs {
let start = counting_writer.written_bytes();
counting_writer.write_all(source.as_slice())?;
let stop = counting_writer.written_bytes();
file_index.insert(path.to_path_buf(), (start, stop));
}
let index_offset = counting_writer.written_bytes();
serde_json::to_writer(&mut counting_writer, &file_index)?;
let index_offset_buffer = index_offset.to_le_bytes();
counting_writer.write_all(&index_offset_buffer[..])?;
Ok(())
}
} }
impl fmt::Debug for RAMDirectory { impl fmt::Debug for RAMDirectory {
@@ -145,28 +163,17 @@ impl RAMDirectory {
self.fs.read().unwrap().total_mem_usage() self.fs.read().unwrap().total_mem_usage()
} }
/// Write a copy of all of the files saved in the RAMDirectory in the target `Directory`. /// Serialize the RAMDirectory into a bundle.
/// ///
/// Files are all written using the `Directory::write` meaning, even if they were /// This method will fail, write nothing, and return an error if a
/// written using the `atomic_write` api. /// clone of this repository exists.
/// pub fn serialize_bundle(self, wrt: &mut WritePtr) -> io::Result<()> {
/// If an error is encounterred, files may be persisted partially. let inner_directory_rlock = self.fs.read().unwrap();
pub fn persist(&self, dest: &mut dyn Directory) -> crate::Result<()> { inner_directory_rlock.serialize_bundle(wrt)
let wlock = self.fs.write().unwrap();
for (path, source) in wlock.fs.iter() {
let mut dest_wrt = dest.open_write(path)?;
dest_wrt.write_all(source.as_slice())?;
dest_wrt.terminate()?;
}
Ok(())
} }
} }
impl Directory for RAMDirectory { impl Directory for RAMDirectory {
fn open_read(&self, path: &Path) -> result::Result<ReadOnlySource, OpenReadError> {
self.fs.read().unwrap().open_read(path)
}
fn delete(&self, path: &Path) -> result::Result<(), DeleteError> { fn delete(&self, path: &Path) -> result::Result<(), DeleteError> {
fail_point!("RAMDirectory::delete", |_| { fail_point!("RAMDirectory::delete", |_| {
use crate::directory::error::IOError; use crate::directory::error::IOError;
@@ -176,10 +183,6 @@ impl Directory for RAMDirectory {
self.fs.write().unwrap().delete(path) self.fs.write().unwrap().delete(path)
} }
fn exists(&self, path: &Path) -> bool {
self.fs.read().unwrap().exists(path)
}
fn open_write(&mut self, path: &Path) -> Result<WritePtr, OpenWriteError> { fn open_write(&mut self, path: &Path) -> Result<WritePtr, OpenWriteError> {
let mut fs = self.fs.write().unwrap(); let mut fs = self.fs.write().unwrap();
let path_buf = PathBuf::from(path); let path_buf = PathBuf::from(path);
@@ -193,10 +196,6 @@ impl Directory for RAMDirectory {
} }
} }
fn atomic_read(&self, path: &Path) -> Result<Vec<u8>, OpenReadError> {
Ok(self.open_read(path)?.as_slice().to_owned())
}
fn atomic_write(&mut self, path: &Path, data: &[u8]) -> io::Result<()> { fn atomic_write(&mut self, path: &Path, data: &[u8]) -> io::Result<()> {
fail_point!("RAMDirectory::atomic_write", |msg| Err(io::Error::new( fail_point!("RAMDirectory::atomic_write", |msg| Err(io::Error::new(
io::ErrorKind::Other, io::ErrorKind::Other,
@@ -221,27 +220,16 @@ impl Directory for RAMDirectory {
} }
} }
#[cfg(test)] impl ReadOnlyDirectory for RAMDirectory {
mod tests { fn open_read(&self, path: &Path) -> result::Result<ReadOnlySource, OpenReadError> {
use super::RAMDirectory; self.fs.read().unwrap().open_read(path)
use crate::Directory; }
use std::io::Write;
use std::path::Path;
#[test] fn exists(&self, path: &Path) -> bool {
fn test_persist() { self.fs.read().unwrap().exists(path)
let msg_atomic: &'static [u8] = b"atomic is the way"; }
let msg_seq: &'static [u8] = b"sequential is the way";
let path_atomic: &'static Path = Path::new("atomic"); fn atomic_read(&self, path: &Path) -> Result<Vec<u8>, OpenReadError> {
let path_seq: &'static Path = Path::new("seq"); Ok(self.open_read(path)?.as_slice().to_owned())
let mut directory = RAMDirectory::create();
assert!(directory.atomic_write(path_atomic, msg_atomic).is_ok());
let mut wrt = directory.open_write(path_seq).unwrap();
assert!(wrt.write_all(msg_seq).is_ok());
assert!(wrt.flush().is_ok());
let mut directory_copy = RAMDirectory::create();
assert!(directory.persist(&mut directory_copy).is_ok());
assert_eq!(directory_copy.atomic_read(path_atomic).unwrap(), msg_atomic);
assert_eq!(directory_copy.atomic_read(path_seq).unwrap(), msg_seq);
} }
} }

View File

@@ -1,47 +1,58 @@
use crate::common::BitSet;
use crate::fastfield::DeleteBitSet; use crate::fastfield::DeleteBitSet;
use crate::DocId; use crate::DocId;
use std::borrow::Borrow; use std::borrow::Borrow;
use std::borrow::BorrowMut; use std::borrow::BorrowMut;
use std::cmp::Ordering;
/// Sentinel value returned when a DocSet has been entirely consumed. /// Expresses the outcome of a call to `DocSet`'s `.skip_next(...)`.
/// #[derive(PartialEq, Eq, Debug)]
/// This is not u32::MAX as one would have expected, due to the lack of SSE2 instructions pub enum SkipResult {
/// to compare [u32; 4]. /// target was in the docset
pub const TERMINATED: DocId = std::i32::MAX as u32; Reached,
/// target was not in the docset, skipping stopped as a greater element was found
OverStep,
/// the docset was entirely consumed without finding the target, nor any
/// element greater than the target.
End,
}
/// Represents an iterable set of sorted doc ids. /// Represents an iterable set of sorted doc ids.
pub trait DocSet { pub trait DocSet {
/// Goes to the next element. /// Goes to the next element.
/// /// `.advance(...)` needs to be called a first time to point to the correct
/// The DocId of the next element is returned. /// element.
/// In other words we should always have : fn advance(&mut self) -> bool;
/// ```ignore
/// let doc = docset.advance();
/// assert_eq!(doc, docset.doc());
/// ```
///
/// If we reached the end of the DocSet, TERMINATED should be returned.
///
/// Calling `.advance()` on a terminated DocSet should be supported, and TERMINATED should
/// be returned.
/// TODO Test existing docsets.
fn advance(&mut self) -> DocId;
/// Advances the DocSet forward until reaching the target, or going to the /// After skipping, position the iterator in such a way that `.doc()`
/// lowest DocId greater than the target. /// will return a value greater than or equal to target.
/// ///
/// If the end of the DocSet is reached, TERMINATED is returned. /// SkipResult expresses whether the `target value` was reached, overstepped,
/// or if the `DocSet` was entirely consumed without finding any value
/// greater or equal to the `target`.
/// ///
/// Calling `.seek(target)` on a terminated DocSet is legal. Implementation /// WARNING: Calling skip always advances the docset.
/// of DocSet should support it. /// More specifically, if the docset is already positionned on the target
/// skipping will advance to the next position and return SkipResult::Overstep.
/// ///
/// Calling `seek(TERMINATED)` is also legal and is the normal way to consume a DocSet. /// If `.skip_next()` oversteps, then the docset must be positionned correctly
fn seek(&mut self, target: DocId) -> DocId { /// on an existing document. In other words, `.doc()` should return the first document
let mut doc = self.doc(); /// greater than `DocId`.
while doc < target { fn skip_next(&mut self, target: DocId) -> SkipResult {
doc = self.advance(); if !self.advance() {
return SkipResult::End;
}
loop {
match self.doc().cmp(&target) {
Ordering::Less => {
if !self.advance() {
return SkipResult::End;
}
}
Ordering::Equal => return SkipResult::Reached,
Ordering::Greater => return SkipResult::OverStep,
}
} }
doc
} }
/// Fills a given mutable buffer with the next doc ids from the /// Fills a given mutable buffer with the next doc ids from the
@@ -60,38 +71,38 @@ pub trait DocSet {
/// use case where batching. The normal way to /// use case where batching. The normal way to
/// go through the `DocId`'s is to call `.advance()`. /// go through the `DocId`'s is to call `.advance()`.
fn fill_buffer(&mut self, buffer: &mut [DocId]) -> usize { fn fill_buffer(&mut self, buffer: &mut [DocId]) -> usize {
if self.doc() == TERMINATED {
return 0;
}
for (i, buffer_val) in buffer.iter_mut().enumerate() { for (i, buffer_val) in buffer.iter_mut().enumerate() {
*buffer_val = self.doc(); if self.advance() {
if self.advance() == TERMINATED { *buffer_val = self.doc();
return i + 1; } else {
return i;
} }
} }
buffer.len() buffer.len()
} }
/// Returns the current document /// Returns the current document
/// Right after creating a new DocSet, the docset points to the first document.
///
/// If the DocSet is empty, .doc() should return `TERMINATED`.
fn doc(&self) -> DocId; fn doc(&self) -> DocId;
/// Returns a best-effort hint of the /// Returns a best-effort hint of the
/// length of the docset. /// length of the docset.
fn size_hint(&self) -> u32; fn size_hint(&self) -> u32;
/// Appends all docs to a `bitset`.
fn append_to_bitset(&mut self, bitset: &mut BitSet) {
while self.advance() {
bitset.insert(self.doc());
}
}
/// Returns the number documents matching. /// Returns the number documents matching.
/// Calling this method consumes the `DocSet`. /// Calling this method consumes the `DocSet`.
fn count(&mut self, delete_bitset: &DeleteBitSet) -> u32 { fn count(&mut self, delete_bitset: &DeleteBitSet) -> u32 {
let mut count = 0u32; let mut count = 0u32;
let mut doc = self.doc(); while self.advance() {
while doc != TERMINATED { if !delete_bitset.is_deleted(self.doc()) {
if !delete_bitset.is_deleted(doc) {
count += 1u32; count += 1u32;
} }
doc = self.advance();
} }
count count
} }
@@ -103,42 +114,22 @@ pub trait DocSet {
/// given by `count()`. /// given by `count()`.
fn count_including_deleted(&mut self) -> u32 { fn count_including_deleted(&mut self) -> u32 {
let mut count = 0u32; let mut count = 0u32;
let mut doc = self.doc(); while self.advance() {
while doc != TERMINATED {
count += 1u32; count += 1u32;
doc = self.advance();
} }
count count
} }
} }
impl<'a> DocSet for &'a mut dyn DocSet {
fn advance(&mut self) -> u32 {
(**self).advance()
}
fn seek(&mut self, target: DocId) -> DocId {
(**self).seek(target)
}
fn doc(&self) -> u32 {
(**self).doc()
}
fn size_hint(&self) -> u32 {
(**self).size_hint()
}
}
impl<TDocSet: DocSet + ?Sized> DocSet for Box<TDocSet> { impl<TDocSet: DocSet + ?Sized> DocSet for Box<TDocSet> {
fn advance(&mut self) -> DocId { fn advance(&mut self) -> bool {
let unboxed: &mut TDocSet = self.borrow_mut(); let unboxed: &mut TDocSet = self.borrow_mut();
unboxed.advance() unboxed.advance()
} }
fn seek(&mut self, target: DocId) -> DocId { fn skip_next(&mut self, target: DocId) -> SkipResult {
let unboxed: &mut TDocSet = self.borrow_mut(); let unboxed: &mut TDocSet = self.borrow_mut();
unboxed.seek(target) unboxed.skip_next(target)
} }
fn doc(&self) -> DocId { fn doc(&self) -> DocId {
@@ -160,4 +151,9 @@ impl<TDocSet: DocSet + ?Sized> DocSet for Box<TDocSet> {
let unboxed: &mut TDocSet = self.borrow_mut(); let unboxed: &mut TDocSet = self.borrow_mut();
unboxed.count_including_deleted() unboxed.count_including_deleted()
} }
fn append_to_bitset(&mut self, bitset: &mut BitSet) {
let unboxed: &mut TDocSet = self.borrow_mut();
unboxed.append_to_bitset(bitset);
}
} }

View File

@@ -7,6 +7,7 @@ use crate::directory::error::{Incompatibility, LockError};
use crate::fastfield::FastFieldNotAvailableError; use crate::fastfield::FastFieldNotAvailableError;
use crate::query; use crate::query;
use crate::schema; use crate::schema;
use serde_json;
use std::fmt; use std::fmt;
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::PoisonError; use std::sync::PoisonError;
@@ -24,10 +25,10 @@ impl DataCorruption {
} }
} }
pub fn comment_only(comment: String) -> DataCorruption { pub fn comment_only<TS: ToString>(comment: TS) -> DataCorruption {
DataCorruption { DataCorruption {
filepath: None, filepath: None,
comment, comment: comment.to_string(),
} }
} }
} }

View File

@@ -179,7 +179,7 @@ mod tests {
use super::*; use super::*;
use crate::common::CompositeFile; use crate::common::CompositeFile;
use crate::directory::{Directory, RAMDirectory, WritePtr}; use crate::directory::{Directory, RAMDirectory, ReadOnlyDirectory, WritePtr};
use crate::fastfield::FastFieldReader; use crate::fastfield::FastFieldReader;
use crate::merge_policy::NoMergePolicy; use crate::merge_policy::NoMergePolicy;
use crate::schema::Field; use crate::schema::Field;

View File

@@ -7,6 +7,9 @@ pub use self::writer::MultiValueIntFastFieldWriter;
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use time;
use self::time::Duration;
use crate::collector::TopDocs; use crate::collector::TopDocs;
use crate::query::QueryParser; use crate::query::QueryParser;
use crate::schema::Cardinality; use crate::schema::Cardinality;
@@ -14,7 +17,6 @@ mod tests {
use crate::schema::IntOptions; use crate::schema::IntOptions;
use crate::schema::Schema; use crate::schema::Schema;
use crate::Index; use crate::Index;
use chrono::Duration;
#[test] #[test]
fn test_multivalued_u64() { fn test_multivalued_u64() {

View File

@@ -6,6 +6,7 @@ use crate::schema::{Document, Field};
use crate::termdict::TermOrdinal; use crate::termdict::TermOrdinal;
use crate::DocId; use crate::DocId;
use fnv::FnvHashMap; use fnv::FnvHashMap;
use itertools::Itertools;
use std::io; use std::io;
/// Writer for multi-valued (as in, more than one value per document) /// Writer for multi-valued (as in, more than one value per document)
@@ -150,8 +151,8 @@ impl MultiValueIntFastFieldWriter {
} }
} }
None => { None => {
let val_min_max = crate::common::minmax(self.vals.iter().cloned()); let val_min_max = self.vals.iter().cloned().minmax();
let (val_min, val_max) = val_min_max.unwrap_or((0u64, 0u64)); let (val_min, val_max) = val_min_max.into_option().unwrap_or((0u64, 0u64));
value_serializer = value_serializer =
serializer.new_u64_fast_field_with_idx(self.field, val_min, val_max, 1)?; serializer.new_u64_fast_field_with_idx(self.field, val_min, val_max, 1)?;
for &val in &self.vals { for &val in &self.vals {

View File

@@ -4,7 +4,7 @@ use crate::common::compute_num_bits;
use crate::common::BinarySerializable; use crate::common::BinarySerializable;
use crate::common::CompositeFile; use crate::common::CompositeFile;
use crate::directory::ReadOnlySource; use crate::directory::ReadOnlySource;
use crate::directory::{Directory, RAMDirectory, WritePtr}; use crate::directory::{Directory, RAMDirectory, ReadOnlyDirectory, WritePtr};
use crate::fastfield::{FastFieldSerializer, FastFieldsWriter}; use crate::fastfield::{FastFieldSerializer, FastFieldsWriter};
use crate::schema::Schema; use crate::schema::Schema;
use crate::schema::FAST; use crate::schema::FAST;

View File

@@ -4,6 +4,7 @@ use crate::fastfield::MultiValueIntFastFieldReader;
use crate::fastfield::{FastFieldNotAvailableError, FastFieldReader}; use crate::fastfield::{FastFieldNotAvailableError, FastFieldReader};
use crate::schema::{Cardinality, Field, FieldType, Schema}; use crate::schema::{Cardinality, Field, FieldType, Schema};
use crate::space_usage::PerFieldSpaceUsage; use crate::space_usage::PerFieldSpaceUsage;
use crate::Result;
use std::collections::HashMap; use std::collections::HashMap;
/// Provides access to all of the FastFieldReader. /// Provides access to all of the FastFieldReader.
@@ -53,7 +54,7 @@ impl FastFieldReaders {
pub(crate) fn load_all( pub(crate) fn load_all(
schema: &Schema, schema: &Schema,
fast_fields_composite: &CompositeFile, fast_fields_composite: &CompositeFile,
) -> crate::Result<FastFieldReaders> { ) -> Result<FastFieldReaders> {
let mut fast_field_readers = FastFieldReaders { let mut fast_field_readers = FastFieldReaders {
fast_field_i64: Default::default(), fast_field_i64: Default::default(),
fast_field_u64: Default::default(), fast_field_u64: Default::default(),

View File

@@ -10,7 +10,7 @@ use crate::core::SegmentMeta;
use crate::core::SegmentReader; use crate::core::SegmentReader;
use crate::directory::TerminatingWrite; use crate::directory::TerminatingWrite;
use crate::directory::{DirectoryLock, GarbageCollectionResult}; use crate::directory::{DirectoryLock, GarbageCollectionResult};
use crate::docset::{DocSet, TERMINATED}; use crate::docset::DocSet;
use crate::error::TantivyError; use crate::error::TantivyError;
use crate::fastfield::write_delete_bitset; use crate::fastfield::write_delete_bitset;
use crate::indexer::delete_queue::{DeleteCursor, DeleteQueue}; use crate::indexer::delete_queue::{DeleteCursor, DeleteQueue};
@@ -112,15 +112,15 @@ fn compute_deleted_bitset(
if let Some(mut docset) = if let Some(mut docset) =
inverted_index.read_postings(&delete_op.term, IndexRecordOption::Basic) inverted_index.read_postings(&delete_op.term, IndexRecordOption::Basic)
{ {
let mut deleted_doc = docset.doc(); while docset.advance() {
while deleted_doc != TERMINATED { let deleted_doc = docset.doc();
if deleted_doc < limit_doc { if deleted_doc < limit_doc {
delete_bitset.insert(deleted_doc); delete_bitset.insert(deleted_doc);
might_have_changed = true; might_have_changed = true;
} }
deleted_doc = docset.advance();
} }
} }
delete_cursor.advance(); delete_cursor.advance();
} }
Ok(might_have_changed) Ok(might_have_changed)
@@ -155,8 +155,6 @@ pub(crate) fn advance_deletes(
None => BitSet::with_max_value(max_doc), None => BitSet::with_max_value(max_doc),
}; };
let num_deleted_docs_before = segment.meta().num_deleted_docs();
compute_deleted_bitset( compute_deleted_bitset(
&mut delete_bitset, &mut delete_bitset,
&segment_reader, &segment_reader,
@@ -166,8 +164,6 @@ pub(crate) fn advance_deletes(
)?; )?;
// TODO optimize // TODO optimize
// It should be possible to do something smarter by manipulation bitsets directly
// to compute this union.
if let Some(seg_delete_bitset) = segment_reader.delete_bitset() { if let Some(seg_delete_bitset) = segment_reader.delete_bitset() {
for doc in 0u32..max_doc { for doc in 0u32..max_doc {
if seg_delete_bitset.is_deleted(doc) { if seg_delete_bitset.is_deleted(doc) {
@@ -176,9 +172,8 @@ pub(crate) fn advance_deletes(
} }
} }
let num_deleted_docs: u32 = delete_bitset.len() as u32; let num_deleted_docs = delete_bitset.len();
if num_deleted_docs > num_deleted_docs_before { if num_deleted_docs > 0 {
// There are new deletes. We need to write a new delete file.
segment = segment.with_delete_meta(num_deleted_docs as u32, target_opstamp); segment = segment.with_delete_meta(num_deleted_docs as u32, target_opstamp);
let mut delete_file = segment.open_write(SegmentComponent::DELETE)?; let mut delete_file = segment.open_write(SegmentComponent::DELETE)?;
write_delete_bitset(&delete_bitset, max_doc, &mut delete_file)?; write_delete_bitset(&delete_bitset, max_doc, &mut delete_file)?;
@@ -346,7 +341,7 @@ impl IndexWriter {
fn drop_sender(&mut self) { fn drop_sender(&mut self) {
let (sender, _receiver) = channel::bounded(1); let (sender, _receiver) = channel::bounded(1);
self.operation_sender = sender; mem::replace(&mut self.operation_sender, sender);
} }
/// If there are some merging threads, blocks until they all finish their work and /// If there are some merging threads, blocks until they all finish their work and
@@ -808,46 +803,6 @@ mod tests {
assert_eq!(batch_opstamp1, 2u64); assert_eq!(batch_opstamp1, 2u64);
} }
#[test]
fn test_no_need_to_rewrite_delete_file_if_no_new_deletes() {
let mut schema_builder = schema::Schema::builder();
let text_field = schema_builder.add_text_field("text", schema::TEXT);
let index = Index::create_in_ram(schema_builder.build());
let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap();
index_writer.add_document(doc!(text_field => "hello1"));
index_writer.add_document(doc!(text_field => "hello2"));
assert!(index_writer.commit().is_ok());
let reader = index.reader().unwrap();
let searcher = reader.searcher();
assert_eq!(searcher.segment_readers().len(), 1);
assert_eq!(searcher.segment_reader(0u32).num_deleted_docs(), 0);
index_writer.delete_term(Term::from_field_text(text_field, "hello1"));
assert!(index_writer.commit().is_ok());
assert!(reader.reload().is_ok());
let searcher = reader.searcher();
assert_eq!(searcher.segment_readers().len(), 1);
assert_eq!(searcher.segment_reader(0u32).num_deleted_docs(), 1);
let previous_delete_opstamp = index.load_metas().unwrap().segments[0].delete_opstamp();
// All docs containing hello1 have been already removed.
// We should not update the delete meta.
index_writer.delete_term(Term::from_field_text(text_field, "hello1"));
assert!(index_writer.commit().is_ok());
assert!(reader.reload().is_ok());
let searcher = reader.searcher();
assert_eq!(searcher.segment_readers().len(), 1);
assert_eq!(searcher.segment_reader(0u32).num_deleted_docs(), 1);
let after_delete_opstamp = index.load_metas().unwrap().segments[0].delete_opstamp();
assert_eq!(after_delete_opstamp, previous_delete_opstamp);
}
#[test] #[test]
fn test_ordered_batched_operations() { fn test_ordered_batched_operations() {
// * one delete for `doc!(field=>"a")` // * one delete for `doc!(field=>"a")`
@@ -942,7 +897,7 @@ mod tests {
let index_writer = index.writer(3_000_000).unwrap(); let index_writer = index.writer(3_000_000).unwrap();
assert_eq!( assert_eq!(
format!("{:?}", index_writer.get_merge_policy()), format!("{:?}", index_writer.get_merge_policy()),
"LogMergePolicy { min_merge_size: 8, max_merge_size: 10000000, min_layer_size: 10000, \ "LogMergePolicy { min_merge_size: 8, min_layer_size: 10000, \
level_log_size: 0.75 }" level_log_size: 0.75 }"
); );
let merge_policy = Box::new(NoMergePolicy::default()); let merge_policy = Box::new(NoMergePolicy::default());

View File

@@ -6,14 +6,12 @@ use std::f64;
const DEFAULT_LEVEL_LOG_SIZE: f64 = 0.75; const DEFAULT_LEVEL_LOG_SIZE: f64 = 0.75;
const DEFAULT_MIN_LAYER_SIZE: u32 = 10_000; const DEFAULT_MIN_LAYER_SIZE: u32 = 10_000;
const DEFAULT_MIN_MERGE_SIZE: usize = 8; const DEFAULT_MIN_MERGE_SIZE: usize = 8;
const DEFAULT_MAX_MERGE_SIZE: usize = 10_000_000;
/// `LogMergePolicy` tries tries to merge segments that have a similar number of /// `LogMergePolicy` tries tries to merge segments that have a similar number of
/// documents. /// documents.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct LogMergePolicy { pub struct LogMergePolicy {
min_merge_size: usize, min_merge_size: usize,
max_merge_size: usize,
min_layer_size: u32, min_layer_size: u32,
level_log_size: f64, level_log_size: f64,
} }
@@ -28,12 +26,6 @@ impl LogMergePolicy {
self.min_merge_size = min_merge_size; self.min_merge_size = min_merge_size;
} }
/// Set the maximum number docs in a segment for it to be considered for
/// merging.
pub fn set_max_merge_size(&mut self, max_merge_size: usize) {
self.max_merge_size = max_merge_size;
}
/// Set the minimum segment size under which all segment belong /// Set the minimum segment size under which all segment belong
/// to the same level. /// to the same level.
pub fn set_min_layer_size(&mut self, min_layer_size: u32) { pub fn set_min_layer_size(&mut self, min_layer_size: u32) {
@@ -54,44 +46,39 @@ impl LogMergePolicy {
impl MergePolicy for LogMergePolicy { impl MergePolicy for LogMergePolicy {
fn compute_merge_candidates(&self, segments: &[SegmentMeta]) -> Vec<MergeCandidate> { fn compute_merge_candidates(&self, segments: &[SegmentMeta]) -> Vec<MergeCandidate> {
if segments.is_empty() {
return Vec::new();
}
let mut size_sorted_tuples = segments let mut size_sorted_tuples = segments
.iter() .iter()
.map(SegmentMeta::num_docs) .map(SegmentMeta::num_docs)
.filter(|s| s <= &(self.max_merge_size as u32))
.enumerate() .enumerate()
.collect::<Vec<(usize, u32)>>(); .collect::<Vec<(usize, u32)>>();
size_sorted_tuples.sort_by(|x, y| y.1.cmp(&(x.1))); size_sorted_tuples.sort_by(|x, y| y.1.cmp(&(x.1)));
if size_sorted_tuples.len() <= 1 {
return Vec::new();
}
let size_sorted_log_tuples: Vec<_> = size_sorted_tuples let size_sorted_log_tuples: Vec<_> = size_sorted_tuples
.into_iter() .into_iter()
.map(|(ind, num_docs)| (ind, f64::from(self.clip_min_size(num_docs)).log2())) .map(|(ind, num_docs)| (ind, f64::from(self.clip_min_size(num_docs)).log2()))
.collect(); .collect();
if let Some(&(first_ind, first_score)) = size_sorted_log_tuples.first() { let (first_ind, first_score) = size_sorted_log_tuples[0];
let mut current_max_log_size = first_score; let mut current_max_log_size = first_score;
let mut levels = vec![vec![first_ind]]; let mut levels = vec![vec![first_ind]];
for &(ind, score) in (&size_sorted_log_tuples).iter().skip(1) { for &(ind, score) in (&size_sorted_log_tuples).iter().skip(1) {
if score < (current_max_log_size - self.level_log_size) { if score < (current_max_log_size - self.level_log_size) {
current_max_log_size = score; current_max_log_size = score;
levels.push(Vec::new()); levels.push(Vec::new());
}
levels.last_mut().unwrap().push(ind);
} }
levels levels.last_mut().unwrap().push(ind);
.iter()
.filter(|level| level.len() >= self.min_merge_size)
.map(|ind_vec| {
MergeCandidate(ind_vec.iter().map(|&ind| segments[ind].id()).collect())
})
.collect()
} else {
return vec![];
} }
levels
.iter()
.filter(|level| level.len() >= self.min_merge_size)
.map(|ind_vec| MergeCandidate(ind_vec.iter().map(|&ind| segments[ind].id()).collect()))
.collect()
} }
} }
@@ -99,7 +86,6 @@ impl Default for LogMergePolicy {
fn default() -> LogMergePolicy { fn default() -> LogMergePolicy {
LogMergePolicy { LogMergePolicy {
min_merge_size: DEFAULT_MIN_MERGE_SIZE, min_merge_size: DEFAULT_MIN_MERGE_SIZE,
max_merge_size: DEFAULT_MAX_MERGE_SIZE,
min_layer_size: DEFAULT_MIN_LAYER_SIZE, min_layer_size: DEFAULT_MIN_LAYER_SIZE,
level_log_size: DEFAULT_LEVEL_LOG_SIZE, level_log_size: DEFAULT_LEVEL_LOG_SIZE,
} }
@@ -118,7 +104,6 @@ mod tests {
fn test_merge_policy() -> LogMergePolicy { fn test_merge_policy() -> LogMergePolicy {
let mut log_merge_policy = LogMergePolicy::default(); let mut log_merge_policy = LogMergePolicy::default();
log_merge_policy.set_min_merge_size(3); log_merge_policy.set_min_merge_size(3);
log_merge_policy.set_max_merge_size(100_000);
log_merge_policy.set_min_layer_size(2); log_merge_policy.set_min_layer_size(2);
log_merge_policy log_merge_policy
} }
@@ -156,11 +141,11 @@ mod tests {
create_random_segment_meta(10), create_random_segment_meta(10),
create_random_segment_meta(10), create_random_segment_meta(10),
create_random_segment_meta(10), create_random_segment_meta(10),
create_random_segment_meta(1_000), create_random_segment_meta(1000),
create_random_segment_meta(1_000), create_random_segment_meta(1000),
create_random_segment_meta(1_000), create_random_segment_meta(1000),
create_random_segment_meta(10_000), create_random_segment_meta(10000),
create_random_segment_meta(10_000), create_random_segment_meta(10000),
create_random_segment_meta(10), create_random_segment_meta(10),
create_random_segment_meta(10), create_random_segment_meta(10),
create_random_segment_meta(10), create_random_segment_meta(10),
@@ -183,7 +168,6 @@ mod tests {
let result_list = test_merge_policy().compute_merge_candidates(&test_input); let result_list = test_merge_policy().compute_merge_candidates(&test_input);
assert_eq!(result_list.len(), 2); assert_eq!(result_list.len(), 2);
} }
#[test] #[test]
fn test_log_merge_policy_small_segments() { fn test_log_merge_policy_small_segments() {
// segments under min_layer_size are merged together // segments under min_layer_size are merged together
@@ -198,30 +182,4 @@ mod tests {
let result_list = test_merge_policy().compute_merge_candidates(&test_input); let result_list = test_merge_policy().compute_merge_candidates(&test_input);
assert_eq!(result_list.len(), 1); assert_eq!(result_list.len(), 1);
} }
#[test]
fn test_log_merge_policy_all_segments_too_large_to_merge() {
let eight_large_segments: Vec<SegmentMeta> =
std::iter::repeat_with(|| create_random_segment_meta(100_001))
.take(8)
.collect();
assert!(test_merge_policy()
.compute_merge_candidates(&eight_large_segments)
.is_empty());
}
#[test]
fn test_large_merge_segments() {
let test_input = vec![
create_random_segment_meta(1_000_000),
create_random_segment_meta(100_001),
create_random_segment_meta(100_000),
create_random_segment_meta(100_000),
create_random_segment_meta(100_000),
];
let result_list = test_merge_policy().compute_merge_candidates(&test_input);
// Do not include large segments
assert_eq!(result_list.len(), 1);
assert_eq!(result_list[0].0.len(), 3)
}
} }

View File

@@ -2,7 +2,7 @@ use crate::common::MAX_DOC_LIMIT;
use crate::core::Segment; use crate::core::Segment;
use crate::core::SegmentReader; use crate::core::SegmentReader;
use crate::core::SerializableSegment; use crate::core::SerializableSegment;
use crate::docset::{DocSet, TERMINATED}; use crate::docset::DocSet;
use crate::fastfield::BytesFastFieldReader; use crate::fastfield::BytesFastFieldReader;
use crate::fastfield::DeleteBitSet; use crate::fastfield::DeleteBitSet;
use crate::fastfield::FastFieldReader; use crate::fastfield::FastFieldReader;
@@ -21,6 +21,9 @@ use crate::store::StoreWriter;
use crate::termdict::TermMerger; use crate::termdict::TermMerger;
use crate::termdict::TermOrdinal; use crate::termdict::TermOrdinal;
use crate::DocId; use crate::DocId;
use crate::Result;
use crate::TantivyError;
use itertools::Itertools;
use std::cmp; use std::cmp;
use std::collections::HashMap; use std::collections::HashMap;
@@ -69,11 +72,11 @@ fn compute_min_max_val(
Some(delete_bitset) => { Some(delete_bitset) => {
// some deleted documents, // some deleted documents,
// we need to recompute the max / min // we need to recompute the max / min
crate::common::minmax( (0..max_doc)
(0..max_doc) .filter(|doc_id| delete_bitset.is_alive(*doc_id))
.filter(|doc_id| delete_bitset.is_alive(*doc_id)) .map(|doc_id| u64_reader.get(doc_id))
.map(|doc_id| u64_reader.get(doc_id)), .minmax()
) .into_option()
} }
None => { None => {
// no deleted documents, // no deleted documents,
@@ -140,7 +143,7 @@ impl DeltaComputer {
} }
impl IndexMerger { impl IndexMerger {
pub fn open(schema: Schema, segments: &[Segment]) -> crate::Result<IndexMerger> { pub fn open(schema: Schema, segments: &[Segment]) -> Result<IndexMerger> {
let mut readers = vec![]; let mut readers = vec![];
let mut max_doc: u32 = 0u32; let mut max_doc: u32 = 0u32;
for segment in segments { for segment in segments {
@@ -156,7 +159,7 @@ impl IndexMerger {
which exceeds the limit {}.", which exceeds the limit {}.",
max_doc, MAX_DOC_LIMIT max_doc, MAX_DOC_LIMIT
); );
return Err(crate::TantivyError::InvalidArgument(err_msg)); return Err(TantivyError::InvalidArgument(err_msg));
} }
Ok(IndexMerger { Ok(IndexMerger {
schema, schema,
@@ -165,10 +168,7 @@ impl IndexMerger {
}) })
} }
fn write_fieldnorms( fn write_fieldnorms(&self, fieldnorms_serializer: &mut FieldNormsSerializer) -> Result<()> {
&self,
fieldnorms_serializer: &mut FieldNormsSerializer,
) -> crate::Result<()> {
let fields = FieldNormsWriter::fields_with_fieldnorm(&self.schema); let fields = FieldNormsWriter::fields_with_fieldnorm(&self.schema);
let mut fieldnorms_data = Vec::with_capacity(self.max_doc as usize); let mut fieldnorms_data = Vec::with_capacity(self.max_doc as usize);
for field in fields { for field in fields {
@@ -189,7 +189,7 @@ impl IndexMerger {
&self, &self,
fast_field_serializer: &mut FastFieldSerializer, fast_field_serializer: &mut FastFieldSerializer,
mut term_ord_mappings: HashMap<Field, TermOrdinalMapping>, mut term_ord_mappings: HashMap<Field, TermOrdinalMapping>,
) -> crate::Result<()> { ) -> Result<()> {
for (field, field_entry) in self.schema.fields() { for (field, field_entry) in self.schema.fields() {
let field_type = field_entry.field_type(); let field_type = field_entry.field_type();
match *field_type { match *field_type {
@@ -234,7 +234,7 @@ impl IndexMerger {
&self, &self,
field: Field, field: Field,
fast_field_serializer: &mut FastFieldSerializer, fast_field_serializer: &mut FastFieldSerializer,
) -> crate::Result<()> { ) -> Result<()> {
let mut u64_readers = vec![]; let mut u64_readers = vec![];
let mut min_value = u64::max_value(); let mut min_value = u64::max_value();
let mut max_value = u64::min_value(); let mut max_value = u64::min_value();
@@ -284,7 +284,7 @@ impl IndexMerger {
&self, &self,
field: Field, field: Field,
fast_field_serializer: &mut FastFieldSerializer, fast_field_serializer: &mut FastFieldSerializer,
) -> crate::Result<()> { ) -> Result<()> {
let mut total_num_vals = 0u64; let mut total_num_vals = 0u64;
let mut u64s_readers: Vec<MultiValueIntFastFieldReader<u64>> = Vec::new(); let mut u64s_readers: Vec<MultiValueIntFastFieldReader<u64>> = Vec::new();
@@ -331,7 +331,7 @@ impl IndexMerger {
field: Field, field: Field,
term_ordinal_mappings: &TermOrdinalMapping, term_ordinal_mappings: &TermOrdinalMapping,
fast_field_serializer: &mut FastFieldSerializer, fast_field_serializer: &mut FastFieldSerializer,
) -> crate::Result<()> { ) -> Result<()> {
// Multifastfield consists in 2 fastfields. // Multifastfield consists in 2 fastfields.
// The first serves as an index into the second one and is stricly increasing. // The first serves as an index into the second one and is stricly increasing.
// The second contains the actual values. // The second contains the actual values.
@@ -371,7 +371,7 @@ impl IndexMerger {
&self, &self,
field: Field, field: Field,
fast_field_serializer: &mut FastFieldSerializer, fast_field_serializer: &mut FastFieldSerializer,
) -> crate::Result<()> { ) -> Result<()> {
// Multifastfield consists in 2 fastfields. // Multifastfield consists in 2 fastfields.
// The first serves as an index into the second one and is stricly increasing. // The first serves as an index into the second one and is stricly increasing.
// The second contains the actual values. // The second contains the actual values.
@@ -436,7 +436,7 @@ impl IndexMerger {
&self, &self,
field: Field, field: Field,
fast_field_serializer: &mut FastFieldSerializer, fast_field_serializer: &mut FastFieldSerializer,
) -> crate::Result<()> { ) -> Result<()> {
let mut total_num_vals = 0u64; let mut total_num_vals = 0u64;
let mut bytes_readers: Vec<BytesFastFieldReader> = Vec::new(); let mut bytes_readers: Vec<BytesFastFieldReader> = Vec::new();
@@ -492,7 +492,7 @@ impl IndexMerger {
indexed_field: Field, indexed_field: Field,
field_type: &FieldType, field_type: &FieldType,
serializer: &mut InvertedIndexSerializer, serializer: &mut InvertedIndexSerializer,
) -> crate::Result<Option<TermOrdinalMapping>> { ) -> Result<Option<TermOrdinalMapping>> {
let mut positions_buffer: Vec<u32> = Vec::with_capacity(1_000); let mut positions_buffer: Vec<u32> = Vec::with_capacity(1_000);
let mut delta_computer = DeltaComputer::new(); let mut delta_computer = DeltaComputer::new();
let field_readers = self let field_readers = self
@@ -574,12 +574,10 @@ impl IndexMerger {
let inverted_index = segment_reader.inverted_index(indexed_field); let inverted_index = segment_reader.inverted_index(indexed_field);
let mut segment_postings = inverted_index let mut segment_postings = inverted_index
.read_postings_from_terminfo(term_info, segment_postings_option); .read_postings_from_terminfo(term_info, segment_postings_option);
let mut doc = segment_postings.doc(); while segment_postings.advance() {
while doc != TERMINATED { if !segment_reader.is_deleted(segment_postings.doc()) {
if !segment_reader.is_deleted(doc) {
return Some((segment_ord, segment_postings)); return Some((segment_ord, segment_postings));
} }
doc = segment_postings.advance();
} }
None None
}) })
@@ -589,45 +587,57 @@ impl IndexMerger {
// of all of the segments containing the given term. // of all of the segments containing the given term.
// //
// These segments are non-empty and advance has already been called. // These segments are non-empty and advance has already been called.
if segment_postings.is_empty() { if !segment_postings.is_empty() {
continue; // If not, the `term` will be entirely removed.
}
// If not, the `term` will be entirely removed.
// We know that there is at least one document containing // We know that there is at least one document containing
// the term, so we add it. // the term, so we add it.
let to_term_ord = field_serializer.new_term(term_bytes)?; let to_term_ord = field_serializer.new_term(term_bytes)?;
if let Some(ref mut term_ord_mapping) = term_ord_mapping_opt { if let Some(ref mut term_ord_mapping) = term_ord_mapping_opt {
for (segment_ord, from_term_ord) in merged_terms.matching_segments() { for (segment_ord, from_term_ord) in merged_terms.matching_segments() {
term_ord_mapping.register_from_to(segment_ord, from_term_ord, to_term_ord); term_ord_mapping.register_from_to(segment_ord, from_term_ord, to_term_ord);
}
}
// We can now serialize this postings, by pushing each document to the
// postings serializer.
for (segment_ord, mut segment_postings) in segment_postings {
let old_to_new_doc_id = &merged_doc_id_map[segment_ord];
let mut doc = segment_postings.doc();
while doc != TERMINATED {
// deleted doc are skipped as they do not have a `remapped_doc_id`.
if let Some(remapped_doc_id) = old_to_new_doc_id[doc as usize] {
// we make sure to only write the term iff
// there is at least one document.
let term_freq = segment_postings.term_freq();
segment_postings.positions(&mut positions_buffer);
let delta_positions = delta_computer.compute_delta(&positions_buffer);
field_serializer.write_doc(remapped_doc_id, term_freq, delta_positions)?;
} }
doc = segment_postings.advance();
} }
}
// closing the term. // We can now serialize this postings, by pushing each document to the
field_serializer.close_term()?; // postings serializer.
for (segment_ord, mut segment_postings) in segment_postings {
let old_to_new_doc_id = &merged_doc_id_map[segment_ord];
loop {
let doc = segment_postings.doc();
// `.advance()` has been called once before the loop.
//
// It was required to make sure we only consider segments
// that effectively contain at least one non-deleted document
// and remove terms that do not have documents associated.
//
// For this reason, we cannot use a `while segment_postings.advance()` loop.
// deleted doc are skipped as they do not have a `remapped_doc_id`.
if let Some(remapped_doc_id) = old_to_new_doc_id[doc as usize] {
// we make sure to only write the term iff
// there is at least one document.
let term_freq = segment_postings.term_freq();
segment_postings.positions(&mut positions_buffer);
let delta_positions = delta_computer.compute_delta(&positions_buffer);
field_serializer.write_doc(
remapped_doc_id,
term_freq,
delta_positions,
)?;
}
if !segment_postings.advance() {
break;
}
}
}
// closing the term.
field_serializer.close_term()?;
}
} }
field_serializer.close()?; field_serializer.close()?;
Ok(term_ord_mapping_opt) Ok(term_ord_mapping_opt)
@@ -636,7 +646,7 @@ impl IndexMerger {
fn write_postings( fn write_postings(
&self, &self,
serializer: &mut InvertedIndexSerializer, serializer: &mut InvertedIndexSerializer,
) -> crate::Result<HashMap<Field, TermOrdinalMapping>> { ) -> Result<HashMap<Field, TermOrdinalMapping>> {
let mut term_ordinal_mappings = HashMap::new(); let mut term_ordinal_mappings = HashMap::new();
for (field, field_entry) in self.schema.fields() { for (field, field_entry) in self.schema.fields() {
if field_entry.is_indexed() { if field_entry.is_indexed() {
@@ -650,7 +660,7 @@ impl IndexMerger {
Ok(term_ordinal_mappings) Ok(term_ordinal_mappings)
} }
fn write_storable_fields(&self, store_writer: &mut StoreWriter) -> crate::Result<()> { fn write_storable_fields(&self, store_writer: &mut StoreWriter) -> Result<()> {
for reader in &self.readers { for reader in &self.readers {
let store_reader = reader.get_store_reader(); let store_reader = reader.get_store_reader();
if reader.num_deleted_docs() > 0 { if reader.num_deleted_docs() > 0 {
@@ -667,7 +677,7 @@ impl IndexMerger {
} }
impl SerializableSegment for IndexMerger { impl SerializableSegment for IndexMerger {
fn write(&self, mut serializer: SegmentSerializer) -> crate::Result<u32> { fn write(&self, mut serializer: SegmentSerializer) -> Result<u32> {
let term_ord_mappings = self.write_postings(serializer.get_postings_serializer())?; let term_ord_mappings = self.write_postings(serializer.get_postings_serializer())?;
self.write_fieldnorms(serializer.get_fieldnorms_serializer())?; self.write_fieldnorms(serializer.get_fieldnorms_serializer())?;
self.write_fast_fields(serializer.get_fast_field_serializer(), term_ord_mappings)?; self.write_fast_fields(serializer.get_fast_field_serializer(), term_ord_mappings)?;

View File

@@ -19,8 +19,6 @@ pub struct AddOperation {
/// UserOperation is an enum type that encapsulates other operation types. /// UserOperation is an enum type that encapsulates other operation types.
#[derive(Eq, PartialEq, Debug)] #[derive(Eq, PartialEq, Debug)]
pub enum UserOperation { pub enum UserOperation {
/// Add operation
Add(Document), Add(Document),
/// Delete operation
Delete(Term), Delete(Term),
} }

View File

@@ -1,5 +1,6 @@
use super::IndexWriter; use super::IndexWriter;
use crate::Opstamp; use crate::Opstamp;
use crate::Result;
use futures::executor::block_on; use futures::executor::block_on;
/// A prepared commit /// A prepared commit
@@ -26,11 +27,11 @@ impl<'a> PreparedCommit<'a> {
self.payload = Some(payload.to_string()) self.payload = Some(payload.to_string())
} }
pub fn abort(self) -> crate::Result<Opstamp> { pub fn abort(self) -> Result<Opstamp> {
self.index_writer.rollback() self.index_writer.rollback()
} }
pub fn commit(self) -> crate::Result<Opstamp> { pub fn commit(self) -> Result<Opstamp> {
info!("committing {}", self.opstamp); info!("committing {}", self.opstamp);
let _ = block_on( let _ = block_on(
self.index_writer self.index_writer

View File

@@ -4,6 +4,7 @@ use crate::core::SegmentMeta;
use crate::error::TantivyError; use crate::error::TantivyError;
use crate::indexer::delete_queue::DeleteCursor; use crate::indexer::delete_queue::DeleteCursor;
use crate::indexer::SegmentEntry; use crate::indexer::SegmentEntry;
use crate::Result as TantivyResult;
use std::collections::hash_set::HashSet; use std::collections::hash_set::HashSet;
use std::fmt::{self, Debug, Formatter}; use std::fmt::{self, Debug, Formatter};
use std::sync::RwLock; use std::sync::RwLock;
@@ -48,7 +49,7 @@ pub struct SegmentManager {
} }
impl Debug for SegmentManager { impl Debug for SegmentManager {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), fmt::Error> {
let lock = self.read(); let lock = self.read();
write!( write!(
f, f,
@@ -144,7 +145,7 @@ impl SegmentManager {
/// Returns an error if some segments are missing, or if /// Returns an error if some segments are missing, or if
/// the `segment_ids` are not either all committed or all /// the `segment_ids` are not either all committed or all
/// uncommitted. /// uncommitted.
pub fn start_merge(&self, segment_ids: &[SegmentId]) -> crate::Result<Vec<SegmentEntry>> { pub fn start_merge(&self, segment_ids: &[SegmentId]) -> TantivyResult<Vec<SegmentEntry>> {
let registers_lock = self.read(); let registers_lock = self.read();
let mut segment_entries = vec![]; let mut segment_entries = vec![];
if registers_lock.uncommitted.contains_all(segment_ids) { if registers_lock.uncommitted.contains_all(segment_ids) {
@@ -187,7 +188,7 @@ impl SegmentManager {
.segments_status(before_merge_segment_ids) .segments_status(before_merge_segment_ids)
.ok_or_else(|| { .ok_or_else(|| {
warn!("couldn't find segment in SegmentManager"); warn!("couldn't find segment in SegmentManager");
crate::TantivyError::InvalidArgument( crate::Error::InvalidArgument(
"The segments that were merged could not be found in the SegmentManager. \ "The segments that were merged could not be found in the SegmentManager. \
This is not necessarily a bug, and can happen after a rollback for instance." This is not necessarily a bug, and can happen after a rollback for instance."
.to_string(), .to_string(),

View File

@@ -1,8 +1,13 @@
use crate::Directory;
use crate::core::Segment; use crate::core::Segment;
use crate::core::SegmentComponent; use crate::core::SegmentComponent;
use crate::directory::error::OpenWriteError;
use crate::directory::{DirectoryClone, RAMDirectory, TerminatingWrite, WritePtr};
use crate::fastfield::FastFieldSerializer; use crate::fastfield::FastFieldSerializer;
use crate::fieldnorm::FieldNormsSerializer; use crate::fieldnorm::FieldNormsSerializer;
use crate::postings::InvertedIndexSerializer; use crate::postings::InvertedIndexSerializer;
use crate::schema::Schema;
use crate::store::StoreWriter; use crate::store::StoreWriter;
/// Segment serializer is in charge of laying out on disk /// Segment serializer is in charge of laying out on disk
@@ -12,25 +17,50 @@ pub struct SegmentSerializer {
fast_field_serializer: FastFieldSerializer, fast_field_serializer: FastFieldSerializer,
fieldnorms_serializer: FieldNormsSerializer, fieldnorms_serializer: FieldNormsSerializer,
postings_serializer: InvertedIndexSerializer, postings_serializer: InvertedIndexSerializer,
bundle_writer: Option<(RAMDirectory, WritePtr)>,
}
pub(crate) struct SegmentSerializerWriters {
postings_wrt: WritePtr,
positions_skip_wrt: WritePtr,
positions_wrt: WritePtr,
terms_wrt: WritePtr,
fast_field_wrt: WritePtr,
fieldnorms_wrt: WritePtr,
store_wrt: WritePtr,
}
impl SegmentSerializerWriters {
pub(crate) fn for_segment(segment: &mut Segment) -> Result<Self, OpenWriteError> {
Ok(SegmentSerializerWriters {
postings_wrt: segment.open_write(SegmentComponent::POSTINGS)?,
positions_skip_wrt: segment.open_write(SegmentComponent::POSITIONS)?,
positions_wrt: segment.open_write(SegmentComponent::POSITIONSSKIP)?,
terms_wrt: segment.open_write(SegmentComponent::TERMS)?,
fast_field_wrt: segment.open_write(SegmentComponent::FASTFIELDS)?,
fieldnorms_wrt: segment.open_write(SegmentComponent::FIELDNORMS)?,
store_wrt: segment.open_write(SegmentComponent::STORE)?,
})
}
} }
impl SegmentSerializer { impl SegmentSerializer {
/// Creates a new `SegmentSerializer`. pub(crate) fn new(schema: Schema, writers: SegmentSerializerWriters) -> crate::Result<Self> {
pub fn for_segment(segment: &mut Segment) -> crate::Result<SegmentSerializer> { let fast_field_serializer = FastFieldSerializer::from_write(writers.fast_field_wrt)?;
let store_write = segment.open_write(SegmentComponent::STORE)?; let fieldnorms_serializer = FieldNormsSerializer::from_write(writers.fieldnorms_wrt)?;
let postings_serializer = InvertedIndexSerializer::open(
let fast_field_write = segment.open_write(SegmentComponent::FASTFIELDS)?; schema,
let fast_field_serializer = FastFieldSerializer::from_write(fast_field_write)?; writers.terms_wrt,
writers.postings_wrt,
let fieldnorms_write = segment.open_write(SegmentComponent::FIELDNORMS)?; writers.positions_wrt,
let fieldnorms_serializer = FieldNormsSerializer::from_write(fieldnorms_write)?; writers.positions_skip_wrt,
);
let postings_serializer = InvertedIndexSerializer::open(segment)?;
Ok(SegmentSerializer { Ok(SegmentSerializer {
store_writer: StoreWriter::new(store_write), store_writer: StoreWriter::new(writers.store_wrt),
fast_field_serializer, fast_field_serializer,
fieldnorms_serializer, fieldnorms_serializer,
postings_serializer, postings_serializer,
bundle_writer: None,
}) })
} }
@@ -55,11 +85,15 @@ impl SegmentSerializer {
} }
/// Finalize the segment serialization. /// Finalize the segment serialization.
pub fn close(self) -> crate::Result<()> { pub fn close(mut self) -> crate::Result<()> {
self.fast_field_serializer.close()?; self.fast_field_serializer.close()?;
self.postings_serializer.close()?; self.postings_serializer.close()?;
self.store_writer.close()?; self.store_writer.close()?;
self.fieldnorms_serializer.close()?; self.fieldnorms_serializer.close()?;
if let Some((ram_directory, mut bundle_wrt)) = self.bundle_writer.take() {
ram_directory.serialize_bundle(&mut bundle_wrt)?;
bundle_wrt.terminate()?;
}
Ok(()) Ok(())
} }
} }

View File

@@ -12,6 +12,7 @@ use crate::indexer::index_writer::advance_deletes;
use crate::indexer::merge_operation::MergeOperationInventory; use crate::indexer::merge_operation::MergeOperationInventory;
use crate::indexer::merger::IndexMerger; use crate::indexer::merger::IndexMerger;
use crate::indexer::segment_manager::SegmentsStatus; use crate::indexer::segment_manager::SegmentsStatus;
use crate::indexer::segment_serializer::SegmentSerializerWriters;
use crate::indexer::stamper::Stamper; use crate::indexer::stamper::Stamper;
use crate::indexer::SegmentEntry; use crate::indexer::SegmentEntry;
use crate::indexer::SegmentSerializer; use crate::indexer::SegmentSerializer;
@@ -23,6 +24,7 @@ use futures::channel::oneshot;
use futures::executor::{ThreadPool, ThreadPoolBuilder}; use futures::executor::{ThreadPool, ThreadPoolBuilder};
use futures::future::Future; use futures::future::Future;
use futures::future::TryFutureExt; use futures::future::TryFutureExt;
use serde_json;
use std::borrow::BorrowMut; use std::borrow::BorrowMut;
use std::collections::HashSet; use std::collections::HashSet;
use std::io::Write; use std::io::Write;
@@ -131,7 +133,9 @@ fn merge(
let merger: IndexMerger = IndexMerger::open(index.schema(), &segments[..])?; let merger: IndexMerger = IndexMerger::open(index.schema(), &segments[..])?;
// ... we just serialize this index merger in our new segment to merge the two segments. // ... we just serialize this index merger in our new segment to merge the two segments.
let segment_serializer = SegmentSerializer::for_segment(&mut merged_segment)?; let segment_serializer_wrts = SegmentSerializerWriters::for_segment(&mut merged_segment)?;
let segment_serializer =
SegmentSerializer::new(merged_segment.schema(), segment_serializer_wrts)?;
let num_docs = merger.write(segment_serializer)?; let num_docs = merger.write(segment_serializer)?;
@@ -172,18 +176,14 @@ impl SegmentUpdater {
.pool_size(1) .pool_size(1)
.create() .create()
.map_err(|_| { .map_err(|_| {
crate::TantivyError::SystemError( crate::Error::SystemError("Failed to spawn segment updater thread".to_string())
"Failed to spawn segment updater thread".to_string(),
)
})?; })?;
let merge_thread_pool = ThreadPoolBuilder::new() let merge_thread_pool = ThreadPoolBuilder::new()
.name_prefix("merge_thread") .name_prefix("merge_thread")
.pool_size(NUM_MERGE_THREADS) .pool_size(NUM_MERGE_THREADS)
.create() .create()
.map_err(|_| { .map_err(|_| {
crate::TantivyError::SystemError( crate::Error::SystemError("Failed to spawn segment merging thread".to_string())
"Failed to spawn segment merging thread".to_string(),
)
})?; })?;
let index_meta = index.load_metas()?; let index_meta = index.load_metas()?;
Ok(SegmentUpdater(Arc::new(InnerSegmentUpdater { Ok(SegmentUpdater(Arc::new(InnerSegmentUpdater {
@@ -225,7 +225,7 @@ impl SegmentUpdater {
receiver.unwrap_or_else(|_| { receiver.unwrap_or_else(|_| {
let err_msg = let err_msg =
"A segment_updater future did not success. This should never happen.".to_string(); "A segment_updater future did not success. This should never happen.".to_string();
Err(crate::TantivyError::SystemError(err_msg)) Err(crate::Error::SystemError(err_msg))
}) })
} }
@@ -422,7 +422,7 @@ impl SegmentUpdater {
}); });
Ok(merging_future_recv Ok(merging_future_recv
.unwrap_or_else(|_| Err(crate::TantivyError::SystemError("Merge failed".to_string())))) .unwrap_or_else(|_| Err(crate::Error::SystemError("Merge failed".to_string()))))
} }
async fn consider_merge_options(&self) { async fn consider_merge_options(&self) {

View File

@@ -3,7 +3,7 @@ use crate::core::Segment;
use crate::core::SerializableSegment; use crate::core::SerializableSegment;
use crate::fastfield::FastFieldsWriter; use crate::fastfield::FastFieldsWriter;
use crate::fieldnorm::FieldNormsWriter; use crate::fieldnorm::FieldNormsWriter;
use crate::indexer::segment_serializer::SegmentSerializer; use crate::indexer::segment_serializer::{SegmentSerializer, SegmentSerializerWriters};
use crate::postings::compute_table_size; use crate::postings::compute_table_size;
use crate::postings::MultiFieldPostingsWriter; use crate::postings::MultiFieldPostingsWriter;
use crate::schema::FieldType; use crate::schema::FieldType;
@@ -11,18 +11,21 @@ use crate::schema::Schema;
use crate::schema::Term; use crate::schema::Term;
use crate::schema::Value; use crate::schema::Value;
use crate::schema::{Field, FieldEntry}; use crate::schema::{Field, FieldEntry};
use crate::tokenizer::{BoxTokenStream, PreTokenizedStream}; use crate::tokenizer::BoxedTokenizer;
use crate::tokenizer::{FacetTokenizer, TextAnalyzer}; use crate::tokenizer::FacetTokenizer;
use crate::tokenizer::{TokenStreamChain, Tokenizer}; use crate::tokenizer::PreTokenizedStream;
use crate::tokenizer::{TokenStream, TokenStreamChain, Tokenizer};
use crate::DocId; use crate::DocId;
use crate::Opstamp; use crate::Opstamp;
use crate::Result;
use crate::TantivyError;
use std::io; use std::io;
use std::str; use std::str;
/// Computes the initial size of the hash table. /// Computes the initial size of the hash table.
/// ///
/// Returns a number of bit `b`, such that the recommended initial table size is 2^b. /// Returns a number of bit `b`, such that the recommended initial table size is 2^b.
fn initial_table_size(per_thread_memory_budget: usize) -> crate::Result<usize> { fn initial_table_size(per_thread_memory_budget: usize) -> Result<usize> {
let table_memory_upper_bound = per_thread_memory_budget / 3; let table_memory_upper_bound = per_thread_memory_budget / 3;
if let Some(limit) = (10..) if let Some(limit) = (10..)
.take_while(|num_bits: &usize| compute_table_size(*num_bits) < table_memory_upper_bound) .take_while(|num_bits: &usize| compute_table_size(*num_bits) < table_memory_upper_bound)
@@ -30,7 +33,7 @@ fn initial_table_size(per_thread_memory_budget: usize) -> crate::Result<usize> {
{ {
Ok(limit.min(19)) // we cap it at 2^19 = 512K. Ok(limit.min(19)) // we cap it at 2^19 = 512K.
} else { } else {
Err(crate::TantivyError::InvalidArgument( Err(TantivyError::InvalidArgument(
format!("per thread memory budget (={}) is too small. Raise the memory budget or lower the number of threads.", per_thread_memory_budget))) format!("per thread memory budget (={}) is too small. Raise the memory budget or lower the number of threads.", per_thread_memory_budget)))
} }
} }
@@ -47,7 +50,7 @@ pub struct SegmentWriter {
fast_field_writers: FastFieldsWriter, fast_field_writers: FastFieldsWriter,
fieldnorms_writer: FieldNormsWriter, fieldnorms_writer: FieldNormsWriter,
doc_opstamps: Vec<Opstamp>, doc_opstamps: Vec<Opstamp>,
tokenizers: Vec<Option<TextAnalyzer>>, tokenizers: Vec<Option<BoxedTokenizer>>,
} }
impl SegmentWriter { impl SegmentWriter {
@@ -64,9 +67,10 @@ impl SegmentWriter {
memory_budget: usize, memory_budget: usize,
mut segment: Segment, mut segment: Segment,
schema: &Schema, schema: &Schema,
) -> crate::Result<SegmentWriter> { ) -> Result<SegmentWriter> {
let table_num_bits = initial_table_size(memory_budget)?; let table_num_bits = initial_table_size(memory_budget)?;
let segment_serializer = SegmentSerializer::for_segment(&mut segment)?; let segment_serializer_wrts = SegmentSerializerWriters::for_segment(&mut segment)?;
let segment_serializer = SegmentSerializer::new(segment.schema(), segment_serializer_wrts)?;
let multifield_postings = MultiFieldPostingsWriter::new(schema, table_num_bits); let multifield_postings = MultiFieldPostingsWriter::new(schema, table_num_bits);
let tokenizers = schema let tokenizers = schema
.fields() .fields()
@@ -97,7 +101,7 @@ impl SegmentWriter {
/// ///
/// Finalize consumes the `SegmentWriter`, so that it cannot /// Finalize consumes the `SegmentWriter`, so that it cannot
/// be used afterwards. /// be used afterwards.
pub fn finalize(mut self) -> crate::Result<Vec<u64>> { pub fn finalize(mut self) -> Result<Vec<u64>> {
self.fieldnorms_writer.fill_up_to_max_doc(self.max_doc); self.fieldnorms_writer.fill_up_to_max_doc(self.max_doc);
write( write(
&self.multifield_postings, &self.multifield_postings,
@@ -156,7 +160,7 @@ impl SegmentWriter {
} }
} }
FieldType::Str(_) => { FieldType::Str(_) => {
let mut token_streams: Vec<BoxTokenStream> = vec![]; let mut token_streams: Vec<Box<dyn TokenStream>> = vec![];
let mut offsets = vec![]; let mut offsets = vec![];
let mut total_offset = 0; let mut total_offset = 0;
@@ -169,7 +173,7 @@ impl SegmentWriter {
} }
token_streams token_streams
.push(PreTokenizedStream::from(tok_str.clone()).into()); .push(Box::new(PreTokenizedStream::from(tok_str.clone())));
} }
Value::Str(ref text) => { Value::Str(ref text) => {
if let Some(ref mut tokenizer) = if let Some(ref mut tokenizer) =
@@ -188,7 +192,8 @@ impl SegmentWriter {
let num_tokens = if token_streams.is_empty() { let num_tokens = if token_streams.is_empty() {
0 0
} else { } else {
let mut token_stream = TokenStreamChain::new(offsets, token_streams); let mut token_stream: Box<dyn TokenStream> =
Box::new(TokenStreamChain::new(offsets, token_streams));
self.multifield_postings self.multifield_postings
.index_text(doc_id, field, &mut token_stream) .index_text(doc_id, field, &mut token_stream)
}; };
@@ -279,7 +284,7 @@ fn write(
fast_field_writers: &FastFieldsWriter, fast_field_writers: &FastFieldsWriter,
fieldnorms_writer: &FieldNormsWriter, fieldnorms_writer: &FieldNormsWriter,
mut serializer: SegmentSerializer, mut serializer: SegmentSerializer,
) -> crate::Result<()> { ) -> Result<()> {
let term_ord_map = multifield_postings.serialize(serializer.get_postings_serializer())?; let term_ord_map = multifield_postings.serialize(serializer.get_postings_serializer())?;
fast_field_writers.serialize(serializer.get_fast_field_serializer(), &term_ord_map)?; fast_field_writers.serialize(serializer.get_fast_field_serializer(), &term_ord_map)?;
fieldnorms_writer.serialize(serializer.get_fieldnorms_serializer())?; fieldnorms_writer.serialize(serializer.get_fieldnorms_serializer())?;
@@ -288,7 +293,7 @@ fn write(
} }
impl SerializableSegment for SegmentWriter { impl SerializableSegment for SegmentWriter {
fn write(&self, serializer: SegmentSerializer) -> crate::Result<u32> { fn write(&self, serializer: SegmentSerializer) -> Result<u32> {
let max_doc = self.max_doc; let max_doc = self.max_doc;
write( write(
&self.multifield_postings, &self.multifield_postings,

View File

@@ -1,76 +1,18 @@
use crate::Opstamp; use crate::Opstamp;
use std::ops::Range; use std::ops::Range;
use std::sync::atomic::Ordering; use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc; use std::sync::Arc;
#[cfg(not(target_arch = "arm"))]
mod atomic_impl {
use crate::Opstamp;
use std::sync::atomic::{AtomicU64, Ordering};
#[derive(Default)]
pub struct AtomicU64Wrapper(AtomicU64);
impl AtomicU64Wrapper {
pub fn new(first_opstamp: Opstamp) -> AtomicU64Wrapper {
AtomicU64Wrapper(AtomicU64::new(first_opstamp as u64))
}
pub fn fetch_add(&self, val: u64, order: Ordering) -> u64 {
self.0.fetch_add(val as u64, order) as u64
}
pub fn revert(&self, val: u64, order: Ordering) -> u64 {
self.0.store(val, order);
val
}
}
}
#[cfg(target_arch = "arm")]
mod atomic_impl {
use crate::Opstamp;
/// Under other architecture, we rely on a mutex.
use std::sync::atomic::Ordering;
use std::sync::RwLock;
#[derive(Default)]
pub struct AtomicU64Wrapper(RwLock<u64>);
impl AtomicU64Wrapper {
pub fn new(first_opstamp: Opstamp) -> AtomicU64Wrapper {
AtomicU64Wrapper(RwLock::new(first_opstamp))
}
pub fn fetch_add(&self, incr: u64, _order: Ordering) -> u64 {
let mut lock = self.0.write().unwrap();
let previous_val = *lock;
*lock = previous_val + incr;
previous_val
}
pub fn revert(&self, val: u64, _order: Ordering) -> u64 {
let mut lock = self.0.write().unwrap();
*lock = val;
val
}
}
}
use self::atomic_impl::AtomicU64Wrapper;
/// Stamper provides Opstamps, which is just an auto-increment id to label /// Stamper provides Opstamps, which is just an auto-increment id to label
/// an operation. /// an operation.
/// ///
/// Cloning does not "fork" the stamp generation. The stamper actually wraps an `Arc`. /// Cloning does not "fork" the stamp generation. The stamper actually wraps an `Arc`.
#[derive(Clone, Default)] #[derive(Clone, Default)]
pub struct Stamper(Arc<AtomicU64Wrapper>); pub struct Stamper(Arc<AtomicU64>);
impl Stamper { impl Stamper {
pub fn new(first_opstamp: Opstamp) -> Stamper { pub fn new(first_opstamp: Opstamp) -> Stamper {
Stamper(Arc::new(AtomicU64Wrapper::new(first_opstamp))) Stamper(Arc::new(AtomicU64::new(first_opstamp)))
} }
pub fn stamp(&self) -> Opstamp { pub fn stamp(&self) -> Opstamp {
@@ -89,7 +31,8 @@ impl Stamper {
/// Reverts the stamper to a given `Opstamp` value and returns it /// Reverts the stamper to a given `Opstamp` value and returns it
pub fn revert(&self, to_opstamp: Opstamp) -> Opstamp { pub fn revert(&self, to_opstamp: Opstamp) -> Opstamp {
self.0.revert(to_opstamp, Ordering::SeqCst) self.0.store(to_opstamp, Ordering::SeqCst);
to_opstamp
} }
} }

View File

@@ -98,6 +98,9 @@
//! [literate programming](https://tantivy-search.github.io/examples/basic_search.html) / //! [literate programming](https://tantivy-search.github.io/examples/basic_search.html) /
//! [source code](https://github.com/tantivy-search/tantivy/blob/master/examples/basic_search.rs)) //! [source code](https://github.com/tantivy-search/tantivy/blob/master/examples/basic_search.rs))
#[macro_use]
extern crate serde_derive;
#[cfg_attr(test, macro_use)] #[cfg_attr(test, macro_use)]
extern crate serde_json; extern crate serde_json;
@@ -118,13 +121,13 @@ mod functional_test;
mod macros; mod macros;
pub use crate::error::TantivyError; pub use crate::error::TantivyError;
#[deprecated(since = "0.7.0", note = "please use `tantivy::TantivyError` instead")]
pub use crate::error::TantivyError as Error;
pub use chrono; pub use chrono;
/// Tantivy result. /// Tantivy result.
/// pub type Result<T> = std::result::Result<T, error::TantivyError>;
/// Within tantivy, please avoid importing `Result` using `use crate::Result`
/// and instead, refer to this as `crate::Result<T>`.
pub type Result<T> = std::result::Result<T, TantivyError>;
/// Tantivy DateTime /// Tantivy DateTime
pub type DateTime = chrono::DateTime<chrono::Utc>; pub type DateTime = chrono::DateTime<chrono::Utc>;
@@ -156,13 +159,12 @@ mod snippet;
pub use self::snippet::{Snippet, SnippetGenerator}; pub use self::snippet::{Snippet, SnippetGenerator};
mod docset; mod docset;
pub use self::docset::{DocSet, TERMINATED}; pub use self::docset::{DocSet, SkipResult};
pub use crate::common::{f64_to_u64, i64_to_u64, u64_to_f64, u64_to_i64}; pub use crate::common::{f64_to_u64, i64_to_u64, u64_to_f64, u64_to_i64};
pub use crate::core::{Executor, SegmentComponent}; pub use crate::core::SegmentComponent;
pub use crate::core::{Index, IndexMeta, Searcher, Segment, SegmentId, SegmentMeta}; pub use crate::core::{Index, IndexMeta, Searcher, Segment, SegmentId, SegmentMeta};
pub use crate::core::{InvertedIndexReader, SegmentReader}; pub use crate::core::{InvertedIndexReader, SegmentReader};
pub use crate::directory::Directory; pub use crate::directory::Directory;
pub use crate::indexer::operation::UserOperation;
pub use crate::indexer::IndexWriter; pub use crate::indexer::IndexWriter;
pub use crate::postings::Postings; pub use crate::postings::Postings;
pub use crate::reader::LeasedItem; pub use crate::reader::LeasedItem;
@@ -170,7 +172,6 @@ pub use crate::schema::{Document, Term};
use std::fmt; use std::fmt;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
/// Index format version. /// Index format version.
const INDEX_FORMAT_VERSION: u32 = 1; const INDEX_FORMAT_VERSION: u32 = 1;
@@ -285,7 +286,7 @@ mod tests {
use crate::collector::tests::TEST_COLLECTOR_WITH_SCORE; use crate::collector::tests::TEST_COLLECTOR_WITH_SCORE;
use crate::core::SegmentReader; use crate::core::SegmentReader;
use crate::docset::{DocSet, TERMINATED}; use crate::docset::DocSet;
use crate::query::BooleanQuery; use crate::query::BooleanQuery;
use crate::schema::*; use crate::schema::*;
use crate::DocAddress; use crate::DocAddress;
@@ -381,12 +382,19 @@ mod tests {
index_writer.commit().unwrap(); index_writer.commit().unwrap();
} }
{ {
index_writer.add_document(doc!(text_field=>"a")); {
index_writer.add_document(doc!(text_field=>"a a")); let doc = doc!(text_field=>"a");
index_writer.add_document(doc);
}
{
let doc = doc!(text_field=>"a a");
index_writer.add_document(doc);
}
index_writer.commit().unwrap(); index_writer.commit().unwrap();
} }
{ {
index_writer.add_document(doc!(text_field=>"c")); let doc = doc!(text_field=>"c");
index_writer.add_document(doc);
index_writer.commit().unwrap(); index_writer.commit().unwrap();
} }
{ {
@@ -465,12 +473,10 @@ mod tests {
} }
fn advance_undeleted(docset: &mut dyn DocSet, reader: &SegmentReader) -> bool { fn advance_undeleted(docset: &mut dyn DocSet, reader: &SegmentReader) -> bool {
let mut doc = docset.advance(); while docset.advance() {
while doc != TERMINATED { if !reader.is_deleted(docset.doc()) {
if !reader.is_deleted(doc) {
return true; return true;
} }
doc = docset.advance();
} }
false false
} }
@@ -636,8 +642,9 @@ mod tests {
.inverted_index(term.field()) .inverted_index(term.field())
.read_postings(&term, IndexRecordOption::Basic) .read_postings(&term, IndexRecordOption::Basic)
.unwrap(); .unwrap();
assert!(postings.advance());
assert_eq!(postings.doc(), 0); assert_eq!(postings.doc(), 0);
assert_eq!(postings.advance(), TERMINATED); assert!(!postings.advance());
} }
#[test] #[test]
@@ -659,8 +666,9 @@ mod tests {
.inverted_index(term.field()) .inverted_index(term.field())
.read_postings(&term, IndexRecordOption::Basic) .read_postings(&term, IndexRecordOption::Basic)
.unwrap(); .unwrap();
assert!(postings.advance());
assert_eq!(postings.doc(), 0); assert_eq!(postings.doc(), 0);
assert_eq!(postings.advance(), TERMINATED); assert!(!postings.advance());
} }
#[test] #[test]
@@ -682,8 +690,9 @@ mod tests {
.inverted_index(term.field()) .inverted_index(term.field())
.read_postings(&term, IndexRecordOption::Basic) .read_postings(&term, IndexRecordOption::Basic)
.unwrap(); .unwrap();
assert!(postings.advance());
assert_eq!(postings.doc(), 0); assert_eq!(postings.doc(), 0);
assert_eq!(postings.advance(), TERMINATED); assert!(!postings.advance());
} }
#[test] #[test]
@@ -752,8 +761,10 @@ mod tests {
{ {
// writing the segment // writing the segment
let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap(); let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap();
let doc = doc!(text_field=>"af af af bc bc"); {
index_writer.add_document(doc); let doc = doc!(text_field=>"af af af bc bc");
index_writer.add_document(doc);
}
index_writer.commit().unwrap(); index_writer.commit().unwrap();
} }
{ {
@@ -769,9 +780,10 @@ mod tests {
let mut postings = inverted_index let mut postings = inverted_index
.read_postings(&term_af, IndexRecordOption::WithFreqsAndPositions) .read_postings(&term_af, IndexRecordOption::WithFreqsAndPositions)
.unwrap(); .unwrap();
assert!(postings.advance());
assert_eq!(postings.doc(), 0); assert_eq!(postings.doc(), 0);
assert_eq!(postings.term_freq(), 3); assert_eq!(postings.term_freq(), 3);
assert_eq!(postings.advance(), TERMINATED); assert!(!postings.advance());
} }
} }

View File

@@ -37,9 +37,9 @@ const LONG_SKIP_INTERVAL: u64 = (LONG_SKIP_IN_BLOCKS * COMPRESSION_BLOCK_SIZE) a
#[cfg(test)] #[cfg(test)]
pub mod tests { pub mod tests {
use super::PositionSerializer; use super::{PositionReader, PositionSerializer};
use crate::directory::ReadOnlySource; use crate::directory::ReadOnlySource;
use crate::positions::reader::PositionReader; use crate::positions::COMPRESSION_BLOCK_SIZE;
use std::iter; use std::iter;
fn create_stream_buffer(vals: &[u32]) -> (ReadOnlySource, ReadOnlySource) { fn create_stream_buffer(vals: &[u32]) -> (ReadOnlySource, ReadOnlySource) {
@@ -68,7 +68,7 @@ pub mod tests {
let mut position_reader = PositionReader::new(stream, skip, 0u64); let mut position_reader = PositionReader::new(stream, skip, 0u64);
for &n in &[1, 10, 127, 128, 130, 312] { for &n in &[1, 10, 127, 128, 130, 312] {
let mut v = vec![0u32; n]; let mut v = vec![0u32; n];
position_reader.read(0, &mut v[..]); position_reader.read(&mut v[..n]);
for i in 0..n { for i in 0..n {
assert_eq!(v[i], i as u32); assert_eq!(v[i], i as u32);
} }
@@ -76,19 +76,19 @@ pub mod tests {
} }
#[test] #[test]
fn test_position_read_with_offset() { fn test_position_skip() {
let v: Vec<u32> = (0..1000).collect(); let v: Vec<u32> = (0..1_000).collect();
let (stream, skip) = create_stream_buffer(&v[..]); let (stream, skip) = create_stream_buffer(&v[..]);
assert_eq!(skip.len(), 12); assert_eq!(skip.len(), 12);
assert_eq!(stream.len(), 1168); assert_eq!(stream.len(), 1168);
let mut position_reader = PositionReader::new(stream, skip, 0u64); let mut position_reader = PositionReader::new(stream, skip, 0u64);
for &offset in &[1u64, 10u64, 127u64, 128u64, 130u64, 312u64] { position_reader.skip(10);
for &len in &[1, 10, 130, 500] { for &n in &[10, 127, COMPRESSION_BLOCK_SIZE, 130, 312] {
let mut v = vec![0u32; len]; let mut v = vec![0u32; n];
position_reader.read(offset, &mut v[..]); position_reader.read(&mut v[..n]);
for i in 0..len { for i in 0..n {
assert_eq!(v[i], i as u32 + offset as u32); assert_eq!(v[i], 10u32 + i as u32);
}
} }
} }
} }
@@ -103,12 +103,11 @@ pub mod tests {
let mut position_reader = PositionReader::new(stream, skip, 0u64); let mut position_reader = PositionReader::new(stream, skip, 0u64);
let mut buf = [0u32; 7]; let mut buf = [0u32; 7];
let mut c = 0; let mut c = 0;
let mut offset = 0;
for _ in 0..100 { for _ in 0..100 {
position_reader.read(offset, &mut buf); position_reader.read(&mut buf);
position_reader.read(offset, &mut buf); position_reader.read(&mut buf);
offset += 7; position_reader.skip(4);
position_reader.skip(3);
for &el in &buf { for &el in &buf {
assert_eq!(c, el); assert_eq!(c, el);
c += 1; c += 1;
@@ -116,58 +115,6 @@ pub mod tests {
} }
} }
#[test]
fn test_position_reread_anchor_different_than_block() {
let v: Vec<u32> = (0..2_000_000).collect();
let (stream, skip) = create_stream_buffer(&v[..]);
assert_eq!(skip.len(), 15_749);
assert_eq!(stream.len(), 4_987_872);
let mut position_reader = PositionReader::new(stream.clone(), skip.clone(), 0);
let mut buf = [0u32; 256];
position_reader.read(128, &mut buf);
for i in 0..256 {
assert_eq!(buf[i], (128 + i) as u32);
}
position_reader.read(128, &mut buf);
for i in 0..256 {
assert_eq!(buf[i], (128 + i) as u32);
}
}
#[test]
#[should_panic(expected = "offset arguments should be increasing.")]
fn test_position_panic_if_called_previous_anchor() {
let v: Vec<u32> = (0..2_000_000).collect();
let (stream, skip) = create_stream_buffer(&v[..]);
assert_eq!(skip.len(), 15_749);
assert_eq!(stream.len(), 4_987_872);
let mut buf = [0u32; 1];
let mut position_reader = PositionReader::new(stream.clone(), skip.clone(), 200_000);
position_reader.read(230, &mut buf);
position_reader.read(9, &mut buf);
}
#[test]
fn test_positions_bug() {
let mut v: Vec<u32> = vec![];
for i in 1..200 {
for j in 0..i {
v.push(j);
}
}
let (stream, skip) = create_stream_buffer(&v[..]);
let mut buf = Vec::new();
let mut position_reader = PositionReader::new(stream.clone(), skip.clone(), 0);
let mut offset = 0;
for i in 1..24 {
buf.resize(i, 0);
position_reader.read(offset, &mut buf[..]);
offset += i as u64;
let r: Vec<u32> = (0..i).map(|el| el as u32).collect();
assert_eq!(buf, &r[..]);
}
}
#[test] #[test]
fn test_position_long_skip_const() { fn test_position_long_skip_const() {
const CONST_VAL: u32 = 9u32; const CONST_VAL: u32 = 9u32;
@@ -177,7 +124,7 @@ pub mod tests {
assert_eq!(stream.len(), 1_000_000); assert_eq!(stream.len(), 1_000_000);
let mut position_reader = PositionReader::new(stream, skip, 128 * 1024); let mut position_reader = PositionReader::new(stream, skip, 128 * 1024);
let mut buf = [0u32; 1]; let mut buf = [0u32; 1];
position_reader.read(0, &mut buf); position_reader.read(&mut buf);
assert_eq!(buf[0], CONST_VAL); assert_eq!(buf[0], CONST_VAL);
} }
@@ -196,7 +143,7 @@ pub mod tests {
] { ] {
let mut position_reader = PositionReader::new(stream.clone(), skip.clone(), offset); let mut position_reader = PositionReader::new(stream.clone(), skip.clone(), offset);
let mut buf = [0u32; 1]; let mut buf = [0u32; 1];
position_reader.read(0, &mut buf); position_reader.read(&mut buf);
assert_eq!(buf[0], offset as u32); assert_eq!(buf[0], offset as u32);
} }
} }

View File

@@ -3,6 +3,7 @@ use crate::directory::ReadOnlySource;
use crate::positions::COMPRESSION_BLOCK_SIZE; use crate::positions::COMPRESSION_BLOCK_SIZE;
use crate::positions::LONG_SKIP_INTERVAL; use crate::positions::LONG_SKIP_INTERVAL;
use crate::positions::LONG_SKIP_IN_BLOCKS; use crate::positions::LONG_SKIP_IN_BLOCKS;
use crate::postings::compression::compressed_block_size;
/// Positions works as a long sequence of compressed block. /// Positions works as a long sequence of compressed block.
/// All terms are chained one after the other. /// All terms are chained one after the other.
/// ///
@@ -61,20 +62,22 @@ impl Positions {
fn reader(&self, offset: u64) -> PositionReader { fn reader(&self, offset: u64) -> PositionReader {
let long_skip_id = (offset / LONG_SKIP_INTERVAL) as usize; let long_skip_id = (offset / LONG_SKIP_INTERVAL) as usize;
let small_skip = (offset % LONG_SKIP_INTERVAL) as usize;
let offset_num_bytes: u64 = self.long_skip(long_skip_id); let offset_num_bytes: u64 = self.long_skip(long_skip_id);
let mut position_read = OwnedRead::new(self.position_source.clone()); let mut position_read = OwnedRead::new(self.position_source.clone());
position_read.advance(offset_num_bytes as usize); position_read.advance(offset_num_bytes as usize);
let mut skip_read = OwnedRead::new(self.skip_source.clone()); let mut skip_read = OwnedRead::new(self.skip_source.clone());
skip_read.advance(long_skip_id * LONG_SKIP_IN_BLOCKS); skip_read.advance(long_skip_id * LONG_SKIP_IN_BLOCKS);
PositionReader { let mut position_reader = PositionReader {
bit_packer: self.bit_packer, bit_packer: self.bit_packer,
skip_read, skip_read,
position_read, position_read,
inner_offset: 0,
buffer: Box::new([0u32; 128]), buffer: Box::new([0u32; 128]),
block_offset: std::i64::MAX as u64, ahead: None,
anchor_offset: (long_skip_id as u64) * LONG_SKIP_INTERVAL, };
abs_offset: offset, position_reader.skip(small_skip);
} position_reader
} }
} }
@@ -82,12 +85,51 @@ pub struct PositionReader {
skip_read: OwnedRead, skip_read: OwnedRead,
position_read: OwnedRead, position_read: OwnedRead,
bit_packer: BitPacker4x, bit_packer: BitPacker4x,
buffer: Box<[u32; COMPRESSION_BLOCK_SIZE]>, inner_offset: usize,
buffer: Box<[u32; 128]>,
ahead: Option<usize>, // if None, no block is loaded.
// if Some(num_blocks), the block currently loaded is num_blocks ahead
// of the block of the next int to read.
}
block_offset: u64, // `ahead` represents the offset of the block currently loaded
anchor_offset: u64, // compared to the cursor of the actual stream.
//
abs_offset: u64, // By contract, when this function is called, the current block has to be
// decompressed.
//
// If the requested number of els ends exactly at a given block, the next
// block is not decompressed.
fn read_impl(
bit_packer: BitPacker4x,
mut position: &[u8],
buffer: &mut [u32; 128],
mut inner_offset: usize,
num_bits: &[u8],
output: &mut [u32],
) -> usize {
let mut output_start = 0;
let mut output_len = output.len();
let mut ahead = 0;
loop {
let available_len = COMPRESSION_BLOCK_SIZE - inner_offset;
// We have enough elements in the current block.
// Let's copy the requested elements in the output buffer,
// and return.
if output_len <= available_len {
output[output_start..].copy_from_slice(&buffer[inner_offset..][..output_len]);
return ahead;
}
output[output_start..][..available_len].copy_from_slice(&buffer[inner_offset..]);
output_len -= available_len;
output_start += available_len;
inner_offset = 0;
let num_bits = num_bits[ahead];
bit_packer.decompress(position, &mut buffer[..], num_bits);
let block_len = compressed_block_size(num_bits);
position = &position[block_len..];
ahead += 1;
}
} }
impl PositionReader { impl PositionReader {
@@ -99,65 +141,57 @@ impl PositionReader {
Positions::new(position_source, skip_source).reader(offset) Positions::new(position_source, skip_source).reader(offset)
} }
fn advance_num_blocks(&mut self, num_blocks: usize) { /// Fills a buffer with the next `output.len()` integers.
let num_bits: usize = self.skip_read.as_ref()[..num_blocks] /// This does not consume / advance the stream.
.iter() pub fn read(&mut self, output: &mut [u32]) {
.cloned() let skip_data = self.skip_read.as_ref();
.map(|num_bits| num_bits as usize) let position_data = self.position_read.as_ref();
.sum(); let num_bits = self.skip_read.get(0);
let num_bytes_to_skip = num_bits * COMPRESSION_BLOCK_SIZE / 8; if self.ahead != Some(0) {
self.skip_read.advance(num_blocks as usize); // the block currently available is not the block
self.position_read.advance(num_bytes_to_skip); // for the current position
}
/// Fills a buffer with the positions `[offset..offset+output.len())` integers.
///
/// `offset` is required to have a value >= to the offsets given in previous calls
/// for the given `PositionReaderAbsolute` instance.
pub fn read(&mut self, mut offset: u64, mut output: &mut [u32]) {
offset += self.abs_offset;
assert!(
offset >= self.anchor_offset,
"offset arguments should be increasing."
);
let delta_to_block_offset = offset as i64 - self.block_offset as i64;
if delta_to_block_offset < 0 || delta_to_block_offset >= 128 {
// The first position is not within the first block.
// We need to decompress the first block.
let delta_to_anchor_offset = offset - self.anchor_offset;
let num_blocks_to_skip =
(delta_to_anchor_offset / (COMPRESSION_BLOCK_SIZE as u64)) as usize;
self.advance_num_blocks(num_blocks_to_skip);
self.anchor_offset = offset - (offset % COMPRESSION_BLOCK_SIZE as u64);
self.block_offset = self.anchor_offset;
let num_bits = self.skip_read.get(0);
self.bit_packer
.decompress(self.position_read.as_ref(), self.buffer.as_mut(), num_bits);
} else {
let num_blocks_to_skip =
((self.block_offset - self.anchor_offset) / COMPRESSION_BLOCK_SIZE as u64) as usize;
self.advance_num_blocks(num_blocks_to_skip);
self.anchor_offset = self.block_offset;
}
let mut num_bits = self.skip_read.get(0);
let mut position_data = self.position_read.as_ref();
for i in 1.. {
let offset_in_block = (offset as usize) % COMPRESSION_BLOCK_SIZE;
let remaining_in_block = COMPRESSION_BLOCK_SIZE - offset_in_block;
if remaining_in_block >= output.len() {
output.copy_from_slice(&self.buffer[offset_in_block..][..output.len()]);
break;
}
output[..remaining_in_block].copy_from_slice(&self.buffer[offset_in_block..]);
output = &mut output[remaining_in_block..];
offset += remaining_in_block as u64;
position_data = &position_data[(num_bits as usize * COMPRESSION_BLOCK_SIZE / 8)..];
num_bits = self.skip_read.get(i);
self.bit_packer self.bit_packer
.decompress(position_data, self.buffer.as_mut(), num_bits); .decompress(position_data, self.buffer.as_mut(), num_bits);
self.block_offset += COMPRESSION_BLOCK_SIZE as u64; self.ahead = Some(0);
} }
let block_len = compressed_block_size(num_bits);
self.ahead = Some(read_impl(
self.bit_packer,
&position_data[block_len..],
self.buffer.as_mut(),
self.inner_offset,
&skip_data[1..],
output,
));
}
/// Skip the next `skip_len` integer.
///
/// If a full block is skipped, calling
/// `.skip(...)` will avoid decompressing it.
///
/// May panic if the end of the stream is reached.
pub fn skip(&mut self, skip_len: usize) {
let skip_len_plus_inner_offset = skip_len + self.inner_offset;
let num_blocks_to_advance = skip_len_plus_inner_offset / COMPRESSION_BLOCK_SIZE;
self.inner_offset = skip_len_plus_inner_offset % COMPRESSION_BLOCK_SIZE;
self.ahead = self.ahead.and_then(|num_blocks| {
if num_blocks >= num_blocks_to_advance {
Some(num_blocks - num_blocks_to_advance)
} else {
None
}
});
let skip_len_in_bits = self.skip_read.as_ref()[..num_blocks_to_advance]
.iter()
.map(|num_bits| *num_bits as usize)
.sum::<usize>()
* COMPRESSION_BLOCK_SIZE;
let skip_len_in_bytes = skip_len_in_bits / 8;
self.skip_read.advance(num_blocks_to_advance);
self.position_read.advance(skip_len_in_bytes);
} }
} }

View File

@@ -87,7 +87,6 @@ fn exponential_search(arr: &[u32], target: u32) -> (usize, usize) {
(begin, end) (begin, end)
} }
#[inline(never)]
fn galloping(block_docs: &[u32], target: u32) -> usize { fn galloping(block_docs: &[u32], target: u32) -> usize {
let (start, end) = exponential_search(&block_docs, target); let (start, end) = exponential_search(&block_docs, target);
start + linear_search(&block_docs[start..end], target) start + linear_search(&block_docs[start..end], target)
@@ -107,7 +106,7 @@ impl BlockSearcher {
/// the target. /// the target.
/// ///
/// The results should be equivalent to /// The results should be equivalent to
/// ```compile_fail /// ```ignore
/// block[..] /// block[..]
// .iter() // .iter()
// .take_while(|&&val| val < target) // .take_while(|&&val| val < target)
@@ -130,18 +129,23 @@ impl BlockSearcher {
/// ///
/// If SSE2 instructions are available in the `(platform, running CPU)`, /// If SSE2 instructions are available in the `(platform, running CPU)`,
/// then we use a different implementation that does an exhaustive linear search over /// then we use a different implementation that does an exhaustive linear search over
/// the block regardless of whether the block is full or not. /// the full block whenever the block is full (`len == 128`). It is surprisingly faster, most likely because of the lack
/// /// of branch.
/// Indeed, if the block is not full, the remaining items are TERMINATED. pub(crate) fn search_in_block(
/// It is surprisingly faster, most likely because of the lack of branch misprediction. self,
pub(crate) fn search_in_block(self, block_docs: &AlignedBuffer, target: u32) -> usize { block_docs: &AlignedBuffer,
len: usize,
start: usize,
target: u32,
) -> usize {
#[cfg(target_arch = "x86_64")] #[cfg(target_arch = "x86_64")]
{ {
if self == BlockSearcher::SSE2 { use crate::postings::compression::COMPRESSION_BLOCK_SIZE;
if self == BlockSearcher::SSE2 && len == COMPRESSION_BLOCK_SIZE {
return sse2::linear_search_sse2_128(block_docs, target); return sse2::linear_search_sse2_128(block_docs, target);
} }
} }
galloping(&block_docs.0[..], target) start + galloping(&block_docs.0[start..len], target)
} }
} }
@@ -162,7 +166,6 @@ mod tests {
use super::exponential_search; use super::exponential_search;
use super::linear_search; use super::linear_search;
use super::BlockSearcher; use super::BlockSearcher;
use crate::docset::TERMINATED;
use crate::postings::compression::{AlignedBuffer, COMPRESSION_BLOCK_SIZE}; use crate::postings::compression::{AlignedBuffer, COMPRESSION_BLOCK_SIZE};
#[test] #[test]
@@ -193,12 +196,19 @@ mod tests {
fn util_test_search_in_block(block_searcher: BlockSearcher, block: &[u32], target: u32) { fn util_test_search_in_block(block_searcher: BlockSearcher, block: &[u32], target: u32) {
let cursor = search_in_block_trivial_but_slow(block, target); let cursor = search_in_block_trivial_but_slow(block, target);
assert!(block.len() < COMPRESSION_BLOCK_SIZE); assert!(block.len() < COMPRESSION_BLOCK_SIZE);
let mut output_buffer = [TERMINATED; COMPRESSION_BLOCK_SIZE]; let mut output_buffer = [u32::max_value(); COMPRESSION_BLOCK_SIZE];
output_buffer[..block.len()].copy_from_slice(block); output_buffer[..block.len()].copy_from_slice(block);
assert_eq!( for i in 0..cursor {
block_searcher.search_in_block(&AlignedBuffer(output_buffer), target), assert_eq!(
cursor block_searcher.search_in_block(
); &AlignedBuffer(output_buffer),
block.len(),
i,
target
),
cursor
);
}
} }
fn util_test_search_in_block_all(block_searcher: BlockSearcher, block: &[u32]) { fn util_test_search_in_block_all(block_searcher: BlockSearcher, block: &[u32]) {

View File

@@ -1,427 +0,0 @@
use crate::common::{BinarySerializable, VInt};
use crate::directory::ReadOnlySource;
use crate::postings::compression::{
AlignedBuffer, BlockDecoder, VIntDecoder, COMPRESSION_BLOCK_SIZE,
};
use crate::postings::{BlockInfo, FreqReadingOption, SkipReader};
use crate::schema::IndexRecordOption;
use crate::{DocId, TERMINATED};
/// `BlockSegmentPostings` is a cursor iterating over blocks
/// of documents.
///
/// # Warning
///
/// While it is useful for some very specific high-performance
/// use cases, you should prefer using `SegmentPostings` for most usage.
pub struct BlockSegmentPostings {
pub(crate) doc_decoder: BlockDecoder,
loaded_offset: usize,
freq_decoder: BlockDecoder,
freq_reading_option: FreqReadingOption,
doc_freq: usize,
data: ReadOnlySource,
skip_reader: SkipReader,
}
fn decode_bitpacked_block(
doc_decoder: &mut BlockDecoder,
freq_decoder_opt: Option<&mut BlockDecoder>,
data: &[u8],
doc_offset: DocId,
doc_num_bits: u8,
tf_num_bits: u8,
) {
let num_consumed_bytes = doc_decoder.uncompress_block_sorted(data, doc_offset, doc_num_bits);
if let Some(freq_decoder) = freq_decoder_opt {
freq_decoder.uncompress_block_unsorted(&data[num_consumed_bytes..], tf_num_bits);
}
}
fn decode_vint_block(
doc_decoder: &mut BlockDecoder,
freq_decoder_opt: Option<&mut BlockDecoder>,
data: &[u8],
doc_offset: DocId,
num_vint_docs: usize,
) {
doc_decoder.clear();
let num_consumed_bytes = doc_decoder.uncompress_vint_sorted(data, doc_offset, num_vint_docs);
if let Some(freq_decoder) = freq_decoder_opt {
freq_decoder.uncompress_vint_unsorted(&data[num_consumed_bytes..], num_vint_docs);
}
}
fn split_into_skips_and_postings(
doc_freq: u32,
data: ReadOnlySource,
) -> (Option<ReadOnlySource>, ReadOnlySource) {
if doc_freq < COMPRESSION_BLOCK_SIZE as u32 {
return (None, data);
}
let mut data_byte_arr = data.as_slice();
let skip_len = VInt::deserialize(&mut data_byte_arr)
.expect("Data corrupted")
.0 as usize;
let vint_len = data.len() - data_byte_arr.len();
let (skip_data, postings_data) = data.slice_from(vint_len).split(skip_len);
(Some(skip_data), postings_data)
}
impl BlockSegmentPostings {
pub(crate) fn from_data(
doc_freq: u32,
data: ReadOnlySource,
record_option: IndexRecordOption,
requested_option: IndexRecordOption,
) -> BlockSegmentPostings {
let freq_reading_option = match (record_option, requested_option) {
(IndexRecordOption::Basic, _) => FreqReadingOption::NoFreq,
(_, IndexRecordOption::Basic) => FreqReadingOption::SkipFreq,
(_, _) => FreqReadingOption::ReadFreq,
};
let (skip_data_opt, postings_data) = split_into_skips_and_postings(doc_freq, data);
let skip_reader = match skip_data_opt {
Some(skip_data) => SkipReader::new(skip_data, doc_freq, record_option),
None => SkipReader::new(ReadOnlySource::empty(), doc_freq, record_option),
};
let doc_freq = doc_freq as usize;
let mut block_segment_postings = BlockSegmentPostings {
doc_decoder: BlockDecoder::with_val(TERMINATED),
loaded_offset: std::usize::MAX,
freq_decoder: BlockDecoder::with_val(1),
freq_reading_option,
doc_freq,
data: postings_data,
skip_reader,
};
block_segment_postings.advance();
block_segment_postings
}
// Resets the block segment postings on another position
// in the postings file.
//
// This is useful for enumerating through a list of terms,
// and consuming the associated posting lists while avoiding
// reallocating a `BlockSegmentPostings`.
//
// # Warning
//
// This does not reset the positions list.
pub(crate) fn reset(&mut self, doc_freq: u32, postings_data: ReadOnlySource) {
let (skip_data_opt, postings_data) = split_into_skips_and_postings(doc_freq, postings_data);
self.data = ReadOnlySource::new(postings_data);
self.loaded_offset = std::usize::MAX;
self.loaded_offset = std::usize::MAX;
if let Some(skip_data) = skip_data_opt {
self.skip_reader.reset(skip_data, doc_freq);
} else {
self.skip_reader.reset(ReadOnlySource::empty(), doc_freq);
}
self.doc_freq = doc_freq as usize;
}
/// Returns the document frequency associated to this block postings.
///
/// This `doc_freq` is simply the sum of the length of all of the blocks
/// length, and it does not take in account deleted documents.
pub fn doc_freq(&self) -> usize {
self.doc_freq
}
/// Returns the array of docs in the current block.
///
/// Before the first call to `.advance()`, the block
/// returned by `.docs()` is empty.
#[inline]
pub fn docs(&self) -> &[DocId] {
self.doc_decoder.output_array()
}
#[inline(always)]
pub(crate) fn docs_aligned(&self) -> &AlignedBuffer {
self.doc_decoder.output_aligned()
}
/// Return the document at index `idx` of the block.
#[inline(always)]
pub fn doc(&self, idx: usize) -> u32 {
self.doc_decoder.output(idx)
}
/// Return the array of `term freq` in the block.
#[inline]
pub fn freqs(&self) -> &[u32] {
self.freq_decoder.output_array()
}
/// Return the frequency at index `idx` of the block.
#[inline]
pub fn freq(&self, idx: usize) -> u32 {
self.freq_decoder.output(idx)
}
/// Returns the length of the current block.
///
/// All blocks have a length of `NUM_DOCS_PER_BLOCK`,
/// except the last block that may have a length
/// of any number between 1 and `NUM_DOCS_PER_BLOCK - 1`
#[inline]
pub fn block_len(&self) -> usize {
self.doc_decoder.output_len
}
pub(crate) fn position_offset(&self) -> u64 {
self.skip_reader.position_offset()
}
/// Position on a block that may contains `target_doc`.
///
/// If all docs are smaller than target, the block loaded may be empty,
/// or be the last an incomplete VInt block.
pub fn seek(&mut self, target_doc: DocId) {
self.skip_reader.seek(target_doc);
self.load_block();
}
fn load_block(&mut self) {
let offset = self.skip_reader.byte_offset();
if self.loaded_offset == offset {
return;
}
self.loaded_offset = offset;
match self.skip_reader.block_info() {
BlockInfo::BitPacked {
doc_num_bits,
tf_num_bits,
..
} => {
decode_bitpacked_block(
&mut self.doc_decoder,
if let FreqReadingOption::ReadFreq = self.freq_reading_option {
Some(&mut self.freq_decoder)
} else {
None
},
&self.data.as_slice()[offset..],
self.skip_reader.last_doc_in_previous_block,
doc_num_bits,
tf_num_bits,
);
}
BlockInfo::VInt(num_vint_docs) => {
decode_vint_block(
&mut self.doc_decoder,
if let FreqReadingOption::ReadFreq = self.freq_reading_option {
Some(&mut self.freq_decoder)
} else {
None
},
&self.data.as_slice()[offset..],
self.skip_reader.last_doc_in_previous_block,
num_vint_docs as usize,
);
}
}
}
/// Advance to the next block.
///
/// Returns false iff there was no remaining blocks.
pub fn advance(&mut self) -> bool {
if !self.skip_reader.advance() {
return false;
}
self.load_block();
true
}
/// Returns an empty segment postings object
pub fn empty() -> BlockSegmentPostings {
BlockSegmentPostings {
doc_decoder: BlockDecoder::with_val(TERMINATED),
loaded_offset: std::usize::MAX,
freq_decoder: BlockDecoder::with_val(1),
freq_reading_option: FreqReadingOption::NoFreq,
doc_freq: 0,
data: ReadOnlySource::new(vec![]),
skip_reader: SkipReader::new(ReadOnlySource::new(vec![]), 0, IndexRecordOption::Basic),
}
}
}
#[cfg(test)]
mod tests {
use super::BlockSegmentPostings;
use crate::common::HasLen;
use crate::core::Index;
use crate::docset::{DocSet, TERMINATED};
use crate::postings::compression::COMPRESSION_BLOCK_SIZE;
use crate::postings::postings::Postings;
use crate::postings::SegmentPostings;
use crate::schema::IndexRecordOption;
use crate::schema::Schema;
use crate::schema::Term;
use crate::schema::INDEXED;
use crate::DocId;
#[test]
fn test_empty_segment_postings() {
let mut postings = SegmentPostings::empty();
assert_eq!(postings.advance(), TERMINATED);
assert_eq!(postings.advance(), TERMINATED);
assert_eq!(postings.len(), 0);
}
#[test]
fn test_empty_postings_doc_returns_terminated() {
let mut postings = SegmentPostings::empty();
assert_eq!(postings.doc(), TERMINATED);
assert_eq!(postings.advance(), TERMINATED);
}
#[test]
fn test_empty_postings_doc_term_freq_returns_0() {
let postings = SegmentPostings::empty();
assert_eq!(postings.term_freq(), 1);
}
#[test]
fn test_empty_block_segment_postings() {
let mut postings = BlockSegmentPostings::empty();
assert!(!postings.advance());
assert_eq!(postings.doc_freq(), 0);
}
#[test]
fn test_block_segment_postings() {
let mut block_segments = build_block_postings(&(0..100_000).collect::<Vec<u32>>());
let mut offset: u32 = 0u32;
// checking that the `doc_freq` is correct
assert_eq!(block_segments.doc_freq(), 100_000);
loop {
let block = block_segments.docs();
for (i, doc) in block.iter().cloned().enumerate() {
assert_eq!(offset + (i as u32), doc);
}
offset += block.len() as u32;
if block_segments.advance() {
break;
}
}
}
#[test]
fn test_skip_right_at_new_block() {
let mut doc_ids = (0..128).collect::<Vec<u32>>();
// 128 is missing
doc_ids.push(129);
doc_ids.push(130);
{
let block_segments = build_block_postings(&doc_ids);
let mut docset = SegmentPostings::from_block_postings(block_segments, None);
assert_eq!(docset.seek(128), 129);
assert_eq!(docset.doc(), 129);
assert_eq!(docset.advance(), 130);
assert_eq!(docset.doc(), 130);
assert_eq!(docset.advance(), TERMINATED);
}
{
let block_segments = build_block_postings(&doc_ids);
let mut docset = SegmentPostings::from_block_postings(block_segments, None);
assert_eq!(docset.seek(129), 129);
assert_eq!(docset.doc(), 129);
assert_eq!(docset.advance(), 130);
assert_eq!(docset.doc(), 130);
assert_eq!(docset.advance(), TERMINATED);
}
{
let block_segments = build_block_postings(&doc_ids);
let mut docset = SegmentPostings::from_block_postings(block_segments, None);
assert_eq!(docset.doc(), 0);
assert_eq!(docset.seek(131), TERMINATED);
assert_eq!(docset.doc(), TERMINATED);
}
}
fn build_block_postings(docs: &[DocId]) -> BlockSegmentPostings {
let mut schema_builder = Schema::builder();
let int_field = schema_builder.add_u64_field("id", INDEXED);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap();
let mut last_doc = 0u32;
for &doc in docs {
for _ in last_doc..doc {
index_writer.add_document(doc!(int_field=>1u64));
}
index_writer.add_document(doc!(int_field=>0u64));
last_doc = doc + 1;
}
index_writer.commit().unwrap();
let searcher = index.reader().unwrap().searcher();
let segment_reader = searcher.segment_reader(0);
let inverted_index = segment_reader.inverted_index(int_field);
let term = Term::from_field_u64(int_field, 0u64);
let term_info = inverted_index.get_term_info(&term).unwrap();
inverted_index.read_block_postings_from_terminfo(&term_info, IndexRecordOption::Basic)
}
#[test]
fn test_block_segment_postings_skip2() {
let mut docs = vec![0];
for i in 0..1300 {
docs.push((i * i / 100) + i);
}
let mut block_postings = build_block_postings(&docs[..]);
for i in vec![0, 424, 10000] {
block_postings.seek(i);
let docs = block_postings.docs();
assert!(docs[0] <= i);
assert!(docs.last().cloned().unwrap_or(0u32) >= i);
}
block_postings.seek(100_000);
assert_eq!(block_postings.doc(COMPRESSION_BLOCK_SIZE - 1), TERMINATED);
}
#[test]
fn test_reset_block_segment_postings() {
let mut schema_builder = Schema::builder();
let int_field = schema_builder.add_u64_field("id", INDEXED);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap();
// create two postings list, one containg even number,
// the other containing odd numbers.
for i in 0..6 {
let doc = doc!(int_field=> (i % 2) as u64);
index_writer.add_document(doc);
}
index_writer.commit().unwrap();
let searcher = index.reader().unwrap().searcher();
let segment_reader = searcher.segment_reader(0);
let mut block_segments;
{
let term = Term::from_field_u64(int_field, 0u64);
let inverted_index = segment_reader.inverted_index(int_field);
let term_info = inverted_index.get_term_info(&term).unwrap();
block_segments = inverted_index
.read_block_postings_from_terminfo(&term_info, IndexRecordOption::Basic);
}
assert_eq!(block_segments.docs(), &[0, 2, 4]);
{
let term = Term::from_field_u64(int_field, 1u64);
let inverted_index = segment_reader.inverted_index(int_field);
let term_info = inverted_index.get_term_info(&term).unwrap();
inverted_index.reset_block_postings_from_terminfo(&term_info, &mut block_segments);
}
assert!(block_segments.advance());
assert_eq!(block_segments.docs(), &[1, 3, 5]);
}
}

View File

@@ -1,5 +1,4 @@
use crate::common::FixedSize; use crate::common::FixedSize;
use crate::docset::TERMINATED;
use bitpacking::{BitPacker, BitPacker4x}; use bitpacking::{BitPacker, BitPacker4x};
pub const COMPRESSION_BLOCK_SIZE: usize = BitPacker4x::BLOCK_LEN; pub const COMPRESSION_BLOCK_SIZE: usize = BitPacker4x::BLOCK_LEN;
@@ -18,12 +17,6 @@ pub struct BlockEncoder {
pub output_len: usize, pub output_len: usize,
} }
impl Default for BlockEncoder {
fn default() -> Self {
BlockEncoder::new()
}
}
impl BlockEncoder { impl BlockEncoder {
pub fn new() -> BlockEncoder { pub fn new() -> BlockEncoder {
BlockEncoder { BlockEncoder {
@@ -61,13 +54,11 @@ pub struct BlockDecoder {
pub output_len: usize, pub output_len: usize,
} }
impl Default for BlockDecoder { impl BlockDecoder {
fn default() -> Self { pub fn new() -> BlockDecoder {
BlockDecoder::with_val(0u32) BlockDecoder::with_val(0u32)
} }
}
impl BlockDecoder {
pub fn with_val(val: u32) -> BlockDecoder { pub fn with_val(val: u32) -> BlockDecoder {
BlockDecoder { BlockDecoder {
bitpacker: BitPacker4x::new(), bitpacker: BitPacker4x::new(),
@@ -99,18 +90,14 @@ impl BlockDecoder {
} }
#[inline] #[inline]
pub(crate) fn output_aligned(&self) -> &AlignedBuffer { pub(crate) fn output_aligned(&self) -> (&AlignedBuffer, usize) {
&self.output (&self.output, self.output_len)
} }
#[inline] #[inline]
pub fn output(&self, idx: usize) -> u32 { pub fn output(&self, idx: usize) -> u32 {
self.output.0[idx] self.output.0[idx]
} }
pub fn clear(&mut self) {
self.output.0.iter_mut().for_each(|el| *el = TERMINATED);
}
} }
pub trait VIntEncoder { pub trait VIntEncoder {
@@ -147,9 +134,9 @@ pub trait VIntDecoder {
/// For instance, if delta encoded are `1, 3, 9`, and the /// For instance, if delta encoded are `1, 3, 9`, and the
/// `offset` is 5, then the output will be: /// `offset` is 5, then the output will be:
/// `5 + 1 = 6, 6 + 3= 9, 9 + 9 = 18` /// `5 + 1 = 6, 6 + 3= 9, 9 + 9 = 18`
fn uncompress_vint_sorted( fn uncompress_vint_sorted<'a>(
&mut self, &mut self,
compressed_data: &[u8], compressed_data: &'a [u8],
offset: u32, offset: u32,
num_els: usize, num_els: usize,
) -> usize; ) -> usize;
@@ -159,7 +146,7 @@ pub trait VIntDecoder {
/// ///
/// The method takes a number of int to decompress, and returns /// The method takes a number of int to decompress, and returns
/// the amount of bytes that were read to decompress them. /// the amount of bytes that were read to decompress them.
fn uncompress_vint_unsorted(&mut self, compressed_data: &[u8], num_els: usize) -> usize; fn uncompress_vint_unsorted<'a>(&mut self, compressed_data: &'a [u8], num_els: usize) -> usize;
} }
impl VIntEncoder for BlockEncoder { impl VIntEncoder for BlockEncoder {
@@ -173,9 +160,9 @@ impl VIntEncoder for BlockEncoder {
} }
impl VIntDecoder for BlockDecoder { impl VIntDecoder for BlockDecoder {
fn uncompress_vint_sorted( fn uncompress_vint_sorted<'a>(
&mut self, &mut self,
compressed_data: &[u8], compressed_data: &'a [u8],
offset: u32, offset: u32,
num_els: usize, num_els: usize,
) -> usize { ) -> usize {
@@ -183,7 +170,7 @@ impl VIntDecoder for BlockDecoder {
vint::uncompress_sorted(compressed_data, &mut self.output.0[..num_els], offset) vint::uncompress_sorted(compressed_data, &mut self.output.0[..num_els], offset)
} }
fn uncompress_vint_unsorted(&mut self, compressed_data: &[u8], num_els: usize) -> usize { fn uncompress_vint_unsorted<'a>(&mut self, compressed_data: &'a [u8], num_els: usize) -> usize {
self.output_len = num_els; self.output_len = num_els;
vint::uncompress_unsorted(compressed_data, &mut self.output.0[..num_els]) vint::uncompress_unsorted(compressed_data, &mut self.output.0[..num_els])
} }
@@ -199,7 +186,7 @@ pub mod tests {
let vals: Vec<u32> = (0u32..128u32).map(|i| i * 7).collect(); let vals: Vec<u32> = (0u32..128u32).map(|i| i * 7).collect();
let mut encoder = BlockEncoder::new(); let mut encoder = BlockEncoder::new();
let (num_bits, compressed_data) = encoder.compress_block_sorted(&vals, 0); let (num_bits, compressed_data) = encoder.compress_block_sorted(&vals, 0);
let mut decoder = BlockDecoder::default(); let mut decoder = BlockDecoder::new();
{ {
let consumed_num_bytes = decoder.uncompress_block_sorted(compressed_data, 0, num_bits); let consumed_num_bytes = decoder.uncompress_block_sorted(compressed_data, 0, num_bits);
assert_eq!(consumed_num_bytes, compressed_data.len()); assert_eq!(consumed_num_bytes, compressed_data.len());
@@ -212,9 +199,9 @@ pub mod tests {
#[test] #[test]
fn test_encode_sorted_block_with_offset() { fn test_encode_sorted_block_with_offset() {
let vals: Vec<u32> = (0u32..128u32).map(|i| 11 + i * 7).collect(); let vals: Vec<u32> = (0u32..128u32).map(|i| 11 + i * 7).collect();
let mut encoder = BlockEncoder::default(); let mut encoder = BlockEncoder::new();
let (num_bits, compressed_data) = encoder.compress_block_sorted(&vals, 10); let (num_bits, compressed_data) = encoder.compress_block_sorted(&vals, 10);
let mut decoder = BlockDecoder::default(); let mut decoder = BlockDecoder::new();
{ {
let consumed_num_bytes = decoder.uncompress_block_sorted(compressed_data, 10, num_bits); let consumed_num_bytes = decoder.uncompress_block_sorted(compressed_data, 10, num_bits);
assert_eq!(consumed_num_bytes, compressed_data.len()); assert_eq!(consumed_num_bytes, compressed_data.len());
@@ -229,11 +216,11 @@ pub mod tests {
let mut compressed: Vec<u8> = Vec::new(); let mut compressed: Vec<u8> = Vec::new();
let n = 128; let n = 128;
let vals: Vec<u32> = (0..n).map(|i| 11u32 + (i as u32) * 7u32).collect(); let vals: Vec<u32> = (0..n).map(|i| 11u32 + (i as u32) * 7u32).collect();
let mut encoder = BlockEncoder::default(); let mut encoder = BlockEncoder::new();
let (num_bits, compressed_data) = encoder.compress_block_sorted(&vals, 10); let (num_bits, compressed_data) = encoder.compress_block_sorted(&vals, 10);
compressed.extend_from_slice(compressed_data); compressed.extend_from_slice(compressed_data);
compressed.push(173u8); compressed.push(173u8);
let mut decoder = BlockDecoder::default(); let mut decoder = BlockDecoder::new();
{ {
let consumed_num_bytes = decoder.uncompress_block_sorted(&compressed, 10, num_bits); let consumed_num_bytes = decoder.uncompress_block_sorted(&compressed, 10, num_bits);
assert_eq!(consumed_num_bytes, compressed.len() - 1); assert_eq!(consumed_num_bytes, compressed.len() - 1);
@@ -249,11 +236,11 @@ pub mod tests {
let mut compressed: Vec<u8> = Vec::new(); let mut compressed: Vec<u8> = Vec::new();
let n = 128; let n = 128;
let vals: Vec<u32> = (0..n).map(|i| 11u32 + (i as u32) * 7u32 % 12).collect(); let vals: Vec<u32> = (0..n).map(|i| 11u32 + (i as u32) * 7u32 % 12).collect();
let mut encoder = BlockEncoder::default(); let mut encoder = BlockEncoder::new();
let (num_bits, compressed_data) = encoder.compress_block_unsorted(&vals); let (num_bits, compressed_data) = encoder.compress_block_unsorted(&vals);
compressed.extend_from_slice(compressed_data); compressed.extend_from_slice(compressed_data);
compressed.push(173u8); compressed.push(173u8);
let mut decoder = BlockDecoder::default(); let mut decoder = BlockDecoder::new();
{ {
let consumed_num_bytes = decoder.uncompress_block_unsorted(&compressed, num_bits); let consumed_num_bytes = decoder.uncompress_block_unsorted(&compressed, num_bits);
assert_eq!(consumed_num_bytes + 1, compressed.len()); assert_eq!(consumed_num_bytes + 1, compressed.len());
@@ -264,11 +251,6 @@ pub mod tests {
} }
} }
#[test]
fn test_block_decoder_initialization() {
let block = BlockDecoder::with_val(TERMINATED);
assert_eq!(block.output(0), TERMINATED);
}
#[test] #[test]
fn test_encode_vint() { fn test_encode_vint() {
{ {
@@ -278,7 +260,7 @@ pub mod tests {
for offset in &[0u32, 1u32, 2u32] { for offset in &[0u32, 1u32, 2u32] {
let encoded_data = encoder.compress_vint_sorted(&input, *offset); let encoded_data = encoder.compress_vint_sorted(&input, *offset);
assert!(encoded_data.len() <= expected_length); assert!(encoded_data.len() <= expected_length);
let mut decoder = BlockDecoder::default(); let mut decoder = BlockDecoder::new();
let consumed_num_bytes = let consumed_num_bytes =
decoder.uncompress_vint_sorted(&encoded_data, *offset, input.len()); decoder.uncompress_vint_sorted(&encoded_data, *offset, input.len());
assert_eq!(consumed_num_bytes, encoded_data.len()); assert_eq!(consumed_num_bytes, encoded_data.len());

View File

@@ -42,7 +42,7 @@ pub(crate) fn compress_unsorted<'a>(input: &[u32], output: &'a mut [u8]) -> &'a
} }
#[inline(always)] #[inline(always)]
pub fn uncompress_sorted(compressed_data: &[u8], output: &mut [u32], offset: u32) -> usize { pub fn uncompress_sorted<'a>(compressed_data: &'a [u8], output: &mut [u32], offset: u32) -> usize {
let mut read_byte = 0; let mut read_byte = 0;
let mut result = offset; let mut result = offset;
for output_mut in output.iter_mut() { for output_mut in output.iter_mut() {

View File

@@ -3,8 +3,11 @@ Postings module (also called inverted index)
*/ */
mod block_search; mod block_search;
mod block_segment_postings;
pub(crate) mod compression; pub(crate) mod compression;
/// Postings module
///
/// Postings, also called inverted lists, is the key datastructure
/// to full-text search.
mod postings; mod postings;
mod postings_writer; mod postings_writer;
mod recorder; mod recorder;
@@ -19,17 +22,18 @@ pub(crate) use self::block_search::BlockSearcher;
pub(crate) use self::postings_writer::MultiFieldPostingsWriter; pub(crate) use self::postings_writer::MultiFieldPostingsWriter;
pub use self::serializer::{FieldSerializer, InvertedIndexSerializer}; pub use self::serializer::{FieldSerializer, InvertedIndexSerializer};
use self::compression::COMPRESSION_BLOCK_SIZE;
pub use self::postings::Postings; pub use self::postings::Postings;
pub(crate) use self::skip::{BlockInfo, SkipReader}; pub(crate) use self::skip::SkipReader;
pub use self::term_info::TermInfo; pub use self::term_info::TermInfo;
pub use self::block_segment_postings::BlockSegmentPostings; pub use self::segment_postings::{BlockSegmentPostings, SegmentPostings};
pub use self::segment_postings::SegmentPostings;
pub(crate) use self::stacker::compute_table_size; pub(crate) use self::stacker::compute_table_size;
pub use crate::common::HasLen; pub use crate::common::HasLen;
pub(crate) const USE_SKIP_INFO_LIMIT: u32 = COMPRESSION_BLOCK_SIZE as u32;
pub(crate) type UnorderedTermId = u64; pub(crate) type UnorderedTermId = u64;
#[cfg_attr(feature = "cargo-clippy", allow(clippy::enum_variant_names))] #[cfg_attr(feature = "cargo-clippy", allow(clippy::enum_variant_names))]
@@ -47,7 +51,7 @@ pub mod tests {
use crate::core::Index; use crate::core::Index;
use crate::core::SegmentComponent; use crate::core::SegmentComponent;
use crate::core::SegmentReader; use crate::core::SegmentReader;
use crate::docset::{DocSet, TERMINATED}; use crate::docset::{DocSet, SkipResult};
use crate::fieldnorm::FieldNormReader; use crate::fieldnorm::FieldNormReader;
use crate::indexer::operation::AddOperation; use crate::indexer::operation::AddOperation;
use crate::indexer::SegmentWriter; use crate::indexer::SegmentWriter;
@@ -71,7 +75,7 @@ pub mod tests {
let schema = schema_builder.build(); let schema = schema_builder.build();
let index = Index::create_in_ram(schema); let index = Index::create_in_ram(schema);
let mut segment = index.new_segment(); let mut segment = index.new_segment();
let mut posting_serializer = InvertedIndexSerializer::open(&mut segment).unwrap(); let mut posting_serializer = InvertedIndexSerializer::for_segment(&mut segment).unwrap();
{ {
let mut field_serializer = posting_serializer.new_field(text_field, 120 * 4).unwrap(); let mut field_serializer = posting_serializer.new_field(text_field, 120 * 4).unwrap();
field_serializer.new_term("abc".as_bytes()).unwrap(); field_serializer.new_term("abc".as_bytes()).unwrap();
@@ -111,12 +115,29 @@ pub mod tests {
let mut postings = inverted_index let mut postings = inverted_index
.read_postings(&term, IndexRecordOption::WithFreqsAndPositions) .read_postings(&term, IndexRecordOption::WithFreqsAndPositions)
.unwrap(); .unwrap();
assert_eq!(postings.doc(), 0); postings.advance();
postings.positions(&mut positions); postings.positions(&mut positions);
assert_eq!(&[0, 1, 2], &positions[..]); assert_eq!(&[0, 1, 2], &positions[..]);
postings.positions(&mut positions); postings.positions(&mut positions);
assert_eq!(&[0, 1, 2], &positions[..]); assert_eq!(&[0, 1, 2], &positions[..]);
assert_eq!(postings.advance(), 1); postings.advance();
postings.positions(&mut positions);
assert_eq!(&[0, 5], &positions[..]);
}
{
let mut postings = inverted_index
.read_postings(&term, IndexRecordOption::WithFreqsAndPositions)
.unwrap();
postings.advance();
postings.advance();
postings.positions(&mut positions);
assert_eq!(&[0, 5], &positions[..]);
}
{
let mut postings = inverted_index
.read_postings(&term, IndexRecordOption::WithFreqsAndPositions)
.unwrap();
assert_eq!(postings.skip_next(1), SkipResult::Reached);
assert_eq!(postings.doc(), 1); assert_eq!(postings.doc(), 1);
postings.positions(&mut positions); postings.positions(&mut positions);
assert_eq!(&[0, 5], &positions[..]); assert_eq!(&[0, 5], &positions[..]);
@@ -125,25 +146,7 @@ pub mod tests {
let mut postings = inverted_index let mut postings = inverted_index
.read_postings(&term, IndexRecordOption::WithFreqsAndPositions) .read_postings(&term, IndexRecordOption::WithFreqsAndPositions)
.unwrap(); .unwrap();
assert_eq!(postings.doc(), 0); assert_eq!(postings.skip_next(1002), SkipResult::Reached);
assert_eq!(postings.advance(), 1);
postings.positions(&mut positions);
assert_eq!(&[0, 5], &positions[..]);
}
{
let mut postings = inverted_index
.read_postings(&term, IndexRecordOption::WithFreqsAndPositions)
.unwrap();
assert_eq!(postings.seek(1), 1);
assert_eq!(postings.doc(), 1);
postings.positions(&mut positions);
assert_eq!(&[0, 5], &positions[..]);
}
{
let mut postings = inverted_index
.read_postings(&term, IndexRecordOption::WithFreqsAndPositions)
.unwrap();
assert_eq!(postings.seek(1002), 1002);
assert_eq!(postings.doc(), 1002); assert_eq!(postings.doc(), 1002);
postings.positions(&mut positions); postings.positions(&mut positions);
assert_eq!(&[0, 5], &positions[..]); assert_eq!(&[0, 5], &positions[..]);
@@ -152,8 +155,8 @@ pub mod tests {
let mut postings = inverted_index let mut postings = inverted_index
.read_postings(&term, IndexRecordOption::WithFreqsAndPositions) .read_postings(&term, IndexRecordOption::WithFreqsAndPositions)
.unwrap(); .unwrap();
assert_eq!(postings.seek(100), 100); assert_eq!(postings.skip_next(100), SkipResult::Reached);
assert_eq!(postings.seek(1002), 1002); assert_eq!(postings.skip_next(1002), SkipResult::Reached);
assert_eq!(postings.doc(), 1002); assert_eq!(postings.doc(), 1002);
postings.positions(&mut positions); postings.positions(&mut positions);
assert_eq!(&[0, 5], &positions[..]); assert_eq!(&[0, 5], &positions[..]);
@@ -278,21 +281,22 @@ pub mod tests {
.read_postings(&term_a, IndexRecordOption::WithFreqsAndPositions) .read_postings(&term_a, IndexRecordOption::WithFreqsAndPositions)
.unwrap(); .unwrap();
assert_eq!(postings_a.len(), 1000); assert_eq!(postings_a.len(), 1000);
assert!(postings_a.advance());
assert_eq!(postings_a.doc(), 0); assert_eq!(postings_a.doc(), 0);
assert_eq!(postings_a.term_freq(), 6); assert_eq!(postings_a.term_freq(), 6);
postings_a.positions(&mut positions); postings_a.positions(&mut positions);
assert_eq!(&positions[..], [0, 2, 4, 6, 7, 13]); assert_eq!(&positions[..], [0, 2, 4, 6, 7, 13]);
assert_eq!(postings_a.advance(), 1u32); assert!(postings_a.advance());
assert_eq!(postings_a.doc(), 1u32); assert_eq!(postings_a.doc(), 1u32);
assert_eq!(postings_a.term_freq(), 1); assert_eq!(postings_a.term_freq(), 1);
for i in 2u32..1000u32 { for i in 2u32..1000u32 {
assert_eq!(postings_a.advance(), i); assert!(postings_a.advance());
assert_eq!(postings_a.term_freq(), 1); assert_eq!(postings_a.term_freq(), 1);
postings_a.positions(&mut positions); postings_a.positions(&mut positions);
assert_eq!(&positions[..], [i]); assert_eq!(&positions[..], [i]);
assert_eq!(postings_a.doc(), i); assert_eq!(postings_a.doc(), i);
} }
assert_eq!(postings_a.advance(), TERMINATED); assert!(!postings_a.advance());
} }
{ {
let term_e = Term::from_field_text(text_field, "e"); let term_e = Term::from_field_text(text_field, "e");
@@ -302,6 +306,7 @@ pub mod tests {
.unwrap(); .unwrap();
assert_eq!(postings_e.len(), 1000 - 2); assert_eq!(postings_e.len(), 1000 - 2);
for i in 2u32..1000u32 { for i in 2u32..1000u32 {
assert!(postings_e.advance());
assert_eq!(postings_e.term_freq(), i); assert_eq!(postings_e.term_freq(), i);
postings_e.positions(&mut positions); postings_e.positions(&mut positions);
assert_eq!(positions.len(), i as usize); assert_eq!(positions.len(), i as usize);
@@ -309,9 +314,8 @@ pub mod tests {
assert_eq!(positions[j], (j as u32)); assert_eq!(positions[j], (j as u32));
} }
assert_eq!(postings_e.doc(), i); assert_eq!(postings_e.doc(), i);
postings_e.advance();
} }
assert_eq!(postings_e.doc(), TERMINATED); assert!(!postings_e.advance());
} }
} }
} }
@@ -325,8 +329,16 @@ pub mod tests {
let index = Index::create_in_ram(schema); let index = Index::create_in_ram(schema);
{ {
let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap(); let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap();
index_writer.add_document(doc!(text_field => "g b b d c g c")); {
index_writer.add_document(doc!(text_field => "g a b b a d c g c")); let mut doc = Document::default();
doc.add_text(text_field, "g b b d c g c");
index_writer.add_document(doc);
}
{
let mut doc = Document::default();
doc.add_text(text_field, "g a b b a d c g c");
index_writer.add_document(doc);
}
assert!(index_writer.commit().is_ok()); assert!(index_writer.commit().is_ok());
} }
let term_a = Term::from_field_text(text_field, "a"); let term_a = Term::from_field_text(text_field, "a");
@@ -336,6 +348,7 @@ pub mod tests {
.inverted_index(text_field) .inverted_index(text_field)
.read_postings(&term_a, IndexRecordOption::WithFreqsAndPositions) .read_postings(&term_a, IndexRecordOption::WithFreqsAndPositions)
.unwrap(); .unwrap();
assert!(postings.advance());
assert_eq!(postings.doc(), 1u32); assert_eq!(postings.doc(), 1u32);
postings.positions(&mut positions); postings.positions(&mut positions);
assert_eq!(&positions[..], &[1u32, 4]); assert_eq!(&positions[..], &[1u32, 4]);
@@ -357,8 +370,11 @@ pub mod tests {
let index = Index::create_in_ram(schema); let index = Index::create_in_ram(schema);
{ {
let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap(); let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap();
for i in 0u64..num_docs as u64 { for i in 0..num_docs {
let doc = doc!(value_field => 2u64, value_field => i % 2u64); let mut doc = Document::default();
doc.add_u64(value_field, 2);
doc.add_u64(value_field, (i % 2) as u64);
index_writer.add_document(doc); index_writer.add_document(doc);
} }
assert!(index_writer.commit().is_ok()); assert!(index_writer.commit().is_ok());
@@ -375,10 +391,11 @@ pub mod tests {
.inverted_index(term_2.field()) .inverted_index(term_2.field())
.read_postings(&term_2, IndexRecordOption::Basic) .read_postings(&term_2, IndexRecordOption::Basic)
.unwrap(); .unwrap();
assert_eq!(segment_postings.seek(i), i);
assert_eq!(segment_postings.skip_next(i), SkipResult::Reached);
assert_eq!(segment_postings.doc(), i); assert_eq!(segment_postings.doc(), i);
assert_eq!(segment_postings.seek(j), j); assert_eq!(segment_postings.skip_next(j), SkipResult::Reached);
assert_eq!(segment_postings.doc(), j); assert_eq!(segment_postings.doc(), j);
} }
} }
@@ -390,16 +407,17 @@ pub mod tests {
.unwrap(); .unwrap();
// check that `skip_next` advances the iterator // check that `skip_next` advances the iterator
assert!(segment_postings.advance());
assert_eq!(segment_postings.doc(), 0); assert_eq!(segment_postings.doc(), 0);
assert_eq!(segment_postings.seek(1), 1); assert_eq!(segment_postings.skip_next(1), SkipResult::Reached);
assert_eq!(segment_postings.doc(), 1); assert_eq!(segment_postings.doc(), 1);
assert_eq!(segment_postings.seek(1), 1); assert_eq!(segment_postings.skip_next(1), SkipResult::OverStep);
assert_eq!(segment_postings.doc(), 1); assert_eq!(segment_postings.doc(), 2);
// check that going beyond the end is handled // check that going beyond the end is handled
assert_eq!(segment_postings.seek(num_docs), TERMINATED); assert_eq!(segment_postings.skip_next(num_docs), SkipResult::End);
} }
// check that filtering works // check that filtering works
@@ -410,7 +428,7 @@ pub mod tests {
.unwrap(); .unwrap();
for i in 0..num_docs / 2 { for i in 0..num_docs / 2 {
assert_eq!(segment_postings.seek(i * 2), i * 2); assert_eq!(segment_postings.skip_next(i * 2), SkipResult::Reached);
assert_eq!(segment_postings.doc(), i * 2); assert_eq!(segment_postings.doc(), i * 2);
} }
@@ -420,7 +438,7 @@ pub mod tests {
.unwrap(); .unwrap();
for i in 0..num_docs / 2 - 1 { for i in 0..num_docs / 2 - 1 {
assert!(segment_postings.seek(i * 2 + 1) > (i * 1) * 2); assert_eq!(segment_postings.skip_next(i * 2 + 1), SkipResult::OverStep);
assert_eq!(segment_postings.doc(), (i + 1) * 2); assert_eq!(segment_postings.doc(), (i + 1) * 2);
} }
} }
@@ -432,7 +450,6 @@ pub mod tests {
assert!(index_writer.commit().is_ok()); assert!(index_writer.commit().is_ok());
} }
let searcher = index.reader().unwrap().searcher(); let searcher = index.reader().unwrap().searcher();
assert_eq!(searcher.segment_readers().len(), 1);
let segment_reader = searcher.segment_reader(0); let segment_reader = searcher.segment_reader(0);
// make sure seeking still works // make sure seeking still works
@@ -443,11 +460,11 @@ pub mod tests {
.unwrap(); .unwrap();
if i % 2 == 0 { if i % 2 == 0 {
assert_eq!(segment_postings.seek(i), i); assert_eq!(segment_postings.skip_next(i), SkipResult::Reached);
assert_eq!(segment_postings.doc(), i); assert_eq!(segment_postings.doc(), i);
assert!(segment_reader.is_deleted(i)); assert!(segment_reader.is_deleted(i));
} else { } else {
assert_eq!(segment_postings.seek(i), i); assert_eq!(segment_postings.skip_next(i), SkipResult::Reached);
assert_eq!(segment_postings.doc(), i); assert_eq!(segment_postings.doc(), i);
} }
} }
@@ -462,16 +479,12 @@ pub mod tests {
let mut last = 2; // start from 5 to avoid seeking to 3 twice let mut last = 2; // start from 5 to avoid seeking to 3 twice
let mut cur = 3; let mut cur = 3;
loop { loop {
let seek = segment_postings.seek(cur); match segment_postings.skip_next(cur) {
if seek == TERMINATED { SkipResult::End => break,
break; SkipResult::Reached => assert_eq!(segment_postings.doc(), cur),
} SkipResult::OverStep => assert_eq!(segment_postings.doc(), cur + 1),
assert_eq!(seek, segment_postings.doc());
if seek == cur {
assert_eq!(segment_postings.doc(), cur);
} else {
assert_eq!(segment_postings.doc(), cur + 1);
} }
let next = cur + last; let next = cur + last;
last = cur; last = cur;
cur = next; cur = next;
@@ -557,7 +570,7 @@ pub mod tests {
} }
impl<TDocSet: DocSet> DocSet for UnoptimizedDocSet<TDocSet> { impl<TDocSet: DocSet> DocSet for UnoptimizedDocSet<TDocSet> {
fn advance(&mut self) -> DocId { fn advance(&mut self) -> bool {
self.0.advance() self.0.advance()
} }
@@ -583,22 +596,30 @@ pub mod tests {
for target in targets { for target in targets {
let mut postings_opt = postings_factory(); let mut postings_opt = postings_factory();
let mut postings_unopt = UnoptimizedDocSet::wrap(postings_factory()); let mut postings_unopt = UnoptimizedDocSet::wrap(postings_factory());
let skip_result_opt = postings_opt.seek(target); let skip_result_opt = postings_opt.skip_next(target);
let skip_result_unopt = postings_unopt.seek(target); let skip_result_unopt = postings_unopt.skip_next(target);
assert_eq!( assert_eq!(
skip_result_unopt, skip_result_opt, skip_result_unopt, skip_result_opt,
"Failed while skipping to {}", "Failed while skipping to {}",
target target
); );
assert!(skip_result_opt >= target); match skip_result_opt {
assert_eq!(skip_result_opt, postings_opt.doc()); SkipResult::Reached => assert_eq!(postings_opt.doc(), target),
if skip_result_opt == TERMINATED { SkipResult::OverStep => assert!(postings_opt.doc() > target),
return; SkipResult::End => {
return;
}
} }
while postings_opt.doc() != TERMINATED { while postings_opt.advance() {
assert_eq!(postings_opt.doc(), postings_unopt.doc()); assert!(postings_unopt.advance());
assert_eq!(postings_opt.advance(), postings_unopt.advance()); assert_eq!(
postings_opt.doc(),
postings_unopt.doc(),
"Failed while skipping to {}",
target
);
} }
assert!(!postings_unopt.advance());
} }
} }
} }
@@ -607,7 +628,7 @@ pub mod tests {
mod bench { mod bench {
use super::tests::*; use super::tests::*;
use crate::docset::TERMINATED; use crate::docset::SkipResult;
use crate::query::Intersection; use crate::query::Intersection;
use crate::schema::IndexRecordOption; use crate::schema::IndexRecordOption;
use crate::tests; use crate::tests;
@@ -625,7 +646,7 @@ mod bench {
.inverted_index(TERM_A.field()) .inverted_index(TERM_A.field())
.read_postings(&*TERM_A, IndexRecordOption::Basic) .read_postings(&*TERM_A, IndexRecordOption::Basic)
.unwrap(); .unwrap();
while segment_postings.advance() != TERMINATED {} while segment_postings.advance() {}
}); });
} }
@@ -657,7 +678,7 @@ mod bench {
segment_postings_c, segment_postings_c,
segment_postings_d, segment_postings_d,
]); ]);
while intersection.advance() != TERMINATED {} while intersection.advance() {}
}); });
} }
@@ -673,10 +694,11 @@ mod bench {
.unwrap(); .unwrap();
let mut existing_docs = Vec::new(); let mut existing_docs = Vec::new();
segment_postings.advance();
for doc in &docs { for doc in &docs {
if *doc >= segment_postings.doc() { if *doc >= segment_postings.doc() {
existing_docs.push(*doc); existing_docs.push(*doc);
if segment_postings.seek(*doc) == TERMINATED { if segment_postings.skip_next(*doc) == SkipResult::End {
break; break;
} }
} }
@@ -688,7 +710,7 @@ mod bench {
.read_postings(&*TERM_A, IndexRecordOption::Basic) .read_postings(&*TERM_A, IndexRecordOption::Basic)
.unwrap(); .unwrap();
for doc in &existing_docs { for doc in &existing_docs {
if segment_postings.seek(*doc) == TERMINATED { if segment_postings.skip_next(*doc) == SkipResult::End {
break; break;
} }
} }
@@ -727,9 +749,8 @@ mod bench {
.read_postings(&*TERM_A, IndexRecordOption::Basic) .read_postings(&*TERM_A, IndexRecordOption::Basic)
.unwrap(); .unwrap();
let mut s = 0u32; let mut s = 0u32;
while segment_postings.doc() != TERMINATED { while segment_postings.advance() {
s += (segment_postings.doc() & n) % 1024; s += (segment_postings.doc() & n) % 1024;
segment_postings.advance()
} }
s s
}); });

View File

@@ -11,6 +11,7 @@ use crate::termdict::TermOrdinal;
use crate::tokenizer::TokenStream; use crate::tokenizer::TokenStream;
use crate::tokenizer::{Token, MAX_TOKEN_LEN}; use crate::tokenizer::{Token, MAX_TOKEN_LEN};
use crate::DocId; use crate::DocId;
use crate::Result;
use fnv::FnvHashMap; use fnv::FnvHashMap;
use std::collections::HashMap; use std::collections::HashMap;
use std::io; use std::io;
@@ -128,7 +129,7 @@ impl MultiFieldPostingsWriter {
pub fn serialize( pub fn serialize(
&self, &self,
serializer: &mut InvertedIndexSerializer, serializer: &mut InvertedIndexSerializer,
) -> crate::Result<HashMap<Field, FnvHashMap<UnorderedTermId, TermOrdinal>>> { ) -> Result<HashMap<Field, FnvHashMap<UnorderedTermId, TermOrdinal>>> {
let mut term_offsets: Vec<(&[u8], Addr, UnorderedTermId)> = let mut term_offsets: Vec<(&[u8], Addr, UnorderedTermId)> =
self.term_index.iter().collect(); self.term_index.iter().collect();
term_offsets.sort_unstable_by_key(|&(k, _, _)| k); term_offsets.sort_unstable_by_key(|&(k, _, _)| k);

View File

@@ -1,19 +1,56 @@
use crate::common::BitSet;
use crate::common::HasLen; use crate::common::HasLen;
use crate::common::{BinarySerializable, VInt};
use crate::docset::DocSet; use crate::docset::{DocSet, SkipResult};
use crate::positions::PositionReader; use crate::positions::PositionReader;
use crate::postings::compression::{compressed_block_size, AlignedBuffer};
use crate::postings::compression::COMPRESSION_BLOCK_SIZE; use crate::postings::compression::{BlockDecoder, VIntDecoder, COMPRESSION_BLOCK_SIZE};
use crate::postings::serializer::PostingsSerializer; use crate::postings::serializer::PostingsSerializer;
use crate::postings::BlockSearcher; use crate::postings::BlockSearcher;
use crate::postings::FreqReadingOption;
use crate::postings::Postings; use crate::postings::Postings;
use crate::postings::SkipReader;
use crate::postings::USE_SKIP_INFO_LIMIT;
use crate::schema::IndexRecordOption; use crate::schema::IndexRecordOption;
use crate::DocId; use crate::DocId;
use owned_read::OwnedRead;
use std::cmp::Ordering;
use tantivy_fst::Streamer;
use crate::directory::ReadOnlySource; struct PositionComputer {
use crate::postings::BlockSegmentPostings; // store the amount of position int
// before reading positions.
//
// if none, position are already loaded in
// the positions vec.
position_to_skip: usize,
position_reader: PositionReader,
}
impl PositionComputer {
pub fn new(position_reader: PositionReader) -> PositionComputer {
PositionComputer {
position_to_skip: 0,
position_reader,
}
}
pub fn add_skip(&mut self, num_skip: usize) {
self.position_to_skip += num_skip;
}
// Positions can only be read once.
pub fn positions_with_offset(&mut self, offset: u32, output: &mut [u32]) {
self.position_reader.skip(self.position_to_skip);
self.position_to_skip = 0;
self.position_reader.read(output);
let mut cum = offset;
for output_mut in output.iter_mut() {
cum += *output_mut;
*output_mut = cum;
}
}
}
/// `SegmentPostings` represents the inverted list or postings associated to /// `SegmentPostings` represents the inverted list or postings associated to
/// a term in a `Segment`. /// a term in a `Segment`.
@@ -23,17 +60,18 @@ use crate::postings::BlockSegmentPostings;
pub struct SegmentPostings { pub struct SegmentPostings {
block_cursor: BlockSegmentPostings, block_cursor: BlockSegmentPostings,
cur: usize, cur: usize,
position_reader: Option<PositionReader>, position_computer: Option<PositionComputer>,
block_searcher: BlockSearcher, block_searcher: BlockSearcher,
} }
impl SegmentPostings { impl SegmentPostings {
/// Returns an empty segment postings object /// Returns an empty segment postings object
pub fn empty() -> Self { pub fn empty() -> Self {
let empty_block_cursor = BlockSegmentPostings::empty();
SegmentPostings { SegmentPostings {
block_cursor: BlockSegmentPostings::empty(), block_cursor: empty_block_cursor,
cur: 0, cur: COMPRESSION_BLOCK_SIZE,
position_reader: None, position_computer: None,
block_searcher: BlockSearcher::default(), block_searcher: BlockSearcher::default(),
} }
} }
@@ -59,13 +97,15 @@ impl SegmentPostings {
} }
let block_segment_postings = BlockSegmentPostings::from_data( let block_segment_postings = BlockSegmentPostings::from_data(
docs.len() as u32, docs.len() as u32,
ReadOnlySource::from(buffer), OwnedRead::new(buffer),
IndexRecordOption::Basic, IndexRecordOption::Basic,
IndexRecordOption::Basic, IndexRecordOption::Basic,
); );
SegmentPostings::from_block_postings(block_segment_postings, None) SegmentPostings::from_block_postings(block_segment_postings, None)
} }
}
impl SegmentPostings {
/// Reads a Segment postings from an &[u8] /// Reads a Segment postings from an &[u8]
/// ///
/// * `len` - number of document in the posting lists. /// * `len` - number of document in the posting lists.
@@ -74,12 +114,12 @@ impl SegmentPostings {
/// frequencies and/or positions /// frequencies and/or positions
pub(crate) fn from_block_postings( pub(crate) fn from_block_postings(
segment_block_postings: BlockSegmentPostings, segment_block_postings: BlockSegmentPostings,
position_reader: Option<PositionReader>, positions_stream_opt: Option<PositionReader>,
) -> SegmentPostings { ) -> SegmentPostings {
SegmentPostings { SegmentPostings {
block_cursor: segment_block_postings, block_cursor: segment_block_postings,
cur: 0, // cursor within the block cur: COMPRESSION_BLOCK_SIZE, // cursor within the block
position_reader, position_computer: positions_stream_opt.map(PositionComputer::new),
block_searcher: BlockSearcher::default(), block_searcher: BlockSearcher::default(),
} }
} }
@@ -89,52 +129,134 @@ impl DocSet for SegmentPostings {
// goes to the next element. // goes to the next element.
// next needs to be called a first time to point to the correct element. // next needs to be called a first time to point to the correct element.
#[inline] #[inline]
fn advance(&mut self) -> DocId { fn advance(&mut self) -> bool {
if self.cur == COMPRESSION_BLOCK_SIZE - 1 { if self.position_computer.is_some() && self.cur < COMPRESSION_BLOCK_SIZE {
self.cur = 0; let term_freq = self.term_freq() as usize;
self.block_cursor.advance(); if let Some(position_computer) = self.position_computer.as_mut() {
} else { position_computer.add_skip(term_freq);
self.cur += 1; }
} }
self.doc() self.cur += 1;
if self.cur >= self.block_cursor.block_len() {
self.cur = 0;
if !self.block_cursor.advance() {
self.cur = COMPRESSION_BLOCK_SIZE;
return false;
}
}
true
} }
fn seek(&mut self, target: DocId) -> DocId { fn skip_next(&mut self, target: DocId) -> SkipResult {
if self.doc() == target { if !self.advance() {
return target; return SkipResult::End;
}
match self.doc().cmp(&target) {
Ordering::Equal => {
return SkipResult::Reached;
}
Ordering::Greater => {
return SkipResult::OverStep;
}
_ => {
// ...
}
} }
self.block_cursor.seek(target);
// At this point we are on the block, that might contain our document. // In the following, thanks to the call to advance above,
let output = self.block_cursor.docs_aligned(); // we know that the position is not loaded and we need
// to skip every doc_freq we cross.
self.cur = self.block_searcher.search_in_block(&output, target); // skip blocks until one that might contain the target
// check if we need to go to the next block
let mut sum_freqs_skipped: u32 = 0;
if !self
.block_cursor
.docs()
.last()
.map(|doc| *doc >= target)
.unwrap_or(false)
// there should always be at least a document in the block
// since advance returned.
{
// we are not in the right block.
//
// First compute all of the freqs skipped from the current block.
if self.position_computer.is_some() {
sum_freqs_skipped = self.block_cursor.freqs()[self.cur..].iter().sum();
match self.block_cursor.skip_to(target) {
BlockSegmentPostingsSkipResult::Success(block_skip_freqs) => {
sum_freqs_skipped += block_skip_freqs;
}
BlockSegmentPostingsSkipResult::Terminated => {
return SkipResult::End;
}
}
} else if self.block_cursor.skip_to(target)
== BlockSegmentPostingsSkipResult::Terminated
{
// no positions needed. no need to sum freqs.
return SkipResult::End;
}
self.cur = 0;
}
// The last block is not full and padded with the value TERMINATED, let cur = self.cur;
// so that we are guaranteed to have at least doc in the block (a real one or the padding)
// that is greater or equal to the target. // we're in the right block now, start with an exponential search
debug_assert!(self.cur < COMPRESSION_BLOCK_SIZE); let (output, len) = self.block_cursor.docs_aligned();
let new_cur = self
.block_searcher
.search_in_block(&output, len, cur, target);
if let Some(position_computer) = self.position_computer.as_mut() {
sum_freqs_skipped += self.block_cursor.freqs()[cur..new_cur].iter().sum::<u32>();
position_computer.add_skip(sum_freqs_skipped as usize);
}
self.cur = new_cur;
// `doc` is now the first element >= `target` // `doc` is now the first element >= `target`
let doc = output.0[new_cur];
// If all docs are smaller than target the current block should be incomplemented and padded
// with the value `TERMINATED`.
//
// After the search, the cursor should point to the first value of TERMINATED.
let doc = output.0[self.cur];
debug_assert!(doc >= target); debug_assert!(doc >= target);
doc if doc == target {
SkipResult::Reached
} else {
SkipResult::OverStep
}
} }
/// Return the current document's `DocId`. /// Return the current document's `DocId`.
///
/// # Panics
///
/// Will panics if called without having called advance before.
#[inline] #[inline]
fn doc(&self) -> DocId { fn doc(&self) -> DocId {
self.block_cursor.doc(self.cur) let docs = self.block_cursor.docs();
debug_assert!(
self.cur < docs.len(),
"Have you forgotten to call `.advance()` at least once before calling `.doc()` ."
);
docs[self.cur]
} }
fn size_hint(&self) -> u32 { fn size_hint(&self) -> u32 {
self.len() as u32 self.len() as u32
} }
fn append_to_bitset(&mut self, bitset: &mut BitSet) {
// finish the current block
if self.advance() {
for &doc in &self.block_cursor.docs()[self.cur..] {
bitset.insert(doc);
}
// ... iterate through the remaining blocks.
while self.block_cursor.advance() {
for &doc in self.block_cursor.docs() {
bitset.insert(doc);
}
}
}
}
} }
impl HasLen for SegmentPostings { impl HasLen for SegmentPostings {
@@ -168,52 +290,515 @@ impl Postings for SegmentPostings {
fn positions_with_offset(&mut self, offset: u32, output: &mut Vec<u32>) { fn positions_with_offset(&mut self, offset: u32, output: &mut Vec<u32>) {
let term_freq = self.term_freq() as usize; let term_freq = self.term_freq() as usize;
if let Some(position_reader) = self.position_reader.as_mut() { if let Some(position_comp) = self.position_computer.as_mut() {
let read_offset = self.block_cursor.position_offset()
+ (self.block_cursor.freqs()[..self.cur]
.iter()
.cloned()
.sum::<u32>() as u64);
output.resize(term_freq, 0u32); output.resize(term_freq, 0u32);
position_reader.read(read_offset, &mut output[..]); position_comp.positions_with_offset(offset, &mut output[..]);
let mut cum = offset;
for output_mut in output.iter_mut() {
cum += *output_mut;
*output_mut = cum;
}
} else { } else {
output.clear(); output.clear();
} }
} }
} }
/// `BlockSegmentPostings` is a cursor iterating over blocks
/// of documents.
///
/// # Warning
///
/// While it is useful for some very specific high-performance
/// use cases, you should prefer using `SegmentPostings` for most usage.
pub struct BlockSegmentPostings {
doc_decoder: BlockDecoder,
freq_decoder: BlockDecoder,
freq_reading_option: FreqReadingOption,
doc_freq: usize,
doc_offset: DocId,
num_vint_docs: usize,
remaining_data: OwnedRead,
skip_reader: SkipReader,
}
fn split_into_skips_and_postings(
doc_freq: u32,
mut data: OwnedRead,
) -> (Option<OwnedRead>, OwnedRead) {
if doc_freq >= USE_SKIP_INFO_LIMIT {
let skip_len = VInt::deserialize(&mut data).expect("Data corrupted").0 as usize;
let mut postings_data = data.clone();
postings_data.advance(skip_len);
data.clip(skip_len);
(Some(data), postings_data)
} else {
(None, data)
}
}
#[derive(Debug, Eq, PartialEq)]
pub enum BlockSegmentPostingsSkipResult {
Terminated,
Success(u32), //< number of term freqs to skip
}
impl BlockSegmentPostings {
pub(crate) fn from_data(
doc_freq: u32,
data: OwnedRead,
record_option: IndexRecordOption,
requested_option: IndexRecordOption,
) -> BlockSegmentPostings {
let freq_reading_option = match (record_option, requested_option) {
(IndexRecordOption::Basic, _) => FreqReadingOption::NoFreq,
(_, IndexRecordOption::Basic) => FreqReadingOption::SkipFreq,
(_, _) => FreqReadingOption::ReadFreq,
};
let (skip_data_opt, postings_data) = split_into_skips_and_postings(doc_freq, data);
let skip_reader = match skip_data_opt {
Some(skip_data) => SkipReader::new(skip_data, record_option),
None => SkipReader::new(OwnedRead::new(&[][..]), record_option),
};
let doc_freq = doc_freq as usize;
let num_vint_docs = doc_freq % COMPRESSION_BLOCK_SIZE;
BlockSegmentPostings {
num_vint_docs,
doc_decoder: BlockDecoder::new(),
freq_decoder: BlockDecoder::with_val(1),
freq_reading_option,
doc_offset: 0,
doc_freq,
remaining_data: postings_data,
skip_reader,
}
}
// Resets the block segment postings on another position
// in the postings file.
//
// This is useful for enumerating through a list of terms,
// and consuming the associated posting lists while avoiding
// reallocating a `BlockSegmentPostings`.
//
// # Warning
//
// This does not reset the positions list.
pub(crate) fn reset(&mut self, doc_freq: u32, postings_data: OwnedRead) {
let (skip_data_opt, postings_data) = split_into_skips_and_postings(doc_freq, postings_data);
let num_vint_docs = (doc_freq as usize) & (COMPRESSION_BLOCK_SIZE - 1);
self.num_vint_docs = num_vint_docs;
self.remaining_data = postings_data;
if let Some(skip_data) = skip_data_opt {
self.skip_reader.reset(skip_data);
} else {
self.skip_reader.reset(OwnedRead::new(&[][..]))
}
self.doc_offset = 0;
self.doc_freq = doc_freq as usize;
}
/// Returns the document frequency associated to this block postings.
///
/// This `doc_freq` is simply the sum of the length of all of the blocks
/// length, and it does not take in account deleted documents.
pub fn doc_freq(&self) -> usize {
self.doc_freq
}
/// Returns the array of docs in the current block.
///
/// Before the first call to `.advance()`, the block
/// returned by `.docs()` is empty.
#[inline]
pub fn docs(&self) -> &[DocId] {
self.doc_decoder.output_array()
}
pub(crate) fn docs_aligned(&self) -> (&AlignedBuffer, usize) {
self.doc_decoder.output_aligned()
}
/// Return the document at index `idx` of the block.
#[inline]
pub fn doc(&self, idx: usize) -> u32 {
self.doc_decoder.output(idx)
}
/// Return the array of `term freq` in the block.
#[inline]
pub fn freqs(&self) -> &[u32] {
self.freq_decoder.output_array()
}
/// Return the frequency at index `idx` of the block.
#[inline]
pub fn freq(&self, idx: usize) -> u32 {
self.freq_decoder.output(idx)
}
/// Returns the length of the current block.
///
/// All blocks have a length of `NUM_DOCS_PER_BLOCK`,
/// except the last block that may have a length
/// of any number between 1 and `NUM_DOCS_PER_BLOCK - 1`
#[inline]
fn block_len(&self) -> usize {
self.doc_decoder.output_len
}
/// position on a block that may contains `doc_id`.
/// Always advance the current block.
///
/// Returns true if a block that has an element greater or equal to the target is found.
/// Returning true does not guarantee that the smallest element of the block is smaller
/// than the target. It only guarantees that the last element is greater or equal.
///
/// Returns false iff all of the document remaining are smaller than
/// `doc_id`. In that case, all of these document are consumed.
///
pub fn skip_to(&mut self, target_doc: DocId) -> BlockSegmentPostingsSkipResult {
let mut skip_freqs = 0u32;
while self.skip_reader.advance() {
if self.skip_reader.doc() >= target_doc {
// the last document of the current block is larger
// than the target.
//
// We found our block!
let num_bits = self.skip_reader.doc_num_bits();
let num_consumed_bytes = self.doc_decoder.uncompress_block_sorted(
self.remaining_data.as_ref(),
self.doc_offset,
num_bits,
);
self.remaining_data.advance(num_consumed_bytes);
let tf_num_bits = self.skip_reader.tf_num_bits();
match self.freq_reading_option {
FreqReadingOption::NoFreq => {}
FreqReadingOption::SkipFreq => {
let num_bytes_to_skip = compressed_block_size(tf_num_bits);
self.remaining_data.advance(num_bytes_to_skip);
}
FreqReadingOption::ReadFreq => {
let num_consumed_bytes = self
.freq_decoder
.uncompress_block_unsorted(self.remaining_data.as_ref(), tf_num_bits);
self.remaining_data.advance(num_consumed_bytes);
}
}
self.doc_offset = self.skip_reader.doc();
return BlockSegmentPostingsSkipResult::Success(skip_freqs);
} else {
skip_freqs += self.skip_reader.tf_sum();
let advance_len = self.skip_reader.total_block_len();
self.doc_offset = self.skip_reader.doc();
self.remaining_data.advance(advance_len);
}
}
// we are now on the last, incomplete, variable encoded block.
if self.num_vint_docs > 0 {
let num_compressed_bytes = self.doc_decoder.uncompress_vint_sorted(
self.remaining_data.as_ref(),
self.doc_offset,
self.num_vint_docs,
);
self.remaining_data.advance(num_compressed_bytes);
match self.freq_reading_option {
FreqReadingOption::NoFreq | FreqReadingOption::SkipFreq => {}
FreqReadingOption::ReadFreq => {
self.freq_decoder
.uncompress_vint_unsorted(self.remaining_data.as_ref(), self.num_vint_docs);
}
}
self.num_vint_docs = 0;
return self
.docs()
.last()
.map(|last_doc| {
if *last_doc >= target_doc {
BlockSegmentPostingsSkipResult::Success(skip_freqs)
} else {
BlockSegmentPostingsSkipResult::Terminated
}
})
.unwrap_or(BlockSegmentPostingsSkipResult::Terminated);
}
BlockSegmentPostingsSkipResult::Terminated
}
/// Advance to the next block.
///
/// Returns false iff there was no remaining blocks.
pub fn advance(&mut self) -> bool {
if self.skip_reader.advance() {
let num_bits = self.skip_reader.doc_num_bits();
let num_consumed_bytes = self.doc_decoder.uncompress_block_sorted(
self.remaining_data.as_ref(),
self.doc_offset,
num_bits,
);
self.remaining_data.advance(num_consumed_bytes);
let tf_num_bits = self.skip_reader.tf_num_bits();
match self.freq_reading_option {
FreqReadingOption::NoFreq => {}
FreqReadingOption::SkipFreq => {
let num_bytes_to_skip = compressed_block_size(tf_num_bits);
self.remaining_data.advance(num_bytes_to_skip);
}
FreqReadingOption::ReadFreq => {
let num_consumed_bytes = self
.freq_decoder
.uncompress_block_unsorted(self.remaining_data.as_ref(), tf_num_bits);
self.remaining_data.advance(num_consumed_bytes);
}
}
// it will be used as the next offset.
self.doc_offset = self.doc_decoder.output(COMPRESSION_BLOCK_SIZE - 1);
true
} else if self.num_vint_docs > 0 {
let num_compressed_bytes = self.doc_decoder.uncompress_vint_sorted(
self.remaining_data.as_ref(),
self.doc_offset,
self.num_vint_docs,
);
self.remaining_data.advance(num_compressed_bytes);
match self.freq_reading_option {
FreqReadingOption::NoFreq | FreqReadingOption::SkipFreq => {}
FreqReadingOption::ReadFreq => {
self.freq_decoder
.uncompress_vint_unsorted(self.remaining_data.as_ref(), self.num_vint_docs);
}
}
self.num_vint_docs = 0;
true
} else {
false
}
}
/// Returns an empty segment postings object
pub fn empty() -> BlockSegmentPostings {
BlockSegmentPostings {
num_vint_docs: 0,
doc_decoder: BlockDecoder::new(),
freq_decoder: BlockDecoder::with_val(1),
freq_reading_option: FreqReadingOption::NoFreq,
doc_offset: 0,
doc_freq: 0,
remaining_data: OwnedRead::new(vec![]),
skip_reader: SkipReader::new(OwnedRead::new(vec![]), IndexRecordOption::Basic),
}
}
}
impl<'b> Streamer<'b> for BlockSegmentPostings {
type Item = &'b [DocId];
fn next(&'b mut self) -> Option<&'b [DocId]> {
if self.advance() {
Some(self.docs())
} else {
None
}
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::BlockSegmentPostings;
use super::BlockSegmentPostingsSkipResult;
use super::SegmentPostings; use super::SegmentPostings;
use crate::common::HasLen; use crate::common::HasLen;
use crate::core::Index;
use crate::docset::{DocSet, TERMINATED}; use crate::docset::DocSet;
use crate::postings::postings::Postings; use crate::postings::postings::Postings;
use crate::schema::IndexRecordOption;
use crate::schema::Schema;
use crate::schema::Term;
use crate::schema::INDEXED;
use crate::DocId;
use crate::SkipResult;
use tantivy_fst::Streamer;
#[test] #[test]
fn test_empty_segment_postings() { fn test_empty_segment_postings() {
let mut postings = SegmentPostings::empty(); let mut postings = SegmentPostings::empty();
assert_eq!(postings.advance(), TERMINATED); assert!(!postings.advance());
assert_eq!(postings.advance(), TERMINATED); assert!(!postings.advance());
assert_eq!(postings.len(), 0); assert_eq!(postings.len(), 0);
} }
#[test] #[test]
fn test_empty_postings_doc_returns_terminated() { #[should_panic(expected = "Have you forgotten to call `.advance()`")]
let mut postings = SegmentPostings::empty(); fn test_panic_if_doc_called_before_advance() {
assert_eq!(postings.doc(), TERMINATED); SegmentPostings::empty().doc();
assert_eq!(postings.advance(), TERMINATED);
} }
#[test] #[test]
fn test_empty_postings_doc_term_freq_returns_0() { #[should_panic(expected = "Have you forgotten to call `.advance()`")]
let postings = SegmentPostings::empty(); fn test_panic_if_freq_called_before_advance() {
assert_eq!(postings.term_freq(), 1); SegmentPostings::empty().term_freq();
}
#[test]
fn test_empty_block_segment_postings() {
let mut postings = BlockSegmentPostings::empty();
assert!(!postings.advance());
assert_eq!(postings.doc_freq(), 0);
}
#[test]
fn test_block_segment_postings() {
let mut block_segments = build_block_postings(&(0..100_000).collect::<Vec<u32>>());
let mut offset: u32 = 0u32;
// checking that the block before calling advance is empty
assert!(block_segments.docs().is_empty());
// checking that the `doc_freq` is correct
assert_eq!(block_segments.doc_freq(), 100_000);
while let Some(block) = block_segments.next() {
for (i, doc) in block.iter().cloned().enumerate() {
assert_eq!(offset + (i as u32), doc);
}
offset += block.len() as u32;
}
}
#[test]
fn test_skip_right_at_new_block() {
let mut doc_ids = (0..128).collect::<Vec<u32>>();
doc_ids.push(129);
doc_ids.push(130);
{
let block_segments = build_block_postings(&doc_ids);
let mut docset = SegmentPostings::from_block_postings(block_segments, None);
assert_eq!(docset.skip_next(128), SkipResult::OverStep);
assert_eq!(docset.doc(), 129);
assert!(docset.advance());
assert_eq!(docset.doc(), 130);
assert!(!docset.advance());
}
{
let block_segments = build_block_postings(&doc_ids);
let mut docset = SegmentPostings::from_block_postings(block_segments, None);
assert_eq!(docset.skip_next(129), SkipResult::Reached);
assert_eq!(docset.doc(), 129);
assert!(docset.advance());
assert_eq!(docset.doc(), 130);
assert!(!docset.advance());
}
{
let block_segments = build_block_postings(&doc_ids);
let mut docset = SegmentPostings::from_block_postings(block_segments, None);
assert_eq!(docset.skip_next(131), SkipResult::End);
}
}
fn build_block_postings(docs: &[DocId]) -> BlockSegmentPostings {
let mut schema_builder = Schema::builder();
let int_field = schema_builder.add_u64_field("id", INDEXED);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap();
let mut last_doc = 0u32;
for &doc in docs {
for _ in last_doc..doc {
index_writer.add_document(doc!(int_field=>1u64));
}
index_writer.add_document(doc!(int_field=>0u64));
last_doc = doc + 1;
}
index_writer.commit().unwrap();
let searcher = index.reader().unwrap().searcher();
let segment_reader = searcher.segment_reader(0);
let inverted_index = segment_reader.inverted_index(int_field);
let term = Term::from_field_u64(int_field, 0u64);
let term_info = inverted_index.get_term_info(&term).unwrap();
inverted_index.read_block_postings_from_terminfo(&term_info, IndexRecordOption::Basic)
}
#[test]
fn test_block_segment_postings_skip() {
for i in 0..4 {
let mut block_postings = build_block_postings(&[3]);
assert_eq!(
block_postings.skip_to(i),
BlockSegmentPostingsSkipResult::Success(0u32)
);
assert_eq!(
block_postings.skip_to(i),
BlockSegmentPostingsSkipResult::Terminated
);
}
let mut block_postings = build_block_postings(&[3]);
assert_eq!(
block_postings.skip_to(4u32),
BlockSegmentPostingsSkipResult::Terminated
);
}
#[test]
fn test_block_segment_postings_skip2() {
let mut docs = vec![0];
for i in 0..1300 {
docs.push((i * i / 100) + i);
}
let mut block_postings = build_block_postings(&docs[..]);
for i in vec![0, 424, 10000] {
assert_eq!(
block_postings.skip_to(i),
BlockSegmentPostingsSkipResult::Success(0u32)
);
let docs = block_postings.docs();
assert!(docs[0] <= i);
assert!(docs.last().cloned().unwrap_or(0u32) >= i);
}
assert_eq!(
block_postings.skip_to(100_000),
BlockSegmentPostingsSkipResult::Terminated
);
assert_eq!(
block_postings.skip_to(101_000),
BlockSegmentPostingsSkipResult::Terminated
);
}
#[test]
fn test_reset_block_segment_postings() {
let mut schema_builder = Schema::builder();
let int_field = schema_builder.add_u64_field("id", INDEXED);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap();
// create two postings list, one containg even number,
// the other containing odd numbers.
for i in 0..6 {
let doc = doc!(int_field=> (i % 2) as u64);
index_writer.add_document(doc);
}
index_writer.commit().unwrap();
let searcher = index.reader().unwrap().searcher();
let segment_reader = searcher.segment_reader(0);
let mut block_segments;
{
let term = Term::from_field_u64(int_field, 0u64);
let inverted_index = segment_reader.inverted_index(int_field);
let term_info = inverted_index.get_term_info(&term).unwrap();
block_segments = inverted_index
.read_block_postings_from_terminfo(&term_info, IndexRecordOption::Basic);
}
assert!(block_segments.advance());
assert_eq!(block_segments.docs(), &[0, 2, 4]);
{
let term = Term::from_field_u64(int_field, 1u64);
let inverted_index = segment_reader.inverted_index(int_field);
let term_info = inverted_index.get_term_info(&term).unwrap();
inverted_index.reset_block_postings_from_terminfo(&term_info, &mut block_segments);
}
assert!(block_segments.advance());
assert_eq!(block_segments.docs(), &[1, 3, 5]);
} }
} }

View File

@@ -6,10 +6,12 @@ use crate::directory::WritePtr;
use crate::positions::PositionSerializer; use crate::positions::PositionSerializer;
use crate::postings::compression::{BlockEncoder, VIntEncoder, COMPRESSION_BLOCK_SIZE}; use crate::postings::compression::{BlockEncoder, VIntEncoder, COMPRESSION_BLOCK_SIZE};
use crate::postings::skip::SkipSerializer; use crate::postings::skip::SkipSerializer;
use crate::postings::USE_SKIP_INFO_LIMIT;
use crate::schema::Schema; use crate::schema::Schema;
use crate::schema::{Field, FieldEntry, FieldType}; use crate::schema::{Field, FieldEntry, FieldType};
use crate::termdict::{TermDictionaryBuilder, TermOrdinal}; use crate::termdict::{TermDictionaryBuilder, TermOrdinal};
use crate::DocId; use crate::Result;
use crate::{Directory, DocId};
use std::io::{self, Write}; use std::io::{self, Write};
/// `InvertedIndexSerializer` is in charge of serializing /// `InvertedIndexSerializer` is in charge of serializing
@@ -52,33 +54,36 @@ pub struct InvertedIndexSerializer {
} }
impl InvertedIndexSerializer { impl InvertedIndexSerializer {
/// Open a new `InvertedIndexSerializer` for the given segment pub(crate) fn for_segment(segment: &mut Segment) -> crate::Result<Self> {
fn create( let schema = segment.schema();
terms_write: CompositeWrite<WritePtr>, use crate::core::SegmentComponent;
postings_write: CompositeWrite<WritePtr>, let terms_wrt = segment.open_write(SegmentComponent::TERMS)?;
positions_write: CompositeWrite<WritePtr>, let postings_wrt = segment.open_write(SegmentComponent::POSTINGS)?;
positionsidx_write: CompositeWrite<WritePtr>, let positions_wrt = segment.open_write(SegmentComponent::POSITIONS)?;
schema: Schema, let positions_idx_wrt = segment.open_write(SegmentComponent::POSITIONSSKIP)?;
) -> crate::Result<InvertedIndexSerializer> { Ok(Self::open(
Ok(InvertedIndexSerializer {
terms_write,
postings_write,
positions_write,
positionsidx_write,
schema, schema,
}) terms_wrt,
postings_wrt,
positions_wrt,
positions_idx_wrt,
))
} }
/// Open a new `PostingsSerializer` for the given segment /// Open a new `PostingsSerializer` for the given segment
pub fn open(segment: &mut Segment) -> crate::Result<InvertedIndexSerializer> { pub(crate) fn open(
use crate::SegmentComponent::{POSITIONS, POSITIONSSKIP, POSTINGS, TERMS}; schema: Schema,
InvertedIndexSerializer::create( terms_wrt: WritePtr,
CompositeWrite::wrap(segment.open_write(TERMS)?), postings_wrt: WritePtr,
CompositeWrite::wrap(segment.open_write(POSTINGS)?), positions_wrt: WritePtr,
CompositeWrite::wrap(segment.open_write(POSITIONS)?), positions_idx_wrt: WritePtr,
CompositeWrite::wrap(segment.open_write(POSITIONSSKIP)?), ) -> InvertedIndexSerializer {
segment.schema(), InvertedIndexSerializer {
) terms_write: CompositeWrite::wrap(terms_wrt),
postings_write: CompositeWrite::wrap(postings_wrt),
positions_write: CompositeWrite::wrap(positions_wrt),
positionsidx_write: CompositeWrite::wrap(positions_idx_wrt),
schema,
}
} }
/// Must be called before starting pushing terms of /// Must be called before starting pushing terms of
@@ -146,7 +151,8 @@ impl<'a> FieldSerializer<'a> {
} }
_ => (false, false), _ => (false, false),
}; };
let term_dictionary_builder = TermDictionaryBuilder::create(term_dictionary_write)?; let term_dictionary_builder =
TermDictionaryBuilder::create(term_dictionary_write, &field_type)?;
let postings_serializer = let postings_serializer =
PostingsSerializer::new(postings_write, term_freq_enabled, position_enabled); PostingsSerializer::new(postings_write, term_freq_enabled, position_enabled);
let positions_serializer_opt = if position_enabled { let positions_serializer_opt = if position_enabled {
@@ -390,7 +396,7 @@ impl<W: Write> PostingsSerializer<W> {
} }
self.block.clear(); self.block.clear();
} }
if doc_freq >= COMPRESSION_BLOCK_SIZE as u32 { if doc_freq >= USE_SKIP_INFO_LIMIT {
let skip_data = self.skip_write.data(); let skip_data = self.skip_write.data();
VInt(skip_data.len() as u64).serialize(&mut self.output_write)?; VInt(skip_data.len() as u64).serialize(&mut self.output_write)?;
self.output_write.write_all(skip_data)?; self.output_write.write_all(skip_data)?;

View File

@@ -1,8 +1,7 @@
use crate::common::BinarySerializable; use crate::common::BinarySerializable;
use crate::directory::ReadOnlySource; use crate::postings::compression::COMPRESSION_BLOCK_SIZE;
use crate::postings::compression::{compressed_block_size, COMPRESSION_BLOCK_SIZE};
use crate::schema::IndexRecordOption; use crate::schema::IndexRecordOption;
use crate::{DocId, TERMINATED}; use crate::DocId;
use owned_read::OwnedRead; use owned_read::OwnedRead;
pub struct SkipSerializer { pub struct SkipSerializer {
@@ -51,143 +50,80 @@ impl SkipSerializer {
} }
pub(crate) struct SkipReader { pub(crate) struct SkipReader {
last_doc_in_block: DocId, doc: DocId,
pub(crate) last_doc_in_previous_block: DocId,
owned_read: OwnedRead, owned_read: OwnedRead,
doc_num_bits: u8,
tf_num_bits: u8,
tf_sum: u32,
skip_info: IndexRecordOption, skip_info: IndexRecordOption,
byte_offset: usize,
remaining_docs: u32, // number of docs remaining, including the
// documents in the current block.
block_info: BlockInfo,
position_offset: u64,
}
#[derive(Clone, Eq, PartialEq, Copy, Debug)]
pub(crate) enum BlockInfo {
BitPacked {
doc_num_bits: u8,
tf_num_bits: u8,
tf_sum: u32,
},
VInt(u32),
}
impl Default for BlockInfo {
fn default() -> Self {
BlockInfo::VInt(0)
}
} }
impl SkipReader { impl SkipReader {
pub fn new(data: ReadOnlySource, doc_freq: u32, skip_info: IndexRecordOption) -> SkipReader { pub fn new(data: OwnedRead, skip_info: IndexRecordOption) -> SkipReader {
SkipReader { SkipReader {
last_doc_in_block: 0u32, doc: 0u32,
last_doc_in_previous_block: 0u32, owned_read: data,
owned_read: OwnedRead::new(data),
skip_info, skip_info,
block_info: BlockInfo::default(), doc_num_bits: 0u8,
byte_offset: 0, tf_num_bits: 0u8,
remaining_docs: doc_freq, tf_sum: 0u32,
position_offset: 0u64,
} }
} }
pub fn reset(&mut self, data: ReadOnlySource, doc_freq: u32) { pub fn reset(&mut self, data: OwnedRead) {
self.last_doc_in_block = 0u32; self.doc = 0u32;
self.last_doc_in_previous_block = 0u32; self.owned_read = data;
self.owned_read = OwnedRead::new(data); self.doc_num_bits = 0u8;
self.block_info = BlockInfo::default(); self.tf_num_bits = 0u8;
self.byte_offset = 0; self.tf_sum = 0u32;
self.remaining_docs = doc_freq;
} }
#[cfg(test)] pub fn total_block_len(&self) -> usize {
#[inline(always)] (self.doc_num_bits + self.tf_num_bits) as usize * COMPRESSION_BLOCK_SIZE / 8
pub(crate) fn last_doc_in_block(&self) -> DocId {
self.last_doc_in_block
} }
pub fn position_offset(&self) -> u64 { pub fn doc(&self) -> DocId {
self.position_offset self.doc
} }
pub fn byte_offset(&self) -> usize { pub fn doc_num_bits(&self) -> u8 {
self.byte_offset self.doc_num_bits
} }
fn read_block_info(&mut self) { /// Number of bits used to encode term frequencies
let doc_delta = u32::deserialize(&mut self.owned_read).expect("Skip data corrupted");
self.last_doc_in_block += doc_delta as DocId;
let doc_num_bits = self.owned_read.get(0);
match self.skip_info {
IndexRecordOption::Basic => {
self.owned_read.advance(1);
self.block_info = BlockInfo::BitPacked {
doc_num_bits,
tf_num_bits: 0,
tf_sum: 0,
};
}
IndexRecordOption::WithFreqs => {
let tf_num_bits = self.owned_read.get(1);
self.block_info = BlockInfo::BitPacked {
doc_num_bits,
tf_num_bits,
tf_sum: 0,
};
self.owned_read.advance(2);
}
IndexRecordOption::WithFreqsAndPositions => {
let tf_num_bits = self.owned_read.get(1);
self.owned_read.advance(2);
let tf_sum = u32::deserialize(&mut self.owned_read).expect("Failed reading tf_sum");
self.block_info = BlockInfo::BitPacked {
doc_num_bits,
tf_num_bits,
tf_sum,
};
}
}
}
pub fn block_info(&self) -> BlockInfo {
self.block_info
}
/// Advance the skip reader to the block that may contain the target.
/// ///
/// If the target is larger than all documents, the skip_reader /// 0 if term frequencies are not enabled.
/// then advance to the last Variable In block. pub fn tf_num_bits(&self) -> u8 {
pub fn seek(&mut self, target: DocId) { self.tf_num_bits
while self.last_doc_in_block < target { }
self.advance();
} pub fn tf_sum(&self) -> u32 {
self.tf_sum
} }
pub fn advance(&mut self) -> bool { pub fn advance(&mut self) -> bool {
match self.block_info { if self.owned_read.as_ref().is_empty() {
BlockInfo::BitPacked { false
doc_num_bits,
tf_num_bits,
tf_sum,
} => {
self.remaining_docs -= COMPRESSION_BLOCK_SIZE as u32;
self.byte_offset += compressed_block_size(doc_num_bits + tf_num_bits);
self.position_offset += tf_sum as u64;
}
BlockInfo::VInt(num_vint_docs) => {
self.remaining_docs -= num_vint_docs;
}
}
self.last_doc_in_previous_block = self.last_doc_in_block;
if self.remaining_docs >= COMPRESSION_BLOCK_SIZE as u32 {
self.read_block_info();
true
} else { } else {
self.last_doc_in_block = TERMINATED; let doc_delta = u32::deserialize(&mut self.owned_read).expect("Skip data corrupted");
self.block_info = BlockInfo::VInt(self.remaining_docs); self.doc += doc_delta as DocId;
self.remaining_docs > 0 self.doc_num_bits = self.owned_read.get(0);
match self.skip_info {
IndexRecordOption::Basic => {
self.owned_read.advance(1);
}
IndexRecordOption::WithFreqs => {
self.tf_num_bits = self.owned_read.get(1);
self.owned_read.advance(2);
}
IndexRecordOption::WithFreqsAndPositions => {
self.tf_num_bits = self.owned_read.get(1);
self.owned_read.advance(2);
self.tf_sum =
u32::deserialize(&mut self.owned_read).expect("Failed reading tf_sum");
}
}
true
} }
} }
} }
@@ -195,11 +131,9 @@ impl SkipReader {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::BlockInfo;
use super::IndexRecordOption; use super::IndexRecordOption;
use super::{SkipReader, SkipSerializer}; use super::{SkipReader, SkipSerializer};
use crate::directory::ReadOnlySource; use owned_read::OwnedRead;
use crate::postings::compression::COMPRESSION_BLOCK_SIZE;
#[test] #[test]
fn test_skip_with_freq() { fn test_skip_with_freq() {
@@ -211,34 +145,15 @@ mod tests {
skip_serializer.write_term_freq(2u8); skip_serializer.write_term_freq(2u8);
skip_serializer.data().to_owned() skip_serializer.data().to_owned()
}; };
let doc_freq = 3u32 + (COMPRESSION_BLOCK_SIZE * 2) as u32; let mut skip_reader = SkipReader::new(OwnedRead::new(buf), IndexRecordOption::WithFreqs);
let mut skip_reader = SkipReader::new(
ReadOnlySource::new(buf),
doc_freq,
IndexRecordOption::WithFreqs,
);
assert!(skip_reader.advance()); assert!(skip_reader.advance());
assert_eq!(skip_reader.last_doc_in_block(), 1u32); assert_eq!(skip_reader.doc(), 1u32);
assert_eq!( assert_eq!(skip_reader.doc_num_bits(), 2u8);
skip_reader.block_info(), assert_eq!(skip_reader.tf_num_bits(), 3u8);
BlockInfo::BitPacked {
doc_num_bits: 2u8,
tf_num_bits: 3u8,
tf_sum: 0
}
);
assert!(skip_reader.advance()); assert!(skip_reader.advance());
assert_eq!(skip_reader.last_doc_in_block(), 5u32); assert_eq!(skip_reader.doc(), 5u32);
assert_eq!( assert_eq!(skip_reader.doc_num_bits(), 5u8);
skip_reader.block_info(), assert_eq!(skip_reader.tf_num_bits(), 2u8);
BlockInfo::BitPacked {
doc_num_bits: 5u8,
tf_num_bits: 2u8,
tf_sum: 0
}
);
assert!(skip_reader.advance());
assert_eq!(skip_reader.block_info(), BlockInfo::VInt(3u32));
assert!(!skip_reader.advance()); assert!(!skip_reader.advance());
} }
@@ -250,60 +165,13 @@ mod tests {
skip_serializer.write_doc(5u32, 5u8); skip_serializer.write_doc(5u32, 5u8);
skip_serializer.data().to_owned() skip_serializer.data().to_owned()
}; };
let doc_freq = 3u32 + (COMPRESSION_BLOCK_SIZE * 2) as u32; let mut skip_reader = SkipReader::new(OwnedRead::new(buf), IndexRecordOption::Basic);
let mut skip_reader = SkipReader::new(
ReadOnlySource::from(buf),
doc_freq,
IndexRecordOption::Basic,
);
assert!(skip_reader.advance()); assert!(skip_reader.advance());
assert_eq!(skip_reader.last_doc_in_block(), 1u32); assert_eq!(skip_reader.doc(), 1u32);
assert_eq!( assert_eq!(skip_reader.doc_num_bits(), 2u8);
skip_reader.block_info(),
BlockInfo::BitPacked {
doc_num_bits: 2u8,
tf_num_bits: 0,
tf_sum: 0u32
}
);
assert!(skip_reader.advance()); assert!(skip_reader.advance());
assert_eq!(skip_reader.last_doc_in_block(), 5u32); assert_eq!(skip_reader.doc(), 5u32);
assert_eq!( assert_eq!(skip_reader.doc_num_bits(), 5u8);
skip_reader.block_info(),
BlockInfo::BitPacked {
doc_num_bits: 5u8,
tf_num_bits: 0,
tf_sum: 0u32
}
);
assert!(skip_reader.advance());
assert_eq!(skip_reader.block_info(), BlockInfo::VInt(3u32));
assert!(!skip_reader.advance());
}
#[test]
fn test_skip_multiple_of_block_size() {
let buf = {
let mut skip_serializer = SkipSerializer::new();
skip_serializer.write_doc(1u32, 2u8);
skip_serializer.data().to_owned()
};
let doc_freq = COMPRESSION_BLOCK_SIZE as u32;
let mut skip_reader = SkipReader::new(
ReadOnlySource::from(buf),
doc_freq,
IndexRecordOption::Basic,
);
assert!(skip_reader.advance());
assert_eq!(skip_reader.last_doc_in_block(), 1u32);
assert_eq!(
skip_reader.block_info(),
BlockInfo::BitPacked {
doc_num_bits: 2u8,
tf_num_bits: 0,
tf_sum: 0u32
}
);
assert!(!skip_reader.advance()); assert!(!skip_reader.advance());
} }
} }

View File

@@ -1,4 +1,6 @@
use murmurhash32::murmurhash2; use murmurhash32;
use self::murmurhash32::murmurhash2;
use super::{Addr, MemoryArena}; use super::{Addr, MemoryArena};
use crate::postings::stacker::memory_arena::store; use crate::postings::stacker::memory_arena::store;

View File

@@ -1,10 +1,10 @@
use crate::core::Searcher; use crate::core::Searcher;
use crate::core::SegmentReader; use crate::core::SegmentReader;
use crate::docset::{DocSet, TERMINATED}; use crate::docset::DocSet;
use crate::query::boost_query::BoostScorer;
use crate::query::explanation::does_not_match; use crate::query::explanation::does_not_match;
use crate::query::{Explanation, Query, Scorer, Weight}; use crate::query::{Explanation, Query, Scorer, Weight};
use crate::DocId; use crate::DocId;
use crate::Result;
use crate::Score; use crate::Score;
/// Query that matches all of the documents. /// Query that matches all of the documents.
@@ -14,7 +14,7 @@ use crate::Score;
pub struct AllQuery; pub struct AllQuery;
impl Query for AllQuery { impl Query for AllQuery {
fn weight(&self, _: &Searcher, _: bool) -> crate::Result<Box<dyn Weight>> { fn weight(&self, _: &Searcher, _: bool) -> Result<Box<dyn Weight>> {
Ok(Box::new(AllWeight)) Ok(Box::new(AllWeight))
} }
} }
@@ -23,15 +23,15 @@ impl Query for AllQuery {
pub struct AllWeight; pub struct AllWeight;
impl Weight for AllWeight { impl Weight for AllWeight {
fn scorer(&self, reader: &SegmentReader, boost: f32) -> crate::Result<Box<dyn Scorer>> { fn scorer(&self, reader: &SegmentReader) -> Result<Box<dyn Scorer>> {
let all_scorer = AllScorer { Ok(Box::new(AllScorer {
state: State::NotStarted,
doc: 0u32, doc: 0u32,
max_doc: reader.max_doc(), max_doc: reader.max_doc(),
}; }))
Ok(Box::new(BoostScorer::new(all_scorer, boost)))
} }
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> { fn explain(&self, reader: &SegmentReader, doc: DocId) -> Result<Explanation> {
if doc >= reader.max_doc() { if doc >= reader.max_doc() {
return Err(does_not_match(doc)); return Err(does_not_match(doc));
} }
@@ -39,20 +39,39 @@ impl Weight for AllWeight {
} }
} }
enum State {
NotStarted,
Started,
Finished,
}
/// Scorer associated to the `AllQuery` query. /// Scorer associated to the `AllQuery` query.
pub struct AllScorer { pub struct AllScorer {
state: State,
doc: DocId, doc: DocId,
max_doc: DocId, max_doc: DocId,
} }
impl DocSet for AllScorer { impl DocSet for AllScorer {
fn advance(&mut self) -> DocId { fn advance(&mut self) -> bool {
if self.doc + 1 >= self.max_doc { match self.state {
self.doc = TERMINATED; State::NotStarted => {
return TERMINATED; self.state = State::Started;
self.doc = 0;
}
State::Started => {
self.doc += 1u32;
}
State::Finished => {
return false;
}
}
if self.doc < self.max_doc {
true
} else {
self.state = State::Finished;
false
} }
self.doc += 1;
self.doc
} }
fn doc(&self) -> DocId { fn doc(&self) -> DocId {
@@ -72,13 +91,14 @@ impl Scorer for AllScorer {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::AllQuery; use super::AllQuery;
use crate::docset::TERMINATED;
use crate::query::Query; use crate::query::Query;
use crate::schema::{Schema, TEXT}; use crate::schema::{Schema, TEXT};
use crate::Index; use crate::Index;
fn create_test_index() -> Index { #[test]
fn test_all_query() {
let mut schema_builder = Schema::builder(); let mut schema_builder = Schema::builder();
let field = schema_builder.add_text_field("text", TEXT); let field = schema_builder.add_text_field("text", TEXT);
let schema = schema_builder.build(); let schema = schema_builder.build();
@@ -89,47 +109,25 @@ mod tests {
index_writer.commit().unwrap(); index_writer.commit().unwrap();
index_writer.add_document(doc!(field=>"ccc")); index_writer.add_document(doc!(field=>"ccc"));
index_writer.commit().unwrap(); index_writer.commit().unwrap();
index
}
#[test]
fn test_all_query() {
let index = create_test_index();
let reader = index.reader().unwrap(); let reader = index.reader().unwrap();
reader.reload().unwrap();
let searcher = reader.searcher(); let searcher = reader.searcher();
let weight = AllQuery.weight(&searcher, false).unwrap(); let weight = AllQuery.weight(&searcher, false).unwrap();
{ {
let reader = searcher.segment_reader(0); let reader = searcher.segment_reader(0);
let mut scorer = weight.scorer(reader, 1.0f32).unwrap(); let mut scorer = weight.scorer(reader).unwrap();
assert!(scorer.advance());
assert_eq!(scorer.doc(), 0u32); assert_eq!(scorer.doc(), 0u32);
assert_eq!(scorer.advance(), 1u32); assert!(scorer.advance());
assert_eq!(scorer.doc(), 1u32); assert_eq!(scorer.doc(), 1u32);
assert_eq!(scorer.advance(), TERMINATED); assert!(!scorer.advance());
} }
{ {
let reader = searcher.segment_reader(1); let reader = searcher.segment_reader(1);
let mut scorer = weight.scorer(reader, 1.0f32).unwrap(); let mut scorer = weight.scorer(reader).unwrap();
assert!(scorer.advance());
assert_eq!(scorer.doc(), 0u32); assert_eq!(scorer.doc(), 0u32);
assert_eq!(scorer.advance(), TERMINATED); assert!(!scorer.advance());
}
}
#[test]
fn test_all_query_with_boost() {
let index = create_test_index();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let weight = AllQuery.weight(&searcher, false).unwrap();
let reader = searcher.segment_reader(0);
{
let mut scorer = weight.scorer(reader, 2.0f32).unwrap();
assert_eq!(scorer.doc(), 0u32);
assert_eq!(scorer.score(), 2.0f32);
}
{
let mut scorer = weight.scorer(reader, 1.5f32).unwrap();
assert_eq!(scorer.doc(), 0u32);
assert_eq!(scorer.score(), 1.5f32);
} }
} }
} }

View File

@@ -6,8 +6,8 @@ use crate::query::{Scorer, Weight};
use crate::schema::{Field, IndexRecordOption}; use crate::schema::{Field, IndexRecordOption};
use crate::termdict::{TermDictionary, TermStreamer}; use crate::termdict::{TermDictionary, TermStreamer};
use crate::DocId; use crate::DocId;
use crate::Result;
use crate::TantivyError; use crate::TantivyError;
use crate::{Result, SkipResult};
use std::sync::Arc; use std::sync::Arc;
use tantivy_fst::Automaton; use tantivy_fst::Automaton;
@@ -40,7 +40,7 @@ impl<A> Weight for AutomatonWeight<A>
where where
A: Automaton + Send + Sync + 'static, A: Automaton + Send + Sync + 'static,
{ {
fn scorer(&self, reader: &SegmentReader, boost: f32) -> Result<Box<dyn Scorer>> { fn scorer(&self, reader: &SegmentReader) -> Result<Box<dyn Scorer>> {
let max_doc = reader.max_doc(); let max_doc = reader.max_doc();
let mut doc_bitset = BitSet::with_max_value(max_doc); let mut doc_bitset = BitSet::with_max_value(max_doc);
@@ -51,23 +51,19 @@ where
let term_info = term_stream.value(); let term_info = term_stream.value();
let mut block_segment_postings = inverted_index let mut block_segment_postings = inverted_index
.read_block_postings_from_terminfo(term_info, IndexRecordOption::Basic); .read_block_postings_from_terminfo(term_info, IndexRecordOption::Basic);
loop { while block_segment_postings.advance() {
for &doc in block_segment_postings.docs() { for &doc in block_segment_postings.docs() {
doc_bitset.insert(doc); doc_bitset.insert(doc);
} }
if !block_segment_postings.advance() {
break;
}
} }
} }
let doc_bitset = BitSetDocSet::from(doc_bitset); let doc_bitset = BitSetDocSet::from(doc_bitset);
let const_scorer = ConstScorer::new(doc_bitset, boost); Ok(Box::new(ConstScorer::new(doc_bitset)))
Ok(Box::new(const_scorer))
} }
fn explain(&self, reader: &SegmentReader, doc: DocId) -> Result<Explanation> { fn explain(&self, reader: &SegmentReader, doc: DocId) -> Result<Explanation> {
let mut scorer = self.scorer(reader, 1.0f32)?; let mut scorer = self.scorer(reader)?;
if scorer.seek(doc) == doc { if scorer.skip_next(doc) == SkipResult::Reached {
Ok(Explanation::new("AutomatonScorer", 1.0f32)) Ok(Explanation::new("AutomatonScorer", 1.0f32))
} else { } else {
Err(TantivyError::InvalidArgument( Err(TantivyError::InvalidArgument(
@@ -76,94 +72,3 @@ where
} }
} }
} }
#[cfg(test)]
mod tests {
use super::AutomatonWeight;
use crate::docset::TERMINATED;
use crate::query::Weight;
use crate::schema::{Schema, STRING};
use crate::Index;
use tantivy_fst::Automaton;
fn create_index() -> Index {
let mut schema = Schema::builder();
let title = schema.add_text_field("title", STRING);
let index = Index::create_in_ram(schema.build());
let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap();
index_writer.add_document(doc!(title=>"abc"));
index_writer.add_document(doc!(title=>"bcd"));
index_writer.add_document(doc!(title=>"abcd"));
assert!(index_writer.commit().is_ok());
index
}
enum State {
Start,
NotMatching,
AfterA,
}
struct PrefixedByA;
impl Automaton for PrefixedByA {
type State = State;
fn start(&self) -> Self::State {
State::Start
}
fn is_match(&self, state: &Self::State) -> bool {
match *state {
State::AfterA => true,
_ => false,
}
}
fn accept(&self, state: &Self::State, byte: u8) -> Self::State {
match *state {
State::Start => {
if byte == b'a' {
State::AfterA
} else {
State::NotMatching
}
}
State::AfterA => State::AfterA,
State::NotMatching => State::NotMatching,
}
}
}
#[test]
fn test_automaton_weight() {
let index = create_index();
let field = index.schema().get_field("title").unwrap();
let automaton_weight = AutomatonWeight::new(field, PrefixedByA);
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let mut scorer = automaton_weight
.scorer(searcher.segment_reader(0u32), 1.0f32)
.unwrap();
assert_eq!(scorer.doc(), 0u32);
assert_eq!(scorer.score(), 1.0f32);
assert_eq!(scorer.advance(), 2u32);
assert_eq!(scorer.doc(), 2u32);
assert_eq!(scorer.score(), 1.0f32);
assert_eq!(scorer.advance(), TERMINATED);
}
#[test]
fn test_automaton_weight_boost() {
let index = create_index();
let field = index.schema().get_field("title").unwrap();
let automaton_weight = AutomatonWeight::new(field, PrefixedByA);
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let mut scorer = automaton_weight
.scorer(searcher.segment_reader(0u32), 1.32f32)
.unwrap();
assert_eq!(scorer.doc(), 0u32);
assert_eq!(scorer.score(), 1.32f32);
}
}

View File

@@ -1,6 +1,7 @@
use crate::common::{BitSet, TinySet}; use crate::common::{BitSet, TinySet};
use crate::docset::{DocSet, TERMINATED}; use crate::docset::{DocSet, SkipResult};
use crate::DocId; use crate::DocId;
use std::cmp::Ordering;
/// A `BitSetDocSet` makes it possible to iterate through a bitset as if it was a `DocSet`. /// A `BitSetDocSet` makes it possible to iterate through a bitset as if it was a `DocSet`.
/// ///
@@ -32,50 +33,75 @@ impl From<BitSet> for BitSetDocSet {
} else { } else {
docs.tinyset(0) docs.tinyset(0)
}; };
let mut docset = BitSetDocSet { BitSetDocSet {
docs, docs,
cursor_bucket: 0, cursor_bucket: 0,
cursor_tinybitset: first_tiny_bitset, cursor_tinybitset: first_tiny_bitset,
doc: 0u32, doc: 0u32,
}; }
docset.advance();
docset
} }
} }
impl DocSet for BitSetDocSet { impl DocSet for BitSetDocSet {
fn advance(&mut self) -> DocId { fn advance(&mut self) -> bool {
if let Some(lower) = self.cursor_tinybitset.pop_lowest() { if let Some(lower) = self.cursor_tinybitset.pop_lowest() {
self.doc = (self.cursor_bucket as u32 * 64u32) | lower; self.doc = (self.cursor_bucket as u32 * 64u32) | lower;
return self.doc; return true;
} }
if let Some(cursor_bucket) = self.docs.first_non_empty_bucket(self.cursor_bucket + 1) { if let Some(cursor_bucket) = self.docs.first_non_empty_bucket(self.cursor_bucket + 1) {
self.go_to_bucket(cursor_bucket); self.go_to_bucket(cursor_bucket);
let lower = self.cursor_tinybitset.pop_lowest().unwrap(); let lower = self.cursor_tinybitset.pop_lowest().unwrap();
self.doc = (cursor_bucket * 64u32) | lower; self.doc = (cursor_bucket * 64u32) | lower;
self.doc true
} else { } else {
self.doc = TERMINATED; false
TERMINATED
} }
} }
fn seek(&mut self, target: DocId) -> DocId { fn skip_next(&mut self, target: DocId) -> SkipResult {
// skip is required to advance.
if !self.advance() {
return SkipResult::End;
}
let target_bucket = target / 64u32; let target_bucket = target / 64u32;
// Mask for all of the bits greater or equal // Mask for all of the bits greater or equal
// to our target document. // to our target document.
if target_bucket > self.cursor_bucket { match target_bucket.cmp(&self.cursor_bucket) {
self.go_to_bucket(target_bucket); Ordering::Greater => {
let greater_filter: TinySet = TinySet::range_greater_or_equal(target); self.go_to_bucket(target_bucket);
self.cursor_tinybitset = self.cursor_tinybitset.intersect(greater_filter); let greater_filter: TinySet = TinySet::range_greater_or_equal(target);
self.advance(); self.cursor_tinybitset = self.cursor_tinybitset.intersect(greater_filter);
if !self.advance() {
SkipResult::End
} else if self.doc() == target {
SkipResult::Reached
} else {
debug_assert!(self.doc() > target);
SkipResult::OverStep
}
}
Ordering::Equal => loop {
match self.doc().cmp(&target) {
Ordering::Less => {
if !self.advance() {
return SkipResult::End;
}
}
Ordering::Equal => {
return SkipResult::Reached;
}
Ordering::Greater => {
debug_assert!(self.doc() > target);
return SkipResult::OverStep;
}
}
},
Ordering::Less => {
debug_assert!(self.doc() > target);
SkipResult::OverStep
}
} }
let mut doc = self.doc();
while doc < target {
doc = self.advance();
}
doc
} }
/// Returns the current document /// Returns the current document
@@ -96,7 +122,7 @@ impl DocSet for BitSetDocSet {
mod tests { mod tests {
use super::BitSetDocSet; use super::BitSetDocSet;
use crate::common::BitSet; use crate::common::BitSet;
use crate::docset::{DocSet, TERMINATED}; use crate::docset::{DocSet, SkipResult};
use crate::DocId; use crate::DocId;
fn create_docbitset(docs: &[DocId], max_doc: DocId) -> BitSetDocSet { fn create_docbitset(docs: &[DocId], max_doc: DocId) -> BitSetDocSet {
@@ -107,24 +133,19 @@ mod tests {
BitSetDocSet::from(docset) BitSetDocSet::from(docset)
} }
#[test]
fn test_empty() {
let bitset = BitSet::with_max_value(1000);
let mut empty = BitSetDocSet::from(bitset);
assert_eq!(empty.advance(), TERMINATED)
}
fn test_go_through_sequential(docs: &[DocId]) { fn test_go_through_sequential(docs: &[DocId]) {
let mut docset = create_docbitset(docs, 1_000u32); let mut docset = create_docbitset(docs, 1_000u32);
for &doc in docs { for &doc in docs {
assert!(docset.advance());
assert_eq!(doc, docset.doc()); assert_eq!(doc, docset.doc());
docset.advance();
} }
assert_eq!(docset.advance(), TERMINATED); assert!(!docset.advance());
assert!(!docset.advance());
} }
#[test] #[test]
fn test_docbitset_sequential() { fn test_docbitset_sequential() {
test_go_through_sequential(&[]);
test_go_through_sequential(&[1, 2, 3]); test_go_through_sequential(&[1, 2, 3]);
test_go_through_sequential(&[1, 2, 3, 4, 5, 63, 64, 65]); test_go_through_sequential(&[1, 2, 3, 4, 5, 63, 64, 65]);
test_go_through_sequential(&[63, 64, 65]); test_go_through_sequential(&[63, 64, 65]);
@@ -135,64 +156,64 @@ mod tests {
fn test_docbitset_skip() { fn test_docbitset_skip() {
{ {
let mut docset = create_docbitset(&[1, 5, 6, 7, 5112], 10_000); let mut docset = create_docbitset(&[1, 5, 6, 7, 5112], 10_000);
assert_eq!(docset.seek(7), 7); assert_eq!(docset.skip_next(7), SkipResult::Reached);
assert_eq!(docset.doc(), 7); assert_eq!(docset.doc(), 7);
assert_eq!(docset.advance(), 5112); assert!(docset.advance(), 7);
assert_eq!(docset.doc(), 5112); assert_eq!(docset.doc(), 5112);
assert_eq!(docset.advance(), TERMINATED); assert!(!docset.advance());
} }
{ {
let mut docset = create_docbitset(&[1, 5, 6, 7, 5112], 10_000); let mut docset = create_docbitset(&[1, 5, 6, 7, 5112], 10_000);
assert_eq!(docset.seek(3), 5); assert_eq!(docset.skip_next(3), SkipResult::OverStep);
assert_eq!(docset.doc(), 5); assert_eq!(docset.doc(), 5);
assert_eq!(docset.advance(), 6); assert!(docset.advance());
} }
{ {
let mut docset = create_docbitset(&[5112], 10_000); let mut docset = create_docbitset(&[5112], 10_000);
assert_eq!(docset.seek(5112), 5112); assert_eq!(docset.skip_next(5112), SkipResult::Reached);
assert_eq!(docset.doc(), 5112); assert_eq!(docset.doc(), 5112);
assert_eq!(docset.advance(), TERMINATED); assert!(!docset.advance());
} }
{ {
let mut docset = create_docbitset(&[5112], 10_000); let mut docset = create_docbitset(&[5112], 10_000);
assert_eq!(docset.seek(5113), TERMINATED); assert_eq!(docset.skip_next(5113), SkipResult::End);
assert_eq!(docset.advance(), TERMINATED); assert!(!docset.advance());
} }
{ {
let mut docset = create_docbitset(&[5112], 10_000); let mut docset = create_docbitset(&[5112], 10_000);
assert_eq!(docset.seek(5111), 5112); assert_eq!(docset.skip_next(5111), SkipResult::OverStep);
assert_eq!(docset.doc(), 5112); assert_eq!(docset.doc(), 5112);
assert_eq!(docset.advance(), TERMINATED); assert!(!docset.advance());
} }
{ {
let mut docset = create_docbitset(&[1, 5, 6, 7, 5112, 5500, 6666], 10_000); let mut docset = create_docbitset(&[1, 5, 6, 7, 5112, 5500, 6666], 10_000);
assert_eq!(docset.seek(5112), 5112); assert_eq!(docset.skip_next(5112), SkipResult::Reached);
assert_eq!(docset.doc(), 5112); assert_eq!(docset.doc(), 5112);
assert_eq!(docset.advance(), 5500); assert!(docset.advance());
assert_eq!(docset.doc(), 5500); assert_eq!(docset.doc(), 5500);
assert_eq!(docset.advance(), 6666); assert!(docset.advance());
assert_eq!(docset.doc(), 6666); assert_eq!(docset.doc(), 6666);
assert_eq!(docset.advance(), TERMINATED); assert!(!docset.advance());
} }
{ {
let mut docset = create_docbitset(&[1, 5, 6, 7, 5112, 5500, 6666], 10_000); let mut docset = create_docbitset(&[1, 5, 6, 7, 5112, 5500, 6666], 10_000);
assert_eq!(docset.seek(5111), 5112); assert_eq!(docset.skip_next(5111), SkipResult::OverStep);
assert_eq!(docset.doc(), 5112); assert_eq!(docset.doc(), 5112);
assert_eq!(docset.advance(), 5500); assert!(docset.advance());
assert_eq!(docset.doc(), 5500); assert_eq!(docset.doc(), 5500);
assert_eq!(docset.advance(), 6666); assert!(docset.advance());
assert_eq!(docset.doc(), 6666); assert_eq!(docset.doc(), 6666);
assert_eq!(docset.advance(), TERMINATED); assert!(!docset.advance());
} }
{ {
let mut docset = create_docbitset(&[1, 5, 6, 7, 5112, 5513, 6666], 10_000); let mut docset = create_docbitset(&[1, 5, 6, 7, 5112, 5513, 6666], 10_000);
assert_eq!(docset.seek(5111), 5112); assert_eq!(docset.skip_next(5111), SkipResult::OverStep);
assert_eq!(docset.doc(), 5112); assert_eq!(docset.doc(), 5112);
assert_eq!(docset.advance(), 5513); assert!(docset.advance());
assert_eq!(docset.doc(), 5513); assert_eq!(docset.doc(), 5513);
assert_eq!(docset.advance(), 6666); assert!(docset.advance());
assert_eq!(docset.doc(), 6666); assert_eq!(docset.doc(), 6666);
assert_eq!(docset.advance(), TERMINATED); assert!(!docset.advance());
} }
} }
} }
@@ -202,7 +223,6 @@ mod bench {
use super::BitSet; use super::BitSet;
use super::BitSetDocSet; use super::BitSetDocSet;
use crate::docset::TERMINATED;
use crate::test; use crate::test;
use crate::tests; use crate::tests;
use crate::DocSet; use crate::DocSet;
@@ -237,7 +257,7 @@ mod bench {
} }
b.iter(|| { b.iter(|| {
let mut docset = BitSetDocSet::from(bitset.clone()); let mut docset = BitSetDocSet::from(bitset.clone());
while docset.advance() != TERMINATED {} while docset.advance() {}
}); });
} }
} }

View File

@@ -25,6 +25,7 @@ fn compute_tf_cache(average_fieldnorm: f32) -> [f32; 256] {
cache cache
} }
#[derive(Clone)]
pub struct BM25Weight { pub struct BM25Weight {
idf_explain: Explanation, idf_explain: Explanation,
weight: f32, weight: f32,
@@ -33,15 +34,6 @@ pub struct BM25Weight {
} }
impl BM25Weight { impl BM25Weight {
pub fn boost_by(&self, boost: f32) -> BM25Weight {
BM25Weight {
idf_explain: self.idf_explain.clone(),
weight: self.weight * boost,
cache: self.cache,
average_fieldnorm: self.average_fieldnorm,
}
}
pub fn for_terms(searcher: &Searcher, terms: &[Term]) -> BM25Weight { pub fn for_terms(searcher: &Searcher, terms: &[Term]) -> BM25Weight {
assert!(!terms.is_empty(), "BM25 requires at least one term"); assert!(!terms.is_empty(), "BM25 requires at least one term");
let field = terms[0].field(); let field = terms[0].field();

View File

@@ -5,6 +5,7 @@ use crate::query::TermQuery;
use crate::query::Weight; use crate::query::Weight;
use crate::schema::IndexRecordOption; use crate::schema::IndexRecordOption;
use crate::schema::Term; use crate::schema::Term;
use crate::Result;
use crate::Searcher; use crate::Searcher;
use std::collections::BTreeSet; use std::collections::BTreeSet;
@@ -29,9 +30,9 @@ use std::collections::BTreeSet;
///use tantivy::query::{BooleanQuery, Occur, PhraseQuery, Query, TermQuery}; ///use tantivy::query::{BooleanQuery, Occur, PhraseQuery, Query, TermQuery};
///use tantivy::schema::{IndexRecordOption, Schema, TEXT}; ///use tantivy::schema::{IndexRecordOption, Schema, TEXT};
///use tantivy::Term; ///use tantivy::Term;
///use tantivy::Index; ///use tantivy::{Index, Result};
/// ///
///fn main() -> tantivy::Result<()> { ///fn main() -> Result<()> {
/// let mut schema_builder = Schema::builder(); /// let mut schema_builder = Schema::builder();
/// let title = schema_builder.add_text_field("title", TEXT); /// let title = schema_builder.add_text_field("title", TEXT);
/// let body = schema_builder.add_text_field("body", TEXT); /// let body = schema_builder.add_text_field("body", TEXT);
@@ -148,14 +149,14 @@ impl From<Vec<(Occur, Box<dyn Query>)>> for BooleanQuery {
} }
impl Query for BooleanQuery { impl Query for BooleanQuery {
fn weight(&self, searcher: &Searcher, scoring_enabled: bool) -> crate::Result<Box<dyn Weight>> { fn weight(&self, searcher: &Searcher, scoring_enabled: bool) -> Result<Box<dyn Weight>> {
let sub_weights = self let sub_weights = self
.subqueries .subqueries
.iter() .iter()
.map(|&(ref occur, ref subquery)| { .map(|&(ref occur, ref subquery)| {
Ok((*occur, subquery.weight(searcher, scoring_enabled)?)) Ok((*occur, subquery.weight(searcher, scoring_enabled)?))
}) })
.collect::<crate::Result<_>>()?; .collect::<Result<_>>()?;
Ok(Box::new(BooleanWeight::new(sub_weights, scoring_enabled))) Ok(Box::new(BooleanWeight::new(sub_weights, scoring_enabled)))
} }

View File

@@ -2,7 +2,6 @@ use crate::core::SegmentReader;
use crate::query::explanation::does_not_match; use crate::query::explanation::does_not_match;
use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner, SumWithCoordsCombiner}; use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner, SumWithCoordsCombiner};
use crate::query::term_query::TermScorer; use crate::query::term_query::TermScorer;
use crate::query::weight::{for_each_pruning_scorer, for_each_scorer};
use crate::query::EmptyScorer; use crate::query::EmptyScorer;
use crate::query::Exclude; use crate::query::Exclude;
use crate::query::Occur; use crate::query::Occur;
@@ -11,21 +10,17 @@ use crate::query::Scorer;
use crate::query::Union; use crate::query::Union;
use crate::query::Weight; use crate::query::Weight;
use crate::query::{intersect_scorers, Explanation}; use crate::query::{intersect_scorers, Explanation};
use crate::{DocId, Score}; use crate::Result;
use crate::{DocId, SkipResult};
use std::collections::HashMap; use std::collections::HashMap;
enum SpecializedScorer<TScoreCombiner: ScoreCombiner> { fn scorer_union<TScoreCombiner>(scorers: Vec<Box<dyn Scorer>>) -> Box<dyn Scorer>
TermUnion(Union<TermScorer, TScoreCombiner>),
Other(Box<dyn Scorer>),
}
fn scorer_union<TScoreCombiner>(scorers: Vec<Box<dyn Scorer>>) -> SpecializedScorer<TScoreCombiner>
where where
TScoreCombiner: ScoreCombiner, TScoreCombiner: ScoreCombiner,
{ {
assert!(!scorers.is_empty()); assert!(!scorers.is_empty());
if scorers.len() == 1 { if scorers.len() == 1 {
return SpecializedScorer::Other(scorers.into_iter().next().unwrap()); //< we checked the size beforehands return scorers.into_iter().next().unwrap(); //< we checked the size beforehands
} }
{ {
@@ -35,21 +30,14 @@ where
.into_iter() .into_iter()
.map(|scorer| *(scorer.downcast::<TermScorer>().map_err(|_| ()).unwrap())) .map(|scorer| *(scorer.downcast::<TermScorer>().map_err(|_| ()).unwrap()))
.collect(); .collect();
return SpecializedScorer::TermUnion(Union::<TermScorer, TScoreCombiner>::from( let scorer: Box<dyn Scorer> =
scorers, Box::new(Union::<TermScorer, TScoreCombiner>::from(scorers));
)); return scorer;
} }
} }
SpecializedScorer::Other(Box::new(Union::<_, TScoreCombiner>::from(scorers)))
}
impl<TScoreCombiner: ScoreCombiner> Into<Box<dyn Scorer>> for SpecializedScorer<TScoreCombiner> { let scorer: Box<dyn Scorer> = Box::new(Union::<_, TScoreCombiner>::from(scorers));
fn into(self) -> Box<dyn Scorer> { scorer
match self {
Self::TermUnion(union) => Box::new(union),
Self::Other(scorer) => scorer,
}
}
} }
pub struct BooleanWeight { pub struct BooleanWeight {
@@ -68,11 +56,10 @@ impl BooleanWeight {
fn per_occur_scorers( fn per_occur_scorers(
&self, &self,
reader: &SegmentReader, reader: &SegmentReader,
boost: f32, ) -> Result<HashMap<Occur, Vec<Box<dyn Scorer>>>> {
) -> crate::Result<HashMap<Occur, Vec<Box<dyn Scorer>>>> {
let mut per_occur_scorers: HashMap<Occur, Vec<Box<dyn Scorer>>> = HashMap::new(); let mut per_occur_scorers: HashMap<Occur, Vec<Box<dyn Scorer>>> = HashMap::new();
for &(ref occur, ref subweight) in &self.weights { for &(ref occur, ref subweight) in &self.weights {
let sub_scorer: Box<dyn Scorer> = subweight.scorer(reader, boost)?; let sub_scorer: Box<dyn Scorer> = subweight.scorer(reader)?;
per_occur_scorers per_occur_scorers
.entry(*occur) .entry(*occur)
.or_insert_with(Vec::new) .or_insert_with(Vec::new)
@@ -84,51 +71,41 @@ impl BooleanWeight {
fn complex_scorer<TScoreCombiner: ScoreCombiner>( fn complex_scorer<TScoreCombiner: ScoreCombiner>(
&self, &self,
reader: &SegmentReader, reader: &SegmentReader,
boost: f32, ) -> Result<Box<dyn Scorer>> {
) -> crate::Result<SpecializedScorer<TScoreCombiner>> { let mut per_occur_scorers = self.per_occur_scorers(reader)?;
let mut per_occur_scorers = self.per_occur_scorers(reader, boost)?;
let should_scorer_opt: Option<SpecializedScorer<TScoreCombiner>> = per_occur_scorers let should_scorer_opt: Option<Box<dyn Scorer>> = per_occur_scorers
.remove(&Occur::Should) .remove(&Occur::Should)
.map(scorer_union::<TScoreCombiner>); .map(scorer_union::<TScoreCombiner>);
let exclude_scorer_opt: Option<Box<dyn Scorer>> = per_occur_scorers let exclude_scorer_opt: Option<Box<dyn Scorer>> = per_occur_scorers
.remove(&Occur::MustNot) .remove(&Occur::MustNot)
.map(scorer_union::<TScoreCombiner>) .map(scorer_union::<TScoreCombiner>);
.map(Into::into);
let must_scorer_opt: Option<Box<dyn Scorer>> = per_occur_scorers let must_scorer_opt: Option<Box<dyn Scorer>> = per_occur_scorers
.remove(&Occur::Must) .remove(&Occur::Must)
.map(intersect_scorers); .map(intersect_scorers);
let positive_scorer: SpecializedScorer<TScoreCombiner> = let positive_scorer: Box<dyn Scorer> = match (should_scorer_opt, must_scorer_opt) {
match (should_scorer_opt, must_scorer_opt) { (Some(should_scorer), Some(must_scorer)) => {
(Some(should_scorer), Some(must_scorer)) => { if self.scoring_enabled {
if self.scoring_enabled { Box::new(RequiredOptionalScorer::<_, _, TScoreCombiner>::new(
SpecializedScorer::Other(Box::new(RequiredOptionalScorer::< must_scorer,
Box<dyn Scorer>, should_scorer,
Box<dyn Scorer>, ))
TScoreCombiner, } else {
>::new( must_scorer
must_scorer, should_scorer.into()
)))
} else {
SpecializedScorer::Other(must_scorer)
}
} }
(None, Some(must_scorer)) => SpecializedScorer::Other(must_scorer), }
(Some(should_scorer), None) => should_scorer, (None, Some(must_scorer)) => must_scorer,
(None, None) => { (Some(should_scorer), None) => should_scorer,
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer))); (None, None) => {
} return Ok(Box::new(EmptyScorer));
}; }
};
if let Some(exclude_scorer) = exclude_scorer_opt { if let Some(exclude_scorer) = exclude_scorer_opt {
let positive_scorer_boxed: Box<dyn Scorer> = positive_scorer.into(); Ok(Box::new(Exclude::new(positive_scorer, exclude_scorer)))
Ok(SpecializedScorer::Other(Box::new(Exclude::new(
positive_scorer_boxed,
exclude_scorer,
))))
} else { } else {
Ok(positive_scorer) Ok(positive_scorer)
} }
@@ -136,7 +113,7 @@ impl BooleanWeight {
} }
impl Weight for BooleanWeight { impl Weight for BooleanWeight {
fn scorer(&self, reader: &SegmentReader, boost: f32) -> crate::Result<Box<dyn Scorer>> { fn scorer(&self, reader: &SegmentReader) -> Result<Box<dyn Scorer>> {
if self.weights.is_empty() { if self.weights.is_empty() {
Ok(Box::new(EmptyScorer)) Ok(Box::new(EmptyScorer))
} else if self.weights.len() == 1 { } else if self.weights.len() == 1 {
@@ -144,20 +121,18 @@ impl Weight for BooleanWeight {
if occur == Occur::MustNot { if occur == Occur::MustNot {
Ok(Box::new(EmptyScorer)) Ok(Box::new(EmptyScorer))
} else { } else {
weight.scorer(reader, boost) weight.scorer(reader)
} }
} else if self.scoring_enabled { } else if self.scoring_enabled {
self.complex_scorer::<SumWithCoordsCombiner>(reader, boost) self.complex_scorer::<SumWithCoordsCombiner>(reader)
.map(Into::into)
} else { } else {
self.complex_scorer::<DoNothingCombiner>(reader, boost) self.complex_scorer::<DoNothingCombiner>(reader)
.map(Into::into)
} }
} }
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> { fn explain(&self, reader: &SegmentReader, doc: DocId) -> Result<Explanation> {
let mut scorer = self.scorer(reader, 1.0f32)?; let mut scorer = self.scorer(reader)?;
if scorer.seek(doc) != doc { if scorer.skip_next(doc) != SkipResult::Reached {
return Err(does_not_match(doc)); return Err(does_not_match(doc));
} }
if !self.scoring_enabled { if !self.scoring_enabled {
@@ -174,51 +149,6 @@ impl Weight for BooleanWeight {
} }
Ok(explanation) Ok(explanation)
} }
fn for_each(
&self,
reader: &SegmentReader,
callback: &mut dyn FnMut(DocId, Score),
) -> crate::Result<()> {
let scorer = self.complex_scorer::<SumWithCoordsCombiner>(reader, 1.0f32)?;
match scorer {
SpecializedScorer::TermUnion(mut union_scorer) => {
for_each_scorer(&mut union_scorer, callback);
}
SpecializedScorer::Other(mut scorer) => {
for_each_scorer(scorer.as_mut(), callback);
}
}
Ok(())
}
/// Calls `callback` with all of the `(doc, score)` for which score
/// is exceeding a given threshold.
///
/// This method is useful for the TopDocs collector.
/// For all docsets, the blanket implementation has the benefit
/// of prefiltering (doc, score) pairs, avoiding the
/// virtual dispatch cost.
///
/// More importantly, it makes it possible for scorers to implement
/// important optimization (e.g. BlockWAND for union).
fn for_each_pruning(
&self,
threshold: f32,
reader: &SegmentReader,
callback: &mut dyn FnMut(DocId, Score) -> Score,
) -> crate::Result<()> {
let scorer = self.complex_scorer::<SumWithCoordsCombiner>(reader, 1.0f32)?;
match scorer {
SpecializedScorer::TermUnion(mut union_scorer) => {
for_each_pruning_scorer(&mut union_scorer, threshold, callback);
}
SpecializedScorer::Other(mut scorer) => {
for_each_pruning_scorer(scorer.as_mut(), threshold, callback);
}
}
Ok(())
}
} }
fn is_positive_occur(occur: Occur) -> bool { fn is_positive_occur(occur: Occur) -> bool {

View File

@@ -18,7 +18,6 @@ mod tests {
use crate::query::Scorer; use crate::query::Scorer;
use crate::query::TermQuery; use crate::query::TermQuery;
use crate::schema::*; use crate::schema::*;
use crate::tests::assert_nearly_equals;
use crate::Index; use crate::Index;
use crate::{DocAddress, DocId}; use crate::{DocAddress, DocId};
@@ -31,11 +30,24 @@ mod tests {
// writing the segment // writing the segment
let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap(); let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap();
{ {
index_writer.add_document(doc!(text_field => "a b c")); let doc = doc!(text_field => "a b c");
index_writer.add_document(doc!(text_field => "a c")); index_writer.add_document(doc);
index_writer.add_document(doc!(text_field => "b c")); }
index_writer.add_document(doc!(text_field => "a b c d")); {
index_writer.add_document(doc!(text_field => "d")); let doc = doc!(text_field => "a c");
index_writer.add_document(doc);
}
{
let doc = doc!(text_field => "b c");
index_writer.add_document(doc);
}
{
let doc = doc!(text_field => "a b c d");
index_writer.add_document(doc);
}
{
let doc = doc!(text_field => "d");
index_writer.add_document(doc);
} }
assert!(index_writer.commit().is_ok()); assert!(index_writer.commit().is_ok());
} }
@@ -58,9 +70,7 @@ mod tests {
let query = query_parser.parse_query("+a").unwrap(); let query = query_parser.parse_query("+a").unwrap();
let searcher = index.reader().unwrap().searcher(); let searcher = index.reader().unwrap().searcher();
let weight = query.weight(&searcher, true).unwrap(); let weight = query.weight(&searcher, true).unwrap();
let scorer = weight let scorer = weight.scorer(searcher.segment_reader(0u32)).unwrap();
.scorer(searcher.segment_reader(0u32), 1.0f32)
.unwrap();
assert!(scorer.is::<TermScorer>()); assert!(scorer.is::<TermScorer>());
} }
@@ -72,17 +82,13 @@ mod tests {
{ {
let query = query_parser.parse_query("+a +b +c").unwrap(); let query = query_parser.parse_query("+a +b +c").unwrap();
let weight = query.weight(&searcher, true).unwrap(); let weight = query.weight(&searcher, true).unwrap();
let scorer = weight let scorer = weight.scorer(searcher.segment_reader(0u32)).unwrap();
.scorer(searcher.segment_reader(0u32), 1.0f32)
.unwrap();
assert!(scorer.is::<Intersection<TermScorer>>()); assert!(scorer.is::<Intersection<TermScorer>>());
} }
{ {
let query = query_parser.parse_query("+a +(b c)").unwrap(); let query = query_parser.parse_query("+a +(b c)").unwrap();
let weight = query.weight(&searcher, true).unwrap(); let weight = query.weight(&searcher, true).unwrap();
let scorer = weight let scorer = weight.scorer(searcher.segment_reader(0u32)).unwrap();
.scorer(searcher.segment_reader(0u32), 1.0f32)
.unwrap();
assert!(scorer.is::<Intersection<Box<dyn Scorer>>>()); assert!(scorer.is::<Intersection<Box<dyn Scorer>>>());
} }
} }
@@ -95,9 +101,7 @@ mod tests {
{ {
let query = query_parser.parse_query("+a b").unwrap(); let query = query_parser.parse_query("+a b").unwrap();
let weight = query.weight(&searcher, true).unwrap(); let weight = query.weight(&searcher, true).unwrap();
let scorer = weight let scorer = weight.scorer(searcher.segment_reader(0u32)).unwrap();
.scorer(searcher.segment_reader(0u32), 1.0f32)
.unwrap();
assert!(scorer.is::<RequiredOptionalScorer< assert!(scorer.is::<RequiredOptionalScorer<
Box<dyn Scorer>, Box<dyn Scorer>,
Box<dyn Scorer>, Box<dyn Scorer>,
@@ -107,9 +111,7 @@ mod tests {
{ {
let query = query_parser.parse_query("+a b").unwrap(); let query = query_parser.parse_query("+a b").unwrap();
let weight = query.weight(&searcher, false).unwrap(); let weight = query.weight(&searcher, false).unwrap();
let scorer = weight let scorer = weight.scorer(searcher.segment_reader(0u32)).unwrap();
.scorer(searcher.segment_reader(0u32), 1.0f32)
.unwrap();
assert!(scorer.is::<TermScorer>()); assert!(scorer.is::<TermScorer>());
} }
} }
@@ -177,48 +179,6 @@ mod tests {
} }
} }
#[test]
pub fn test_boolean_query_with_weight() {
let mut schema_builder = Schema::builder();
let text_field = schema_builder.add_text_field("text", TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
{
let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap();
index_writer.add_document(doc!(text_field => "a b c"));
index_writer.add_document(doc!(text_field => "a c"));
index_writer.add_document(doc!(text_field => "b c"));
assert!(index_writer.commit().is_ok());
}
let term_a: Box<dyn Query> = Box::new(TermQuery::new(
Term::from_field_text(text_field, "a"),
IndexRecordOption::WithFreqs,
));
let term_b: Box<dyn Query> = Box::new(TermQuery::new(
Term::from_field_text(text_field, "b"),
IndexRecordOption::WithFreqs,
));
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let boolean_query =
BooleanQuery::from(vec![(Occur::Should, term_a), (Occur::Should, term_b)]);
let boolean_weight = boolean_query.weight(&searcher, true).unwrap();
{
let mut boolean_scorer = boolean_weight
.scorer(searcher.segment_reader(0u32), 1.0f32)
.unwrap();
assert_eq!(boolean_scorer.doc(), 0u32);
assert_nearly_equals(boolean_scorer.score(), 0.84163445f32);
}
{
let mut boolean_scorer = boolean_weight
.scorer(searcher.segment_reader(0u32), 2.0f32)
.unwrap();
assert_eq!(boolean_scorer.doc(), 0u32);
assert_nearly_equals(boolean_scorer.score(), 1.6832689f32);
}
}
#[test] #[test]
pub fn test_intersection_score() { pub fn test_intersection_score() {
let (index, text_field) = aux_test_helper(); let (index, text_field) = aux_test_helper();
@@ -289,9 +249,7 @@ mod tests {
let query_parser = QueryParser::for_index(&index, vec![title, text]); let query_parser = QueryParser::for_index(&index, vec![title, text]);
let query = query_parser.parse_query("Оксана Лифенко").unwrap(); let query = query_parser.parse_query("Оксана Лифенко").unwrap();
let weight = query.weight(&searcher, true).unwrap(); let weight = query.weight(&searcher, true).unwrap();
let mut scorer = weight let mut scorer = weight.scorer(searcher.segment_reader(0u32)).unwrap();
.scorer(searcher.segment_reader(0u32), 1.0f32)
.unwrap();
scorer.advance(); scorer.advance();
let explanation = query.explain(&searcher, DocAddress(0u32, 0u32)).unwrap(); let explanation = query.explain(&searcher, DocAddress(0u32, 0u32)).unwrap();

View File

@@ -1,159 +0,0 @@
use crate::fastfield::DeleteBitSet;
use crate::query::explanation::does_not_match;
use crate::query::{Explanation, Query, Scorer, Weight};
use crate::{DocId, DocSet, Searcher, SegmentReader, Term};
use std::collections::BTreeSet;
use std::fmt;
/// `BoostQuery` is a wrapper over a query used to boost its score.
///
/// The document set matched by the `BoostQuery` is strictly the same as the underlying query.
/// The score of each document, is the score of the underlying query multiplied by the `boost`
/// factor.
pub struct BoostQuery {
query: Box<dyn Query>,
boost: f32,
}
impl BoostQuery {
/// Builds a boost query.
pub fn new(query: Box<dyn Query>, boost: f32) -> BoostQuery {
BoostQuery { query, boost }
}
}
impl Clone for BoostQuery {
fn clone(&self) -> Self {
BoostQuery {
query: self.query.box_clone(),
boost: self.boost,
}
}
}
impl fmt::Debug for BoostQuery {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Boost(query={:?}, boost={})", self.query, self.boost)
}
}
impl Query for BoostQuery {
fn weight(&self, searcher: &Searcher, scoring_enabled: bool) -> crate::Result<Box<dyn Weight>> {
let weight_without_boost = self.query.weight(searcher, scoring_enabled)?;
let boosted_weight = if scoring_enabled {
Box::new(BoostWeight::new(weight_without_boost, self.boost))
} else {
weight_without_boost
};
Ok(boosted_weight)
}
fn query_terms(&self, term_set: &mut BTreeSet<Term>) {
self.query.query_terms(term_set)
}
}
pub(crate) struct BoostWeight {
weight: Box<dyn Weight>,
boost: f32,
}
impl BoostWeight {
pub fn new(weight: Box<dyn Weight>, boost: f32) -> Self {
BoostWeight { weight, boost }
}
}
impl Weight for BoostWeight {
fn scorer(&self, reader: &SegmentReader, boost: f32) -> crate::Result<Box<dyn Scorer>> {
self.weight.scorer(reader, boost * self.boost)
}
fn explain(&self, reader: &SegmentReader, doc: u32) -> crate::Result<Explanation> {
let mut scorer = self.scorer(reader, 1.0f32)?;
if scorer.seek(doc) != doc {
return Err(does_not_match(doc));
}
let mut explanation =
Explanation::new(format!("Boost x{} of ...", self.boost), scorer.score());
let underlying_explanation = self.weight.explain(reader, doc)?;
explanation.add_detail(underlying_explanation);
Ok(explanation)
}
fn count(&self, reader: &SegmentReader) -> crate::Result<u32> {
self.weight.count(reader)
}
}
pub(crate) struct BoostScorer<S: Scorer> {
underlying: S,
boost: f32,
}
impl<S: Scorer> BoostScorer<S> {
pub fn new(underlying: S, boost: f32) -> BoostScorer<S> {
BoostScorer { underlying, boost }
}
}
impl<S: Scorer> DocSet for BoostScorer<S> {
fn advance(&mut self) -> DocId {
self.underlying.advance()
}
fn seek(&mut self, target: DocId) -> DocId {
self.underlying.seek(target)
}
fn fill_buffer(&mut self, buffer: &mut [DocId]) -> usize {
self.underlying.fill_buffer(buffer)
}
fn doc(&self) -> u32 {
self.underlying.doc()
}
fn size_hint(&self) -> u32 {
self.underlying.size_hint()
}
fn count(&mut self, delete_bitset: &DeleteBitSet) -> u32 {
self.underlying.count(delete_bitset)
}
fn count_including_deleted(&mut self) -> u32 {
self.underlying.count_including_deleted()
}
}
impl<S: Scorer> Scorer for BoostScorer<S> {
fn score(&mut self) -> f32 {
self.underlying.score() * self.boost
}
}
#[cfg(test)]
mod tests {
use super::BoostQuery;
use crate::query::{AllQuery, Query};
use crate::schema::Schema;
use crate::{DocAddress, Document, Index};
#[test]
fn test_boost_query_explain() {
let schema = Schema::builder().build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap();
index_writer.add_document(Document::new());
assert!(index_writer.commit().is_ok());
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let query = BoostQuery::new(Box::new(AllQuery), 0.2);
let explanation = query.explain(&searcher, DocAddress(0, 0u32)).unwrap();
assert_eq!(
explanation.to_pretty_json(),
"{\n \"value\": 0.2,\n \"description\": \"Boost x0.2 of ...\",\n \"details\": [\n {\n \"value\": 1.0,\n \"description\": \"AllQuery\"\n }\n ]\n}"
)
}
}

View File

@@ -1,10 +1,10 @@
use super::Scorer; use super::Scorer;
use crate::docset::TERMINATED;
use crate::query::explanation::does_not_match; use crate::query::explanation::does_not_match;
use crate::query::Weight; use crate::query::Weight;
use crate::query::{Explanation, Query}; use crate::query::{Explanation, Query};
use crate::DocId; use crate::DocId;
use crate::DocSet; use crate::DocSet;
use crate::Result;
use crate::Score; use crate::Score;
use crate::Searcher; use crate::Searcher;
use crate::SegmentReader; use crate::SegmentReader;
@@ -16,15 +16,11 @@ use crate::SegmentReader;
pub struct EmptyQuery; pub struct EmptyQuery;
impl Query for EmptyQuery { impl Query for EmptyQuery {
fn weight( fn weight(&self, _searcher: &Searcher, _scoring_enabled: bool) -> Result<Box<dyn Weight>> {
&self,
_searcher: &Searcher,
_scoring_enabled: bool,
) -> crate::Result<Box<dyn Weight>> {
Ok(Box::new(EmptyWeight)) Ok(Box::new(EmptyWeight))
} }
fn count(&self, _searcher: &Searcher) -> crate::Result<usize> { fn count(&self, _searcher: &Searcher) -> Result<usize> {
Ok(0) Ok(0)
} }
} }
@@ -34,11 +30,11 @@ impl Query for EmptyQuery {
/// It is useful for tests and handling edge cases. /// It is useful for tests and handling edge cases.
pub struct EmptyWeight; pub struct EmptyWeight;
impl Weight for EmptyWeight { impl Weight for EmptyWeight {
fn scorer(&self, _reader: &SegmentReader, _boost: f32) -> crate::Result<Box<dyn Scorer>> { fn scorer(&self, _reader: &SegmentReader) -> Result<Box<dyn Scorer>> {
Ok(Box::new(EmptyScorer)) Ok(Box::new(EmptyScorer))
} }
fn explain(&self, _reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> { fn explain(&self, _reader: &SegmentReader, doc: DocId) -> Result<Explanation> {
Err(does_not_match(doc)) Err(does_not_match(doc))
} }
} }
@@ -49,12 +45,15 @@ impl Weight for EmptyWeight {
pub struct EmptyScorer; pub struct EmptyScorer;
impl DocSet for EmptyScorer { impl DocSet for EmptyScorer {
fn advance(&mut self) -> DocId { fn advance(&mut self) -> bool {
TERMINATED false
} }
fn doc(&self) -> DocId { fn doc(&self) -> DocId {
TERMINATED panic!(
"You may not call .doc() on a scorer \
where the last call to advance() did not return true."
);
} }
fn size_hint(&self) -> u32 { fn size_hint(&self) -> u32 {
@@ -70,15 +69,18 @@ impl Scorer for EmptyScorer {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::docset::TERMINATED;
use crate::query::EmptyScorer; use crate::query::EmptyScorer;
use crate::DocSet; use crate::DocSet;
#[test] #[test]
fn test_empty_scorer() { fn test_empty_scorer() {
let mut empty_scorer = EmptyScorer; let mut empty_scorer = EmptyScorer;
assert_eq!(empty_scorer.doc(), TERMINATED); assert!(!empty_scorer.advance());
assert_eq!(empty_scorer.advance(), TERMINATED); }
assert_eq!(empty_scorer.doc(), TERMINATED);
#[test]
#[should_panic]
fn test_empty_scorer_panic_on_doc_call() {
EmptyScorer.doc();
} }
} }

View File

@@ -1,37 +1,41 @@
use crate::docset::{DocSet, TERMINATED}; use crate::docset::{DocSet, SkipResult};
use crate::query::Scorer; use crate::query::Scorer;
use crate::DocId; use crate::DocId;
use crate::Score; use crate::Score;
#[derive(Clone, Copy, Debug)]
enum State {
ExcludeOne(DocId),
Finished,
}
/// Filters a given `DocSet` by removing the docs from a given `DocSet`. /// Filters a given `DocSet` by removing the docs from a given `DocSet`.
/// ///
/// The excluding docset has no impact on scoring. /// The excluding docset has no impact on scoring.
pub struct Exclude<TDocSet, TDocSetExclude> { pub struct Exclude<TDocSet, TDocSetExclude> {
underlying_docset: TDocSet, underlying_docset: TDocSet,
excluding_docset: TDocSetExclude, excluding_docset: TDocSetExclude,
excluding_state: State,
} }
impl<TDocSet, TDocSetExclude> Exclude<TDocSet, TDocSetExclude> impl<TDocSet, TDocSetExclude> Exclude<TDocSet, TDocSetExclude>
where where
TDocSet: DocSet,
TDocSetExclude: DocSet, TDocSetExclude: DocSet,
{ {
/// Creates a new `ExcludeScorer` /// Creates a new `ExcludeScorer`
pub fn new( pub fn new(
mut underlying_docset: TDocSet, underlying_docset: TDocSet,
mut excluding_docset: TDocSetExclude, mut excluding_docset: TDocSetExclude,
) -> Exclude<TDocSet, TDocSetExclude> { ) -> Exclude<TDocSet, TDocSetExclude> {
while underlying_docset.doc() != TERMINATED { let state = if excluding_docset.advance() {
let target = underlying_docset.doc(); State::ExcludeOne(excluding_docset.doc())
if excluding_docset.seek(target) != target { } else {
// this document is not excluded. State::Finished
break; };
}
underlying_docset.advance();
}
Exclude { Exclude {
underlying_docset, underlying_docset,
excluding_docset, excluding_docset,
excluding_state: state,
} }
} }
} }
@@ -47,7 +51,28 @@ where
/// increasing `doc`. /// increasing `doc`.
fn accept(&mut self) -> bool { fn accept(&mut self) -> bool {
let doc = self.underlying_docset.doc(); let doc = self.underlying_docset.doc();
self.excluding_docset.seek(doc) != doc match self.excluding_state {
State::ExcludeOne(excluded_doc) => {
if doc == excluded_doc {
return false;
}
if excluded_doc > doc {
return true;
}
match self.excluding_docset.skip_next(doc) {
SkipResult::OverStep => {
self.excluding_state = State::ExcludeOne(self.excluding_docset.doc());
true
}
SkipResult::End => {
self.excluding_state = State::Finished;
true
}
SkipResult::Reached => false,
}
}
State::Finished => true,
}
} }
} }
@@ -56,24 +81,27 @@ where
TDocSet: DocSet, TDocSet: DocSet,
TDocSetExclude: DocSet, TDocSetExclude: DocSet,
{ {
fn advance(&mut self) -> DocId { fn advance(&mut self) -> bool {
while self.underlying_docset.advance() != TERMINATED { while self.underlying_docset.advance() {
if self.accept() { if self.accept() {
return self.doc(); return true;
} }
} }
TERMINATED false
} }
fn seek(&mut self, target: DocId) -> DocId { fn skip_next(&mut self, target: DocId) -> SkipResult {
let underlying_seek_result = self.underlying_docset.seek(target); let underlying_skip_result = self.underlying_docset.skip_next(target);
if underlying_seek_result == TERMINATED { if underlying_skip_result == SkipResult::End {
return TERMINATED; return SkipResult::End;
} }
if self.accept() { if self.accept() {
return underlying_seek_result; underlying_skip_result
} else if self.advance() {
SkipResult::OverStep
} else {
SkipResult::End
} }
self.advance()
} }
fn doc(&self) -> DocId { fn doc(&self) -> DocId {
@@ -113,9 +141,8 @@ mod tests {
VecDocSet::from(vec![1, 2, 3, 10, 16, 24]), VecDocSet::from(vec![1, 2, 3, 10, 16, 24]),
); );
let mut els = vec![]; let mut els = vec![];
while exclude_scorer.doc() != TERMINATED { while exclude_scorer.advance() {
els.push(exclude_scorer.doc()); els.push(exclude_scorer.doc());
exclude_scorer.advance();
} }
assert_eq!(els, vec![5, 8, 15]); assert_eq!(els, vec![5, 8, 15]);
} }

View File

@@ -1,5 +1,4 @@
use crate::{DocId, TantivyError}; use crate::{DocId, TantivyError};
use serde::Serialize;
pub(crate) fn does_not_match(doc: DocId) -> TantivyError { pub(crate) fn does_not_match(doc: DocId) -> TantivyError {
TantivyError::InvalidArgument(format!("Document #({}) does not match", doc)) TantivyError::InvalidArgument(format!("Document #({}) does not match", doc))

View File

@@ -1,41 +1,16 @@
use crate::error::TantivyError::InvalidArgument;
use crate::query::{AutomatonWeight, Query, Weight}; use crate::query::{AutomatonWeight, Query, Weight};
use crate::schema::Term; use crate::schema::Term;
use crate::Result;
use crate::Searcher; use crate::Searcher;
use crate::TantivyError::InvalidArgument; use levenshtein_automata::{LevenshteinAutomatonBuilder, DFA};
use levenshtein_automata::{Distance, LevenshteinAutomatonBuilder, DFA};
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use std::collections::HashMap; use std::collections::HashMap;
use std::ops::Range; use std::ops::Range;
use tantivy_fst::Automaton;
pub(crate) struct DFAWrapper(pub DFA);
impl Automaton for DFAWrapper {
type State = u32;
fn start(&self) -> Self::State {
self.0.initial_state()
}
fn is_match(&self, state: &Self::State) -> bool {
match self.0.distance(*state) {
Distance::Exact(_) => true,
Distance::AtLeast(_) => false,
}
}
fn can_match(&self, state: &u32) -> bool {
*state != levenshtein_automata::SINK_STATE
}
fn accept(&self, state: &Self::State, byte: u8) -> Self::State {
self.0.transition(*state, byte)
}
}
/// A range of Levenshtein distances that we will build DFAs for our terms /// A range of Levenshtein distances that we will build DFAs for our terms
/// The computation is exponential, so best keep it to low single digits /// The computation is exponential, so best keep it to low single digits
const VALID_LEVENSHTEIN_DISTANCE_RANGE: Range<u8> = 0..3; const VALID_LEVENSHTEIN_DISTANCE_RANGE: Range<u8> = (0..3);
static LEV_BUILDER: Lazy<HashMap<(u8, bool), LevenshteinAutomatonBuilder>> = Lazy::new(|| { static LEV_BUILDER: Lazy<HashMap<(u8, bool), LevenshteinAutomatonBuilder>> = Lazy::new(|| {
let mut lev_builder_cache = HashMap::new(); let mut lev_builder_cache = HashMap::new();
@@ -56,9 +31,9 @@ static LEV_BUILDER: Lazy<HashMap<(u8, bool), LevenshteinAutomatonBuilder>> = Laz
/// use tantivy::collector::{Count, TopDocs}; /// use tantivy::collector::{Count, TopDocs};
/// use tantivy::query::FuzzyTermQuery; /// use tantivy::query::FuzzyTermQuery;
/// use tantivy::schema::{Schema, TEXT}; /// use tantivy::schema::{Schema, TEXT};
/// use tantivy::{doc, Index, Term}; /// use tantivy::{doc, Index, Result, Term};
/// ///
/// fn example() -> tantivy::Result<()> { /// fn example() -> Result<()> {
/// let mut schema_builder = Schema::builder(); /// let mut schema_builder = Schema::builder();
/// let title = schema_builder.add_text_field("title", TEXT); /// let title = schema_builder.add_text_field("title", TEXT);
/// let schema = schema_builder.build(); /// let schema = schema_builder.build();
@@ -117,7 +92,7 @@ impl FuzzyTermQuery {
} }
} }
/// Creates a new Fuzzy Query of the Term prefix /// Creates a new Fuzzy Query that treats transpositions as cost one rather than two
pub fn new_prefix(term: Term, distance: u8, transposition_cost_one: bool) -> FuzzyTermQuery { pub fn new_prefix(term: Term, distance: u8, transposition_cost_one: bool) -> FuzzyTermQuery {
FuzzyTermQuery { FuzzyTermQuery {
term, term,
@@ -127,20 +102,13 @@ impl FuzzyTermQuery {
} }
} }
fn specialized_weight(&self) -> crate::Result<AutomatonWeight<DFAWrapper>> { fn specialized_weight(&self) -> Result<AutomatonWeight<DFA>> {
// LEV_BUILDER is a HashMap, whose `get` method returns an Option // LEV_BUILDER is a HashMap, whose `get` method returns an Option
match LEV_BUILDER.get(&(self.distance, false)) { match LEV_BUILDER.get(&(self.distance, false)) {
// Unwrap the option and build the Ok(AutomatonWeight) // Unwrap the option and build the Ok(AutomatonWeight)
Some(automaton_builder) => { Some(automaton_builder) => {
let automaton = if self.prefix { let automaton = automaton_builder.build_dfa(self.term.text());
automaton_builder.build_prefix_dfa(self.term.text()) Ok(AutomatonWeight::new(self.term.field(), automaton))
} else {
automaton_builder.build_dfa(self.term.text())
};
Ok(AutomatonWeight::new(
self.term.field(),
DFAWrapper(automaton),
))
} }
None => Err(InvalidArgument(format!( None => Err(InvalidArgument(format!(
"Levenshtein distance of {} is not allowed. Choose a value in the {:?} range", "Levenshtein distance of {} is not allowed. Choose a value in the {:?} range",
@@ -151,11 +119,7 @@ impl FuzzyTermQuery {
} }
impl Query for FuzzyTermQuery { impl Query for FuzzyTermQuery {
fn weight( fn weight(&self, _searcher: &Searcher, _scoring_enabled: bool) -> Result<Box<dyn Weight>> {
&self,
_searcher: &Searcher,
_scoring_enabled: bool,
) -> crate::Result<Box<dyn Weight>> {
Ok(Box::new(self.specialized_weight()?)) Ok(Box::new(self.specialized_weight()?))
} }
} }
@@ -188,8 +152,6 @@ mod test {
} }
let reader = index.reader().unwrap(); let reader = index.reader().unwrap();
let searcher = reader.searcher(); let searcher = reader.searcher();
// passes because Levenshtein distance is 1 (substitute 'o' with 'a')
{ {
let term = Term::from_field_text(country_field, "japon"); let term = Term::from_field_text(country_field, "japon");
@@ -201,29 +163,5 @@ mod test {
let (score, _) = top_docs[0]; let (score, _) = top_docs[0];
assert_nearly_equals(1f32, score); assert_nearly_equals(1f32, score);
} }
// fails because non-prefix Levenshtein distance is more than 1 (add 'a' and 'n')
{
let term = Term::from_field_text(country_field, "jap");
let fuzzy_query = FuzzyTermQuery::new(term, 1, true);
let top_docs = searcher
.search(&fuzzy_query, &TopDocs::with_limit(2))
.unwrap();
assert_eq!(top_docs.len(), 0, "Expected no document");
}
// passes because prefix Levenshtein distance is 0
{
let term = Term::from_field_text(country_field, "jap");
let fuzzy_query = FuzzyTermQuery::new_prefix(term, 1, true);
let top_docs = searcher
.search(&fuzzy_query, &TopDocs::with_limit(2))
.unwrap();
assert_eq!(top_docs.len(), 1, "Expected only 1 document");
let (score, _) = top_docs[0];
assert_nearly_equals(1f32, score);
}
} }
} }

View File

@@ -1,4 +1,4 @@
use crate::docset::{DocSet, TERMINATED}; use crate::docset::{DocSet, SkipResult};
use crate::query::term_query::TermScorer; use crate::query::term_query::TermScorer;
use crate::query::EmptyScorer; use crate::query::EmptyScorer;
use crate::query::Scorer; use crate::query::Scorer;
@@ -20,14 +20,12 @@ pub fn intersect_scorers(mut scorers: Vec<Box<dyn Scorer>>) -> Box<dyn Scorer> {
if scorers.len() == 1 { if scorers.len() == 1 {
return scorers.pop().unwrap(); return scorers.pop().unwrap();
} }
scorers.sort_by_key(|scorer| scorer.size_hint());
let doc = go_to_first_doc(&mut scorers[..]);
if doc == TERMINATED {
return Box::new(EmptyScorer);
}
// We know that we have at least 2 elements. // We know that we have at least 2 elements.
let left = scorers.remove(0); let num_docsets = scorers.len();
let right = scorers.remove(0); scorers.sort_by(|left, right| right.size_hint().cmp(&left.size_hint()));
let left = scorers.pop().unwrap();
let right = scorers.pop().unwrap();
scorers.reverse();
let all_term_scorers = [&left, &right] let all_term_scorers = [&left, &right]
.iter() .iter()
.all(|&scorer| scorer.is::<TermScorer>()); .all(|&scorer| scorer.is::<TermScorer>());
@@ -36,12 +34,14 @@ pub fn intersect_scorers(mut scorers: Vec<Box<dyn Scorer>>) -> Box<dyn Scorer> {
left: *(left.downcast::<TermScorer>().map_err(|_| ()).unwrap()), left: *(left.downcast::<TermScorer>().map_err(|_| ()).unwrap()),
right: *(right.downcast::<TermScorer>().map_err(|_| ()).unwrap()), right: *(right.downcast::<TermScorer>().map_err(|_| ()).unwrap()),
others: scorers, others: scorers,
num_docsets,
}); });
} }
Box::new(Intersection { Box::new(Intersection {
left, left,
right, right,
others: scorers, others: scorers,
num_docsets,
}) })
} }
@@ -50,34 +50,22 @@ pub struct Intersection<TDocSet: DocSet, TOtherDocSet: DocSet = Box<dyn Scorer>>
left: TDocSet, left: TDocSet,
right: TDocSet, right: TDocSet,
others: Vec<TOtherDocSet>, others: Vec<TOtherDocSet>,
} num_docsets: usize,
fn go_to_first_doc<TDocSet: DocSet>(docsets: &mut [TDocSet]) -> DocId {
let mut candidate = 0;
'outer: loop {
for docset in docsets.iter_mut() {
let seek_doc = docset.seek(candidate);
if seek_doc > candidate {
candidate = docset.doc();
continue 'outer;
}
}
return candidate;
}
} }
impl<TDocSet: DocSet> Intersection<TDocSet, TDocSet> { impl<TDocSet: DocSet> Intersection<TDocSet, TDocSet> {
pub(crate) fn new(mut docsets: Vec<TDocSet>) -> Intersection<TDocSet, TDocSet> { pub(crate) fn new(mut docsets: Vec<TDocSet>) -> Intersection<TDocSet, TDocSet> {
let num_docsets = docsets.len(); let num_docsets = docsets.len();
assert!(num_docsets >= 2); assert!(num_docsets >= 2);
docsets.sort_by_key(|docset| docset.size_hint()); docsets.sort_by(|left, right| right.size_hint().cmp(&left.size_hint()));
go_to_first_doc(&mut docsets); let left = docsets.pop().unwrap();
let left = docsets.remove(0); let right = docsets.pop().unwrap();
let right = docsets.remove(0); docsets.reverse();
Intersection { Intersection {
left, left,
right, right,
others: docsets, others: docsets,
num_docsets,
} }
} }
} }
@@ -92,44 +80,128 @@ impl<TDocSet: DocSet> Intersection<TDocSet, TDocSet> {
} }
} }
impl<TDocSet: DocSet, TOtherDocSet: DocSet> Intersection<TDocSet, TOtherDocSet> {
pub(crate) fn docset_mut(&mut self, ord: usize) -> &mut dyn DocSet {
match ord {
0 => &mut self.left,
1 => &mut self.right,
n => &mut self.others[n - 2],
}
}
}
impl<TDocSet: DocSet, TOtherDocSet: DocSet> DocSet for Intersection<TDocSet, TOtherDocSet> { impl<TDocSet: DocSet, TOtherDocSet: DocSet> DocSet for Intersection<TDocSet, TOtherDocSet> {
fn advance(&mut self) -> DocId { fn advance(&mut self) -> bool {
let (left, right) = (&mut self.left, &mut self.right); let (left, right) = (&mut self.left, &mut self.right);
let mut candidate = left.advance();
if !left.advance() {
return false;
}
let mut candidate = left.doc();
let mut other_candidate_ord: usize = usize::max_value();
'outer: loop { 'outer: loop {
// In the first part we look for a document in the intersection // In the first part we look for a document in the intersection
// of the two rarest `DocSet` in the intersection. // of the two rarest `DocSet` in the intersection.
loop { loop {
let right_doc = right.seek(candidate); match right.skip_next(candidate) {
candidate = left.seek(right_doc); SkipResult::Reached => {
if candidate == right_doc { break;
break; }
SkipResult::OverStep => {
candidate = right.doc();
other_candidate_ord = usize::max_value();
}
SkipResult::End => {
return false;
}
}
match left.skip_next(candidate) {
SkipResult::Reached => {
break;
}
SkipResult::OverStep => {
candidate = left.doc();
other_candidate_ord = usize::max_value();
}
SkipResult::End => {
return false;
}
} }
} }
debug_assert_eq!(left.doc(), right.doc());
// test the remaining scorers; // test the remaining scorers;
for docset in self.others.iter_mut() { for (ord, docset) in self.others.iter_mut().enumerate() {
let seek_doc = docset.seek(candidate); if ord == other_candidate_ord {
if seek_doc > candidate { continue;
candidate = left.seek(seek_doc); }
continue 'outer; // `candidate_ord` is already at the
// right position.
//
// Calling `skip_next` would advance this docset
// and miss it.
match docset.skip_next(candidate) {
SkipResult::Reached => {}
SkipResult::OverStep => {
// this is not in the intersection,
// let's update our candidate.
candidate = docset.doc();
match left.skip_next(candidate) {
SkipResult::Reached => {
other_candidate_ord = ord;
}
SkipResult::OverStep => {
candidate = left.doc();
other_candidate_ord = usize::max_value();
}
SkipResult::End => {
return false;
}
}
continue 'outer;
}
SkipResult::End => {
return false;
}
} }
} }
return true;
return candidate;
} }
} }
fn seek(&mut self, target: DocId) -> DocId { fn skip_next(&mut self, target: DocId) -> SkipResult {
self.left.seek(target); // We optimize skipping by skipping every single member
let mut docsets: Vec<&mut dyn DocSet> = vec![&mut self.left, &mut self.right]; // of the intersection to target.
for docset in &mut self.others { let mut current_target: DocId = target;
docsets.push(docset); let mut current_ord = self.num_docsets;
'outer: loop {
for ord in 0..self.num_docsets {
let docset = self.docset_mut(ord);
if ord == current_ord {
continue;
}
match docset.skip_next(current_target) {
SkipResult::End => {
return SkipResult::End;
}
SkipResult::OverStep => {
// update the target
// for the remaining members of the intersection.
current_target = docset.doc();
current_ord = ord;
continue 'outer;
}
SkipResult::Reached => {}
}
}
if target == current_target {
return SkipResult::Reached;
} else {
assert!(current_target > target);
return SkipResult::OverStep;
}
} }
go_to_first_doc(&mut docsets[..])
} }
fn doc(&self) -> DocId { fn doc(&self) -> DocId {
@@ -156,7 +228,7 @@ where
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::Intersection; use super::Intersection;
use crate::docset::{DocSet, TERMINATED}; use crate::docset::{DocSet, SkipResult};
use crate::postings::tests::test_skip_against_unoptimized; use crate::postings::tests::test_skip_against_unoptimized;
use crate::query::VecDocSet; use crate::query::VecDocSet;
@@ -166,18 +238,20 @@ mod tests {
let left = VecDocSet::from(vec![1, 3, 9]); let left = VecDocSet::from(vec![1, 3, 9]);
let right = VecDocSet::from(vec![3, 4, 9, 18]); let right = VecDocSet::from(vec![3, 4, 9, 18]);
let mut intersection = Intersection::new(vec![left, right]); let mut intersection = Intersection::new(vec![left, right]);
assert!(intersection.advance());
assert_eq!(intersection.doc(), 3); assert_eq!(intersection.doc(), 3);
assert_eq!(intersection.advance(), 9); assert!(intersection.advance());
assert_eq!(intersection.doc(), 9); assert_eq!(intersection.doc(), 9);
assert_eq!(intersection.advance(), TERMINATED); assert!(!intersection.advance());
} }
{ {
let a = VecDocSet::from(vec![1, 3, 9]); let a = VecDocSet::from(vec![1, 3, 9]);
let b = VecDocSet::from(vec![3, 4, 9, 18]); let b = VecDocSet::from(vec![3, 4, 9, 18]);
let c = VecDocSet::from(vec![1, 5, 9, 111]); let c = VecDocSet::from(vec![1, 5, 9, 111]);
let mut intersection = Intersection::new(vec![a, b, c]); let mut intersection = Intersection::new(vec![a, b, c]);
assert!(intersection.advance());
assert_eq!(intersection.doc(), 9); assert_eq!(intersection.doc(), 9);
assert_eq!(intersection.advance(), TERMINATED); assert!(!intersection.advance());
} }
} }
@@ -186,8 +260,8 @@ mod tests {
let left = VecDocSet::from(vec![0]); let left = VecDocSet::from(vec![0]);
let right = VecDocSet::from(vec![0]); let right = VecDocSet::from(vec![0]);
let mut intersection = Intersection::new(vec![left, right]); let mut intersection = Intersection::new(vec![left, right]);
assert!(intersection.advance());
assert_eq!(intersection.doc(), 0); assert_eq!(intersection.doc(), 0);
assert_eq!(intersection.advance(), TERMINATED);
} }
#[test] #[test]
@@ -195,7 +269,7 @@ mod tests {
let left = VecDocSet::from(vec![0, 1, 2, 4]); let left = VecDocSet::from(vec![0, 1, 2, 4]);
let right = VecDocSet::from(vec![2, 5]); let right = VecDocSet::from(vec![2, 5]);
let mut intersection = Intersection::new(vec![left, right]); let mut intersection = Intersection::new(vec![left, right]);
assert_eq!(intersection.seek(2), 2); assert_eq!(intersection.skip_next(2), SkipResult::Reached);
assert_eq!(intersection.doc(), 2); assert_eq!(intersection.doc(), 2);
} }
@@ -238,7 +312,7 @@ mod tests {
let a = VecDocSet::from(vec![1, 3]); let a = VecDocSet::from(vec![1, 3]);
let b = VecDocSet::from(vec![1, 4]); let b = VecDocSet::from(vec![1, 4]);
let c = VecDocSet::from(vec![3, 9]); let c = VecDocSet::from(vec![3, 9]);
let intersection = Intersection::new(vec![a, b, c]); let mut intersection = Intersection::new(vec![a, b, c]);
assert_eq!(intersection.doc(), TERMINATED); assert!(!intersection.advance());
} }
} }

View File

@@ -1,11 +1,12 @@
/*! Query Module */ /*!
Query
*/
mod all_query; mod all_query;
mod automaton_weight; mod automaton_weight;
mod bitset; mod bitset;
mod bm25; mod bm25;
mod boolean_query; mod boolean_query;
mod boost_query;
mod empty_query; mod empty_query;
mod exclude; mod exclude;
mod explanation; mod explanation;
@@ -36,12 +37,9 @@ pub use self::all_query::{AllQuery, AllScorer, AllWeight};
pub use self::automaton_weight::AutomatonWeight; pub use self::automaton_weight::AutomatonWeight;
pub use self::bitset::BitSetDocSet; pub use self::bitset::BitSetDocSet;
pub use self::boolean_query::BooleanQuery; pub use self::boolean_query::BooleanQuery;
pub use self::boost_query::BoostQuery;
pub use self::empty_query::{EmptyQuery, EmptyScorer, EmptyWeight}; pub use self::empty_query::{EmptyQuery, EmptyScorer, EmptyWeight};
pub use self::exclude::Exclude; pub use self::exclude::Exclude;
pub use self::explanation::Explanation; pub use self::explanation::Explanation;
#[cfg(test)]
pub(crate) use self::fuzzy_query::DFAWrapper;
pub use self::fuzzy_query::FuzzyTermQuery; pub use self::fuzzy_query::FuzzyTermQuery;
pub use self::intersection::intersect_scorers; pub use self::intersection::intersect_scorers;
pub use self::phrase_query::PhraseQuery; pub use self::phrase_query::PhraseQuery;

View File

@@ -7,17 +7,18 @@ pub use self::phrase_scorer::PhraseScorer;
pub use self::phrase_weight::PhraseWeight; pub use self::phrase_weight::PhraseWeight;
#[cfg(test)] #[cfg(test)]
pub mod tests { mod tests {
use super::*; use super::*;
use crate::collector::tests::{TEST_COLLECTOR_WITHOUT_SCORE, TEST_COLLECTOR_WITH_SCORE}; use crate::collector::tests::{TEST_COLLECTOR_WITHOUT_SCORE, TEST_COLLECTOR_WITH_SCORE};
use crate::core::Index; use crate::core::Index;
use crate::error::TantivyError;
use crate::schema::{Schema, Term, TEXT}; use crate::schema::{Schema, Term, TEXT};
use crate::tests::assert_nearly_equals; use crate::tests::assert_nearly_equals;
use crate::DocAddress;
use crate::DocId; use crate::DocId;
use crate::{DocAddress, DocSet};
pub fn create_index(texts: &[&'static str]) -> Index { fn create_index(texts: &[&'static str]) -> Index {
let mut schema_builder = Schema::builder(); let mut schema_builder = Schema::builder();
let text_field = schema_builder.add_text_field("text", TEXT); let text_field = schema_builder.add_text_field("text", TEXT);
let schema = schema_builder.build(); let schema = schema_builder.build();
@@ -60,8 +61,8 @@ pub mod tests {
.map(|docaddr| docaddr.1) .map(|docaddr| docaddr.1)
.collect::<Vec<_>>() .collect::<Vec<_>>()
}; };
assert_eq!(test_query(vec!["a", "b"]), vec![1, 2, 3, 4]);
assert_eq!(test_query(vec!["a", "b", "c"]), vec![2, 4]); assert_eq!(test_query(vec!["a", "b", "c"]), vec![2, 4]);
assert_eq!(test_query(vec!["a", "b"]), vec![1, 2, 3, 4]);
assert_eq!(test_query(vec!["b", "b"]), vec![0, 1]); assert_eq!(test_query(vec!["b", "b"]), vec![0, 1]);
assert!(test_query(vec!["g", "ewrwer"]).is_empty()); assert!(test_query(vec!["g", "ewrwer"]).is_empty());
assert!(test_query(vec!["g", "a"]).is_empty()); assert!(test_query(vec!["g", "a"]).is_empty());
@@ -101,6 +102,30 @@ pub mod tests {
assert!(test_query(vec!["g", "a"]).is_empty()); assert!(test_query(vec!["g", "a"]).is_empty());
} }
#[test]
pub fn test_phrase_count() {
let index = create_index(&["a c", "a a b d a b c", " a b"]);
let schema = index.schema();
let text_field = schema.get_field("text").unwrap();
let searcher = index.reader().unwrap().searcher();
let phrase_query = PhraseQuery::new(vec![
Term::from_field_text(text_field, "a"),
Term::from_field_text(text_field, "b"),
]);
let phrase_weight = phrase_query.phrase_weight(&searcher, true).unwrap();
let mut phrase_scorer = phrase_weight
.phrase_scorer(searcher.segment_reader(0u32))
.unwrap()
.unwrap();
assert!(phrase_scorer.advance());
assert_eq!(phrase_scorer.doc(), 1);
assert_eq!(phrase_scorer.phrase_count(), 2);
assert!(phrase_scorer.advance());
assert_eq!(phrase_scorer.doc(), 2);
assert_eq!(phrase_scorer.phrase_count(), 1);
assert!(!phrase_scorer.advance());
}
#[test] #[test]
pub fn test_phrase_query_no_positions() { pub fn test_phrase_query_no_positions() {
let mut schema_builder = Schema::builder(); let mut schema_builder = Schema::builder();
@@ -126,16 +151,21 @@ pub mod tests {
Term::from_field_text(text_field, "a"), Term::from_field_text(text_field, "a"),
Term::from_field_text(text_field, "b"), Term::from_field_text(text_field, "b"),
]); ]);
match searcher
let search_result = searcher
.search(&phrase_query, &TEST_COLLECTOR_WITH_SCORE) .search(&phrase_query, &TEST_COLLECTOR_WITH_SCORE)
.map(|_| ()); .map(|_| ())
assert!(matches!( .unwrap_err()
search_result, {
Err(crate::TantivyError::SchemaError(msg)) TantivyError::SchemaError(ref msg) => {
if msg == "Applied phrase query on field \"text\", which does not have positions \ assert_eq!(
indexed" "Applied phrase query on field \"text\", which does not have positions indexed",
)); msg.as_str()
);
}
_ => {
panic!("Should have returned an error");
}
}
} }
#[test] #[test]

View File

@@ -1,10 +1,12 @@
use super::PhraseWeight; use super::PhraseWeight;
use crate::core::searcher::Searcher; use crate::core::searcher::Searcher;
use crate::error::TantivyError;
use crate::query::bm25::BM25Weight; use crate::query::bm25::BM25Weight;
use crate::query::Query; use crate::query::Query;
use crate::query::Weight; use crate::query::Weight;
use crate::schema::IndexRecordOption; use crate::schema::IndexRecordOption;
use crate::schema::{Field, Term}; use crate::schema::{Field, Term};
use crate::Result;
use std::collections::BTreeSet; use std::collections::BTreeSet;
/// `PhraseQuery` matches a specific sequence of words. /// `PhraseQuery` matches a specific sequence of words.
@@ -79,7 +81,7 @@ impl PhraseQuery {
&self, &self,
searcher: &Searcher, searcher: &Searcher,
scoring_enabled: bool, scoring_enabled: bool,
) -> crate::Result<PhraseWeight> { ) -> Result<PhraseWeight> {
let schema = searcher.schema(); let schema = searcher.schema();
let field_entry = schema.get_field_entry(self.field); let field_entry = schema.get_field_entry(self.field);
let has_positions = field_entry let has_positions = field_entry
@@ -89,7 +91,7 @@ impl PhraseQuery {
.unwrap_or(false); .unwrap_or(false);
if !has_positions { if !has_positions {
let field_name = field_entry.name(); let field_name = field_entry.name();
return Err(crate::TantivyError::SchemaError(format!( return Err(TantivyError::SchemaError(format!(
"Applied phrase query on field {:?}, which does not have positions indexed", "Applied phrase query on field {:?}, which does not have positions indexed",
field_name field_name
))); )));
@@ -108,7 +110,7 @@ impl Query for PhraseQuery {
/// Create the weight associated to a query. /// Create the weight associated to a query.
/// ///
/// See [`Weight`](./trait.Weight.html). /// See [`Weight`](./trait.Weight.html).
fn weight(&self, searcher: &Searcher, scoring_enabled: bool) -> crate::Result<Box<dyn Weight>> { fn weight(&self, searcher: &Searcher, scoring_enabled: bool) -> Result<Box<dyn Weight>> {
let phrase_weight = self.phrase_weight(searcher, scoring_enabled)?; let phrase_weight = self.phrase_weight(searcher, scoring_enabled)?;
Ok(Box::new(phrase_weight)) Ok(Box::new(phrase_weight))
} }

View File

@@ -1,4 +1,4 @@
use crate::docset::{DocSet, TERMINATED}; use crate::docset::{DocSet, SkipResult};
use crate::fieldnorm::FieldNormReader; use crate::fieldnorm::FieldNormReader;
use crate::postings::Postings; use crate::postings::Postings;
use crate::query::bm25::BM25Weight; use crate::query::bm25::BM25Weight;
@@ -25,12 +25,12 @@ impl<TPostings: Postings> PostingsWithOffset<TPostings> {
} }
impl<TPostings: Postings> DocSet for PostingsWithOffset<TPostings> { impl<TPostings: Postings> DocSet for PostingsWithOffset<TPostings> {
fn advance(&mut self) -> DocId { fn advance(&mut self) -> bool {
self.postings.advance() self.postings.advance()
} }
fn seek(&mut self, target: DocId) -> DocId { fn skip_next(&mut self, target: DocId) -> SkipResult {
self.postings.seek(target) self.postings.skip_next(target)
} }
fn doc(&self) -> DocId { fn doc(&self) -> DocId {
@@ -149,7 +149,7 @@ impl<TPostings: Postings> PhraseScorer<TPostings> {
PostingsWithOffset::new(postings, (max_offset - offset) as u32) PostingsWithOffset::new(postings, (max_offset - offset) as u32)
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let mut scorer = PhraseScorer { PhraseScorer {
intersection_docset: Intersection::new(postings_with_offsets), intersection_docset: Intersection::new(postings_with_offsets),
num_terms: num_docsets, num_terms: num_docsets,
left: Vec::with_capacity(100), left: Vec::with_capacity(100),
@@ -158,11 +158,7 @@ impl<TPostings: Postings> PhraseScorer<TPostings> {
similarity_weight, similarity_weight,
fieldnorm_reader, fieldnorm_reader,
score_needed, score_needed,
};
if scorer.doc() != TERMINATED && !scorer.phrase_match() {
scorer.advance();
} }
scorer
} }
pub fn phrase_count(&self) -> u32 { pub fn phrase_count(&self) -> u32 {
@@ -229,21 +225,31 @@ impl<TPostings: Postings> PhraseScorer<TPostings> {
} }
impl<TPostings: Postings> DocSet for PhraseScorer<TPostings> { impl<TPostings: Postings> DocSet for PhraseScorer<TPostings> {
fn advance(&mut self) -> DocId { fn advance(&mut self) -> bool {
loop { while self.intersection_docset.advance() {
let doc = self.intersection_docset.advance(); if self.phrase_match() {
if doc == TERMINATED || self.phrase_match() { return true;
return doc;
} }
} }
false
} }
fn seek(&mut self, target: DocId) -> DocId { fn skip_next(&mut self, target: DocId) -> SkipResult {
let doc = self.intersection_docset.seek(target); if self.intersection_docset.skip_next(target) == SkipResult::End {
if doc == TERMINATED || self.phrase_match() { return SkipResult::End;
return doc; }
if self.phrase_match() {
if self.doc() == target {
return SkipResult::Reached;
} else {
return SkipResult::OverStep;
}
}
if self.advance() {
SkipResult::OverStep
} else {
SkipResult::End
} }
self.advance()
} }
fn doc(&self) -> DocId { fn doc(&self) -> DocId {

View File

@@ -9,8 +9,8 @@ use crate::query::Weight;
use crate::query::{EmptyScorer, Explanation}; use crate::query::{EmptyScorer, Explanation};
use crate::schema::IndexRecordOption; use crate::schema::IndexRecordOption;
use crate::schema::Term; use crate::schema::Term;
use crate::Result;
use crate::{DocId, DocSet}; use crate::{DocId, DocSet};
use crate::{Result, SkipResult};
pub struct PhraseWeight { pub struct PhraseWeight {
phrase_terms: Vec<(usize, Term)>, phrase_terms: Vec<(usize, Term)>,
@@ -37,12 +37,11 @@ impl PhraseWeight {
reader.get_fieldnorms_reader(field) reader.get_fieldnorms_reader(field)
} }
fn phrase_scorer( pub fn phrase_scorer(
&self, &self,
reader: &SegmentReader, reader: &SegmentReader,
boost: f32,
) -> Result<Option<PhraseScorer<SegmentPostings>>> { ) -> Result<Option<PhraseScorer<SegmentPostings>>> {
let similarity_weight = self.similarity_weight.boost_by(boost); let similarity_weight = self.similarity_weight.clone();
let fieldnorm_reader = self.fieldnorm_reader(reader); let fieldnorm_reader = self.fieldnorm_reader(reader);
if reader.has_deletes() { if reader.has_deletes() {
let mut term_postings_list = Vec::new(); let mut term_postings_list = Vec::new();
@@ -85,8 +84,8 @@ impl PhraseWeight {
} }
impl Weight for PhraseWeight { impl Weight for PhraseWeight {
fn scorer(&self, reader: &SegmentReader, boost: f32) -> Result<Box<dyn Scorer>> { fn scorer(&self, reader: &SegmentReader) -> Result<Box<dyn Scorer>> {
if let Some(scorer) = self.phrase_scorer(reader, boost)? { if let Some(scorer) = self.phrase_scorer(reader)? {
Ok(Box::new(scorer)) Ok(Box::new(scorer))
} else { } else {
Ok(Box::new(EmptyScorer)) Ok(Box::new(EmptyScorer))
@@ -94,12 +93,12 @@ impl Weight for PhraseWeight {
} }
fn explain(&self, reader: &SegmentReader, doc: DocId) -> Result<Explanation> { fn explain(&self, reader: &SegmentReader, doc: DocId) -> Result<Explanation> {
let scorer_opt = self.phrase_scorer(reader, 1.0f32)?; let scorer_opt = self.phrase_scorer(reader)?;
if scorer_opt.is_none() { if scorer_opt.is_none() {
return Err(does_not_match(doc)); return Err(does_not_match(doc));
} }
let mut scorer = scorer_opt.unwrap(); let mut scorer = scorer_opt.unwrap();
if scorer.seek(doc) != doc { if scorer.skip_next(doc) != SkipResult::Reached {
return Err(does_not_match(doc)); return Err(does_not_match(doc));
} }
let fieldnorm_reader = self.fieldnorm_reader(reader); let fieldnorm_reader = self.fieldnorm_reader(reader);
@@ -110,34 +109,3 @@ impl Weight for PhraseWeight {
Ok(explanation) Ok(explanation)
} }
} }
#[cfg(test)]
mod tests {
use super::super::tests::create_index;
use crate::docset::TERMINATED;
use crate::query::PhraseQuery;
use crate::{DocSet, Term};
#[test]
pub fn test_phrase_count() {
let index = create_index(&["a c", "a a b d a b c", " a b"]);
let schema = index.schema();
let text_field = schema.get_field("text").unwrap();
let searcher = index.reader().unwrap().searcher();
let phrase_query = PhraseQuery::new(vec![
Term::from_field_text(text_field, "a"),
Term::from_field_text(text_field, "b"),
]);
let phrase_weight = phrase_query.phrase_weight(&searcher, true).unwrap();
let mut phrase_scorer = phrase_weight
.phrase_scorer(searcher.segment_reader(0u32), 1.0f32)
.unwrap()
.unwrap();
assert_eq!(phrase_scorer.doc(), 1);
assert_eq!(phrase_scorer.phrase_count(), 2);
assert_eq!(phrase_scorer.advance(), 2);
assert_eq!(phrase_scorer.doc(), 2);
assert_eq!(phrase_scorer.phrase_count(), 1);
assert_eq!(phrase_scorer.advance(), TERMINATED);
}
}

View File

@@ -2,6 +2,7 @@ use super::Weight;
use crate::core::searcher::Searcher; use crate::core::searcher::Searcher;
use crate::query::Explanation; use crate::query::Explanation;
use crate::DocAddress; use crate::DocAddress;
use crate::Result;
use crate::Term; use crate::Term;
use downcast_rs::impl_downcast; use downcast_rs::impl_downcast;
use std::collections::BTreeSet; use std::collections::BTreeSet;
@@ -47,17 +48,17 @@ pub trait Query: QueryClone + downcast_rs::Downcast + fmt::Debug {
/// can increase performances. /// can increase performances.
/// ///
/// See [`Weight`](./trait.Weight.html). /// See [`Weight`](./trait.Weight.html).
fn weight(&self, searcher: &Searcher, scoring_enabled: bool) -> crate::Result<Box<dyn Weight>>; fn weight(&self, searcher: &Searcher, scoring_enabled: bool) -> Result<Box<dyn Weight>>;
/// Returns an `Explanation` for the score of the document. /// Returns an `Explanation` for the score of the document.
fn explain(&self, searcher: &Searcher, doc_address: DocAddress) -> crate::Result<Explanation> { fn explain(&self, searcher: &Searcher, doc_address: DocAddress) -> Result<Explanation> {
let reader = searcher.segment_reader(doc_address.segment_ord()); let reader = searcher.segment_reader(doc_address.segment_ord());
let weight = self.weight(searcher, true)?; let weight = self.weight(searcher, true)?;
weight.explain(reader, doc_address.doc()) weight.explain(reader, doc_address.doc())
} }
/// Returns the number of documents matching the query. /// Returns the number of documents matching the query.
fn count(&self, searcher: &Searcher) -> crate::Result<usize> { fn count(&self, searcher: &Searcher) -> Result<usize> {
let weight = self.weight(searcher, false)?; let weight = self.weight(searcher, false)?;
let mut result = 0; let mut result = 0;
for reader in searcher.segment_readers() { for reader in searcher.segment_readers() {
@@ -85,11 +86,11 @@ where
} }
impl Query for Box<dyn Query> { impl Query for Box<dyn Query> {
fn weight(&self, searcher: &Searcher, scoring_enabled: bool) -> crate::Result<Box<dyn Weight>> { fn weight(&self, searcher: &Searcher, scoring_enabled: bool) -> Result<Box<dyn Weight>> {
self.as_ref().weight(searcher, scoring_enabled) self.as_ref().weight(searcher, scoring_enabled)
} }
fn count(&self, searcher: &Searcher) -> crate::Result<usize> { fn count(&self, searcher: &Searcher) -> Result<usize> {
self.as_ref().count(searcher) self.as_ref().count(searcher)
} }

View File

@@ -21,17 +21,6 @@ pub enum LogicalLiteral {
pub enum LogicalAST { pub enum LogicalAST {
Clause(Vec<(Occur, LogicalAST)>), Clause(Vec<(Occur, LogicalAST)>),
Leaf(Box<LogicalLiteral>), Leaf(Box<LogicalLiteral>),
Boost(Box<LogicalAST>, f32),
}
impl LogicalAST {
pub fn boost(self, boost: f32) -> LogicalAST {
if (boost - 1.0f32).abs() < std::f32::EPSILON {
self
} else {
LogicalAST::Boost(Box::new(self), boost)
}
}
} }
fn occur_letter(occur: Occur) -> &'static str { fn occur_letter(occur: Occur) -> &'static str {
@@ -58,7 +47,6 @@ impl fmt::Debug for LogicalAST {
} }
Ok(()) Ok(())
} }
LogicalAST::Boost(ref ast, boost) => write!(formatter, "{:?}^{}", ast, boost),
LogicalAST::Leaf(ref literal) => write!(formatter, "{:?}", literal), LogicalAST::Leaf(ref literal) => write!(formatter, "{:?}", literal),
} }
} }

View File

@@ -1,5 +1,6 @@
use super::logical_ast::*; use super::logical_ast::*;
use crate::core::Index; use crate::core::Index;
use crate::query::AllQuery;
use crate::query::BooleanQuery; use crate::query::BooleanQuery;
use crate::query::EmptyQuery; use crate::query::EmptyQuery;
use crate::query::Occur; use crate::query::Occur;
@@ -7,13 +8,11 @@ use crate::query::PhraseQuery;
use crate::query::Query; use crate::query::Query;
use crate::query::RangeQuery; use crate::query::RangeQuery;
use crate::query::TermQuery; use crate::query::TermQuery;
use crate::query::{AllQuery, BoostQuery};
use crate::schema::{Facet, IndexRecordOption}; use crate::schema::{Facet, IndexRecordOption};
use crate::schema::{Field, Schema}; use crate::schema::{Field, Schema};
use crate::schema::{FieldType, Term}; use crate::schema::{FieldType, Term};
use crate::tokenizer::TokenizerManager; use crate::tokenizer::TokenizerManager;
use std::borrow::Cow; use std::borrow::Cow;
use std::collections::HashMap;
use std::num::{ParseFloatError, ParseIntError}; use std::num::{ParseFloatError, ParseIntError};
use std::ops::Bound; use std::ops::Bound;
use std::str::FromStr; use std::str::FromStr;
@@ -113,9 +112,8 @@ fn trim_ast(logical_ast: LogicalAST) -> Option<LogicalAST> {
/// The language covered by the current parser is extremely simple. /// The language covered by the current parser is extremely simple.
/// ///
/// * simple terms: "e.g.: `Barack Obama` are simply tokenized using /// * simple terms: "e.g.: `Barack Obama` are simply tokenized using
/// tantivy's [`SimpleTokenizer`](tantivy::tokenizer::SimpleTokenizer), hence /// tantivy's `StandardTokenizer`, hence becoming `["barack", "obama"]`.
/// becoming `["barack", "obama"]`. The terms are then searched within /// The terms are then searched within the default terms of the query parser.
/// the default terms of the query parser.
/// ///
/// e.g. If `body` and `title` are default fields, our example terms are /// e.g. If `body` and `title` are default fields, our example terms are
/// `["title:barack", "body:barack", "title:obama", "body:obama"]`. /// `["title:barack", "body:barack", "title:obama", "body:obama"]`.
@@ -146,6 +144,7 @@ fn trim_ast(logical_ast: LogicalAST) -> Option<LogicalAST> {
/// ///
/// * must terms: By prepending a term by a `+`, a term can be made required for the search. /// * must terms: By prepending a term by a `+`, a term can be made required for the search.
/// ///
///
/// * phrase terms: Quoted terms become phrase searches on fields that have positions indexed. /// * phrase terms: Quoted terms become phrase searches on fields that have positions indexed.
/// e.g., `title:"Barack Obama"` will only find documents that have "barack" immediately followed /// e.g., `title:"Barack Obama"` will only find documents that have "barack" immediately followed
/// by "obama". /// by "obama".
@@ -159,30 +158,12 @@ fn trim_ast(logical_ast: LogicalAST) -> Option<LogicalAST> {
/// ///
/// * all docs query: A plain `*` will match all documents in the index. /// * all docs query: A plain `*` will match all documents in the index.
/// ///
/// Parts of the queries can be boosted by appending `^boostfactor`.
/// For instance, `"SRE"^2.0 OR devops^0.4` will boost documents containing `SRE` instead of
/// devops. Negative boosts are not allowed.
///
/// It is also possible to define a boost for a some specific field, at the query parser level.
/// (See [`set_boost(...)`](#method.set_field_boost) ). Typically you may want to boost a title
/// field.
#[derive(Clone)] #[derive(Clone)]
pub struct QueryParser { pub struct QueryParser {
schema: Schema, schema: Schema,
default_fields: Vec<Field>, default_fields: Vec<Field>,
conjunction_by_default: bool, conjunction_by_default: bool,
tokenizer_manager: TokenizerManager, tokenizer_manager: TokenizerManager,
boost: HashMap<Field, f32>,
}
fn all_negative(ast: &LogicalAST) -> bool {
match ast {
LogicalAST::Leaf(_) => false,
LogicalAST::Boost(ref child_ast, _) => all_negative(&*child_ast),
LogicalAST::Clause(children) => children
.iter()
.all(|(ref occur, child)| (*occur == Occur::MustNot) || all_negative(child)),
}
} }
impl QueryParser { impl QueryParser {
@@ -200,7 +181,6 @@ impl QueryParser {
default_fields, default_fields,
tokenizer_manager, tokenizer_manager,
conjunction_by_default: false, conjunction_by_default: false,
boost: Default::default(),
} }
} }
@@ -221,17 +201,6 @@ impl QueryParser {
self.conjunction_by_default = true; self.conjunction_by_default = true;
} }
/// Sets a boost for a specific field.
///
/// The parse query will automatically boost this field.
///
/// If the query defines a query boost through the query language (e.g: `country:France^3.0`),
/// the two boosts (the one defined in the query, and the one defined in the `QueryParser`)
/// are multiplied together.
pub fn set_field_boost(&mut self, field: Field, boost: f32) {
self.boost.insert(field, boost);
}
/// Parse a query /// Parse a query
/// ///
/// Note that `parse_query` returns an error if the input /// Note that `parse_query` returns an error if the input
@@ -264,13 +233,8 @@ impl QueryParser {
&self, &self,
user_input_ast: UserInputAST, user_input_ast: UserInputAST,
) -> Result<LogicalAST, QueryParserError> { ) -> Result<LogicalAST, QueryParserError> {
let ast = self.compute_logical_ast_with_occur(user_input_ast)?; let (occur, ast) = self.compute_logical_ast_with_occur(user_input_ast)?;
if let LogicalAST::Clause(children) = &ast { if occur == Occur::MustNot {
if children.is_empty() {
return Ok(ast);
}
}
if all_negative(&ast) {
return Err(QueryParserError::AllButQueryForbidden); return Err(QueryParserError::AllButQueryForbidden);
} }
Ok(ast) Ok(ast)
@@ -426,30 +390,30 @@ impl QueryParser {
fn compute_logical_ast_with_occur( fn compute_logical_ast_with_occur(
&self, &self,
user_input_ast: UserInputAST, user_input_ast: UserInputAST,
) -> Result<LogicalAST, QueryParserError> { ) -> Result<(Occur, LogicalAST), QueryParserError> {
match user_input_ast { match user_input_ast {
UserInputAST::Clause(sub_queries) => { UserInputAST::Clause(sub_queries) => {
let default_occur = self.default_occur(); let default_occur = self.default_occur();
let mut logical_sub_queries: Vec<(Occur, LogicalAST)> = Vec::new(); let mut logical_sub_queries: Vec<(Occur, LogicalAST)> = Vec::new();
for (occur_opt, sub_ast) in sub_queries { for sub_query in sub_queries {
let sub_ast = self.compute_logical_ast_with_occur(sub_ast)?; let (occur, sub_ast) = self.compute_logical_ast_with_occur(sub_query)?;
let occur = occur_opt.unwrap_or(default_occur); let new_occur = Occur::compose(default_occur, occur);
logical_sub_queries.push((occur, sub_ast)); logical_sub_queries.push((new_occur, sub_ast));
} }
Ok(LogicalAST::Clause(logical_sub_queries)) Ok((Occur::Should, LogicalAST::Clause(logical_sub_queries)))
} }
UserInputAST::Boost(ast, boost) => { UserInputAST::Unary(left_occur, subquery) => {
let ast = self.compute_logical_ast_with_occur(*ast)?; let (right_occur, logical_sub_queries) =
Ok(ast.boost(boost)) self.compute_logical_ast_with_occur(*subquery)?;
Ok((Occur::compose(left_occur, right_occur), logical_sub_queries))
}
UserInputAST::Leaf(leaf) => {
let result_ast = self.compute_logical_ast_from_leaf(*leaf)?;
Ok((Occur::Should, result_ast))
} }
UserInputAST::Leaf(leaf) => self.compute_logical_ast_from_leaf(*leaf),
} }
} }
fn field_boost(&self, field: Field) -> f32 {
self.boost.get(&field).cloned().unwrap_or(1.0f32)
}
fn compute_logical_ast_from_leaf( fn compute_logical_ast_from_leaf(
&self, &self,
leaf: UserInputLeaf, leaf: UserInputLeaf,
@@ -475,9 +439,7 @@ impl QueryParser {
let mut asts: Vec<LogicalAST> = Vec::new(); let mut asts: Vec<LogicalAST> = Vec::new();
for (field, phrase) in term_phrases { for (field, phrase) in term_phrases {
if let Some(ast) = self.compute_logical_ast_for_leaf(field, &phrase)? { if let Some(ast) = self.compute_logical_ast_for_leaf(field, &phrase)? {
// Apply some field specific boost defined at the query parser level. asts.push(LogicalAST::Leaf(Box::new(ast)));
let boost = self.field_boost(field);
asts.push(LogicalAST::Leaf(Box::new(ast)).boost(boost));
} }
} }
let result_ast: LogicalAST = if asts.len() == 1 { let result_ast: LogicalAST = if asts.len() == 1 {
@@ -497,16 +459,14 @@ impl QueryParser {
let mut clauses = fields let mut clauses = fields
.iter() .iter()
.map(|&field| { .map(|&field| {
let boost = self.field_boost(field);
let field_entry = self.schema.get_field_entry(field); let field_entry = self.schema.get_field_entry(field);
let value_type = field_entry.field_type().value_type(); let value_type = field_entry.field_type().value_type();
let logical_ast = LogicalAST::Leaf(Box::new(LogicalLiteral::Range { Ok(LogicalAST::Leaf(Box::new(LogicalLiteral::Range {
field, field,
value_type, value_type,
lower: self.resolve_bound(field, &lower)?, lower: self.resolve_bound(field, &lower)?,
upper: self.resolve_bound(field, &upper)?, upper: self.resolve_bound(field, &upper)?,
})); })))
Ok(logical_ast.boost(boost))
}) })
.collect::<Result<Vec<_>, QueryParserError>>()?; .collect::<Result<Vec<_>, QueryParserError>>()?;
let result_ast = if clauses.len() == 1 { let result_ast = if clauses.len() == 1 {
@@ -559,11 +519,6 @@ fn convert_to_query(logical_ast: LogicalAST) -> Box<dyn Query> {
Some(LogicalAST::Leaf(trimmed_logical_literal)) => { Some(LogicalAST::Leaf(trimmed_logical_literal)) => {
convert_literal_to_query(*trimmed_logical_literal) convert_literal_to_query(*trimmed_logical_literal)
} }
Some(LogicalAST::Boost(ast, boost)) => {
let query = convert_to_query(*ast);
let boosted_query = BoostQuery::new(query, boost);
Box::new(boosted_query)
}
None => Box::new(EmptyQuery), None => Box::new(EmptyQuery),
} }
} }
@@ -578,12 +533,12 @@ mod test {
use crate::schema::{IndexRecordOption, TextFieldIndexing, TextOptions}; use crate::schema::{IndexRecordOption, TextFieldIndexing, TextOptions};
use crate::schema::{Schema, Term, INDEXED, STORED, STRING, TEXT}; use crate::schema::{Schema, Term, INDEXED, STORED, STRING, TEXT};
use crate::tokenizer::{ use crate::tokenizer::{
LowerCaser, SimpleTokenizer, StopWordFilter, TextAnalyzer, TokenizerManager, LowerCaser, SimpleTokenizer, StopWordFilter, Tokenizer, TokenizerManager,
}; };
use crate::Index; use crate::Index;
use matches::assert_matches; use matches::assert_matches;
fn make_schema() -> Schema { fn make_query_parser() -> QueryParser {
let mut schema_builder = Schema::builder(); let mut schema_builder = Schema::builder();
let text_field_indexing = TextFieldIndexing::default() let text_field_indexing = TextFieldIndexing::default()
.set_tokenizer("en_with_stop_words") .set_tokenizer("en_with_stop_words")
@@ -591,8 +546,8 @@ mod test {
let text_options = TextOptions::default() let text_options = TextOptions::default()
.set_indexing_options(text_field_indexing) .set_indexing_options(text_field_indexing)
.set_stored(); .set_stored();
schema_builder.add_text_field("title", TEXT); let title = schema_builder.add_text_field("title", TEXT);
schema_builder.add_text_field("text", TEXT); let text = schema_builder.add_text_field("text", TEXT);
schema_builder.add_i64_field("signed", INDEXED); schema_builder.add_i64_field("signed", INDEXED);
schema_builder.add_u64_field("unsigned", INDEXED); schema_builder.add_u64_field("unsigned", INDEXED);
schema_builder.add_text_field("notindexed_text", STORED); schema_builder.add_text_field("notindexed_text", STORED);
@@ -603,19 +558,12 @@ mod test {
schema_builder.add_date_field("date", INDEXED); schema_builder.add_date_field("date", INDEXED);
schema_builder.add_f64_field("float", INDEXED); schema_builder.add_f64_field("float", INDEXED);
schema_builder.add_facet_field("facet"); schema_builder.add_facet_field("facet");
schema_builder.build() let schema = schema_builder.build();
} let default_fields = vec![title, text];
fn make_query_parser() -> QueryParser {
let schema = make_schema();
let default_fields: Vec<Field> = vec!["title", "text"]
.into_iter()
.flat_map(|field_name| schema.get_field(field_name))
.collect();
let tokenizer_manager = TokenizerManager::default(); let tokenizer_manager = TokenizerManager::default();
tokenizer_manager.register( tokenizer_manager.register(
"en_with_stop_words", "en_with_stop_words",
TextAnalyzer::from(SimpleTokenizer) SimpleTokenizer
.filter(LowerCaser) .filter(LowerCaser)
.filter(StopWordFilter::remove(vec!["the".to_string()])), .filter(StopWordFilter::remove(vec!["the".to_string()])),
); );
@@ -653,45 +601,6 @@ mod test {
); );
} }
#[test]
pub fn test_parse_query_with_boost() {
let mut query_parser = make_query_parser();
let schema = make_schema();
let text_field = schema.get_field("text").unwrap();
query_parser.set_field_boost(text_field, 2.0f32);
let query = query_parser.parse_query("text:hello").unwrap();
assert_eq!(
format!("{:?}", query),
"Boost(query=TermQuery(Term(field=1,bytes=[104, 101, 108, 108, 111])), boost=2)"
);
}
#[test]
pub fn test_parse_query_range_with_boost() {
let mut query_parser = make_query_parser();
let schema = make_schema();
let title_field = schema.get_field("title").unwrap();
query_parser.set_field_boost(title_field, 2.0f32);
let query = query_parser.parse_query("title:[A TO B]").unwrap();
assert_eq!(
format!("{:?}", query),
"Boost(query=RangeQuery { field: Field(0), value_type: Str, left_bound: Included([97]), right_bound: Included([98]) }, boost=2)"
);
}
#[test]
pub fn test_parse_query_with_default_boost_and_custom_boost() {
let mut query_parser = make_query_parser();
let schema = make_schema();
let text_field = schema.get_field("text").unwrap();
query_parser.set_field_boost(text_field, 2.0f32);
let query = query_parser.parse_query("text:hello^2").unwrap();
assert_eq!(
format!("{:?}", query),
"Boost(query=Boost(query=TermQuery(Term(field=1,bytes=[104, 101, 108, 108, 111])), boost=2), boost=2)"
);
}
#[test] #[test]
pub fn test_parse_nonindexed_field_yields_error() { pub fn test_parse_nonindexed_field_yields_error() {
let query_parser = make_query_parser(); let query_parser = make_query_parser();
@@ -790,20 +699,6 @@ mod test {
); );
} }
#[test]
fn test_parse_query_to_ast_ab_c() {
test_parse_query_to_logical_ast_helper(
"(+title:a +title:b) title:c",
"((+Term(field=0,bytes=[97]) +Term(field=0,bytes=[98])) Term(field=0,bytes=[99]))",
false,
);
test_parse_query_to_logical_ast_helper(
"(+title:a +title:b) title:c",
"(+(+Term(field=0,bytes=[97]) +Term(field=0,bytes=[98])) +Term(field=0,bytes=[99]))",
true,
);
}
#[test] #[test]
pub fn test_parse_query_to_ast_single_term() { pub fn test_parse_query_to_ast_single_term() {
test_parse_query_to_logical_ast_helper( test_parse_query_to_logical_ast_helper(
@@ -823,13 +718,11 @@ mod test {
Term(field=1,bytes=[116, 105, 116, 105])))", Term(field=1,bytes=[116, 105, 116, 105])))",
false, false,
); );
} assert_eq!(
parse_query_to_logical_ast("-title:toto", false)
#[test] .err()
fn test_single_negative_term() { .unwrap(),
assert_matches!( QueryParserError::AllButQueryForbidden
parse_query_to_logical_ast("-title:toto", false),
Err(QueryParserError::AllButQueryForbidden)
); );
} }
@@ -989,18 +882,6 @@ mod test {
assert!(query_parser.parse_query("with_stop_words:the").is_ok()); assert!(query_parser.parse_query("with_stop_words:the").is_ok());
} }
#[test]
pub fn test_parse_query_single_negative_term_through_error() {
assert_matches!(
parse_query_to_logical_ast("-title:toto", true),
Err(QueryParserError::AllButQueryForbidden)
);
assert_matches!(
parse_query_to_logical_ast("-title:toto", false),
Err(QueryParserError::AllButQueryForbidden)
);
}
#[test] #[test]
pub fn test_parse_query_to_ast_conjunction() { pub fn test_parse_query_to_ast_conjunction() {
test_parse_query_to_logical_ast_helper( test_parse_query_to_logical_ast_helper(
@@ -1020,6 +901,12 @@ mod test {
Term(field=1,bytes=[116, 105, 116, 105])))", Term(field=1,bytes=[116, 105, 116, 105])))",
true, true,
); );
assert_eq!(
parse_query_to_logical_ast("-title:toto", true)
.err()
.unwrap(),
QueryParserError::AllButQueryForbidden
);
test_parse_query_to_logical_ast_helper( test_parse_query_to_logical_ast_helper(
"title:a b", "title:a b",
"(+Term(field=0,bytes=[97]) \ "(+Term(field=0,bytes=[97]) \
@@ -1043,26 +930,4 @@ mod test {
false false
); );
} }
#[test]
fn test_and_default_regardless_of_default_conjunctive() {
for &default_conjunction in &[false, true] {
test_parse_query_to_logical_ast_helper(
"title:a AND title:b",
"(+Term(field=0,bytes=[97]) +Term(field=0,bytes=[98]))",
default_conjunction,
);
}
}
#[test]
fn test_or_default_conjunctive() {
for &default_conjunction in &[false, true] {
test_parse_query_to_logical_ast_helper(
"title:a OR title:b",
"(Term(field=0,bytes=[97]) Term(field=0,bytes=[98]))",
default_conjunction,
);
}
}
} }

View File

@@ -10,7 +10,7 @@ use crate::schema::Type;
use crate::schema::{Field, IndexRecordOption, Term}; use crate::schema::{Field, IndexRecordOption, Term};
use crate::termdict::{TermDictionary, TermStreamer}; use crate::termdict::{TermDictionary, TermStreamer};
use crate::DocId; use crate::DocId;
use crate::Result; use crate::{Result, SkipResult};
use std::collections::Bound; use std::collections::Bound;
use std::ops::Range; use std::ops::Range;
@@ -289,7 +289,7 @@ impl RangeWeight {
} }
impl Weight for RangeWeight { impl Weight for RangeWeight {
fn scorer(&self, reader: &SegmentReader, boost: f32) -> Result<Box<dyn Scorer>> { fn scorer(&self, reader: &SegmentReader) -> Result<Box<dyn Scorer>> {
let max_doc = reader.max_doc(); let max_doc = reader.max_doc();
let mut doc_bitset = BitSet::with_max_value(max_doc); let mut doc_bitset = BitSet::with_max_value(max_doc);
@@ -300,22 +300,19 @@ impl Weight for RangeWeight {
let term_info = term_range.value(); let term_info = term_range.value();
let mut block_segment_postings = inverted_index let mut block_segment_postings = inverted_index
.read_block_postings_from_terminfo(term_info, IndexRecordOption::Basic); .read_block_postings_from_terminfo(term_info, IndexRecordOption::Basic);
loop { while block_segment_postings.advance() {
for &doc in block_segment_postings.docs() { for &doc in block_segment_postings.docs() {
doc_bitset.insert(doc); doc_bitset.insert(doc);
} }
if !block_segment_postings.advance() {
break;
}
} }
} }
let doc_bitset = BitSetDocSet::from(doc_bitset); let doc_bitset = BitSetDocSet::from(doc_bitset);
Ok(Box::new(ConstScorer::new(doc_bitset, boost))) Ok(Box::new(ConstScorer::new(doc_bitset)))
} }
fn explain(&self, reader: &SegmentReader, doc: DocId) -> Result<Explanation> { fn explain(&self, reader: &SegmentReader, doc: DocId) -> Result<Explanation> {
let mut scorer = self.scorer(reader, 1.0f32)?; let mut scorer = self.scorer(reader)?;
if scorer.seek(doc) != doc { if scorer.skip_next(doc) != SkipResult::Reached {
return Err(does_not_match(doc)); return Err(does_not_match(doc));
} }
Ok(Explanation::new("RangeQuery", 1.0f32)) Ok(Explanation::new("RangeQuery", 1.0f32))

View File

@@ -1,6 +1,7 @@
use crate::error::TantivyError; use crate::error::TantivyError;
use crate::query::{AutomatonWeight, Query, Weight}; use crate::query::{AutomatonWeight, Query, Weight};
use crate::schema::Field; use crate::schema::Field;
use crate::Result;
use crate::Searcher; use crate::Searcher;
use std::clone::Clone; use std::clone::Clone;
use std::sync::Arc; use std::sync::Arc;
@@ -57,7 +58,7 @@ pub struct RegexQuery {
impl RegexQuery { impl RegexQuery {
/// Creates a new RegexQuery from a given pattern /// Creates a new RegexQuery from a given pattern
pub fn from_pattern(regex_pattern: &str, field: Field) -> crate::Result<Self> { pub fn from_pattern(regex_pattern: &str, field: Field) -> Result<Self> {
let regex = Regex::new(&regex_pattern) let regex = Regex::new(&regex_pattern)
.map_err(|_| TantivyError::InvalidArgument(regex_pattern.to_string()))?; .map_err(|_| TantivyError::InvalidArgument(regex_pattern.to_string()))?;
Ok(RegexQuery::from_regex(regex, field)) Ok(RegexQuery::from_regex(regex, field))
@@ -77,11 +78,7 @@ impl RegexQuery {
} }
impl Query for RegexQuery { impl Query for RegexQuery {
fn weight( fn weight(&self, _searcher: &Searcher, _scoring_enabled: bool) -> Result<Box<dyn Weight>> {
&self,
_searcher: &Searcher,
_scoring_enabled: bool,
) -> crate::Result<Box<dyn Weight>> {
Ok(Box::new(self.specialized_weight())) Ok(Box::new(self.specialized_weight()))
} }
} }

View File

@@ -1,8 +1,9 @@
use crate::docset::DocSet; use crate::docset::{DocSet, SkipResult};
use crate::query::score_combiner::ScoreCombiner; use crate::query::score_combiner::ScoreCombiner;
use crate::query::Scorer; use crate::query::Scorer;
use crate::DocId; use crate::DocId;
use crate::Score; use crate::Score;
use std::cmp::Ordering;
use std::marker::PhantomData; use std::marker::PhantomData;
/// Given a required scorer and an optional scorer /// Given a required scorer and an optional scorer
@@ -16,6 +17,7 @@ pub struct RequiredOptionalScorer<TReqScorer, TOptScorer, TScoreCombiner> {
req_scorer: TReqScorer, req_scorer: TReqScorer,
opt_scorer: TOptScorer, opt_scorer: TOptScorer,
score_cache: Option<Score>, score_cache: Option<Score>,
opt_finished: bool,
_phantom: PhantomData<TScoreCombiner>, _phantom: PhantomData<TScoreCombiner>,
} }
@@ -27,12 +29,14 @@ where
/// Creates a new `RequiredOptionalScorer`. /// Creates a new `RequiredOptionalScorer`.
pub fn new( pub fn new(
req_scorer: TReqScorer, req_scorer: TReqScorer,
opt_scorer: TOptScorer, mut opt_scorer: TOptScorer,
) -> RequiredOptionalScorer<TReqScorer, TOptScorer, TScoreCombiner> { ) -> RequiredOptionalScorer<TReqScorer, TOptScorer, TScoreCombiner> {
let opt_finished = !opt_scorer.advance();
RequiredOptionalScorer { RequiredOptionalScorer {
req_scorer, req_scorer,
opt_scorer, opt_scorer,
score_cache: None, score_cache: None,
opt_finished,
_phantom: PhantomData, _phantom: PhantomData,
} }
} }
@@ -44,7 +48,7 @@ where
TReqScorer: DocSet, TReqScorer: DocSet,
TOptScorer: DocSet, TOptScorer: DocSet,
{ {
fn advance(&mut self) -> DocId { fn advance(&mut self) -> bool {
self.score_cache = None; self.score_cache = None;
self.req_scorer.advance() self.req_scorer.advance()
} }
@@ -72,8 +76,22 @@ where
let doc = self.doc(); let doc = self.doc();
let mut score_combiner = TScoreCombiner::default(); let mut score_combiner = TScoreCombiner::default();
score_combiner.update(&mut self.req_scorer); score_combiner.update(&mut self.req_scorer);
if self.opt_scorer.seek(doc) == doc { if !self.opt_finished {
score_combiner.update(&mut self.opt_scorer); match self.opt_scorer.doc().cmp(&doc) {
Ordering::Greater => {}
Ordering::Equal => {
score_combiner.update(&mut self.opt_scorer);
}
Ordering::Less => match self.opt_scorer.skip_next(doc) {
SkipResult::Reached => {
score_combiner.update(&mut self.opt_scorer);
}
SkipResult::End => {
self.opt_finished = true;
}
SkipResult::OverStep => {}
},
}
} }
let score = score_combiner.score(); let score = score_combiner.score();
self.score_cache = Some(score); self.score_cache = Some(score);
@@ -84,7 +102,7 @@ where
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::RequiredOptionalScorer; use super::RequiredOptionalScorer;
use crate::docset::{DocSet, TERMINATED}; use crate::docset::DocSet;
use crate::postings::tests::test_skip_against_unoptimized; use crate::postings::tests::test_skip_against_unoptimized;
use crate::query::score_combiner::{DoNothingCombiner, SumCombiner}; use crate::query::score_combiner::{DoNothingCombiner, SumCombiner};
use crate::query::ConstScorer; use crate::query::ConstScorer;
@@ -97,13 +115,12 @@ mod tests {
let req = vec![1, 3, 7]; let req = vec![1, 3, 7];
let mut reqoptscorer: RequiredOptionalScorer<_, _, SumCombiner> = let mut reqoptscorer: RequiredOptionalScorer<_, _, SumCombiner> =
RequiredOptionalScorer::new( RequiredOptionalScorer::new(
ConstScorer::from(VecDocSet::from(req.clone())), ConstScorer::new(VecDocSet::from(req.clone())),
ConstScorer::from(VecDocSet::from(vec![])), ConstScorer::new(VecDocSet::from(vec![])),
); );
let mut docs = vec![]; let mut docs = vec![];
while reqoptscorer.doc() != TERMINATED { while reqoptscorer.advance() {
docs.push(reqoptscorer.doc()); docs.push(reqoptscorer.doc());
reqoptscorer.advance();
} }
assert_eq!(docs, req); assert_eq!(docs, req);
} }
@@ -112,49 +129,50 @@ mod tests {
fn test_reqopt_scorer() { fn test_reqopt_scorer() {
let mut reqoptscorer: RequiredOptionalScorer<_, _, SumCombiner> = let mut reqoptscorer: RequiredOptionalScorer<_, _, SumCombiner> =
RequiredOptionalScorer::new( RequiredOptionalScorer::new(
ConstScorer::new(VecDocSet::from(vec![1, 3, 7, 8, 9, 10, 13, 15]), 1.0f32), ConstScorer::new(VecDocSet::from(vec![1, 3, 7, 8, 9, 10, 13, 15])),
ConstScorer::new(VecDocSet::from(vec![1, 2, 7, 11, 12, 15]), 1.0f32), ConstScorer::new(VecDocSet::from(vec![1, 2, 7, 11, 12, 15])),
); );
{ {
assert!(reqoptscorer.advance());
assert_eq!(reqoptscorer.doc(), 1); assert_eq!(reqoptscorer.doc(), 1);
assert_eq!(reqoptscorer.score(), 2f32); assert_eq!(reqoptscorer.score(), 2f32);
} }
{ {
assert_eq!(reqoptscorer.advance(), 3); assert!(reqoptscorer.advance());
assert_eq!(reqoptscorer.doc(), 3); assert_eq!(reqoptscorer.doc(), 3);
assert_eq!(reqoptscorer.score(), 1f32); assert_eq!(reqoptscorer.score(), 1f32);
} }
{ {
assert_eq!(reqoptscorer.advance(), 7); assert!(reqoptscorer.advance());
assert_eq!(reqoptscorer.doc(), 7); assert_eq!(reqoptscorer.doc(), 7);
assert_eq!(reqoptscorer.score(), 2f32); assert_eq!(reqoptscorer.score(), 2f32);
} }
{ {
assert_eq!(reqoptscorer.advance(), 8); assert!(reqoptscorer.advance());
assert_eq!(reqoptscorer.doc(), 8); assert_eq!(reqoptscorer.doc(), 8);
assert_eq!(reqoptscorer.score(), 1f32); assert_eq!(reqoptscorer.score(), 1f32);
} }
{ {
assert_eq!(reqoptscorer.advance(), 9); assert!(reqoptscorer.advance());
assert_eq!(reqoptscorer.doc(), 9); assert_eq!(reqoptscorer.doc(), 9);
assert_eq!(reqoptscorer.score(), 1f32); assert_eq!(reqoptscorer.score(), 1f32);
} }
{ {
assert_eq!(reqoptscorer.advance(), 10); assert!(reqoptscorer.advance());
assert_eq!(reqoptscorer.doc(), 10); assert_eq!(reqoptscorer.doc(), 10);
assert_eq!(reqoptscorer.score(), 1f32); assert_eq!(reqoptscorer.score(), 1f32);
} }
{ {
assert_eq!(reqoptscorer.advance(), 13); assert!(reqoptscorer.advance());
assert_eq!(reqoptscorer.doc(), 13); assert_eq!(reqoptscorer.doc(), 13);
assert_eq!(reqoptscorer.score(), 1f32); assert_eq!(reqoptscorer.score(), 1f32);
} }
{ {
assert_eq!(reqoptscorer.advance(), 15); assert!(reqoptscorer.advance());
assert_eq!(reqoptscorer.doc(), 15); assert_eq!(reqoptscorer.doc(), 15);
assert_eq!(reqoptscorer.score(), 2f32); assert_eq!(reqoptscorer.score(), 2f32);
} }
assert_eq!(reqoptscorer.advance(), TERMINATED); assert!(!reqoptscorer.advance());
} }
#[test] #[test]
@@ -165,8 +183,8 @@ mod tests {
test_skip_against_unoptimized( test_skip_against_unoptimized(
|| { || {
Box::new(RequiredOptionalScorer::<_, _, DoNothingCombiner>::new( Box::new(RequiredOptionalScorer::<_, _, DoNothingCombiner>::new(
ConstScorer::from(VecDocSet::from(req_docs.clone())), ConstScorer::new(VecDocSet::from(req_docs.clone())),
ConstScorer::from(VecDocSet::from(opt_docs.clone())), ConstScorer::new(VecDocSet::from(opt_docs.clone())),
)) ))
}, },
skip_docs, skip_docs,

View File

@@ -1,4 +1,5 @@
use crate::docset::DocSet; use crate::common::BitSet;
use crate::docset::{DocSet, SkipResult};
use crate::DocId; use crate::DocId;
use crate::Score; use crate::Score;
use downcast_rs::impl_downcast; use downcast_rs::impl_downcast;
@@ -12,6 +13,14 @@ pub trait Scorer: downcast_rs::Downcast + DocSet + 'static {
/// ///
/// This method will perform a bit of computation and is not cached. /// This method will perform a bit of computation and is not cached.
fn score(&mut self) -> Score; fn score(&mut self) -> Score;
/// Iterates through all of the document matched by the DocSet
/// `DocSet` and push the scored documents to the collector.
fn for_each(&mut self, callback: &mut dyn FnMut(DocId, Score)) {
while self.advance() {
callback(self.doc(), self.score());
}
}
} }
impl_downcast!(Scorer); impl_downcast!(Scorer);
@@ -20,6 +29,11 @@ impl Scorer for Box<dyn Scorer> {
fn score(&mut self) -> Score { fn score(&mut self) -> Score {
self.deref_mut().score() self.deref_mut().score()
} }
fn for_each(&mut self, callback: &mut dyn FnMut(DocId, Score)) {
let scorer = self.deref_mut();
scorer.for_each(callback);
}
} }
/// Wraps a `DocSet` and simply returns a constant `Scorer`. /// Wraps a `DocSet` and simply returns a constant `Scorer`.
@@ -35,24 +49,26 @@ pub struct ConstScorer<TDocSet: DocSet> {
impl<TDocSet: DocSet> ConstScorer<TDocSet> { impl<TDocSet: DocSet> ConstScorer<TDocSet> {
/// Creates a new `ConstScorer`. /// Creates a new `ConstScorer`.
pub fn new(docset: TDocSet, score: f32) -> ConstScorer<TDocSet> { pub fn new(docset: TDocSet) -> ConstScorer<TDocSet> {
ConstScorer { docset, score } ConstScorer {
docset,
score: 1f32,
}
} }
}
impl<TDocSet: DocSet> From<TDocSet> for ConstScorer<TDocSet> { /// Sets the constant score to a different value.
fn from(docset: TDocSet) -> Self { pub fn set_score(&mut self, score: Score) {
ConstScorer::new(docset, 1.0f32) self.score = score;
} }
} }
impl<TDocSet: DocSet> DocSet for ConstScorer<TDocSet> { impl<TDocSet: DocSet> DocSet for ConstScorer<TDocSet> {
fn advance(&mut self) -> DocId { fn advance(&mut self) -> bool {
self.docset.advance() self.docset.advance()
} }
fn seek(&mut self, target: DocId) -> DocId { fn skip_next(&mut self, target: DocId) -> SkipResult {
self.docset.seek(target) self.docset.skip_next(target)
} }
fn fill_buffer(&mut self, buffer: &mut [DocId]) -> usize { fn fill_buffer(&mut self, buffer: &mut [DocId]) -> usize {
@@ -66,10 +82,14 @@ impl<TDocSet: DocSet> DocSet for ConstScorer<TDocSet> {
fn size_hint(&self) -> u32 { fn size_hint(&self) -> u32 {
self.docset.size_hint() self.docset.size_hint()
} }
fn append_to_bitset(&mut self, bitset: &mut BitSet) {
self.docset.append_to_bitset(bitset);
}
} }
impl<TDocSet: DocSet + 'static> Scorer for ConstScorer<TDocSet> { impl<TDocSet: DocSet + 'static> Scorer for ConstScorer<TDocSet> {
fn score(&mut self) -> Score { fn score(&mut self) -> Score {
self.score 1f32
} }
} }

View File

@@ -26,8 +26,10 @@ mod tests {
{ {
// writing the segment // writing the segment
let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap(); let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap();
let doc = doc!(text_field => "a"); {
index_writer.add_document(doc); let doc = doc!(text_field => "a");
index_writer.add_document(doc);
}
assert!(index_writer.commit().is_ok()); assert!(index_writer.commit().is_ok());
} }
let searcher = index.reader().unwrap().searcher(); let searcher = index.reader().unwrap().searcher();
@@ -37,7 +39,8 @@ mod tests {
); );
let term_weight = term_query.weight(&searcher, true).unwrap(); let term_weight = term_query.weight(&searcher, true).unwrap();
let segment_reader = searcher.segment_reader(0); let segment_reader = searcher.segment_reader(0);
let mut term_scorer = term_weight.scorer(segment_reader, 1.0f32).unwrap(); let mut term_scorer = term_weight.scorer(segment_reader).unwrap();
assert!(term_scorer.advance());
assert_eq!(term_scorer.doc(), 0); assert_eq!(term_scorer.doc(), 0);
assert_eq!(term_scorer.score(), 0.28768212); assert_eq!(term_scorer.score(), 0.28768212);
} }

View File

@@ -3,6 +3,7 @@ use crate::query::bm25::BM25Weight;
use crate::query::Query; use crate::query::Query;
use crate::query::Weight; use crate::query::Weight;
use crate::schema::IndexRecordOption; use crate::schema::IndexRecordOption;
use crate::Result;
use crate::Searcher; use crate::Searcher;
use crate::Term; use crate::Term;
use std::collections::BTreeSet; use std::collections::BTreeSet;
@@ -100,7 +101,7 @@ impl TermQuery {
} }
impl Query for TermQuery { impl Query for TermQuery {
fn weight(&self, searcher: &Searcher, scoring_enabled: bool) -> crate::Result<Box<dyn Weight>> { fn weight(&self, searcher: &Searcher, scoring_enabled: bool) -> Result<Box<dyn Weight>> {
Ok(Box::new(self.specialized_weight(searcher, scoring_enabled))) Ok(Box::new(self.specialized_weight(searcher, scoring_enabled)))
} }
fn query_terms(&self, term_set: &mut BTreeSet<Term>) { fn query_terms(&self, term_set: &mut BTreeSet<Term>) {

View File

@@ -1,4 +1,4 @@
use crate::docset::DocSet; use crate::docset::{DocSet, SkipResult};
use crate::query::{Explanation, Scorer}; use crate::query::{Explanation, Scorer};
use crate::DocId; use crate::DocId;
use crate::Score; use crate::Score;
@@ -45,12 +45,12 @@ impl TermScorer {
} }
impl DocSet for TermScorer { impl DocSet for TermScorer {
fn advance(&mut self) -> DocId { fn advance(&mut self) -> bool {
self.postings.advance() self.postings.advance()
} }
fn seek(&mut self, target: DocId) -> DocId { fn skip_next(&mut self, target: DocId) -> SkipResult {
self.postings.seek(target) self.postings.skip_next(target)
} }
fn doc(&self) -> DocId { fn doc(&self) -> DocId {

View File

@@ -4,13 +4,12 @@ use crate::docset::DocSet;
use crate::postings::SegmentPostings; use crate::postings::SegmentPostings;
use crate::query::bm25::BM25Weight; use crate::query::bm25::BM25Weight;
use crate::query::explanation::does_not_match; use crate::query::explanation::does_not_match;
use crate::query::weight::{for_each_pruning_scorer, for_each_scorer};
use crate::query::Weight; use crate::query::Weight;
use crate::query::{Explanation, Scorer}; use crate::query::{Explanation, Scorer};
use crate::schema::IndexRecordOption; use crate::schema::IndexRecordOption;
use crate::Result; use crate::DocId;
use crate::Term; use crate::Term;
use crate::{DocId, Score}; use crate::{Result, SkipResult};
pub struct TermWeight { pub struct TermWeight {
term: Term, term: Term,
@@ -19,14 +18,14 @@ pub struct TermWeight {
} }
impl Weight for TermWeight { impl Weight for TermWeight {
fn scorer(&self, reader: &SegmentReader, boost: f32) -> Result<Box<dyn Scorer>> { fn scorer(&self, reader: &SegmentReader) -> Result<Box<dyn Scorer>> {
let term_scorer = self.scorer_specialized(reader, boost)?; let term_scorer = self.scorer_specialized(reader)?;
Ok(Box::new(term_scorer)) Ok(Box::new(term_scorer))
} }
fn explain(&self, reader: &SegmentReader, doc: DocId) -> Result<Explanation> { fn explain(&self, reader: &SegmentReader, doc: DocId) -> Result<Explanation> {
let mut scorer = self.scorer_specialized(reader, 1.0f32)?; let mut scorer = self.scorer_specialized(reader)?;
if scorer.seek(doc) != doc { if scorer.skip_next(doc) != SkipResult::Reached {
return Err(does_not_match(doc)); return Err(does_not_match(doc));
} }
Ok(scorer.explain()) Ok(scorer.explain())
@@ -34,7 +33,7 @@ impl Weight for TermWeight {
fn count(&self, reader: &SegmentReader) -> Result<u32> { fn count(&self, reader: &SegmentReader) -> Result<u32> {
if let Some(delete_bitset) = reader.delete_bitset() { if let Some(delete_bitset) = reader.delete_bitset() {
Ok(self.scorer(reader, 1.0f32)?.count(delete_bitset)) Ok(self.scorer(reader)?.count(delete_bitset))
} else { } else {
let field = self.term.field(); let field = self.term.field();
Ok(reader Ok(reader
@@ -44,39 +43,6 @@ impl Weight for TermWeight {
.unwrap_or(0)) .unwrap_or(0))
} }
} }
/// Iterates through all of the document matched by the DocSet
/// `DocSet` and push the scored documents to the collector.
fn for_each(
&self,
reader: &SegmentReader,
callback: &mut dyn FnMut(DocId, Score),
) -> crate::Result<()> {
let mut scorer = self.scorer_specialized(reader, 1.0f32)?;
for_each_scorer(&mut scorer, callback);
Ok(())
}
/// Calls `callback` with all of the `(doc, score)` for which score
/// is exceeding a given threshold.
///
/// This method is useful for the TopDocs collector.
/// For all docsets, the blanket implementation has the benefit
/// of prefiltering (doc, score) pairs, avoiding the
/// virtual dispatch cost.
///
/// More importantly, it makes it possible for scorers to implement
/// important optimization (e.g. BlockWAND for union).
fn for_each_pruning(
&self,
threshold: f32,
reader: &SegmentReader,
callback: &mut dyn FnMut(DocId, Score) -> Score,
) -> crate::Result<()> {
let mut scorer = self.scorer(reader, 1.0f32)?;
for_each_pruning_scorer(&mut scorer, threshold, callback);
Ok(())
}
} }
impl TermWeight { impl TermWeight {
@@ -92,11 +58,11 @@ impl TermWeight {
} }
} }
fn scorer_specialized(&self, reader: &SegmentReader, boost: f32) -> Result<TermScorer> { fn scorer_specialized(&self, reader: &SegmentReader) -> Result<TermScorer> {
let field = self.term.field(); let field = self.term.field();
let inverted_index = reader.inverted_index(field); let inverted_index = reader.inverted_index(field);
let fieldnorm_reader = reader.get_fieldnorms_reader(field); let fieldnorm_reader = reader.get_fieldnorms_reader(field);
let similarity_weight = self.similarity_weight.boost_by(boost); let similarity_weight = self.similarity_weight.clone();
let postings_opt: Option<SegmentPostings> = let postings_opt: Option<SegmentPostings> =
inverted_index.read_postings(&self.term, self.index_record_option); inverted_index.read_postings(&self.term, self.index_record_option);
if let Some(segment_postings) = postings_opt { if let Some(segment_postings) = postings_opt {

View File

@@ -1,9 +1,10 @@
use crate::common::TinySet; use crate::common::TinySet;
use crate::docset::{DocSet, TERMINATED}; use crate::docset::{DocSet, SkipResult};
use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner}; use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner};
use crate::query::Scorer; use crate::query::Scorer;
use crate::DocId; use crate::DocId;
use crate::Score; use crate::Score;
use std::cmp::Ordering;
const HORIZON_NUM_TINYBITSETS: usize = 64; const HORIZON_NUM_TINYBITSETS: usize = 64;
const HORIZON: u32 = 64u32 * HORIZON_NUM_TINYBITSETS as u32; const HORIZON: u32 = 64u32 * HORIZON_NUM_TINYBITSETS as u32;
@@ -46,9 +47,17 @@ where
fn from(docsets: Vec<TScorer>) -> Union<TScorer, TScoreCombiner> { fn from(docsets: Vec<TScorer>) -> Union<TScorer, TScoreCombiner> {
let non_empty_docsets: Vec<TScorer> = docsets let non_empty_docsets: Vec<TScorer> = docsets
.into_iter() .into_iter()
.filter(|docset| docset.doc() != TERMINATED) .flat_map(
|mut docset| {
if docset.advance() {
Some(docset)
} else {
None
}
},
)
.collect(); .collect();
let mut union = Union { Union {
docsets: non_empty_docsets, docsets: non_empty_docsets,
bitsets: Box::new([TinySet::empty(); HORIZON_NUM_TINYBITSETS]), bitsets: Box::new([TinySet::empty(); HORIZON_NUM_TINYBITSETS]),
scores: Box::new([TScoreCombiner::default(); HORIZON as usize]), scores: Box::new([TScoreCombiner::default(); HORIZON as usize]),
@@ -56,13 +65,7 @@ where
offset: 0, offset: 0,
doc: 0, doc: 0,
score: 0f32, score: 0f32,
};
if union.refill() {
union.advance();
} else {
union.doc = TERMINATED;
} }
union
} }
} }
@@ -83,7 +86,7 @@ fn refill<TScorer: Scorer, TScoreCombiner: ScoreCombiner>(
let delta = doc - min_doc; let delta = doc - min_doc;
bitsets[(delta / 64) as usize].insert_mut(delta % 64u32); bitsets[(delta / 64) as usize].insert_mut(delta % 64u32);
score_combiner[delta as usize].update(scorer); score_combiner[delta as usize].update(scorer);
if scorer.advance() == TERMINATED { if !scorer.advance() {
// remove the docset, it has been entirely consumed. // remove the docset, it has been entirely consumed.
return true; return true;
} }
@@ -96,7 +99,6 @@ impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> Union<TScorer, TScoreCombin
if let Some(min_doc) = self.docsets.iter().map(DocSet::doc).min() { if let Some(min_doc) = self.docsets.iter().map(DocSet::doc).min() {
self.offset = min_doc; self.offset = min_doc;
self.cursor = 0; self.cursor = 0;
self.doc = min_doc;
refill( refill(
&mut self.docsets, &mut self.docsets,
&mut *self.bitsets, &mut *self.bitsets,
@@ -131,23 +133,50 @@ where
TScorer: Scorer, TScorer: Scorer,
TScoreCombiner: ScoreCombiner, TScoreCombiner: ScoreCombiner,
{ {
fn advance(&mut self) -> DocId { fn advance(&mut self) -> bool {
if self.advance_buffered() { if self.advance_buffered() {
return self.doc; return true;
} }
if !self.refill() { if self.refill() {
self.doc = TERMINATED; self.advance();
return TERMINATED; true
} else {
false
} }
if !self.advance_buffered() {
return TERMINATED;
}
self.doc
} }
fn seek(&mut self, target: DocId) -> DocId { fn count_including_deleted(&mut self) -> u32 {
if self.doc >= target { let mut count = self.bitsets[self.cursor..HORIZON_NUM_TINYBITSETS]
return self.doc; .iter()
.map(|bitset| bitset.len())
.sum::<u32>();
for bitset in self.bitsets.iter_mut() {
bitset.clear();
}
while self.refill() {
count += self.bitsets.iter().map(|bitset| bitset.len()).sum::<u32>();
for bitset in self.bitsets.iter_mut() {
bitset.clear();
}
}
self.cursor = HORIZON_NUM_TINYBITSETS;
count
}
// TODO implement `count` efficiently.
fn skip_next(&mut self, target: DocId) -> SkipResult {
if !self.advance() {
return SkipResult::End;
}
match self.doc.cmp(&target) {
Ordering::Equal => {
return SkipResult::Reached;
}
Ordering::Greater => {
return SkipResult::OverStep;
}
Ordering::Less => {}
} }
let gap = target - self.offset; let gap = target - self.offset;
if gap < HORIZON { if gap < HORIZON {
@@ -165,11 +194,18 @@ where
// Advancing until we reach the end of the bucket // Advancing until we reach the end of the bucket
// or we reach a doc greater or equal to the target. // or we reach a doc greater or equal to the target.
let mut doc = self.doc(); while self.advance() {
while doc < target { match self.doc().cmp(&target) {
doc = self.advance(); Ordering::Equal => {
return SkipResult::Reached;
}
Ordering::Greater => {
return SkipResult::OverStep;
}
Ordering::Less => {}
}
} }
doc SkipResult::End
} else { } else {
// clear the buffered info. // clear the buffered info.
for obsolete_tinyset in self.bitsets.iter_mut() { for obsolete_tinyset in self.bitsets.iter_mut() {
@@ -183,53 +219,36 @@ where
// advance all docsets to a doc >= to the target. // advance all docsets to a doc >= to the target.
#[cfg_attr(feature = "cargo-clippy", allow(clippy::clippy::collapsible_if))] #[cfg_attr(feature = "cargo-clippy", allow(clippy::clippy::collapsible_if))]
unordered_drain_filter(&mut self.docsets, |docset| { unordered_drain_filter(&mut self.docsets, |docset| {
docset.seek(target) == TERMINATED if docset.doc() < target {
if docset.skip_next(target) == SkipResult::End {
return true;
}
}
false
}); });
// at this point all of the docsets // at this point all of the docsets
// are positionned on a doc >= to the target. // are positionned on a doc >= to the target.
if !self.refill() { if self.refill() {
self.doc = TERMINATED; self.advance();
return TERMINATED; if self.doc() == target {
SkipResult::Reached
} else {
debug_assert!(self.doc() > target);
SkipResult::OverStep
}
} else {
SkipResult::End
} }
self.advance()
} }
} }
// TODO Also implement `count` with deletes efficiently.
fn doc(&self) -> DocId { fn doc(&self) -> DocId {
self.doc self.doc
} }
fn size_hint(&self) -> u32 { fn size_hint(&self) -> u32 {
self.docsets 0u32
.iter()
.map(|docset| docset.size_hint())
.max()
.unwrap_or(0u32)
}
fn count_including_deleted(&mut self) -> u32 {
if self.doc == TERMINATED {
return 0;
}
let mut count = self.bitsets[self.cursor..HORIZON_NUM_TINYBITSETS]
.iter()
.map(|bitset| bitset.len())
.sum::<u32>()
+ 1;
for bitset in self.bitsets.iter_mut() {
bitset.clear();
}
while self.refill() {
count += self.bitsets.iter().map(|bitset| bitset.len()).sum::<u32>();
for bitset in self.bitsets.iter_mut() {
bitset.clear();
}
}
self.cursor = HORIZON_NUM_TINYBITSETS;
count
} }
} }
@@ -248,7 +267,7 @@ mod tests {
use super::Union; use super::Union;
use super::HORIZON; use super::HORIZON;
use crate::docset::{DocSet, TERMINATED}; use crate::docset::{DocSet, SkipResult};
use crate::postings::tests::test_skip_against_unoptimized; use crate::postings::tests::test_skip_against_unoptimized;
use crate::query::score_combiner::DoNothingCombiner; use crate::query::score_combiner::DoNothingCombiner;
use crate::query::ConstScorer; use crate::query::ConstScorer;
@@ -271,18 +290,18 @@ mod tests {
vals.iter() vals.iter()
.cloned() .cloned()
.map(VecDocSet::from) .map(VecDocSet::from)
.map(|docset| ConstScorer::new(docset, 1.0f32)) .map(ConstScorer::new)
.collect::<Vec<ConstScorer<VecDocSet>>>(), .collect::<Vec<ConstScorer<VecDocSet>>>(),
) )
}; };
let mut union: Union<_, DoNothingCombiner> = make_union(); let mut union: Union<_, DoNothingCombiner> = make_union();
let mut count = 0; let mut count = 0;
while union.doc() != TERMINATED { while union.advance() {
assert!(union_expected.advance());
assert_eq!(union_expected.doc(), union.doc()); assert_eq!(union_expected.doc(), union.doc());
assert_eq!(union_expected.advance(), union.advance());
count += 1; count += 1;
} }
assert_eq!(union_expected.advance(), TERMINATED); assert!(!union_expected.advance());
assert_eq!(count, make_union().count_including_deleted()); assert_eq!(count, make_union().count_including_deleted());
} }
@@ -310,7 +329,9 @@ mod tests {
fn test_aux_union_skip(docs_list: &[Vec<DocId>], skip_targets: Vec<DocId>) { fn test_aux_union_skip(docs_list: &[Vec<DocId>], skip_targets: Vec<DocId>) {
let mut btree_set = BTreeSet::new(); let mut btree_set = BTreeSet::new();
for docs in docs_list { for docs in docs_list {
btree_set.extend(docs.iter().cloned()); for &doc in docs.iter() {
btree_set.insert(doc);
}
} }
let docset_factory = || { let docset_factory = || {
let res: Box<dyn DocSet> = Box::new(Union::<_, DoNothingCombiner>::from( let res: Box<dyn DocSet> = Box::new(Union::<_, DoNothingCombiner>::from(
@@ -318,17 +339,17 @@ mod tests {
.iter() .iter()
.map(|docs| docs.clone()) .map(|docs| docs.clone())
.map(VecDocSet::from) .map(VecDocSet::from)
.map(|docset| ConstScorer::new(docset, 1.0f32)) .map(ConstScorer::new)
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
)); ));
res res
}; };
let mut docset = docset_factory(); let mut docset = docset_factory();
for el in btree_set { for el in btree_set {
assert!(docset.advance());
assert_eq!(el, docset.doc()); assert_eq!(el, docset.doc());
docset.advance();
} }
assert_eq!(docset.doc(), TERMINATED); assert!(!docset.advance());
test_skip_against_unoptimized(docset_factory, skip_targets); test_skip_against_unoptimized(docset_factory, skip_targets);
} }
@@ -348,13 +369,13 @@ mod tests {
#[test] #[test]
fn test_union_skip_corner_case3() { fn test_union_skip_corner_case3() {
let mut docset = Union::<_, DoNothingCombiner>::from(vec![ let mut docset = Union::<_, DoNothingCombiner>::from(vec![
ConstScorer::from(VecDocSet::from(vec![0u32, 5u32])), ConstScorer::new(VecDocSet::from(vec![0u32, 5u32])),
ConstScorer::from(VecDocSet::from(vec![1u32, 4u32])), ConstScorer::new(VecDocSet::from(vec![1u32, 4u32])),
]); ]);
assert!(docset.advance());
assert_eq!(docset.doc(), 0u32); assert_eq!(docset.doc(), 0u32);
assert_eq!(docset.seek(0u32), 0u32); assert_eq!(docset.skip_next(0u32), SkipResult::OverStep);
assert_eq!(docset.seek(0u32), 0u32); assert_eq!(docset.doc(), 1u32)
assert_eq!(docset.doc(), 0u32)
} }
#[test] #[test]

View File

@@ -1,8 +1,9 @@
#![allow(dead_code)] #![allow(dead_code)]
use crate::common::HasLen; use crate::common::HasLen;
use crate::docset::{DocSet, TERMINATED}; use crate::docset::DocSet;
use crate::DocId; use crate::DocId;
use std::num::Wrapping;
/// Simulate a `Postings` objects from a `VecPostings`. /// Simulate a `Postings` objects from a `VecPostings`.
/// `VecPostings` only exist for testing purposes. /// `VecPostings` only exist for testing purposes.
@@ -11,30 +12,26 @@ use crate::DocId;
/// No positions are returned. /// No positions are returned.
pub struct VecDocSet { pub struct VecDocSet {
doc_ids: Vec<DocId>, doc_ids: Vec<DocId>,
cursor: usize, cursor: Wrapping<usize>,
} }
impl From<Vec<DocId>> for VecDocSet { impl From<Vec<DocId>> for VecDocSet {
fn from(doc_ids: Vec<DocId>) -> VecDocSet { fn from(doc_ids: Vec<DocId>) -> VecDocSet {
VecDocSet { doc_ids, cursor: 0 } VecDocSet {
doc_ids,
cursor: Wrapping(usize::max_value()),
}
} }
} }
impl DocSet for VecDocSet { impl DocSet for VecDocSet {
fn advance(&mut self) -> DocId { fn advance(&mut self) -> bool {
self.cursor += 1; self.cursor += Wrapping(1);
if self.cursor >= self.doc_ids.len() { self.doc_ids.len() > self.cursor.0
self.cursor = self.doc_ids.len();
return TERMINATED;
}
self.doc()
} }
fn doc(&self) -> DocId { fn doc(&self) -> DocId {
if self.cursor == self.doc_ids.len() { self.doc_ids[self.cursor.0]
return TERMINATED;
}
self.doc_ids[self.cursor]
} }
fn size_hint(&self) -> u32 { fn size_hint(&self) -> u32 {
@@ -52,21 +49,22 @@ impl HasLen for VecDocSet {
pub mod tests { pub mod tests {
use super::*; use super::*;
use crate::docset::DocSet; use crate::docset::{DocSet, SkipResult};
use crate::DocId; use crate::DocId;
#[test] #[test]
pub fn test_vec_postings() { pub fn test_vec_postings() {
let doc_ids: Vec<DocId> = (0u32..1024u32).map(|e| e * 3).collect(); let doc_ids: Vec<DocId> = (0u32..1024u32).map(|e| e * 3).collect();
let mut postings = VecDocSet::from(doc_ids); let mut postings = VecDocSet::from(doc_ids);
assert!(postings.advance());
assert_eq!(postings.doc(), 0u32); assert_eq!(postings.doc(), 0u32);
assert_eq!(postings.advance(), 3u32); assert!(postings.advance());
assert_eq!(postings.doc(), 3u32); assert_eq!(postings.doc(), 3u32);
assert_eq!(postings.seek(14u32), 15u32); assert_eq!(postings.skip_next(14u32), SkipResult::OverStep);
assert_eq!(postings.doc(), 15u32); assert_eq!(postings.doc(), 15u32);
assert_eq!(postings.seek(300u32), 300u32); assert_eq!(postings.skip_next(300u32), SkipResult::Reached);
assert_eq!(postings.doc(), 300u32); assert_eq!(postings.doc(), 300u32);
assert_eq!(postings.seek(6000u32), TERMINATED); assert_eq!(postings.skip_next(6000u32), SkipResult::End);
} }
#[test] #[test]

Some files were not shown because too many files have changed in this diff Show More