Compare commits

..

72 Commits

Author SHA1 Message Date
Paul Masurel
ce7db2e1d0 Merge branch 'master' into blockwand 2020-07-16 15:39:52 +09:00
lyj
410aed0176 Update segment_updater.rs (#848) 2020-07-16 12:33:11 +09:00
aptend
00a239a712 fix typo in index_meta.rs (#851) 2020-07-16 12:32:45 +09:00
Paul Masurel
68fe406924 Removed asserts (#850) 2020-07-16 12:24:55 +09:00
Paul Masurel
f71b04acb0 Bugfix. (#849)
go_to_first_doc was typically calling seek with a target smaller than
doc.

Since SegmentPostings typically do a linear search on the full block,
regardless of the current position, it could have our segment postings
go backward.
2020-07-16 10:57:51 +09:00
Paul Masurel
1a462c641b merge 2020-07-14 14:01:48 +09:00
Paul Masurel
663db14d70 Bugfix.
go_to_first_doc was typically calling seek with a target smaller than
doc.

Since SegmentPostings typically do a linear search on the full block,
regardless of the current position, it could have our segment postings
go backward.
2020-07-14 12:59:20 +09:00
Paul Masurel
75ea74e465 added blockwand information 2020-07-14 09:56:20 +09:00
lyj
1ab7f660a4 Update index.rs (#846) 2020-07-02 15:11:38 +09:00
Sean Stangl
0ebbc4cb5a Fix incorrect SimpleTokenizer link in documentation (#844) 2020-07-01 10:26:36 +09:00
lyj
5300cb5da0 Update mod.rs (#845) 2020-07-01 10:25:26 +09:00
Paul Masurel
d71447a9e0 introducing Block WAND params to TextOptions 2020-06-10 09:44:34 +09:00
Ype Kingma
7d773abc92 Boolean query: do not combine excluded scores. (#840)
* Do nothing when combining score values of excluded scores.

* Add test case for two excluded.

* Test score for two excluded terms.

* Use TopDocs in test_boolean_query_two_excluded
2020-06-08 20:01:19 +09:00
Paul Masurel
c34541ccce Alive doc iterator. (#837) 2020-06-05 19:42:51 +09:00
Paul Masurel
1cc5bd706c Fixes build for no-default-features (#839) 2020-06-05 19:41:55 +09:00
Paul Masurel
7df5a8a530 ll 2020-06-05 19:37:38 +09:00
Paul Masurel
4026d183bc Small readability change 2020-06-03 09:04:57 +09:00
Paul Masurel
f0ab0fa5b8 Relying on blockwand 2020-06-01 22:28:08 +09:00
Paul Masurel
a53572069b merge 2020-06-01 13:57:32 +09:00
Paul Masurel
c0f5645cd9 Move for_each functions from Scorer to Weight. (#836)
* Move for_each functions from Scorer to Weight.

* Specialized foreach / foreach_pruning for union of termscorer.
2020-06-01 11:31:18 +09:00
Paul Masurel
522953ce5c merged 2020-05-27 17:13:49 +09:00
Paul Masurel
f750b18fd6 merged 2020-05-27 16:57:50 +09:00
Paul Masurel
cbff874e43 Change the loading of blocks. 2020-05-27 16:36:50 +09:00
Paul Masurel
baf015fc57 Simplification of the segment postings seek implementation. (#834) 2020-05-27 08:49:47 +09:00
Paul Masurel
7275ebdf3c Skiprefactoring skipabsolute (#831)
Simplification of the way we handle positions.
2020-05-25 09:51:23 +09:00
Paul Masurel
b974e7ce34 Closes #828. (#829)
There was a bug in the LogMergePolicy that was surfacing when there were
segments, but all of the segments were larger than the max limit.

After filtering, the list of segments candidate for merge was 0, and
the code was indexing the first element of an empty Vec.
2020-05-22 16:24:07 +09:00
Paul Masurel
8f8f34499f Updated CHANGELOG with the TopCollector offset information and cargo fmt. 2020-05-20 22:26:54 +09:00
Rob Young
6ea6f4bfcd Add offset to TopDocsCollector (#826)
* Add offset to TopDocsCollector

Add an offset to TopDocsCollector and TopDocs to make it clearer how to
handle pagination.

Closes #822

* Address review comments

- Make Debug formatting of TopDocs clearer.
- Add unit tests for limit and offset on TopCollector.
- Change API for using offset to a fluent interface.
- Add some context to the docstring to clarify what limit and offset are
  equivalent to in other projects.

* Changes required by rebase on e25284

- Pass Collector into TweakedScoreTopCollector and
  CustomScoreTopCollector.
- Add std:: qualifier to f32, i32 etc. Not sure why this was not failing
  already.
- Add unit tests for TopDocs with offset including for tweaked and
  custom score collectors.

In order to convert a TopCollector<Score> to a TopCollector<TScore> I
had to add a `into_tscore` method to `TopCollector`. This is a hack but
I don't know how to avoid it.
2020-05-20 22:25:24 +09:00
Paul Masurel
5623112132 blop 2020-05-19 17:31:29 +09:00
Paul Masurel
dd20454cc7 First stab at blockwand 2020-05-17 16:09:04 +09:00
Paul Masurel
e25284bafe Major change in the DocSet/Scorer API (#824)
- Change in the DocSet and Scorer API. (@fulmicoton). 
A freshly created DocSet point directly to their first doc. A sentinel value called TERMINATED marks the end of a DocSet.
`.advance()` returns the new DocId. `Scorer::skip(target)` has been replaced by `Scorer::seek(target)` and returns the resulting DocId.
As a result, iterating through DocSet now looks as follows
```rust
let mut doc = docset.doc();
while doc != TERMINATED {
   // ...
   doc = docset.advance();
}
```
The change made it possible to greatly simplify a lot of the docset's code.
- Misc internal optimization and introduction of the `Scorer::for_each_pruning` function. (@fulmicoton)
2020-05-16 16:33:36 +09:00
Fisher Darling
8b67877cd5 Made field methods const fns (#823) 2020-05-16 10:59:50 +09:00
Rob Young
9de1360538 Minor doc and test improvements around fuzzy querying (#825) 2020-05-16 10:59:24 +09:00
Paul Masurel
c55db83609 Closes #805 (#820)
Added TryInto implementation for IndexReaderBuilder
2020-04-27 12:01:17 +09:00
Paul Masurel
1e5ebdbf3c Format and remove useless import (#819) 2020-04-27 11:56:49 +09:00
Paul Masurel
9a2090ab21 Create the MMapDirectory does not return a Directory. (#818) 2020-04-27 11:42:20 +09:00
Paul Masurel
e4aaacdb86 Minor change in README.md 2020-04-21 21:30:34 +09:00
Paul Masurel
29acf1104d Update README's claim on performance. 2020-04-21 14:44:26 +09:00
Paul Masurel
3d34fa0b69 Fixed changelog 2020-04-19 15:55:54 +09:00
Rob Young
77f363987a Make TweakScore and CustomScore mutable at the segment level (#807)
* Make TweakScore and CustomScore mutable

Make TweakScore and CustomScore mutable at the segment level.

Addresses issue #806

* Add example to show tweak_score working for facets
2020-04-19 15:54:00 +09:00
Paul Masurel
c0be461191 Removing tantivy-fst conf and removing warning. (#813) 2020-04-18 20:19:23 +09:00
dependabot-preview[bot]
1fb562f44a Update fail requirement from 0.3 to 0.4 (#810)
Updates the requirements on [fail](https://github.com/tikv/fail-rs) to permit the latest version.
- [Release notes](https://github.com/tikv/fail-rs/releases)
- [Changelog](https://github.com/tikv/fail-rs/blob/master/CHANGELOG.md)
- [Commits](https://github.com/tikv/fail-rs/compare/v0.3.0...v0.4.0)

Signed-off-by: dependabot-preview[bot] <support@dependabot.com>

Co-authored-by: dependabot-preview[bot] <27856297+dependabot-preview[bot]@users.noreply.github.com>
2020-04-17 07:14:19 +09:00
Rob Young
c591d0e591 Switch fst dependency to git (#808)
Closes #803

This allows the package to be built without first cloning the
tantivy-search/fst repo into the expected place. This should fix CI.
2020-04-16 23:05:12 +09:00
Paul Masurel
186d7fc20e Fix build 2020-04-01 09:32:45 +09:00
Paul Masurel
cfbdef5186 Using tantivy-fst version 0.3. 2020-03-31 23:24:54 +09:00
Paul Masurel
d04368b1d4 Closes #788. OR not working when using conjunction by default. (#802) 2020-03-31 21:13:50 +09:00
Chen Xu
b167058028 Fix prefix option for FuzzyTermQuery (#797)
* Fix prefix option for FuzzyTermQuery

* Update changelog
2020-03-19 20:19:32 +09:00
Paul Masurel
262957717b unit test fix and use of matches 2020-03-15 00:20:17 +09:00
Paul Masurel
873a808321 Removed itertools (#792) 2020-03-11 18:41:04 +09:00
dependabot-preview[bot]
6fa8f9330e Update base64 requirement from 0.11.0 to 0.12.0 (#791)
Updates the requirements on [base64](https://github.com/marshallpierce/rust-base64) to permit the latest version.
- [Release notes](https://github.com/marshallpierce/rust-base64/releases)
- [Changelog](https://github.com/marshallpierce/rust-base64/blob/master/RELEASE-NOTES.md)
- [Commits](https://github.com/marshallpierce/rust-base64/compare/v0.11.0...v0.12.0)

Signed-off-by: dependabot-preview[bot] <support@dependabot.com>

Co-authored-by: dependabot-preview[bot] <27856297+dependabot-preview[bot]@users.noreply.github.com>
2020-03-11 17:51:22 +09:00
Paul Masurel
b3f0ef0878 Avoid writing a new delete file if there was no actual deletes. (#787)
When applying the delete operations in the delete queue, it is possible
that there was no new deleted document.

In this case, avoid creating a new delete file, and updating the delete
opstamp.
2020-03-08 13:04:21 +09:00
Paul Masurel
04304262ba cargo fmt 2020-03-08 09:58:42 +09:00
Paul Masurel
920ced364a Added a method to persist the RAMDirectory into a different directory. 2020-03-07 17:00:50 +09:00
Paul Masurel
e0499118e2 Minor refactoring 2020-03-07 15:56:03 +09:00
Paul Masurel
50b5efae46 Added derive feature to serde crate 2020-03-06 23:46:29 +09:00
Paul Masurel
486b8fa9c5 Removing serde-derive dependency (#786) 2020-03-06 23:33:58 +09:00
Minoru Osuka
b2baed9bdd Add Lindera to README.md (#785)
* Add Lindera to README.md

* Put lindera in first place
2020-03-03 20:23:59 +09:00
Paul Masurel
b591542c0b Removing err.description() before deprecation. 2020-03-03 09:58:49 +09:00
Paul Masurel
a83fa00ac4 Faster compilation of query-grammar. (#784) 2020-03-02 22:12:42 +09:00
Paul Masurel
7ff5c7c797 Removing the fst feature in the levenshtein_automata crate. 2020-03-02 21:47:05 +09:00
Paul Masurel
1748602691 ignore -> compile_fail 2020-03-02 09:59:48 +09:00
Paul Masurel
6542dd5337 Removing parenthesis. 2020-03-01 09:41:53 +09:00
Nicholas Connor
c64a44b9e1 Slight re-organization to increase contrast of "Getting Started" (#783) 2020-02-28 08:42:38 +09:00
Paul Masurel
fccc5b3bed Closes #758 2020-02-27 17:58:43 +09:00
Paul Masurel
98b9d5c6c4 Closes #780. Will be fixed on the next published release. 2020-02-21 09:41:52 +09:00
Paul Masurel
afd2c1a8ad Merge branch 'master' of github.com:tantivy-search/tantivy 2020-02-19 22:08:44 +09:00
Paul Masurel
81f35a3ceb Bumped tantivy-grammar version 2020-02-19 22:08:31 +09:00
Paul Masurel
7e2e765f4a Bumped tantivy-grammar version 2020-02-19 22:07:54 +09:00
Paul Masurel
7d6cfa58e1 [WIP] Alternative take on boosted queries (#772)
* Alternative take on boosted queries

* Fixing unit test

* Added boosting to the query grammar.

* Made BoostQuery public.

* Added support for boosting field in QueryParser

Closes #547
2020-02-19 11:04:38 +09:00
Paul Masurel
14735ce3aa Update snap version to 1. (#781) 2020-02-17 10:41:44 +09:00
Paul Masurel
72f7cc1569 Closes #777 (#779) 2020-02-17 09:53:38 +09:00
Paul Masurel
abef5c4e74 Updating combine to version 4 (#775) 2020-02-06 23:02:48 +09:00
106 changed files with 4664 additions and 2511 deletions

View File

@@ -1,9 +1,31 @@
Tantivy 0.13.0
======================
- Bugfix in `FuzzyTermQuery` not matching terms by prefix when it should (@Peachball)
- Relaxed constraints on the custom/tweak score functions. At the segment level, they can be mut, and they are not required to be Sync + Send.
- `MMapDirectory::open` does not return a `Result` anymore.
- Change in the DocSet and Scorer API. (@fulmicoton).
A freshly created DocSet point directly to their first doc. A sentinel value called TERMINATED marks the end of a DocSet.
`.advance()` returns the new DocId. `Scorer::skip(target)` has been replaced by `Scorer::seek(target)` and returns the resulting DocId.
As a result, iterating through DocSet now looks as follows
```rust
let mut doc = docset.doc();
while doc != TERMINATED {
// ...
doc = docset.advance();
}
```
The change made it possible to greatly simplify a lot of the docset's code.
- Misc internal optimization and introduction of the `Scorer::for_each_pruning` function. (@fulmicoton)
- Added an offset option to the Top(.*)Collectors. (@robyoung)
Tantivy 0.12.0
======================
- Removing static dispatch in tokenizers for simplicity. (#762)
- Added backward iteration for `TermDictionary` stream. (@halvorboe)
- Fixed a performance issue when searching for the posting lists of a missing term (@audunhalland)
- Added a configurable maximum number of docs (10M by default) for a segment to be considered for merge (@hntd187, landed by @halvorboe #713)
- Important Bugfix #777, causing tantivy to retain memory mapping. (diagnosed by @poljar)
- Added support for field boosting. (#547, @fulmicoton)
## How to update?

View File

@@ -5,7 +5,7 @@ authors = ["Paul Masurel <paul.masurel@gmail.com>"]
license = "MIT"
categories = ["database-implementations", "data-structures"]
description = """Search engine library"""
documentation = "https://tantivy-search.github.io/tantivy/tantivy/index.html"
documentation = "https://docs.rs/tantivy/"
homepage = "https://github.com/tantivy-search/tantivy"
repository = "https://github.com/tantivy-search/tantivy"
readme = "README.md"
@@ -13,25 +13,23 @@ keywords = ["search", "information", "retrieval"]
edition = "2018"
[dependencies]
base64 = "0.11.0"
base64 = "0.12.0"
byteorder = "1.0"
crc32fast = "1.2.0"
once_cell = "1.0"
regex ={version = "1.3.0", default-features = false, features = ["std"]}
tantivy-fst = "0.2.1"
tantivy-fst = "0.3"
memmap = {version = "0.7", optional=true}
lz4 = {version="1.20", optional=true}
snap = {version="0.2"}
snap = "1"
atomicwrites = {version="0.2.2", optional=true}
tempfile = "3.0"
log = "0.4"
serde = "1.0"
serde_derive = "1.0"
serde = {version="1.0", features=["derive"]}
serde_json = "1.0"
num_cpus = "1.2"
fs2={version="0.4", optional=true}
itertools = "0.8"
levenshtein_automata = {version="0.1", features=["fst_automaton"]}
levenshtein_automata = "0.2"
notify = {version="4", optional=true}
uuid = { version = "0.8", features = ["v4", "serde"] }
crossbeam = "0.7"
@@ -40,14 +38,14 @@ owning_ref = "0.4"
stable_deref_trait = "1.0.0"
rust-stemmers = "1.2"
downcast-rs = { version="1.0" }
tantivy-query-grammar = { version="0.11", path="./query-grammar" }
tantivy-query-grammar = { version="0.13", path="./query-grammar" }
bitpacking = {version="0.8", default-features = false, features=["bitpacker4x"]}
census = "0.4"
fnv = "1.0.6"
owned-read = "0.4"
failure = "0.1"
htmlescape = "0.3.1"
fail = "0.3"
fail = "0.4"
murmurhash32 = "0.2"
chrono = "0.4"
smallvec = "1.0"
@@ -60,9 +58,10 @@ winapi = "0.3"
rand = "0.7"
maplit = "1"
matches = "0.1.8"
proptest = "0.10"
[dev-dependencies.fail]
version = "0.3"
version = "0.4"
features = ["failpoints"]
[profile.release]

View File

@@ -31,16 +31,20 @@ Tantivy is, in fact, strongly inspired by Lucene's design.
# Benchmark
Tantivy is typically faster than Lucene, but the results depend on
the nature of the queries in your workload.
The following [benchmark](https://tantivy-search.github.io/bench/) break downs
performance for different type of queries / collection.
In general, Tantivy tends to be
- slower than Lucene on union with a Top-K due to Block-WAND optimization.
- faster than Lucene on intersection and phrase queries.
Your mileage WILL vary depending on the nature of queries and their load.
# Features
- Full-text search
- Configurable tokenizer (stemming available for 17 Latin languages with third party support for Chinese ([tantivy-jieba](https://crates.io/crates/tantivy-jieba) and [cang-jie](https://crates.io/crates/cang-jie)) and [Japanese](https://crates.io/crates/tantivy-tokenizer-tiny-segmenter))
- Configurable tokenizer (stemming available for 17 Latin languages with third party support for Chinese ([tantivy-jieba](https://crates.io/crates/tantivy-jieba) and [cang-jie](https://crates.io/crates/cang-jie)), Japanese ([lindera](https://github.com/lindera-morphology/lindera-tantivy) and [tantivy-tokenizer-tiny-segmente](https://crates.io/crates/tantivy-tokenizer-tiny-segmenter)) and Korean ([lindera](https://github.com/lindera-morphology/lindera-tantivy) + [lindera-ko-dic-builder](https://github.com/lindera-morphology/lindera-ko-dic-builder))
- Fast (check out the :racehorse: :sparkles: [benchmark](https://tantivy-search.github.io/bench/) :sparkles: :racehorse:)
- Tiny startup time (<10ms), perfect for command line tools
- BM25 scoring (the same as Lucene)
@@ -59,18 +63,17 @@ performance for different type of queries / collection.
- Configurable indexing (optional term frequency and position indexing)
- Cheesy logo with a horse
# Non-features
## Non-features
- Distributed search is out of the scope of Tantivy. That being said, Tantivy is a
library upon which one could build a distributed search. Serializable/mergeable collector state for instance,
are within the scope of Tantivy.
# Supported OS and compiler
Tantivy works on stable Rust (>= 1.27) and supports Linux, MacOS, and Windows.
# Getting started
Tantivy works on stable Rust (>= 1.27) and supports Linux, MacOS, and Windows.
- [Tantivy's simple search example](https://tantivy-search.github.io/examples/basic_search.html)
- [tantivy-cli and its tutorial](https://github.com/tantivy-search/tantivy-cli) - `tantivy-cli` is an actual command line interface that makes it easy for you to create a search engine,
index documents, and search via the CLI or a small server with a REST API.

View File

@@ -18,5 +18,5 @@ install:
build: false
test_script:
- REM SET RUST_LOG=tantivy,test & cargo test --verbose --no-default-features --features mmap
- REM SET RUST_LOG=tantivy,test & cargo test --all --verbose --no-default-features --features mmap
- REM SET RUST_BACKTRACE=1 & cargo build --examples

View File

@@ -0,0 +1,98 @@
use std::collections::HashSet;
use tantivy::collector::TopDocs;
use tantivy::doc;
use tantivy::query::BooleanQuery;
use tantivy::schema::*;
use tantivy::{DocId, Index, Score, SegmentReader};
fn main() -> tantivy::Result<()> {
let mut schema_builder = Schema::builder();
let title = schema_builder.add_text_field("title", STORED);
let ingredient = schema_builder.add_facet_field("ingredient");
let schema = schema_builder.build();
let index = Index::create_in_ram(schema.clone());
let mut index_writer = index.writer(30_000_000)?;
index_writer.add_document(doc!(
title => "Fried egg",
ingredient => Facet::from("/ingredient/egg"),
ingredient => Facet::from("/ingredient/oil"),
));
index_writer.add_document(doc!(
title => "Scrambled egg",
ingredient => Facet::from("/ingredient/egg"),
ingredient => Facet::from("/ingredient/butter"),
ingredient => Facet::from("/ingredient/milk"),
ingredient => Facet::from("/ingredient/salt"),
));
index_writer.add_document(doc!(
title => "Egg rolls",
ingredient => Facet::from("/ingredient/egg"),
ingredient => Facet::from("/ingredient/garlic"),
ingredient => Facet::from("/ingredient/salt"),
ingredient => Facet::from("/ingredient/oil"),
ingredient => Facet::from("/ingredient/tortilla-wrap"),
ingredient => Facet::from("/ingredient/mushroom"),
));
index_writer.commit()?;
let reader = index.reader()?;
let searcher = reader.searcher();
{
let facets = vec![
Facet::from("/ingredient/egg"),
Facet::from("/ingredient/oil"),
Facet::from("/ingredient/garlic"),
Facet::from("/ingredient/mushroom"),
];
let query = BooleanQuery::new_multiterms_query(
facets
.iter()
.map(|key| Term::from_facet(ingredient, &key))
.collect(),
);
let top_docs_by_custom_score =
TopDocs::with_limit(2).tweak_score(move |segment_reader: &SegmentReader| {
let mut ingredient_reader = segment_reader.facet_reader(ingredient).unwrap();
let facet_dict = ingredient_reader.facet_dict();
let query_ords: HashSet<u64> = facets
.iter()
.filter_map(|key| facet_dict.term_ord(key.encoded_str()))
.collect();
let mut facet_ords_buffer: Vec<u64> = Vec::with_capacity(20);
move |doc: DocId, original_score: Score| {
ingredient_reader.facet_ords(doc, &mut facet_ords_buffer);
let missing_ingredients = facet_ords_buffer
.iter()
.filter(|ord| !query_ords.contains(ord))
.count();
let tweak = 1.0 / 4_f32.powi(missing_ingredients as i32);
original_score * tweak
}
});
let top_docs = searcher.search(&query, &top_docs_by_custom_score)?;
let titles: Vec<String> = top_docs
.iter()
.map(|(_, doc_id)| {
searcher
.doc(*doc_id)
.unwrap()
.get_first(title)
.unwrap()
.text()
.unwrap()
.to_owned()
})
.collect();
assert_eq!(titles, vec!["Fried egg", "Egg rolls"]);
}
Ok(())
}

View File

@@ -10,7 +10,7 @@
// ---
// Importing tantivy...
use tantivy::schema::*;
use tantivy::{doc, DocId, DocSet, Index, Postings};
use tantivy::{doc, DocSet, Index, Postings, TERMINATED};
fn main() -> tantivy::Result<()> {
// We first create a schema for the sake of the
@@ -62,12 +62,11 @@ fn main() -> tantivy::Result<()> {
{
// this buffer will be used to request for positions
let mut positions: Vec<u32> = Vec::with_capacity(100);
while segment_postings.advance() {
// the number of time the term appears in the document.
let doc_id: DocId = segment_postings.doc(); //< do not try to access this before calling advance once.
let mut doc_id = segment_postings.doc();
while doc_id != TERMINATED {
// This MAY contains deleted documents as well.
if segment_reader.is_deleted(doc_id) {
doc_id = segment_postings.advance();
continue;
}
@@ -86,6 +85,7 @@ fn main() -> tantivy::Result<()> {
// Doc 2: TermFreq 1: [0]
// ```
println!("Doc {}: TermFreq {}: {:?}", doc_id, term_freq, positions);
doc_id = segment_postings.advance();
}
}
}
@@ -117,11 +117,16 @@ fn main() -> tantivy::Result<()> {
if let Some(mut block_segment_postings) =
inverted_index.read_block_postings(&term_the, IndexRecordOption::Basic)
{
while block_segment_postings.advance() {
loop {
let docs = block_segment_postings.docs();
if docs.is_empty() {
break;
}
// Once again these docs MAY contains deleted documents as well.
let docs = block_segment_postings.docs();
// Prints `Docs [0, 2].`
println!("Docs {:?}", docs);
block_segment_postings.advance();
}
}
}

View File

@@ -9,11 +9,10 @@
// - import tokenized text straight from json,
// - perform a search on documents with pre-tokenized text
use tantivy::tokenizer::{PreTokenizedString, SimpleTokenizer, Token, Tokenizer};
use tantivy::collector::{Count, TopDocs};
use tantivy::query::TermQuery;
use tantivy::schema::*;
use tantivy::tokenizer::{PreTokenizedString, SimpleTokenizer, Token, Tokenizer};
use tantivy::{doc, Index, ReloadPolicy};
use tempfile::TempDir;

View File

@@ -1,6 +1,6 @@
[package]
name = "tantivy-query-grammar"
version = "0.11.0"
version = "0.13.0"
authors = ["Paul Masurel <paul.masurel@gmail.com>"]
license = "MIT"
categories = ["database-implementations", "data-structures"]
@@ -13,4 +13,4 @@ keywords = ["search", "information", "retrieval"]
edition = "2018"
[dependencies]
combine = ">=3.6.0,<4.0.0"
combine = {version="4", default-features=false, features=[] }

View File

@@ -1,5 +1,3 @@
#![recursion_limit = "100"]
mod occur;
mod query_grammar;
mod user_input_ast;

View File

@@ -1,171 +1,209 @@
use super::user_input_ast::*;
use super::user_input_ast::{UserInputAST, UserInputBound, UserInputLeaf, UserInputLiteral};
use crate::Occur;
use combine::char::*;
use combine::error::StreamError;
use combine::stream::StreamErrorFor;
use combine::*;
use combine::error::StringStreamError;
use combine::parser::char::{char, digit, letter, space, spaces, string};
use combine::parser::Parser;
use combine::{
attempt, choice, eof, many, many1, one_of, optional, parser, satisfy, skip_many1, value,
};
parser! {
fn field[I]()(I) -> String
where [I: Stream<Item = char>] {
(
letter(),
many(satisfy(|c: char| c.is_alphanumeric() || c == '_')),
).skip(char(':')).map(|(s1, s2): (char, String)| format!("{}{}", s1, s2))
}
}
parser! {
fn word[I]()(I) -> String
where [I: Stream<Item = char>] {
(
satisfy(|c: char| !c.is_whitespace() && !['-', '`', ':', '{', '}', '"', '[', ']', '(',')'].contains(&c) ),
many(satisfy(|c: char| !c.is_whitespace() && ![':', '{', '}', '"', '[', ']', '(',')'].contains(&c)))
)
fn field<'a>() -> impl Parser<&'a str, Output = String> {
(
letter(),
many(satisfy(|c: char| c.is_alphanumeric() || c == '_')),
)
.skip(char(':'))
.map(|(s1, s2): (char, String)| format!("{}{}", s1, s2))
.and_then(|s: String|
match s.as_str() {
"OR" => Err(StreamErrorFor::<I>::unexpected_static_message("OR")),
"AND" => Err(StreamErrorFor::<I>::unexpected_static_message("AND")),
"NOT" => Err(StreamErrorFor::<I>::unexpected_static_message("NOT")),
_ => Ok(s)
})
}
}
parser! {
fn literal[I]()(I) -> UserInputLeaf
where [I: Stream<Item = char>]
{
let term_val = || {
let phrase = char('"').with(many1(satisfy(|c| c != '"'))).skip(char('"'));
phrase.or(word())
};
let term_val_with_field = negative_number().or(term_val());
let term_query =
(field(), term_val_with_field)
.map(|(field_name, phrase)| UserInputLiteral {
field_name: Some(field_name),
phrase,
});
let term_default_field = term_val().map(|phrase| UserInputLiteral {
field_name: None,
phrase,
});
attempt(term_query)
.or(term_default_field)
.map(UserInputLeaf::from)
}
}
parser! {
fn negative_number[I]()(I) -> String
where [I: Stream<Item = char>]
{
(char('-'), many1(satisfy(char::is_numeric)),
optional((char('.'), many1(satisfy(char::is_numeric)))))
.map(|(s1, s2, s3): (char, String, Option<(char, String)>)| {
if let Some(('.', s3)) = s3 {
format!("{}{}.{}", s1, s2, s3)
} else {
format!("{}{}", s1, s2)
}
})
}
}
parser! {
fn spaces1[I]()(I) -> ()
where [I: Stream<Item = char>] {
skip_many1(space())
}
}
parser! {
/// Function that parses a range out of a Stream
/// Supports ranges like:
/// [5 TO 10], {5 TO 10}, [* TO 10], [10 TO *], {10 TO *], >5, <=10
/// [a TO *], [a TO c], [abc TO bcd}
fn range[I]()(I) -> UserInputLeaf
where [I: Stream<Item = char>] {
let range_term_val = || {
word().or(negative_number()).or(char('*').with(value("*".to_string())))
};
// check for unbounded range in the form of <5, <=10, >5, >=5
let elastic_unbounded_range = (choice([attempt(string(">=")),
attempt(string("<=")),
attempt(string("<")),
attempt(string(">"))])
.skip(spaces()),
range_term_val()).
map(|(comparison_sign, bound): (&str, String)|
match comparison_sign {
">=" => (UserInputBound::Inclusive(bound), UserInputBound::Unbounded),
"<=" => (UserInputBound::Unbounded, UserInputBound::Inclusive(bound)),
"<" => (UserInputBound::Unbounded, UserInputBound::Exclusive(bound)),
">" => (UserInputBound::Exclusive(bound), UserInputBound::Unbounded),
// default case
_ => (UserInputBound::Unbounded, UserInputBound::Unbounded)
});
let lower_bound = (one_of("{[".chars()), range_term_val())
.map(|(boundary_char, lower_bound): (char, String)|
if lower_bound == "*" {
UserInputBound::Unbounded
} else if boundary_char == '{' {
UserInputBound::Exclusive(lower_bound)
} else {
UserInputBound::Inclusive(lower_bound)
});
let upper_bound = (range_term_val(), one_of("}]".chars()))
.map(|(higher_bound, boundary_char): (String, char)|
if higher_bound == "*" {
UserInputBound::Unbounded
} else if boundary_char == '}' {
UserInputBound::Exclusive(higher_bound)
} else {
UserInputBound::Inclusive(higher_bound)
});
// return only lower and upper
let lower_to_upper = (lower_bound.
skip((spaces(),
string("TO"),
spaces())),
upper_bound);
(optional(field()).skip(spaces()),
// try elastic first, if it matches, the range is unbounded
attempt(elastic_unbounded_range).or(lower_to_upper))
.map(|(field, (lower, upper))|
// Construct the leaf from extracted field (optional)
// and bounds
UserInputLeaf::Range {
field,
lower,
upper
fn word<'a>() -> impl Parser<&'a str, Output = String> {
(
satisfy(|c: char| {
!c.is_whitespace()
&& !['-', '^', '`', ':', '{', '}', '"', '[', ']', '(', ')'].contains(&c)
}),
many(satisfy(|c: char| {
!c.is_whitespace() && ![':', '^', '{', '}', '"', '[', ']', '(', ')'].contains(&c)
})),
)
.map(|(s1, s2): (char, String)| format!("{}{}", s1, s2))
.and_then(|s: String| match s.as_str() {
"OR" | "AND " | "NOT" => Err(StringStreamError::UnexpectedParse),
_ => Ok(s),
})
}
}
fn term_val<'a>() -> impl Parser<&'a str, Output = String> {
let phrase = char('"').with(many1(satisfy(|c| c != '"'))).skip(char('"'));
phrase.or(word())
}
fn term_query<'a>() -> impl Parser<&'a str, Output = UserInputLiteral> {
let term_val_with_field = negative_number().or(term_val());
(field(), term_val_with_field).map(|(field_name, phrase)| UserInputLiteral {
field_name: Some(field_name),
phrase,
})
}
fn literal<'a>() -> impl Parser<&'a str, Output = UserInputLeaf> {
let term_default_field = term_val().map(|phrase| UserInputLiteral {
field_name: None,
phrase,
});
attempt(term_query())
.or(term_default_field)
.map(UserInputLeaf::from)
}
fn negative_number<'a>() -> impl Parser<&'a str, Output = String> {
(
char('-'),
many1(digit()),
optional((char('.'), many1(digit()))),
)
.map(|(s1, s2, s3): (char, String, Option<(char, String)>)| {
if let Some(('.', s3)) = s3 {
format!("{}{}.{}", s1, s2, s3)
} else {
format!("{}{}", s1, s2)
}
})
}
fn spaces1<'a>() -> impl Parser<&'a str, Output = ()> {
skip_many1(space())
}
/// Function that parses a range out of a Stream
/// Supports ranges like:
/// [5 TO 10], {5 TO 10}, [* TO 10], [10 TO *], {10 TO *], >5, <=10
/// [a TO *], [a TO c], [abc TO bcd}
fn range<'a>() -> impl Parser<&'a str, Output = UserInputLeaf> {
let range_term_val = || {
word()
.or(negative_number())
.or(char('*').with(value("*".to_string())))
};
// check for unbounded range in the form of <5, <=10, >5, >=5
let elastic_unbounded_range = (
choice([
attempt(string(">=")),
attempt(string("<=")),
attempt(string("<")),
attempt(string(">")),
])
.skip(spaces()),
range_term_val(),
)
.map(
|(comparison_sign, bound): (&str, String)| match comparison_sign {
">=" => (UserInputBound::Inclusive(bound), UserInputBound::Unbounded),
"<=" => (UserInputBound::Unbounded, UserInputBound::Inclusive(bound)),
"<" => (UserInputBound::Unbounded, UserInputBound::Exclusive(bound)),
">" => (UserInputBound::Exclusive(bound), UserInputBound::Unbounded),
// default case
_ => (UserInputBound::Unbounded, UserInputBound::Unbounded),
},
);
let lower_bound = (one_of("{[".chars()), range_term_val()).map(
|(boundary_char, lower_bound): (char, String)| {
if lower_bound == "*" {
UserInputBound::Unbounded
} else if boundary_char == '{' {
UserInputBound::Exclusive(lower_bound)
} else {
UserInputBound::Inclusive(lower_bound)
}
},
);
let upper_bound = (range_term_val(), one_of("}]".chars())).map(
|(higher_bound, boundary_char): (String, char)| {
if higher_bound == "*" {
UserInputBound::Unbounded
} else if boundary_char == '}' {
UserInputBound::Exclusive(higher_bound)
} else {
UserInputBound::Inclusive(higher_bound)
}
},
);
// return only lower and upper
let lower_to_upper = (
lower_bound.skip((spaces(), string("TO"), spaces())),
upper_bound,
);
(
optional(field()).skip(spaces()),
// try elastic first, if it matches, the range is unbounded
attempt(elastic_unbounded_range).or(lower_to_upper),
)
.map(|(field, (lower, upper))|
// Construct the leaf from extracted field (optional)
// and bounds
UserInputLeaf::Range {
field,
lower,
upper
})
}
fn negate(expr: UserInputAST) -> UserInputAST {
expr.unary(Occur::MustNot)
}
fn must(expr: UserInputAST) -> UserInputAST {
expr.unary(Occur::Must)
fn leaf<'a>() -> impl Parser<&'a str, Output = UserInputAST> {
parser(|input| {
char('(')
.with(ast())
.skip(char(')'))
.or(char('*').map(|_| UserInputAST::from(UserInputLeaf::All)))
.or(attempt(
string("NOT").skip(spaces1()).with(leaf()).map(negate),
))
.or(attempt(range().map(UserInputAST::from)))
.or(literal().map(UserInputAST::from))
.parse_stream(input)
.into_result()
})
}
parser! {
fn leaf[I]()(I) -> UserInputAST
where [I: Stream<Item = char>] {
char('-').with(leaf()).map(negate)
.or(char('+').with(leaf()).map(must))
.or(char('(').with(ast()).skip(char(')')))
.or(char('*').map(|_| UserInputAST::from(UserInputLeaf::All)))
.or(attempt(string("NOT").skip(spaces1()).with(leaf()).map(negate)))
.or(attempt(range().map(UserInputAST::from)))
.or(literal().map(UserInputAST::from))
}
fn occur_symbol<'a>() -> impl Parser<&'a str, Output = Occur> {
char('-')
.map(|_| Occur::MustNot)
.or(char('+').map(|_| Occur::Must))
}
fn occur_leaf<'a>() -> impl Parser<&'a str, Output = (Option<Occur>, UserInputAST)> {
(optional(occur_symbol()), boosted_leaf())
}
fn positive_float_number<'a>() -> impl Parser<&'a str, Output = f32> {
(many1(digit()), optional((char('.'), many1(digit())))).map(
|(int_part, decimal_part_opt): (String, Option<(char, String)>)| {
let mut float_str = int_part;
if let Some((chr, decimal_str)) = decimal_part_opt {
float_str.push(chr);
float_str.push_str(&decimal_str);
}
float_str.parse::<f32>().unwrap()
},
)
}
fn boost<'a>() -> impl Parser<&'a str, Output = f32> {
(char('^'), positive_float_number()).map(|(_, boost)| boost)
}
fn boosted_leaf<'a>() -> impl Parser<&'a str, Output = UserInputAST> {
(leaf(), optional(boost())).map(|(leaf, boost_opt)| match boost_opt {
Some(boost) if (boost - 1.0).abs() > std::f32::EPSILON => {
UserInputAST::Boost(Box::new(leaf), boost)
}
_ => leaf,
})
}
#[derive(Clone, Copy)]
@@ -174,13 +212,10 @@ enum BinaryOperand {
And,
}
parser! {
fn binary_operand[I]()(I) -> BinaryOperand
where [I: Stream<Item = char>]
{
string("AND").with(value(BinaryOperand::And))
.or(string("OR").with(value(BinaryOperand::Or)))
}
fn binary_operand<'a>() -> impl Parser<&'a str, Output = BinaryOperand> {
string("AND")
.with(value(BinaryOperand::And))
.or(string("OR").with(value(BinaryOperand::Or)))
}
fn aggregate_binary_expressions(
@@ -208,37 +243,81 @@ fn aggregate_binary_expressions(
}
}
parser! {
pub fn ast[I]()(I) -> UserInputAST
where [I: Stream<Item = char>]
{
let operand_leaf = (binary_operand().skip(spaces()), leaf().skip(spaces()));
let boolean_expr = (leaf().skip(spaces().silent()), many1(operand_leaf)).map(
|(left, right)| aggregate_binary_expressions(left,right));
let whitespace_separated_leaves = many1(leaf().skip(spaces().silent()))
.map(|subqueries: Vec<UserInputAST>|
if subqueries.len() == 1 {
subqueries.into_iter().next().unwrap()
} else {
UserInputAST::Clause(subqueries.into_iter().collect())
});
let expr = attempt(boolean_expr).or(whitespace_separated_leaves);
spaces().with(expr).skip(spaces())
}
fn operand_leaf<'a>() -> impl Parser<&'a str, Output = (BinaryOperand, UserInputAST)> {
(
binary_operand().skip(spaces()),
boosted_leaf().skip(spaces()),
)
}
parser! {
pub fn parse_to_ast[I]()(I) -> UserInputAST
where [I: Stream<Item = char>]
{
spaces().with(optional(ast()).skip(eof())).map(|opt_ast| opt_ast.unwrap_or_else(UserInputAST::empty_query))
}
pub fn ast<'a>() -> impl Parser<&'a str, Output = UserInputAST> {
let boolean_expr = (boosted_leaf().skip(spaces()), many1(operand_leaf()))
.map(|(left, right)| aggregate_binary_expressions(left, right));
let whitespace_separated_leaves = many1(occur_leaf().skip(spaces().silent())).map(
|subqueries: Vec<(Option<Occur>, UserInputAST)>| {
if subqueries.len() == 1 {
let (occur_opt, ast) = subqueries.into_iter().next().unwrap();
match occur_opt.unwrap_or(Occur::Should) {
Occur::Must | Occur::Should => ast,
Occur::MustNot => UserInputAST::Clause(vec![(Some(Occur::MustNot), ast)]),
}
} else {
UserInputAST::Clause(subqueries.into_iter().collect())
}
},
);
let expr = attempt(boolean_expr).or(whitespace_separated_leaves);
spaces().with(expr).skip(spaces())
}
pub fn parse_to_ast<'a>() -> impl Parser<&'a str, Output = UserInputAST> {
spaces()
.with(optional(ast()).skip(eof()))
.map(|opt_ast| opt_ast.unwrap_or_else(UserInputAST::empty_query))
}
#[cfg(test)]
mod test {
use super::*;
use combine::parser::Parser;
pub fn nearly_equals(a: f32, b: f32) -> bool {
(a - b).abs() < 0.0005 * (a + b).abs()
}
fn assert_nearly_equals(expected: f32, val: f32) {
assert!(
nearly_equals(val, expected),
"Got {}, expected {}.",
val,
expected
);
}
#[test]
fn test_occur_symbol() {
assert_eq!(super::occur_symbol().parse("-"), Ok((Occur::MustNot, "")));
assert_eq!(super::occur_symbol().parse("+"), Ok((Occur::Must, "")));
}
#[test]
fn test_positive_float_number() {
fn valid_parse(float_str: &str, expected_val: f32, expected_remaining: &str) {
let (val, remaining) = positive_float_number().parse(float_str).unwrap();
assert_eq!(remaining, expected_remaining);
assert_nearly_equals(val, expected_val);
}
fn error_parse(float_str: &str) {
assert!(positive_float_number().parse(float_str).is_err());
}
valid_parse("1.0", 1.0f32, "");
valid_parse("1", 1.0f32, "");
valid_parse("0.234234 aaa", 0.234234f32, " aaa");
error_parse(".3332");
error_parse("1.");
error_parse("-1.");
}
fn test_parse_query_to_ast_helper(query: &str, expected: &str) {
let query = parse_to_ast().parse(query).unwrap().0;
@@ -269,15 +348,24 @@ mod test {
"Err(UnexpectedParse)"
);
test_parse_query_to_ast_helper("NOTa", "\"NOTa\"");
test_parse_query_to_ast_helper("NOT a", "-(\"a\")");
test_parse_query_to_ast_helper("NOT a", "(-\"a\")");
}
#[test]
fn test_boosting() {
assert!(parse_to_ast().parse("a^2^3").is_err());
assert!(parse_to_ast().parse("a^2^").is_err());
test_parse_query_to_ast_helper("a^3", "(\"a\")^3");
test_parse_query_to_ast_helper("a^3 b^2", "(*(\"a\")^3 *(\"b\")^2)");
test_parse_query_to_ast_helper("a^1", "\"a\"");
}
#[test]
fn test_parse_query_to_ast_binary_op() {
test_parse_query_to_ast_helper("a AND b", "(+(\"a\") +(\"b\"))");
test_parse_query_to_ast_helper("a OR b", "(?(\"a\") ?(\"b\"))");
test_parse_query_to_ast_helper("a OR b AND c", "(?(\"a\") ?((+(\"b\") +(\"c\"))))");
test_parse_query_to_ast_helper("a AND b AND c", "(+(\"a\") +(\"b\") +(\"c\"))");
test_parse_query_to_ast_helper("a AND b", "(+\"a\" +\"b\")");
test_parse_query_to_ast_helper("a OR b", "(?\"a\" ?\"b\")");
test_parse_query_to_ast_helper("a OR b AND c", "(?\"a\" ?(+\"b\" +\"c\"))");
test_parse_query_to_ast_helper("a AND b AND c", "(+\"a\" +\"b\" +\"c\")");
assert_eq!(
format!("{:?}", parse_to_ast().parse("a OR b aaa")),
"Err(UnexpectedParse)"
@@ -315,6 +403,13 @@ mod test {
test_parse_query_to_ast_helper("weight: <= 70.5", "weight:{\"*\" TO \"70.5\"]");
}
#[test]
fn test_occur_leaf() {
let ((occur, ast), _) = super::occur_leaf().parse("+abc").unwrap();
assert_eq!(occur, Some(Occur::Must));
assert_eq!(format!("{:?}", ast), "\"abc\"");
}
#[test]
fn test_range_parser() {
// testing the range() parser separately
@@ -343,32 +438,67 @@ mod test {
fn test_parse_query_to_triming_spaces() {
test_parse_query_to_ast_helper(" abc", "\"abc\"");
test_parse_query_to_ast_helper("abc ", "\"abc\"");
test_parse_query_to_ast_helper("( a OR abc)", "(?(\"a\") ?(\"abc\"))");
test_parse_query_to_ast_helper("(a OR abc)", "(?(\"a\") ?(\"abc\"))");
test_parse_query_to_ast_helper("(a OR abc)", "(?(\"a\") ?(\"abc\"))");
test_parse_query_to_ast_helper("a OR abc ", "(?(\"a\") ?(\"abc\"))");
test_parse_query_to_ast_helper("(a OR abc )", "(?(\"a\") ?(\"abc\"))");
test_parse_query_to_ast_helper("(a OR abc) ", "(?(\"a\") ?(\"abc\"))");
test_parse_query_to_ast_helper("( a OR abc)", "(?\"a\" ?\"abc\")");
test_parse_query_to_ast_helper("(a OR abc)", "(?\"a\" ?\"abc\")");
test_parse_query_to_ast_helper("(a OR abc)", "(?\"a\" ?\"abc\")");
test_parse_query_to_ast_helper("a OR abc ", "(?\"a\" ?\"abc\")");
test_parse_query_to_ast_helper("(a OR abc )", "(?\"a\" ?\"abc\")");
test_parse_query_to_ast_helper("(a OR abc) ", "(?\"a\" ?\"abc\")");
}
#[test]
fn test_parse_query_to_ast() {
fn test_parse_query_single_term() {
test_parse_query_to_ast_helper("abc", "\"abc\"");
test_parse_query_to_ast_helper("a b", "(\"a\" \"b\")");
test_parse_query_to_ast_helper("+(a b)", "+((\"a\" \"b\"))");
test_parse_query_to_ast_helper("+d", "+(\"d\")");
test_parse_query_to_ast_helper("+(a b) +d", "(+((\"a\" \"b\")) +(\"d\"))");
test_parse_query_to_ast_helper("(+a +b) d", "((+(\"a\") +(\"b\")) \"d\")");
test_parse_query_to_ast_helper("(+a)", "+(\"a\")");
test_parse_query_to_ast_helper("(+a +b)", "(+(\"a\") +(\"b\"))");
}
#[test]
fn test_parse_query_default_clause() {
test_parse_query_to_ast_helper("a b", "(*\"a\" *\"b\")");
}
#[test]
fn test_parse_query_must_default_clause() {
test_parse_query_to_ast_helper("+(a b)", "(*\"a\" *\"b\")");
}
#[test]
fn test_parse_query_must_single_term() {
test_parse_query_to_ast_helper("+d", "\"d\"");
}
#[test]
fn test_single_term_with_field() {
test_parse_query_to_ast_helper("abc:toto", "abc:\"toto\"");
}
#[test]
fn test_single_term_with_float() {
test_parse_query_to_ast_helper("abc:1.1", "abc:\"1.1\"");
test_parse_query_to_ast_helper("+abc:toto", "+(abc:\"toto\")");
test_parse_query_to_ast_helper("(+abc:toto -titi)", "(+(abc:\"toto\") -(\"titi\"))");
test_parse_query_to_ast_helper("-abc:toto", "-(abc:\"toto\")");
test_parse_query_to_ast_helper("abc:a b", "(abc:\"a\" \"b\")");
}
#[test]
fn test_must_clause() {
test_parse_query_to_ast_helper("(+a +b)", "(+\"a\" +\"b\")");
}
#[test]
fn test_parse_test_query_plus_a_b_plus_d() {
test_parse_query_to_ast_helper("+(a b) +d", "(+(*\"a\" *\"b\") +\"d\")");
}
#[test]
fn test_parse_test_query_other() {
test_parse_query_to_ast_helper("(+a +b) d", "(*(+\"a\" +\"b\") *\"d\")");
test_parse_query_to_ast_helper("+abc:toto", "abc:\"toto\"");
test_parse_query_to_ast_helper("(+abc:toto -titi)", "(+abc:\"toto\" -\"titi\")");
test_parse_query_to_ast_helper("-abc:toto", "(-abc:\"toto\")");
test_parse_query_to_ast_helper("abc:a b", "(*abc:\"a\" *\"b\")");
test_parse_query_to_ast_helper("abc:\"a b\"", "abc:\"a b\"");
test_parse_query_to_ast_helper("foo:[1 TO 5]", "foo:[\"1\" TO \"5\"]");
}
#[test]
fn test_parse_query_with_range() {
test_parse_query_to_ast_helper("[1 TO 5]", "[\"1\" TO \"5\"]");
test_parse_query_to_ast_helper("foo:{a TO z}", "foo:{\"a\" TO \"z\"}");
test_parse_query_to_ast_helper("foo:[1 TO toto}", "foo:[\"1\" TO \"toto\"}");

View File

@@ -85,14 +85,14 @@ impl UserInputBound {
}
pub enum UserInputAST {
Clause(Vec<UserInputAST>),
Unary(Occur, Box<UserInputAST>),
Clause(Vec<(Option<Occur>, UserInputAST)>),
Leaf(Box<UserInputLeaf>),
Boost(Box<UserInputAST>, f32),
}
impl UserInputAST {
pub fn unary(self, occur: Occur) -> UserInputAST {
UserInputAST::Unary(occur, Box::new(self))
UserInputAST::Clause(vec![(Some(occur), self)])
}
fn compose(occur: Occur, asts: Vec<UserInputAST>) -> UserInputAST {
@@ -103,7 +103,7 @@ impl UserInputAST {
} else {
UserInputAST::Clause(
asts.into_iter()
.map(|ast: UserInputAST| ast.unary(occur))
.map(|ast: UserInputAST| (Some(occur), ast))
.collect::<Vec<_>>(),
)
}
@@ -134,26 +134,38 @@ impl From<UserInputLeaf> for UserInputAST {
}
}
fn print_occur_ast(
occur_opt: Option<Occur>,
ast: &UserInputAST,
formatter: &mut fmt::Formatter,
) -> fmt::Result {
if let Some(occur) = occur_opt {
write!(formatter, "{}{:?}", occur, ast)?;
} else {
write!(formatter, "*{:?}", ast)?;
}
Ok(())
}
impl fmt::Debug for UserInputAST {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
match *self {
UserInputAST::Clause(ref subqueries) => {
if subqueries.is_empty() {
write!(formatter, "<emptyclause>")?;
} else {
write!(formatter, "(")?;
write!(formatter, "{:?}", &subqueries[0])?;
print_occur_ast(subqueries[0].0, &subqueries[0].1, formatter)?;
for subquery in &subqueries[1..] {
write!(formatter, " {:?}", subquery)?;
write!(formatter, " ")?;
print_occur_ast(subquery.0, &subquery.1, formatter)?;
}
write!(formatter, ")")?;
}
Ok(())
}
UserInputAST::Unary(ref occur, ref subquery) => {
write!(formatter, "{}({:?})", occur, subquery)
}
UserInputAST::Leaf(ref subquery) => write!(formatter, "{:?}", subquery),
UserInputAST::Boost(ref leaf, boost) => write!(formatter, "({:?})^{}", leaf, boost),
}
}
}

View File

@@ -11,13 +11,13 @@ impl<TCustomScorer, TScore> CustomScoreTopCollector<TCustomScorer, TScore>
where
TScore: Clone + PartialOrd,
{
pub fn new(
pub(crate) fn new(
custom_scorer: TCustomScorer,
limit: usize,
collector: TopCollector<TScore>,
) -> CustomScoreTopCollector<TCustomScorer, TScore> {
CustomScoreTopCollector {
custom_scorer,
collector: TopCollector::with_limit(limit),
collector,
}
}
}
@@ -28,7 +28,7 @@ where
/// It is the segment local version of the [`CustomScorer`](./trait.CustomScorer.html).
pub trait CustomSegmentScorer<TScore>: 'static {
/// Computes the score of a specific `doc`.
fn score(&self, doc: DocId) -> TScore;
fn score(&mut self, doc: DocId) -> TScore;
}
/// `CustomScorer` makes it possible to define any kind of score.
@@ -117,9 +117,9 @@ where
impl<F, TScore> CustomSegmentScorer<TScore> for F
where
F: 'static + Sync + Send + Fn(DocId) -> TScore,
F: 'static + FnMut(DocId) -> TScore,
{
fn score(&self, doc: DocId) -> TScore {
fn score(&mut self, doc: DocId) -> TScore {
(self)(doc)
}
}

View File

@@ -1,6 +1,5 @@
use crate::collector::Collector;
use crate::collector::SegmentCollector;
use crate::docset::SkipResult;
use crate::fastfield::FacetReader;
use crate::schema::Facet;
use crate::schema::Field;
@@ -188,6 +187,11 @@ pub struct FacetSegmentCollector {
collapse_facet_ords: Vec<u64>,
}
enum SkipResult {
Found,
NotFound,
}
fn skip<'a, I: Iterator<Item = &'a Facet>>(
target: &[u8],
collapse_it: &mut Peekable<I>,
@@ -197,14 +201,14 @@ fn skip<'a, I: Iterator<Item = &'a Facet>>(
Some(facet_bytes) => match facet_bytes.encoded_str().as_bytes().cmp(target) {
Ordering::Less => {}
Ordering::Greater => {
return SkipResult::OverStep;
return SkipResult::NotFound;
}
Ordering::Equal => {
return SkipResult::Reached;
return SkipResult::Found;
}
},
None => {
return SkipResult::End;
return SkipResult::NotFound;
}
}
collapse_it.next();
@@ -281,7 +285,7 @@ impl Collector for FacetCollector {
// is positionned on a term that has not been processed yet.
let skip_result = skip(facet_streamer.key(), &mut collapse_facet_it);
match skip_result {
SkipResult::Reached => {
SkipResult::Found => {
// we reach a facet we decided to collapse.
let collapse_depth = facet_depth(facet_streamer.key());
let mut collapsed_id = 0;
@@ -301,7 +305,7 @@ impl Collector for FacetCollector {
}
break;
}
SkipResult::End | SkipResult::OverStep => {
SkipResult::NotFound => {
collapse_mapping.push(0);
if !facet_streamer.advance() {
break;

View File

@@ -109,6 +109,7 @@ pub use self::tweak_score_top_collector::{ScoreSegmentTweaker, ScoreTweaker};
mod facet_collector;
pub use self::facet_collector::FacetCollector;
use crate::query::Weight;
/// `Fruit` is the type for the result of our collection.
/// e.g. `usize` for the `Count` collector.
@@ -154,6 +155,29 @@ pub trait Collector: Sync {
/// Combines the fruit associated to the collection of each segments
/// into one fruit.
fn merge_fruits(&self, segment_fruits: Vec<Self::Fruit>) -> crate::Result<Self::Fruit>;
/// Created a segment collector and
fn collect_segment(
&self,
weight: &dyn Weight,
segment_ord: u32,
reader: &SegmentReader,
) -> crate::Result<<Self::Child as SegmentCollector>::Fruit> {
let mut segment_collector = self.for_segment(segment_ord as u32, reader)?;
if let Some(delete_bitset) = reader.delete_bitset() {
weight.for_each(reader, &mut |doc, score| {
if delete_bitset.is_alive(doc) {
segment_collector.collect(doc, score);
}
})?;
} else {
weight.for_each(reader, &mut |doc, score| {
segment_collector.collect(doc, score);
})?;
}
Ok(segment_collector.harvest())
}
}
/// The `SegmentCollector` is the trait in charge of defining the

View File

@@ -18,9 +18,9 @@ use std::collections::BinaryHeap;
/// Two elements are equal if their feature is equal, and regardless of whether `doc`
/// is equal. This should be perfectly fine for this usage, but let's make sure this
/// struct is never public.
struct ComparableDoc<T, D> {
feature: T,
doc: D,
pub(crate) struct ComparableDoc<T, D> {
pub feature: T,
pub doc: D,
}
impl<T: PartialOrd, D: PartialOrd> PartialOrd for ComparableDoc<T, D> {
@@ -56,7 +56,8 @@ impl<T: PartialOrd, D: PartialOrd> PartialEq for ComparableDoc<T, D> {
impl<T: PartialOrd, D: PartialOrd> Eq for ComparableDoc<T, D> {}
pub(crate) struct TopCollector<T> {
limit: usize,
pub limit: usize,
pub offset: usize,
_marker: PhantomData<T>,
}
@@ -72,14 +73,20 @@ where
if limit < 1 {
panic!("Limit must be strictly greater than 0.");
}
TopCollector {
Self {
limit,
offset: 0,
_marker: PhantomData,
}
}
pub fn limit(&self) -> usize {
self.limit
/// Skip the first "offset" documents when collecting.
///
/// This is equivalent to `OFFSET` in MySQL or PostgreSQL and `start` in
/// Lucene's TopDocsCollector.
pub fn and_offset(mut self, offset: usize) -> TopCollector<T> {
self.offset = offset;
self
}
pub fn merge_fruits(
@@ -92,7 +99,7 @@ where
let mut top_collector = BinaryHeap::new();
for child_fruit in children {
for (feature, doc) in child_fruit {
if top_collector.len() < self.limit {
if top_collector.len() < (self.limit + self.offset) {
top_collector.push(ComparableDoc { feature, doc });
} else if let Some(mut head) = top_collector.peek_mut() {
if head.feature < feature {
@@ -104,6 +111,7 @@ where
Ok(top_collector
.into_sorted_vec()
.into_iter()
.skip(self.offset)
.map(|cdoc| (cdoc.feature, cdoc.doc))
.collect())
}
@@ -113,7 +121,23 @@ where
segment_id: SegmentLocalId,
_: &SegmentReader,
) -> crate::Result<TopSegmentCollector<F>> {
Ok(TopSegmentCollector::new(segment_id, self.limit))
Ok(TopSegmentCollector::new(
segment_id,
self.limit + self.offset,
))
}
/// Create a new TopCollector with the same limit and offset.
///
/// Ideally we would use Into but the blanket implementation seems to cause the Scorer traits
/// to fail.
#[doc(hidden)]
pub(crate) fn into_tscore<TScore: PartialOrd + Clone>(self) -> TopCollector<TScore> {
TopCollector {
limit: self.limit,
offset: self.offset,
_marker: PhantomData,
}
}
}
@@ -187,7 +211,7 @@ impl<T: PartialOrd + Clone> TopSegmentCollector<T> {
#[cfg(test)]
mod tests {
use super::TopSegmentCollector;
use super::{TopCollector, TopSegmentCollector};
use crate::DocAddress;
#[test]
@@ -248,6 +272,48 @@ mod tests {
top_collector_limit_3.harvest()[..2].to_vec(),
);
}
#[test]
fn test_top_collector_with_limit_and_offset() {
let collector = TopCollector::with_limit(2).and_offset(1);
let results = collector
.merge_fruits(vec![vec![
(0.9, DocAddress(0, 1)),
(0.8, DocAddress(0, 2)),
(0.7, DocAddress(0, 3)),
(0.6, DocAddress(0, 4)),
(0.5, DocAddress(0, 5)),
]])
.unwrap();
assert_eq!(
results,
vec![(0.8, DocAddress(0, 2)), (0.7, DocAddress(0, 3)),]
);
}
#[test]
fn test_top_collector_with_limit_larger_than_set_and_offset() {
let collector = TopCollector::with_limit(2).and_offset(1);
let results = collector
.merge_fruits(vec![vec![(0.9, DocAddress(0, 1)), (0.8, DocAddress(0, 2))]])
.unwrap();
assert_eq!(results, vec![(0.8, DocAddress(0, 2)),]);
}
#[test]
fn test_top_collector_with_limit_and_offset_larger_than_set() {
let collector = TopCollector::with_limit(2).and_offset(20);
let results = collector
.merge_fruits(vec![vec![(0.9, DocAddress(0, 1)), (0.8, DocAddress(0, 2))]])
.unwrap();
assert_eq!(results, vec![]);
}
}
#[cfg(all(test, feature = "unstable"))]

View File

@@ -1,18 +1,20 @@
use super::Collector;
use crate::collector::custom_score_top_collector::CustomScoreTopCollector;
use crate::collector::top_collector::TopCollector;
use crate::collector::top_collector::TopSegmentCollector;
use crate::collector::top_collector::{ComparableDoc, TopCollector};
use crate::collector::tweak_score_top_collector::TweakedScoreTopCollector;
use crate::collector::{
CustomScorer, CustomSegmentScorer, ScoreSegmentTweaker, ScoreTweaker, SegmentCollector,
};
use crate::fastfield::FastFieldReader;
use crate::query::Weight;
use crate::schema::Field;
use crate::DocAddress;
use crate::DocId;
use crate::Score;
use crate::SegmentLocalId;
use crate::SegmentReader;
use std::collections::BinaryHeap;
use std::fmt;
/// The `TopDocs` collector keeps track of the top `K` documents
@@ -57,7 +59,11 @@ pub struct TopDocs(TopCollector<Score>);
impl fmt::Debug for TopDocs {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "TopDocs({})", self.0.limit())
write!(
f,
"TopDocs(limit={}, offset={})",
self.0.limit, self.0.offset
)
}
}
@@ -66,7 +72,7 @@ struct ScorerByFastFieldReader {
}
impl CustomSegmentScorer<u64> for ScorerByFastFieldReader {
fn score(&self, doc: DocId) -> u64 {
fn score(&mut self, doc: DocId) -> u64 {
self.ff_reader.get_u64(u64::from(doc))
}
}
@@ -84,7 +90,8 @@ impl CustomScorer<u64> for ScorerByField {
.u64(self.field)
.ok_or_else(|| {
crate::TantivyError::SchemaError(format!(
"Field requested is not a i64/u64 fast field."
"Field requested ({:?}) is not a i64/u64 fast field.",
self.field
))
})?;
Ok(ScorerByFastFieldReader { ff_reader })
@@ -100,6 +107,45 @@ impl TopDocs {
TopDocs(TopCollector::with_limit(limit))
}
/// Skip the first "offset" documents when collecting.
///
/// This is equivalent to `OFFSET` in MySQL or PostgreSQL and `start` in
/// Lucene's TopDocsCollector.
///
/// ```rust
/// use tantivy::collector::TopDocs;
/// use tantivy::query::QueryParser;
/// use tantivy::schema::{Schema, TEXT};
/// use tantivy::{doc, DocAddress, Index};
///
/// let mut schema_builder = Schema::builder();
/// let title = schema_builder.add_text_field("title", TEXT);
/// let schema = schema_builder.build();
/// let index = Index::create_in_ram(schema);
///
/// let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap();
/// index_writer.add_document(doc!(title => "The Name of the Wind"));
/// index_writer.add_document(doc!(title => "The Diary of Muadib"));
/// index_writer.add_document(doc!(title => "A Dairy Cow"));
/// index_writer.add_document(doc!(title => "The Diary of a Young Girl"));
/// index_writer.add_document(doc!(title => "The Diary of Lena Mukhina"));
/// assert!(index_writer.commit().is_ok());
///
/// let reader = index.reader().unwrap();
/// let searcher = reader.searcher();
///
/// let query_parser = QueryParser::for_index(&index, vec![title]);
/// let query = query_parser.parse_query("diary").unwrap();
/// let top_docs = searcher.search(&query, &TopDocs::with_limit(2).and_offset(1)).unwrap();
///
/// assert_eq!(top_docs.len(), 2);
/// assert_eq!(&top_docs[0], &(0.5204813, DocAddress(0, 4)));
/// assert_eq!(&top_docs[1], &(0.4793185, DocAddress(0, 3)));
/// ```
pub fn and_offset(self, offset: usize) -> TopDocs {
TopDocs(self.0.and_offset(offset))
}
/// Set top-K to rank documents by a given fast field.
///
/// ```rust
@@ -280,7 +326,7 @@ impl TopDocs {
TScoreSegmentTweaker: ScoreSegmentTweaker<TScore> + 'static,
TScoreTweaker: ScoreTweaker<TScore, Child = TScoreSegmentTweaker>,
{
TweakedScoreTopCollector::new(score_tweaker, self.0.limit())
TweakedScoreTopCollector::new(score_tweaker, self.0.into_tscore())
}
/// Ranks the documents using a custom score.
@@ -394,7 +440,7 @@ impl TopDocs {
TCustomSegmentScorer: CustomSegmentScorer<TScore> + 'static,
TCustomScorer: CustomScorer<TScore, Child = TCustomSegmentScorer>,
{
CustomScoreTopCollector::new(custom_score, self.0.limit())
CustomScoreTopCollector::new(custom_score, self.0.into_tscore())
}
}
@@ -422,6 +468,64 @@ impl Collector for TopDocs {
) -> crate::Result<Self::Fruit> {
self.0.merge_fruits(child_fruits)
}
fn collect_segment(
&self,
weight: &dyn Weight,
segment_ord: u32,
reader: &SegmentReader,
) -> crate::Result<<Self::Child as SegmentCollector>::Fruit> {
let heap_len = self.0.limit + self.0.offset;
let mut heap: BinaryHeap<ComparableDoc<Score, DocId>> = BinaryHeap::with_capacity(heap_len);
if let Some(delete_bitset) = reader.delete_bitset() {
let mut threshold = f32::MIN;
weight.for_each_pruning(threshold, reader, &mut |doc, score| {
if delete_bitset.is_deleted(doc) {
return threshold;
}
let heap_item = ComparableDoc {
feature: score,
doc,
};
if heap.len() < heap_len {
heap.push(heap_item);
if heap.len() == heap_len {
threshold = heap.peek().map(|el| el.feature).unwrap_or(f32::MIN);
}
return threshold;
}
*heap.peek_mut().unwrap() = heap_item;
threshold = heap.peek().map(|el| el.feature).unwrap_or(std::f32::MIN);
threshold
})?;
} else {
weight.for_each_pruning(f32::MIN, reader, &mut |doc, score| {
let heap_item = ComparableDoc {
feature: score,
doc,
};
if heap.len() < heap_len {
heap.push(heap_item);
// TODO the threshold is suboptimal for heap.len == heap_len
if heap.len() == heap_len {
return heap.peek().map(|el| el.feature).unwrap_or(f32::MIN);
} else {
return f32::MIN;
}
}
*heap.peek_mut().unwrap() = heap_item;
heap.peek().map(|el| el.feature).unwrap_or(std::f32::MIN)
})?;
}
let fruit: Vec<(Score, DocAddress)> = heap
.into_sorted_vec()
.into_iter()
.map(|cid| (cid.feature, DocAddress(segment_ord, cid.doc)))
.collect();
Ok(fruit)
}
}
/// Segment Collector associated to `TopDocs`.
@@ -431,7 +535,7 @@ impl SegmentCollector for TopScoreSegmentCollector {
type Fruit = Vec<(Score, DocAddress)>;
fn collect(&mut self, doc: DocId, score: Score) {
self.0.collect(doc, score)
self.0.collect(doc, score);
}
fn harvest(self) -> Vec<(Score, DocAddress)> {
@@ -445,11 +549,10 @@ mod tests {
use crate::collector::Collector;
use crate::query::{AllQuery, Query, QueryParser};
use crate::schema::{Field, Schema, FAST, STORED, TEXT};
use crate::DocAddress;
use crate::Index;
use crate::IndexWriter;
use crate::Score;
use itertools::Itertools;
use crate::{DocAddress, DocId, SegmentReader};
fn make_index() -> Index {
let mut schema_builder = Schema::builder();
@@ -489,6 +592,21 @@ mod tests {
);
}
#[test]
fn test_top_collector_not_at_capacity_with_offset() {
let index = make_index();
let field = index.schema().get_field("text").unwrap();
let query_parser = QueryParser::for_index(&index, vec![field]);
let text_query = query_parser.parse_query("droopy tax").unwrap();
let score_docs: Vec<(Score, DocAddress)> = index
.reader()
.unwrap()
.searcher()
.search(&text_query, &TopDocs::with_limit(4).and_offset(2))
.unwrap();
assert_eq!(score_docs, vec![(0.48527452, DocAddress(0, 0))]);
}
#[test]
fn test_top_collector_at_capacity() {
let index = make_index();
@@ -510,6 +628,27 @@ mod tests {
);
}
#[test]
fn test_top_collector_at_capacity_with_offset() {
let index = make_index();
let field = index.schema().get_field("text").unwrap();
let query_parser = QueryParser::for_index(&index, vec![field]);
let text_query = query_parser.parse_query("droopy tax").unwrap();
let score_docs: Vec<(Score, DocAddress)> = index
.reader()
.unwrap()
.searcher()
.search(&text_query, &TopDocs::with_limit(2).and_offset(1))
.unwrap();
assert_eq!(
score_docs,
vec![
(0.5376842, DocAddress(0u32, 2)),
(0.48527452, DocAddress(0, 0))
]
);
}
#[test]
fn test_top_collector_stable_sorting() {
let index = make_index();
@@ -523,8 +662,8 @@ mod tests {
// precondition for the test to be meaningful: we did get documents
// with the same score
assert!(page_1.iter().map(|result| result.0).all_equal());
assert!(page_2.iter().map(|result| result.0).all_equal());
assert!(page_1.iter().all(|result| result.0 == page_1[0].0));
assert!(page_2.iter().all(|result| result.0 == page_2[0].0));
// sanity check since we're relying on make_index()
assert_eq!(page_1.len(), 2);
@@ -614,12 +753,59 @@ mod tests {
let top_collector = TopDocs::with_limit(4).order_by_u64_field(size);
let err = top_collector.for_segment(0, segment);
if let Err(crate::TantivyError::SchemaError(msg)) = err {
assert_eq!(msg, "Field requested is not a i64/u64 fast field.");
assert_eq!(
msg,
"Field requested (Field(1)) is not a i64/u64 fast field."
);
} else {
assert!(false);
}
}
#[test]
fn test_tweak_score_top_collector_with_offset() {
let index = make_index();
let field = index.schema().get_field("text").unwrap();
let query_parser = QueryParser::for_index(&index, vec![field]);
let text_query = query_parser.parse_query("droopy tax").unwrap();
let collector = TopDocs::with_limit(2).and_offset(1).tweak_score(
move |_segment_reader: &SegmentReader| move |doc: DocId, _original_score: Score| doc,
);
let score_docs: Vec<(u32, DocAddress)> = index
.reader()
.unwrap()
.searcher()
.search(&text_query, &collector)
.unwrap();
assert_eq!(
score_docs,
vec![(1, DocAddress(0, 1)), (0, DocAddress(0, 0)),]
);
}
#[test]
fn test_custom_score_top_collector_with_offset() {
let index = make_index();
let field = index.schema().get_field("text").unwrap();
let query_parser = QueryParser::for_index(&index, vec![field]);
let text_query = query_parser.parse_query("droopy tax").unwrap();
let collector = TopDocs::with_limit(2)
.and_offset(1)
.custom_score(move |_segment_reader: &SegmentReader| move |doc: DocId| doc);
let score_docs: Vec<(u32, DocAddress)> = index
.reader()
.unwrap()
.searcher()
.search(&text_query, &collector)
.unwrap();
assert_eq!(
score_docs,
vec![(1, DocAddress(0, 1)), (0, DocAddress(0, 0)),]
);
}
fn index(
query: &str,
query_field: Field,

View File

@@ -14,11 +14,11 @@ where
{
pub fn new(
score_tweaker: TScoreTweaker,
limit: usize,
collector: TopCollector<TScore>,
) -> TweakedScoreTopCollector<TScoreTweaker, TScore> {
TweakedScoreTopCollector {
score_tweaker,
collector: TopCollector::with_limit(limit),
collector,
}
}
}
@@ -29,7 +29,7 @@ where
/// It is the segment local version of the [`ScoreTweaker`](./trait.ScoreTweaker.html).
pub trait ScoreSegmentTweaker<TScore>: 'static {
/// Tweak the given `score` for the document `doc`.
fn score(&self, doc: DocId, score: Score) -> TScore;
fn score(&mut self, doc: DocId, score: Score) -> TScore;
}
/// `ScoreTweaker` makes it possible to tweak the score
@@ -121,9 +121,9 @@ where
impl<F, TScore> ScoreSegmentTweaker<TScore> for F
where
F: 'static + Sync + Send + Fn(DocId, Score) -> TScore,
F: 'static + FnMut(DocId, Score) -> TScore,
{
fn score(&self, doc: DocId, score: Score) -> TScore {
fn score(&mut self, doc: DocId, score: Score) -> TScore {
(self)(doc, score)
}
}

View File

@@ -33,6 +33,10 @@ impl TinySet {
TinySet(0u64)
}
pub fn clear(&mut self) {
self.0 = 0u64;
}
/// Returns the complement of the set in `[0, 64[`.
fn complement(self) -> TinySet {
TinySet(!self.0)
@@ -43,6 +47,11 @@ impl TinySet {
!self.intersect(TinySet::singleton(el)).is_empty()
}
/// Returns the number of elements in the TinySet.
pub fn len(self) -> u32 {
self.0.count_ones()
}
/// Returns the intersection of `self` and `other`
pub fn intersect(self, other: TinySet) -> TinySet {
TinySet(self.0 & other.0)
@@ -109,22 +118,12 @@ impl TinySet {
pub fn range_greater_or_equal(from_included: u32) -> TinySet {
TinySet::range_lower(from_included).complement()
}
pub fn clear(&mut self) {
self.0 = 0u64;
}
pub fn len(self) -> u32 {
self.0.count_ones()
}
}
#[derive(Clone)]
pub struct BitSet {
tinysets: Box<[TinySet]>,
len: usize, //< Technically it should be u32, but we
// count multiple inserts.
// `usize` guards us from overflow.
len: usize,
max_value: u32,
}
@@ -204,7 +203,7 @@ mod tests {
use super::BitSet;
use super::TinySet;
use crate::docset::DocSet;
use crate::docset::{DocSet, TERMINATED};
use crate::query::BitSetDocSet;
use crate::tests;
use crate::tests::generate_nonunique_unsorted;
@@ -278,11 +277,13 @@ mod tests {
}
assert_eq!(btreeset.len(), bitset.len());
let mut bitset_docset = BitSetDocSet::from(bitset);
let mut remaining = true;
for el in btreeset.into_iter() {
bitset_docset.advance();
assert!(remaining);
assert_eq!(bitset_docset.doc(), el);
remaining = bitset_docset.advance() != TERMINATED;
}
assert!(!bitset_docset.advance());
assert!(!remaining);
}
#[test]

View File

@@ -18,6 +18,19 @@ pub use byteorder::LittleEndian as Endianness;
/// We do not allow segments with more than
pub const MAX_DOC_LIMIT: u32 = 1 << 31;
pub fn minmax<I, T>(mut vals: I) -> Option<(T, T)>
where
I: Iterator<Item = T>,
T: Copy + Ord,
{
if let Some(first_el) = vals.next() {
return Some(vals.fold((first_el, first_el), |(min_val, max_val), el| {
(min_val.min(el), max_val.max(el))
}));
}
None
}
/// Computes the number of bits that will be used for bitpacking.
///
/// In general the target is the minimum number of bits
@@ -134,6 +147,7 @@ pub fn u64_to_f64(val: u64) -> f64 {
#[cfg(test)]
pub(crate) mod test {
pub use super::minmax;
pub use super::serialize::test::fixed_size_test;
use super::{compute_num_bits, f64_to_u64, i64_to_u64, u64_to_f64, u64_to_i64};
use std::f64;
@@ -199,4 +213,21 @@ pub(crate) mod test {
assert!(((super::MAX_DOC_LIMIT - 1) as i32) >= 0);
assert!((super::MAX_DOC_LIMIT as i32) < 0);
}
#[test]
fn test_minmax_empty() {
let vals: Vec<u32> = vec![];
assert_eq!(minmax(vals.into_iter()), None);
}
#[test]
fn test_minmax_one() {
assert_eq!(minmax(vec![1].into_iter()), Some((1, 1)));
}
#[test]
fn test_minmax_two() {
assert_eq!(minmax(vec![1, 2].into_iter()), Some((1, 2)));
assert_eq!(minmax(vec![2, 1].into_iter()), Some((1, 2)));
}
}

View File

@@ -89,6 +89,19 @@ impl FixedSize for u64 {
const SIZE_IN_BYTES: usize = 8;
}
impl BinarySerializable for f32 {
fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
writer.write_f32::<Endianness>(*self)
}
fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
reader.read_f32::<Endianness>()
}
}
impl FixedSize for f32 {
const SIZE_IN_BYTES: usize = 4;
}
impl BinarySerializable for i64 {
fn serialize<W: Write>(&self, writer: &mut W) -> io::Result<()> {
writer.write_i64::<Endianness>(*self)

View File

@@ -5,7 +5,7 @@ use std::io::Read;
use std::io::Write;
/// Wrapper over a `u64` that serializes as a variable int.
#[derive(Debug, Eq, PartialEq)]
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct VInt(pub u64);
const STOP_BIT: u8 = 128;

View File

@@ -1,4 +1,3 @@
use super::segment::create_segment;
use super::segment::Segment;
use crate::core::Executor;
use crate::core::IndexMeta;
@@ -22,12 +21,13 @@ use crate::schema::FieldType;
use crate::schema::Schema;
use crate::tokenizer::{TextAnalyzer, TokenizerManager};
use crate::IndexWriter;
use num_cpus;
use std::borrow::BorrowMut;
use std::collections::HashSet;
use std::fmt;
#[cfg(feature = "mmap")]
use std::path::{Path, PathBuf};
use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;
fn load_metas(
@@ -283,7 +283,7 @@ impl Index {
TantivyError::LockFailure(
err,
Some(
"Failed to acquire index lock. If you are using\
"Failed to acquire index lock. If you are using \
a regular directory, this means there is already an \
`IndexWriter` working on this `Directory`, in this process \
or in a different process."
@@ -337,7 +337,7 @@ impl Index {
#[doc(hidden)]
pub fn segment(&self, segment_meta: SegmentMeta) -> Segment {
create_segment(self.clone(), segment_meta)
Segment::for_index(self.clone(), segment_meta)
}
/// Creates a new segment.

View File

@@ -3,8 +3,7 @@ use crate::core::SegmentId;
use crate::schema::Schema;
use crate::Opstamp;
use census::{Inventory, TrackedObject};
use serde;
use serde_json;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::fmt;
use std::path::PathBuf;
@@ -214,7 +213,7 @@ pub struct IndexMeta {
#[serde(skip_serializing_if = "Option::is_none")]
/// Payload associated to the last commit.
///
/// Upon commit, clients can optionally add a small `Striing` payload to their commit
/// Upon commit, clients can optionally add a small `String` payload to their commit
/// to help identify this commit.
/// This payload is entirely unused by tantivy.
pub payload: Option<String>,

View File

@@ -7,7 +7,6 @@ use crate::schema::FieldType;
use crate::schema::IndexRecordOption;
use crate::schema::Term;
use crate::termdict::TermDictionary;
use owned_read::OwnedRead;
/// The inverted index reader is in charge of accessing
/// the inverted index associated to a specific field.
@@ -97,8 +96,7 @@ impl InvertedIndexReader {
let offset = term_info.postings_offset as usize;
let end_source = self.postings_source.len();
let postings_slice = self.postings_source.slice(offset, end_source);
let postings_reader = OwnedRead::new(postings_slice);
block_postings.reset(term_info.doc_freq, postings_reader);
block_postings.reset(term_info.doc_freq, postings_slice);
}
/// Returns a block postings given a `Term`.
@@ -127,7 +125,7 @@ impl InvertedIndexReader {
let postings_data = self.postings_source.slice_from(offset);
BlockSegmentPostings::from_data(
term_info.doc_freq,
OwnedRead::new(postings_data),
postings_data,
self.record_option,
requested_option,
)

View File

@@ -1,11 +1,8 @@
use crate::collector::Collector;
use crate::collector::SegmentCollector;
use crate::core::Executor;
use crate::core::InvertedIndexReader;
use crate::core::SegmentReader;
use crate::query::Query;
use crate::query::Scorer;
use crate::query::Weight;
use crate::schema::Document;
use crate::schema::Schema;
use crate::schema::{Field, Term};
@@ -17,26 +14,6 @@ use crate::Index;
use std::fmt;
use std::sync::Arc;
fn collect_segment<C: Collector>(
collector: &C,
weight: &dyn Weight,
segment_ord: u32,
segment_reader: &SegmentReader,
) -> crate::Result<C::Fruit> {
let mut scorer = weight.scorer(segment_reader)?;
let mut segment_collector = collector.for_segment(segment_ord as u32, segment_reader)?;
if let Some(delete_bitset) = segment_reader.delete_bitset() {
scorer.for_each(&mut |doc, score| {
if delete_bitset.is_alive(doc) {
segment_collector.collect(doc, score);
}
});
} else {
scorer.for_each(&mut |doc, score| segment_collector.collect(doc, score));
}
Ok(segment_collector.harvest())
}
/// Holds a list of `SegmentReader`s ready for search.
///
/// It guarantees that the `Segment` will not be removed before
@@ -163,12 +140,7 @@ impl Searcher {
let segment_readers = self.segment_readers();
let fruits = executor.map(
|(segment_ord, segment_reader)| {
collect_segment(
collector,
weight.as_ref(),
segment_ord as u32,
segment_reader,
)
collector.collect_segment(weight.as_ref(), segment_ord as u32, segment_reader)
},
segment_readers.iter().enumerate(),
)?;

View File

@@ -24,15 +24,12 @@ impl fmt::Debug for Segment {
}
}
/// Creates a new segment given an `Index` and a `SegmentId`
///
/// The function is here to make it private outside `tantivy`.
/// #[doc(hidden)]
pub fn create_segment(index: Index, meta: SegmentMeta) -> Segment {
Segment { index, meta }
}
impl Segment {
/// Creates a new segment given an `Index` and a `SegmentId`
pub(crate) fn for_index(index: Index, meta: SegmentMeta) -> Segment {
Segment { index, meta }
}
/// Returns the index the segment belongs to.
pub fn index(&self) -> &Index {
&self.index

View File

@@ -4,6 +4,7 @@ use uuid::Uuid;
#[cfg(test)]
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use std::error::Error;
use std::str::FromStr;
#[cfg(test)]

View File

@@ -8,7 +8,7 @@ use crate::directory::ReadOnlySource;
use crate::fastfield::DeleteBitSet;
use crate::fastfield::FacetReader;
use crate::fastfield::FastFieldReaders;
use crate::fieldnorm::FieldNormReader;
use crate::fieldnorm::{FieldNormReader, FieldNormReaders};
use crate::schema::Field;
use crate::schema::FieldType;
use crate::schema::Schema;
@@ -48,7 +48,7 @@ pub struct SegmentReader {
positions_composite: CompositeFile,
positions_idx_composite: CompositeFile,
fast_fields_readers: Arc<FastFieldReaders>,
fieldnorms_composite: CompositeFile,
fieldnorm_readers: FieldNormReaders,
store_source: ReadOnlySource,
delete_bitset_opt: Option<DeleteBitSet>,
@@ -126,8 +126,8 @@ impl SegmentReader {
/// They are simply stored as a fast field, serialized in
/// the `.fieldnorm` file of the segment.
pub fn get_fieldnorms_reader(&self, field: Field) -> FieldNormReader {
if let Some(fieldnorm_source) = self.fieldnorms_composite.open_read(field) {
FieldNormReader::open(fieldnorm_source)
if let Some(fieldnorm_source) = self.fieldnorm_readers.get_field(field) {
fieldnorm_source
} else {
let field_name = self.schema.get_field_name(field);
let err_msg = format!(
@@ -178,8 +178,8 @@ impl SegmentReader {
let fast_field_readers =
Arc::new(FastFieldReaders::load_all(&schema, &fast_fields_composite)?);
let fieldnorms_data = segment.open_read(SegmentComponent::FIELDNORMS)?;
let fieldnorms_composite = CompositeFile::open(&fieldnorms_data)?;
let fieldnorm_data = segment.open_read(SegmentComponent::FIELDNORMS)?;
let fieldnorm_readers = FieldNormReaders::new(fieldnorm_data)?;
let delete_bitset_opt = if segment.meta().has_deletes() {
let delete_data = segment.open_read(SegmentComponent::DELETE)?;
@@ -195,7 +195,7 @@ impl SegmentReader {
termdict_composite,
postings_composite,
fast_fields_readers: fast_field_readers,
fieldnorms_composite,
fieldnorm_readers,
segment_id: segment.id(),
store_source,
delete_bitset_opt,
@@ -295,8 +295,8 @@ impl SegmentReader {
}
/// Returns an iterator that will iterate over the alive document ids
pub fn doc_ids_alive(&self) -> SegmentReaderAliveDocsIterator<'_> {
SegmentReaderAliveDocsIterator::new(&self)
pub fn doc_ids_alive<'a>(&'a self) -> impl Iterator<Item = DocId> + 'a {
(0u32..self.max_doc).filter(move |doc| !self.is_deleted(*doc))
}
/// Summarize total space usage of this segment.
@@ -308,7 +308,7 @@ impl SegmentReader {
self.positions_composite.space_usage(),
self.positions_idx_composite.space_usage(),
self.fast_fields_readers.space_usage(),
self.fieldnorms_composite.space_usage(),
self.fieldnorm_readers.space_usage(),
self.get_store_reader().space_usage(),
self.delete_bitset_opt
.as_ref()
@@ -324,52 +324,6 @@ impl fmt::Debug for SegmentReader {
}
}
/// Implements the iterator trait to allow easy iteration
/// over non-deleted ("alive") DocIds in a SegmentReader
pub struct SegmentReaderAliveDocsIterator<'a> {
reader: &'a SegmentReader,
max_doc: DocId,
current: DocId,
}
impl<'a> SegmentReaderAliveDocsIterator<'a> {
pub fn new(reader: &'a SegmentReader) -> SegmentReaderAliveDocsIterator<'a> {
SegmentReaderAliveDocsIterator {
reader,
max_doc: reader.max_doc(),
current: 0,
}
}
}
impl<'a> Iterator for SegmentReaderAliveDocsIterator<'a> {
type Item = DocId;
fn next(&mut self) -> Option<Self::Item> {
// TODO: Use TinySet (like in BitSetDocSet) to speed this process up
if self.current >= self.max_doc {
return None;
}
// find the next alive doc id
while self.reader.is_deleted(self.current) {
self.current += 1;
if self.current >= self.max_doc {
return None;
}
}
// capture the current alive DocId
let result = Some(self.current);
// move down the chain
self.current += 1;
result
}
}
#[cfg(test)]
mod test {
use crate::core::Index;

View File

@@ -8,6 +8,8 @@ use crc32fast::Hasher;
use std::io;
use std::io::Write;
const FOOTER_MAX_LEN: usize = 10_000;
type CrcHashU32 = u32;
#[derive(Debug, Clone, PartialEq)]
@@ -92,12 +94,24 @@ impl Footer {
match &self.versioned_footer {
VersionedFooter::V1 {
crc32: _crc,
store_compression: compression,
store_compression,
} => {
if &library_version.store_compression != compression {
if &library_version.store_compression != store_compression {
return Err(Incompatibility::CompressionMismatch {
library_compression_format: library_version.store_compression.to_string(),
index_compression_format: compression.to_string(),
index_compression_format: store_compression.to_string(),
});
}
Ok(())
}
VersionedFooter::V2 {
crc32: _crc,
store_compression,
} => {
if &library_version.store_compression != store_compression {
return Err(Incompatibility::CompressionMismatch {
library_compression_format: library_version.store_compression.to_string(),
index_compression_format: store_compression.to_string(),
});
}
Ok(())
@@ -118,24 +132,29 @@ pub enum VersionedFooter {
crc32: CrcHashU32,
store_compression: String,
},
// Introduction of the Block WAND information.
V2 {
crc32: CrcHashU32,
store_compression: String,
},
}
impl BinarySerializable for VersionedFooter {
fn serialize<W: io::Write>(&self, writer: &mut W) -> io::Result<()> {
let mut buf = Vec::new();
match self {
VersionedFooter::V1 {
VersionedFooter::V2 {
crc32,
store_compression: compression,
} => {
// Serializes a valid `VersionedFooter` or panics if the version is unknown
// [ version | crc_hash | compression_mode ]
// [ 0..4 | 4..8 | variable ]
BinarySerializable::serialize(&1u32, &mut buf)?;
BinarySerializable::serialize(&2u32, &mut buf)?;
BinarySerializable::serialize(crc32, &mut buf)?;
BinarySerializable::serialize(compression, &mut buf)?;
}
VersionedFooter::UnknownVersion => {
VersionedFooter::V1 { .. } | VersionedFooter::UnknownVersion => {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Cannot serialize an unknown versioned footer ",
@@ -143,22 +162,40 @@ impl BinarySerializable for VersionedFooter {
}
}
BinarySerializable::serialize(&VInt(buf.len() as u64), writer)?;
assert!(buf.len() <= FOOTER_MAX_LEN);
writer.write_all(&buf[..])?;
Ok(())
}
fn deserialize<R: io::Read>(reader: &mut R) -> io::Result<Self> {
let len = VInt::deserialize(reader)?.0 as usize;
if len > FOOTER_MAX_LEN {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"Footer seems invalid as it suggests a footer len of {}. File is corrupted, \
or the index was created with a different & old version of tantivy.",
len
),
));
}
let mut buf = vec![0u8; len];
reader.read_exact(&mut buf[..])?;
let mut cursor = &buf[..];
let version = u32::deserialize(&mut cursor)?;
if version == 1 {
let crc32 = u32::deserialize(&mut cursor)?;
let compression = String::deserialize(&mut cursor)?;
let store_compression = String::deserialize(&mut cursor)?;
Ok(VersionedFooter::V1 {
crc32,
store_compression: compression,
store_compression,
})
} else if version == 2 {
let crc32 = u32::deserialize(&mut cursor)?;
let store_compression = String::deserialize(&mut cursor)?;
Ok(VersionedFooter::V2 {
crc32,
store_compression,
})
} else {
Ok(VersionedFooter::UnknownVersion)
@@ -169,6 +206,7 @@ impl BinarySerializable for VersionedFooter {
impl VersionedFooter {
pub fn crc(&self) -> Option<CrcHashU32> {
match self {
VersionedFooter::V2 { crc32, .. } => Some(*crc32),
VersionedFooter::V1 { crc32, .. } => Some(*crc32),
VersionedFooter::UnknownVersion { .. } => None,
}
@@ -206,7 +244,7 @@ impl<W: TerminatingWrite> Write for FooterProxy<W> {
impl<W: TerminatingWrite> TerminatingWrite for FooterProxy<W> {
fn terminate_ref(&mut self, _: AntiCallToken) -> io::Result<()> {
let crc32 = self.hasher.take().unwrap().finalize();
let footer = Footer::new(VersionedFooter::V1 {
let footer = Footer::new(VersionedFooter::V2 {
crc32,
store_compression: crate::store::COMPRESSION.to_string(),
});
@@ -221,11 +259,12 @@ mod tests {
use super::CrcHashU32;
use super::FooterProxy;
use crate::common::BinarySerializable;
use crate::common::{BinarySerializable, VInt};
use crate::directory::footer::{Footer, VersionedFooter};
use crate::directory::TerminatingWrite;
use byteorder::{ByteOrder, LittleEndian};
use regex::Regex;
use std::io;
#[test]
fn test_versioned_footer() {
@@ -234,15 +273,11 @@ mod tests {
assert!(footer_proxy.terminate().is_ok());
assert_eq!(vec.len(), 167);
let footer = Footer::deserialize(&mut &vec[..]).unwrap();
if let VersionedFooter::V1 {
crc32: _,
store_compression,
} = footer.versioned_footer
{
assert_eq!(store_compression, crate::store::COMPRESSION);
} else {
panic!("Versioned footer should be V1.");
}
assert!(matches!(
footer.versioned_footer,
VersionedFooter::V2 { store_compression, .. }
if store_compression == crate::store::COMPRESSION
));
assert_eq!(&footer.version, crate::version());
}
@@ -250,7 +285,7 @@ mod tests {
fn test_serialize_deserialize_footer() {
let mut buffer = Vec::new();
let crc32 = 123456u32;
let footer: Footer = Footer::new(VersionedFooter::V1 {
let footer: Footer = Footer::new(VersionedFooter::V2 {
crc32,
store_compression: "lz4".to_string(),
});
@@ -262,7 +297,7 @@ mod tests {
#[test]
fn footer_length() {
let crc32 = 1111111u32;
let versioned_footer = VersionedFooter::V1 {
let versioned_footer = VersionedFooter::V2 {
crc32,
store_compression: "lz4".to_string(),
};
@@ -283,7 +318,7 @@ mod tests {
// versionned footer length
12 | 128,
// index format version
1,
2,
0,
0,
0,
@@ -302,7 +337,7 @@ mod tests {
let versioned_footer = VersionedFooter::deserialize(&mut cursor).unwrap();
assert!(cursor.is_empty());
let expected_crc: u32 = LittleEndian::read_u32(&v_footer_bytes[5..9]) as CrcHashU32;
let expected_versioned_footer: VersionedFooter = VersionedFooter::V1 {
let expected_versioned_footer: VersionedFooter = VersionedFooter::V2 {
crc32: expected_crc,
store_compression: "lz4".to_string(),
};
@@ -336,4 +371,20 @@ mod tests {
let res = footer.is_compatible();
assert!(res.is_err());
}
#[test]
fn test_deserialize_too_large_footer() {
let mut buf = vec![];
assert!(FooterProxy::new(&mut buf).terminate().is_ok());
let mut long_len_buf = [0u8; 10];
let num_bytes = VInt(super::FOOTER_MAX_LEN as u64 + 1u64).serialize_into(&mut long_len_buf);
buf[0..num_bytes].copy_from_slice(&long_len_buf[..num_bytes]);
let err = Footer::deserialize(&mut &buf[..]).unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
assert_eq!(
err.to_string(),
"Footer seems invalid as it suggests a footer len of 10001. File is corrupted, \
or the index was created with a different & old version of tantivy."
);
}
}

View File

@@ -11,7 +11,6 @@ use crate::error::DataCorruption;
use crate::Directory;
use crc32fast::Hasher;
use serde_json;
use std::collections::HashSet;
use std::io;
use std::io::Write;

View File

@@ -1,10 +1,3 @@
use fs2;
use notify;
use self::fs2::FileExt;
use self::notify::RawEvent;
use self::notify::RecursiveMode;
use self::notify::Watcher;
use crate::core::META_FILEPATH;
use crate::directory::error::LockError;
use crate::directory::error::{
@@ -20,8 +13,12 @@ use crate::directory::WatchCallback;
use crate::directory::WatchCallbackList;
use crate::directory::WatchHandle;
use crate::directory::{TerminatingWrite, WritePtr};
use atomicwrites;
use fs2::FileExt;
use memmap::Mmap;
use notify::RawEvent;
use notify::RecursiveMode;
use notify::Watcher;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::convert::From;
use std::fmt;
@@ -142,7 +139,7 @@ impl MmapCache {
}
struct WatcherWrapper {
_watcher: Mutex<notify::PollWatcher>,
_watcher: Mutex<notify::RecommendedWatcher>,
watcher_router: Arc<WatchCallbackList>,
}
@@ -150,7 +147,7 @@ impl WatcherWrapper {
pub fn new(path: &Path) -> Result<Self, OpenDirectoryError> {
let (tx, watcher_recv): (Sender<RawEvent>, Receiver<RawEvent>) = channel();
// We need to initialize the
let watcher = notify::poll::PollWatcher::with_delay_ms(tx, 1)
let watcher = notify::raw_watcher(tx)
.and_then(|mut watcher| {
watcher.watch(path, RecursiveMode::Recursive)?;
Ok(watcher)
@@ -223,17 +220,13 @@ struct MmapDirectoryInner {
}
impl MmapDirectoryInner {
fn new(
root_path: PathBuf,
temp_directory: Option<TempDir>,
) -> Result<MmapDirectoryInner, OpenDirectoryError> {
let mmap_directory_inner = MmapDirectoryInner {
fn new(root_path: PathBuf, temp_directory: Option<TempDir>) -> MmapDirectoryInner {
MmapDirectoryInner {
root_path,
mmap_cache: Default::default(),
_temp_directory: temp_directory,
watcher: RwLock::new(None),
};
Ok(mmap_directory_inner)
}
}
fn watch(&self, watch_callback: WatchCallback) -> crate::Result<WatchHandle> {
@@ -267,14 +260,11 @@ impl fmt::Debug for MmapDirectory {
}
impl MmapDirectory {
fn new(
root_path: PathBuf,
temp_directory: Option<TempDir>,
) -> Result<MmapDirectory, OpenDirectoryError> {
let inner = MmapDirectoryInner::new(root_path, temp_directory)?;
Ok(MmapDirectory {
fn new(root_path: PathBuf, temp_directory: Option<TempDir>) -> MmapDirectory {
let inner = MmapDirectoryInner::new(root_path, temp_directory);
MmapDirectory {
inner: Arc::new(inner),
})
}
}
/// Creates a new MmapDirectory in a temporary directory.
@@ -284,7 +274,7 @@ impl MmapDirectory {
pub fn create_from_tempdir() -> Result<MmapDirectory, OpenDirectoryError> {
let tempdir = TempDir::new().map_err(OpenDirectoryError::IoError)?;
let tempdir_path = PathBuf::from(tempdir.path());
MmapDirectory::new(tempdir_path, Some(tempdir))
Ok(MmapDirectory::new(tempdir_path, Some(tempdir)))
}
/// Opens a MmapDirectory in a directory.
@@ -302,7 +292,7 @@ impl MmapDirectory {
directory_path,
)))
} else {
Ok(MmapDirectory::new(PathBuf::from(directory_path), None)?)
Ok(MmapDirectory::new(PathBuf::from(directory_path), None))
}
}

View File

@@ -144,6 +144,22 @@ impl RAMDirectory {
pub fn total_mem_usage(&self) -> usize {
self.fs.read().unwrap().total_mem_usage()
}
/// Write a copy of all of the files saved in the RAMDirectory in the target `Directory`.
///
/// Files are all written using the `Directory::write` meaning, even if they were
/// written using the `atomic_write` api.
///
/// If an error is encounterred, files may be persisted partially.
pub fn persist(&self, dest: &mut dyn Directory) -> crate::Result<()> {
let wlock = self.fs.write().unwrap();
for (path, source) in wlock.fs.iter() {
let mut dest_wrt = dest.open_write(path)?;
dest_wrt.write_all(source.as_slice())?;
dest_wrt.terminate()?;
}
Ok(())
}
}
impl Directory for RAMDirectory {
@@ -204,3 +220,28 @@ impl Directory for RAMDirectory {
Ok(self.fs.write().unwrap().watch(watch_callback))
}
}
#[cfg(test)]
mod tests {
use super::RAMDirectory;
use crate::Directory;
use std::io::Write;
use std::path::Path;
#[test]
fn test_persist() {
let msg_atomic: &'static [u8] = b"atomic is the way";
let msg_seq: &'static [u8] = b"sequential is the way";
let path_atomic: &'static Path = Path::new("atomic");
let path_seq: &'static Path = Path::new("seq");
let mut directory = RAMDirectory::create();
assert!(directory.atomic_write(path_atomic, msg_atomic).is_ok());
let mut wrt = directory.open_write(path_seq).unwrap();
assert!(wrt.write_all(msg_seq).is_ok());
assert!(wrt.flush().is_ok());
let mut directory_copy = RAMDirectory::create();
assert!(directory.persist(&mut directory_copy).is_ok());
assert_eq!(directory_copy.atomic_read(path_atomic).unwrap(), msg_atomic);
assert_eq!(directory_copy.atomic_read(path_seq).unwrap(), msg_seq);
}
}

View File

@@ -1,58 +1,48 @@
use crate::common::BitSet;
use crate::fastfield::DeleteBitSet;
use crate::DocId;
use std::borrow::Borrow;
use std::borrow::BorrowMut;
use std::cmp::Ordering;
/// Expresses the outcome of a call to `DocSet`'s `.skip_next(...)`.
#[derive(PartialEq, Eq, Debug)]
pub enum SkipResult {
/// target was in the docset
Reached,
/// target was not in the docset, skipping stopped as a greater element was found
OverStep,
/// the docset was entirely consumed without finding the target, nor any
/// element greater than the target.
End,
}
/// Sentinel value returned when a DocSet has been entirely consumed.
///
/// This is not u32::MAX as one would have expected, due to the lack of SSE2 instructions
/// to compare [u32; 4].
pub const TERMINATED: DocId = std::i32::MAX as u32;
/// Represents an iterable set of sorted doc ids.
pub trait DocSet {
/// Goes to the next element.
/// `.advance(...)` needs to be called a first time to point to the correct
/// element.
fn advance(&mut self) -> bool;
///
/// The DocId of the next element is returned.
/// In other words we should always have :
/// ```ignore
/// let doc = docset.advance();
/// assert_eq!(doc, docset.doc());
/// ```
///
/// If we reached the end of the DocSet, TERMINATED should be returned.
///
/// Calling `.advance()` on a terminated DocSet should be supported, and TERMINATED should
/// be returned.
/// TODO Test existing docsets.
fn advance(&mut self) -> DocId;
/// After skipping, position the iterator in such a way that `.doc()`
/// will return a value greater than or equal to target.
/// Advances the DocSet forward until reaching the target, or going to the
/// lowest DocId greater than the target.
///
/// SkipResult expresses whether the `target value` was reached, overstepped,
/// or if the `DocSet` was entirely consumed without finding any value
/// greater or equal to the `target`.
/// If the end of the DocSet is reached, TERMINATED is returned.
///
/// WARNING: Calling skip always advances the docset.
/// More specifically, if the docset is already positionned on the target
/// skipping will advance to the next position and return SkipResult::Overstep.
/// Calling `.seek(target)` on a terminated DocSet is legal. Implementation
/// of DocSet should support it.
///
/// If `.skip_next()` oversteps, then the docset must be positionned correctly
/// on an existing document. In other words, `.doc()` should return the first document
/// greater than `DocId`.
fn skip_next(&mut self, target: DocId) -> SkipResult {
if !self.advance() {
return SkipResult::End;
}
loop {
match self.doc().cmp(&target) {
Ordering::Less => {
if !self.advance() {
return SkipResult::End;
}
}
Ordering::Equal => return SkipResult::Reached,
Ordering::Greater => return SkipResult::OverStep,
}
/// Calling `seek(TERMINATED)` is also legal and is the normal way to consume a DocSet.
fn seek(&mut self, target: DocId) -> DocId {
let mut doc = self.doc();
debug_assert!(doc <= target);
while doc < target {
doc = self.advance();
}
doc
}
/// Fills a given mutable buffer with the next doc ids from the
@@ -71,38 +61,38 @@ pub trait DocSet {
/// use case where batching. The normal way to
/// go through the `DocId`'s is to call `.advance()`.
fn fill_buffer(&mut self, buffer: &mut [DocId]) -> usize {
if self.doc() == TERMINATED {
return 0;
}
for (i, buffer_val) in buffer.iter_mut().enumerate() {
if self.advance() {
*buffer_val = self.doc();
} else {
return i;
*buffer_val = self.doc();
if self.advance() == TERMINATED {
return i + 1;
}
}
buffer.len()
}
/// Returns the current document
/// Right after creating a new DocSet, the docset points to the first document.
///
/// If the DocSet is empty, .doc() should return `TERMINATED`.
fn doc(&self) -> DocId;
/// Returns a best-effort hint of the
/// length of the docset.
fn size_hint(&self) -> u32;
/// Appends all docs to a `bitset`.
fn append_to_bitset(&mut self, bitset: &mut BitSet) {
while self.advance() {
bitset.insert(self.doc());
}
}
/// Returns the number documents matching.
/// Calling this method consumes the `DocSet`.
fn count(&mut self, delete_bitset: &DeleteBitSet) -> u32 {
let mut count = 0u32;
while self.advance() {
if !delete_bitset.is_deleted(self.doc()) {
let mut doc = self.doc();
while doc != TERMINATED {
if !delete_bitset.is_deleted(doc) {
count += 1u32;
}
doc = self.advance();
}
count
}
@@ -114,22 +104,42 @@ pub trait DocSet {
/// given by `count()`.
fn count_including_deleted(&mut self) -> u32 {
let mut count = 0u32;
while self.advance() {
let mut doc = self.doc();
while doc != TERMINATED {
count += 1u32;
doc = self.advance();
}
count
}
}
impl<'a> DocSet for &'a mut dyn DocSet {
fn advance(&mut self) -> u32 {
(**self).advance()
}
fn seek(&mut self, target: DocId) -> DocId {
(**self).seek(target)
}
fn doc(&self) -> u32 {
(**self).doc()
}
fn size_hint(&self) -> u32 {
(**self).size_hint()
}
}
impl<TDocSet: DocSet + ?Sized> DocSet for Box<TDocSet> {
fn advance(&mut self) -> bool {
fn advance(&mut self) -> DocId {
let unboxed: &mut TDocSet = self.borrow_mut();
unboxed.advance()
}
fn skip_next(&mut self, target: DocId) -> SkipResult {
fn seek(&mut self, target: DocId) -> DocId {
let unboxed: &mut TDocSet = self.borrow_mut();
unboxed.skip_next(target)
unboxed.seek(target)
}
fn doc(&self) -> DocId {
@@ -151,9 +161,4 @@ impl<TDocSet: DocSet + ?Sized> DocSet for Box<TDocSet> {
let unboxed: &mut TDocSet = self.borrow_mut();
unboxed.count_including_deleted()
}
fn append_to_bitset(&mut self, bitset: &mut BitSet) {
let unboxed: &mut TDocSet = self.borrow_mut();
unboxed.append_to_bitset(bitset);
}
}

View File

@@ -7,7 +7,6 @@ use crate::directory::error::{Incompatibility, LockError};
use crate::fastfield::FastFieldNotAvailableError;
use crate::query;
use crate::schema;
use serde_json;
use std::fmt;
use std::path::PathBuf;
use std::sync::PoisonError;

View File

@@ -6,7 +6,6 @@ use crate::schema::{Document, Field};
use crate::termdict::TermOrdinal;
use crate::DocId;
use fnv::FnvHashMap;
use itertools::Itertools;
use std::io;
/// Writer for multi-valued (as in, more than one value per document)
@@ -151,8 +150,8 @@ impl MultiValueIntFastFieldWriter {
}
}
None => {
let val_min_max = self.vals.iter().cloned().minmax();
let (val_min, val_max) = val_min_max.into_option().unwrap_or((0u64, 0u64));
let val_min_max = crate::common::minmax(self.vals.iter().cloned());
let (val_min, val_max) = val_min_max.unwrap_or((0u64, 0u64));
value_serializer =
serializer.new_u64_fast_field_with_idx(self.field, val_min, val_max, 1)?;
for &val in &self.vals {

View File

@@ -21,7 +21,7 @@ mod reader;
mod serializer;
mod writer;
pub use self::reader::FieldNormReader;
pub use self::reader::{FieldNormReader, FieldNormReaders};
pub use self::serializer::FieldNormsSerializer;
pub use self::writer::FieldNormsWriter;

View File

@@ -1,6 +1,41 @@
use super::{fieldnorm_to_id, id_to_fieldnorm};
use crate::common::CompositeFile;
use crate::directory::ReadOnlySource;
use crate::schema::Field;
use crate::space_usage::PerFieldSpaceUsage;
use crate::DocId;
use std::sync::Arc;
/// Reader for the fieldnorm (for each document, the number of tokens indexed in the
/// field) of all indexed fields in the index.
///
/// Each fieldnorm is approximately compressed over one byte. We refer to this byte as
/// `fieldnorm_id`.
/// The mapping from `fieldnorm` to `fieldnorm_id` is given by monotonic.
#[derive(Clone)]
pub struct FieldNormReaders {
data: Arc<CompositeFile>,
}
impl FieldNormReaders {
/// Creates a field norm reader.
pub fn new(source: ReadOnlySource) -> crate::Result<FieldNormReaders> {
let data = CompositeFile::open(&source)?;
Ok(FieldNormReaders {
data: Arc::new(data),
})
}
/// Returns the FieldNormReader for a specific field.
pub fn get_field(&self, field: Field) -> Option<FieldNormReader> {
self.data.open_read(field).map(FieldNormReader::open)
}
/// Return a break down of the space usage per field.
pub fn space_usage(&self) -> PerFieldSpaceUsage {
self.data.space_usage()
}
}
/// Reads the fieldnorm associated to a document.
/// The fieldnorm represents the length associated to
@@ -19,6 +54,7 @@ use crate::DocId;
/// Apart from compression, this scale also makes it possible to
/// precompute computationally expensive functions of the fieldnorm
/// in a very short array.
#[derive(Clone)]
pub struct FieldNormReader {
data: ReadOnlySource,
}
@@ -29,6 +65,11 @@ impl FieldNormReader {
FieldNormReader { data }
}
/// Returns the number of documents in this segment.
pub fn num_docs(&self) -> u32 {
self.data.len() as u32
}
/// Returns the `fieldnorm` associated to a doc id.
/// The fieldnorm is a value approximating the number
/// of tokens in a given field of the `doc_id`.
@@ -65,10 +106,11 @@ impl FieldNormReader {
}
#[cfg(test)]
impl From<Vec<u32>> for FieldNormReader {
fn from(field_norms: Vec<u32>) -> FieldNormReader {
impl From<&[u32]> for FieldNormReader {
fn from(field_norms: &[u32]) -> FieldNormReader {
let field_norms_id = field_norms
.into_iter()
.iter()
.cloned()
.map(FieldNormReader::fieldnorm_to_id)
.collect::<Vec<u8>>();
let field_norms_data = ReadOnlySource::from(field_norms_id);

View File

@@ -78,11 +78,12 @@ impl FieldNormsWriter {
}
/// Serialize the seen fieldnorm values to the serializer for all fields.
pub fn serialize(&self, fieldnorms_serializer: &mut FieldNormsSerializer) -> io::Result<()> {
pub fn serialize(&self, mut fieldnorms_serializer: FieldNormsSerializer) -> io::Result<()> {
for &field in self.fields.iter() {
let fieldnorm_values: &[u8] = &self.fieldnorms_buffer[field.field_id() as usize][..];
fieldnorms_serializer.serialize_field(field, fieldnorm_values)?;
}
fieldnorms_serializer.close()?;
Ok(())
}
}

View File

@@ -10,7 +10,7 @@ use crate::core::SegmentMeta;
use crate::core::SegmentReader;
use crate::directory::TerminatingWrite;
use crate::directory::{DirectoryLock, GarbageCollectionResult};
use crate::docset::DocSet;
use crate::docset::{DocSet, TERMINATED};
use crate::error::TantivyError;
use crate::fastfield::write_delete_bitset;
use crate::indexer::delete_queue::{DeleteCursor, DeleteQueue};
@@ -112,15 +112,15 @@ fn compute_deleted_bitset(
if let Some(mut docset) =
inverted_index.read_postings(&delete_op.term, IndexRecordOption::Basic)
{
while docset.advance() {
let deleted_doc = docset.doc();
let mut deleted_doc = docset.doc();
while deleted_doc != TERMINATED {
if deleted_doc < limit_doc {
delete_bitset.insert(deleted_doc);
might_have_changed = true;
}
deleted_doc = docset.advance();
}
}
delete_cursor.advance();
}
Ok(might_have_changed)
@@ -155,6 +155,8 @@ pub(crate) fn advance_deletes(
None => BitSet::with_max_value(max_doc),
};
let num_deleted_docs_before = segment.meta().num_deleted_docs();
compute_deleted_bitset(
&mut delete_bitset,
&segment_reader,
@@ -164,6 +166,8 @@ pub(crate) fn advance_deletes(
)?;
// TODO optimize
// It should be possible to do something smarter by manipulation bitsets directly
// to compute this union.
if let Some(seg_delete_bitset) = segment_reader.delete_bitset() {
for doc in 0u32..max_doc {
if seg_delete_bitset.is_deleted(doc) {
@@ -172,8 +176,9 @@ pub(crate) fn advance_deletes(
}
}
let num_deleted_docs = delete_bitset.len();
if num_deleted_docs > 0 {
let num_deleted_docs: u32 = delete_bitset.len() as u32;
if num_deleted_docs > num_deleted_docs_before {
// There are new deletes. We need to write a new delete file.
segment = segment.with_delete_meta(num_deleted_docs as u32, target_opstamp);
let mut delete_file = segment.open_write(SegmentComponent::DELETE)?;
write_delete_bitset(&delete_bitset, max_doc, &mut delete_file)?;
@@ -341,7 +346,7 @@ impl IndexWriter {
fn drop_sender(&mut self) {
let (sender, _receiver) = channel::bounded(1);
mem::replace(&mut self.operation_sender, sender);
self.operation_sender = sender;
}
/// If there are some merging threads, blocks until they all finish their work and
@@ -803,6 +808,46 @@ mod tests {
assert_eq!(batch_opstamp1, 2u64);
}
#[test]
fn test_no_need_to_rewrite_delete_file_if_no_new_deletes() {
let mut schema_builder = schema::Schema::builder();
let text_field = schema_builder.add_text_field("text", schema::TEXT);
let index = Index::create_in_ram(schema_builder.build());
let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap();
index_writer.add_document(doc!(text_field => "hello1"));
index_writer.add_document(doc!(text_field => "hello2"));
assert!(index_writer.commit().is_ok());
let reader = index.reader().unwrap();
let searcher = reader.searcher();
assert_eq!(searcher.segment_readers().len(), 1);
assert_eq!(searcher.segment_reader(0u32).num_deleted_docs(), 0);
index_writer.delete_term(Term::from_field_text(text_field, "hello1"));
assert!(index_writer.commit().is_ok());
assert!(reader.reload().is_ok());
let searcher = reader.searcher();
assert_eq!(searcher.segment_readers().len(), 1);
assert_eq!(searcher.segment_reader(0u32).num_deleted_docs(), 1);
let previous_delete_opstamp = index.load_metas().unwrap().segments[0].delete_opstamp();
// All docs containing hello1 have been already removed.
// We should not update the delete meta.
index_writer.delete_term(Term::from_field_text(text_field, "hello1"));
assert!(index_writer.commit().is_ok());
assert!(reader.reload().is_ok());
let searcher = reader.searcher();
assert_eq!(searcher.segment_readers().len(), 1);
assert_eq!(searcher.segment_reader(0u32).num_deleted_docs(), 1);
let after_delete_opstamp = index.load_metas().unwrap().segments[0].delete_opstamp();
assert_eq!(after_delete_opstamp, previous_delete_opstamp);
}
#[test]
fn test_ordered_batched_operations() {
// * one delete for `doc!(field=>"a")`

View File

@@ -54,10 +54,6 @@ impl LogMergePolicy {
impl MergePolicy for LogMergePolicy {
fn compute_merge_candidates(&self, segments: &[SegmentMeta]) -> Vec<MergeCandidate> {
if segments.is_empty() {
return Vec::new();
}
let mut size_sorted_tuples = segments
.iter()
.map(SegmentMeta::num_docs)
@@ -67,27 +63,35 @@ impl MergePolicy for LogMergePolicy {
size_sorted_tuples.sort_by(|x, y| y.1.cmp(&(x.1)));
if size_sorted_tuples.len() <= 1 {
return Vec::new();
}
let size_sorted_log_tuples: Vec<_> = size_sorted_tuples
.into_iter()
.map(|(ind, num_docs)| (ind, f64::from(self.clip_min_size(num_docs)).log2()))
.collect();
let (first_ind, first_score) = size_sorted_log_tuples[0];
let mut current_max_log_size = first_score;
let mut levels = vec![vec![first_ind]];
for &(ind, score) in (&size_sorted_log_tuples).iter().skip(1) {
if score < (current_max_log_size - self.level_log_size) {
current_max_log_size = score;
levels.push(Vec::new());
if let Some(&(first_ind, first_score)) = size_sorted_log_tuples.first() {
let mut current_max_log_size = first_score;
let mut levels = vec![vec![first_ind]];
for &(ind, score) in (&size_sorted_log_tuples).iter().skip(1) {
if score < (current_max_log_size - self.level_log_size) {
current_max_log_size = score;
levels.push(Vec::new());
}
levels.last_mut().unwrap().push(ind);
}
levels.last_mut().unwrap().push(ind);
levels
.iter()
.filter(|level| level.len() >= self.min_merge_size)
.map(|ind_vec| {
MergeCandidate(ind_vec.iter().map(|&ind| segments[ind].id()).collect())
})
.collect()
} else {
return vec![];
}
levels
.iter()
.filter(|level| level.len() >= self.min_merge_size)
.map(|ind_vec| MergeCandidate(ind_vec.iter().map(|&ind| segments[ind].id()).collect()))
.collect()
}
}
@@ -179,6 +183,7 @@ mod tests {
let result_list = test_merge_policy().compute_merge_candidates(&test_input);
assert_eq!(result_list.len(), 2);
}
#[test]
fn test_log_merge_policy_small_segments() {
// segments under min_layer_size are merged together
@@ -194,6 +199,17 @@ mod tests {
assert_eq!(result_list.len(), 1);
}
#[test]
fn test_log_merge_policy_all_segments_too_large_to_merge() {
let eight_large_segments: Vec<SegmentMeta> =
std::iter::repeat_with(|| create_random_segment_meta(100_001))
.take(8)
.collect();
assert!(test_merge_policy()
.compute_merge_candidates(&eight_large_segments)
.is_empty());
}
#[test]
fn test_large_merge_segments() {
let test_input = vec![

View File

@@ -2,15 +2,15 @@ use crate::common::MAX_DOC_LIMIT;
use crate::core::Segment;
use crate::core::SegmentReader;
use crate::core::SerializableSegment;
use crate::docset::DocSet;
use crate::docset::{DocSet, TERMINATED};
use crate::fastfield::BytesFastFieldReader;
use crate::fastfield::DeleteBitSet;
use crate::fastfield::FastFieldReader;
use crate::fastfield::FastFieldSerializer;
use crate::fastfield::MultiValueIntFastFieldReader;
use crate::fieldnorm::FieldNormReader;
use crate::fieldnorm::FieldNormsSerializer;
use crate::fieldnorm::FieldNormsWriter;
use crate::fieldnorm::{FieldNormReader, FieldNormReaders};
use crate::indexer::SegmentSerializer;
use crate::postings::InvertedIndexSerializer;
use crate::postings::Postings;
@@ -20,8 +20,7 @@ use crate::schema::{Field, Schema};
use crate::store::StoreWriter;
use crate::termdict::TermMerger;
use crate::termdict::TermOrdinal;
use crate::DocId;
use itertools::Itertools;
use crate::{DocId, SegmentComponent};
use std::cmp;
use std::collections::HashMap;
@@ -70,11 +69,11 @@ fn compute_min_max_val(
Some(delete_bitset) => {
// some deleted documents,
// we need to recompute the max / min
(0..max_doc)
.filter(|doc_id| delete_bitset.is_alive(*doc_id))
.map(|doc_id| u64_reader.get(doc_id))
.minmax()
.into_option()
crate::common::minmax(
(0..max_doc)
.filter(|doc_id| delete_bitset.is_alive(*doc_id))
.map(|doc_id| u64_reader.get(doc_id)),
)
}
None => {
// no deleted documents,
@@ -168,7 +167,7 @@ impl IndexMerger {
fn write_fieldnorms(
&self,
fieldnorms_serializer: &mut FieldNormsSerializer,
mut fieldnorms_serializer: FieldNormsSerializer,
) -> crate::Result<()> {
let fields = FieldNormsWriter::fields_with_fieldnorm(&self.schema);
let mut fieldnorms_data = Vec::with_capacity(self.max_doc as usize);
@@ -183,6 +182,7 @@ impl IndexMerger {
}
fieldnorms_serializer.serialize_field(field, &fieldnorms_data[..])?;
}
fieldnorms_serializer.close()?;
Ok(())
}
@@ -493,6 +493,7 @@ impl IndexMerger {
indexed_field: Field,
field_type: &FieldType,
serializer: &mut InvertedIndexSerializer,
fieldnorm_reader: Option<FieldNormReader>,
) -> crate::Result<Option<TermOrdinalMapping>> {
let mut positions_buffer: Vec<u32> = Vec::with_capacity(1_000);
let mut delta_computer = DeltaComputer::new();
@@ -551,7 +552,8 @@ impl IndexMerger {
// - Segment 2's doc ids become [seg0.max_doc + seg1.max_doc,
// seg0.max_doc + seg1.max_doc + seg2.max_doc]
// ...
let mut field_serializer = serializer.new_field(indexed_field, total_num_tokens)?;
let mut field_serializer =
serializer.new_field(indexed_field, total_num_tokens, fieldnorm_reader)?;
let field_entry = self.schema.get_field_entry(indexed_field);
@@ -575,10 +577,12 @@ impl IndexMerger {
let inverted_index = segment_reader.inverted_index(indexed_field);
let mut segment_postings = inverted_index
.read_postings_from_terminfo(term_info, segment_postings_option);
while segment_postings.advance() {
if !segment_reader.is_deleted(segment_postings.doc()) {
let mut doc = segment_postings.doc();
while doc != TERMINATED {
if !segment_reader.is_deleted(doc) {
return Some((segment_ord, segment_postings));
}
doc = segment_postings.advance();
}
None
})
@@ -588,57 +592,49 @@ impl IndexMerger {
// of all of the segments containing the given term.
//
// These segments are non-empty and advance has already been called.
if !segment_postings.is_empty() {
// If not, the `term` will be entirely removed.
// We know that there is at least one document containing
// the term, so we add it.
let to_term_ord = field_serializer.new_term(term_bytes)?;
if let Some(ref mut term_ord_mapping) = term_ord_mapping_opt {
for (segment_ord, from_term_ord) in merged_terms.matching_segments() {
term_ord_mapping.register_from_to(segment_ord, from_term_ord, to_term_ord);
}
}
// We can now serialize this postings, by pushing each document to the
// postings serializer.
for (segment_ord, mut segment_postings) in segment_postings {
let old_to_new_doc_id = &merged_doc_id_map[segment_ord];
loop {
let doc = segment_postings.doc();
// `.advance()` has been called once before the loop.
//
// It was required to make sure we only consider segments
// that effectively contain at least one non-deleted document
// and remove terms that do not have documents associated.
//
// For this reason, we cannot use a `while segment_postings.advance()` loop.
// deleted doc are skipped as they do not have a `remapped_doc_id`.
if let Some(remapped_doc_id) = old_to_new_doc_id[doc as usize] {
// we make sure to only write the term iff
// there is at least one document.
let term_freq = segment_postings.term_freq();
segment_postings.positions(&mut positions_buffer);
let delta_positions = delta_computer.compute_delta(&positions_buffer);
field_serializer.write_doc(
remapped_doc_id,
term_freq,
delta_positions,
)?;
}
if !segment_postings.advance() {
break;
}
}
}
// closing the term.
field_serializer.close_term()?;
if segment_postings.is_empty() {
continue;
}
// If not, the `term` will be entirely removed.
// We know that there is at least one document containing
// the term, so we add it.
let term_doc_freq = segment_postings
.iter()
.map(|(_, segment_posting)| segment_posting.doc_freq())
.sum();
let to_term_ord = field_serializer.new_term(term_bytes, term_doc_freq)?;
if let Some(ref mut term_ord_mapping) = term_ord_mapping_opt {
for (segment_ord, from_term_ord) in merged_terms.matching_segments() {
term_ord_mapping.register_from_to(segment_ord, from_term_ord, to_term_ord);
}
}
// We can now serialize this postings, by pushing each document to the
// postings serializer.
for (segment_ord, mut segment_postings) in segment_postings {
let old_to_new_doc_id = &merged_doc_id_map[segment_ord];
let mut doc = segment_postings.doc();
while doc != TERMINATED {
// deleted doc are skipped as they do not have a `remapped_doc_id`.
if let Some(remapped_doc_id) = old_to_new_doc_id[doc as usize] {
// we make sure to only write the term iff
// there is at least one document.
let term_freq = segment_postings.term_freq();
segment_postings.positions(&mut positions_buffer);
let delta_positions = delta_computer.compute_delta(&positions_buffer);
field_serializer.write_doc(remapped_doc_id, term_freq, delta_positions)?;
}
doc = segment_postings.advance();
}
}
// closing the term.
field_serializer.close_term()?;
}
field_serializer.close()?;
Ok(term_ord_mapping_opt)
@@ -647,13 +643,18 @@ impl IndexMerger {
fn write_postings(
&self,
serializer: &mut InvertedIndexSerializer,
fieldnorm_readers: FieldNormReaders,
) -> crate::Result<HashMap<Field, TermOrdinalMapping>> {
let mut term_ordinal_mappings = HashMap::new();
for (field, field_entry) in self.schema.fields() {
let fieldnorm_reader = fieldnorm_readers.get_field(field);
if field_entry.is_indexed() {
if let Some(term_ordinal_mapping) =
self.write_postings_for_field(field, field_entry.field_type(), serializer)?
{
if let Some(term_ordinal_mapping) = self.write_postings_for_field(
field,
field_entry.field_type(),
serializer,
fieldnorm_reader,
)? {
term_ordinal_mappings.insert(field, term_ordinal_mapping);
}
}
@@ -679,8 +680,15 @@ impl IndexMerger {
impl SerializableSegment for IndexMerger {
fn write(&self, mut serializer: SegmentSerializer) -> crate::Result<u32> {
let term_ord_mappings = self.write_postings(serializer.get_postings_serializer())?;
self.write_fieldnorms(serializer.get_fieldnorms_serializer())?;
if let Some(fieldnorms_serializer) = serializer.extract_fieldnorms_serializer() {
self.write_fieldnorms(fieldnorms_serializer)?;
}
let fieldnorm_data = serializer
.segment()
.open_read(SegmentComponent::FIELDNORMS)?;
let fieldnorm_readers = FieldNormReaders::new(fieldnorm_data)?;
let term_ord_mappings =
self.write_postings(serializer.get_postings_serializer(), fieldnorm_readers)?;
self.write_fast_fields(serializer.get_fast_field_serializer(), term_ord_mappings)?;
self.write_storable_fields(serializer.get_store_writer())?;
serializer.close()?;
@@ -690,15 +698,15 @@ impl SerializableSegment for IndexMerger {
#[cfg(test)]
mod tests {
use crate::assert_nearly_equals;
use crate::collector::tests::TEST_COLLECTOR_WITH_SCORE;
use crate::collector::tests::{BytesFastFieldTestCollector, FastFieldTestCollector};
use crate::collector::{Count, FacetCollector};
use crate::core::Index;
use crate::query::AllQuery;
use crate::query::BooleanQuery;
use crate::query::Scorer;
use crate::query::TermQuery;
use crate::schema;
use crate::schema::Cardinality;
use crate::schema::Document;
use crate::schema::Facet;
use crate::schema::IndexRecordOption;
@@ -706,9 +714,11 @@ mod tests {
use crate::schema::Term;
use crate::schema::TextFieldIndexing;
use crate::schema::INDEXED;
use crate::schema::{Cardinality, TEXT};
use crate::DocAddress;
use crate::IndexWriter;
use crate::Searcher;
use crate::{schema, DocSet, SegmentId};
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use futures::executor::block_on;
use std::io::Cursor;
@@ -1515,12 +1525,9 @@ mod tests {
for i in 0..100 {
let mut doc = Document::new();
doc.add_f64(field, 42.0);
doc.add_f64(multi_field, 0.24);
doc.add_f64(multi_field, 0.27);
writer.add_document(doc);
if i % 5 == 0 {
writer.commit()?;
}
@@ -1532,6 +1539,72 @@ mod tests {
// If a merging thread fails, we should end up with more
// than one segment here
assert_eq!(1, index.searchable_segments()?.len());
Ok(())
}
#[test]
fn test_merged_index_has_blockwand() -> crate::Result<()> {
let mut builder = schema::SchemaBuilder::new();
let text = builder.add_text_field("text", TEXT);
let index = Index::create_in_ram(builder.build());
let mut writer = index.writer_with_num_threads(1, 3_000_000)?;
let happy_term = Term::from_field_text(text, "happy");
let term_query = TermQuery::new(happy_term, IndexRecordOption::WithFreqs);
for _ in 0..62 {
writer.add_document(doc!(text=>"hello happy tax payer"));
}
writer.commit()?;
let reader = index.reader()?;
let searcher = reader.searcher();
let mut term_scorer = term_query
.specialized_weight(&searcher, true)
.specialized_scorer(searcher.segment_reader(0u32), 1.0f32)?;
assert_eq!(term_scorer.doc(), 0);
assert_nearly_equals!(term_scorer.block_max_score(), 0.0079681855);
assert_nearly_equals!(term_scorer.score(), 0.0079681855);
for _ in 0..81 {
writer.add_document(doc!(text=>"hello happy tax payer"));
}
writer.commit()?;
reader.reload()?;
let searcher = reader.searcher();
assert_eq!(searcher.segment_readers().len(), 2);
for segment_reader in searcher.segment_readers() {
let mut term_scorer = term_query
.specialized_weight(&searcher, true)
.specialized_scorer(segment_reader, 1.0f32)?;
// the difference compared to before is instrinsic to the bm25 formula. no worries there.
for doc in segment_reader.doc_ids_alive() {
assert_eq!(term_scorer.doc(), doc);
assert_nearly_equals!(term_scorer.block_max_score(), 0.003478312);
assert_nearly_equals!(term_scorer.score(), 0.003478312);
term_scorer.advance();
}
}
let segment_ids: Vec<SegmentId> = searcher
.segment_readers()
.iter()
.map(|reader| reader.segment_id())
.collect();
block_on(writer.merge(&segment_ids[..]))?;
reader.reload()?;
let searcher = reader.searcher();
assert_eq!(searcher.segment_readers().len(), 1);
let segment_reader = searcher.segment_reader(0u32);
let mut term_scorer = term_query
.specialized_weight(&searcher, true)
.specialized_scorer(segment_reader, 1.0f32)?;
// the difference compared to before is instrinsic to the bm25 formula. no worries there.
for doc in segment_reader.doc_ids_alive() {
assert_eq!(term_scorer.doc(), doc);
assert_nearly_equals!(term_scorer.block_max_score(), 0.003478312);
assert_nearly_equals!(term_scorer.score(), 0.003478312);
term_scorer.advance();
}
Ok(())
}

View File

@@ -8,15 +8,16 @@ use crate::store::StoreWriter;
/// Segment serializer is in charge of laying out on disk
/// the data accumulated and sorted by the `SegmentWriter`.
pub struct SegmentSerializer {
segment: Segment,
store_writer: StoreWriter,
fast_field_serializer: FastFieldSerializer,
fieldnorms_serializer: FieldNormsSerializer,
fieldnorms_serializer: Option<FieldNormsSerializer>,
postings_serializer: InvertedIndexSerializer,
}
impl SegmentSerializer {
/// Creates a new `SegmentSerializer`.
pub fn for_segment(segment: &mut Segment) -> crate::Result<SegmentSerializer> {
pub fn for_segment(mut segment: Segment) -> crate::Result<SegmentSerializer> {
let store_write = segment.open_write(SegmentComponent::STORE)?;
let fast_field_write = segment.open_write(SegmentComponent::FASTFIELDS)?;
@@ -25,15 +26,20 @@ impl SegmentSerializer {
let fieldnorms_write = segment.open_write(SegmentComponent::FIELDNORMS)?;
let fieldnorms_serializer = FieldNormsSerializer::from_write(fieldnorms_write)?;
let postings_serializer = InvertedIndexSerializer::open(segment)?;
let postings_serializer = InvertedIndexSerializer::open(&mut segment)?;
Ok(SegmentSerializer {
segment,
store_writer: StoreWriter::new(store_write),
fast_field_serializer,
fieldnorms_serializer,
fieldnorms_serializer: Some(fieldnorms_serializer),
postings_serializer,
})
}
pub fn segment(&self) -> &Segment {
&self.segment
}
/// Accessor to the `PostingsSerializer`.
pub fn get_postings_serializer(&mut self) -> &mut InvertedIndexSerializer {
&mut self.postings_serializer
@@ -44,9 +50,11 @@ impl SegmentSerializer {
&mut self.fast_field_serializer
}
/// Accessor to the field norm serializer.
pub fn get_fieldnorms_serializer(&mut self) -> &mut FieldNormsSerializer {
&mut self.fieldnorms_serializer
/// Extract the field norm serializer.
///
/// Note the fieldnorms serializer can only be extracted once.
pub fn extract_fieldnorms_serializer(&mut self) -> Option<FieldNormsSerializer> {
self.fieldnorms_serializer.take()
}
/// Accessor to the `StoreWriter`.
@@ -55,11 +63,13 @@ impl SegmentSerializer {
}
/// Finalize the segment serialization.
pub fn close(self) -> crate::Result<()> {
pub fn close(mut self) -> crate::Result<()> {
if let Some(fieldnorms_serializer) = self.extract_fieldnorms_serializer() {
fieldnorms_serializer.close()?;
}
self.fast_field_serializer.close()?;
self.postings_serializer.close()?;
self.store_writer.close()?;
self.fieldnorms_serializer.close()?;
Ok(())
}
}

View File

@@ -23,7 +23,6 @@ use futures::channel::oneshot;
use futures::executor::{ThreadPool, ThreadPoolBuilder};
use futures::future::Future;
use futures::future::TryFutureExt;
use serde_json;
use std::borrow::BorrowMut;
use std::collections::HashSet;
use std::io::Write;
@@ -113,7 +112,7 @@ fn merge(
target_opstamp: Opstamp,
) -> crate::Result<SegmentEntry> {
// first we need to apply deletes to our segment.
let mut merged_segment = index.new_segment();
let merged_segment = index.new_segment();
// First we apply all of the delet to the merged segment, up to the target opstamp.
for segment_entry in &mut segment_entries {
@@ -131,12 +130,14 @@ fn merge(
// An IndexMerger is like a "view" of our merged segments.
let merger: IndexMerger = IndexMerger::open(index.schema(), &segments[..])?;
let merged_segment_id = merged_segment.id();
// ... we just serialize this index merger in our new segment to merge the two segments.
let segment_serializer = SegmentSerializer::for_segment(&mut merged_segment)?;
let segment_serializer = SegmentSerializer::for_segment(merged_segment)?;
let num_docs = merger.write(segment_serializer)?;
let segment_meta = index.new_segment_meta(merged_segment.id(), num_docs);
let segment_meta = index.new_segment_meta(merged_segment_id, num_docs);
Ok(SegmentEntry::new(segment_meta, delete_cursor, None))
}
@@ -522,7 +523,7 @@ impl SegmentUpdater {
///
/// Upon termination of the current merging threads,
/// merge opportunity may appear.
//
///
/// We keep waiting until the merge policy judges that
/// no opportunity is available.
///

View File

@@ -2,7 +2,7 @@ use super::operation::AddOperation;
use crate::core::Segment;
use crate::core::SerializableSegment;
use crate::fastfield::FastFieldsWriter;
use crate::fieldnorm::FieldNormsWriter;
use crate::fieldnorm::{FieldNormReaders, FieldNormsWriter};
use crate::indexer::segment_serializer::SegmentSerializer;
use crate::postings::compute_table_size;
use crate::postings::MultiFieldPostingsWriter;
@@ -14,8 +14,8 @@ use crate::schema::{Field, FieldEntry};
use crate::tokenizer::{BoxTokenStream, PreTokenizedStream};
use crate::tokenizer::{FacetTokenizer, TextAnalyzer};
use crate::tokenizer::{TokenStreamChain, Tokenizer};
use crate::DocId;
use crate::Opstamp;
use crate::{DocId, SegmentComponent};
use std::io;
use std::str;
@@ -62,11 +62,12 @@ impl SegmentWriter {
/// - schema
pub fn for_segment(
memory_budget: usize,
mut segment: Segment,
segment: Segment,
schema: &Schema,
) -> crate::Result<SegmentWriter> {
let tokenizer_manager = segment.index().tokenizers().clone();
let table_num_bits = initial_table_size(memory_budget)?;
let segment_serializer = SegmentSerializer::for_segment(&mut segment)?;
let segment_serializer = SegmentSerializer::for_segment(segment)?;
let multifield_postings = MultiFieldPostingsWriter::new(schema, table_num_bits);
let tokenizers = schema
.fields()
@@ -76,7 +77,7 @@ impl SegmentWriter {
.get_indexing_options()
.and_then(|text_index_option| {
let tokenizer_name = &text_index_option.tokenizer();
segment.index().tokenizers().get(tokenizer_name)
tokenizer_manager.get(tokenizer_name)
}),
_ => None,
},
@@ -280,9 +281,16 @@ fn write(
fieldnorms_writer: &FieldNormsWriter,
mut serializer: SegmentSerializer,
) -> crate::Result<()> {
let term_ord_map = multifield_postings.serialize(serializer.get_postings_serializer())?;
if let Some(fieldnorms_serializer) = serializer.extract_fieldnorms_serializer() {
fieldnorms_writer.serialize(fieldnorms_serializer)?;
}
let fieldnorm_data = serializer
.segment()
.open_read(SegmentComponent::FIELDNORMS)?;
let fieldnorm_readers = FieldNormReaders::new(fieldnorm_data)?;
let term_ord_map =
multifield_postings.serialize(serializer.get_postings_serializer(), fieldnorm_readers)?;
fast_field_writers.serialize(serializer.get_fast_field_serializer(), &term_ord_map)?;
fieldnorms_writer.serialize(serializer.get_fieldnorms_serializer())?;
serializer.close()?;
Ok(())
}

View File

@@ -98,9 +98,6 @@
//! [literate programming](https://tantivy-search.github.io/examples/basic_search.html) /
//! [source code](https://github.com/tantivy-search/tantivy/blob/master/examples/basic_search.rs))
#[macro_use]
extern crate serde_derive;
#[cfg_attr(test, macro_use)]
extern crate serde_json;
@@ -159,7 +156,7 @@ mod snippet;
pub use self::snippet::{Snippet, SnippetGenerator};
mod docset;
pub use self::docset::{DocSet, SkipResult};
pub use self::docset::{DocSet, TERMINATED};
pub use crate::common::{f64_to_u64, i64_to_u64, u64_to_f64, u64_to_i64};
pub use crate::core::{Executor, SegmentComponent};
pub use crate::core::{Index, IndexMeta, Searcher, Segment, SegmentId, SegmentMeta};
@@ -173,9 +170,10 @@ pub use crate::schema::{Document, Term};
use std::fmt;
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
/// Index format version.
const INDEX_FORMAT_VERSION: u32 = 1;
const INDEX_FORMAT_VERSION: u32 = 2;
/// Structure version for the index.
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
@@ -287,7 +285,7 @@ mod tests {
use crate::collector::tests::TEST_COLLECTOR_WITH_SCORE;
use crate::core::SegmentReader;
use crate::docset::DocSet;
use crate::docset::{DocSet, TERMINATED};
use crate::query::BooleanQuery;
use crate::schema::*;
use crate::DocAddress;
@@ -300,17 +298,26 @@ mod tests {
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
pub fn assert_nearly_equals(expected: f32, val: f32) {
assert!(
nearly_equals(val, expected),
"Got {}, expected {}.",
val,
expected
);
}
pub fn nearly_equals(a: f32, b: f32) -> bool {
(a - b).abs() < 0.0005 * (a + b).abs()
/// Checks if left and right are close one to each other.
/// Panics if the two values are more than 0.5% apart.
#[macro_export]
macro_rules! assert_nearly_equals {
($left:expr, $right:expr) => {{
match (&$left, &$right) {
(left_val, right_val) => {
let diff = (left_val - right_val).abs();
let add = left_val.abs() + right_val.abs();
if diff > 0.0005 * add {
panic!(
r#"assertion failed: `(left ~= right)`
left: `{:?}`,
right: `{:?}`"#,
&*left_val, &*right_val
)
}
}
}
}};
}
pub fn generate_nonunique_unsorted(max_value: u32, n_elems: usize) -> Vec<u32> {
@@ -383,19 +390,12 @@ mod tests {
index_writer.commit().unwrap();
}
{
{
let doc = doc!(text_field=>"a");
index_writer.add_document(doc);
}
{
let doc = doc!(text_field=>"a a");
index_writer.add_document(doc);
}
index_writer.add_document(doc!(text_field=>"a"));
index_writer.add_document(doc!(text_field=>"a a"));
index_writer.commit().unwrap();
}
{
let doc = doc!(text_field=>"c");
index_writer.add_document(doc);
index_writer.add_document(doc!(text_field=>"c"));
index_writer.commit().unwrap();
}
{
@@ -474,10 +474,12 @@ mod tests {
}
fn advance_undeleted(docset: &mut dyn DocSet, reader: &SegmentReader) -> bool {
while docset.advance() {
if !reader.is_deleted(docset.doc()) {
let mut doc = docset.advance();
while doc != TERMINATED {
if !reader.is_deleted(doc) {
return true;
}
doc = docset.advance();
}
false
}
@@ -643,9 +645,8 @@ mod tests {
.inverted_index(term.field())
.read_postings(&term, IndexRecordOption::Basic)
.unwrap();
assert!(postings.advance());
assert_eq!(postings.doc(), 0);
assert!(!postings.advance());
assert_eq!(postings.advance(), TERMINATED);
}
#[test]
@@ -667,9 +668,8 @@ mod tests {
.inverted_index(term.field())
.read_postings(&term, IndexRecordOption::Basic)
.unwrap();
assert!(postings.advance());
assert_eq!(postings.doc(), 0);
assert!(!postings.advance());
assert_eq!(postings.advance(), TERMINATED);
}
#[test]
@@ -691,9 +691,8 @@ mod tests {
.inverted_index(term.field())
.read_postings(&term, IndexRecordOption::Basic)
.unwrap();
assert!(postings.advance());
assert_eq!(postings.doc(), 0);
assert!(!postings.advance());
assert_eq!(postings.advance(), TERMINATED);
}
#[test]
@@ -762,10 +761,8 @@ mod tests {
{
// writing the segment
let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap();
{
let doc = doc!(text_field=>"af af af bc bc");
index_writer.add_document(doc);
}
let doc = doc!(text_field=>"af af af bc bc");
index_writer.add_document(doc);
index_writer.commit().unwrap();
}
{
@@ -781,10 +778,9 @@ mod tests {
let mut postings = inverted_index
.read_postings(&term_af, IndexRecordOption::WithFreqsAndPositions)
.unwrap();
assert!(postings.advance());
assert_eq!(postings.doc(), 0);
assert_eq!(postings.term_freq(), 3);
assert!(!postings.advance());
assert_eq!(postings.advance(), TERMINATED);
}
}

View File

@@ -37,9 +37,9 @@ const LONG_SKIP_INTERVAL: u64 = (LONG_SKIP_IN_BLOCKS * COMPRESSION_BLOCK_SIZE) a
#[cfg(test)]
pub mod tests {
use super::{PositionReader, PositionSerializer};
use super::PositionSerializer;
use crate::directory::ReadOnlySource;
use crate::positions::COMPRESSION_BLOCK_SIZE;
use crate::positions::reader::PositionReader;
use std::iter;
fn create_stream_buffer(vals: &[u32]) -> (ReadOnlySource, ReadOnlySource) {
@@ -68,7 +68,7 @@ pub mod tests {
let mut position_reader = PositionReader::new(stream, skip, 0u64);
for &n in &[1, 10, 127, 128, 130, 312] {
let mut v = vec![0u32; n];
position_reader.read(&mut v[..n]);
position_reader.read(0, &mut v[..]);
for i in 0..n {
assert_eq!(v[i], i as u32);
}
@@ -76,19 +76,19 @@ pub mod tests {
}
#[test]
fn test_position_skip() {
let v: Vec<u32> = (0..1_000).collect();
fn test_position_read_with_offset() {
let v: Vec<u32> = (0..1000).collect();
let (stream, skip) = create_stream_buffer(&v[..]);
assert_eq!(skip.len(), 12);
assert_eq!(stream.len(), 1168);
let mut position_reader = PositionReader::new(stream, skip, 0u64);
position_reader.skip(10);
for &n in &[10, 127, COMPRESSION_BLOCK_SIZE, 130, 312] {
let mut v = vec![0u32; n];
position_reader.read(&mut v[..n]);
for i in 0..n {
assert_eq!(v[i], 10u32 + i as u32);
for &offset in &[1u64, 10u64, 127u64, 128u64, 130u64, 312u64] {
for &len in &[1, 10, 130, 500] {
let mut v = vec![0u32; len];
position_reader.read(offset, &mut v[..]);
for i in 0..len {
assert_eq!(v[i], i as u32 + offset as u32);
}
}
}
}
@@ -103,11 +103,12 @@ pub mod tests {
let mut position_reader = PositionReader::new(stream, skip, 0u64);
let mut buf = [0u32; 7];
let mut c = 0;
let mut offset = 0;
for _ in 0..100 {
position_reader.read(&mut buf);
position_reader.read(&mut buf);
position_reader.skip(4);
position_reader.skip(3);
position_reader.read(offset, &mut buf);
position_reader.read(offset, &mut buf);
offset += 7;
for &el in &buf {
assert_eq!(c, el);
c += 1;
@@ -115,6 +116,58 @@ pub mod tests {
}
}
#[test]
fn test_position_reread_anchor_different_than_block() {
let v: Vec<u32> = (0..2_000_000).collect();
let (stream, skip) = create_stream_buffer(&v[..]);
assert_eq!(skip.len(), 15_749);
assert_eq!(stream.len(), 4_987_872);
let mut position_reader = PositionReader::new(stream.clone(), skip.clone(), 0);
let mut buf = [0u32; 256];
position_reader.read(128, &mut buf);
for i in 0..256 {
assert_eq!(buf[i], (128 + i) as u32);
}
position_reader.read(128, &mut buf);
for i in 0..256 {
assert_eq!(buf[i], (128 + i) as u32);
}
}
#[test]
#[should_panic(expected = "offset arguments should be increasing.")]
fn test_position_panic_if_called_previous_anchor() {
let v: Vec<u32> = (0..2_000_000).collect();
let (stream, skip) = create_stream_buffer(&v[..]);
assert_eq!(skip.len(), 15_749);
assert_eq!(stream.len(), 4_987_872);
let mut buf = [0u32; 1];
let mut position_reader = PositionReader::new(stream.clone(), skip.clone(), 200_000);
position_reader.read(230, &mut buf);
position_reader.read(9, &mut buf);
}
#[test]
fn test_positions_bug() {
let mut v: Vec<u32> = vec![];
for i in 1..200 {
for j in 0..i {
v.push(j);
}
}
let (stream, skip) = create_stream_buffer(&v[..]);
let mut buf = Vec::new();
let mut position_reader = PositionReader::new(stream.clone(), skip.clone(), 0);
let mut offset = 0;
for i in 1..24 {
buf.resize(i, 0);
position_reader.read(offset, &mut buf[..]);
offset += i as u64;
let r: Vec<u32> = (0..i).map(|el| el as u32).collect();
assert_eq!(buf, &r[..]);
}
}
#[test]
fn test_position_long_skip_const() {
const CONST_VAL: u32 = 9u32;
@@ -124,7 +177,7 @@ pub mod tests {
assert_eq!(stream.len(), 1_000_000);
let mut position_reader = PositionReader::new(stream, skip, 128 * 1024);
let mut buf = [0u32; 1];
position_reader.read(&mut buf);
position_reader.read(0, &mut buf);
assert_eq!(buf[0], CONST_VAL);
}
@@ -143,7 +196,7 @@ pub mod tests {
] {
let mut position_reader = PositionReader::new(stream.clone(), skip.clone(), offset);
let mut buf = [0u32; 1];
position_reader.read(&mut buf);
position_reader.read(0, &mut buf);
assert_eq!(buf[0], offset as u32);
}
}

View File

@@ -3,7 +3,6 @@ use crate::directory::ReadOnlySource;
use crate::positions::COMPRESSION_BLOCK_SIZE;
use crate::positions::LONG_SKIP_INTERVAL;
use crate::positions::LONG_SKIP_IN_BLOCKS;
use crate::postings::compression::compressed_block_size;
/// Positions works as a long sequence of compressed block.
/// All terms are chained one after the other.
///
@@ -62,22 +61,20 @@ impl Positions {
fn reader(&self, offset: u64) -> PositionReader {
let long_skip_id = (offset / LONG_SKIP_INTERVAL) as usize;
let small_skip = (offset % LONG_SKIP_INTERVAL) as usize;
let offset_num_bytes: u64 = self.long_skip(long_skip_id);
let mut position_read = OwnedRead::new(self.position_source.clone());
position_read.advance(offset_num_bytes as usize);
let mut skip_read = OwnedRead::new(self.skip_source.clone());
skip_read.advance(long_skip_id * LONG_SKIP_IN_BLOCKS);
let mut position_reader = PositionReader {
PositionReader {
bit_packer: self.bit_packer,
skip_read,
position_read,
inner_offset: 0,
buffer: Box::new([0u32; 128]),
ahead: None,
};
position_reader.skip(small_skip);
position_reader
block_offset: std::i64::MAX as u64,
anchor_offset: (long_skip_id as u64) * LONG_SKIP_INTERVAL,
abs_offset: offset,
}
}
}
@@ -85,51 +82,12 @@ pub struct PositionReader {
skip_read: OwnedRead,
position_read: OwnedRead,
bit_packer: BitPacker4x,
inner_offset: usize,
buffer: Box<[u32; 128]>,
ahead: Option<usize>, // if None, no block is loaded.
// if Some(num_blocks), the block currently loaded is num_blocks ahead
// of the block of the next int to read.
}
buffer: Box<[u32; COMPRESSION_BLOCK_SIZE]>,
// `ahead` represents the offset of the block currently loaded
// compared to the cursor of the actual stream.
//
// By contract, when this function is called, the current block has to be
// decompressed.
//
// If the requested number of els ends exactly at a given block, the next
// block is not decompressed.
fn read_impl(
bit_packer: BitPacker4x,
mut position: &[u8],
buffer: &mut [u32; 128],
mut inner_offset: usize,
num_bits: &[u8],
output: &mut [u32],
) -> usize {
let mut output_start = 0;
let mut output_len = output.len();
let mut ahead = 0;
loop {
let available_len = COMPRESSION_BLOCK_SIZE - inner_offset;
// We have enough elements in the current block.
// Let's copy the requested elements in the output buffer,
// and return.
if output_len <= available_len {
output[output_start..].copy_from_slice(&buffer[inner_offset..][..output_len]);
return ahead;
}
output[output_start..][..available_len].copy_from_slice(&buffer[inner_offset..]);
output_len -= available_len;
output_start += available_len;
inner_offset = 0;
let num_bits = num_bits[ahead];
bit_packer.decompress(position, &mut buffer[..], num_bits);
let block_len = compressed_block_size(num_bits);
position = &position[block_len..];
ahead += 1;
}
block_offset: u64,
anchor_offset: u64,
abs_offset: u64,
}
impl PositionReader {
@@ -141,57 +99,65 @@ impl PositionReader {
Positions::new(position_source, skip_source).reader(offset)
}
/// Fills a buffer with the next `output.len()` integers.
/// This does not consume / advance the stream.
pub fn read(&mut self, output: &mut [u32]) {
let skip_data = self.skip_read.as_ref();
let position_data = self.position_read.as_ref();
let num_bits = self.skip_read.get(0);
if self.ahead != Some(0) {
// the block currently available is not the block
// for the current position
fn advance_num_blocks(&mut self, num_blocks: usize) {
let num_bits: usize = self.skip_read.as_ref()[..num_blocks]
.iter()
.cloned()
.map(|num_bits| num_bits as usize)
.sum();
let num_bytes_to_skip = num_bits * COMPRESSION_BLOCK_SIZE / 8;
self.skip_read.advance(num_blocks as usize);
self.position_read.advance(num_bytes_to_skip);
}
/// Fills a buffer with the positions `[offset..offset+output.len())` integers.
///
/// `offset` is required to have a value >= to the offsets given in previous calls
/// for the given `PositionReaderAbsolute` instance.
pub fn read(&mut self, mut offset: u64, mut output: &mut [u32]) {
offset += self.abs_offset;
assert!(
offset >= self.anchor_offset,
"offset arguments should be increasing."
);
let delta_to_block_offset = offset as i64 - self.block_offset as i64;
if delta_to_block_offset < 0 || delta_to_block_offset >= 128 {
// The first position is not within the first block.
// We need to decompress the first block.
let delta_to_anchor_offset = offset - self.anchor_offset;
let num_blocks_to_skip =
(delta_to_anchor_offset / (COMPRESSION_BLOCK_SIZE as u64)) as usize;
self.advance_num_blocks(num_blocks_to_skip);
self.anchor_offset = offset - (offset % COMPRESSION_BLOCK_SIZE as u64);
self.block_offset = self.anchor_offset;
let num_bits = self.skip_read.get(0);
self.bit_packer
.decompress(self.position_read.as_ref(), self.buffer.as_mut(), num_bits);
} else {
let num_blocks_to_skip =
((self.block_offset - self.anchor_offset) / COMPRESSION_BLOCK_SIZE as u64) as usize;
self.advance_num_blocks(num_blocks_to_skip);
self.anchor_offset = self.block_offset;
}
let mut num_bits = self.skip_read.get(0);
let mut position_data = self.position_read.as_ref();
for i in 1.. {
let offset_in_block = (offset as usize) % COMPRESSION_BLOCK_SIZE;
let remaining_in_block = COMPRESSION_BLOCK_SIZE - offset_in_block;
if remaining_in_block >= output.len() {
output.copy_from_slice(&self.buffer[offset_in_block..][..output.len()]);
break;
}
output[..remaining_in_block].copy_from_slice(&self.buffer[offset_in_block..]);
output = &mut output[remaining_in_block..];
offset += remaining_in_block as u64;
position_data = &position_data[(num_bits as usize * COMPRESSION_BLOCK_SIZE / 8)..];
num_bits = self.skip_read.get(i);
self.bit_packer
.decompress(position_data, self.buffer.as_mut(), num_bits);
self.ahead = Some(0);
self.block_offset += COMPRESSION_BLOCK_SIZE as u64;
}
let block_len = compressed_block_size(num_bits);
self.ahead = Some(read_impl(
self.bit_packer,
&position_data[block_len..],
self.buffer.as_mut(),
self.inner_offset,
&skip_data[1..],
output,
));
}
/// Skip the next `skip_len` integer.
///
/// If a full block is skipped, calling
/// `.skip(...)` will avoid decompressing it.
///
/// May panic if the end of the stream is reached.
pub fn skip(&mut self, skip_len: usize) {
let skip_len_plus_inner_offset = skip_len + self.inner_offset;
let num_blocks_to_advance = skip_len_plus_inner_offset / COMPRESSION_BLOCK_SIZE;
self.inner_offset = skip_len_plus_inner_offset % COMPRESSION_BLOCK_SIZE;
self.ahead = self.ahead.and_then(|num_blocks| {
if num_blocks >= num_blocks_to_advance {
Some(num_blocks - num_blocks_to_advance)
} else {
None
}
});
let skip_len_in_bits = self.skip_read.as_ref()[..num_blocks_to_advance]
.iter()
.map(|num_bits| *num_bits as usize)
.sum::<usize>()
* COMPRESSION_BLOCK_SIZE;
let skip_len_in_bytes = skip_len_in_bits / 8;
self.skip_read.advance(num_blocks_to_advance);
self.position_read.advance(skip_len_in_bytes);
}
}

View File

@@ -87,6 +87,7 @@ fn exponential_search(arr: &[u32], target: u32) -> (usize, usize) {
(begin, end)
}
#[inline(never)]
fn galloping(block_docs: &[u32], target: u32) -> usize {
let (start, end) = exponential_search(&block_docs, target);
start + linear_search(&block_docs[start..end], target)
@@ -106,7 +107,7 @@ impl BlockSearcher {
/// the target.
///
/// The results should be equivalent to
/// ```ignore
/// ```compile_fail
/// block[..]
// .iter()
// .take_while(|&&val| val < target)
@@ -129,23 +130,18 @@ impl BlockSearcher {
///
/// If SSE2 instructions are available in the `(platform, running CPU)`,
/// then we use a different implementation that does an exhaustive linear search over
/// the full block whenever the block is full (`len == 128`). It is surprisingly faster, most likely because of the lack
/// of branch.
pub(crate) fn search_in_block(
self,
block_docs: &AlignedBuffer,
len: usize,
start: usize,
target: u32,
) -> usize {
/// the block regardless of whether the block is full or not.
///
/// Indeed, if the block is not full, the remaining items are TERMINATED.
/// It is surprisingly faster, most likely because of the lack of branch misprediction.
pub(crate) fn search_in_block(self, block_docs: &AlignedBuffer, target: u32) -> usize {
#[cfg(target_arch = "x86_64")]
{
use crate::postings::compression::COMPRESSION_BLOCK_SIZE;
if self == BlockSearcher::SSE2 && len == COMPRESSION_BLOCK_SIZE {
if self == BlockSearcher::SSE2 {
return sse2::linear_search_sse2_128(block_docs, target);
}
}
start + galloping(&block_docs.0[start..len], target)
galloping(&block_docs.0[..], target)
}
}
@@ -166,6 +162,7 @@ mod tests {
use super::exponential_search;
use super::linear_search;
use super::BlockSearcher;
use crate::docset::TERMINATED;
use crate::postings::compression::{AlignedBuffer, COMPRESSION_BLOCK_SIZE};
#[test]
@@ -196,19 +193,12 @@ mod tests {
fn util_test_search_in_block(block_searcher: BlockSearcher, block: &[u32], target: u32) {
let cursor = search_in_block_trivial_but_slow(block, target);
assert!(block.len() < COMPRESSION_BLOCK_SIZE);
let mut output_buffer = [u32::max_value(); COMPRESSION_BLOCK_SIZE];
let mut output_buffer = [TERMINATED; COMPRESSION_BLOCK_SIZE];
output_buffer[..block.len()].copy_from_slice(block);
for i in 0..cursor {
assert_eq!(
block_searcher.search_in_block(
&AlignedBuffer(output_buffer),
block.len(),
i,
target
),
cursor
);
}
assert_eq!(
block_searcher.search_in_block(&AlignedBuffer(output_buffer), target),
cursor
);
}
fn util_test_search_in_block_all(block_searcher: BlockSearcher, block: &[u32]) {

View File

@@ -0,0 +1,515 @@
use crate::common::{BinarySerializable, VInt};
use crate::directory::ReadOnlySource;
use crate::fieldnorm::FieldNormReader;
use crate::postings::compression::{
AlignedBuffer, BlockDecoder, VIntDecoder, COMPRESSION_BLOCK_SIZE,
};
use crate::postings::{BlockInfo, FreqReadingOption, SkipReader};
use crate::query::BM25Weight;
use crate::schema::IndexRecordOption;
use crate::{DocId, Score, TERMINATED};
fn max_f32<I: Iterator<Item = f32>>(mut it: I) -> Option<f32> {
if let Some(first) = it.next() {
Some(it.fold(first, f32::max))
} else {
None
}
}
/// `BlockSegmentPostings` is a cursor iterating over blocks
/// of documents.
///
/// # Warning
///
/// While it is useful for some very specific high-performance
/// use cases, you should prefer using `SegmentPostings` for most usage.
pub struct BlockSegmentPostings {
pub(crate) doc_decoder: BlockDecoder,
loaded_offset: usize,
freq_decoder: BlockDecoder,
freq_reading_option: FreqReadingOption,
block_max_score_cache: Option<Score>,
doc_freq: u32,
data: ReadOnlySource,
pub(crate) skip_reader: SkipReader,
}
fn decode_bitpacked_block(
doc_decoder: &mut BlockDecoder,
freq_decoder_opt: Option<&mut BlockDecoder>,
data: &[u8],
doc_offset: DocId,
doc_num_bits: u8,
tf_num_bits: u8,
) {
let num_consumed_bytes = doc_decoder.uncompress_block_sorted(data, doc_offset, doc_num_bits);
if let Some(freq_decoder) = freq_decoder_opt {
freq_decoder.uncompress_block_unsorted(&data[num_consumed_bytes..], tf_num_bits);
}
}
fn decode_vint_block(
doc_decoder: &mut BlockDecoder,
freq_decoder_opt: Option<&mut BlockDecoder>,
data: &[u8],
doc_offset: DocId,
num_vint_docs: usize,
) {
doc_decoder.fill(TERMINATED);
let num_consumed_bytes = doc_decoder.uncompress_vint_sorted(data, doc_offset, num_vint_docs);
if let Some(freq_decoder) = freq_decoder_opt {
freq_decoder.uncompress_vint_unsorted(&data[num_consumed_bytes..], num_vint_docs);
}
}
fn split_into_skips_and_postings(
doc_freq: u32,
data: ReadOnlySource,
) -> (Option<ReadOnlySource>, ReadOnlySource) {
if doc_freq < COMPRESSION_BLOCK_SIZE as u32 {
return (None, data);
}
let mut data_byte_arr = data.as_slice();
let skip_len = VInt::deserialize(&mut data_byte_arr)
.expect("Data corrupted")
.0 as usize;
let vint_len = data.len() - data_byte_arr.len();
let (skip_data, postings_data) = data.slice_from(vint_len).split(skip_len);
(Some(skip_data), postings_data)
}
impl BlockSegmentPostings {
pub(crate) fn from_data(
doc_freq: u32,
data: ReadOnlySource,
record_option: IndexRecordOption,
requested_option: IndexRecordOption,
) -> BlockSegmentPostings {
let freq_reading_option = match (record_option, requested_option) {
(IndexRecordOption::Basic, _) => FreqReadingOption::NoFreq,
(_, IndexRecordOption::Basic) => FreqReadingOption::SkipFreq,
(_, _) => FreqReadingOption::ReadFreq,
};
let (skip_data_opt, postings_data) = split_into_skips_and_postings(doc_freq, data);
let skip_reader = match skip_data_opt {
Some(skip_data) => SkipReader::new(skip_data, doc_freq, record_option),
None => SkipReader::new(ReadOnlySource::empty(), doc_freq, record_option),
};
let mut block_segment_postings = BlockSegmentPostings {
doc_decoder: BlockDecoder::with_val(TERMINATED),
loaded_offset: std::usize::MAX,
freq_decoder: BlockDecoder::with_val(1),
freq_reading_option,
block_max_score_cache: None,
doc_freq,
data: postings_data,
skip_reader,
};
block_segment_postings.load_block();
block_segment_postings
}
/// Returns the block_max_score for the current block.
/// It does not require the block to be loaded. For instance, it is ok to call this method
/// after having called `.shallow_advance(..)`.
///
/// See `TermScorer::block_max_score(..)` for more information.
pub fn block_max_score(
&mut self,
fieldnorm_reader: &FieldNormReader,
bm25_weight: &BM25Weight,
) -> Score {
let (block_max_score_cache, skip_reader, doc_decoder, freq_decoder) = (
&mut self.block_max_score_cache,
&self.skip_reader,
&self.doc_decoder,
&self.freq_decoder,
);
*block_max_score_cache.get_or_insert_with(|| {
skip_reader
.block_max_score(bm25_weight)
.or_else(|| {
let docs = doc_decoder.output_array();
let freqs = freq_decoder.output_array();
max_f32(docs.iter().cloned().zip(freqs.iter().cloned()).map(
|(doc, term_freq)| {
let fieldnorm_id = fieldnorm_reader.fieldnorm_id(doc);
bm25_weight.score(fieldnorm_id, term_freq)
},
))
})
.unwrap_or(0f32)
})
}
pub(crate) fn freq_reading_option(&self) -> FreqReadingOption {
self.freq_reading_option
}
// Resets the block segment postings on another position
// in the postings file.
//
// This is useful for enumerating through a list of terms,
// and consuming the associated posting lists while avoiding
// reallocating a `BlockSegmentPostings`.
//
// # Warning
//
// This does not reset the positions list.
pub(crate) fn reset(&mut self, doc_freq: u32, postings_data: ReadOnlySource) {
let (skip_data_opt, postings_data) = split_into_skips_and_postings(doc_freq, postings_data);
self.data = ReadOnlySource::new(postings_data);
self.loaded_offset = std::usize::MAX;
if let Some(skip_data) = skip_data_opt {
self.skip_reader.reset(skip_data, doc_freq);
} else {
self.skip_reader.reset(ReadOnlySource::empty(), doc_freq);
}
self.doc_freq = doc_freq;
self.load_block();
}
/// Returns the overall number of documents in the block postings.
/// It does not take in account whether documents are deleted or not.
///
/// This `doc_freq` is simply the sum of the length of all of the blocks
/// length, and it does not take in account deleted documents.
pub fn doc_freq(&self) -> u32 {
self.doc_freq
}
/// Returns the array of docs in the current block.
///
/// Before the first call to `.advance()`, the block
/// returned by `.docs()` is empty.
#[inline]
pub fn docs(&self) -> &[DocId] {
debug_assert!(self.block_is_loaded());
self.doc_decoder.output_array()
}
/// Returns a full block, regardless of whetehr the block is complete or incomplete (
/// as it happens for the last block of the posting list).
///
/// In the latter case, the block is guaranteed to be padded with the sentinel value:
/// `TERMINATED`. The array is also guaranteed to be aligned on 16 bytes = 128 bits.
///
/// This method is useful to run SSE2 linear search.
#[inline(always)]
pub(crate) fn docs_aligned(&self) -> &AlignedBuffer {
debug_assert!(self.block_is_loaded());
self.doc_decoder.output_aligned()
}
/// Return the document at index `idx` of the block.
#[inline(always)]
pub fn doc(&self, idx: usize) -> u32 {
self.doc_decoder.output(idx)
}
/// Return the array of `term freq` in the block.
#[inline]
pub fn freqs(&self) -> &[u32] {
debug_assert!(self.block_is_loaded());
self.freq_decoder.output_array()
}
/// Return the frequency at index `idx` of the block.
#[inline]
pub fn freq(&self, idx: usize) -> u32 {
debug_assert!(self.block_is_loaded());
self.freq_decoder.output(idx)
}
/// Returns the length of the current block.
///
/// All blocks have a length of `NUM_DOCS_PER_BLOCK`,
/// except the last block that may have a length
/// of any number between 1 and `NUM_DOCS_PER_BLOCK - 1`
#[inline]
pub fn block_len(&self) -> usize {
debug_assert!(self.block_is_loaded());
self.doc_decoder.output_len
}
/// Position on a block that may contains `target_doc`.
///
/// If all docs are smaller than target, the block loaded may be empty,
/// or be the last an incomplete VInt block.
pub fn seek(&mut self, target_doc: DocId) {
self.skip_reader.seek(target_doc);
self.load_block();
}
pub(crate) fn position_offset(&self) -> u64 {
self.skip_reader.position_offset()
}
/// Dangerous API! This calls seek on the skip list,
/// but does not `.load_block()` afterwards.
///
/// `.load_block()` needs to be called manually afterwards.
/// If all docs are smaller than target, the block loaded may be empty,
/// or be the last an incomplete VInt block.
pub(crate) fn shallow_seek(&mut self, target_doc: DocId) {
self.skip_reader.seek(target_doc);
}
pub(crate) fn block_is_loaded(&self) -> bool {
self.loaded_offset == self.skip_reader.byte_offset()
}
pub(crate) fn load_block(&mut self) {
self.block_max_score_cache = None;
let offset = self.skip_reader.byte_offset();
if self.loaded_offset == offset {
return;
}
self.loaded_offset = offset;
match self.skip_reader.block_info() {
BlockInfo::BitPacked {
doc_num_bits,
tf_num_bits,
..
} => {
decode_bitpacked_block(
&mut self.doc_decoder,
if let FreqReadingOption::ReadFreq = self.freq_reading_option {
Some(&mut self.freq_decoder)
} else {
None
},
&self.data.as_slice()[offset..],
self.skip_reader.last_doc_in_previous_block,
doc_num_bits,
tf_num_bits,
);
}
BlockInfo::VInt { num_docs } => {
let data = {
if num_docs == 0 {
&[]
} else {
&self.data.as_slice()[offset..]
}
};
decode_vint_block(
&mut self.doc_decoder,
if let FreqReadingOption::ReadFreq = self.freq_reading_option {
Some(&mut self.freq_decoder)
} else {
None
},
data,
self.skip_reader.last_doc_in_previous_block,
num_docs as usize,
);
}
}
}
/// Advance to the next block.
///
/// Returns false iff there was no remaining blocks.
pub fn advance(&mut self) {
self.skip_reader.advance();
self.load_block();
}
/// Returns an empty segment postings object
pub fn empty() -> BlockSegmentPostings {
BlockSegmentPostings {
doc_decoder: BlockDecoder::with_val(TERMINATED),
loaded_offset: 0,
freq_decoder: BlockDecoder::with_val(1),
freq_reading_option: FreqReadingOption::NoFreq,
block_max_score_cache: None,
doc_freq: 0,
data: ReadOnlySource::new(vec![]),
skip_reader: SkipReader::new(ReadOnlySource::new(vec![]), 0, IndexRecordOption::Basic),
}
}
}
#[cfg(test)]
mod tests {
use super::BlockSegmentPostings;
use crate::common::HasLen;
use crate::core::Index;
use crate::docset::{DocSet, TERMINATED};
use crate::postings::compression::COMPRESSION_BLOCK_SIZE;
use crate::postings::postings::Postings;
use crate::postings::SegmentPostings;
use crate::schema::IndexRecordOption;
use crate::schema::Schema;
use crate::schema::Term;
use crate::schema::INDEXED;
use crate::DocId;
#[test]
fn test_empty_segment_postings() {
let mut postings = SegmentPostings::empty();
assert_eq!(postings.doc(), TERMINATED);
assert_eq!(postings.advance(), TERMINATED);
assert_eq!(postings.advance(), TERMINATED);
assert_eq!(postings.doc_freq(), 0);
assert_eq!(postings.len(), 0);
}
#[test]
fn test_empty_postings_doc_returns_terminated() {
let mut postings = SegmentPostings::empty();
assert_eq!(postings.doc(), TERMINATED);
assert_eq!(postings.advance(), TERMINATED);
}
#[test]
fn test_empty_postings_doc_term_freq_returns_0() {
let postings = SegmentPostings::empty();
assert_eq!(postings.term_freq(), 1);
}
#[test]
fn test_empty_block_segment_postings() {
let mut postings = BlockSegmentPostings::empty();
assert!(postings.docs().is_empty());
assert_eq!(postings.doc_freq(), 0);
postings.advance();
assert!(postings.docs().is_empty());
assert_eq!(postings.doc_freq(), 0);
}
#[test]
fn test_block_segment_postings() {
let mut block_segments = build_block_postings(&(0..100_000).collect::<Vec<u32>>());
let mut offset: u32 = 0u32;
// checking that the `doc_freq` is correct
assert_eq!(block_segments.doc_freq(), 100_000);
loop {
let block = block_segments.docs();
if block.is_empty() {
break;
}
for (i, doc) in block.iter().cloned().enumerate() {
assert_eq!(offset + (i as u32), doc);
}
offset += block.len() as u32;
block_segments.advance();
}
}
#[test]
fn test_skip_right_at_new_block() {
let mut doc_ids = (0..128).collect::<Vec<u32>>();
// 128 is missing
doc_ids.push(129);
doc_ids.push(130);
{
let block_segments = build_block_postings(&doc_ids);
let mut docset = SegmentPostings::from_block_postings(block_segments, None);
assert_eq!(docset.seek(128), 129);
assert_eq!(docset.doc(), 129);
assert_eq!(docset.advance(), 130);
assert_eq!(docset.doc(), 130);
assert_eq!(docset.advance(), TERMINATED);
}
{
let block_segments = build_block_postings(&doc_ids);
let mut docset = SegmentPostings::from_block_postings(block_segments, None);
assert_eq!(docset.seek(129), 129);
assert_eq!(docset.doc(), 129);
assert_eq!(docset.advance(), 130);
assert_eq!(docset.doc(), 130);
assert_eq!(docset.advance(), TERMINATED);
}
{
let block_segments = build_block_postings(&doc_ids);
let mut docset = SegmentPostings::from_block_postings(block_segments, None);
assert_eq!(docset.doc(), 0);
assert_eq!(docset.seek(131), TERMINATED);
assert_eq!(docset.doc(), TERMINATED);
}
}
fn build_block_postings(docs: &[DocId]) -> BlockSegmentPostings {
let mut schema_builder = Schema::builder();
let int_field = schema_builder.add_u64_field("id", INDEXED);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap();
let mut last_doc = 0u32;
for &doc in docs {
for _ in last_doc..doc {
index_writer.add_document(doc!(int_field=>1u64));
}
index_writer.add_document(doc!(int_field=>0u64));
last_doc = doc + 1;
}
index_writer.commit().unwrap();
let searcher = index.reader().unwrap().searcher();
let segment_reader = searcher.segment_reader(0);
let inverted_index = segment_reader.inverted_index(int_field);
let term = Term::from_field_u64(int_field, 0u64);
let term_info = inverted_index.get_term_info(&term).unwrap();
inverted_index.read_block_postings_from_terminfo(&term_info, IndexRecordOption::Basic)
}
#[test]
fn test_block_segment_postings_seek2() {
let mut docs = vec![0];
for i in 0..1300 {
docs.push((i * i / 100) + i);
}
let mut block_postings = build_block_postings(&docs[..]);
for i in vec![0, 424, 10000] {
block_postings.shallow_seek(i);
block_postings.load_block();
let docs = block_postings.docs();
assert!(docs[0] <= i);
assert!(docs.last().cloned().unwrap_or(0u32) >= i);
}
block_postings.shallow_seek(100_000);
block_postings.load_block();
assert_eq!(block_postings.doc(COMPRESSION_BLOCK_SIZE - 1), TERMINATED);
}
#[test]
fn test_reset_block_segment_postings() {
let mut schema_builder = Schema::builder();
let int_field = schema_builder.add_u64_field("id", INDEXED);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap();
// create two postings list, one containg even number,
// the other containing odd numbers.
for i in 0..6 {
let doc = doc!(int_field=> (i % 2) as u64);
index_writer.add_document(doc);
}
index_writer.commit().unwrap();
let searcher = index.reader().unwrap().searcher();
let segment_reader = searcher.segment_reader(0);
let mut block_segments;
{
let term = Term::from_field_u64(int_field, 0u64);
let inverted_index = segment_reader.inverted_index(int_field);
let term_info = inverted_index.get_term_info(&term).unwrap();
block_segments = inverted_index
.read_block_postings_from_terminfo(&term_info, IndexRecordOption::Basic);
}
assert_eq!(block_segments.docs(), &[0, 2, 4]);
{
let term = Term::from_field_u64(int_field, 1u64);
let inverted_index = segment_reader.inverted_index(int_field);
let term_info = inverted_index.get_term_info(&term).unwrap();
inverted_index.reset_block_postings_from_terminfo(&term_info, &mut block_segments);
}
assert_eq!(block_segments.docs(), &[1, 3, 5]);
}
}

View File

@@ -1,4 +1,5 @@
use crate::common::FixedSize;
use crate::docset::TERMINATED;
use bitpacking::{BitPacker, BitPacker4x};
pub const COMPRESSION_BLOCK_SIZE: usize = BitPacker4x::BLOCK_LEN;
@@ -17,6 +18,12 @@ pub struct BlockEncoder {
pub output_len: usize,
}
impl Default for BlockEncoder {
fn default() -> Self {
BlockEncoder::new()
}
}
impl BlockEncoder {
pub fn new() -> BlockEncoder {
BlockEncoder {
@@ -54,11 +61,13 @@ pub struct BlockDecoder {
pub output_len: usize,
}
impl BlockDecoder {
pub fn new() -> BlockDecoder {
impl Default for BlockDecoder {
fn default() -> Self {
BlockDecoder::with_val(0u32)
}
}
impl BlockDecoder {
pub fn with_val(val: u32) -> BlockDecoder {
BlockDecoder {
bitpacker: BitPacker4x::new(),
@@ -90,14 +99,18 @@ impl BlockDecoder {
}
#[inline]
pub(crate) fn output_aligned(&self) -> (&AlignedBuffer, usize) {
(&self.output, self.output_len)
pub(crate) fn output_aligned(&self) -> &AlignedBuffer {
&self.output
}
#[inline]
pub fn output(&self, idx: usize) -> u32 {
self.output.0[idx]
}
pub fn fill(&mut self, val: u32) {
self.output.0.iter_mut().for_each(|el| *el = val);
}
}
pub trait VIntEncoder {
@@ -134,9 +147,9 @@ pub trait VIntDecoder {
/// For instance, if delta encoded are `1, 3, 9`, and the
/// `offset` is 5, then the output will be:
/// `5 + 1 = 6, 6 + 3= 9, 9 + 9 = 18`
fn uncompress_vint_sorted<'a>(
fn uncompress_vint_sorted(
&mut self,
compressed_data: &'a [u8],
compressed_data: &[u8],
offset: u32,
num_els: usize,
) -> usize;
@@ -146,7 +159,7 @@ pub trait VIntDecoder {
///
/// The method takes a number of int to decompress, and returns
/// the amount of bytes that were read to decompress them.
fn uncompress_vint_unsorted<'a>(&mut self, compressed_data: &'a [u8], num_els: usize) -> usize;
fn uncompress_vint_unsorted(&mut self, compressed_data: &[u8], num_els: usize) -> usize;
}
impl VIntEncoder for BlockEncoder {
@@ -160,9 +173,9 @@ impl VIntEncoder for BlockEncoder {
}
impl VIntDecoder for BlockDecoder {
fn uncompress_vint_sorted<'a>(
fn uncompress_vint_sorted(
&mut self,
compressed_data: &'a [u8],
compressed_data: &[u8],
offset: u32,
num_els: usize,
) -> usize {
@@ -170,7 +183,7 @@ impl VIntDecoder for BlockDecoder {
vint::uncompress_sorted(compressed_data, &mut self.output.0[..num_els], offset)
}
fn uncompress_vint_unsorted<'a>(&mut self, compressed_data: &'a [u8], num_els: usize) -> usize {
fn uncompress_vint_unsorted(&mut self, compressed_data: &[u8], num_els: usize) -> usize {
self.output_len = num_els;
vint::uncompress_unsorted(compressed_data, &mut self.output.0[..num_els])
}
@@ -186,7 +199,7 @@ pub mod tests {
let vals: Vec<u32> = (0u32..128u32).map(|i| i * 7).collect();
let mut encoder = BlockEncoder::new();
let (num_bits, compressed_data) = encoder.compress_block_sorted(&vals, 0);
let mut decoder = BlockDecoder::new();
let mut decoder = BlockDecoder::default();
{
let consumed_num_bytes = decoder.uncompress_block_sorted(compressed_data, 0, num_bits);
assert_eq!(consumed_num_bytes, compressed_data.len());
@@ -199,9 +212,9 @@ pub mod tests {
#[test]
fn test_encode_sorted_block_with_offset() {
let vals: Vec<u32> = (0u32..128u32).map(|i| 11 + i * 7).collect();
let mut encoder = BlockEncoder::new();
let mut encoder = BlockEncoder::default();
let (num_bits, compressed_data) = encoder.compress_block_sorted(&vals, 10);
let mut decoder = BlockDecoder::new();
let mut decoder = BlockDecoder::default();
{
let consumed_num_bytes = decoder.uncompress_block_sorted(compressed_data, 10, num_bits);
assert_eq!(consumed_num_bytes, compressed_data.len());
@@ -216,11 +229,11 @@ pub mod tests {
let mut compressed: Vec<u8> = Vec::new();
let n = 128;
let vals: Vec<u32> = (0..n).map(|i| 11u32 + (i as u32) * 7u32).collect();
let mut encoder = BlockEncoder::new();
let mut encoder = BlockEncoder::default();
let (num_bits, compressed_data) = encoder.compress_block_sorted(&vals, 10);
compressed.extend_from_slice(compressed_data);
compressed.push(173u8);
let mut decoder = BlockDecoder::new();
let mut decoder = BlockDecoder::default();
{
let consumed_num_bytes = decoder.uncompress_block_sorted(&compressed, 10, num_bits);
assert_eq!(consumed_num_bytes, compressed.len() - 1);
@@ -236,11 +249,11 @@ pub mod tests {
let mut compressed: Vec<u8> = Vec::new();
let n = 128;
let vals: Vec<u32> = (0..n).map(|i| 11u32 + (i as u32) * 7u32 % 12).collect();
let mut encoder = BlockEncoder::new();
let mut encoder = BlockEncoder::default();
let (num_bits, compressed_data) = encoder.compress_block_unsorted(&vals);
compressed.extend_from_slice(compressed_data);
compressed.push(173u8);
let mut decoder = BlockDecoder::new();
let mut decoder = BlockDecoder::default();
{
let consumed_num_bytes = decoder.uncompress_block_unsorted(&compressed, num_bits);
assert_eq!(consumed_num_bytes + 1, compressed.len());
@@ -251,6 +264,11 @@ pub mod tests {
}
}
#[test]
fn test_block_decoder_initialization() {
let block = BlockDecoder::with_val(TERMINATED);
assert_eq!(block.output(0), TERMINATED);
}
#[test]
fn test_encode_vint() {
{
@@ -260,7 +278,7 @@ pub mod tests {
for offset in &[0u32, 1u32, 2u32] {
let encoded_data = encoder.compress_vint_sorted(&input, *offset);
assert!(encoded_data.len() <= expected_length);
let mut decoder = BlockDecoder::new();
let mut decoder = BlockDecoder::default();
let consumed_num_bytes =
decoder.uncompress_vint_sorted(&encoded_data, *offset, input.len());
assert_eq!(consumed_num_bytes, encoded_data.len());

View File

@@ -42,7 +42,7 @@ pub(crate) fn compress_unsorted<'a>(input: &[u32], output: &'a mut [u8]) -> &'a
}
#[inline(always)]
pub fn uncompress_sorted<'a>(compressed_data: &'a [u8], output: &mut [u32], offset: u32) -> usize {
pub fn uncompress_sorted(compressed_data: &[u8], output: &mut [u32], offset: u32) -> usize {
let mut read_byte = 0;
let mut result = offset;
for output_mut in output.iter_mut() {

View File

@@ -3,11 +3,8 @@ Postings module (also called inverted index)
*/
mod block_search;
mod block_segment_postings;
pub(crate) mod compression;
/// Postings module
///
/// Postings, also called inverted lists, is the key datastructure
/// to full-text search.
mod postings;
mod postings_writer;
mod recorder;
@@ -22,18 +19,17 @@ pub(crate) use self::block_search::BlockSearcher;
pub(crate) use self::postings_writer::MultiFieldPostingsWriter;
pub use self::serializer::{FieldSerializer, InvertedIndexSerializer};
use self::compression::COMPRESSION_BLOCK_SIZE;
pub use self::postings::Postings;
pub(crate) use self::skip::SkipReader;
pub(crate) use self::skip::{BlockInfo, SkipReader};
pub use self::term_info::TermInfo;
pub use self::segment_postings::{BlockSegmentPostings, SegmentPostings};
pub use self::block_segment_postings::BlockSegmentPostings;
pub use self::segment_postings::SegmentPostings;
pub(crate) use self::stacker::compute_table_size;
pub use crate::common::HasLen;
pub(crate) const USE_SKIP_INFO_LIMIT: u32 = COMPRESSION_BLOCK_SIZE as u32;
pub(crate) type UnorderedTermId = u64;
#[cfg_attr(feature = "cargo-clippy", allow(clippy::enum_variant_names))]
@@ -51,7 +47,7 @@ pub mod tests {
use crate::core::Index;
use crate::core::SegmentComponent;
use crate::core::SegmentReader;
use crate::docset::{DocSet, SkipResult};
use crate::docset::{DocSet, TERMINATED};
use crate::fieldnorm::FieldNormReader;
use crate::indexer::operation::AddOperation;
use crate::indexer::SegmentWriter;
@@ -77,8 +73,10 @@ pub mod tests {
let mut segment = index.new_segment();
let mut posting_serializer = InvertedIndexSerializer::open(&mut segment).unwrap();
{
let mut field_serializer = posting_serializer.new_field(text_field, 120 * 4).unwrap();
field_serializer.new_term("abc".as_bytes()).unwrap();
let mut field_serializer = posting_serializer
.new_field(text_field, 120 * 4, None)
.unwrap();
field_serializer.new_term("abc".as_bytes(), 12u32).unwrap();
for doc_id in 0u32..120u32 {
let delta_positions = vec![1, 2, 3, 2];
field_serializer
@@ -115,29 +113,12 @@ pub mod tests {
let mut postings = inverted_index
.read_postings(&term, IndexRecordOption::WithFreqsAndPositions)
.unwrap();
postings.advance();
assert_eq!(postings.doc(), 0);
postings.positions(&mut positions);
assert_eq!(&[0, 1, 2], &positions[..]);
postings.positions(&mut positions);
assert_eq!(&[0, 1, 2], &positions[..]);
postings.advance();
postings.positions(&mut positions);
assert_eq!(&[0, 5], &positions[..]);
}
{
let mut postings = inverted_index
.read_postings(&term, IndexRecordOption::WithFreqsAndPositions)
.unwrap();
postings.advance();
postings.advance();
postings.positions(&mut positions);
assert_eq!(&[0, 5], &positions[..]);
}
{
let mut postings = inverted_index
.read_postings(&term, IndexRecordOption::WithFreqsAndPositions)
.unwrap();
assert_eq!(postings.skip_next(1), SkipResult::Reached);
assert_eq!(postings.advance(), 1);
assert_eq!(postings.doc(), 1);
postings.positions(&mut positions);
assert_eq!(&[0, 5], &positions[..]);
@@ -146,7 +127,25 @@ pub mod tests {
let mut postings = inverted_index
.read_postings(&term, IndexRecordOption::WithFreqsAndPositions)
.unwrap();
assert_eq!(postings.skip_next(1002), SkipResult::Reached);
assert_eq!(postings.doc(), 0);
assert_eq!(postings.advance(), 1);
postings.positions(&mut positions);
assert_eq!(&[0, 5], &positions[..]);
}
{
let mut postings = inverted_index
.read_postings(&term, IndexRecordOption::WithFreqsAndPositions)
.unwrap();
assert_eq!(postings.seek(1), 1);
assert_eq!(postings.doc(), 1);
postings.positions(&mut positions);
assert_eq!(&[0, 5], &positions[..]);
}
{
let mut postings = inverted_index
.read_postings(&term, IndexRecordOption::WithFreqsAndPositions)
.unwrap();
assert_eq!(postings.seek(1002), 1002);
assert_eq!(postings.doc(), 1002);
postings.positions(&mut positions);
assert_eq!(&[0, 5], &positions[..]);
@@ -155,8 +154,8 @@ pub mod tests {
let mut postings = inverted_index
.read_postings(&term, IndexRecordOption::WithFreqsAndPositions)
.unwrap();
assert_eq!(postings.skip_next(100), SkipResult::Reached);
assert_eq!(postings.skip_next(1002), SkipResult::Reached);
assert_eq!(postings.seek(100), 100);
assert_eq!(postings.seek(1002), 1002);
assert_eq!(postings.doc(), 1002);
postings.positions(&mut positions);
assert_eq!(&[0, 5], &positions[..]);
@@ -281,22 +280,21 @@ pub mod tests {
.read_postings(&term_a, IndexRecordOption::WithFreqsAndPositions)
.unwrap();
assert_eq!(postings_a.len(), 1000);
assert!(postings_a.advance());
assert_eq!(postings_a.doc(), 0);
assert_eq!(postings_a.term_freq(), 6);
postings_a.positions(&mut positions);
assert_eq!(&positions[..], [0, 2, 4, 6, 7, 13]);
assert!(postings_a.advance());
assert_eq!(postings_a.advance(), 1u32);
assert_eq!(postings_a.doc(), 1u32);
assert_eq!(postings_a.term_freq(), 1);
for i in 2u32..1000u32 {
assert!(postings_a.advance());
assert_eq!(postings_a.advance(), i);
assert_eq!(postings_a.term_freq(), 1);
postings_a.positions(&mut positions);
assert_eq!(&positions[..], [i]);
assert_eq!(postings_a.doc(), i);
}
assert!(!postings_a.advance());
assert_eq!(postings_a.advance(), TERMINATED);
}
{
let term_e = Term::from_field_text(text_field, "e");
@@ -306,7 +304,6 @@ pub mod tests {
.unwrap();
assert_eq!(postings_e.len(), 1000 - 2);
for i in 2u32..1000u32 {
assert!(postings_e.advance());
assert_eq!(postings_e.term_freq(), i);
postings_e.positions(&mut positions);
assert_eq!(positions.len(), i as usize);
@@ -314,8 +311,9 @@ pub mod tests {
assert_eq!(positions[j], (j as u32));
}
assert_eq!(postings_e.doc(), i);
postings_e.advance();
}
assert!(!postings_e.advance());
assert_eq!(postings_e.doc(), TERMINATED);
}
}
}
@@ -329,16 +327,8 @@ pub mod tests {
let index = Index::create_in_ram(schema);
{
let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap();
{
let mut doc = Document::default();
doc.add_text(text_field, "g b b d c g c");
index_writer.add_document(doc);
}
{
let mut doc = Document::default();
doc.add_text(text_field, "g a b b a d c g c");
index_writer.add_document(doc);
}
index_writer.add_document(doc!(text_field => "g b b d c g c"));
index_writer.add_document(doc!(text_field => "g a b b a d c g c"));
assert!(index_writer.commit().is_ok());
}
let term_a = Term::from_field_text(text_field, "a");
@@ -348,7 +338,6 @@ pub mod tests {
.inverted_index(text_field)
.read_postings(&term_a, IndexRecordOption::WithFreqsAndPositions)
.unwrap();
assert!(postings.advance());
assert_eq!(postings.doc(), 1u32);
postings.positions(&mut positions);
assert_eq!(&positions[..], &[1u32, 4]);
@@ -370,11 +359,8 @@ pub mod tests {
let index = Index::create_in_ram(schema);
{
let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap();
for i in 0..num_docs {
let mut doc = Document::default();
doc.add_u64(value_field, 2);
doc.add_u64(value_field, (i % 2) as u64);
for i in 0u64..num_docs as u64 {
let doc = doc!(value_field => 2u64, value_field => i % 2u64);
index_writer.add_document(doc);
}
assert!(index_writer.commit().is_ok());
@@ -391,11 +377,10 @@ pub mod tests {
.inverted_index(term_2.field())
.read_postings(&term_2, IndexRecordOption::Basic)
.unwrap();
assert_eq!(segment_postings.skip_next(i), SkipResult::Reached);
assert_eq!(segment_postings.seek(i), i);
assert_eq!(segment_postings.doc(), i);
assert_eq!(segment_postings.skip_next(j), SkipResult::Reached);
assert_eq!(segment_postings.seek(j), j);
assert_eq!(segment_postings.doc(), j);
}
}
@@ -407,17 +392,16 @@ pub mod tests {
.unwrap();
// check that `skip_next` advances the iterator
assert!(segment_postings.advance());
assert_eq!(segment_postings.doc(), 0);
assert_eq!(segment_postings.skip_next(1), SkipResult::Reached);
assert_eq!(segment_postings.seek(1), 1);
assert_eq!(segment_postings.doc(), 1);
assert_eq!(segment_postings.skip_next(1), SkipResult::OverStep);
assert_eq!(segment_postings.doc(), 2);
assert_eq!(segment_postings.seek(1), 1);
assert_eq!(segment_postings.doc(), 1);
// check that going beyond the end is handled
assert_eq!(segment_postings.skip_next(num_docs), SkipResult::End);
assert_eq!(segment_postings.seek(num_docs), TERMINATED);
}
// check that filtering works
@@ -428,7 +412,7 @@ pub mod tests {
.unwrap();
for i in 0..num_docs / 2 {
assert_eq!(segment_postings.skip_next(i * 2), SkipResult::Reached);
assert_eq!(segment_postings.seek(i * 2), i * 2);
assert_eq!(segment_postings.doc(), i * 2);
}
@@ -438,7 +422,7 @@ pub mod tests {
.unwrap();
for i in 0..num_docs / 2 - 1 {
assert_eq!(segment_postings.skip_next(i * 2 + 1), SkipResult::OverStep);
assert!(segment_postings.seek(i * 2 + 1) > (i * 1) * 2);
assert_eq!(segment_postings.doc(), (i + 1) * 2);
}
}
@@ -450,6 +434,7 @@ pub mod tests {
assert!(index_writer.commit().is_ok());
}
let searcher = index.reader().unwrap().searcher();
assert_eq!(searcher.segment_readers().len(), 1);
let segment_reader = searcher.segment_reader(0);
// make sure seeking still works
@@ -460,11 +445,11 @@ pub mod tests {
.unwrap();
if i % 2 == 0 {
assert_eq!(segment_postings.skip_next(i), SkipResult::Reached);
assert_eq!(segment_postings.seek(i), i);
assert_eq!(segment_postings.doc(), i);
assert!(segment_reader.is_deleted(i));
} else {
assert_eq!(segment_postings.skip_next(i), SkipResult::Reached);
assert_eq!(segment_postings.seek(i), i);
assert_eq!(segment_postings.doc(), i);
}
}
@@ -479,12 +464,16 @@ pub mod tests {
let mut last = 2; // start from 5 to avoid seeking to 3 twice
let mut cur = 3;
loop {
match segment_postings.skip_next(cur) {
SkipResult::End => break,
SkipResult::Reached => assert_eq!(segment_postings.doc(), cur),
SkipResult::OverStep => assert_eq!(segment_postings.doc(), cur + 1),
let seek = segment_postings.seek(cur);
if seek == TERMINATED {
break;
}
assert_eq!(seek, segment_postings.doc());
if seek == cur {
assert_eq!(segment_postings.doc(), cur);
} else {
assert_eq!(segment_postings.doc(), cur + 1);
}
let next = cur + last;
last = cur;
cur = next;
@@ -570,7 +559,7 @@ pub mod tests {
}
impl<TDocSet: DocSet> DocSet for UnoptimizedDocSet<TDocSet> {
fn advance(&mut self) -> bool {
fn advance(&mut self) -> DocId {
self.0.advance()
}
@@ -595,31 +584,26 @@ pub mod tests {
) {
for target in targets {
let mut postings_opt = postings_factory();
if target < postings_opt.doc() {
continue;
}
let mut postings_unopt = UnoptimizedDocSet::wrap(postings_factory());
let skip_result_opt = postings_opt.skip_next(target);
let skip_result_unopt = postings_unopt.skip_next(target);
let skip_result_opt = postings_opt.seek(target);
let skip_result_unopt = postings_unopt.seek(target);
assert_eq!(
skip_result_unopt, skip_result_opt,
"Failed while skipping to {}",
target
);
match skip_result_opt {
SkipResult::Reached => assert_eq!(postings_opt.doc(), target),
SkipResult::OverStep => assert!(postings_opt.doc() > target),
SkipResult::End => {
return;
}
assert!(skip_result_opt >= target);
assert_eq!(skip_result_opt, postings_opt.doc());
if skip_result_opt == TERMINATED {
return;
}
while postings_opt.advance() {
assert!(postings_unopt.advance());
assert_eq!(
postings_opt.doc(),
postings_unopt.doc(),
"Failed while skipping to {}",
target
);
while postings_opt.doc() != TERMINATED {
assert_eq!(postings_opt.doc(), postings_unopt.doc());
assert_eq!(postings_opt.advance(), postings_unopt.advance());
}
assert!(!postings_unopt.advance());
}
}
}
@@ -628,7 +612,7 @@ pub mod tests {
mod bench {
use super::tests::*;
use crate::docset::SkipResult;
use crate::docset::TERMINATED;
use crate::query::Intersection;
use crate::schema::IndexRecordOption;
use crate::tests;
@@ -646,7 +630,7 @@ mod bench {
.inverted_index(TERM_A.field())
.read_postings(&*TERM_A, IndexRecordOption::Basic)
.unwrap();
while segment_postings.advance() {}
while segment_postings.advance() != TERMINATED {}
});
}
@@ -678,7 +662,7 @@ mod bench {
segment_postings_c,
segment_postings_d,
]);
while intersection.advance() {}
while intersection.advance() != TERMINATED {}
});
}
@@ -694,11 +678,10 @@ mod bench {
.unwrap();
let mut existing_docs = Vec::new();
segment_postings.advance();
for doc in &docs {
if *doc >= segment_postings.doc() {
existing_docs.push(*doc);
if segment_postings.skip_next(*doc) == SkipResult::End {
if segment_postings.seek(*doc) == TERMINATED {
break;
}
}
@@ -710,7 +693,7 @@ mod bench {
.read_postings(&*TERM_A, IndexRecordOption::Basic)
.unwrap();
for doc in &existing_docs {
if segment_postings.skip_next(*doc) == SkipResult::End {
if segment_postings.seek(*doc) == TERMINATED {
break;
}
}
@@ -749,8 +732,9 @@ mod bench {
.read_postings(&*TERM_A, IndexRecordOption::Basic)
.unwrap();
let mut s = 0u32;
while segment_postings.advance() {
while segment_postings.doc() != TERMINATED {
s += (segment_postings.doc() & n) % 1024;
segment_postings.advance()
}
s
});

View File

@@ -1,5 +1,6 @@
use super::stacker::{Addr, MemoryArena, TermHashMap};
use crate::fieldnorm::FieldNormReaders;
use crate::postings::recorder::{
BufferLender, NothingRecorder, Recorder, TFAndPositionRecorder, TermFrequencyRecorder,
};
@@ -128,6 +129,7 @@ impl MultiFieldPostingsWriter {
pub fn serialize(
&self,
serializer: &mut InvertedIndexSerializer,
fieldnorm_readers: FieldNormReaders,
) -> crate::Result<HashMap<Field, FnvHashMap<UnorderedTermId, TermOrdinal>>> {
let mut term_offsets: Vec<(&[u8], Addr, UnorderedTermId)> =
self.term_index.iter().collect();
@@ -161,8 +163,12 @@ impl MultiFieldPostingsWriter {
}
let postings_writer = &self.per_field_postings_writers[field.field_id() as usize];
let mut field_serializer =
serializer.new_field(field, postings_writer.total_num_tokens())?;
let fieldnorm_reader = fieldnorm_readers.get_field(field);
let mut field_serializer = serializer.new_field(
field,
postings_writer.total_num_tokens(),
fieldnorm_reader,
)?;
postings_writer.serialize(
&term_offsets[start..stop],
&mut field_serializer,
@@ -297,7 +303,8 @@ impl<Rec: Recorder + 'static> PostingsWriter for SpecializedPostingsWriter<Rec>
let mut buffer_lender = BufferLender::default();
for &(term_bytes, addr, _) in term_addrs {
let recorder: Rec = termdict_heap.read(addr);
serializer.new_term(&term_bytes[4..])?;
let term_doc_freq = recorder.term_doc_freq().unwrap_or(0u32);
serializer.new_term(&term_bytes[4..], term_doc_freq)?;
recorder.serialize(&mut buffer_lender, serializer, heap)?;
serializer.close_term()?;
}

View File

@@ -75,6 +75,8 @@ pub(crate) trait Recorder: Copy + 'static {
serializer: &mut FieldSerializer<'_>,
heap: &MemoryArena,
) -> io::Result<()>;
/// Returns the number of document containg this term.
fn term_doc_freq(&self) -> Option<u32>;
}
/// Only records the doc ids
@@ -113,11 +115,16 @@ impl Recorder for NothingRecorder {
) -> io::Result<()> {
let buffer = buffer_lender.lend_u8();
self.stack.read_to_end(heap, buffer);
// TODO avoid reading twice.
for doc in VInt32Reader::new(&buffer[..]) {
serializer.write_doc(doc as u32, 0u32, &[][..])?;
}
Ok(())
}
fn term_doc_freq(&self) -> Option<u32> {
None
}
}
/// Recorder encoding document ids, and term frequencies
@@ -126,6 +133,7 @@ pub struct TermFrequencyRecorder {
stack: ExpUnrolledLinkedList,
current_doc: DocId,
current_tf: u32,
term_doc_freq: u32,
}
impl Recorder for TermFrequencyRecorder {
@@ -134,6 +142,7 @@ impl Recorder for TermFrequencyRecorder {
stack: ExpUnrolledLinkedList::new(),
current_doc: u32::max_value(),
current_tf: 0u32,
term_doc_freq: 0u32,
}
}
@@ -142,6 +151,7 @@ impl Recorder for TermFrequencyRecorder {
}
fn new_doc(&mut self, doc: DocId, heap: &mut MemoryArena) {
self.term_doc_freq += 1;
self.current_doc = doc;
let _ = write_u32_vint(doc, &mut self.stack.writer(heap));
}
@@ -172,6 +182,10 @@ impl Recorder for TermFrequencyRecorder {
Ok(())
}
fn term_doc_freq(&self) -> Option<u32> {
Some(self.term_doc_freq)
}
}
/// Recorder encoding term frequencies as well as positions.
@@ -179,12 +193,14 @@ impl Recorder for TermFrequencyRecorder {
pub struct TFAndPositionRecorder {
stack: ExpUnrolledLinkedList,
current_doc: DocId,
term_doc_freq: u32,
}
impl Recorder for TFAndPositionRecorder {
fn new() -> Self {
TFAndPositionRecorder {
stack: ExpUnrolledLinkedList::new(),
current_doc: u32::max_value(),
term_doc_freq: 0u32,
}
}
@@ -194,6 +210,7 @@ impl Recorder for TFAndPositionRecorder {
fn new_doc(&mut self, doc: DocId, heap: &mut MemoryArena) {
self.current_doc = doc;
self.term_doc_freq += 1u32;
let _ = write_u32_vint(doc, &mut self.stack.writer(heap));
}
@@ -233,6 +250,10 @@ impl Recorder for TFAndPositionRecorder {
}
Ok(())
}
fn term_doc_freq(&self) -> Option<u32> {
Some(self.term_doc_freq)
}
}
#[cfg(test)]

View File

@@ -1,56 +1,20 @@
use crate::common::BitSet;
use crate::common::HasLen;
use crate::common::{BinarySerializable, VInt};
use crate::docset::{DocSet, SkipResult};
use crate::docset::DocSet;
use crate::positions::PositionReader;
use crate::postings::compression::{compressed_block_size, AlignedBuffer};
use crate::postings::compression::{BlockDecoder, VIntDecoder, COMPRESSION_BLOCK_SIZE};
use crate::postings::compression::COMPRESSION_BLOCK_SIZE;
use crate::postings::serializer::PostingsSerializer;
use crate::postings::BlockSearcher;
use crate::postings::FreqReadingOption;
use crate::postings::Postings;
use crate::postings::SkipReader;
use crate::postings::USE_SKIP_INFO_LIMIT;
use crate::schema::IndexRecordOption;
use crate::DocId;
use owned_read::OwnedRead;
use std::cmp::Ordering;
use tantivy_fst::Streamer;
struct PositionComputer {
// store the amount of position int
// before reading positions.
//
// if none, position are already loaded in
// the positions vec.
position_to_skip: usize,
position_reader: PositionReader,
}
impl PositionComputer {
pub fn new(position_reader: PositionReader) -> PositionComputer {
PositionComputer {
position_to_skip: 0,
position_reader,
}
}
pub fn add_skip(&mut self, num_skip: usize) {
self.position_to_skip += num_skip;
}
// Positions can only be read once.
pub fn positions_with_offset(&mut self, offset: u32, output: &mut [u32]) {
self.position_reader.skip(self.position_to_skip);
self.position_to_skip = 0;
self.position_reader.read(output);
let mut cum = offset;
for output_mut in output.iter_mut() {
cum += *output_mut;
*output_mut = cum;
}
}
}
use crate::directory::ReadOnlySource;
use crate::fieldnorm::FieldNormReader;
use crate::postings::BlockSegmentPostings;
/// `SegmentPostings` represents the inverted list or postings associated to
/// a term in a `Segment`.
@@ -58,24 +22,29 @@ impl PositionComputer {
/// As we iterate through the `SegmentPostings`, the frequencies are optionally decoded.
/// Positions on the other hand, are optionally entirely decoded upfront.
pub struct SegmentPostings {
block_cursor: BlockSegmentPostings,
pub(crate) block_cursor: BlockSegmentPostings,
cur: usize,
position_computer: Option<PositionComputer>,
position_reader: Option<PositionReader>,
block_searcher: BlockSearcher,
}
impl SegmentPostings {
/// Returns an empty segment postings object
pub fn empty() -> Self {
let empty_block_cursor = BlockSegmentPostings::empty();
SegmentPostings {
block_cursor: empty_block_cursor,
cur: COMPRESSION_BLOCK_SIZE,
position_computer: None,
block_cursor: BlockSegmentPostings::empty(),
cur: 0,
position_reader: None,
block_searcher: BlockSearcher::default(),
}
}
/// Returns the overall number of documents in the block postings.
/// It does not take in account whether documents are deleted or not.
pub fn doc_freq(&self) -> u32 {
self.block_cursor.doc_freq()
}
/// Creates a segment postings object with the given documents
/// and no frequency encoded.
///
@@ -87,7 +56,8 @@ impl SegmentPostings {
pub fn create_from_docs(docs: &[u32]) -> SegmentPostings {
let mut buffer = Vec::new();
{
let mut postings_serializer = PostingsSerializer::new(&mut buffer, false, false);
let mut postings_serializer = PostingsSerializer::new(&mut buffer, false, false, None);
postings_serializer.new_term(docs.len() as u32);
for &doc in docs {
postings_serializer.write_doc(doc, 1u32);
}
@@ -97,15 +67,36 @@ impl SegmentPostings {
}
let block_segment_postings = BlockSegmentPostings::from_data(
docs.len() as u32,
OwnedRead::new(buffer),
ReadOnlySource::from(buffer),
IndexRecordOption::Basic,
IndexRecordOption::Basic,
);
SegmentPostings::from_block_postings(block_segment_postings, None)
}
}
impl SegmentPostings {
/// Helper functions to create `SegmentPostings` for tests.
pub fn create_from_docs_and_tfs(
doc_and_tfs: &[(u32, u32)],
fieldnorm_reader: Option<FieldNormReader>,
) -> crate::Result<SegmentPostings> {
let mut buffer = Vec::new();
let mut postings_serializer =
PostingsSerializer::new(&mut buffer, true, false, fieldnorm_reader);
postings_serializer.new_term(doc_and_tfs.len() as u32);
for &(doc, tf) in doc_and_tfs {
postings_serializer.write_doc(doc, tf);
}
postings_serializer
.close_term(doc_and_tfs.len() as u32)?;
let block_segment_postings = BlockSegmentPostings::from_data(
doc_and_tfs.len() as u32,
ReadOnlySource::from(buffer),
IndexRecordOption::WithFreqs,
IndexRecordOption::WithFreqs,
);
Ok(SegmentPostings::from_block_postings(block_segment_postings, None))
}
/// Reads a Segment postings from an &[u8]
///
/// * `len` - number of document in the posting lists.
@@ -114,12 +105,12 @@ impl SegmentPostings {
/// frequencies and/or positions
pub(crate) fn from_block_postings(
segment_block_postings: BlockSegmentPostings,
positions_stream_opt: Option<PositionReader>,
position_reader: Option<PositionReader>,
) -> SegmentPostings {
SegmentPostings {
block_cursor: segment_block_postings,
cur: COMPRESSION_BLOCK_SIZE, // cursor within the block
position_computer: positions_stream_opt.map(PositionComputer::new),
cur: 0, // cursor within the block
position_reader,
block_searcher: BlockSearcher::default(),
}
}
@@ -129,139 +120,60 @@ impl DocSet for SegmentPostings {
// goes to the next element.
// next needs to be called a first time to point to the correct element.
#[inline]
fn advance(&mut self) -> bool {
if self.position_computer.is_some() && self.cur < COMPRESSION_BLOCK_SIZE {
let term_freq = self.term_freq() as usize;
if let Some(position_computer) = self.position_computer.as_mut() {
position_computer.add_skip(term_freq);
}
}
self.cur += 1;
if self.cur >= self.block_cursor.block_len() {
fn advance(&mut self) -> DocId {
assert!(self.block_cursor.block_is_loaded());
if self.cur == COMPRESSION_BLOCK_SIZE - 1 {
self.cur = 0;
if !self.block_cursor.advance() {
self.cur = COMPRESSION_BLOCK_SIZE;
return false;
}
self.block_cursor.advance();
} else {
self.cur += 1;
}
true
self.doc()
}
fn skip_next(&mut self, target: DocId) -> SkipResult {
if !self.advance() {
return SkipResult::End;
}
match self.doc().cmp(&target) {
Ordering::Equal => {
return SkipResult::Reached;
}
Ordering::Greater => {
return SkipResult::OverStep;
}
_ => {
// ...
}
fn seek(&mut self, target: DocId) -> DocId {
debug_assert!(self.doc() <= target);
if self.doc() >= target {
return self.doc();
}
// In the following, thanks to the call to advance above,
// we know that the position is not loaded and we need
// to skip every doc_freq we cross.
self.block_cursor.seek(target);
// skip blocks until one that might contain the target
// check if we need to go to the next block
let mut sum_freqs_skipped: u32 = 0;
if !self
.block_cursor
.docs()
.last()
.map(|doc| *doc >= target)
.unwrap_or(false)
// there should always be at least a document in the block
// since advance returned.
{
// we are not in the right block.
//
// First compute all of the freqs skipped from the current block.
if self.position_computer.is_some() {
sum_freqs_skipped = self.block_cursor.freqs()[self.cur..].iter().sum();
match self.block_cursor.skip_to(target) {
BlockSegmentPostingsSkipResult::Success(block_skip_freqs) => {
sum_freqs_skipped += block_skip_freqs;
}
BlockSegmentPostingsSkipResult::Terminated => {
return SkipResult::End;
}
}
} else if self.block_cursor.skip_to(target)
== BlockSegmentPostingsSkipResult::Terminated
{
// no positions needed. no need to sum freqs.
return SkipResult::End;
}
self.cur = 0;
}
// At this point we are on the block, that might contain our document.
let output = self.block_cursor.docs_aligned();
self.cur = self.block_searcher.search_in_block(&output, target);
let cur = self.cur;
// we're in the right block now, start with an exponential search
let (output, len) = self.block_cursor.docs_aligned();
let new_cur = self
.block_searcher
.search_in_block(&output, len, cur, target);
if let Some(position_computer) = self.position_computer.as_mut() {
sum_freqs_skipped += self.block_cursor.freqs()[cur..new_cur].iter().sum::<u32>();
position_computer.add_skip(sum_freqs_skipped as usize);
}
self.cur = new_cur;
// The last block is not full and padded with the value TERMINATED,
// so that we are guaranteed to have at least doc in the block (a real one or the padding)
// that is greater or equal to the target.
debug_assert!(self.cur < COMPRESSION_BLOCK_SIZE);
// `doc` is now the first element >= `target`
let doc = output.0[new_cur];
// If all docs are smaller than target the current block should be incomplemented and padded
// with the value `TERMINATED`.
//
// After the search, the cursor should point to the first value of TERMINATED.
let doc = output.0[self.cur];
debug_assert!(doc >= target);
if doc == target {
SkipResult::Reached
} else {
SkipResult::OverStep
}
debug_assert_eq!(doc, self.doc());
doc
}
/// Return the current document's `DocId`.
///
/// # Panics
///
/// Will panics if called without having called advance before.
#[inline]
fn doc(&self) -> DocId {
let docs = self.block_cursor.docs();
debug_assert!(
self.cur < docs.len(),
"Have you forgotten to call `.advance()` at least once before calling `.doc()` ."
);
docs[self.cur]
self.block_cursor.doc(self.cur)
}
fn size_hint(&self) -> u32 {
self.len() as u32
}
fn append_to_bitset(&mut self, bitset: &mut BitSet) {
// finish the current block
if self.advance() {
for &doc in &self.block_cursor.docs()[self.cur..] {
bitset.insert(doc);
}
// ... iterate through the remaining blocks.
while self.block_cursor.advance() {
for &doc in self.block_cursor.docs() {
bitset.insert(doc);
}
}
}
}
}
impl HasLen for SegmentPostings {
fn len(&self) -> usize {
self.block_cursor.doc_freq()
self.block_cursor.doc_freq() as usize
}
}
@@ -290,515 +202,52 @@ impl Postings for SegmentPostings {
fn positions_with_offset(&mut self, offset: u32, output: &mut Vec<u32>) {
let term_freq = self.term_freq() as usize;
if let Some(position_comp) = self.position_computer.as_mut() {
if let Some(position_reader) = self.position_reader.as_mut() {
let read_offset = self.block_cursor.position_offset()
+ (self.block_cursor.freqs()[..self.cur]
.iter()
.cloned()
.sum::<u32>() as u64);
output.resize(term_freq, 0u32);
position_comp.positions_with_offset(offset, &mut output[..]);
position_reader.read(read_offset, &mut output[..]);
let mut cum = offset;
for output_mut in output.iter_mut() {
cum += *output_mut;
*output_mut = cum;
}
} else {
output.clear();
}
}
}
/// `BlockSegmentPostings` is a cursor iterating over blocks
/// of documents.
///
/// # Warning
///
/// While it is useful for some very specific high-performance
/// use cases, you should prefer using `SegmentPostings` for most usage.
pub struct BlockSegmentPostings {
doc_decoder: BlockDecoder,
freq_decoder: BlockDecoder,
freq_reading_option: FreqReadingOption,
doc_freq: usize,
doc_offset: DocId,
num_vint_docs: usize,
remaining_data: OwnedRead,
skip_reader: SkipReader,
}
fn split_into_skips_and_postings(
doc_freq: u32,
mut data: OwnedRead,
) -> (Option<OwnedRead>, OwnedRead) {
if doc_freq >= USE_SKIP_INFO_LIMIT {
let skip_len = VInt::deserialize(&mut data).expect("Data corrupted").0 as usize;
let mut postings_data = data.clone();
postings_data.advance(skip_len);
data.clip(skip_len);
(Some(data), postings_data)
} else {
(None, data)
}
}
#[derive(Debug, Eq, PartialEq)]
pub enum BlockSegmentPostingsSkipResult {
Terminated,
Success(u32), //< number of term freqs to skip
}
impl BlockSegmentPostings {
pub(crate) fn from_data(
doc_freq: u32,
data: OwnedRead,
record_option: IndexRecordOption,
requested_option: IndexRecordOption,
) -> BlockSegmentPostings {
let freq_reading_option = match (record_option, requested_option) {
(IndexRecordOption::Basic, _) => FreqReadingOption::NoFreq,
(_, IndexRecordOption::Basic) => FreqReadingOption::SkipFreq,
(_, _) => FreqReadingOption::ReadFreq,
};
let (skip_data_opt, postings_data) = split_into_skips_and_postings(doc_freq, data);
let skip_reader = match skip_data_opt {
Some(skip_data) => SkipReader::new(skip_data, record_option),
None => SkipReader::new(OwnedRead::new(&[][..]), record_option),
};
let doc_freq = doc_freq as usize;
let num_vint_docs = doc_freq % COMPRESSION_BLOCK_SIZE;
BlockSegmentPostings {
num_vint_docs,
doc_decoder: BlockDecoder::new(),
freq_decoder: BlockDecoder::with_val(1),
freq_reading_option,
doc_offset: 0,
doc_freq,
remaining_data: postings_data,
skip_reader,
}
}
// Resets the block segment postings on another position
// in the postings file.
//
// This is useful for enumerating through a list of terms,
// and consuming the associated posting lists while avoiding
// reallocating a `BlockSegmentPostings`.
//
// # Warning
//
// This does not reset the positions list.
pub(crate) fn reset(&mut self, doc_freq: u32, postings_data: OwnedRead) {
let (skip_data_opt, postings_data) = split_into_skips_and_postings(doc_freq, postings_data);
let num_vint_docs = (doc_freq as usize) & (COMPRESSION_BLOCK_SIZE - 1);
self.num_vint_docs = num_vint_docs;
self.remaining_data = postings_data;
if let Some(skip_data) = skip_data_opt {
self.skip_reader.reset(skip_data);
} else {
self.skip_reader.reset(OwnedRead::new(&[][..]))
}
self.doc_offset = 0;
self.doc_freq = doc_freq as usize;
}
/// Returns the document frequency associated to this block postings.
///
/// This `doc_freq` is simply the sum of the length of all of the blocks
/// length, and it does not take in account deleted documents.
pub fn doc_freq(&self) -> usize {
self.doc_freq
}
/// Returns the array of docs in the current block.
///
/// Before the first call to `.advance()`, the block
/// returned by `.docs()` is empty.
#[inline]
pub fn docs(&self) -> &[DocId] {
self.doc_decoder.output_array()
}
pub(crate) fn docs_aligned(&self) -> (&AlignedBuffer, usize) {
self.doc_decoder.output_aligned()
}
/// Return the document at index `idx` of the block.
#[inline]
pub fn doc(&self, idx: usize) -> u32 {
self.doc_decoder.output(idx)
}
/// Return the array of `term freq` in the block.
#[inline]
pub fn freqs(&self) -> &[u32] {
self.freq_decoder.output_array()
}
/// Return the frequency at index `idx` of the block.
#[inline]
pub fn freq(&self, idx: usize) -> u32 {
self.freq_decoder.output(idx)
}
/// Returns the length of the current block.
///
/// All blocks have a length of `NUM_DOCS_PER_BLOCK`,
/// except the last block that may have a length
/// of any number between 1 and `NUM_DOCS_PER_BLOCK - 1`
#[inline]
fn block_len(&self) -> usize {
self.doc_decoder.output_len
}
/// position on a block that may contains `doc_id`.
/// Always advance the current block.
///
/// Returns true if a block that has an element greater or equal to the target is found.
/// Returning true does not guarantee that the smallest element of the block is smaller
/// than the target. It only guarantees that the last element is greater or equal.
///
/// Returns false iff all of the document remaining are smaller than
/// `doc_id`. In that case, all of these document are consumed.
///
pub fn skip_to(&mut self, target_doc: DocId) -> BlockSegmentPostingsSkipResult {
let mut skip_freqs = 0u32;
while self.skip_reader.advance() {
if self.skip_reader.doc() >= target_doc {
// the last document of the current block is larger
// than the target.
//
// We found our block!
let num_bits = self.skip_reader.doc_num_bits();
let num_consumed_bytes = self.doc_decoder.uncompress_block_sorted(
self.remaining_data.as_ref(),
self.doc_offset,
num_bits,
);
self.remaining_data.advance(num_consumed_bytes);
let tf_num_bits = self.skip_reader.tf_num_bits();
match self.freq_reading_option {
FreqReadingOption::NoFreq => {}
FreqReadingOption::SkipFreq => {
let num_bytes_to_skip = compressed_block_size(tf_num_bits);
self.remaining_data.advance(num_bytes_to_skip);
}
FreqReadingOption::ReadFreq => {
let num_consumed_bytes = self
.freq_decoder
.uncompress_block_unsorted(self.remaining_data.as_ref(), tf_num_bits);
self.remaining_data.advance(num_consumed_bytes);
}
}
self.doc_offset = self.skip_reader.doc();
return BlockSegmentPostingsSkipResult::Success(skip_freqs);
} else {
skip_freqs += self.skip_reader.tf_sum();
let advance_len = self.skip_reader.total_block_len();
self.doc_offset = self.skip_reader.doc();
self.remaining_data.advance(advance_len);
}
}
// we are now on the last, incomplete, variable encoded block.
if self.num_vint_docs > 0 {
let num_compressed_bytes = self.doc_decoder.uncompress_vint_sorted(
self.remaining_data.as_ref(),
self.doc_offset,
self.num_vint_docs,
);
self.remaining_data.advance(num_compressed_bytes);
match self.freq_reading_option {
FreqReadingOption::NoFreq | FreqReadingOption::SkipFreq => {}
FreqReadingOption::ReadFreq => {
self.freq_decoder
.uncompress_vint_unsorted(self.remaining_data.as_ref(), self.num_vint_docs);
}
}
self.num_vint_docs = 0;
return self
.docs()
.last()
.map(|last_doc| {
if *last_doc >= target_doc {
BlockSegmentPostingsSkipResult::Success(skip_freqs)
} else {
BlockSegmentPostingsSkipResult::Terminated
}
})
.unwrap_or(BlockSegmentPostingsSkipResult::Terminated);
}
BlockSegmentPostingsSkipResult::Terminated
}
/// Advance to the next block.
///
/// Returns false iff there was no remaining blocks.
pub fn advance(&mut self) -> bool {
if self.skip_reader.advance() {
let num_bits = self.skip_reader.doc_num_bits();
let num_consumed_bytes = self.doc_decoder.uncompress_block_sorted(
self.remaining_data.as_ref(),
self.doc_offset,
num_bits,
);
self.remaining_data.advance(num_consumed_bytes);
let tf_num_bits = self.skip_reader.tf_num_bits();
match self.freq_reading_option {
FreqReadingOption::NoFreq => {}
FreqReadingOption::SkipFreq => {
let num_bytes_to_skip = compressed_block_size(tf_num_bits);
self.remaining_data.advance(num_bytes_to_skip);
}
FreqReadingOption::ReadFreq => {
let num_consumed_bytes = self
.freq_decoder
.uncompress_block_unsorted(self.remaining_data.as_ref(), tf_num_bits);
self.remaining_data.advance(num_consumed_bytes);
}
}
// it will be used as the next offset.
self.doc_offset = self.doc_decoder.output(COMPRESSION_BLOCK_SIZE - 1);
true
} else if self.num_vint_docs > 0 {
let num_compressed_bytes = self.doc_decoder.uncompress_vint_sorted(
self.remaining_data.as_ref(),
self.doc_offset,
self.num_vint_docs,
);
self.remaining_data.advance(num_compressed_bytes);
match self.freq_reading_option {
FreqReadingOption::NoFreq | FreqReadingOption::SkipFreq => {}
FreqReadingOption::ReadFreq => {
self.freq_decoder
.uncompress_vint_unsorted(self.remaining_data.as_ref(), self.num_vint_docs);
}
}
self.num_vint_docs = 0;
true
} else {
false
}
}
/// Returns an empty segment postings object
pub fn empty() -> BlockSegmentPostings {
BlockSegmentPostings {
num_vint_docs: 0,
doc_decoder: BlockDecoder::new(),
freq_decoder: BlockDecoder::with_val(1),
freq_reading_option: FreqReadingOption::NoFreq,
doc_offset: 0,
doc_freq: 0,
remaining_data: OwnedRead::new(vec![]),
skip_reader: SkipReader::new(OwnedRead::new(vec![]), IndexRecordOption::Basic),
}
}
}
impl<'b> Streamer<'b> for BlockSegmentPostings {
type Item = &'b [DocId];
fn next(&'b mut self) -> Option<&'b [DocId]> {
if self.advance() {
Some(self.docs())
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::BlockSegmentPostings;
use super::BlockSegmentPostingsSkipResult;
use super::SegmentPostings;
use crate::common::HasLen;
use crate::core::Index;
use crate::docset::DocSet;
use crate::docset::{DocSet, TERMINATED};
use crate::postings::postings::Postings;
use crate::schema::IndexRecordOption;
use crate::schema::Schema;
use crate::schema::Term;
use crate::schema::INDEXED;
use crate::DocId;
use crate::SkipResult;
use tantivy_fst::Streamer;
#[test]
fn test_empty_segment_postings() {
let mut postings = SegmentPostings::empty();
assert!(!postings.advance());
assert!(!postings.advance());
assert_eq!(postings.advance(), TERMINATED);
assert_eq!(postings.advance(), TERMINATED);
assert_eq!(postings.len(), 0);
}
#[test]
#[should_panic(expected = "Have you forgotten to call `.advance()`")]
fn test_panic_if_doc_called_before_advance() {
SegmentPostings::empty().doc();
fn test_empty_postings_doc_returns_terminated() {
let mut postings = SegmentPostings::empty();
assert_eq!(postings.doc(), TERMINATED);
assert_eq!(postings.advance(), TERMINATED);
}
#[test]
#[should_panic(expected = "Have you forgotten to call `.advance()`")]
fn test_panic_if_freq_called_before_advance() {
SegmentPostings::empty().term_freq();
}
#[test]
fn test_empty_block_segment_postings() {
let mut postings = BlockSegmentPostings::empty();
assert!(!postings.advance());
assert_eq!(postings.doc_freq(), 0);
}
#[test]
fn test_block_segment_postings() {
let mut block_segments = build_block_postings(&(0..100_000).collect::<Vec<u32>>());
let mut offset: u32 = 0u32;
// checking that the block before calling advance is empty
assert!(block_segments.docs().is_empty());
// checking that the `doc_freq` is correct
assert_eq!(block_segments.doc_freq(), 100_000);
while let Some(block) = block_segments.next() {
for (i, doc) in block.iter().cloned().enumerate() {
assert_eq!(offset + (i as u32), doc);
}
offset += block.len() as u32;
}
}
#[test]
fn test_skip_right_at_new_block() {
let mut doc_ids = (0..128).collect::<Vec<u32>>();
doc_ids.push(129);
doc_ids.push(130);
{
let block_segments = build_block_postings(&doc_ids);
let mut docset = SegmentPostings::from_block_postings(block_segments, None);
assert_eq!(docset.skip_next(128), SkipResult::OverStep);
assert_eq!(docset.doc(), 129);
assert!(docset.advance());
assert_eq!(docset.doc(), 130);
assert!(!docset.advance());
}
{
let block_segments = build_block_postings(&doc_ids);
let mut docset = SegmentPostings::from_block_postings(block_segments, None);
assert_eq!(docset.skip_next(129), SkipResult::Reached);
assert_eq!(docset.doc(), 129);
assert!(docset.advance());
assert_eq!(docset.doc(), 130);
assert!(!docset.advance());
}
{
let block_segments = build_block_postings(&doc_ids);
let mut docset = SegmentPostings::from_block_postings(block_segments, None);
assert_eq!(docset.skip_next(131), SkipResult::End);
}
}
fn build_block_postings(docs: &[DocId]) -> BlockSegmentPostings {
let mut schema_builder = Schema::builder();
let int_field = schema_builder.add_u64_field("id", INDEXED);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap();
let mut last_doc = 0u32;
for &doc in docs {
for _ in last_doc..doc {
index_writer.add_document(doc!(int_field=>1u64));
}
index_writer.add_document(doc!(int_field=>0u64));
last_doc = doc + 1;
}
index_writer.commit().unwrap();
let searcher = index.reader().unwrap().searcher();
let segment_reader = searcher.segment_reader(0);
let inverted_index = segment_reader.inverted_index(int_field);
let term = Term::from_field_u64(int_field, 0u64);
let term_info = inverted_index.get_term_info(&term).unwrap();
inverted_index.read_block_postings_from_terminfo(&term_info, IndexRecordOption::Basic)
}
#[test]
fn test_block_segment_postings_skip() {
for i in 0..4 {
let mut block_postings = build_block_postings(&[3]);
assert_eq!(
block_postings.skip_to(i),
BlockSegmentPostingsSkipResult::Success(0u32)
);
assert_eq!(
block_postings.skip_to(i),
BlockSegmentPostingsSkipResult::Terminated
);
}
let mut block_postings = build_block_postings(&[3]);
assert_eq!(
block_postings.skip_to(4u32),
BlockSegmentPostingsSkipResult::Terminated
);
}
#[test]
fn test_block_segment_postings_skip2() {
let mut docs = vec![0];
for i in 0..1300 {
docs.push((i * i / 100) + i);
}
let mut block_postings = build_block_postings(&docs[..]);
for i in vec![0, 424, 10000] {
assert_eq!(
block_postings.skip_to(i),
BlockSegmentPostingsSkipResult::Success(0u32)
);
let docs = block_postings.docs();
assert!(docs[0] <= i);
assert!(docs.last().cloned().unwrap_or(0u32) >= i);
}
assert_eq!(
block_postings.skip_to(100_000),
BlockSegmentPostingsSkipResult::Terminated
);
assert_eq!(
block_postings.skip_to(101_000),
BlockSegmentPostingsSkipResult::Terminated
);
}
#[test]
fn test_reset_block_segment_postings() {
let mut schema_builder = Schema::builder();
let int_field = schema_builder.add_u64_field("id", INDEXED);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap();
// create two postings list, one containg even number,
// the other containing odd numbers.
for i in 0..6 {
let doc = doc!(int_field=> (i % 2) as u64);
index_writer.add_document(doc);
}
index_writer.commit().unwrap();
let searcher = index.reader().unwrap().searcher();
let segment_reader = searcher.segment_reader(0);
let mut block_segments;
{
let term = Term::from_field_u64(int_field, 0u64);
let inverted_index = segment_reader.inverted_index(int_field);
let term_info = inverted_index.get_term_info(&term).unwrap();
block_segments = inverted_index
.read_block_postings_from_terminfo(&term_info, IndexRecordOption::Basic);
}
assert!(block_segments.advance());
assert_eq!(block_segments.docs(), &[0, 2, 4]);
{
let term = Term::from_field_u64(int_field, 1u64);
let inverted_index = segment_reader.inverted_index(int_field);
let term_info = inverted_index.get_term_info(&term).unwrap();
inverted_index.reset_block_postings_from_terminfo(&term_info, &mut block_segments);
}
assert!(block_segments.advance());
assert_eq!(block_segments.docs(), &[1, 3, 5]);
fn test_empty_postings_doc_term_freq_returns_0() {
let postings = SegmentPostings::empty();
assert_eq!(postings.term_freq(), 1);
}
}

View File

@@ -3,14 +3,16 @@ use crate::common::{BinarySerializable, VInt};
use crate::common::{CompositeWrite, CountingWriter};
use crate::core::Segment;
use crate::directory::WritePtr;
use crate::fieldnorm::FieldNormReader;
use crate::positions::PositionSerializer;
use crate::postings::compression::{BlockEncoder, VIntEncoder, COMPRESSION_BLOCK_SIZE};
use crate::postings::skip::SkipSerializer;
use crate::postings::USE_SKIP_INFO_LIMIT;
use crate::query::BM25Weight;
use crate::schema::Schema;
use crate::schema::{Field, FieldEntry, FieldType};
use crate::termdict::{TermDictionaryBuilder, TermOrdinal};
use crate::DocId;
use std::cmp::Ordering;
use std::io::{self, Write};
/// `InvertedIndexSerializer` is in charge of serializing
@@ -90,6 +92,7 @@ impl InvertedIndexSerializer {
&mut self,
field: Field,
total_num_tokens: u64,
fieldnorm_reader: Option<FieldNormReader>,
) -> io::Result<FieldSerializer<'_>> {
let field_entry: &FieldEntry = self.schema.get_field_entry(field);
let term_dictionary_write = self.terms_write.for_field(field);
@@ -104,6 +107,7 @@ impl InvertedIndexSerializer {
postings_write,
positions_write,
positionsidx_write,
fieldnorm_reader,
)
}
@@ -135,6 +139,7 @@ impl<'a> FieldSerializer<'a> {
postings_write: &'a mut CountingWriter<WritePtr>,
positions_write: &'a mut CountingWriter<WritePtr>,
positionsidx_write: &'a mut CountingWriter<WritePtr>,
fieldnorm_reader: Option<FieldNormReader>,
) -> io::Result<FieldSerializer<'a>> {
let (term_freq_enabled, position_enabled): (bool, bool) = match field_type {
FieldType::Str(ref text_options) => {
@@ -148,8 +153,12 @@ impl<'a> FieldSerializer<'a> {
_ => (false, false),
};
let term_dictionary_builder = TermDictionaryBuilder::create(term_dictionary_write)?;
let postings_serializer =
PostingsSerializer::new(postings_write, term_freq_enabled, position_enabled);
let postings_serializer = PostingsSerializer::new(
postings_write,
term_freq_enabled,
position_enabled,
fieldnorm_reader,
);
let positions_serializer_opt = if position_enabled {
Some(PositionSerializer::new(positions_write, positionsidx_write))
} else {
@@ -182,8 +191,8 @@ impl<'a> FieldSerializer<'a> {
/// Starts the postings for a new term.
/// * term - the term. It needs to come after the previous term according
/// to the lexicographical order.
/// * doc_freq - return the number of document containing the term.
pub fn new_term(&mut self, term: &[u8]) -> io::Result<TermOrdinal> {
/// * term_doc_freq - return the number of document containing the term.
pub fn new_term(&mut self, term: &[u8], term_doc_freq: u32) -> io::Result<TermOrdinal> {
assert!(
!self.term_open,
"Called new_term, while the previous term was not closed."
@@ -194,6 +203,7 @@ impl<'a> FieldSerializer<'a> {
self.term_dictionary_builder.insert_key(term)?;
let term_ordinal = self.num_terms;
self.num_terms += 1;
self.postings_serializer.new_term(term_doc_freq);
Ok(term_ordinal)
}
@@ -307,6 +317,21 @@ pub struct PostingsSerializer<W: Write> {
termfreq_enabled: bool,
termfreq_sum_enabled: bool,
fieldnorm_reader: Option<FieldNormReader>,
bm25_weight: Option<BM25Weight>,
num_docs: u32, // Number of docs in the segment
avg_fieldnorm: f32, // Average number of term in the field for that segment.
// this value is used to compute the block wand information.
}
fn get_avg_fieldnorm(fieldnorm_reader: &FieldNormReader) -> f32 {
let num_docs = fieldnorm_reader.num_docs();
let sum_fieldnorm: f32 = (0u32..num_docs)
.map(|doc| fieldnorm_reader.fieldnorm(doc) as f32)
.sum();
sum_fieldnorm / (num_docs as f32)
}
impl<W: Write> PostingsSerializer<W> {
@@ -314,7 +339,16 @@ impl<W: Write> PostingsSerializer<W> {
write: W,
termfreq_enabled: bool,
termfreq_sum_enabled: bool,
fieldnorm_reader: Option<FieldNormReader>,
) -> PostingsSerializer<W> {
let avg_fieldnorm: f32 = fieldnorm_reader
.as_ref()
.map(get_avg_fieldnorm)
.unwrap_or(0f32);
let num_docs = fieldnorm_reader
.as_ref()
.map(|fieldnorm_reader| fieldnorm_reader.num_docs())
.unwrap_or(0u32);
PostingsSerializer {
output_write: CountingWriter::wrap(write),
@@ -327,6 +361,23 @@ impl<W: Write> PostingsSerializer<W> {
last_doc_id_encoded: 0u32,
termfreq_enabled,
termfreq_sum_enabled,
fieldnorm_reader,
bm25_weight: None,
num_docs,
avg_fieldnorm,
}
}
pub fn new_term(&mut self, term_doc_freq: u32) {
if self.termfreq_enabled && self.num_docs > 0 {
let bm25_weight = BM25Weight::for_one_term(
term_doc_freq as u64,
self.num_docs as u64,
self.avg_fieldnorm,
);
self.bm25_weight = Some(bm25_weight);
}
}
@@ -343,7 +394,6 @@ impl<W: Write> PostingsSerializer<W> {
self.postings_write.extend(block_encoded);
}
if self.termfreq_enabled {
// encode the term_freqs
let (num_bits, block_encoded): (u8, &[u8]) = self
.block_encoder
.compress_block_unsorted(&self.block.term_freqs());
@@ -353,6 +403,32 @@ impl<W: Write> PostingsSerializer<W> {
let sum_freq = self.block.term_freqs().iter().cloned().sum();
self.skip_write.write_total_term_freq(sum_freq);
}
let mut blockwand_params_opt = None;
if let Some(bm25_weight) = self.bm25_weight.as_ref() {
if let Some(fieldnorm_reader) = self.fieldnorm_reader.as_ref() {
let docs = self.block.doc_ids();
let term_freqs = self.block.term_freqs();
blockwand_params_opt = docs
.iter()
.cloned()
.map(|doc| fieldnorm_reader.fieldnorm_id(doc))
.zip(term_freqs.iter().cloned())
.max_by(
|(left_fieldnorm_id, left_term_freq),
(right_fieldnorm_id, right_term_freq)| {
let left_score =
bm25_weight.tf_factor(*left_fieldnorm_id, *left_term_freq);
let right_score =
bm25_weight.tf_factor(*right_fieldnorm_id, *right_term_freq);
left_score
.partial_cmp(&right_score)
.unwrap_or(Ordering::Equal)
},
);
}
}
let (fieldnorm_id, term_freq) = blockwand_params_opt.unwrap_or((0u8, 0u32));
self.skip_write.write_blockwand_max(fieldnorm_id, term_freq);
}
self.block.clear();
}
@@ -391,7 +467,7 @@ impl<W: Write> PostingsSerializer<W> {
}
self.block.clear();
}
if doc_freq >= USE_SKIP_INFO_LIMIT {
if doc_freq >= COMPRESSION_BLOCK_SIZE as u32 {
let skip_data = self.skip_write.data();
VInt(skip_data.len() as u64).serialize(&mut self.output_write)?;
self.output_write.write_all(skip_data)?;
@@ -401,6 +477,7 @@ impl<W: Write> PostingsSerializer<W> {
}
self.skip_write.clear();
self.postings_write.clear();
self.bm25_weight = None;
Ok(())
}

View File

@@ -1,7 +1,9 @@
use crate::common::BinarySerializable;
use crate::postings::compression::COMPRESSION_BLOCK_SIZE;
use crate::common::{BinarySerializable, VInt};
use crate::directory::ReadOnlySource;
use crate::postings::compression::{compressed_block_size, COMPRESSION_BLOCK_SIZE};
use crate::query::BM25Weight;
use crate::schema::IndexRecordOption;
use crate::DocId;
use crate::{DocId, Score, TERMINATED};
use owned_read::OwnedRead;
pub struct SkipSerializer {
@@ -39,6 +41,11 @@ impl SkipSerializer {
.expect("Should never fail");
}
pub fn write_blockwand_max(&mut self, fieldnorm_id: u8, term_freq: u32) {
self.buffer.push(fieldnorm_id);
VInt(term_freq as u64).serialize_into_vec(&mut self.buffer);
}
pub fn data(&self) -> &[u8] {
&self.buffer[..]
}
@@ -50,80 +57,186 @@ impl SkipSerializer {
}
pub(crate) struct SkipReader {
doc: DocId,
last_doc_in_block: DocId,
pub(crate) last_doc_in_previous_block: DocId,
owned_read: OwnedRead,
doc_num_bits: u8,
tf_num_bits: u8,
tf_sum: u32,
skip_info: IndexRecordOption,
byte_offset: usize,
remaining_docs: u32, // number of docs remaining, including the
// documents in the current block.
block_info: BlockInfo,
position_offset: u64,
}
#[derive(Clone, Copy, Debug)]
pub(crate) enum BlockInfo {
BitPacked {
doc_num_bits: u8,
tf_num_bits: u8,
tf_sum: u32,
block_wand_fieldnorm_id: u8,
block_wand_term_freq: u32,
},
VInt {
num_docs: u32,
},
}
impl Default for BlockInfo {
fn default() -> Self {
BlockInfo::VInt { num_docs: 0u32 }
}
}
impl SkipReader {
pub fn new(data: OwnedRead, skip_info: IndexRecordOption) -> SkipReader {
SkipReader {
doc: 0u32,
owned_read: data,
pub fn new(data: ReadOnlySource, doc_freq: u32, skip_info: IndexRecordOption) -> SkipReader {
let mut skip_reader = SkipReader {
last_doc_in_block: if doc_freq >= COMPRESSION_BLOCK_SIZE as u32 {
0
} else {
TERMINATED
},
last_doc_in_previous_block: 0u32,
owned_read: OwnedRead::new(data),
skip_info,
doc_num_bits: 0u8,
tf_num_bits: 0u8,
tf_sum: 0u32,
block_info: BlockInfo::VInt { num_docs: doc_freq },
byte_offset: 0,
remaining_docs: doc_freq,
position_offset: 0u64,
};
if doc_freq >= COMPRESSION_BLOCK_SIZE as u32 {
skip_reader.read_block_info();
}
skip_reader
}
pub fn reset(&mut self, data: ReadOnlySource, doc_freq: u32) {
self.last_doc_in_block = if doc_freq >= COMPRESSION_BLOCK_SIZE as u32 {
0
} else {
TERMINATED
};
self.last_doc_in_previous_block = 0u32;
self.owned_read = OwnedRead::new(data);
self.block_info = BlockInfo::VInt { num_docs: doc_freq };
self.byte_offset = 0;
self.remaining_docs = doc_freq;
self.position_offset = 0u64;
if doc_freq >= COMPRESSION_BLOCK_SIZE as u32 {
self.read_block_info();
}
}
pub fn reset(&mut self, data: OwnedRead) {
self.doc = 0u32;
self.owned_read = data;
self.doc_num_bits = 0u8;
self.tf_num_bits = 0u8;
self.tf_sum = 0u32;
pub fn block_max_score(&self, bm25_weight: &BM25Weight) -> Option<Score> {
match self.block_info {
BlockInfo::BitPacked {
block_wand_fieldnorm_id,
block_wand_term_freq,
..
} => Some(bm25_weight.score(block_wand_fieldnorm_id, block_wand_term_freq)),
BlockInfo::VInt { .. } => None,
}
}
pub fn total_block_len(&self) -> usize {
(self.doc_num_bits + self.tf_num_bits) as usize * COMPRESSION_BLOCK_SIZE / 8
pub(crate) fn last_doc_in_block(&self) -> DocId {
self.last_doc_in_block
}
pub fn doc(&self) -> DocId {
self.doc
pub fn position_offset(&self) -> u64 {
self.position_offset
}
pub fn doc_num_bits(&self) -> u8 {
self.doc_num_bits
pub fn byte_offset(&self) -> usize {
self.byte_offset
}
/// Number of bits used to encode term frequencies
///
/// 0 if term frequencies are not enabled.
pub fn tf_num_bits(&self) -> u8 {
self.tf_num_bits
}
pub fn tf_sum(&self) -> u32 {
self.tf_sum
}
pub fn advance(&mut self) -> bool {
if self.owned_read.as_ref().is_empty() {
false
} else {
let doc_delta = u32::deserialize(&mut self.owned_read).expect("Skip data corrupted");
self.doc += doc_delta as DocId;
self.doc_num_bits = self.owned_read.get(0);
match self.skip_info {
IndexRecordOption::Basic => {
self.owned_read.advance(1);
}
IndexRecordOption::WithFreqs => {
self.tf_num_bits = self.owned_read.get(1);
self.owned_read.advance(2);
}
IndexRecordOption::WithFreqsAndPositions => {
self.tf_num_bits = self.owned_read.get(1);
self.owned_read.advance(2);
self.tf_sum =
u32::deserialize(&mut self.owned_read).expect("Failed reading tf_sum");
}
fn read_block_info(&mut self) {
let doc_delta = u32::deserialize(&mut self.owned_read).expect("Skip data corrupted");
self.last_doc_in_block += doc_delta as DocId;
let doc_num_bits = self.owned_read.get(0);
match self.skip_info {
IndexRecordOption::Basic => {
self.owned_read.advance(1);
self.block_info = BlockInfo::BitPacked {
doc_num_bits,
tf_num_bits: 0,
tf_sum: 0,
block_wand_fieldnorm_id: 0,
block_wand_term_freq: 0,
};
}
true
IndexRecordOption::WithFreqs => {
let tf_num_bits = self.owned_read.get(1);
let block_wand_fieldnorm_id = self.owned_read.get(2);
self.owned_read.advance(3);
let block_wand_term_freq =
VInt::deserialize_u64(&mut self.owned_read).unwrap() as u32;
self.block_info = BlockInfo::BitPacked {
doc_num_bits,
tf_num_bits,
tf_sum: 0,
block_wand_fieldnorm_id,
block_wand_term_freq,
};
}
IndexRecordOption::WithFreqsAndPositions => {
let tf_num_bits = self.owned_read.get(1);
self.owned_read.advance(2);
let tf_sum = u32::deserialize(&mut self.owned_read).expect("Failed reading tf_sum");
let block_wand_fieldnorm_id = self.owned_read.get(0);
self.owned_read.advance(1);
let block_wand_term_freq =
VInt::deserialize_u64(&mut self.owned_read).unwrap() as u32;
self.block_info = BlockInfo::BitPacked {
doc_num_bits,
tf_num_bits,
tf_sum,
block_wand_fieldnorm_id,
block_wand_term_freq,
};
}
}
}
pub fn block_info(&self) -> BlockInfo {
self.block_info
}
/// Advance the skip reader to the block that may contain the target.
///
/// If the target is larger than all documents, the skip_reader
/// then advance to the last Variable In block.
pub fn seek(&mut self, target: DocId) {
while self.last_doc_in_block() < target {
self.advance();
}
}
pub fn advance(&mut self) {
match self.block_info {
BlockInfo::BitPacked {
doc_num_bits,
tf_num_bits,
tf_sum,
..
} => {
self.remaining_docs -= COMPRESSION_BLOCK_SIZE as u32;
self.byte_offset += compressed_block_size(doc_num_bits + tf_num_bits);
self.position_offset += tf_sum as u64;
}
BlockInfo::VInt { num_docs} => {
debug_assert_eq!(num_docs, self.remaining_docs);
self.remaining_docs = 0;
self.byte_offset = std::usize::MAX;
}
}
self.last_doc_in_previous_block = self.last_doc_in_block;
if self.remaining_docs >= COMPRESSION_BLOCK_SIZE as u32 {
self.read_block_info();
} else {
self.last_doc_in_block = TERMINATED;
self.block_info = BlockInfo::VInt { num_docs: self.remaining_docs };
}
}
}
@@ -131,9 +244,11 @@ impl SkipReader {
#[cfg(test)]
mod tests {
use super::BlockInfo;
use super::IndexRecordOption;
use super::{SkipReader, SkipSerializer};
use owned_read::OwnedRead;
use crate::directory::ReadOnlySource;
use crate::postings::compression::COMPRESSION_BLOCK_SIZE;
#[test]
fn test_skip_with_freq() {
@@ -141,20 +256,47 @@ mod tests {
let mut skip_serializer = SkipSerializer::new();
skip_serializer.write_doc(1u32, 2u8);
skip_serializer.write_term_freq(3u8);
skip_serializer.write_blockwand_max(13u8, 3u32);
skip_serializer.write_doc(5u32, 5u8);
skip_serializer.write_term_freq(2u8);
skip_serializer.write_blockwand_max(8u8, 2u32);
skip_serializer.data().to_owned()
};
let mut skip_reader = SkipReader::new(OwnedRead::new(buf), IndexRecordOption::WithFreqs);
assert!(skip_reader.advance());
assert_eq!(skip_reader.doc(), 1u32);
assert_eq!(skip_reader.doc_num_bits(), 2u8);
assert_eq!(skip_reader.tf_num_bits(), 3u8);
assert!(skip_reader.advance());
assert_eq!(skip_reader.doc(), 5u32);
assert_eq!(skip_reader.doc_num_bits(), 5u8);
assert_eq!(skip_reader.tf_num_bits(), 2u8);
assert!(!skip_reader.advance());
let doc_freq = 3u32 + (COMPRESSION_BLOCK_SIZE * 2) as u32;
let mut skip_reader = SkipReader::new(
ReadOnlySource::new(buf),
doc_freq,
IndexRecordOption::WithFreqs,
);
assert_eq!(skip_reader.last_doc_in_block(), 1u32);
assert!(matches!(
skip_reader.block_info,
BlockInfo::BitPacked {
doc_num_bits: 2u8,
tf_num_bits: 3u8,
tf_sum: 0,
block_wand_fieldnorm_id: 13,
block_wand_term_freq: 3
}
));
skip_reader.advance();
assert_eq!(skip_reader.last_doc_in_block(), 5u32);
assert!(matches!(
skip_reader.block_info(),
BlockInfo::BitPacked {
doc_num_bits: 5u8,
tf_num_bits: 2u8,
tf_sum: 0,
block_wand_fieldnorm_id: 8,
block_wand_term_freq: 2
}
));
skip_reader.advance();
assert!(matches!(skip_reader.block_info(), BlockInfo::VInt { num_docs: 3u32 }));
skip_reader.advance();
assert!(matches!(skip_reader.block_info(), BlockInfo::VInt { num_docs: 0u32 }));
skip_reader.advance();
assert!(matches!(skip_reader.block_info(), BlockInfo::VInt { num_docs: 0u32 }));
}
#[test]
@@ -165,13 +307,68 @@ mod tests {
skip_serializer.write_doc(5u32, 5u8);
skip_serializer.data().to_owned()
};
let mut skip_reader = SkipReader::new(OwnedRead::new(buf), IndexRecordOption::Basic);
assert!(skip_reader.advance());
assert_eq!(skip_reader.doc(), 1u32);
assert_eq!(skip_reader.doc_num_bits(), 2u8);
assert!(skip_reader.advance());
assert_eq!(skip_reader.doc(), 5u32);
assert_eq!(skip_reader.doc_num_bits(), 5u8);
assert!(!skip_reader.advance());
let doc_freq = 3u32 + (COMPRESSION_BLOCK_SIZE * 2) as u32;
let mut skip_reader = SkipReader::new(
ReadOnlySource::from(buf),
doc_freq,
IndexRecordOption::Basic,
);
assert_eq!(skip_reader.last_doc_in_block(), 1u32);
assert!(matches!(
skip_reader.block_info(),
BlockInfo::BitPacked {
doc_num_bits: 2u8,
tf_num_bits: 0,
tf_sum: 0u32,
block_wand_fieldnorm_id: 0,
block_wand_term_freq: 0
}
));
skip_reader.advance();
assert_eq!(skip_reader.last_doc_in_block(), 5u32);
assert!(matches!(
skip_reader.block_info(),
BlockInfo::BitPacked {
doc_num_bits: 5u8,
tf_num_bits: 0,
tf_sum: 0u32,
block_wand_fieldnorm_id: 0,
block_wand_term_freq: 0
}
));
skip_reader.advance();
assert!(matches!(skip_reader.block_info(), BlockInfo::VInt { num_docs: 3u32 }));
skip_reader.advance();
assert!(matches!(skip_reader.block_info(), BlockInfo::VInt { num_docs: 0u32 }));
skip_reader.advance();
assert!(matches!(skip_reader.block_info(), BlockInfo::VInt { num_docs: 0u32 }));
}
#[test]
fn test_skip_multiple_of_block_size() {
let buf = {
let mut skip_serializer = SkipSerializer::new();
skip_serializer.write_doc(1u32, 2u8);
skip_serializer.data().to_owned()
};
let doc_freq = COMPRESSION_BLOCK_SIZE as u32;
let mut skip_reader = SkipReader::new(
ReadOnlySource::from(buf),
doc_freq,
IndexRecordOption::Basic,
);
assert_eq!(skip_reader.last_doc_in_block(), 1u32);
assert!(matches!(
skip_reader.block_info(),
BlockInfo::BitPacked {
doc_num_bits: 2u8,
tf_num_bits: 0,
tf_sum: 0u32,
block_wand_fieldnorm_id: 0,
block_wand_term_freq: 0
}
));
skip_reader.advance();
assert!(matches!(skip_reader.block_info(), BlockInfo::VInt { num_docs: 0u32 }));
}
}

View File

@@ -1,6 +1,4 @@
use murmurhash32;
use self::murmurhash32::murmurhash2;
use murmurhash32::murmurhash2;
use super::{Addr, MemoryArena};
use crate::postings::stacker::memory_arena::store;

View File

@@ -1,6 +1,7 @@
use crate::core::Searcher;
use crate::core::SegmentReader;
use crate::docset::DocSet;
use crate::docset::{DocSet, TERMINATED};
use crate::query::boost_query::BoostScorer;
use crate::query::explanation::does_not_match;
use crate::query::{Explanation, Query, Scorer, Weight};
use crate::DocId;
@@ -22,12 +23,12 @@ impl Query for AllQuery {
pub struct AllWeight;
impl Weight for AllWeight {
fn scorer(&self, reader: &SegmentReader) -> crate::Result<Box<dyn Scorer>> {
Ok(Box::new(AllScorer {
state: State::NotStarted,
fn scorer(&self, reader: &SegmentReader, boost: f32) -> crate::Result<Box<dyn Scorer>> {
let all_scorer = AllScorer {
doc: 0u32,
max_doc: reader.max_doc(),
}))
};
Ok(Box::new(BoostScorer::new(all_scorer, boost)))
}
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {
@@ -38,39 +39,20 @@ impl Weight for AllWeight {
}
}
enum State {
NotStarted,
Started,
Finished,
}
/// Scorer associated to the `AllQuery` query.
pub struct AllScorer {
state: State,
doc: DocId,
max_doc: DocId,
}
impl DocSet for AllScorer {
fn advance(&mut self) -> bool {
match self.state {
State::NotStarted => {
self.state = State::Started;
self.doc = 0;
}
State::Started => {
self.doc += 1u32;
}
State::Finished => {
return false;
}
}
if self.doc < self.max_doc {
true
} else {
self.state = State::Finished;
false
fn advance(&mut self) -> DocId {
if self.doc + 1 >= self.max_doc {
self.doc = TERMINATED;
return TERMINATED;
}
self.doc += 1;
self.doc
}
fn doc(&self) -> DocId {
@@ -90,14 +72,13 @@ impl Scorer for AllScorer {
#[cfg(test)]
mod tests {
use super::AllQuery;
use crate::docset::TERMINATED;
use crate::query::Query;
use crate::schema::{Schema, TEXT};
use crate::Index;
#[test]
fn test_all_query() {
fn create_test_index() -> Index {
let mut schema_builder = Schema::builder();
let field = schema_builder.add_text_field("text", TEXT);
let schema = schema_builder.build();
@@ -108,25 +89,47 @@ mod tests {
index_writer.commit().unwrap();
index_writer.add_document(doc!(field=>"ccc"));
index_writer.commit().unwrap();
index
}
#[test]
fn test_all_query() {
let index = create_test_index();
let reader = index.reader().unwrap();
reader.reload().unwrap();
let searcher = reader.searcher();
let weight = AllQuery.weight(&searcher, false).unwrap();
{
let reader = searcher.segment_reader(0);
let mut scorer = weight.scorer(reader).unwrap();
assert!(scorer.advance());
let mut scorer = weight.scorer(reader, 1.0f32).unwrap();
assert_eq!(scorer.doc(), 0u32);
assert!(scorer.advance());
assert_eq!(scorer.advance(), 1u32);
assert_eq!(scorer.doc(), 1u32);
assert!(!scorer.advance());
assert_eq!(scorer.advance(), TERMINATED);
}
{
let reader = searcher.segment_reader(1);
let mut scorer = weight.scorer(reader).unwrap();
assert!(scorer.advance());
let mut scorer = weight.scorer(reader, 1.0f32).unwrap();
assert_eq!(scorer.doc(), 0u32);
assert!(!scorer.advance());
assert_eq!(scorer.advance(), TERMINATED);
}
}
#[test]
fn test_all_query_with_boost() {
let index = create_test_index();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let weight = AllQuery.weight(&searcher, false).unwrap();
let reader = searcher.segment_reader(0);
{
let mut scorer = weight.scorer(reader, 2.0f32).unwrap();
assert_eq!(scorer.doc(), 0u32);
assert_eq!(scorer.score(), 2.0f32);
}
{
let mut scorer = weight.scorer(reader, 1.5f32).unwrap();
assert_eq!(scorer.doc(), 0u32);
assert_eq!(scorer.score(), 1.5f32);
}
}
}

View File

@@ -6,8 +6,8 @@ use crate::query::{Scorer, Weight};
use crate::schema::{Field, IndexRecordOption};
use crate::termdict::{TermDictionary, TermStreamer};
use crate::DocId;
use crate::Result;
use crate::TantivyError;
use crate::{Result, SkipResult};
use std::sync::Arc;
use tantivy_fst::Automaton;
@@ -40,10 +40,9 @@ impl<A> Weight for AutomatonWeight<A>
where
A: Automaton + Send + Sync + 'static,
{
fn scorer(&self, reader: &SegmentReader) -> Result<Box<dyn Scorer>> {
fn scorer(&self, reader: &SegmentReader, boost: f32) -> Result<Box<dyn Scorer>> {
let max_doc = reader.max_doc();
let mut doc_bitset = BitSet::with_max_value(max_doc);
let inverted_index = reader.inverted_index(self.field);
let term_dict = inverted_index.terms();
let mut term_stream = self.automaton_stream(term_dict);
@@ -51,19 +50,25 @@ where
let term_info = term_stream.value();
let mut block_segment_postings = inverted_index
.read_block_postings_from_terminfo(term_info, IndexRecordOption::Basic);
while block_segment_postings.advance() {
for &doc in block_segment_postings.docs() {
loop {
let docs = block_segment_postings.docs();
if docs.is_empty() {
break;
}
for &doc in docs {
doc_bitset.insert(doc);
}
block_segment_postings.advance();
}
}
let doc_bitset = BitSetDocSet::from(doc_bitset);
Ok(Box::new(ConstScorer::new(doc_bitset)))
let const_scorer = ConstScorer::new(doc_bitset, boost);
Ok(Box::new(const_scorer))
}
fn explain(&self, reader: &SegmentReader, doc: DocId) -> Result<Explanation> {
let mut scorer = self.scorer(reader)?;
if scorer.skip_next(doc) == SkipResult::Reached {
let mut scorer = self.scorer(reader, 1.0f32)?;
if scorer.seek(doc) == doc {
Ok(Explanation::new("AutomatonScorer", 1.0f32))
} else {
Err(TantivyError::InvalidArgument(
@@ -72,3 +77,94 @@ where
}
}
}
#[cfg(test)]
mod tests {
use super::AutomatonWeight;
use crate::docset::TERMINATED;
use crate::query::Weight;
use crate::schema::{Schema, STRING};
use crate::Index;
use tantivy_fst::Automaton;
fn create_index() -> Index {
let mut schema = Schema::builder();
let title = schema.add_text_field("title", STRING);
let index = Index::create_in_ram(schema.build());
let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap();
index_writer.add_document(doc!(title=>"abc"));
index_writer.add_document(doc!(title=>"bcd"));
index_writer.add_document(doc!(title=>"abcd"));
assert!(index_writer.commit().is_ok());
index
}
enum State {
Start,
NotMatching,
AfterA,
}
struct PrefixedByA;
impl Automaton for PrefixedByA {
type State = State;
fn start(&self) -> Self::State {
State::Start
}
fn is_match(&self, state: &Self::State) -> bool {
match *state {
State::AfterA => true,
_ => false,
}
}
fn accept(&self, state: &Self::State, byte: u8) -> Self::State {
match *state {
State::Start => {
if byte == b'a' {
State::AfterA
} else {
State::NotMatching
}
}
State::AfterA => State::AfterA,
State::NotMatching => State::NotMatching,
}
}
}
#[test]
fn test_automaton_weight() {
let index = create_index();
let field = index.schema().get_field("title").unwrap();
let automaton_weight = AutomatonWeight::new(field, PrefixedByA);
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let mut scorer = automaton_weight
.scorer(searcher.segment_reader(0u32), 1.0f32)
.unwrap();
assert_eq!(scorer.doc(), 0u32);
assert_eq!(scorer.score(), 1.0f32);
assert_eq!(scorer.advance(), 2u32);
assert_eq!(scorer.doc(), 2u32);
assert_eq!(scorer.score(), 1.0f32);
assert_eq!(scorer.advance(), TERMINATED);
}
#[test]
fn test_automaton_weight_boost() {
let index = create_index();
let field = index.schema().get_field("title").unwrap();
let automaton_weight = AutomatonWeight::new(field, PrefixedByA);
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let mut scorer = automaton_weight
.scorer(searcher.segment_reader(0u32), 1.32f32)
.unwrap();
assert_eq!(scorer.doc(), 0u32);
assert_eq!(scorer.score(), 1.32f32);
}
}

View File

@@ -1,7 +1,6 @@
use crate::common::{BitSet, TinySet};
use crate::docset::{DocSet, SkipResult};
use crate::docset::{DocSet, TERMINATED};
use crate::DocId;
use std::cmp::Ordering;
/// A `BitSetDocSet` makes it possible to iterate through a bitset as if it was a `DocSet`.
///
@@ -33,75 +32,50 @@ impl From<BitSet> for BitSetDocSet {
} else {
docs.tinyset(0)
};
BitSetDocSet {
let mut docset = BitSetDocSet {
docs,
cursor_bucket: 0,
cursor_tinybitset: first_tiny_bitset,
doc: 0u32,
}
};
docset.advance();
docset
}
}
impl DocSet for BitSetDocSet {
fn advance(&mut self) -> bool {
fn advance(&mut self) -> DocId {
if let Some(lower) = self.cursor_tinybitset.pop_lowest() {
self.doc = (self.cursor_bucket as u32 * 64u32) | lower;
return true;
return self.doc;
}
if let Some(cursor_bucket) = self.docs.first_non_empty_bucket(self.cursor_bucket + 1) {
self.go_to_bucket(cursor_bucket);
let lower = self.cursor_tinybitset.pop_lowest().unwrap();
self.doc = (cursor_bucket * 64u32) | lower;
true
self.doc
} else {
false
self.doc = TERMINATED;
TERMINATED
}
}
fn skip_next(&mut self, target: DocId) -> SkipResult {
// skip is required to advance.
if !self.advance() {
return SkipResult::End;
}
fn seek(&mut self, target: DocId) -> DocId {
let target_bucket = target / 64u32;
// Mask for all of the bits greater or equal
// to our target document.
match target_bucket.cmp(&self.cursor_bucket) {
Ordering::Greater => {
self.go_to_bucket(target_bucket);
let greater_filter: TinySet = TinySet::range_greater_or_equal(target);
self.cursor_tinybitset = self.cursor_tinybitset.intersect(greater_filter);
if !self.advance() {
SkipResult::End
} else if self.doc() == target {
SkipResult::Reached
} else {
debug_assert!(self.doc() > target);
SkipResult::OverStep
}
}
Ordering::Equal => loop {
match self.doc().cmp(&target) {
Ordering::Less => {
if !self.advance() {
return SkipResult::End;
}
}
Ordering::Equal => {
return SkipResult::Reached;
}
Ordering::Greater => {
debug_assert!(self.doc() > target);
return SkipResult::OverStep;
}
}
},
Ordering::Less => {
debug_assert!(self.doc() > target);
SkipResult::OverStep
}
if target_bucket > self.cursor_bucket {
self.go_to_bucket(target_bucket);
let greater_filter: TinySet = TinySet::range_greater_or_equal(target);
self.cursor_tinybitset = self.cursor_tinybitset.intersect(greater_filter);
self.advance();
}
let mut doc = self.doc();
while doc < target {
doc = self.advance();
}
doc
}
/// Returns the current document
@@ -122,7 +96,7 @@ impl DocSet for BitSetDocSet {
mod tests {
use super::BitSetDocSet;
use crate::common::BitSet;
use crate::docset::{DocSet, SkipResult};
use crate::docset::{DocSet, TERMINATED};
use crate::DocId;
fn create_docbitset(docs: &[DocId], max_doc: DocId) -> BitSetDocSet {
@@ -133,19 +107,24 @@ mod tests {
BitSetDocSet::from(docset)
}
#[test]
fn test_empty() {
let bitset = BitSet::with_max_value(1000);
let mut empty = BitSetDocSet::from(bitset);
assert_eq!(empty.advance(), TERMINATED)
}
fn test_go_through_sequential(docs: &[DocId]) {
let mut docset = create_docbitset(docs, 1_000u32);
for &doc in docs {
assert!(docset.advance());
assert_eq!(doc, docset.doc());
docset.advance();
}
assert!(!docset.advance());
assert!(!docset.advance());
assert_eq!(docset.advance(), TERMINATED);
}
#[test]
fn test_docbitset_sequential() {
test_go_through_sequential(&[]);
test_go_through_sequential(&[1, 2, 3]);
test_go_through_sequential(&[1, 2, 3, 4, 5, 63, 64, 65]);
test_go_through_sequential(&[63, 64, 65]);
@@ -156,64 +135,64 @@ mod tests {
fn test_docbitset_skip() {
{
let mut docset = create_docbitset(&[1, 5, 6, 7, 5112], 10_000);
assert_eq!(docset.skip_next(7), SkipResult::Reached);
assert_eq!(docset.seek(7), 7);
assert_eq!(docset.doc(), 7);
assert!(docset.advance(), 7);
assert_eq!(docset.advance(), 5112);
assert_eq!(docset.doc(), 5112);
assert!(!docset.advance());
assert_eq!(docset.advance(), TERMINATED);
}
{
let mut docset = create_docbitset(&[1, 5, 6, 7, 5112], 10_000);
assert_eq!(docset.skip_next(3), SkipResult::OverStep);
assert_eq!(docset.seek(3), 5);
assert_eq!(docset.doc(), 5);
assert!(docset.advance());
assert_eq!(docset.advance(), 6);
}
{
let mut docset = create_docbitset(&[5112], 10_000);
assert_eq!(docset.skip_next(5112), SkipResult::Reached);
assert_eq!(docset.seek(5112), 5112);
assert_eq!(docset.doc(), 5112);
assert!(!docset.advance());
assert_eq!(docset.advance(), TERMINATED);
}
{
let mut docset = create_docbitset(&[5112], 10_000);
assert_eq!(docset.skip_next(5113), SkipResult::End);
assert!(!docset.advance());
assert_eq!(docset.seek(5113), TERMINATED);
assert_eq!(docset.advance(), TERMINATED);
}
{
let mut docset = create_docbitset(&[5112], 10_000);
assert_eq!(docset.skip_next(5111), SkipResult::OverStep);
assert_eq!(docset.seek(5111), 5112);
assert_eq!(docset.doc(), 5112);
assert!(!docset.advance());
assert_eq!(docset.advance(), TERMINATED);
}
{
let mut docset = create_docbitset(&[1, 5, 6, 7, 5112, 5500, 6666], 10_000);
assert_eq!(docset.skip_next(5112), SkipResult::Reached);
assert_eq!(docset.seek(5112), 5112);
assert_eq!(docset.doc(), 5112);
assert!(docset.advance());
assert_eq!(docset.advance(), 5500);
assert_eq!(docset.doc(), 5500);
assert!(docset.advance());
assert_eq!(docset.advance(), 6666);
assert_eq!(docset.doc(), 6666);
assert!(!docset.advance());
assert_eq!(docset.advance(), TERMINATED);
}
{
let mut docset = create_docbitset(&[1, 5, 6, 7, 5112, 5500, 6666], 10_000);
assert_eq!(docset.skip_next(5111), SkipResult::OverStep);
assert_eq!(docset.seek(5111), 5112);
assert_eq!(docset.doc(), 5112);
assert!(docset.advance());
assert_eq!(docset.advance(), 5500);
assert_eq!(docset.doc(), 5500);
assert!(docset.advance());
assert_eq!(docset.advance(), 6666);
assert_eq!(docset.doc(), 6666);
assert!(!docset.advance());
assert_eq!(docset.advance(), TERMINATED);
}
{
let mut docset = create_docbitset(&[1, 5, 6, 7, 5112, 5513, 6666], 10_000);
assert_eq!(docset.skip_next(5111), SkipResult::OverStep);
assert_eq!(docset.seek(5111), 5112);
assert_eq!(docset.doc(), 5112);
assert!(docset.advance());
assert_eq!(docset.advance(), 5513);
assert_eq!(docset.doc(), 5513);
assert!(docset.advance());
assert_eq!(docset.advance(), 6666);
assert_eq!(docset.doc(), 6666);
assert!(!docset.advance());
assert_eq!(docset.advance(), TERMINATED);
}
}
}
@@ -223,6 +202,7 @@ mod bench {
use super::BitSet;
use super::BitSetDocSet;
use crate::docset::TERMINATED;
use crate::test;
use crate::tests;
use crate::DocSet;
@@ -257,7 +237,7 @@ mod bench {
}
b.iter(|| {
let mut docset = BitSetDocSet::from(bitset.clone());
while docset.advance() {}
while docset.advance() != TERMINATED {}
});
}
}

View File

@@ -3,11 +3,14 @@ use crate::query::Explanation;
use crate::Score;
use crate::Searcher;
use crate::Term;
use serde::Deserialize;
use serde::Serialize;
const K1: f32 = 1.2;
const B: f32 = 0.75;
fn idf(doc_freq: u64, doc_count: u64) -> f32 {
assert!(doc_count >= doc_freq, "{} >= {}", doc_count, doc_freq);
let x = ((doc_count - doc_freq) as f32 + 0.5) / (doc_freq as f32 + 0.5);
(1f32 + x).ln()
}
@@ -25,7 +28,12 @@ fn compute_tf_cache(average_fieldnorm: f32) -> [f32; 256] {
cache
}
#[derive(Clone)]
#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)]
pub struct BM25Params {
pub idf: f32,
pub avg_fieldnorm: f32,
}
pub struct BM25Weight {
idf_explain: Explanation,
weight: f32,
@@ -34,6 +42,15 @@ pub struct BM25Weight {
}
impl BM25Weight {
pub fn boost_by(&self, boost: f32) -> BM25Weight {
BM25Weight {
idf_explain: self.idf_explain.clone(),
weight: self.weight * boost,
cache: self.cache,
average_fieldnorm: self.average_fieldnorm,
}
}
pub fn for_terms(searcher: &Searcher, terms: &[Term]) -> BM25Weight {
assert!(!terms.is_empty(), "BM25 requires at least one term");
let field = terms[0].field();
@@ -54,17 +71,9 @@ impl BM25Weight {
}
let average_fieldnorm = total_num_tokens as f32 / total_num_docs as f32;
let mut idf_explain: Explanation;
if terms.len() == 1 {
let term_doc_freq = searcher.doc_freq(&terms[0]);
let idf = idf(term_doc_freq, total_num_docs);
idf_explain =
Explanation::new("idf, computed as log(1 + (N - n + 0.5) / (n + 0.5))", idf);
idf_explain.add_const(
"n, number of docs containing this term",
term_doc_freq as f32,
);
idf_explain.add_const("N, total number of docs", total_num_docs as f32);
BM25Weight::for_one_term(term_doc_freq, total_num_docs, average_fieldnorm)
} else {
let idf = terms
.iter()
@@ -73,9 +82,21 @@ impl BM25Weight {
idf(term_doc_freq, total_num_docs)
})
.sum::<f32>();
idf_explain = Explanation::new("idf", idf);
let idf_explain = Explanation::new("idf", idf);
BM25Weight::new(idf_explain, average_fieldnorm)
}
BM25Weight::new(idf_explain, average_fieldnorm)
}
pub fn for_one_term(term_doc_freq: u64, total_num_docs: u64, avg_fieldnorm: f32) -> BM25Weight {
let idf = idf(term_doc_freq, total_num_docs);
let mut idf_explain =
Explanation::new("idf, computed as log(1 + (N - n + 0.5) / (n + 0.5))", idf);
idf_explain.add_const(
"n, number of docs containing this term",
term_doc_freq as f32,
);
idf_explain.add_const("N, total number of docs", total_num_docs as f32);
BM25Weight::new(idf_explain, avg_fieldnorm)
}
fn new(idf_explain: Explanation, average_fieldnorm: f32) -> BM25Weight {
@@ -90,15 +111,23 @@ impl BM25Weight {
#[inline(always)]
pub fn score(&self, fieldnorm_id: u8, term_freq: u32) -> Score {
let norm = self.cache[fieldnorm_id as usize];
self.weight * self.tf_factor(fieldnorm_id, term_freq)
}
pub fn max_score(&self) -> Score {
self.score(255u8, 2_013_265_944)
}
#[inline(always)]
pub(crate) fn tf_factor(&self, fieldnorm_id: u8, term_freq: u32) -> f32 {
let term_freq = term_freq as f32;
self.weight * term_freq / (term_freq + norm)
let norm = self.cache[fieldnorm_id as usize];
term_freq / (term_freq + norm)
}
pub fn explain(&self, fieldnorm_id: u8, term_freq: u32) -> Explanation {
// The explain format is directly copied from Lucene's.
// (So, Kudos to Lucene)
let score = self.score(fieldnorm_id, term_freq);
let norm = self.cache[fieldnorm_id as usize];
@@ -131,10 +160,10 @@ impl BM25Weight {
mod tests {
use super::idf;
use crate::tests::assert_nearly_equals;
use crate::assert_nearly_equals;
#[test]
fn test_idf() {
assert_nearly_equals(idf(1, 2), 0.6931472);
assert_nearly_equals!(idf(1, 2), std::f32::consts::LN_2);
}
}

View File

@@ -0,0 +1,206 @@
use crate::query::term_query::TermScorer;
use crate::query::Scorer;
use crate::{DocId, DocSet, Score, TERMINATED};
use std::ops::DerefMut;
use std::ops::Deref;
/// Takes a term_scorers sorted by their current doc() and a threshold and returns
/// Returns (pivot_len, pivot_ord) defined as follows:
/// - `pivot_doc` lowest document that has a chance of exceeding (>) the threshold score.
/// - `before_pivot_len` number of term_scorers such that term_scorer.doc() < pivot.
/// - `pivot_len` number of term_scorers such that term_scorer.doc() <= pivot.
///
/// We always have `before_pivot_len` < `pivot_len`.
///
/// None is returned if we establish that no document can exceed the threshold.
fn find_pivot_doc(term_scorers: &[TermScorerWithMaxScore], threshold: f32) -> Option<(usize, usize, DocId)> {
let mut max_score = 0.0f32;
let mut before_pivot_len = 0;
let mut pivot_doc = TERMINATED;
while before_pivot_len < term_scorers.len() {
let term_scorer = &term_scorers[before_pivot_len];
max_score += term_scorer.max_score;
if max_score > threshold {
pivot_doc = term_scorer.doc();
break;
}
before_pivot_len += 1;
}
if pivot_doc == TERMINATED {
return None;
}
// Right now i is an ordinal, we want a len.
let mut pivot_len = before_pivot_len + 1;
// Some other term_scorer may be positioned on the same document.
pivot_len += term_scorers[pivot_len..].iter()
.take_while(|term_scorer| term_scorer.doc() == pivot_doc)
.count();
Some((before_pivot_len, pivot_len, pivot_doc))
}
struct TermScorerWithMaxScore<'a> {
scorer: &'a mut TermScorer,
max_score: f32,
}
impl<'a> From<&'a mut TermScorer> for TermScorerWithMaxScore<'a> {
fn from(scorer: &'a mut TermScorer) -> Self {
let max_score = scorer.max_score();
TermScorerWithMaxScore {
scorer,
max_score
}
}
}
impl<'a> Deref for TermScorerWithMaxScore<'a> {
type Target = TermScorer;
fn deref(&self) -> &Self::Target {
self.scorer
}
}
impl<'a> DerefMut for TermScorerWithMaxScore<'a> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.scorer
}
}
// Before and after calling this method, scorers need to be sorted by their `.doc()`.
fn block_max_was_too_low_advance_one_scorer(scorers: &mut Vec<TermScorerWithMaxScore>, pivot_len: usize) {
let mut scorer_to_seek = pivot_len - 1;
let mut doc_to_seek_after = scorers[scorer_to_seek].doc();
for scorer_ord in (0..pivot_len - 1).rev() {
let scorer = &scorers[scorer_ord];
if scorer.last_doc_in_block() <= doc_to_seek_after {
doc_to_seek_after = scorer.last_doc_in_block();
scorer_to_seek = scorer_ord;
}
}
for scorer in &scorers[pivot_len..] {
if scorer.doc() <= doc_to_seek_after {
doc_to_seek_after = scorer.doc();
}
}
scorers[scorer_to_seek].seek(doc_to_seek_after + 1);
restore_ordering(scorers, scorer_to_seek);
}
// Given a list of term_scorers and a `ord` and assuming that `term_scorers[ord]` is sorted
// except term_scorers[ord] that might be in advance compared to its ranks,
// bubble up term_scorers[ord] in order to restore the ordering.
fn restore_ordering(term_scorers: &mut Vec<TermScorerWithMaxScore>, ord: usize) {
let doc = term_scorers[ord].doc();
for i in ord + 1..term_scorers.len() {
if term_scorers[i].doc() >= doc {
break;
}
term_scorers.swap(i, i - 1);
}
}
// Attempts to advance all term_scorers between `&term_scorers[0..before_len]` to the pivot.
// If this works, return true.
// If this fails (ie: one of the term_scorer does not contain `pivot_doc` and seek goes past the
// pivot), reorder the term_scorers to ensure the list is still sorted and returns `false`.
// If a term_scorer reach TERMINATED in the process return false remove the term_scorer and return.
fn align_scorers(term_scorers: &mut Vec<TermScorerWithMaxScore>, pivot_doc: DocId, before_pivot_len: usize) -> bool {
debug_assert_ne!(pivot_doc, TERMINATED);
for i in (0..before_pivot_len).rev() {
let new_doc = term_scorers[i].seek(pivot_doc);
if new_doc != pivot_doc {
if new_doc == TERMINATED {
term_scorers.swap_remove(i);
}
// We went past the pivot.
// We just go through the outer loop mechanic (Note that pivot is
// still a possible candidate).
//
// Termination is still guaranteed since we can only consider the same
// pivot at most term_scorers.len() - 1 times.
restore_ordering(term_scorers, i);
return false;
}
}
return true;
}
// Assumes terms_scorers[..pivot_len] are positioned on the same doc (pivot_doc).
// Advance term_scorers[..pivot_len] and out of these removes the terminated scores.
// Restores the ordering of term_scorers.
fn advance_all_scorers_on_pivot(term_scorers: &mut Vec<TermScorerWithMaxScore>, pivot_len: usize) {
let mut i = 0;
for _ in 0..pivot_len {
if term_scorers[i].advance() == TERMINATED {
term_scorers.swap_remove(i);
} else {
i += 1;
}
}
term_scorers.sort_by_key(|scorer| scorer.doc());
}
pub fn block_wand(
mut scorers: Vec<TermScorer>,
mut threshold: f32,
callback: &mut dyn FnMut(u32, Score) -> Score,
) {
let mut scorers: Vec<TermScorerWithMaxScore> = scorers.iter_mut().map(TermScorerWithMaxScore::from).collect();
scorers.sort_by_key(|scorer| scorer.doc());
loop {
// At this point we need to ensure that the scorers are sorted!
if let Some((before_pivot_len, pivot_len, pivot_doc)) = find_pivot_doc(&scorers[..], threshold) {
debug_assert_ne!(pivot_doc, TERMINATED);
debug_assert!(before_pivot_len < pivot_len);
let block_max_score_upperbound: Score = scorers[..pivot_len].iter_mut()
.map(|scorer| {
scorer.shallow_seek(pivot_doc);
scorer.block_max_score()
})
.sum();
// Beware after shallow advance, skip readers can be in advance compared to
// the segment posting lists.
//
// `block_segment_postings.load_block()` need to be called separately.
if block_max_score_upperbound <= threshold {
// Block max condition was not reached.
// We could get away by simply advancing the scorers to DocId + 1 but it would
// be inefficient. The optimization requires proper explanation and was
// isolated in a different function.
block_max_was_too_low_advance_one_scorer(&mut scorers, pivot_len);
continue;
}
// Block max condition is observed.
//
// Let's try and advance all scorers before the pivot to the pivot.
if !align_scorers(&mut scorers, pivot_doc, before_pivot_len) {
// At least of the scorer does not contain the pivot.
//
// Let's stop scoring this pivot and go through the pivot selection again.
// Note that the current pivot is not necessarily a bad candidate and it
// may be picked again.
continue;
}
// At this point, all scorers are positioned on the doc.
let score = scorers[..pivot_len]
.iter_mut()
.map(|scorer| scorer.score())
.sum();
if score > threshold {
threshold = callback(pivot_doc, score);
}
// let's advance all of the scorers that are currently positioned on the pivot.
advance_all_scorers_on_pivot(&mut scorers, pivot_len);
} else {
return;
}
}
}

View File

@@ -1,7 +1,9 @@
use crate::core::SegmentReader;
use crate::postings::FreqReadingOption;
use crate::query::explanation::does_not_match;
use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner, SumWithCoordsCombiner};
use crate::query::term_query::TermScorer;
use crate::query::weight::{for_each_pruning_scorer, for_each_scorer};
use crate::query::EmptyScorer;
use crate::query::Exclude;
use crate::query::Occur;
@@ -10,16 +12,21 @@ use crate::query::Scorer;
use crate::query::Union;
use crate::query::Weight;
use crate::query::{intersect_scorers, Explanation};
use crate::{DocId, SkipResult};
use crate::{DocId, Score};
use std::collections::HashMap;
fn scorer_union<TScoreCombiner>(scorers: Vec<Box<dyn Scorer>>) -> Box<dyn Scorer>
enum SpecializedScorer {
TermUnion(Vec<TermScorer>),
Other(Box<dyn Scorer>),
}
fn scorer_union<TScoreCombiner>(scorers: Vec<Box<dyn Scorer>>) -> SpecializedScorer
where
TScoreCombiner: ScoreCombiner,
{
assert!(!scorers.is_empty());
if scorers.len() == 1 {
return scorers.into_iter().next().unwrap(); //< we checked the size beforehands
return SpecializedScorer::Other(scorers.into_iter().next().unwrap()); //< we checked the size beforehands
}
{
@@ -29,14 +36,30 @@ where
.into_iter()
.map(|scorer| *(scorer.downcast::<TermScorer>().map_err(|_| ()).unwrap()))
.collect();
let scorer: Box<dyn Scorer> =
Box::new(Union::<TermScorer, TScoreCombiner>::from(scorers));
return scorer;
if scorers
.iter()
.all(|scorer| scorer.freq_reading_option() == FreqReadingOption::ReadFreq)
{
// Block wand is only available iff we read frequencies.
return SpecializedScorer::TermUnion(scorers);
} else {
return SpecializedScorer::Other(Box::new(Union::<_, TScoreCombiner>::from(
scorers,
)));
}
}
}
SpecializedScorer::Other(Box::new(Union::<_, TScoreCombiner>::from(scorers)))
}
let scorer: Box<dyn Scorer> = Box::new(Union::<_, TScoreCombiner>::from(scorers));
scorer
fn into_box_scorer<TScoreCombiner: ScoreCombiner>(scorer: SpecializedScorer) -> Box<dyn Scorer> {
match scorer {
SpecializedScorer::TermUnion(term_scorers) => {
let union_scorer = Union::<TermScorer, TScoreCombiner>::from(term_scorers);
Box::new(union_scorer)
}
SpecializedScorer::Other(scorer) => scorer,
}
}
pub struct BooleanWeight {
@@ -55,10 +78,11 @@ impl BooleanWeight {
fn per_occur_scorers(
&self,
reader: &SegmentReader,
boost: f32,
) -> crate::Result<HashMap<Occur, Vec<Box<dyn Scorer>>>> {
let mut per_occur_scorers: HashMap<Occur, Vec<Box<dyn Scorer>>> = HashMap::new();
for &(ref occur, ref subweight) in &self.weights {
let sub_scorer: Box<dyn Scorer> = subweight.scorer(reader)?;
let sub_scorer: Box<dyn Scorer> = subweight.scorer(reader, boost)?;
per_occur_scorers
.entry(*occur)
.or_insert_with(Vec::new)
@@ -70,41 +94,52 @@ impl BooleanWeight {
fn complex_scorer<TScoreCombiner: ScoreCombiner>(
&self,
reader: &SegmentReader,
) -> crate::Result<Box<dyn Scorer>> {
let mut per_occur_scorers = self.per_occur_scorers(reader)?;
boost: f32,
) -> crate::Result<SpecializedScorer> {
let mut per_occur_scorers = self.per_occur_scorers(reader, boost)?;
let should_scorer_opt: Option<Box<dyn Scorer>> = per_occur_scorers
let should_scorer_opt: Option<SpecializedScorer> = per_occur_scorers
.remove(&Occur::Should)
.map(scorer_union::<TScoreCombiner>);
let exclude_scorer_opt: Option<Box<dyn Scorer>> = per_occur_scorers
.remove(&Occur::MustNot)
.map(scorer_union::<TScoreCombiner>);
.map(scorer_union::<DoNothingCombiner>)
.map(into_box_scorer::<DoNothingCombiner>);
let must_scorer_opt: Option<Box<dyn Scorer>> = per_occur_scorers
.remove(&Occur::Must)
.map(intersect_scorers);
let positive_scorer: Box<dyn Scorer> = match (should_scorer_opt, must_scorer_opt) {
let positive_scorer: SpecializedScorer = match (should_scorer_opt, must_scorer_opt) {
(Some(should_scorer), Some(must_scorer)) => {
if self.scoring_enabled {
Box::new(RequiredOptionalScorer::<_, _, TScoreCombiner>::new(
SpecializedScorer::Other(Box::new(RequiredOptionalScorer::<
Box<dyn Scorer>,
Box<dyn Scorer>,
TScoreCombiner,
>::new(
must_scorer,
should_scorer,
))
into_box_scorer::<TScoreCombiner>(should_scorer),
)))
} else {
must_scorer
SpecializedScorer::Other(must_scorer)
}
}
(None, Some(must_scorer)) => must_scorer,
(None, Some(must_scorer)) => SpecializedScorer::Other(must_scorer),
(Some(should_scorer), None) => should_scorer,
(None, None) => {
return Ok(Box::new(EmptyScorer));
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)));
}
};
if let Some(exclude_scorer) = exclude_scorer_opt {
Ok(Box::new(Exclude::new(positive_scorer, exclude_scorer)))
let positive_scorer_boxed: Box<dyn Scorer> =
into_box_scorer::<TScoreCombiner>(positive_scorer);
Ok(SpecializedScorer::Other(Box::new(Exclude::new(
positive_scorer_boxed,
exclude_scorer,
))))
} else {
Ok(positive_scorer)
}
@@ -112,7 +147,7 @@ impl BooleanWeight {
}
impl Weight for BooleanWeight {
fn scorer(&self, reader: &SegmentReader) -> crate::Result<Box<dyn Scorer>> {
fn scorer(&self, reader: &SegmentReader, boost: f32) -> crate::Result<Box<dyn Scorer>> {
if self.weights.is_empty() {
Ok(Box::new(EmptyScorer))
} else if self.weights.len() == 1 {
@@ -120,18 +155,22 @@ impl Weight for BooleanWeight {
if occur == Occur::MustNot {
Ok(Box::new(EmptyScorer))
} else {
weight.scorer(reader)
weight.scorer(reader, boost)
}
} else if self.scoring_enabled {
self.complex_scorer::<SumWithCoordsCombiner>(reader)
self.complex_scorer::<SumWithCoordsCombiner>(reader, boost)
.map(|specialized_scorer| {
into_box_scorer::<SumWithCoordsCombiner>(specialized_scorer)
})
} else {
self.complex_scorer::<DoNothingCombiner>(reader)
self.complex_scorer::<DoNothingCombiner>(reader, boost)
.map(into_box_scorer::<DoNothingCombiner>)
}
}
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {
let mut scorer = self.scorer(reader)?;
if scorer.skip_next(doc) != SkipResult::Reached {
let mut scorer = self.scorer(reader, 1.0f32)?;
if scorer.seek(doc) != doc {
return Err(does_not_match(doc));
}
if !self.scoring_enabled {
@@ -148,6 +187,53 @@ impl Weight for BooleanWeight {
}
Ok(explanation)
}
fn for_each(
&self,
reader: &SegmentReader,
callback: &mut dyn FnMut(DocId, Score),
) -> crate::Result<()> {
let scorer = self.complex_scorer::<SumWithCoordsCombiner>(reader, 1.0f32)?;
match scorer {
SpecializedScorer::TermUnion(term_scorers) => {
let mut union_scorer =
Union::<TermScorer, SumWithCoordsCombiner>::from(term_scorers);
for_each_scorer(&mut union_scorer, callback);
}
SpecializedScorer::Other(mut scorer) => {
for_each_scorer(scorer.as_mut(), callback);
}
}
Ok(())
}
/// Calls `callback` with all of the `(doc, score)` for which score
/// is exceeding a given threshold.
///
/// This method is useful for the TopDocs collector.
/// For all docsets, the blanket implementation has the benefit
/// of prefiltering (doc, score) pairs, avoiding the
/// virtual dispatch cost.
///
/// More importantly, it makes it possible for scorers to implement
/// important optimization (e.g. BlockWAND for union).
fn for_each_pruning(
&self,
threshold: f32,
reader: &SegmentReader,
callback: &mut dyn FnMut(DocId, Score) -> Score,
) -> crate::Result<()> {
let scorer = self.complex_scorer::<SumWithCoordsCombiner>(reader, 1.0f32)?;
match scorer {
SpecializedScorer::TermUnion(term_scorers) => {
super::block_wand(term_scorers, threshold, callback);
}
SpecializedScorer::Other(mut scorer) => {
for_each_pruning_scorer(scorer.as_mut(), threshold, callback);
}
}
Ok(())
}
}
fn is_positive_occur(occur: Occur) -> bool {

View File

@@ -1,13 +1,17 @@
mod block_wand;
mod boolean_query;
mod boolean_weight;
pub(crate) use self::block_wand::block_wand;
pub use self::boolean_query::BooleanQuery;
#[cfg(test)]
mod tests {
use super::*;
use crate::assert_nearly_equals;
use crate::collector::tests::TEST_COLLECTOR_WITH_SCORE;
use crate::collector::TopDocs;
use crate::query::score_combiner::SumWithCoordsCombiner;
use crate::query::term_query::TermScorer;
use crate::query::Intersection;
@@ -19,7 +23,7 @@ mod tests {
use crate::query::TermQuery;
use crate::schema::*;
use crate::Index;
use crate::{DocAddress, DocId};
use crate::{DocAddress, DocId, Score};
fn aux_test_helper() -> (Index, Field) {
let mut schema_builder = Schema::builder();
@@ -30,24 +34,11 @@ mod tests {
// writing the segment
let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap();
{
let doc = doc!(text_field => "a b c");
index_writer.add_document(doc);
}
{
let doc = doc!(text_field => "a c");
index_writer.add_document(doc);
}
{
let doc = doc!(text_field => "b c");
index_writer.add_document(doc);
}
{
let doc = doc!(text_field => "a b c d");
index_writer.add_document(doc);
}
{
let doc = doc!(text_field => "d");
index_writer.add_document(doc);
index_writer.add_document(doc!(text_field => "a b c"));
index_writer.add_document(doc!(text_field => "a c"));
index_writer.add_document(doc!(text_field => "b c"));
index_writer.add_document(doc!(text_field => "a b c d"));
index_writer.add_document(doc!(text_field => "d"));
}
assert!(index_writer.commit().is_ok());
}
@@ -70,7 +61,9 @@ mod tests {
let query = query_parser.parse_query("+a").unwrap();
let searcher = index.reader().unwrap().searcher();
let weight = query.weight(&searcher, true).unwrap();
let scorer = weight.scorer(searcher.segment_reader(0u32)).unwrap();
let scorer = weight
.scorer(searcher.segment_reader(0u32), 1.0f32)
.unwrap();
assert!(scorer.is::<TermScorer>());
}
@@ -82,13 +75,17 @@ mod tests {
{
let query = query_parser.parse_query("+a +b +c").unwrap();
let weight = query.weight(&searcher, true).unwrap();
let scorer = weight.scorer(searcher.segment_reader(0u32)).unwrap();
let scorer = weight
.scorer(searcher.segment_reader(0u32), 1.0f32)
.unwrap();
assert!(scorer.is::<Intersection<TermScorer>>());
}
{
let query = query_parser.parse_query("+a +(b c)").unwrap();
let weight = query.weight(&searcher, true).unwrap();
let scorer = weight.scorer(searcher.segment_reader(0u32)).unwrap();
let scorer = weight
.scorer(searcher.segment_reader(0u32), 1.0f32)
.unwrap();
assert!(scorer.is::<Intersection<Box<dyn Scorer>>>());
}
}
@@ -101,7 +98,9 @@ mod tests {
{
let query = query_parser.parse_query("+a b").unwrap();
let weight = query.weight(&searcher, true).unwrap();
let scorer = weight.scorer(searcher.segment_reader(0u32)).unwrap();
let scorer = weight
.scorer(searcher.segment_reader(0u32), 1.0f32)
.unwrap();
assert!(scorer.is::<RequiredOptionalScorer<
Box<dyn Scorer>,
Box<dyn Scorer>,
@@ -111,7 +110,9 @@ mod tests {
{
let query = query_parser.parse_query("+a b").unwrap();
let weight = query.weight(&searcher, false).unwrap();
let scorer = weight.scorer(searcher.segment_reader(0u32)).unwrap();
let scorer = weight
.scorer(searcher.segment_reader(0u32), 1.0f32)
.unwrap();
assert!(scorer.is::<TermScorer>());
}
}
@@ -142,7 +143,6 @@ mod tests {
.map(|doc| doc.1)
.collect::<Vec<DocId>>()
};
{
let boolean_query = BooleanQuery::from(vec![(Occur::Must, make_term_query("a"))]);
assert_eq!(matching_docs(&boolean_query), vec![0, 1, 3]);
@@ -179,6 +179,96 @@ mod tests {
}
}
#[test]
pub fn test_boolean_query_two_excluded() {
let (index, text_field) = aux_test_helper();
let make_term_query = |text: &str| {
let term_query = TermQuery::new(
Term::from_field_text(text_field, text),
IndexRecordOption::Basic,
);
let query: Box<dyn Query> = Box::new(term_query);
query
};
let reader = index.reader().unwrap();
let matching_topdocs = |query: &dyn Query| {
reader
.searcher()
.search(query, &TopDocs::with_limit(3))
.unwrap()
};
let score_doc_4: Score; // score of doc 4 should not be influenced by exclusion
{
let boolean_query_no_excluded =
BooleanQuery::from(vec![(Occur::Must, make_term_query("d"))]);
let topdocs_no_excluded = matching_topdocs(&boolean_query_no_excluded);
assert_eq!(topdocs_no_excluded.len(), 2);
let (top_score, top_doc) = topdocs_no_excluded[0];
assert_eq!(top_doc, DocAddress(0, 4));
assert_eq!(topdocs_no_excluded[1].1, DocAddress(0, 3)); // ignore score of doc 3.
score_doc_4 = top_score;
}
{
let boolean_query_two_excluded = BooleanQuery::from(vec![
(Occur::Must, make_term_query("d")),
(Occur::MustNot, make_term_query("a")),
(Occur::MustNot, make_term_query("b")),
]);
let topdocs_excluded = matching_topdocs(&boolean_query_two_excluded);
assert_eq!(topdocs_excluded.len(), 1);
let (top_score, top_doc) = topdocs_excluded[0];
assert_eq!(top_doc, DocAddress(0, 4));
assert_eq!(top_score, score_doc_4);
}
}
#[test]
pub fn test_boolean_query_with_weight() {
let mut schema_builder = Schema::builder();
let text_field = schema_builder.add_text_field("text", TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
{
let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap();
index_writer.add_document(doc!(text_field => "a b c"));
index_writer.add_document(doc!(text_field => "a c"));
index_writer.add_document(doc!(text_field => "b c"));
assert!(index_writer.commit().is_ok());
}
let term_a: Box<dyn Query> = Box::new(TermQuery::new(
Term::from_field_text(text_field, "a"),
IndexRecordOption::WithFreqs,
));
let term_b: Box<dyn Query> = Box::new(TermQuery::new(
Term::from_field_text(text_field, "b"),
IndexRecordOption::WithFreqs,
));
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let boolean_query =
BooleanQuery::from(vec![(Occur::Should, term_a), (Occur::Should, term_b)]);
let boolean_weight = boolean_query.weight(&searcher, true).unwrap();
{
let mut boolean_scorer = boolean_weight
.scorer(searcher.segment_reader(0u32), 1.0f32)
.unwrap();
assert_eq!(boolean_scorer.doc(), 0u32);
assert_nearly_equals!(boolean_scorer.score(), 0.84163445f32);
}
{
let mut boolean_scorer = boolean_weight
.scorer(searcher.segment_reader(0u32), 2.0f32)
.unwrap();
assert_eq!(boolean_scorer.doc(), 0u32);
assert_nearly_equals!(boolean_scorer.score(), 1.6832689f32);
}
}
#[test]
pub fn test_intersection_score() {
let (index, text_field) = aux_test_helper();
@@ -234,7 +324,7 @@ mod tests {
index_writer.add_document(doc!(
// tf = 1 1
title => "PDF Мастер Класс \"Морячок\" (Оксана Лифенко)",
// tf = 0 0
// tf = 0 0
text => "https://i.ibb.co/pzvHrDN/I3d U T6 Gg TM.jpg\nhttps://i.ibb.co/NFrb6v6/N0ls Z9nwjb U.jpg\nВ описание входит штаны, кофта, берет, матросский воротник. Описание продается в формате PDF, состоит из 12 страниц формата А4 и может быть напечатано на любом принтере.\nОписание предназначено для кукол BJD RealPuki от FairyLand, но может подойти и другим подобным куклам. Также вы можете вязать этот наряд из обычной пряжи, и он подойдет для куколок побольше.\nhttps://vk.com/market 95724412?w=product 95724412_2212"
));
for _ in 0..1_000 {
@@ -249,7 +339,9 @@ mod tests {
let query_parser = QueryParser::for_index(&index, vec![title, text]);
let query = query_parser.parse_query("Оксана Лифенко").unwrap();
let weight = query.weight(&searcher, true).unwrap();
let mut scorer = weight.scorer(searcher.segment_reader(0u32)).unwrap();
let mut scorer = weight
.scorer(searcher.segment_reader(0u32), 1.0f32)
.unwrap();
scorer.advance();
let explanation = query.explain(&searcher, DocAddress(0u32, 0u32)).unwrap();

159
src/query/boost_query.rs Normal file
View File

@@ -0,0 +1,159 @@
use crate::fastfield::DeleteBitSet;
use crate::query::explanation::does_not_match;
use crate::query::{Explanation, Query, Scorer, Weight};
use crate::{DocId, DocSet, Searcher, SegmentReader, Term};
use std::collections::BTreeSet;
use std::fmt;
/// `BoostQuery` is a wrapper over a query used to boost its score.
///
/// The document set matched by the `BoostQuery` is strictly the same as the underlying query.
/// The score of each document, is the score of the underlying query multiplied by the `boost`
/// factor.
pub struct BoostQuery {
query: Box<dyn Query>,
boost: f32,
}
impl BoostQuery {
/// Builds a boost query.
pub fn new(query: Box<dyn Query>, boost: f32) -> BoostQuery {
BoostQuery { query, boost }
}
}
impl Clone for BoostQuery {
fn clone(&self) -> Self {
BoostQuery {
query: self.query.box_clone(),
boost: self.boost,
}
}
}
impl fmt::Debug for BoostQuery {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Boost(query={:?}, boost={})", self.query, self.boost)
}
}
impl Query for BoostQuery {
fn weight(&self, searcher: &Searcher, scoring_enabled: bool) -> crate::Result<Box<dyn Weight>> {
let weight_without_boost = self.query.weight(searcher, scoring_enabled)?;
let boosted_weight = if scoring_enabled {
Box::new(BoostWeight::new(weight_without_boost, self.boost))
} else {
weight_without_boost
};
Ok(boosted_weight)
}
fn query_terms(&self, term_set: &mut BTreeSet<Term>) {
self.query.query_terms(term_set)
}
}
pub(crate) struct BoostWeight {
weight: Box<dyn Weight>,
boost: f32,
}
impl BoostWeight {
pub fn new(weight: Box<dyn Weight>, boost: f32) -> Self {
BoostWeight { weight, boost }
}
}
impl Weight for BoostWeight {
fn scorer(&self, reader: &SegmentReader, boost: f32) -> crate::Result<Box<dyn Scorer>> {
self.weight.scorer(reader, boost * self.boost)
}
fn explain(&self, reader: &SegmentReader, doc: u32) -> crate::Result<Explanation> {
let mut scorer = self.scorer(reader, 1.0f32)?;
if scorer.seek(doc) != doc {
return Err(does_not_match(doc));
}
let mut explanation =
Explanation::new(format!("Boost x{} of ...", self.boost), scorer.score());
let underlying_explanation = self.weight.explain(reader, doc)?;
explanation.add_detail(underlying_explanation);
Ok(explanation)
}
fn count(&self, reader: &SegmentReader) -> crate::Result<u32> {
self.weight.count(reader)
}
}
pub(crate) struct BoostScorer<S: Scorer> {
underlying: S,
boost: f32,
}
impl<S: Scorer> BoostScorer<S> {
pub fn new(underlying: S, boost: f32) -> BoostScorer<S> {
BoostScorer { underlying, boost }
}
}
impl<S: Scorer> DocSet for BoostScorer<S> {
fn advance(&mut self) -> DocId {
self.underlying.advance()
}
fn seek(&mut self, target: DocId) -> DocId {
self.underlying.seek(target)
}
fn fill_buffer(&mut self, buffer: &mut [DocId]) -> usize {
self.underlying.fill_buffer(buffer)
}
fn doc(&self) -> u32 {
self.underlying.doc()
}
fn size_hint(&self) -> u32 {
self.underlying.size_hint()
}
fn count(&mut self, delete_bitset: &DeleteBitSet) -> u32 {
self.underlying.count(delete_bitset)
}
fn count_including_deleted(&mut self) -> u32 {
self.underlying.count_including_deleted()
}
}
impl<S: Scorer> Scorer for BoostScorer<S> {
fn score(&mut self) -> f32 {
self.underlying.score() * self.boost
}
}
#[cfg(test)]
mod tests {
use super::BoostQuery;
use crate::query::{AllQuery, Query};
use crate::schema::Schema;
use crate::{DocAddress, Document, Index};
#[test]
fn test_boost_query_explain() {
let schema = Schema::builder().build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap();
index_writer.add_document(Document::new());
assert!(index_writer.commit().is_ok());
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let query = BoostQuery::new(Box::new(AllQuery), 0.2);
let explanation = query.explain(&searcher, DocAddress(0, 0u32)).unwrap();
assert_eq!(
explanation.to_pretty_json(),
"{\n \"value\": 0.2,\n \"description\": \"Boost x0.2 of ...\",\n \"details\": [\n {\n \"value\": 1.0,\n \"description\": \"AllQuery\"\n }\n ]\n}"
)
}
}

View File

@@ -1,4 +1,5 @@
use super::Scorer;
use crate::docset::TERMINATED;
use crate::query::explanation::does_not_match;
use crate::query::Weight;
use crate::query::{Explanation, Query};
@@ -33,7 +34,7 @@ impl Query for EmptyQuery {
/// It is useful for tests and handling edge cases.
pub struct EmptyWeight;
impl Weight for EmptyWeight {
fn scorer(&self, _reader: &SegmentReader) -> crate::Result<Box<dyn Scorer>> {
fn scorer(&self, _reader: &SegmentReader, _boost: f32) -> crate::Result<Box<dyn Scorer>> {
Ok(Box::new(EmptyScorer))
}
@@ -48,15 +49,12 @@ impl Weight for EmptyWeight {
pub struct EmptyScorer;
impl DocSet for EmptyScorer {
fn advance(&mut self) -> bool {
false
fn advance(&mut self) -> DocId {
TERMINATED
}
fn doc(&self) -> DocId {
panic!(
"You may not call .doc() on a scorer \
where the last call to advance() did not return true."
);
TERMINATED
}
fn size_hint(&self) -> u32 {
@@ -72,18 +70,15 @@ impl Scorer for EmptyScorer {
#[cfg(test)]
mod tests {
use crate::docset::TERMINATED;
use crate::query::EmptyScorer;
use crate::DocSet;
#[test]
fn test_empty_scorer() {
let mut empty_scorer = EmptyScorer;
assert!(!empty_scorer.advance());
}
#[test]
#[should_panic]
fn test_empty_scorer_panic_on_doc_call() {
EmptyScorer.doc();
assert_eq!(empty_scorer.doc(), TERMINATED);
assert_eq!(empty_scorer.advance(), TERMINATED);
assert_eq!(empty_scorer.doc(), TERMINATED);
}
}

View File

@@ -1,12 +1,11 @@
use crate::docset::{DocSet, SkipResult};
use crate::docset::{DocSet, TERMINATED};
use crate::query::Scorer;
use crate::DocId;
use crate::Score;
#[derive(Clone, Copy, Debug)]
enum State {
ExcludeOne(DocId),
Finished,
#[inline(always)]
fn is_within<TDocSetExclude: DocSet>(docset: &mut TDocSetExclude, doc: DocId) -> bool {
docset.doc() <= doc && docset.seek(doc) == doc
}
/// Filters a given `DocSet` by removing the docs from a given `DocSet`.
@@ -15,29 +14,6 @@ enum State {
pub struct Exclude<TDocSet, TDocSetExclude> {
underlying_docset: TDocSet,
excluding_docset: TDocSetExclude,
excluding_state: State,
}
impl<TDocSet, TDocSetExclude> Exclude<TDocSet, TDocSetExclude>
where
TDocSetExclude: DocSet,
{
/// Creates a new `ExcludeScorer`
pub fn new(
underlying_docset: TDocSet,
mut excluding_docset: TDocSetExclude,
) -> Exclude<TDocSet, TDocSetExclude> {
let state = if excluding_docset.advance() {
State::ExcludeOne(excluding_docset.doc())
} else {
State::Finished
};
Exclude {
underlying_docset,
excluding_docset,
excluding_state: state,
}
}
}
impl<TDocSet, TDocSetExclude> Exclude<TDocSet, TDocSetExclude>
@@ -45,33 +21,21 @@ where
TDocSet: DocSet,
TDocSetExclude: DocSet,
{
/// Returns true iff the doc is not removed.
///
/// The method has to be called with non strictly
/// increasing `doc`.
fn accept(&mut self) -> bool {
let doc = self.underlying_docset.doc();
match self.excluding_state {
State::ExcludeOne(excluded_doc) => {
if doc == excluded_doc {
return false;
}
if excluded_doc > doc {
return true;
}
match self.excluding_docset.skip_next(doc) {
SkipResult::OverStep => {
self.excluding_state = State::ExcludeOne(self.excluding_docset.doc());
true
}
SkipResult::End => {
self.excluding_state = State::Finished;
true
}
SkipResult::Reached => false,
}
/// Creates a new `ExcludeScorer`
pub fn new(
mut underlying_docset: TDocSet,
mut excluding_docset: TDocSetExclude,
) -> Exclude<TDocSet, TDocSetExclude> {
while underlying_docset.doc() != TERMINATED {
let target = underlying_docset.doc();
if !is_within(&mut excluding_docset, target) {
break;
}
State::Finished => true,
underlying_docset.advance();
}
Exclude {
underlying_docset,
excluding_docset,
}
}
}
@@ -81,27 +45,27 @@ where
TDocSet: DocSet,
TDocSetExclude: DocSet,
{
fn advance(&mut self) -> bool {
while self.underlying_docset.advance() {
if self.accept() {
return true;
fn advance(&mut self) -> DocId {
loop {
let candidate = self.underlying_docset.advance();
if candidate == TERMINATED {
return TERMINATED;
}
if !is_within(&mut self.excluding_docset, candidate) {
return candidate;
}
}
false
}
fn skip_next(&mut self, target: DocId) -> SkipResult {
let underlying_skip_result = self.underlying_docset.skip_next(target);
if underlying_skip_result == SkipResult::End {
return SkipResult::End;
fn seek(&mut self, target: DocId) -> DocId {
let candidate = self.underlying_docset.seek(target);
if candidate == TERMINATED {
return TERMINATED;
}
if self.accept() {
underlying_skip_result
} else if self.advance() {
SkipResult::OverStep
} else {
SkipResult::End
if !is_within(&mut self.excluding_docset, candidate) {
return candidate;
}
self.advance()
}
fn doc(&self) -> DocId {
@@ -141,8 +105,9 @@ mod tests {
VecDocSet::from(vec![1, 2, 3, 10, 16, 24]),
);
let mut els = vec![];
while exclude_scorer.advance() {
while exclude_scorer.doc() != TERMINATED {
els.push(exclude_scorer.doc());
exclude_scorer.advance();
}
assert_eq!(els, vec![5, 8, 15]);
}
@@ -156,7 +121,7 @@ mod tests {
VecDocSet::from(vec![1, 2, 3, 10, 16, 24]),
))
},
vec![1, 2, 5, 8, 10, 15, 24],
vec![5, 8, 10, 15, 24],
);
}

View File

@@ -1,4 +1,6 @@
use crate::{DocId, TantivyError};
use serde::Serialize;
use std::fmt;
pub(crate) fn does_not_match(doc: DocId) -> TantivyError {
TantivyError::InvalidArgument(format!("Document #({}) does not match", doc))
@@ -17,6 +19,12 @@ pub struct Explanation {
details: Vec<Explanation>,
}
impl fmt::Debug for Explanation {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Explanation({})", self.to_pretty_json())
}
}
impl Explanation {
/// Creates a new explanation object.
pub fn new<T: ToString>(description: T, value: f32) -> Explanation {

View File

@@ -2,14 +2,40 @@ use crate::query::{AutomatonWeight, Query, Weight};
use crate::schema::Term;
use crate::Searcher;
use crate::TantivyError::InvalidArgument;
use levenshtein_automata::{LevenshteinAutomatonBuilder, DFA};
use levenshtein_automata::{Distance, LevenshteinAutomatonBuilder, DFA};
use once_cell::sync::Lazy;
use std::collections::HashMap;
use std::ops::Range;
use tantivy_fst::Automaton;
pub(crate) struct DFAWrapper(pub DFA);
impl Automaton for DFAWrapper {
type State = u32;
fn start(&self) -> Self::State {
self.0.initial_state()
}
fn is_match(&self, state: &Self::State) -> bool {
match self.0.distance(*state) {
Distance::Exact(_) => true,
Distance::AtLeast(_) => false,
}
}
fn can_match(&self, state: &u32) -> bool {
*state != levenshtein_automata::SINK_STATE
}
fn accept(&self, state: &Self::State, byte: u8) -> Self::State {
self.0.transition(*state, byte)
}
}
/// A range of Levenshtein distances that we will build DFAs for our terms
/// The computation is exponential, so best keep it to low single digits
const VALID_LEVENSHTEIN_DISTANCE_RANGE: Range<u8> = (0..3);
const VALID_LEVENSHTEIN_DISTANCE_RANGE: Range<u8> = 0..3;
static LEV_BUILDER: Lazy<HashMap<(u8, bool), LevenshteinAutomatonBuilder>> = Lazy::new(|| {
let mut lev_builder_cache = HashMap::new();
@@ -91,7 +117,7 @@ impl FuzzyTermQuery {
}
}
/// Creates a new Fuzzy Query that treats transpositions as cost one rather than two
/// Creates a new Fuzzy Query of the Term prefix
pub fn new_prefix(term: Term, distance: u8, transposition_cost_one: bool) -> FuzzyTermQuery {
FuzzyTermQuery {
term,
@@ -101,13 +127,20 @@ impl FuzzyTermQuery {
}
}
fn specialized_weight(&self) -> crate::Result<AutomatonWeight<DFA>> {
fn specialized_weight(&self) -> crate::Result<AutomatonWeight<DFAWrapper>> {
// LEV_BUILDER is a HashMap, whose `get` method returns an Option
match LEV_BUILDER.get(&(self.distance, false)) {
// Unwrap the option and build the Ok(AutomatonWeight)
Some(automaton_builder) => {
let automaton = automaton_builder.build_dfa(self.term.text());
Ok(AutomatonWeight::new(self.term.field(), automaton))
let automaton = if self.prefix {
automaton_builder.build_prefix_dfa(self.term.text())
} else {
automaton_builder.build_dfa(self.term.text())
};
Ok(AutomatonWeight::new(
self.term.field(),
DFAWrapper(automaton),
))
}
None => Err(InvalidArgument(format!(
"Levenshtein distance of {} is not allowed. Choose a value in the {:?} range",
@@ -130,10 +163,10 @@ impl Query for FuzzyTermQuery {
#[cfg(test)]
mod test {
use super::FuzzyTermQuery;
use crate::assert_nearly_equals;
use crate::collector::TopDocs;
use crate::schema::Schema;
use crate::schema::TEXT;
use crate::tests::assert_nearly_equals;
use crate::Index;
use crate::Term;
@@ -155,6 +188,8 @@ mod test {
}
let reader = index.reader().unwrap();
let searcher = reader.searcher();
// passes because Levenshtein distance is 1 (substitute 'o' with 'a')
{
let term = Term::from_field_text(country_field, "japon");
@@ -164,7 +199,31 @@ mod test {
.unwrap();
assert_eq!(top_docs.len(), 1, "Expected only 1 document");
let (score, _) = top_docs[0];
assert_nearly_equals(1f32, score);
assert_nearly_equals!(1f32, score);
}
// fails because non-prefix Levenshtein distance is more than 1 (add 'a' and 'n')
{
let term = Term::from_field_text(country_field, "jap");
let fuzzy_query = FuzzyTermQuery::new(term, 1, true);
let top_docs = searcher
.search(&fuzzy_query, &TopDocs::with_limit(2))
.unwrap();
assert_eq!(top_docs.len(), 0, "Expected no document");
}
// passes because prefix Levenshtein distance is 0
{
let term = Term::from_field_text(country_field, "jap");
let fuzzy_query = FuzzyTermQuery::new_prefix(term, 1, true);
let top_docs = searcher
.search(&fuzzy_query, &TopDocs::with_limit(2))
.unwrap();
assert_eq!(top_docs.len(), 1, "Expected only 1 document");
let (score, _) = top_docs[0];
assert_nearly_equals!(1f32, score);
}
}
}

View File

@@ -1,4 +1,4 @@
use crate::docset::{DocSet, SkipResult};
use crate::docset::{DocSet, TERMINATED};
use crate::query::term_query::TermScorer;
use crate::query::EmptyScorer;
use crate::query::Scorer;
@@ -20,12 +20,14 @@ pub fn intersect_scorers(mut scorers: Vec<Box<dyn Scorer>>) -> Box<dyn Scorer> {
if scorers.len() == 1 {
return scorers.pop().unwrap();
}
scorers.sort_by_key(|scorer| scorer.size_hint());
let doc = go_to_first_doc(&mut scorers[..]);
if doc == TERMINATED {
return Box::new(EmptyScorer);
}
// We know that we have at least 2 elements.
let num_docsets = scorers.len();
scorers.sort_by(|left, right| right.size_hint().cmp(&left.size_hint()));
let left = scorers.pop().unwrap();
let right = scorers.pop().unwrap();
scorers.reverse();
let left = scorers.remove(0);
let right = scorers.remove(0);
let all_term_scorers = [&left, &right]
.iter()
.all(|&scorer| scorer.is::<TermScorer>());
@@ -34,14 +36,12 @@ pub fn intersect_scorers(mut scorers: Vec<Box<dyn Scorer>>) -> Box<dyn Scorer> {
left: *(left.downcast::<TermScorer>().map_err(|_| ()).unwrap()),
right: *(right.downcast::<TermScorer>().map_err(|_| ()).unwrap()),
others: scorers,
num_docsets,
});
}
Box::new(Intersection {
left,
right,
others: scorers,
num_docsets,
})
}
@@ -50,22 +50,35 @@ pub struct Intersection<TDocSet: DocSet, TOtherDocSet: DocSet = Box<dyn Scorer>>
left: TDocSet,
right: TDocSet,
others: Vec<TOtherDocSet>,
num_docsets: usize,
}
fn go_to_first_doc<TDocSet: DocSet>(docsets: &mut [TDocSet]) -> DocId {
assert!(!docsets.is_empty());
let mut candidate = docsets.iter().map(TDocSet::doc).max().unwrap();
'outer: loop {
for docset in docsets.iter_mut() {
let seek_doc = docset.seek(candidate);
if seek_doc > candidate {
candidate = docset.doc();
continue 'outer;
}
}
return candidate;
}
}
impl<TDocSet: DocSet> Intersection<TDocSet, TDocSet> {
pub(crate) fn new(mut docsets: Vec<TDocSet>) -> Intersection<TDocSet, TDocSet> {
let num_docsets = docsets.len();
assert!(num_docsets >= 2);
docsets.sort_by(|left, right| right.size_hint().cmp(&left.size_hint()));
let left = docsets.pop().unwrap();
let right = docsets.pop().unwrap();
docsets.reverse();
docsets.sort_by_key(|docset| docset.size_hint());
go_to_first_doc(&mut docsets);
let left = docsets.remove(0);
let right = docsets.remove(0);
Intersection {
left,
right,
others: docsets,
num_docsets,
}
}
}
@@ -80,128 +93,49 @@ impl<TDocSet: DocSet> Intersection<TDocSet, TDocSet> {
}
}
impl<TDocSet: DocSet, TOtherDocSet: DocSet> Intersection<TDocSet, TOtherDocSet> {
pub(crate) fn docset_mut(&mut self, ord: usize) -> &mut dyn DocSet {
match ord {
0 => &mut self.left,
1 => &mut self.right,
n => &mut self.others[n - 2],
}
}
}
impl<TDocSet: DocSet, TOtherDocSet: DocSet> DocSet for Intersection<TDocSet, TOtherDocSet> {
fn advance(&mut self) -> bool {
fn advance(&mut self) -> DocId {
let (left, right) = (&mut self.left, &mut self.right);
if !left.advance() {
return false;
}
let mut candidate = left.doc();
let mut other_candidate_ord: usize = usize::max_value();
let mut candidate = left.advance();
'outer: loop {
// In the first part we look for a document in the intersection
// of the two rarest `DocSet` in the intersection.
loop {
match right.skip_next(candidate) {
SkipResult::Reached => {
break;
}
SkipResult::OverStep => {
candidate = right.doc();
other_candidate_ord = usize::max_value();
}
SkipResult::End => {
return false;
}
}
match left.skip_next(candidate) {
SkipResult::Reached => {
break;
}
SkipResult::OverStep => {
candidate = left.doc();
other_candidate_ord = usize::max_value();
}
SkipResult::End => {
return false;
}
let right_doc = right.seek(candidate);
candidate = left.seek(right_doc);
if candidate == right_doc {
break;
}
}
debug_assert_eq!(left.doc(), right.doc());
// test the remaining scorers;
for (ord, docset) in self.others.iter_mut().enumerate() {
if ord == other_candidate_ord {
continue;
}
// `candidate_ord` is already at the
// right position.
//
// Calling `skip_next` would advance this docset
// and miss it.
match docset.skip_next(candidate) {
SkipResult::Reached => {}
SkipResult::OverStep => {
// this is not in the intersection,
// let's update our candidate.
candidate = docset.doc();
match left.skip_next(candidate) {
SkipResult::Reached => {
other_candidate_ord = ord;
}
SkipResult::OverStep => {
candidate = left.doc();
other_candidate_ord = usize::max_value();
}
SkipResult::End => {
return false;
}
}
continue 'outer;
}
SkipResult::End => {
return false;
}
for docset in self.others.iter_mut() {
let seek_doc = docset.seek(candidate);
if seek_doc > candidate {
candidate = left.seek(seek_doc);
continue 'outer;
}
}
return true;
debug_assert_eq!(candidate, self.left.doc());
debug_assert_eq!(candidate, self.right.doc());
debug_assert!(self.others.iter().all(|docset| docset.doc() == candidate));
return candidate;
}
}
fn skip_next(&mut self, target: DocId) -> SkipResult {
// We optimize skipping by skipping every single member
// of the intersection to target.
let mut current_target: DocId = target;
let mut current_ord = self.num_docsets;
'outer: loop {
for ord in 0..self.num_docsets {
let docset = self.docset_mut(ord);
if ord == current_ord {
continue;
}
match docset.skip_next(current_target) {
SkipResult::End => {
return SkipResult::End;
}
SkipResult::OverStep => {
// update the target
// for the remaining members of the intersection.
current_target = docset.doc();
current_ord = ord;
continue 'outer;
}
SkipResult::Reached => {}
}
}
if target == current_target {
return SkipResult::Reached;
} else {
assert!(current_target > target);
return SkipResult::OverStep;
}
fn seek(&mut self, target: DocId) -> DocId {
self.left.seek(target);
let mut docsets: Vec<&mut dyn DocSet> = vec![&mut self.left, &mut self.right];
for docset in &mut self.others {
docsets.push(docset);
}
let doc = go_to_first_doc(&mut docsets[..]);
debug_assert!(docsets.iter().all(|docset| docset.doc() == doc));
debug_assert!(doc >= target);
doc
}
fn doc(&self) -> DocId {
@@ -228,7 +162,7 @@ where
#[cfg(test)]
mod tests {
use super::Intersection;
use crate::docset::{DocSet, SkipResult};
use crate::docset::{DocSet, TERMINATED};
use crate::postings::tests::test_skip_against_unoptimized;
use crate::query::VecDocSet;
@@ -238,20 +172,18 @@ mod tests {
let left = VecDocSet::from(vec![1, 3, 9]);
let right = VecDocSet::from(vec![3, 4, 9, 18]);
let mut intersection = Intersection::new(vec![left, right]);
assert!(intersection.advance());
assert_eq!(intersection.doc(), 3);
assert!(intersection.advance());
assert_eq!(intersection.advance(), 9);
assert_eq!(intersection.doc(), 9);
assert!(!intersection.advance());
assert_eq!(intersection.advance(), TERMINATED);
}
{
let a = VecDocSet::from(vec![1, 3, 9]);
let b = VecDocSet::from(vec![3, 4, 9, 18]);
let c = VecDocSet::from(vec![1, 5, 9, 111]);
let mut intersection = Intersection::new(vec![a, b, c]);
assert!(intersection.advance());
assert_eq!(intersection.doc(), 9);
assert!(!intersection.advance());
assert_eq!(intersection.advance(), TERMINATED);
}
}
@@ -260,8 +192,8 @@ mod tests {
let left = VecDocSet::from(vec![0]);
let right = VecDocSet::from(vec![0]);
let mut intersection = Intersection::new(vec![left, right]);
assert!(intersection.advance());
assert_eq!(intersection.doc(), 0);
assert_eq!(intersection.advance(), TERMINATED);
}
#[test]
@@ -269,7 +201,7 @@ mod tests {
let left = VecDocSet::from(vec![0, 1, 2, 4]);
let right = VecDocSet::from(vec![2, 5]);
let mut intersection = Intersection::new(vec![left, right]);
assert_eq!(intersection.skip_next(2), SkipResult::Reached);
assert_eq!(intersection.seek(2), 2);
assert_eq!(intersection.doc(), 2);
}
@@ -312,7 +244,7 @@ mod tests {
let a = VecDocSet::from(vec![1, 3]);
let b = VecDocSet::from(vec![1, 4]);
let c = VecDocSet::from(vec![3, 9]);
let mut intersection = Intersection::new(vec![a, b, c]);
assert!(!intersection.advance());
let intersection = Intersection::new(vec![a, b, c]);
assert_eq!(intersection.doc(), TERMINATED);
}
}

View File

@@ -1,12 +1,11 @@
/*!
Query
*/
/*! Query Module */
mod all_query;
mod automaton_weight;
mod bitset;
mod bm25;
mod boolean_query;
mod boost_query;
mod empty_query;
mod exclude;
mod explanation;
@@ -27,6 +26,7 @@ mod weight;
mod vec_docset;
pub(crate) mod score_combiner;
pub(crate) use self::bm25::BM25Weight;
pub use self::intersection::Intersection;
pub use self::union::Union;
@@ -37,9 +37,12 @@ pub use self::all_query::{AllQuery, AllScorer, AllWeight};
pub use self::automaton_weight::AutomatonWeight;
pub use self::bitset::BitSetDocSet;
pub use self::boolean_query::BooleanQuery;
pub use self::boost_query::BoostQuery;
pub use self::empty_query::{EmptyQuery, EmptyScorer, EmptyWeight};
pub use self::exclude::Exclude;
pub use self::explanation::Explanation;
#[cfg(test)]
pub(crate) use self::fuzzy_query::DFAWrapper;
pub use self::fuzzy_query::FuzzyTermQuery;
pub use self::intersection::intersect_scorers;
pub use self::phrase_query::PhraseQuery;

View File

@@ -7,18 +7,18 @@ pub use self::phrase_scorer::PhraseScorer;
pub use self::phrase_weight::PhraseWeight;
#[cfg(test)]
mod tests {
pub mod tests {
use super::*;
use crate::assert_nearly_equals;
use crate::collector::tests::{TEST_COLLECTOR_WITHOUT_SCORE, TEST_COLLECTOR_WITH_SCORE};
use crate::core::Index;
use crate::error::TantivyError;
use crate::query::Weight;
use crate::schema::{Schema, Term, TEXT};
use crate::tests::assert_nearly_equals;
use crate::DocId;
use crate::{DocAddress, DocSet};
use crate::{DocAddress, TERMINATED};
fn create_index(texts: &[&'static str]) -> Index {
pub fn create_index(texts: &[&'static str]) -> Index {
let mut schema_builder = Schema::builder();
let text_field = schema_builder.add_text_field("text", TEXT);
let schema = schema_builder.build();
@@ -61,13 +61,30 @@ mod tests {
.map(|docaddr| docaddr.1)
.collect::<Vec<_>>()
};
assert_eq!(test_query(vec!["a", "b", "c"]), vec![2, 4]);
assert_eq!(test_query(vec!["a", "b"]), vec![1, 2, 3, 4]);
assert_eq!(test_query(vec!["a", "b", "c"]), vec![2, 4]);
assert_eq!(test_query(vec!["b", "b"]), vec![0, 1]);
assert!(test_query(vec!["g", "ewrwer"]).is_empty());
assert!(test_query(vec!["g", "a"]).is_empty());
}
#[test]
pub fn test_phrase_query_simple() -> crate::Result<()> {
let index = create_index(&["a b b d c g c", "a b a b c"]);
let text_field = index.schema().get_field("text").unwrap();
let searcher = index.reader()?.searcher();
let terms: Vec<Term> = vec!["a", "b", "c"]
.iter()
.map(|text| Term::from_field_text(text_field, text))
.collect();
let phrase_query = PhraseQuery::new(terms);
let phrase_weight = phrase_query.phrase_weight(&searcher, false)?;
let mut phrase_scorer = phrase_weight.scorer(searcher.segment_reader(0), 1.0f32)?;
assert_eq!(phrase_scorer.doc(), 1);
assert_eq!(phrase_scorer.advance(), TERMINATED);
Ok(())
}
#[test]
pub fn test_phrase_query_no_score() {
let index = create_index(&[
@@ -102,30 +119,6 @@ mod tests {
assert!(test_query(vec!["g", "a"]).is_empty());
}
#[test]
pub fn test_phrase_count() {
let index = create_index(&["a c", "a a b d a b c", " a b"]);
let schema = index.schema();
let text_field = schema.get_field("text").unwrap();
let searcher = index.reader().unwrap().searcher();
let phrase_query = PhraseQuery::new(vec![
Term::from_field_text(text_field, "a"),
Term::from_field_text(text_field, "b"),
]);
let phrase_weight = phrase_query.phrase_weight(&searcher, true).unwrap();
let mut phrase_scorer = phrase_weight
.phrase_scorer(searcher.segment_reader(0u32))
.unwrap()
.unwrap();
assert!(phrase_scorer.advance());
assert_eq!(phrase_scorer.doc(), 1);
assert_eq!(phrase_scorer.phrase_count(), 2);
assert!(phrase_scorer.advance());
assert_eq!(phrase_scorer.doc(), 2);
assert_eq!(phrase_scorer.phrase_count(), 1);
assert!(!phrase_scorer.advance());
}
#[test]
pub fn test_phrase_query_no_positions() {
let mut schema_builder = Schema::builder();
@@ -151,21 +144,16 @@ mod tests {
Term::from_field_text(text_field, "a"),
Term::from_field_text(text_field, "b"),
]);
match searcher
let search_result = searcher
.search(&phrase_query, &TEST_COLLECTOR_WITH_SCORE)
.map(|_| ())
.unwrap_err()
{
TantivyError::SchemaError(ref msg) => {
assert_eq!(
"Applied phrase query on field \"text\", which does not have positions indexed",
msg.as_str()
);
}
_ => {
panic!("Should have returned an error");
}
}
.map(|_| ());
assert!(matches!(
search_result,
Err(crate::TantivyError::SchemaError(msg))
if msg == "Applied phrase query on field \"text\", which does not have positions \
indexed"
));
}
#[test]
@@ -187,8 +175,8 @@ mod tests {
.to_vec()
};
let scores = test_query(vec!["a", "b"]);
assert_nearly_equals(scores[0], 0.40618482);
assert_nearly_equals(scores[1], 0.46844664);
assert_nearly_equals!(scores[0], 0.40618482);
assert_nearly_equals!(scores[1], 0.46844664);
}
#[test] // motivated by #234

View File

@@ -1,4 +1,4 @@
use crate::docset::{DocSet, SkipResult};
use crate::docset::{DocSet, TERMINATED};
use crate::fieldnorm::FieldNormReader;
use crate::postings::Postings;
use crate::query::bm25::BM25Weight;
@@ -25,12 +25,12 @@ impl<TPostings: Postings> PostingsWithOffset<TPostings> {
}
impl<TPostings: Postings> DocSet for PostingsWithOffset<TPostings> {
fn advance(&mut self) -> bool {
fn advance(&mut self) -> DocId {
self.postings.advance()
}
fn skip_next(&mut self, target: DocId) -> SkipResult {
self.postings.skip_next(target)
fn seek(&mut self, target: DocId) -> DocId {
self.postings.seek(target)
}
fn doc(&self) -> DocId {
@@ -149,7 +149,7 @@ impl<TPostings: Postings> PhraseScorer<TPostings> {
PostingsWithOffset::new(postings, (max_offset - offset) as u32)
})
.collect::<Vec<_>>();
PhraseScorer {
let mut scorer = PhraseScorer {
intersection_docset: Intersection::new(postings_with_offsets),
num_terms: num_docsets,
left: Vec::with_capacity(100),
@@ -158,7 +158,11 @@ impl<TPostings: Postings> PhraseScorer<TPostings> {
similarity_weight,
fieldnorm_reader,
score_needed,
};
if scorer.doc() != TERMINATED && !scorer.phrase_match() {
scorer.advance();
}
scorer
}
pub fn phrase_count(&self) -> u32 {
@@ -225,31 +229,22 @@ impl<TPostings: Postings> PhraseScorer<TPostings> {
}
impl<TPostings: Postings> DocSet for PhraseScorer<TPostings> {
fn advance(&mut self) -> bool {
while self.intersection_docset.advance() {
if self.phrase_match() {
return true;
fn advance(&mut self) -> DocId {
loop {
let doc = self.intersection_docset.advance();
if doc == TERMINATED || self.phrase_match() {
return doc;
}
}
false
}
fn skip_next(&mut self, target: DocId) -> SkipResult {
if self.intersection_docset.skip_next(target) == SkipResult::End {
return SkipResult::End;
}
if self.phrase_match() {
if self.doc() == target {
return SkipResult::Reached;
} else {
return SkipResult::OverStep;
}
}
if self.advance() {
SkipResult::OverStep
} else {
SkipResult::End
fn seek(&mut self, target: DocId) -> DocId {
debug_assert!(target >= self.doc());
let doc = self.intersection_docset.seek(target);
if doc == TERMINATED || self.phrase_match() {
return doc;
}
self.advance()
}
fn doc(&self) -> DocId {
@@ -272,7 +267,6 @@ impl<TPostings: Postings> Scorer for PhraseScorer<TPostings> {
#[cfg(test)]
mod tests {
use super::{intersection, intersection_count};
fn test_intersection_sym(left: &[u32], right: &[u32], expected: &[u32]) {

View File

@@ -9,8 +9,8 @@ use crate::query::Weight;
use crate::query::{EmptyScorer, Explanation};
use crate::schema::IndexRecordOption;
use crate::schema::Term;
use crate::Result;
use crate::{DocId, DocSet};
use crate::{Result, SkipResult};
pub struct PhraseWeight {
phrase_terms: Vec<(usize, Term)>,
@@ -37,11 +37,12 @@ impl PhraseWeight {
reader.get_fieldnorms_reader(field)
}
pub fn phrase_scorer(
fn phrase_scorer(
&self,
reader: &SegmentReader,
boost: f32,
) -> Result<Option<PhraseScorer<SegmentPostings>>> {
let similarity_weight = self.similarity_weight.clone();
let similarity_weight = self.similarity_weight.boost_by(boost);
let fieldnorm_reader = self.fieldnorm_reader(reader);
if reader.has_deletes() {
let mut term_postings_list = Vec::new();
@@ -84,8 +85,8 @@ impl PhraseWeight {
}
impl Weight for PhraseWeight {
fn scorer(&self, reader: &SegmentReader) -> Result<Box<dyn Scorer>> {
if let Some(scorer) = self.phrase_scorer(reader)? {
fn scorer(&self, reader: &SegmentReader, boost: f32) -> Result<Box<dyn Scorer>> {
if let Some(scorer) = self.phrase_scorer(reader, boost)? {
Ok(Box::new(scorer))
} else {
Ok(Box::new(EmptyScorer))
@@ -93,12 +94,12 @@ impl Weight for PhraseWeight {
}
fn explain(&self, reader: &SegmentReader, doc: DocId) -> Result<Explanation> {
let scorer_opt = self.phrase_scorer(reader)?;
let scorer_opt = self.phrase_scorer(reader, 1.0f32)?;
if scorer_opt.is_none() {
return Err(does_not_match(doc));
}
let mut scorer = scorer_opt.unwrap();
if scorer.skip_next(doc) != SkipResult::Reached {
if scorer.seek(doc) != doc {
return Err(does_not_match(doc));
}
let fieldnorm_reader = self.fieldnorm_reader(reader);
@@ -109,3 +110,34 @@ impl Weight for PhraseWeight {
Ok(explanation)
}
}
#[cfg(test)]
mod tests {
use super::super::tests::create_index;
use crate::docset::TERMINATED;
use crate::query::PhraseQuery;
use crate::{DocSet, Term};
#[test]
pub fn test_phrase_count() {
let index = create_index(&["a c", "a a b d a b c", " a b"]);
let schema = index.schema();
let text_field = schema.get_field("text").unwrap();
let searcher = index.reader().unwrap().searcher();
let phrase_query = PhraseQuery::new(vec![
Term::from_field_text(text_field, "a"),
Term::from_field_text(text_field, "b"),
]);
let phrase_weight = phrase_query.phrase_weight(&searcher, true).unwrap();
let mut phrase_scorer = phrase_weight
.phrase_scorer(searcher.segment_reader(0u32), 1.0f32)
.unwrap()
.unwrap();
assert_eq!(phrase_scorer.doc(), 1);
assert_eq!(phrase_scorer.phrase_count(), 2);
assert_eq!(phrase_scorer.advance(), 2);
assert_eq!(phrase_scorer.doc(), 2);
assert_eq!(phrase_scorer.phrase_count(), 1);
assert_eq!(phrase_scorer.advance(), TERMINATED);
}
}

View File

@@ -21,6 +21,17 @@ pub enum LogicalLiteral {
pub enum LogicalAST {
Clause(Vec<(Occur, LogicalAST)>),
Leaf(Box<LogicalLiteral>),
Boost(Box<LogicalAST>, f32),
}
impl LogicalAST {
pub fn boost(self, boost: f32) -> LogicalAST {
if (boost - 1.0f32).abs() < std::f32::EPSILON {
self
} else {
LogicalAST::Boost(Box::new(self), boost)
}
}
}
fn occur_letter(occur: Occur) -> &'static str {
@@ -47,6 +58,7 @@ impl fmt::Debug for LogicalAST {
}
Ok(())
}
LogicalAST::Boost(ref ast, boost) => write!(formatter, "{:?}^{}", ast, boost),
LogicalAST::Leaf(ref literal) => write!(formatter, "{:?}", literal),
}
}

View File

@@ -1,6 +1,5 @@
use super::logical_ast::*;
use crate::core::Index;
use crate::query::AllQuery;
use crate::query::BooleanQuery;
use crate::query::EmptyQuery;
use crate::query::Occur;
@@ -8,11 +7,13 @@ use crate::query::PhraseQuery;
use crate::query::Query;
use crate::query::RangeQuery;
use crate::query::TermQuery;
use crate::query::{AllQuery, BoostQuery};
use crate::schema::{Facet, IndexRecordOption};
use crate::schema::{Field, Schema};
use crate::schema::{FieldType, Term};
use crate::tokenizer::TokenizerManager;
use std::borrow::Cow;
use std::collections::HashMap;
use std::num::{ParseFloatError, ParseIntError};
use std::ops::Bound;
use std::str::FromStr;
@@ -112,8 +113,9 @@ fn trim_ast(logical_ast: LogicalAST) -> Option<LogicalAST> {
/// The language covered by the current parser is extremely simple.
///
/// * simple terms: "e.g.: `Barack Obama` are simply tokenized using
/// tantivy's `StandardTokenizer`, hence becoming `["barack", "obama"]`.
/// The terms are then searched within the default terms of the query parser.
/// tantivy's [`SimpleTokenizer`](../tokenizer/struct.SimpleTokenizer.html), hence
/// becoming `["barack", "obama"]`. The terms are then searched within
/// the default terms of the query parser.
///
/// e.g. If `body` and `title` are default fields, our example terms are
/// `["title:barack", "body:barack", "title:obama", "body:obama"]`.
@@ -144,7 +146,6 @@ fn trim_ast(logical_ast: LogicalAST) -> Option<LogicalAST> {
///
/// * must terms: By prepending a term by a `+`, a term can be made required for the search.
///
///
/// * phrase terms: Quoted terms become phrase searches on fields that have positions indexed.
/// e.g., `title:"Barack Obama"` will only find documents that have "barack" immediately followed
/// by "obama".
@@ -158,12 +159,30 @@ fn trim_ast(logical_ast: LogicalAST) -> Option<LogicalAST> {
///
/// * all docs query: A plain `*` will match all documents in the index.
///
/// Parts of the queries can be boosted by appending `^boostfactor`.
/// For instance, `"SRE"^2.0 OR devops^0.4` will boost documents containing `SRE` instead of
/// devops. Negative boosts are not allowed.
///
/// It is also possible to define a boost for a some specific field, at the query parser level.
/// (See [`set_boost(...)`](#method.set_field_boost) ). Typically you may want to boost a title
/// field.
#[derive(Clone)]
pub struct QueryParser {
schema: Schema,
default_fields: Vec<Field>,
conjunction_by_default: bool,
tokenizer_manager: TokenizerManager,
boost: HashMap<Field, f32>,
}
fn all_negative(ast: &LogicalAST) -> bool {
match ast {
LogicalAST::Leaf(_) => false,
LogicalAST::Boost(ref child_ast, _) => all_negative(&*child_ast),
LogicalAST::Clause(children) => children
.iter()
.all(|(ref occur, child)| (*occur == Occur::MustNot) || all_negative(child)),
}
}
impl QueryParser {
@@ -181,6 +200,7 @@ impl QueryParser {
default_fields,
tokenizer_manager,
conjunction_by_default: false,
boost: Default::default(),
}
}
@@ -201,6 +221,17 @@ impl QueryParser {
self.conjunction_by_default = true;
}
/// Sets a boost for a specific field.
///
/// The parse query will automatically boost this field.
///
/// If the query defines a query boost through the query language (e.g: `country:France^3.0`),
/// the two boosts (the one defined in the query, and the one defined in the `QueryParser`)
/// are multiplied together.
pub fn set_field_boost(&mut self, field: Field, boost: f32) {
self.boost.insert(field, boost);
}
/// Parse a query
///
/// Note that `parse_query` returns an error if the input
@@ -233,8 +264,13 @@ impl QueryParser {
&self,
user_input_ast: UserInputAST,
) -> Result<LogicalAST, QueryParserError> {
let (occur, ast) = self.compute_logical_ast_with_occur(user_input_ast)?;
if occur == Occur::MustNot {
let ast = self.compute_logical_ast_with_occur(user_input_ast)?;
if let LogicalAST::Clause(children) = &ast {
if children.is_empty() {
return Ok(ast);
}
}
if all_negative(&ast) {
return Err(QueryParserError::AllButQueryForbidden);
}
Ok(ast)
@@ -390,30 +426,30 @@ impl QueryParser {
fn compute_logical_ast_with_occur(
&self,
user_input_ast: UserInputAST,
) -> Result<(Occur, LogicalAST), QueryParserError> {
) -> Result<LogicalAST, QueryParserError> {
match user_input_ast {
UserInputAST::Clause(sub_queries) => {
let default_occur = self.default_occur();
let mut logical_sub_queries: Vec<(Occur, LogicalAST)> = Vec::new();
for sub_query in sub_queries {
let (occur, sub_ast) = self.compute_logical_ast_with_occur(sub_query)?;
let new_occur = Occur::compose(default_occur, occur);
logical_sub_queries.push((new_occur, sub_ast));
for (occur_opt, sub_ast) in sub_queries {
let sub_ast = self.compute_logical_ast_with_occur(sub_ast)?;
let occur = occur_opt.unwrap_or(default_occur);
logical_sub_queries.push((occur, sub_ast));
}
Ok((Occur::Should, LogicalAST::Clause(logical_sub_queries)))
Ok(LogicalAST::Clause(logical_sub_queries))
}
UserInputAST::Unary(left_occur, subquery) => {
let (right_occur, logical_sub_queries) =
self.compute_logical_ast_with_occur(*subquery)?;
Ok((Occur::compose(left_occur, right_occur), logical_sub_queries))
}
UserInputAST::Leaf(leaf) => {
let result_ast = self.compute_logical_ast_from_leaf(*leaf)?;
Ok((Occur::Should, result_ast))
UserInputAST::Boost(ast, boost) => {
let ast = self.compute_logical_ast_with_occur(*ast)?;
Ok(ast.boost(boost))
}
UserInputAST::Leaf(leaf) => self.compute_logical_ast_from_leaf(*leaf),
}
}
fn field_boost(&self, field: Field) -> f32 {
self.boost.get(&field).cloned().unwrap_or(1.0f32)
}
fn compute_logical_ast_from_leaf(
&self,
leaf: UserInputLeaf,
@@ -439,7 +475,9 @@ impl QueryParser {
let mut asts: Vec<LogicalAST> = Vec::new();
for (field, phrase) in term_phrases {
if let Some(ast) = self.compute_logical_ast_for_leaf(field, &phrase)? {
asts.push(LogicalAST::Leaf(Box::new(ast)));
// Apply some field specific boost defined at the query parser level.
let boost = self.field_boost(field);
asts.push(LogicalAST::Leaf(Box::new(ast)).boost(boost));
}
}
let result_ast: LogicalAST = if asts.len() == 1 {
@@ -459,14 +497,16 @@ impl QueryParser {
let mut clauses = fields
.iter()
.map(|&field| {
let boost = self.field_boost(field);
let field_entry = self.schema.get_field_entry(field);
let value_type = field_entry.field_type().value_type();
Ok(LogicalAST::Leaf(Box::new(LogicalLiteral::Range {
let logical_ast = LogicalAST::Leaf(Box::new(LogicalLiteral::Range {
field,
value_type,
lower: self.resolve_bound(field, &lower)?,
upper: self.resolve_bound(field, &upper)?,
})))
}));
Ok(logical_ast.boost(boost))
})
.collect::<Result<Vec<_>, QueryParserError>>()?;
let result_ast = if clauses.len() == 1 {
@@ -519,6 +559,11 @@ fn convert_to_query(logical_ast: LogicalAST) -> Box<dyn Query> {
Some(LogicalAST::Leaf(trimmed_logical_literal)) => {
convert_literal_to_query(*trimmed_logical_literal)
}
Some(LogicalAST::Boost(ast, boost)) => {
let query = convert_to_query(*ast);
let boosted_query = BoostQuery::new(query, boost);
Box::new(boosted_query)
}
None => Box::new(EmptyQuery),
}
}
@@ -538,7 +583,7 @@ mod test {
use crate::Index;
use matches::assert_matches;
fn make_query_parser() -> QueryParser {
fn make_schema() -> Schema {
let mut schema_builder = Schema::builder();
let text_field_indexing = TextFieldIndexing::default()
.set_tokenizer("en_with_stop_words")
@@ -546,8 +591,8 @@ mod test {
let text_options = TextOptions::default()
.set_indexing_options(text_field_indexing)
.set_stored();
let title = schema_builder.add_text_field("title", TEXT);
let text = schema_builder.add_text_field("text", TEXT);
schema_builder.add_text_field("title", TEXT);
schema_builder.add_text_field("text", TEXT);
schema_builder.add_i64_field("signed", INDEXED);
schema_builder.add_u64_field("unsigned", INDEXED);
schema_builder.add_text_field("notindexed_text", STORED);
@@ -558,8 +603,15 @@ mod test {
schema_builder.add_date_field("date", INDEXED);
schema_builder.add_f64_field("float", INDEXED);
schema_builder.add_facet_field("facet");
let schema = schema_builder.build();
let default_fields = vec![title, text];
schema_builder.build()
}
fn make_query_parser() -> QueryParser {
let schema = make_schema();
let default_fields: Vec<Field> = vec!["title", "text"]
.into_iter()
.flat_map(|field_name| schema.get_field(field_name))
.collect();
let tokenizer_manager = TokenizerManager::default();
tokenizer_manager.register(
"en_with_stop_words",
@@ -601,6 +653,45 @@ mod test {
);
}
#[test]
pub fn test_parse_query_with_boost() {
let mut query_parser = make_query_parser();
let schema = make_schema();
let text_field = schema.get_field("text").unwrap();
query_parser.set_field_boost(text_field, 2.0f32);
let query = query_parser.parse_query("text:hello").unwrap();
assert_eq!(
format!("{:?}", query),
"Boost(query=TermQuery(Term(field=1,bytes=[104, 101, 108, 108, 111])), boost=2)"
);
}
#[test]
pub fn test_parse_query_range_with_boost() {
let mut query_parser = make_query_parser();
let schema = make_schema();
let title_field = schema.get_field("title").unwrap();
query_parser.set_field_boost(title_field, 2.0f32);
let query = query_parser.parse_query("title:[A TO B]").unwrap();
assert_eq!(
format!("{:?}", query),
"Boost(query=RangeQuery { field: Field(0), value_type: Str, left_bound: Included([97]), right_bound: Included([98]) }, boost=2)"
);
}
#[test]
pub fn test_parse_query_with_default_boost_and_custom_boost() {
let mut query_parser = make_query_parser();
let schema = make_schema();
let text_field = schema.get_field("text").unwrap();
query_parser.set_field_boost(text_field, 2.0f32);
let query = query_parser.parse_query("text:hello^2").unwrap();
assert_eq!(
format!("{:?}", query),
"Boost(query=Boost(query=TermQuery(Term(field=1,bytes=[104, 101, 108, 108, 111])), boost=2), boost=2)"
);
}
#[test]
pub fn test_parse_nonindexed_field_yields_error() {
let query_parser = make_query_parser();
@@ -699,6 +790,20 @@ mod test {
);
}
#[test]
fn test_parse_query_to_ast_ab_c() {
test_parse_query_to_logical_ast_helper(
"(+title:a +title:b) title:c",
"((+Term(field=0,bytes=[97]) +Term(field=0,bytes=[98])) Term(field=0,bytes=[99]))",
false,
);
test_parse_query_to_logical_ast_helper(
"(+title:a +title:b) title:c",
"(+(+Term(field=0,bytes=[97]) +Term(field=0,bytes=[98])) +Term(field=0,bytes=[99]))",
true,
);
}
#[test]
pub fn test_parse_query_to_ast_single_term() {
test_parse_query_to_logical_ast_helper(
@@ -718,11 +823,13 @@ mod test {
Term(field=1,bytes=[116, 105, 116, 105])))",
false,
);
assert_eq!(
parse_query_to_logical_ast("-title:toto", false)
.err()
.unwrap(),
QueryParserError::AllButQueryForbidden
}
#[test]
fn test_single_negative_term() {
assert_matches!(
parse_query_to_logical_ast("-title:toto", false),
Err(QueryParserError::AllButQueryForbidden)
);
}
@@ -882,6 +989,18 @@ mod test {
assert!(query_parser.parse_query("with_stop_words:the").is_ok());
}
#[test]
pub fn test_parse_query_single_negative_term_through_error() {
assert_matches!(
parse_query_to_logical_ast("-title:toto", true),
Err(QueryParserError::AllButQueryForbidden)
);
assert_matches!(
parse_query_to_logical_ast("-title:toto", false),
Err(QueryParserError::AllButQueryForbidden)
);
}
#[test]
pub fn test_parse_query_to_ast_conjunction() {
test_parse_query_to_logical_ast_helper(
@@ -901,12 +1020,6 @@ mod test {
Term(field=1,bytes=[116, 105, 116, 105])))",
true,
);
assert_eq!(
parse_query_to_logical_ast("-title:toto", true)
.err()
.unwrap(),
QueryParserError::AllButQueryForbidden
);
test_parse_query_to_logical_ast_helper(
"title:a b",
"(+Term(field=0,bytes=[97]) \
@@ -930,4 +1043,26 @@ mod test {
false
);
}
#[test]
fn test_and_default_regardless_of_default_conjunctive() {
for &default_conjunction in &[false, true] {
test_parse_query_to_logical_ast_helper(
"title:a AND title:b",
"(+Term(field=0,bytes=[97]) +Term(field=0,bytes=[98]))",
default_conjunction,
);
}
}
#[test]
fn test_or_default_conjunctive() {
for &default_conjunction in &[false, true] {
test_parse_query_to_logical_ast_helper(
"title:a OR title:b",
"(Term(field=0,bytes=[97]) Term(field=0,bytes=[98]))",
default_conjunction,
);
}
}
}

View File

@@ -10,7 +10,7 @@ use crate::schema::Type;
use crate::schema::{Field, IndexRecordOption, Term};
use crate::termdict::{TermDictionary, TermStreamer};
use crate::DocId;
use crate::{Result, SkipResult};
use crate::Result;
use std::collections::Bound;
use std::ops::Range;
@@ -289,7 +289,7 @@ impl RangeWeight {
}
impl Weight for RangeWeight {
fn scorer(&self, reader: &SegmentReader) -> Result<Box<dyn Scorer>> {
fn scorer(&self, reader: &SegmentReader, boost: f32) -> Result<Box<dyn Scorer>> {
let max_doc = reader.max_doc();
let mut doc_bitset = BitSet::with_max_value(max_doc);
@@ -300,19 +300,24 @@ impl Weight for RangeWeight {
let term_info = term_range.value();
let mut block_segment_postings = inverted_index
.read_block_postings_from_terminfo(term_info, IndexRecordOption::Basic);
while block_segment_postings.advance() {
loop {
let docs = block_segment_postings.docs();
if docs.is_empty() {
break;
}
for &doc in block_segment_postings.docs() {
doc_bitset.insert(doc);
}
block_segment_postings.advance();
}
}
let doc_bitset = BitSetDocSet::from(doc_bitset);
Ok(Box::new(ConstScorer::new(doc_bitset)))
Ok(Box::new(ConstScorer::new(doc_bitset, boost)))
}
fn explain(&self, reader: &SegmentReader, doc: DocId) -> Result<Explanation> {
let mut scorer = self.scorer(reader)?;
if scorer.skip_next(doc) != SkipResult::Reached {
let mut scorer = self.scorer(reader, 1.0f32)?;
if scorer.seek(doc) != doc {
return Err(does_not_match(doc));
}
Ok(Explanation::new("RangeQuery", 1.0f32))

View File

@@ -89,10 +89,10 @@ impl Query for RegexQuery {
#[cfg(test)]
mod test {
use super::RegexQuery;
use crate::assert_nearly_equals;
use crate::collector::TopDocs;
use crate::schema::TEXT;
use crate::schema::{Field, Schema};
use crate::tests::assert_nearly_equals;
use crate::{Index, IndexReader};
use std::sync::Arc;
use tantivy_fst::Regex;
@@ -129,7 +129,7 @@ mod test {
.unwrap();
assert_eq!(scored_docs.len(), 1, "Expected only 1 document");
let (score, _) = scored_docs[0];
assert_nearly_equals(1f32, score);
assert_nearly_equals!(1f32, score);
}
let top_docs = searcher
.search(&query_matching_zero, &TopDocs::with_limit(2))

View File

@@ -1,9 +1,8 @@
use crate::docset::{DocSet, SkipResult};
use crate::docset::DocSet;
use crate::query::score_combiner::ScoreCombiner;
use crate::query::Scorer;
use crate::DocId;
use crate::Score;
use std::cmp::Ordering;
use std::marker::PhantomData;
/// Given a required scorer and an optional scorer
@@ -17,7 +16,6 @@ pub struct RequiredOptionalScorer<TReqScorer, TOptScorer, TScoreCombiner> {
req_scorer: TReqScorer,
opt_scorer: TOptScorer,
score_cache: Option<Score>,
opt_finished: bool,
_phantom: PhantomData<TScoreCombiner>,
}
@@ -29,14 +27,12 @@ where
/// Creates a new `RequiredOptionalScorer`.
pub fn new(
req_scorer: TReqScorer,
mut opt_scorer: TOptScorer,
opt_scorer: TOptScorer,
) -> RequiredOptionalScorer<TReqScorer, TOptScorer, TScoreCombiner> {
let opt_finished = !opt_scorer.advance();
RequiredOptionalScorer {
req_scorer,
opt_scorer,
score_cache: None,
opt_finished,
_phantom: PhantomData,
}
}
@@ -48,7 +44,7 @@ where
TReqScorer: DocSet,
TOptScorer: DocSet,
{
fn advance(&mut self) -> bool {
fn advance(&mut self) -> DocId {
self.score_cache = None;
self.req_scorer.advance()
}
@@ -76,22 +72,8 @@ where
let doc = self.doc();
let mut score_combiner = TScoreCombiner::default();
score_combiner.update(&mut self.req_scorer);
if !self.opt_finished {
match self.opt_scorer.doc().cmp(&doc) {
Ordering::Greater => {}
Ordering::Equal => {
score_combiner.update(&mut self.opt_scorer);
}
Ordering::Less => match self.opt_scorer.skip_next(doc) {
SkipResult::Reached => {
score_combiner.update(&mut self.opt_scorer);
}
SkipResult::End => {
self.opt_finished = true;
}
SkipResult::OverStep => {}
},
}
if self.opt_scorer.doc() <= doc && self.opt_scorer.seek(doc) == doc {
score_combiner.update(&mut self.opt_scorer);
}
let score = score_combiner.score();
self.score_cache = Some(score);
@@ -102,7 +84,7 @@ where
#[cfg(test)]
mod tests {
use super::RequiredOptionalScorer;
use crate::docset::DocSet;
use crate::docset::{DocSet, TERMINATED};
use crate::postings::tests::test_skip_against_unoptimized;
use crate::query::score_combiner::{DoNothingCombiner, SumCombiner};
use crate::query::ConstScorer;
@@ -115,12 +97,13 @@ mod tests {
let req = vec![1, 3, 7];
let mut reqoptscorer: RequiredOptionalScorer<_, _, SumCombiner> =
RequiredOptionalScorer::new(
ConstScorer::new(VecDocSet::from(req.clone())),
ConstScorer::new(VecDocSet::from(vec![])),
ConstScorer::from(VecDocSet::from(req.clone())),
ConstScorer::from(VecDocSet::from(vec![])),
);
let mut docs = vec![];
while reqoptscorer.advance() {
while reqoptscorer.doc() != TERMINATED {
docs.push(reqoptscorer.doc());
reqoptscorer.advance();
}
assert_eq!(docs, req);
}
@@ -129,50 +112,49 @@ mod tests {
fn test_reqopt_scorer() {
let mut reqoptscorer: RequiredOptionalScorer<_, _, SumCombiner> =
RequiredOptionalScorer::new(
ConstScorer::new(VecDocSet::from(vec![1, 3, 7, 8, 9, 10, 13, 15])),
ConstScorer::new(VecDocSet::from(vec![1, 2, 7, 11, 12, 15])),
ConstScorer::new(VecDocSet::from(vec![1, 3, 7, 8, 9, 10, 13, 15]), 1.0f32),
ConstScorer::new(VecDocSet::from(vec![1, 2, 7, 11, 12, 15]), 1.0f32),
);
{
assert!(reqoptscorer.advance());
assert_eq!(reqoptscorer.doc(), 1);
assert_eq!(reqoptscorer.score(), 2f32);
}
{
assert!(reqoptscorer.advance());
assert_eq!(reqoptscorer.advance(), 3);
assert_eq!(reqoptscorer.doc(), 3);
assert_eq!(reqoptscorer.score(), 1f32);
}
{
assert!(reqoptscorer.advance());
assert_eq!(reqoptscorer.advance(), 7);
assert_eq!(reqoptscorer.doc(), 7);
assert_eq!(reqoptscorer.score(), 2f32);
}
{
assert!(reqoptscorer.advance());
assert_eq!(reqoptscorer.advance(), 8);
assert_eq!(reqoptscorer.doc(), 8);
assert_eq!(reqoptscorer.score(), 1f32);
}
{
assert!(reqoptscorer.advance());
assert_eq!(reqoptscorer.advance(), 9);
assert_eq!(reqoptscorer.doc(), 9);
assert_eq!(reqoptscorer.score(), 1f32);
}
{
assert!(reqoptscorer.advance());
assert_eq!(reqoptscorer.advance(), 10);
assert_eq!(reqoptscorer.doc(), 10);
assert_eq!(reqoptscorer.score(), 1f32);
}
{
assert!(reqoptscorer.advance());
assert_eq!(reqoptscorer.advance(), 13);
assert_eq!(reqoptscorer.doc(), 13);
assert_eq!(reqoptscorer.score(), 1f32);
}
{
assert!(reqoptscorer.advance());
assert_eq!(reqoptscorer.advance(), 15);
assert_eq!(reqoptscorer.doc(), 15);
assert_eq!(reqoptscorer.score(), 2f32);
}
assert!(!reqoptscorer.advance());
assert_eq!(reqoptscorer.advance(), TERMINATED);
}
#[test]
@@ -183,8 +165,8 @@ mod tests {
test_skip_against_unoptimized(
|| {
Box::new(RequiredOptionalScorer::<_, _, DoNothingCombiner>::new(
ConstScorer::new(VecDocSet::from(req_docs.clone())),
ConstScorer::new(VecDocSet::from(opt_docs.clone())),
ConstScorer::from(VecDocSet::from(req_docs.clone())),
ConstScorer::from(VecDocSet::from(opt_docs.clone())),
))
},
skip_docs,

View File

@@ -1,5 +1,4 @@
use crate::common::BitSet;
use crate::docset::{DocSet, SkipResult};
use crate::docset::DocSet;
use crate::DocId;
use crate::Score;
use downcast_rs::impl_downcast;
@@ -13,14 +12,6 @@ pub trait Scorer: downcast_rs::Downcast + DocSet + 'static {
///
/// This method will perform a bit of computation and is not cached.
fn score(&mut self) -> Score;
/// Iterates through all of the document matched by the DocSet
/// `DocSet` and push the scored documents to the collector.
fn for_each(&mut self, callback: &mut dyn FnMut(DocId, Score)) {
while self.advance() {
callback(self.doc(), self.score());
}
}
}
impl_downcast!(Scorer);
@@ -29,11 +20,6 @@ impl Scorer for Box<dyn Scorer> {
fn score(&mut self) -> Score {
self.deref_mut().score()
}
fn for_each(&mut self, callback: &mut dyn FnMut(DocId, Score)) {
let scorer = self.deref_mut();
scorer.for_each(callback);
}
}
/// Wraps a `DocSet` and simply returns a constant `Scorer`.
@@ -49,26 +35,24 @@ pub struct ConstScorer<TDocSet: DocSet> {
impl<TDocSet: DocSet> ConstScorer<TDocSet> {
/// Creates a new `ConstScorer`.
pub fn new(docset: TDocSet) -> ConstScorer<TDocSet> {
ConstScorer {
docset,
score: 1f32,
}
pub fn new(docset: TDocSet, score: f32) -> ConstScorer<TDocSet> {
ConstScorer { docset, score }
}
}
/// Sets the constant score to a different value.
pub fn set_score(&mut self, score: Score) {
self.score = score;
impl<TDocSet: DocSet> From<TDocSet> for ConstScorer<TDocSet> {
fn from(docset: TDocSet) -> Self {
ConstScorer::new(docset, 1.0f32)
}
}
impl<TDocSet: DocSet> DocSet for ConstScorer<TDocSet> {
fn advance(&mut self) -> bool {
fn advance(&mut self) -> DocId {
self.docset.advance()
}
fn skip_next(&mut self, target: DocId) -> SkipResult {
self.docset.skip_next(target)
fn seek(&mut self, target: DocId) -> DocId {
self.docset.seek(target)
}
fn fill_buffer(&mut self, buffer: &mut [DocId]) -> usize {
@@ -82,14 +66,10 @@ impl<TDocSet: DocSet> DocSet for ConstScorer<TDocSet> {
fn size_hint(&self) -> u32 {
self.docset.size_hint()
}
fn append_to_bitset(&mut self, bitset: &mut BitSet) {
self.docset.append_to_bitset(bitset);
}
}
impl<TDocSet: DocSet + 'static> Scorer for ConstScorer<TDocSet> {
fn score(&mut self) -> Score {
1f32
self.score
}
}

View File

@@ -9,13 +9,13 @@ pub use self::term_weight::TermWeight;
#[cfg(test)]
mod tests {
use crate::assert_nearly_equals;
use crate::collector::TopDocs;
use crate::docset::DocSet;
use crate::postings::compression::COMPRESSION_BLOCK_SIZE;
use crate::query::{Query, QueryParser, Scorer, TermQuery};
use crate::schema::{Field, IndexRecordOption, Schema, STRING, TEXT};
use crate::tests::assert_nearly_equals;
use crate::Index;
use crate::Term;
use crate::{Term, Index, TERMINATED};
#[test]
pub fn test_term_query_no_freq() {
@@ -26,10 +26,8 @@ mod tests {
{
// writing the segment
let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap();
{
let doc = doc!(text_field => "a");
index_writer.add_document(doc);
}
let doc = doc!(text_field => "a");
index_writer.add_document(doc);
assert!(index_writer.commit().is_ok());
}
let searcher = index.reader().unwrap().searcher();
@@ -39,12 +37,46 @@ mod tests {
);
let term_weight = term_query.weight(&searcher, true).unwrap();
let segment_reader = searcher.segment_reader(0);
let mut term_scorer = term_weight.scorer(segment_reader).unwrap();
assert!(term_scorer.advance());
let mut term_scorer = term_weight.scorer(segment_reader, 1.0f32).unwrap();
assert_eq!(term_scorer.doc(), 0);
assert_eq!(term_scorer.score(), 0.28768212);
}
#[test]
pub fn test_term_query_multiple_of_block_len() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let text_field = schema_builder.add_text_field("text", STRING);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
{
// writing the segment
let mut index_writer = index.writer_with_num_threads(1, 3_000_000)?;
for _ in 0..COMPRESSION_BLOCK_SIZE {
let doc = doc!(text_field => "a");
index_writer.add_document(doc);
}
index_writer.commit()?;
}
let searcher = index.reader()?.searcher();
let term_query = TermQuery::new(
Term::from_field_text(text_field, "a"),
IndexRecordOption::Basic,
);
let term_weight = term_query.weight(&searcher, true)?;
let segment_reader = searcher.segment_reader(0);
let mut term_scorer = term_weight.scorer(segment_reader, 1.0f32)?;
for i in 0u32..COMPRESSION_BLOCK_SIZE as u32 {
assert_eq!(term_scorer.doc(), i);
if i == COMPRESSION_BLOCK_SIZE as u32 - 1u32 {
assert_eq!(term_scorer.advance(), TERMINATED);
} else {
assert_eq!(term_scorer.advance(), i + 1);
}
}
assert_eq!(term_scorer.doc(), TERMINATED);
Ok(())
}
#[test]
pub fn test_term_weight() {
let mut schema_builder = Schema::builder();
@@ -72,7 +104,7 @@ mod tests {
.unwrap();
assert_eq!(topdocs.len(), 1);
let (score, _) = topdocs[0];
assert_nearly_equals(0.77802235, score);
assert_nearly_equals!(0.77802235, score);
}
{
let term = Term::from_field_text(left_field, "left1");
@@ -82,9 +114,9 @@ mod tests {
.unwrap();
assert_eq!(top_docs.len(), 2);
let (score1, _) = top_docs[0];
assert_nearly_equals(0.27101856, score1);
assert_nearly_equals!(0.27101856, score1);
let (score2, _) = top_docs[1];
assert_nearly_equals(0.13736556, score2);
assert_nearly_equals!(0.13736556, score2);
}
{
let query_parser = QueryParser::for_index(&index, vec![]);
@@ -92,9 +124,9 @@ mod tests {
let top_docs = searcher.search(&query, &TopDocs::with_limit(2)).unwrap();
assert_eq!(top_docs.len(), 2);
let (score1, _) = top_docs[0];
assert_nearly_equals(0.9153879, score1);
assert_nearly_equals!(0.9153879, score1);
let (score2, _) = top_docs[1];
assert_nearly_equals(0.27101856, score2);
assert_nearly_equals!(0.27101856, score2);
}
}
@@ -115,6 +147,27 @@ mod tests {
assert_eq!(term_query.count(&*reader.searcher()).unwrap(), 1);
}
#[test]
fn test_term_query_simple_seek() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let text_field = schema_builder.add_text_field("text", TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap();
index_writer.add_document(doc!(text_field=>"a"));
index_writer.add_document(doc!(text_field=>"a"));
index_writer.commit()?;
let term_a = Term::from_field_text(text_field, "a");
let term_query = TermQuery::new(term_a, IndexRecordOption::Basic);
let searcher = index.reader()?.searcher();
let term_weight = term_query.weight(&searcher, false)?;
let mut term_scorer = term_weight.scorer(searcher.segment_reader(0u32), 1.0f32)?;
assert_eq!(term_scorer.doc(), 0u32);
term_scorer.seek(1u32);
assert_eq!(term_scorer.doc(), 1u32);
Ok(())
}
#[test]
fn test_term_query_debug() {
let term_query = TermQuery::new(

View File

@@ -1,11 +1,11 @@
use crate::docset::{DocSet, SkipResult};
use crate::docset::DocSet;
use crate::query::{Explanation, Scorer};
use crate::DocId;
use crate::Score;
use crate::fieldnorm::FieldNormReader;
use crate::postings::Postings;
use crate::postings::SegmentPostings;
use crate::postings::{FreqReadingOption, Postings};
use crate::query::bm25::BM25Weight;
pub struct TermScorer {
@@ -26,13 +26,70 @@ impl TermScorer {
similarity_weight,
}
}
}
impl TermScorer {
pub(crate) fn shallow_seek(&mut self, target_doc: DocId) {
self.postings.block_cursor.shallow_seek(target_doc)
}
#[cfg(test)]
pub fn create_for_test(
doc_and_tfs: &[(DocId, u32)],
fieldnorm_vals: &[u32],
similarity_weight: BM25Weight,
) -> crate::Result<TermScorer> {
assert!(!doc_and_tfs.is_empty());
assert!(doc_and_tfs.len() <= fieldnorm_vals.len());
let doc_freq = doc_and_tfs.len();
let max_doc = doc_and_tfs.last().unwrap().0 + 1;
let mut fieldnorms: Vec<u32> = std::iter::repeat(1).take(max_doc as usize).collect();
for i in 0..doc_freq {
let doc = doc_and_tfs[i].0;
let fieldnorm = fieldnorm_vals[i];
fieldnorms[doc as usize] = fieldnorm;
}
let fieldnorm_reader = FieldNormReader::from(&fieldnorms[..]);
let segment_postings =
SegmentPostings::create_from_docs_and_tfs(doc_and_tfs, Some(fieldnorm_reader.clone()))?;
Ok(TermScorer::new(segment_postings, fieldnorm_reader, similarity_weight))
}
/// See `FreqReadingOption`.
pub(crate) fn freq_reading_option(&self) -> FreqReadingOption {
self.postings.block_cursor.freq_reading_option()
}
/// Returns the maximum score for the current block.
///
/// In some rare case, the result may not be exact. In this case a lower value is returned,
/// (and may lead us to return a lesser document).
///
/// At index time, we store the (fieldnorm_id, term frequency) pair that maximizes the
/// score assuming the average fieldnorm computed on this segment.
///
/// Though extremely rare, it is theoretically possible that the actual average fieldnorm
/// is different enough from the current segment average fieldnorm that the maximum over a
/// specific is achieved on a different document.
///
/// (The result is on the other hand guaranteed to be correct if there is only one segment).
pub fn block_max_score(&mut self) -> Score {
self.postings
.block_cursor
.block_max_score(&self.fieldnorm_reader, &self.similarity_weight)
}
pub fn term_freq(&self) -> u32 {
self.postings.term_freq()
}
pub fn doc_freq(&self) -> usize {
self.postings.doc_freq() as usize
}
pub fn fieldnorm_id(&self) -> u8 {
self.fieldnorm_reader.fieldnorm_id(self.doc())
}
@@ -42,15 +99,23 @@ impl TermScorer {
let term_freq = self.term_freq();
self.similarity_weight.explain(fieldnorm_id, term_freq)
}
pub fn max_score(&self) -> f32 {
self.similarity_weight.max_score()
}
pub fn last_doc_in_block(&self) -> DocId {
self.postings.block_cursor.skip_reader.last_doc_in_block()
}
}
impl DocSet for TermScorer {
fn advance(&mut self) -> bool {
fn advance(&mut self) -> DocId {
self.postings.advance()
}
fn skip_next(&mut self, target: DocId) -> SkipResult {
self.postings.skip_next(target)
fn seek(&mut self, target: DocId) -> DocId {
self.postings.seek(target)
}
fn doc(&self) -> DocId {
@@ -69,3 +134,99 @@ impl Scorer for TermScorer {
self.similarity_weight.score(fieldnorm_id, term_freq)
}
}
#[cfg(test)]
mod tests {
use crate::assert_nearly_equals;
use crate::postings::compression::COMPRESSION_BLOCK_SIZE;
use crate::query::term_query::TermScorer;
use crate::query::{BM25Weight, Scorer};
use crate::{DocId, DocSet, TERMINATED};
use proptest::prelude::*;
#[test]
fn test_term_scorer_max_score() -> crate::Result<()> {
let bm25_weight = BM25Weight::for_one_term(3, 6, 10f32);
let mut term_scorer =
TermScorer::create_for_test(&[(2, 3), (3, 12), (7, 8)], &[10, 12, 100], bm25_weight)?;
let max_scorer = term_scorer.max_score();
assert_eq!(max_scorer, 1.3990127f32);
assert_eq!(term_scorer.doc(), 2);
assert_eq!(term_scorer.term_freq(), 3);
assert_nearly_equals!(term_scorer.block_max_score(), 1.3676447f32);
assert_nearly_equals!(term_scorer.score(), 1.0892314f32);
assert_eq!(term_scorer.advance(), 3);
assert_eq!(term_scorer.doc(), 3);
assert_eq!(term_scorer.term_freq(), 12);
assert_nearly_equals!(term_scorer.score(), 1.3676447f32);
assert_eq!(term_scorer.advance(), 7);
assert_eq!(term_scorer.doc(), 7);
assert_eq!(term_scorer.term_freq(), 8);
assert_nearly_equals!(term_scorer.score(), 0.72015285f32);
assert_eq!(term_scorer.advance(), TERMINATED);
Ok(())
}
#[test]
fn test_term_scorer_shallow_advance() -> crate::Result<()> {
let bm25_weight = BM25Weight::for_one_term(300, 1024, 10f32);
let mut doc_and_tfs = vec![];
for i in 0u32..300u32 {
let doc = i * 10;
doc_and_tfs.push((doc, 1u32 + doc % 3u32));
}
let fieldnorms: Vec<u32> = std::iter::repeat(10u32).take(1024).collect();
let mut term_scorer =
TermScorer::create_for_test(&doc_and_tfs, &fieldnorms, bm25_weight)?;
assert_eq!(term_scorer.doc(), 0u32);
term_scorer.shallow_seek(1289);
assert_eq!(term_scorer.doc(), 0u32);
term_scorer.seek(1289);
assert_eq!(term_scorer.doc(), 1290);
Ok(())
}
proptest! {
#[test]
fn test_term_scorer_block_max_score(term_freqs_fieldnorms in proptest::collection::vec((1u32..10u32, 0u32..100u32), 80..300)) {
let term_doc_freq = term_freqs_fieldnorms.len();
let doc_tfs: Vec<(u32, u32)> = term_freqs_fieldnorms.iter()
.cloned()
.enumerate()
.map(|(doc, (tf, _))| (doc as u32, tf))
.collect();
let mut fieldnorms: Vec<u32> = vec![];
for i in 0..term_doc_freq {
let (tf, num_extra_terms) = term_freqs_fieldnorms[i];
fieldnorms.push(tf + num_extra_terms);
}
let average_fieldnorm = fieldnorms
.iter()
.cloned()
.sum::<u32>() as f32 / term_doc_freq as f32;
// Average fieldnorm is over the entire index,
// not necessarily the docs that are in the posting list.
// For this reason we multiply by 1.1 to make a realistic value.
let bm25_weight = BM25Weight::for_one_term(term_doc_freq as u64,
term_doc_freq as u64 * 10u64,
average_fieldnorm);
let mut term_scorer =
TermScorer::create_for_test(&doc_tfs[..], &fieldnorms[..], bm25_weight).unwrap();
let docs: Vec<DocId> = (0..term_doc_freq).map(|doc| doc as DocId).collect();
for block in docs.chunks(COMPRESSION_BLOCK_SIZE) {
let block_max_score = term_scorer.block_max_score();
let mut block_max_score_computed = 0.0f32;
for &doc in block {
assert_eq!(term_scorer.doc(), doc);
block_max_score_computed = block_max_score_computed.max(term_scorer.score());
term_scorer.advance();
}
assert_nearly_equals!(block_max_score_computed, block_max_score);
}
}
}
}

View File

@@ -4,12 +4,13 @@ use crate::docset::DocSet;
use crate::postings::SegmentPostings;
use crate::query::bm25::BM25Weight;
use crate::query::explanation::does_not_match;
use crate::query::weight::{for_each_pruning_scorer, for_each_scorer};
use crate::query::Weight;
use crate::query::{Explanation, Scorer};
use crate::schema::IndexRecordOption;
use crate::DocId;
use crate::Result;
use crate::Term;
use crate::{Result, SkipResult};
use crate::{DocId, Score};
pub struct TermWeight {
term: Term,
@@ -18,14 +19,14 @@ pub struct TermWeight {
}
impl Weight for TermWeight {
fn scorer(&self, reader: &SegmentReader) -> Result<Box<dyn Scorer>> {
let term_scorer = self.scorer_specialized(reader)?;
fn scorer(&self, reader: &SegmentReader, boost: f32) -> Result<Box<dyn Scorer>> {
let term_scorer = self.specialized_scorer(reader, boost)?;
Ok(Box::new(term_scorer))
}
fn explain(&self, reader: &SegmentReader, doc: DocId) -> Result<Explanation> {
let mut scorer = self.scorer_specialized(reader)?;
if scorer.skip_next(doc) != SkipResult::Reached {
let mut scorer = self.specialized_scorer(reader, 1.0f32)?;
if scorer.seek(doc) != doc {
return Err(does_not_match(doc));
}
Ok(scorer.explain())
@@ -33,7 +34,7 @@ impl Weight for TermWeight {
fn count(&self, reader: &SegmentReader) -> Result<u32> {
if let Some(delete_bitset) = reader.delete_bitset() {
Ok(self.scorer(reader)?.count(delete_bitset))
Ok(self.scorer(reader, 1.0f32)?.count(delete_bitset))
} else {
let field = self.term.field();
Ok(reader
@@ -43,6 +44,39 @@ impl Weight for TermWeight {
.unwrap_or(0))
}
}
/// Iterates through all of the document matched by the DocSet
/// `DocSet` and push the scored documents to the collector.
fn for_each(
&self,
reader: &SegmentReader,
callback: &mut dyn FnMut(DocId, Score),
) -> crate::Result<()> {
let mut scorer = self.specialized_scorer(reader, 1.0f32)?;
for_each_scorer(&mut scorer, callback);
Ok(())
}
/// Calls `callback` with all of the `(doc, score)` for which score
/// is exceeding a given threshold.
///
/// This method is useful for the TopDocs collector.
/// For all docsets, the blanket implementation has the benefit
/// of prefiltering (doc, score) pairs, avoiding the
/// virtual dispatch cost.
///
/// More importantly, it makes it possible for scorers to implement
/// important optimization (e.g. BlockWAND for union).
fn for_each_pruning(
&self,
threshold: f32,
reader: &SegmentReader,
callback: &mut dyn FnMut(DocId, Score) -> Score,
) -> crate::Result<()> {
let mut scorer = self.scorer(reader, 1.0f32)?;
for_each_pruning_scorer(&mut scorer, threshold, callback);
Ok(())
}
}
impl TermWeight {
@@ -58,11 +92,11 @@ impl TermWeight {
}
}
fn scorer_specialized(&self, reader: &SegmentReader) -> Result<TermScorer> {
pub fn specialized_scorer(&self, reader: &SegmentReader, boost: f32) -> Result<TermScorer> {
let field = self.term.field();
let inverted_index = reader.inverted_index(field);
let fieldnorm_reader = reader.get_fieldnorms_reader(field);
let similarity_weight = self.similarity_weight.clone();
let similarity_weight = self.similarity_weight.boost_by(boost);
let postings_opt: Option<SegmentPostings> =
inverted_index.read_postings(&self.term, self.index_record_option);
if let Some(segment_postings) = postings_opt {

View File

@@ -1,10 +1,9 @@
use crate::common::TinySet;
use crate::docset::{DocSet, SkipResult};
use crate::docset::{DocSet, TERMINATED};
use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner};
use crate::query::Scorer;
use crate::DocId;
use crate::Score;
use std::cmp::Ordering;
const HORIZON_NUM_TINYBITSETS: usize = 64;
const HORIZON: u32 = 64u32 * HORIZON_NUM_TINYBITSETS as u32;
@@ -47,17 +46,9 @@ where
fn from(docsets: Vec<TScorer>) -> Union<TScorer, TScoreCombiner> {
let non_empty_docsets: Vec<TScorer> = docsets
.into_iter()
.flat_map(
|mut docset| {
if docset.advance() {
Some(docset)
} else {
None
}
},
)
.filter(|docset| docset.doc() != TERMINATED)
.collect();
Union {
let mut union = Union {
docsets: non_empty_docsets,
bitsets: Box::new([TinySet::empty(); HORIZON_NUM_TINYBITSETS]),
scores: Box::new([TScoreCombiner::default(); HORIZON as usize]),
@@ -65,7 +56,13 @@ where
offset: 0,
doc: 0,
score: 0f32,
};
if union.refill() {
union.advance();
} else {
union.doc = TERMINATED;
}
union
}
}
@@ -86,7 +83,7 @@ fn refill<TScorer: Scorer, TScoreCombiner: ScoreCombiner>(
let delta = doc - min_doc;
bitsets[(delta / 64) as usize].insert_mut(delta % 64u32);
score_combiner[delta as usize].update(scorer);
if !scorer.advance() {
if scorer.advance() == TERMINATED {
// remove the docset, it has been entirely consumed.
return true;
}
@@ -99,6 +96,7 @@ impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> Union<TScorer, TScoreCombin
if let Some(min_doc) = self.docsets.iter().map(DocSet::doc).min() {
self.offset = min_doc;
self.cursor = 0;
self.doc = min_doc;
refill(
&mut self.docsets,
&mut *self.bitsets,
@@ -133,50 +131,23 @@ where
TScorer: Scorer,
TScoreCombiner: ScoreCombiner,
{
fn advance(&mut self) -> bool {
fn advance(&mut self) -> DocId {
if self.advance_buffered() {
return true;
return self.doc;
}
if self.refill() {
self.advance();
true
} else {
false
if !self.refill() {
self.doc = TERMINATED;
return TERMINATED;
}
if !self.advance_buffered() {
return TERMINATED;
}
self.doc
}
fn count_including_deleted(&mut self) -> u32 {
let mut count = self.bitsets[self.cursor..HORIZON_NUM_TINYBITSETS]
.iter()
.map(|bitset| bitset.len())
.sum::<u32>();
for bitset in self.bitsets.iter_mut() {
bitset.clear();
}
while self.refill() {
count += self.bitsets.iter().map(|bitset| bitset.len()).sum::<u32>();
for bitset in self.bitsets.iter_mut() {
bitset.clear();
}
}
self.cursor = HORIZON_NUM_TINYBITSETS;
count
}
// TODO implement `count` efficiently.
fn skip_next(&mut self, target: DocId) -> SkipResult {
if !self.advance() {
return SkipResult::End;
}
match self.doc.cmp(&target) {
Ordering::Equal => {
return SkipResult::Reached;
}
Ordering::Greater => {
return SkipResult::OverStep;
}
Ordering::Less => {}
fn seek(&mut self, target: DocId) -> DocId {
if self.doc >= target {
return self.doc;
}
let gap = target - self.offset;
if gap < HORIZON {
@@ -194,18 +165,11 @@ where
// Advancing until we reach the end of the bucket
// or we reach a doc greater or equal to the target.
while self.advance() {
match self.doc().cmp(&target) {
Ordering::Equal => {
return SkipResult::Reached;
}
Ordering::Greater => {
return SkipResult::OverStep;
}
Ordering::Less => {}
}
let mut doc = self.doc();
while doc < target {
doc = self.advance();
}
SkipResult::End
doc
} else {
// clear the buffered info.
for obsolete_tinyset in self.bitsets.iter_mut() {
@@ -220,35 +184,55 @@ where
#[cfg_attr(feature = "cargo-clippy", allow(clippy::clippy::collapsible_if))]
unordered_drain_filter(&mut self.docsets, |docset| {
if docset.doc() < target {
if docset.skip_next(target) == SkipResult::End {
return true;
}
docset.seek(target);
}
false
docset.doc() == TERMINATED
});
// at this point all of the docsets
// are positionned on a doc >= to the target.
if self.refill() {
self.advance();
if self.doc() == target {
SkipResult::Reached
} else {
debug_assert!(self.doc() > target);
SkipResult::OverStep
}
} else {
SkipResult::End
if !self.refill() {
self.doc = TERMINATED;
return TERMINATED;
}
self.advance()
}
}
// TODO Also implement `count` with deletes efficiently.
fn doc(&self) -> DocId {
self.doc
}
fn size_hint(&self) -> u32 {
0u32
self.docsets
.iter()
.map(|docset| docset.size_hint())
.max()
.unwrap_or(0u32)
}
fn count_including_deleted(&mut self) -> u32 {
if self.doc == TERMINATED {
return 0;
}
let mut count = self.bitsets[self.cursor..HORIZON_NUM_TINYBITSETS]
.iter()
.map(|bitset| bitset.len())
.sum::<u32>()
+ 1;
for bitset in self.bitsets.iter_mut() {
bitset.clear();
}
while self.refill() {
count += self.bitsets.iter().map(|bitset| bitset.len()).sum::<u32>();
for bitset in self.bitsets.iter_mut() {
bitset.clear();
}
}
self.cursor = HORIZON_NUM_TINYBITSETS;
count
}
}
@@ -267,7 +251,7 @@ mod tests {
use super::Union;
use super::HORIZON;
use crate::docset::{DocSet, SkipResult};
use crate::docset::{DocSet, TERMINATED};
use crate::postings::tests::test_skip_against_unoptimized;
use crate::query::score_combiner::DoNothingCombiner;
use crate::query::ConstScorer;
@@ -290,18 +274,18 @@ mod tests {
vals.iter()
.cloned()
.map(VecDocSet::from)
.map(ConstScorer::new)
.map(|docset| ConstScorer::new(docset, 1.0f32))
.collect::<Vec<ConstScorer<VecDocSet>>>(),
)
};
let mut union: Union<_, DoNothingCombiner> = make_union();
let mut count = 0;
while union.advance() {
assert!(union_expected.advance());
while union.doc() != TERMINATED {
assert_eq!(union_expected.doc(), union.doc());
assert_eq!(union_expected.advance(), union.advance());
count += 1;
}
assert!(!union_expected.advance());
assert_eq!(union_expected.advance(), TERMINATED);
assert_eq!(count, make_union().count_including_deleted());
}
@@ -329,9 +313,7 @@ mod tests {
fn test_aux_union_skip(docs_list: &[Vec<DocId>], skip_targets: Vec<DocId>) {
let mut btree_set = BTreeSet::new();
for docs in docs_list {
for &doc in docs.iter() {
btree_set.insert(doc);
}
btree_set.extend(docs.iter().cloned());
}
let docset_factory = || {
let res: Box<dyn DocSet> = Box::new(Union::<_, DoNothingCombiner>::from(
@@ -339,17 +321,17 @@ mod tests {
.iter()
.map(|docs| docs.clone())
.map(VecDocSet::from)
.map(ConstScorer::new)
.map(|docset| ConstScorer::new(docset, 1.0f32))
.collect::<Vec<_>>(),
));
res
};
let mut docset = docset_factory();
for el in btree_set {
assert!(docset.advance());
assert_eq!(el, docset.doc());
docset.advance();
}
assert!(!docset.advance());
assert_eq!(docset.doc(), TERMINATED);
test_skip_against_unoptimized(docset_factory, skip_targets);
}
@@ -369,13 +351,13 @@ mod tests {
#[test]
fn test_union_skip_corner_case3() {
let mut docset = Union::<_, DoNothingCombiner>::from(vec![
ConstScorer::new(VecDocSet::from(vec![0u32, 5u32])),
ConstScorer::new(VecDocSet::from(vec![1u32, 4u32])),
ConstScorer::from(VecDocSet::from(vec![0u32, 5u32])),
ConstScorer::from(VecDocSet::from(vec![1u32, 4u32])),
]);
assert!(docset.advance());
assert_eq!(docset.doc(), 0u32);
assert_eq!(docset.skip_next(0u32), SkipResult::OverStep);
assert_eq!(docset.doc(), 1u32)
assert_eq!(docset.seek(0u32), 0u32);
assert_eq!(docset.seek(0u32), 0u32);
assert_eq!(docset.doc(), 0u32)
}
#[test]

View File

@@ -1,9 +1,8 @@
#![allow(dead_code)]
use crate::common::HasLen;
use crate::docset::DocSet;
use crate::docset::{DocSet, TERMINATED};
use crate::DocId;
use std::num::Wrapping;
/// Simulate a `Postings` objects from a `VecPostings`.
/// `VecPostings` only exist for testing purposes.
@@ -12,26 +11,30 @@ use std::num::Wrapping;
/// No positions are returned.
pub struct VecDocSet {
doc_ids: Vec<DocId>,
cursor: Wrapping<usize>,
cursor: usize,
}
impl From<Vec<DocId>> for VecDocSet {
fn from(doc_ids: Vec<DocId>) -> VecDocSet {
VecDocSet {
doc_ids,
cursor: Wrapping(usize::max_value()),
}
VecDocSet { doc_ids, cursor: 0 }
}
}
impl DocSet for VecDocSet {
fn advance(&mut self) -> bool {
self.cursor += Wrapping(1);
self.doc_ids.len() > self.cursor.0
fn advance(&mut self) -> DocId {
self.cursor += 1;
if self.cursor >= self.doc_ids.len() {
self.cursor = self.doc_ids.len();
return TERMINATED;
}
self.doc()
}
fn doc(&self) -> DocId {
self.doc_ids[self.cursor.0]
if self.cursor == self.doc_ids.len() {
return TERMINATED;
}
self.doc_ids[self.cursor]
}
fn size_hint(&self) -> u32 {
@@ -49,22 +52,21 @@ impl HasLen for VecDocSet {
pub mod tests {
use super::*;
use crate::docset::{DocSet, SkipResult};
use crate::docset::DocSet;
use crate::DocId;
#[test]
pub fn test_vec_postings() {
let doc_ids: Vec<DocId> = (0u32..1024u32).map(|e| e * 3).collect();
let mut postings = VecDocSet::from(doc_ids);
assert!(postings.advance());
assert_eq!(postings.doc(), 0u32);
assert!(postings.advance());
assert_eq!(postings.advance(), 3u32);
assert_eq!(postings.doc(), 3u32);
assert_eq!(postings.skip_next(14u32), SkipResult::OverStep);
assert_eq!(postings.seek(14u32), 15u32);
assert_eq!(postings.doc(), 15u32);
assert_eq!(postings.skip_next(300u32), SkipResult::Reached);
assert_eq!(postings.seek(300u32), 300u32);
assert_eq!(postings.doc(), 300u32);
assert_eq!(postings.skip_next(6000u32), SkipResult::End);
assert_eq!(postings.seek(6000u32), TERMINATED);
}
#[test]

View File

@@ -1,7 +1,45 @@
use super::Scorer;
use crate::core::SegmentReader;
use crate::query::Explanation;
use crate::{DocId, Result};
use crate::{DocId, Score, TERMINATED};
/// Iterates through all of the document matched by the DocSet
/// `DocSet` and push the scored documents to the collector.
pub(crate) fn for_each_scorer<TScorer: Scorer + ?Sized>(
scorer: &mut TScorer,
callback: &mut dyn FnMut(DocId, Score),
) {
let mut doc = scorer.doc();
while doc != TERMINATED {
callback(doc, scorer.score());
doc = scorer.advance();
}
}
/// Calls `callback` with all of the `(doc, score)` for which score
/// is exceeding a given threshold.
///
/// This method is useful for the TopDocs collector.
/// For all docsets, the blanket implementation has the benefit
/// of prefiltering (doc, score) pairs, avoiding the
/// virtual dispatch cost.
///
/// More importantly, it makes it possible for scorers to implement
/// important optimization (e.g. BlockWAND for union).
pub(crate) fn for_each_pruning_scorer<TScorer: Scorer + ?Sized>(
scorer: &mut TScorer,
mut threshold: f32,
callback: &mut dyn FnMut(DocId, Score) -> Score,
) {
let mut doc = scorer.doc();
while doc != TERMINATED {
let score = scorer.score();
if score > threshold {
threshold = callback(doc, score);
}
doc = scorer.advance();
}
}
/// A Weight is the specialization of a Query
/// for a given set of segments.
@@ -9,19 +47,55 @@ use crate::{DocId, Result};
/// See [`Query`](./trait.Query.html).
pub trait Weight: Send + Sync + 'static {
/// Returns the scorer for the given segment.
///
/// `boost` is a multiplier to apply to the score.
///
/// See [`Query`](./trait.Query.html).
fn scorer(&self, reader: &SegmentReader) -> Result<Box<dyn Scorer>>;
fn scorer(&self, reader: &SegmentReader, boost: f32) -> crate::Result<Box<dyn Scorer>>;
/// Returns an `Explanation` for the given document.
fn explain(&self, reader: &SegmentReader, doc: DocId) -> Result<Explanation>;
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation>;
/// Returns the number documents within the given `SegmentReader`.
fn count(&self, reader: &SegmentReader) -> Result<u32> {
let mut scorer = self.scorer(reader)?;
fn count(&self, reader: &SegmentReader) -> crate::Result<u32> {
let mut scorer = self.scorer(reader, 1.0f32)?;
if let Some(delete_bitset) = reader.delete_bitset() {
Ok(scorer.count(delete_bitset))
} else {
Ok(scorer.count_including_deleted())
}
}
/// Iterates through all of the document matched by the DocSet
/// `DocSet` and push the scored documents to the collector.
fn for_each(
&self,
reader: &SegmentReader,
callback: &mut dyn FnMut(DocId, Score),
) -> crate::Result<()> {
let mut scorer = self.scorer(reader, 1.0f32)?;
for_each_scorer(scorer.as_mut(), callback);
Ok(())
}
/// Calls `callback` with all of the `(doc, score)` for which score
/// is exceeding a given threshold.
///
/// This method is useful for the TopDocs collector.
/// For all docsets, the blanket implementation has the benefit
/// of prefiltering (doc, score) pairs, avoiding the
/// virtual dispatch cost.
///
/// More importantly, it makes it possible for scorers to implement
/// important optimization (e.g. BlockWAND for union).
fn for_each_pruning(
&self,
threshold: f32,
reader: &SegmentReader,
callback: &mut dyn FnMut(DocId, Score) -> Score,
) -> crate::Result<()> {
let mut scorer = self.scorer(reader, 1.0f32)?;
for_each_pruning_scorer(scorer.as_mut(), threshold, callback);
Ok(())
}
}

View File

@@ -9,6 +9,7 @@ use crate::directory::META_LOCK;
use crate::Index;
use crate::Searcher;
use crate::SegmentReader;
use std::convert::TryInto;
use std::sync::Arc;
/// Defines when a new version of the index should be reloaded.
@@ -21,7 +22,7 @@ pub enum ReloadPolicy {
/// The index is entirely reloaded manually.
/// All updates of the index should be manual.
///
/// No change is reflected automatically. You are required to call `.load_seacher()` manually.
/// No change is reflected automatically. You are required to call `IndexReader::reload()` manually.
Manual,
/// The index is reloaded within milliseconds after a new commit is available.
/// This is made possible by watching changes in the `meta.json` file.
@@ -60,7 +61,6 @@ impl IndexReaderBuilder {
/// Building the reader is a non-trivial operation that requires
/// to open different segment readers. It may take hundreds of milliseconds
/// of time and it may return an error.
/// TODO(pmasurel) Use the `TryInto` trait once it is available in stable.
pub fn try_into(self) -> crate::Result<IndexReader> {
let inner_reader = InnerIndexReader {
index: self.index,
@@ -113,6 +113,14 @@ impl IndexReaderBuilder {
}
}
impl TryInto<IndexReader> for IndexReaderBuilder {
type Error = crate::TantivyError;
fn try_into(self) -> crate::Result<IndexReader> {
IndexReaderBuilder::try_into(self)
}
}
struct InnerIndexReader {
num_searchers: usize,
searcher_pool: Pool<Searcher>,

View File

@@ -68,7 +68,9 @@ impl<T> Pool<T> {
/// After publish, all new `Searcher` acquired will be
/// of the new generation.
pub fn publish_new_generation(&self, items: Vec<T>) {
assert!(!items.is_empty());
let next_generation = self.next_generation.fetch_add(1, Ordering::SeqCst) + 1;
let num_items = items.len();
for item in items {
let gen_item = GenerationItem {
item,
@@ -77,6 +79,23 @@ impl<T> Pool<T> {
self.queue.push(gen_item);
}
self.advertise_generation(next_generation);
// Purge possible previous searchers.
//
// Assuming at this point no searcher is held more than duration T by the user,
// this guarantees that an obsolete searcher will not be uselessly held (and its associated
// mmap) for more than duration T.
//
// Proof: At this point, obsolete searcher that are held by the user will be held for less
// than T. When released, they will be dropped as their generation is detected obsolete.
//
// We still need to ensure that the searcher that are obsolete and in the pool get removed.
// The queue currently contains up to 2n searchers, in any random order.
//
// Half of them are obsoletes. By requesting `(n+1)` fresh searchers, we ensure that all
// searcher will be inspected.
for _ in 0..=num_items {
let _ = self.acquire();
}
}
/// At the exit of this method,

View File

@@ -3,8 +3,8 @@ use crate::common::BinarySerializable;
use crate::common::VInt;
use crate::tokenizer::PreTokenizedString;
use crate::DateTime;
use itertools::Itertools;
use std::io::{self, Read, Write};
use std::mem;
/// Tantivy's Document is the object that can
/// be indexed and then searched for.
@@ -16,7 +16,7 @@ use std::io::{self, Read, Write};
/// Documents are really just a list of couple `(field, value)`.
/// In this list, one field may appear more than once.
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize, Default)]
pub struct Document {
field_values: Vec<FieldValue>,
}
@@ -131,12 +131,34 @@ impl Document {
pub fn get_sorted_field_values(&self) -> Vec<(Field, Vec<&FieldValue>)> {
let mut field_values: Vec<&FieldValue> = self.field_values().iter().collect();
field_values.sort_by_key(|field_value| field_value.field());
field_values
.into_iter()
.group_by(|field_value| field_value.field())
.into_iter()
.map(|(key, group)| (key, group.collect()))
.collect::<Vec<(Field, Vec<&FieldValue>)>>()
let mut grouped_field_values = vec![];
let mut current_field;
let mut current_group;
let mut field_values_it = field_values.into_iter();
if let Some(field_value) = field_values_it.next() {
current_field = field_value.field();
current_group = vec![field_value]
} else {
return grouped_field_values;
}
for field_value in field_values_it {
if field_value.field() == current_field {
current_group.push(field_value);
} else {
grouped_field_values.push((
current_field,
mem::replace(&mut current_group, vec![field_value]),
));
current_field = field_value.field();
}
}
grouped_field_values.push((current_field, current_group));
grouped_field_values
}
/// Returns all of the `FieldValue`s associated the given field

View File

@@ -125,7 +125,7 @@ impl Facet {
/// This function is the inverse of Facet::from(&str).
pub fn to_path_string(&self) -> String {
format!("{}", self.to_string())
format!("{}", self)
}
}

View File

@@ -5,18 +5,20 @@ use std::io::Write;
/// `Field` is represented by an unsigned 32-bit integer type
/// The schema holds the mapping between field names and `Field` objects.
#[derive(Copy, Clone, Debug, PartialEq, PartialOrd, Eq, Ord, Hash, Serialize, Deserialize)]
#[derive(
Copy, Clone, Debug, PartialEq, PartialOrd, Eq, Ord, Hash, serde::Serialize, serde::Deserialize,
)]
pub struct Field(u32);
impl Field {
/// Create a new field object for the given FieldId.
pub fn from_field_id(field_id: u32) -> Field {
pub const fn from_field_id(field_id: u32) -> Field {
Field(field_id)
}
/// Returns a u32 identifying uniquely a field within a schema.
#[allow(clippy::trivially_copy_pass_by_ref)]
pub fn field_id(&self) -> u32 {
pub const fn field_id(&self) -> u32 {
self.0
}
}

View File

@@ -14,7 +14,7 @@ use std::fmt;
/// - a field name
/// - a field type, itself wrapping up options describing
/// how the field should be indexed.
#[derive(Clone, Debug, Eq, PartialEq)]
#[derive(Clone, Debug, PartialEq)]
pub struct FieldEntry {
name: String,
field_type: FieldType,

View File

@@ -48,7 +48,7 @@ pub enum Type {
/// A `FieldType` describes the type (text, u64) of a field as well as
/// how it should be handled by tantivy.
#[derive(Clone, Debug, Eq, PartialEq)]
#[derive(Clone, Debug, PartialEq)]
pub enum FieldType {
/// String field type configuration
Str(TextOptions),

View File

@@ -1,12 +1,10 @@
use crate::common::BinarySerializable;
use crate::schema::Field;
use crate::schema::Value;
use std::io;
use std::io::Read;
use std::io::Write;
use std::io::{self, Read, Write};
/// `FieldValue` holds together a `Field` and its `Value`.
#[derive(Debug, Clone, Ord, PartialEq, Eq, PartialOrd, Serialize, Deserialize)]
#[derive(Debug, Clone, Ord, PartialEq, Eq, PartialOrd, serde::Serialize, serde::Deserialize)]
pub struct FieldValue {
field: Field,
value: Value,

View File

@@ -1,3 +1,5 @@
use serde::{Deserialize, Serialize};
/// `IndexRecordOption` describes an amount information associated
/// to a given indexed field.
///

View File

@@ -1,4 +1,5 @@
use crate::schema::flags::{FastFlag, IndexedFlag, SchemaFlagList, StoredFlag};
use serde::{Deserialize, Serialize};
use std::ops::BitOr;
/// Express whether a field is single-value or multi-valued.

View File

@@ -1,4 +1,5 @@
use crate::schema::Value;
use serde::Serialize;
use std::collections::BTreeMap;
/// Internal representation of a document used for JSON

View File

@@ -1,11 +1,12 @@
use crate::schema::flags::SchemaFlagList;
use crate::schema::flags::StoredFlag;
use crate::schema::IndexRecordOption;
use serde::{Deserialize, Serialize};
use std::borrow::Cow;
use std::ops::BitOr;
/// Define how a text field should be handled by tantivy.
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct TextOptions {
indexing: Option<TextFieldIndexing>,
stored: bool,
@@ -50,7 +51,7 @@ impl Default for TextOptions {
/// - the amount of information that should be stored about the presence of a term in a document.
/// Essentially, should we store the term frequency and/or the positions (See [`IndexRecordOption`](./enum.IndexRecordOption.html)).
/// - the name of the `Tokenizer` that should be used to process the field.
#[derive(Clone, PartialEq, Eq, Debug, Serialize, Deserialize)]
#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)]
pub struct TextFieldIndexing {
record: IndexRecordOption,
tokenizer: Cow<'static, str>,
@@ -155,30 +156,17 @@ mod tests {
#[test]
fn test_field_options() {
{
let field_options = STORED | TEXT;
assert!(field_options.is_stored());
assert!(field_options.get_indexing_options().is_some());
}
{
let mut schema_builder = Schema::builder();
schema_builder.add_text_field("body", TEXT);
let schema = schema_builder.build();
let field = schema.get_field("body").unwrap();
let field_entry = schema.get_field_entry(field);
match field_entry.field_type() {
&FieldType::Str(ref text_options) => {
assert!(text_options.get_indexing_options().is_some());
assert_eq!(
text_options.get_indexing_options().unwrap().tokenizer(),
"default"
);
}
_ => {
panic!("");
}
}
}
let field_options = STORED | TEXT;
assert!(field_options.is_stored());
assert!(field_options.get_indexing_options().is_some());
let mut schema_builder = Schema::builder();
schema_builder.add_text_field("body", TEXT);
let schema = schema_builder.build();
let field = schema.get_field("body").unwrap();
let field_entry = schema.get_field_entry(field);
assert!(matches!(field_entry.field_type(),
&FieldType::Str(ref text_options)
if text_options.get_indexing_options().unwrap().tokenizer() == "default"));
}
#[test]

View File

@@ -11,6 +11,7 @@ under-count actual resultant space usage by up to 4095 bytes per file.
use crate::schema::Field;
use crate::SegmentComponent;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// Indicates space usage in bytes

Some files were not shown because too many files have changed in this diff Show More