Compare commits

..

30 Commits

Author SHA1 Message Date
cong.xie
e9c6383b8c fix: adapt composite aggregation for Quickwit compatibility
Fixes on top of PR #2714 ("Add composite aggregation") to make it work
with Quickwit's current codebase and Postcard serialization:

- Rewrite SegmentCompositeCollector to match current
  SegmentAggregationCollector trait signatures (collect,
  add_intermediate_aggregation_result, prepare_max_bucket)
- Remove Clone derive from CompositeBucketCollector (incompatible
  with dyn SegmentAggregationCollector)
- Add custom serde for FxHashMap entries in
  IntermediateCompositeBucketResult (Postcard requires known sequence
  length)
- Rewrite AfterKey Serialize/Deserialize to output raw values instead
  of internal "type:value" format, matching Elasticsearch wire format
- Remove unused imports and tracing::warn calls

Made-with: Cursor
2026-03-13 15:38:25 -04:00
Remi Dettai
d662415b81 Add composite aggregation 2026-03-13 10:32:41 -04:00
cong.xie
2dc4e9ef78 fix: resolve remaining clippy errors in ddsketch
- Replace approximate PI/E constants with non-famous value in test
- Fix reversed empty range (2048..0) → (0..2048).rev() in store test

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-02-18 15:54:27 -05:00
cong.xie
aeea65f61d refactor: rewrite encoding.rs with idiomatic Rust
- Replace bare constants with FlagType and BinEncodingMode enums
- Use const fn for flag byte construction instead of raw bit ops
- Replace if-else chain with nested match in decode_from_java_bytes
- Use split_first() in read_byte for idiomatic slice consumption
- Use split_at in read_f64_le to avoid TryInto on edition 2018
- Use u64::from(next) instead of `next as u64` casts
- Extract assert_golden, assert_quantiles_match, bytes_to_hex helpers
  to reduce duplication across golden byte tests
- Fix edition-2018 assert! format string compatibility
- Clean up is_valid_flag_byte with let-else and match

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-02-18 15:49:12 -05:00
cong.xie
4211d5a1ed fix: resolve clippy warnings in vendored sketches-ddsketch
- manual_range_contains: use !(0.0..=1.0).contains(&q)
- identity_op: simplify (0 << 2) | FLAG_TYPE to just FLAG_TYPE
- manual_clamp: use .clamp(0, 8) instead of .max(0).min(8)
- manual_repeat_n: use repeat_n() instead of repeat().take()
- cast_abs_to_unsigned: use .unsigned_abs() instead of .abs() as usize

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-02-18 13:36:06 -05:00
cong.xie
d50c7a1daf Add Java source links for cross-language alignment comments
Reference the exact Java source files in DataDog/sketches-java for
Config::new(), Config::key(), Config::value(), Config::from_gamma(),
and Store::add_count() so readers can verify the alignment.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-02-18 13:25:12 -05:00
cong.xie
cf760fd5b6 fix: remove internal reference from code comment
Co-authored-by: Cursor <cursoragent@cursor.com>
2026-02-18 12:59:25 -05:00
cong.xie
df04c7d8f1 fix: rustfmt nightly formatting for vendored sketches-ddsketch
Co-authored-by: Cursor <cursoragent@cursor.com>
2026-02-18 12:53:01 -05:00
cong.xie
68626bf3a1 Vendor sketches-ddsketch with Java-compatible binary encoding
Fork sketches-ddsketch as a workspace member to add native Java binary
serialization (to_java_bytes/from_java_bytes) for DDSketch. This enables
pomsky to return raw DDSketch bytes that event-query can deserialize via
DDSketchWithExactSummaryStatistics.decode().

Key changes:
- Vendor sketches-ddsketch crate with encoding.rs implementing VarEncoding,
  flag bytes, and INDEX_DELTAS_AND_COUNTS store format
- Align Config::key() to floor-based indexing matching Java's LogarithmicMapping
- Add PercentilesCollector::to_sketch_bytes() for pomsky integration
- Cross-language golden byte tests verified byte-identical with Java output

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-02-18 11:36:21 -05:00
cong.xie
7eca33143e Remove Datadog-specific references from comments
This is an open-source repo — replace references to Datadog's event query
with generic cross-language compatibility descriptions.
2026-02-12 11:44:42 -05:00
cong.xie
698f073f88 fix fmt 2026-02-11 15:52:39 -05:00
cong.xie
cdd24b7ee5 Replace hyperloglogplus with Apache DataSketches HLL (lg_k=11)
Switch tantivy's cardinality aggregation from the hyperloglogplus crate
(HyperLogLog++ with p=16) to the official Apache DataSketches HLL
implementation (datasketches crate v0.2.0 with lg_k=11, Hll4).

This enables returning raw HLL sketch bytes from pomsky to Datadog's
event query, where they can be properly deserialized and merged using
the same DataSketches library (Java). The previous implementation
required pomsky to fabricate fake HLL sketches from scalar cardinality
estimates, which produced incorrect results when merged.

Changes:
- Cargo.toml: hyperloglogplus 0.4.1 -> datasketches 0.2.0
- CardinalityCollector: HyperLogLogPlus<u64, BuildSaltedHasher> -> HllSketch
- Custom Serde impl using HllSketch binary format (cross-shard compat)
- New to_sketch_bytes() for external consumers (pomsky)
- Salt preserved via (salt, value) tuple hashing for column type disambiguation
- Removed BuildSaltedHasher struct
- Added 4 new unit tests (serde roundtrip, merge, binary compat, salt)
2026-02-11 08:49:46 -05:00
trinity-1686a
5562ce6037 Merge pull request #2818 from Darkheir/fix/query_grammar_regex_between_parentheses 2026-02-11 11:39:58 +01:00
Metin Dumandag
09b6ececa7 Export fields of the PercentileValuesVecEntry (#2833)
Otherwise, there is no way to access these fields when not using the
json serialized form of the aggregation results.

This simple data struct is part of the public api,
so its fields should be accessible as well.
2026-02-11 11:31:07 +01:00
Moe
8018016e46 feat: add fast field support for Bytes type (#100) (#2830)
## What

Enable range queries and TopN sorting on `Bytes` fast fields, bringing them to parity with `Str` fields.

## Why

`BytesColumn` uses the same dictionary encoding as `StrColumn` internally, but range queries and TopN sorting were explicitly disabled for `Bytes`. This prevented use cases like storing lexicographically sortable binary data (e.g., arbitrary-precision decimals) that need efficient range filtering.

## How

1. **Enable range queries for Bytes** - Changed `is_type_valid_for_fastfield_range_query()` to return `true` for `Type::Bytes`
2. **Add BytesColumn handling in scorer** - Added a branch in `FastFieldRangeWeight::scorer()` to handle bytes fields using dictionary ordinal lookup (mirrors the existing `StrColumn` logic)
3. **Add SortByBytes** - New sort key computer for TopN queries on bytes columns

## Tests

- `test_bytes_field_ff_range_query` - Tests inclusive/exclusive bounds and unbounded ranges
- `test_sort_by_bytes_asc` / `test_sort_by_bytes_desc` - Tests lexicographic ordering in both directions
2026-02-11 11:26:18 +01:00
trinity-1686a
6bf185dc3f Merge pull request #2829 from quickwit-oss/cong.xie/add-intermediate-accessors 2026-02-10 17:07:24 +01:00
cong.xie
bb141abe22 feat(aggregation): add keys() accessor to IntermediateAggregationResults 2026-02-09 15:38:35 -05:00
cong.xie
f1c29ba972 resolve conflcit 2026-02-06 14:23:11 -05:00
cong.xie
ae0554a6a5 feat(aggregation): add public accessors for intermediate aggregation results
Add accessor methods to allow external crates to read intermediate
aggregation results without accessing pub(crate) fields:

- IntermediateAggregationResults: get(), remove()
- IntermediateTermBucketResult: entries(), sum_other_doc_count(), doc_count_error_upper_bound()
- IntermediateAverage: stats()
- IntermediateStats: count(), sum()
- IntermediateKey: Display impl for string conversion
2026-02-06 11:12:20 -05:00
cong.xie
0d7abe5d23 feat(aggregation): add public accessors for intermediate aggregation results
Add accessor methods to allow external crates to read intermediate
aggregation results without accessing pub(crate) fields:

- IntermediateAggregationResults: get(), get_mut(), remove()
- IntermediateTermBucketResult: entries(), sum_other_doc_count(), doc_count_error_upper_bound()
- IntermediateAverage: stats()
- IntermediateStats: count(), sum()
- IntermediateKey: Display impl for string conversion
2026-02-06 10:28:59 -05:00
PSeitz
28db952131 Add regex search and merge segments benchmark (#2826)
* add merge_segments benchmark

* add regex search bench
2026-02-02 17:28:02 +01:00
PSeitz
98ebbf922d faster exclude queries (#2825)
* faster exclude queries

Faster exclude queries with multiple terms.

Changes `Exclude` to be able to exclude multiple DocSets, instead of
putting the docsets into a union.
Use `seek_danger` in `Exclude`.

closes #2822

* replace unwrap with match
2026-01-30 17:06:41 +01:00
Paul Masurel
4a89e74597 Fix rfc3339 typos and add Claude Code skills (#2823)
Closes #2817
2026-01-30 12:00:28 +01:00
Alex Lazar
4d99e51e50 Bump oneshot to 0.1.13 per dependabot (#2821) 2026-01-30 11:42:01 +01:00
Darkheir
a55e4069e4 feat(query-grammar): Apply PR review suggestions
Signed-off-by: Darkheir <raphael.cohen@sekoia.io>
2026-01-28 14:13:55 +01:00
Darkheir
1fd30c62be fix(query-grammar): Fix regexes between parentheses
Signed-off-by: Darkheir <raphael.cohen@sekoia.io>
2026-01-28 10:37:51 +01:00
trinity-1686a
9b619998bd Merge pull request #2816 from evance-br/fix-closing-paren-elastic-range 2026-01-27 17:00:08 +01:00
Evance Soumaoro
765c448945 uncomment commented code when testing 2026-01-27 13:19:41 +00:00
Evance Soumaoro
943594ebaa uncomment commented code when testing 2026-01-27 13:08:38 +00:00
Evance Soumaoro
df17daae0d fix closing parenthesis error on elastic range queries for lenient parser 2026-01-27 13:01:14 +00:00
119 changed files with 9678 additions and 2633 deletions

View File

@@ -0,0 +1,125 @@
---
name: rationalize-deps
description: Analyze Cargo.toml dependencies and attempt to remove unused features to reduce compile times and binary size
---
# Rationalize Dependencies
This skill analyzes Cargo.toml dependencies to identify and remove unused features.
## Overview
Many crates enable features by default that may not be needed. This skill:
1. Identifies dependencies with default features enabled
2. Tests if `default-features = false` works
3. Identifies which specific features are actually needed
4. Verifies compilation after changes
## Step 1: Identify the target
Ask the user which crate(s) to analyze:
- A specific crate name (e.g., "tokio", "serde")
- A specific workspace member (e.g., "quickwit-search")
- "all" to scan the entire workspace
## Step 2: Analyze current dependencies
For the workspace Cargo.toml (`quickwit/Cargo.toml`), list dependencies that:
- Do NOT have `default-features = false`
- Have default features that might be unnecessary
Run: `cargo tree -p <crate> -f "{p} {f}" --edges features` to see what features are actually used.
## Step 3: For each candidate dependency
### 3a: Check the crate's default features
Look up the crate on crates.io or check its Cargo.toml to understand:
- What features are enabled by default
- What each feature provides
Use: `cargo metadata --format-version=1 | jq '.packages[] | select(.name == "<crate>") | .features'`
### 3b: Try disabling default features
Modify the dependency in `quickwit/Cargo.toml`:
From:
```toml
some-crate = { version = "1.0" }
```
To:
```toml
some-crate = { version = "1.0", default-features = false }
```
### 3c: Run cargo check
Run: `cargo check --workspace` (or target specific packages for faster feedback)
If compilation fails:
1. Read the error messages to identify which features are needed
2. Add only the required features explicitly:
```toml
some-crate = { version = "1.0", default-features = false, features = ["needed-feature"] }
```
3. Re-run cargo check
### 3d: Binary search for minimal features
If there are many default features, use binary search:
1. Start with no features
2. If it fails, add half the default features
3. Continue until you find the minimal set
## Step 4: Document findings
For each dependency analyzed, report:
- Original configuration
- New configuration (if changed)
- Features that were removed
- Any features that are required
## Step 5: Verify full build
After all changes, run:
```bash
cargo check --workspace --all-targets
cargo test --workspace --no-run
```
## Common Patterns
### Serde
Often only needs `derive`:
```toml
serde = { version = "1.0", default-features = false, features = ["derive", "std"] }
```
### Tokio
Identify which runtime features are actually used:
```toml
tokio = { version = "1.0", default-features = false, features = ["rt-multi-thread", "macros", "sync"] }
```
### Reqwest
Often doesn't need all TLS backends:
```toml
reqwest = { version = "0.11", default-features = false, features = ["rustls-tls", "json"] }
```
## Rollback
If changes cause issues:
```bash
git checkout quickwit/Cargo.toml
cargo check --workspace
```
## Tips
- Start with large crates that have many default features (tokio, reqwest, hyper)
- Use `cargo bloat --crates` to identify large dependencies
- Check `cargo tree -d` for duplicate dependencies that might indicate feature conflicts
- Some features are needed only for tests - consider using `[dev-dependencies]` features

View File

@@ -0,0 +1,60 @@
---
name: simple-pr
description: Create a simple PR from staged changes with an auto-generated commit message
disable-model-invocation: true
---
# Simple PR
Follow these steps to create a simple PR from staged changes:
## Step 1: Check workspace state
Run: `git status`
Verify that all changes have been staged (no unstaged changes). If there are unstaged changes, abort and ask the user to stage their changes first with `git add`.
Also verify that we are on the `main` branch. If not, abort and ask the user to switch to main first.
## Step 2: Ensure main is up to date
Run: `git pull origin main`
This ensures we're working from the latest code.
## Step 3: Review staged changes
Run: `git diff --cached`
Review the staged changes to understand what the PR will contain.
## Step 4: Generate commit message
Based on the staged changes, generate a concise commit message (1-2 sentences) that describes the "why" rather than the "what".
Display the proposed commit message to the user and ask for confirmation before proceeding.
## Step 5: Create a new branch
Get the git username: `git config user.name | tr ' ' '-' | tr '[:upper:]' '[:lower:]'`
Create a short, descriptive branch name based on the changes (e.g., `fix-typo-in-readme`, `add-retry-logic`, `update-deps`).
Create and checkout the branch: `git checkout -b {username}/{short-descriptive-name}`
## Step 6: Commit changes
Commit with the message from step 3:
```
git commit -m "{commit-message}"
```
## Step 7: Push and open a PR
Push the branch and open a PR:
```
git push -u origin {branch-name}
gh pr create --title "{commit-message-title}" --body "{longer-description-if-needed}"
```
Report the PR URL to the user when complete.

View File

@@ -15,7 +15,7 @@ rust-version = "1.85"
exclude = ["benches/*.json", "benches/*.txt"]
[dependencies]
oneshot = "0.1.7"
oneshot = "0.1.13"
base64 = "0.22.0"
byteorder = "1.4.3"
crc32fast = "1.3.2"
@@ -64,8 +64,8 @@ query-grammar = { version = "0.25.0", path = "./query-grammar", package = "tanti
tantivy-bitpacker = { version = "0.9", path = "./bitpacker" }
common = { version = "0.10", path = "./common/", package = "tantivy-common" }
tokenizer-api = { version = "0.6", path = "./tokenizer-api", package = "tantivy-tokenizer-api" }
sketches-ddsketch = { version = "0.3.0", features = ["use_serde"] }
hyperloglogplus = { version = "0.4.1", features = ["const-loop"] }
sketches-ddsketch = { path = "./sketches-ddsketch", features = ["use_serde"] }
datasketches = "0.2.0"
futures-util = { version = "0.3.28", optional = true }
futures-channel = { version = "0.3.28", optional = true }
fnv = "1.0.7"
@@ -144,6 +144,7 @@ members = [
"sstable",
"tokenizer-api",
"columnar",
"sketches-ddsketch",
]
# Following the "fail" crate best practises, we isolate
@@ -193,3 +194,12 @@ harness = false
[[bench]]
name = "str_search_and_get"
harness = false
[[bench]]
name = "merge_segments"
harness = false
[[bench]]
name = "regex_all_terms"
harness = false

View File

@@ -1,5 +1,6 @@
use binggan::plugins::PeakMemAllocPlugin;
use binggan::{black_box, InputGroup, PeakMemAlloc, INSTRUMENTED_SYSTEM};
use common::DateTime;
use rand::distr::weighted::WeightedIndex;
use rand::rngs::StdRng;
use rand::seq::IndexedRandom;
@@ -70,6 +71,12 @@ fn bench_agg(mut group: InputGroup<Index>) {
register!(group, terms_many_json_mixed_type_with_avg_sub_agg);
register!(group, composite_term_many_page_1000);
register!(group, composite_term_many_page_1000_with_avg_sub_agg);
register!(group, composite_term_few);
register!(group, composite_histogram);
register!(group, composite_histogram_calendar);
register!(group, cardinality_agg);
register!(group, terms_status_with_cardinality_agg);
@@ -313,6 +320,75 @@ fn terms_many_json_mixed_type_with_avg_sub_agg(index: &Index) {
});
execute_agg(index, agg_req);
}
fn composite_term_few(index: &Index) {
let agg_req = json!({
"my_ctf": {
"composite": {
"sources": [
{ "text_few_terms": { "terms": { "field": "text_few_terms" } } }
],
"size": 1000
}
},
});
execute_agg(index, agg_req);
}
fn composite_term_many_page_1000(index: &Index) {
let agg_req = json!({
"my_ctmp1000": {
"composite": {
"sources": [
{ "text_many_terms": { "terms": { "field": "text_many_terms" } } }
],
"size": 1000
}
},
});
execute_agg(index, agg_req);
}
fn composite_term_many_page_1000_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
"my_ctmp1000wasa": {
"composite": {
"sources": [
{ "text_many_terms": { "terms": { "field": "text_many_terms" } } }
],
"size": 1000,
},
"aggs": {
"average_f64": { "avg": { "field": "score_f64" } }
}
},
});
execute_agg(index, agg_req);
}
fn composite_histogram(index: &Index) {
let agg_req = json!({
"my_ch": {
"composite": {
"sources": [
{ "f64_histogram": { "histogram": { "field": "score_f64", "interval": 1 } } }
],
"size": 1000
}
},
});
execute_agg(index, agg_req);
}
fn composite_histogram_calendar(index: &Index) {
let agg_req = json!({
"my_chc": {
"composite": {
"sources": [
{ "time_histogram": { "date_histogram": { "field": "timestamp", "calendar_interval": "month" } } }
],
"size": 1000
}
},
});
execute_agg(index, agg_req);
}
fn execute_agg(index: &Index, agg_req: serde_json::Value) {
let agg_req: Aggregations = serde_json::from_value(agg_req).unwrap();
@@ -504,6 +580,7 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
let score_field = schema_builder.add_u64_field("score", score_fieldtype.clone());
let score_field_f64 = schema_builder.add_f64_field("score_f64", score_fieldtype.clone());
let score_field_i64 = schema_builder.add_i64_field("score_i64", score_fieldtype);
let date_field = schema_builder.add_date_field("timestamp", FAST);
// use tmp dir
let index = if reuse_index {
Index::create_in_dir("agg_bench", schema_builder.build())?
@@ -593,6 +670,7 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
score_field => val as u64,
score_field_f64 => lg_norm.sample(&mut rng),
score_field_i64 => val as i64,
date_field => DateTime::from_timestamp_millis((val * 1_000_000.) as i64),
))?;
if cardinality == Cardinality::OptionalSparse {
for _ in 0..20 {

224
benches/merge_segments.rs Normal file
View File

@@ -0,0 +1,224 @@
// Benchmarks segment merging
//
// Notes:
// - Input segments are kept intact (no deletes / no IndexWriter merge).
// - Output is written to a `NullDirectory` that discards all files except
// fieldnorms (needed for merging).
use std::collections::HashMap;
use std::io::{self, Write};
use std::path::{Path, PathBuf};
use std::sync::{Arc, RwLock};
use binggan::{black_box, BenchRunner};
use rand::prelude::*;
use rand::rngs::StdRng;
use rand::SeedableRng;
use tantivy::directory::error::{DeleteError, OpenReadError, OpenWriteError};
use tantivy::directory::{
AntiCallToken, Directory, FileHandle, OwnedBytes, TerminatingWrite, WatchCallback, WatchHandle,
WritePtr,
};
use tantivy::indexer::{merge_filtered_segments, NoMergePolicy};
use tantivy::schema::{Schema, TEXT};
use tantivy::{doc, HasLen, Index, IndexSettings, Segment};
#[derive(Clone, Default, Debug)]
struct NullDirectory {
blobs: Arc<RwLock<HashMap<PathBuf, OwnedBytes>>>,
}
struct NullWriter;
impl Write for NullWriter {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
impl TerminatingWrite for NullWriter {
fn terminate_ref(&mut self, _token: AntiCallToken) -> io::Result<()> {
Ok(())
}
}
struct InMemoryWriter {
path: PathBuf,
buffer: Vec<u8>,
blobs: Arc<RwLock<HashMap<PathBuf, OwnedBytes>>>,
}
impl Write for InMemoryWriter {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.buffer.extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
impl TerminatingWrite for InMemoryWriter {
fn terminate_ref(&mut self, _token: AntiCallToken) -> io::Result<()> {
let bytes = OwnedBytes::new(std::mem::take(&mut self.buffer));
self.blobs.write().unwrap().insert(self.path.clone(), bytes);
Ok(())
}
}
#[derive(Debug, Default)]
struct NullFileHandle;
impl HasLen for NullFileHandle {
fn len(&self) -> usize {
0
}
}
impl FileHandle for NullFileHandle {
fn read_bytes(&self, _range: std::ops::Range<usize>) -> io::Result<OwnedBytes> {
unimplemented!()
}
}
impl Directory for NullDirectory {
fn get_file_handle(&self, path: &Path) -> Result<Arc<dyn FileHandle>, OpenReadError> {
if let Some(bytes) = self.blobs.read().unwrap().get(path) {
return Ok(Arc::new(bytes.clone()));
}
Ok(Arc::new(NullFileHandle))
}
fn delete(&self, _path: &Path) -> Result<(), DeleteError> {
Ok(())
}
fn exists(&self, _path: &Path) -> Result<bool, OpenReadError> {
Ok(true)
}
fn open_write(&self, path: &Path) -> Result<WritePtr, OpenWriteError> {
let path_buf = path.to_path_buf();
if path.to_string_lossy().ends_with(".fieldnorm") {
let writer = InMemoryWriter {
path: path_buf,
buffer: Vec::new(),
blobs: Arc::clone(&self.blobs),
};
Ok(io::BufWriter::new(Box::new(writer)))
} else {
Ok(io::BufWriter::new(Box::new(NullWriter)))
}
}
fn atomic_read(&self, path: &Path) -> Result<Vec<u8>, OpenReadError> {
if let Some(bytes) = self.blobs.read().unwrap().get(path) {
return Ok(bytes.as_slice().to_vec());
}
Err(OpenReadError::FileDoesNotExist(path.to_path_buf()))
}
fn atomic_write(&self, _path: &Path, _data: &[u8]) -> io::Result<()> {
Ok(())
}
fn sync_directory(&self) -> io::Result<()> {
Ok(())
}
fn watch(&self, _watch_callback: WatchCallback) -> tantivy::Result<WatchHandle> {
Ok(WatchHandle::empty())
}
}
struct MergeScenario {
#[allow(dead_code)]
index: Index,
segments: Vec<Segment>,
settings: IndexSettings,
label: String,
}
fn build_index(
num_segments: usize,
docs_per_segment: usize,
tokens_per_doc: usize,
vocab_size: usize,
) -> MergeScenario {
let mut schema_builder = Schema::builder();
let body = schema_builder.add_text_field("body", TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema.clone());
assert!(vocab_size > 0);
let total_tokens = num_segments * docs_per_segment * tokens_per_doc;
let use_unique_terms = vocab_size >= total_tokens;
let mut rng = StdRng::from_seed([7u8; 32]);
let mut next_token_id: u64 = 0;
{
let mut writer = index.writer_with_num_threads(1, 256_000_000).unwrap();
writer.set_merge_policy(Box::new(NoMergePolicy));
for _ in 0..num_segments {
for _ in 0..docs_per_segment {
let mut tokens = Vec::with_capacity(tokens_per_doc);
for _ in 0..tokens_per_doc {
let token_id = if use_unique_terms {
let id = next_token_id;
next_token_id += 1;
id
} else {
rng.random_range(0..vocab_size as u64)
};
tokens.push(format!("term_{token_id}"));
}
writer.add_document(doc!(body => tokens.join(" "))).unwrap();
}
writer.commit().unwrap();
}
}
let segments = index.searchable_segments().unwrap();
let settings = index.settings().clone();
let label = format!(
"segments={}, docs/seg={}, tokens/doc={}, vocab={}",
num_segments, docs_per_segment, tokens_per_doc, vocab_size
);
MergeScenario {
index,
segments,
settings,
label,
}
}
fn main() {
let scenarios = vec![
build_index(8, 50_000, 12, 8),
build_index(16, 50_000, 12, 8),
build_index(16, 100_000, 12, 8),
build_index(8, 50_000, 8, 8 * 50_000 * 8),
];
let mut runner = BenchRunner::new();
for scenario in scenarios {
let mut group = runner.new_group();
group.set_name(format!("merge_segments inv_index — {}", scenario.label));
let segments = scenario.segments.clone();
let settings = scenario.settings.clone();
group.register("merge", move |_| {
let output_dir = NullDirectory::default();
let filter_doc_ids = vec![None; segments.len()];
let merged_index =
merge_filtered_segments(&segments, settings.clone(), filter_doc_ids, output_dir)
.unwrap();
black_box(merged_index);
});
group.run();
}
}

113
benches/regex_all_terms.rs Normal file
View File

@@ -0,0 +1,113 @@
// Benchmarks regex query that matches all terms in a synthetic index.
//
// Corpus model:
// - N unique terms: t000000, t000001, ...
// - M docs
// - K tokens per doc: doc i gets terms derived from (i, token_index)
//
// Query:
// - Regex "t.*" to match all terms
//
// Run with:
// - cargo bench --bench regex_all_terms
//
use std::fmt::Write;
use binggan::{black_box, BenchRunner};
use tantivy::collector::Count;
use tantivy::query::RegexQuery;
use tantivy::schema::{Schema, TEXT};
use tantivy::{doc, Index, ReloadPolicy};
const HEAP_SIZE_BYTES: usize = 200_000_000;
#[derive(Clone, Copy)]
struct BenchConfig {
num_terms: usize,
num_docs: usize,
tokens_per_doc: usize,
}
fn main() {
let configs = default_configs();
let mut runner = BenchRunner::new();
for config in configs {
let (index, text_field) = build_index(config, HEAP_SIZE_BYTES);
let reader = index
.reader_builder()
.reload_policy(ReloadPolicy::Manual)
.try_into()
.expect("reader");
let searcher = reader.searcher();
let query = RegexQuery::from_pattern("t.*", text_field).expect("regex query");
let mut group = runner.new_group();
group.set_name(format!(
"regex_all_terms_t{}_d{}_k{}",
config.num_terms, config.num_docs, config.tokens_per_doc
));
group.register("regex_count", move |_| {
let count = searcher.search(&query, &Count).expect("search");
black_box(count);
});
group.run();
}
}
fn default_configs() -> Vec<BenchConfig> {
vec![
BenchConfig {
num_terms: 10_000,
num_docs: 100_000,
tokens_per_doc: 1,
},
BenchConfig {
num_terms: 10_000,
num_docs: 100_000,
tokens_per_doc: 8,
},
BenchConfig {
num_terms: 100_000,
num_docs: 100_000,
tokens_per_doc: 1,
},
BenchConfig {
num_terms: 100_000,
num_docs: 100_000,
tokens_per_doc: 8,
},
]
}
fn build_index(config: BenchConfig, heap_size_bytes: usize) -> (Index, tantivy::schema::Field) {
let mut schema_builder = Schema::builder();
let text_field = schema_builder.add_text_field("text", TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let term_width = config.num_terms.to_string().len();
{
let mut writer = index
.writer_with_num_threads(1, heap_size_bytes)
.expect("writer");
let mut buffer = String::new();
for doc_id in 0..config.num_docs {
buffer.clear();
for token_idx in 0..config.tokens_per_doc {
if token_idx > 0 {
buffer.push(' ');
}
let term_id = (doc_id * config.tokens_per_doc + token_idx) % config.num_terms;
write!(&mut buffer, "t{term_id:0term_width$}").expect("write token");
}
writer
.add_document(doc!(text_field => buffer.as_str()))
.expect("add_document");
}
writer.commit().expect("commit");
}
(index, text_field)
}

View File

@@ -45,7 +45,7 @@ fn build_shared_indices(num_docs: usize, distribution: &str) -> BenchIndex {
match distribution {
"dense_random" => {
for _doc_id in 0..num_docs {
let suffix = rng.random_range(0u64..1000u64);
let suffix = rng.gen_range(0u64..1000u64);
let str_val = format!("str_{:03}", suffix);
writer
@@ -71,7 +71,7 @@ fn build_shared_indices(num_docs: usize, distribution: &str) -> BenchIndex {
}
"sparse_random" => {
for _doc_id in 0..num_docs {
let suffix = rng.random_range(0u64..1000000u64);
let suffix = rng.gen_range(0u64..1000000u64);
let str_val = format!("str_{:07}", suffix);
writer

View File

@@ -31,7 +31,7 @@ pub use u64_based::{
serialize_and_load_u64_based_column_values, serialize_u64_based_column_values,
};
pub use u128_based::{
CompactSpaceU64Accessor, open_u128_as_compact_u64, open_u128_mapped,
CompactHit, CompactSpaceU64Accessor, open_u128_as_compact_u64, open_u128_mapped,
serialize_column_values_u128,
};
pub use vec_column::VecColumn;

View File

@@ -292,6 +292,19 @@ impl BinarySerializable for IPCodecParams {
}
}
/// Represents the result of looking up a u128 value in the compact space.
///
/// If a value is outside the compact space, the next compact value is returned.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CompactHit {
/// The value exists in the compact space
Exact(u32),
/// The value does not exist in the compact space, but the next higher value does
Next(u32),
/// The value is greater than the maximum compact value
AfterLast,
}
/// Exposes the compact space compressed values as u64.
///
/// This allows faster access to the values, as u64 is faster to work with than u128.
@@ -309,6 +322,11 @@ impl CompactSpaceU64Accessor {
pub fn compact_to_u128(&self, compact: u32) -> u128 {
self.0.compact_to_u128(compact)
}
/// Finds the next compact space value for a given u128 value.
pub fn u128_to_next_compact(&self, value: u128) -> CompactHit {
self.0.u128_to_next_compact(value)
}
}
impl ColumnValues<u64> for CompactSpaceU64Accessor {
@@ -430,6 +448,26 @@ impl CompactSpaceDecompressor {
Ok(decompressor)
}
/// Finds the next compact space value for a given u128 value
pub fn u128_to_next_compact(&self, value: u128) -> CompactHit {
// Try to convert to compact space
match self.u128_to_compact(value) {
// Value is in compact space, return its compact representation
Ok(compact) => CompactHit::Exact(compact),
// Value is not in compact space
Err(pos) => {
if pos >= self.params.compact_space.ranges_mapping.len() {
// Value is beyond all ranges, no next value exists
CompactHit::AfterLast
} else {
// Get the next range and return its start compact value
let next_range = &self.params.compact_space.ranges_mapping[pos];
CompactHit::Next(next_range.compact_start)
}
}
}
}
/// Converting to compact space for the decompressor is more complex, since we may get values
/// which are outside the compact space. e.g. if we map
/// 1000 => 5
@@ -823,6 +861,41 @@ mod tests {
let _data = test_aux_vals(vals);
}
#[test]
fn test_u128_to_next_compact() {
let vals = &[100u128, 200u128, 1_000_000_000u128, 1_000_000_100u128];
let mut data = test_aux_vals(vals);
let _header = U128Header::deserialize(&mut data);
let decomp = CompactSpaceDecompressor::open(data).unwrap();
// Test value that's already in a range
let compact_100 = decomp.u128_to_compact(100).unwrap();
assert_eq!(
decomp.u128_to_next_compact(100),
CompactHit::Exact(compact_100)
);
// Test value between two ranges
let compact_million = decomp.u128_to_compact(1_000_000_000).unwrap();
assert_eq!(
decomp.u128_to_next_compact(250),
CompactHit::Next(compact_million)
);
// Test value before the first range
assert_eq!(
decomp.u128_to_next_compact(50),
CompactHit::Next(compact_100)
);
// Test value after the last range
assert_eq!(
decomp.u128_to_next_compact(10_000_000_000),
CompactHit::AfterLast
);
}
use proptest::prelude::*;
fn num_strategy() -> impl Strategy<Value = u128> {

View File

@@ -7,7 +7,7 @@ mod compact_space;
use common::{BinarySerializable, OwnedBytes, VInt};
pub use compact_space::{
CompactSpaceCompressor, CompactSpaceDecompressor, CompactSpaceU64Accessor,
CompactHit, CompactSpaceCompressor, CompactSpaceDecompressor, CompactSpaceU64Accessor,
};
use crate::column_values::monotonic_map_column;

View File

@@ -59,7 +59,7 @@ pub struct RowAddr {
pub row_id: RowId,
}
pub use sstable::Dictionary;
pub use sstable::{Dictionary, TermOrdHit};
pub type Streamer<'a> = sstable::Streamer<'a, VoidSSTable>;
pub use common::DateTime;

View File

@@ -178,11 +178,13 @@ impl TinySet {
#[derive(Clone)]
pub struct BitSet {
tinysets: Box<[TinySet]>,
len: u64,
max_value: u32,
}
impl std::fmt::Debug for BitSet {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BitSet")
.field("len", &self.len)
.field("max_value", &self.max_value)
.finish()
}
@@ -210,6 +212,7 @@ impl BitSet {
let tinybitsets = vec![TinySet::empty(); num_buckets as usize].into_boxed_slice();
BitSet {
tinysets: tinybitsets,
len: 0,
max_value,
}
}
@@ -227,6 +230,7 @@ impl BitSet {
}
BitSet {
tinysets: tinybitsets,
len: max_value as u64,
max_value,
}
}
@@ -245,19 +249,17 @@ impl BitSet {
/// Intersect with tinysets
fn intersect_update_with_iter(&mut self, other: impl Iterator<Item = TinySet>) {
self.len = 0;
for (left, right) in self.tinysets.iter_mut().zip(other) {
*left = left.intersect(right);
self.len += left.len() as u64;
}
}
/// Returns the number of elements in the `BitSet`.
#[inline]
pub fn len(&self) -> usize {
self.tinysets
.iter()
.copied()
.map(|tinyset| tinyset.len())
.sum::<u32>() as usize
self.len as usize
}
/// Inserts an element in the `BitSet`
@@ -266,7 +268,7 @@ impl BitSet {
// we do not check saturated els.
let higher = el / 64u32;
let lower = el % 64u32;
self.tinysets[higher as usize].insert_mut(lower);
self.len += u64::from(self.tinysets[higher as usize].insert_mut(lower));
}
/// Inserts an element in the `BitSet`
@@ -275,7 +277,7 @@ impl BitSet {
// we do not check saturated els.
let higher = el / 64u32;
let lower = el % 64u32;
self.tinysets[higher as usize].remove_mut(lower);
self.len -= u64::from(self.tinysets[higher as usize].remove_mut(lower));
}
/// Returns true iff the elements is in the `BitSet`.
@@ -297,9 +299,6 @@ impl BitSet {
.map(|delta_bucket| bucket + delta_bucket as u32)
}
/// Returns the maximum number of elements in the bitset.
///
/// Warning: The largest element the bitset can contain is `max_value - 1`.
#[inline]
pub fn max_value(&self) -> u32 {
self.max_value

View File

@@ -60,7 +60,7 @@ At indexing, tantivy will try to interpret number and strings as different type
priority order.
Numbers will be interpreted as u64, i64 and f64 in that order.
Strings will be interpreted as rfc3999 dates or simple strings.
Strings will be interpreted as rfc3339 dates or simple strings.
The first working type is picked and is the only term that is emitted for indexing.
Note this interpretation happens on a per-document basis, and there is no effort to try to sniff
@@ -81,7 +81,7 @@ Will be interpreted as
(my_path.my_segment, String, 233) or (my_path.my_segment, u64, 233)
```
Likewise, we need to emit two tokens if the query contains an rfc3999 date.
Likewise, we need to emit two tokens if the query contains an rfc3339 date.
Indeed the date could have been actually a single token inside the text of a document at ingestion time. Generally speaking, we will always at least emit a string token in query parsing, and sometimes more.
If one more json field is defined, things get even more complicated.

View File

@@ -91,10 +91,46 @@ fn main() -> tantivy::Result<()> {
}
}
// Some other powerful operations (especially `.seek`) may be useful to consume these
// A `Term` is a text token associated with a field.
// Let's go through all docs containing the term `title:the` and access their position
let term_the = Term::from_field_text(title, "the");
// Some other powerful operations (especially `.skip_to`) may be useful to consume these
// posting lists rapidly.
// You can check for them in the [`DocSet`](https://docs.rs/tantivy/~0/tantivy/trait.DocSet.html) trait
// and the [`Postings`](https://docs.rs/tantivy/~0/tantivy/trait.Postings.html) trait
// Also, for some VERY specific high performance use case like an OLAP analysis of logs,
// you can get better performance by accessing directly the blocks of doc ids.
for segment_reader in searcher.segment_readers() {
// A segment contains different data structure.
// Inverted index stands for the combination of
// - the term dictionary
// - the inverted lists associated with each terms and their positions
let inverted_index = segment_reader.inverted_index(title)?;
// This segment posting object is like a cursor over the documents matching the term.
// The `IndexRecordOption` arguments tells tantivy we will be interested in both term
// frequencies and positions.
//
// If you don't need all this information, you may get better performance by decompressing
// less information.
if let Some(mut block_segment_postings) =
inverted_index.read_block_postings(&term_the, IndexRecordOption::Basic)?
{
loop {
let docs = block_segment_postings.docs();
if docs.is_empty() {
break;
}
// Once again these docs MAY contains deleted documents as well.
let docs = block_segment_postings.docs();
// Prints `Docs [0, 2].`
println!("Docs {docs:?}");
block_segment_postings.advance();
}
}
}
Ok(())
}

View File

@@ -560,7 +560,7 @@ fn range_infallible(inp: &str) -> JResult<&str, UserInputLeaf> {
(
(
value((), tag(">=")),
map(word_infallible("", false), |(bound, err)| {
map(word_infallible(")", false), |(bound, err)| {
(
(
bound
@@ -574,7 +574,7 @@ fn range_infallible(inp: &str) -> JResult<&str, UserInputLeaf> {
),
(
value((), tag("<=")),
map(word_infallible("", false), |(bound, err)| {
map(word_infallible(")", false), |(bound, err)| {
(
(
UserInputBound::Unbounded,
@@ -588,7 +588,7 @@ fn range_infallible(inp: &str) -> JResult<&str, UserInputLeaf> {
),
(
value((), tag(">")),
map(word_infallible("", false), |(bound, err)| {
map(word_infallible(")", false), |(bound, err)| {
(
(
bound
@@ -602,7 +602,7 @@ fn range_infallible(inp: &str) -> JResult<&str, UserInputLeaf> {
),
(
value((), tag("<")),
map(word_infallible("", false), |(bound, err)| {
map(word_infallible(")", false), |(bound, err)| {
(
(
UserInputBound::Unbounded,
@@ -704,7 +704,11 @@ fn regex(inp: &str) -> IResult<&str, UserInputLeaf> {
many1(alt((preceded(char('\\'), char('/')), none_of("/")))),
char('/'),
),
peek(alt((multispace1, eof))),
peek(alt((
value((), multispace1),
value((), char(')')),
value((), eof),
))),
),
|elements| UserInputLeaf::Regex {
field: None,
@@ -721,8 +725,12 @@ fn regex_infallible(inp: &str) -> JResult<&str, UserInputLeaf> {
opt_i_err(char('/'), "missing delimiter /"),
),
opt_i_err(
peek(alt((multispace1, eof))),
"expected whitespace or end of input",
peek(alt((
value((), multispace1),
value((), char(')')),
value((), eof),
))),
"expected whitespace, closing parenthesis, or end of input",
),
)(inp)
{
@@ -1323,6 +1331,14 @@ mod test {
test_parse_query_to_ast_helper("<a", "{\"*\" TO \"a\"}");
test_parse_query_to_ast_helper("<=a", "{\"*\" TO \"a\"]");
test_parse_query_to_ast_helper("<=bsd", "{\"*\" TO \"bsd\"]");
test_parse_query_to_ast_helper("(<=42)", "{\"*\" TO \"42\"]");
test_parse_query_to_ast_helper("(<=42 )", "{\"*\" TO \"42\"]");
test_parse_query_to_ast_helper("(age:>5)", "\"age\":{\"5\" TO \"*\"}");
test_parse_query_to_ast_helper(
"(title:bar AND age:>12)",
"(+\"title\":bar +\"age\":{\"12\" TO \"*\"})",
);
}
#[test]
@@ -1699,6 +1715,10 @@ mod test {
test_parse_query_to_ast_helper("foo:(A OR B)", "(?\"foo\":A ?\"foo\":B)");
test_parse_query_to_ast_helper("foo:(A* OR B*)", "(?\"foo\":A* ?\"foo\":B*)");
test_parse_query_to_ast_helper("foo:(*A OR *B)", "(?\"foo\":*A ?\"foo\":*B)");
// Regexes between parentheses
test_parse_query_to_ast_helper("foo:(/A.*/)", "\"foo\":/A.*/");
test_parse_query_to_ast_helper("foo:(/A.*/ OR /B.*/)", "(?\"foo\":/A.*/ ?\"foo\":/B.*/)");
}
#[test]

View File

@@ -66,6 +66,7 @@ impl UserInputLeaf {
}
UserInputLeaf::Range { field, .. } if field.is_none() => *field = Some(default_field),
UserInputLeaf::Set { field, .. } if field.is_none() => *field = Some(default_field),
UserInputLeaf::Regex { field, .. } if field.is_none() => *field = Some(default_field),
_ => (), // field was already set, do nothing
}
}

View File

@@ -0,0 +1,27 @@
[package]
name = "sketches-ddsketch"
version = "0.3.0"
authors = ["Mike Heffner <mikeh@fesnel.com>"]
edition = "2018"
license = "Apache-2.0"
readme = "README.md"
repository = "https://github.com/mheffner/rust-sketches-ddsketch"
homepage = "https://github.com/mheffner/rust-sketches-ddsketch"
description = """
A direct port of the Golang DDSketch implementation.
"""
exclude = [".gitignore"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
serde = { package = "serde", version = "1.0", optional = true, features = ["derive", "serde_derive"] }
[dev-dependencies]
approx = "0.5.1"
rand = "0.8.5"
rand_distr = "0.4.3"
[features]
use_serde = ["serde", "serde/derive"]

201
sketches-ddsketch/LICENSE Normal file
View File

@@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [2019] [Mike Heffner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@@ -0,0 +1,11 @@
clean:
cargo clean
test:
cargo test
test_logs:
cargo test -- --nocapture
test_performance:
cargo test --release --jobs 1 test_performance -- --ignored --nocapture

View File

@@ -0,0 +1,37 @@
# sketches-ddsketch
This is a direct port of the [Golang](https://github.com/DataDog/sketches-go)
[DDSketch](https://arxiv.org/pdf/1908.10693.pdf) quantile sketch implementation
to Rust. DDSketch is a fully-mergeable quantile sketch with relative-error
guarantees and is extremely fast.
# DDSketch
* Sketch size automatically grows as needed, starting with 128 bins.
* Extremely fast sample insertion and sketch merges.
## Usage
```rust
use sketches_ddsketch::{Config, DDSketch};
let config = Config::defaults();
let mut sketch = DDSketch::new(c);
sketch.add(1.0);
sketch.add(1.0);
sketch.add(1.0);
// Get p=50%
let quantile = sketch.quantile(0.5).unwrap();
assert_eq!(quantile, Some(1.0));
```
## Performance
No performance tuning has been done with this implementation of the port, so we
would expect similar profiles to the original implementation.
Out of the box we see can achieve over 70M sample inserts/sec and 350K sketch
merges/sec. All tests run on a single core Intel i7 processor with 4.2Ghz max
clock.

View File

@@ -0,0 +1,98 @@
#[cfg(feature = "use_serde")]
use serde::{Deserialize, Serialize};
const DEFAULT_MAX_BINS: u32 = 2048;
const DEFAULT_ALPHA: f64 = 0.01;
const DEFAULT_MIN_VALUE: f64 = 1.0e-9;
/// The configuration struct for constructing a `DDSketch`
#[derive(Copy, Clone, Debug, PartialEq)]
#[cfg_attr(feature = "use_serde", derive(Serialize, Deserialize))]
pub struct Config {
pub max_num_bins: u32,
pub gamma: f64,
pub(crate) gamma_ln: f64,
pub(crate) min_value: f64,
pub offset: i32,
}
fn log_gamma(value: f64, gamma_ln: f64) -> f64 {
value.ln() / gamma_ln
}
impl Config {
/// Construct a new `Config` struct with specific parameters. If you are unsure of how to
/// configure this, the `defaults` method constructs a `Config` with built-in defaults.
///
/// `max_num_bins` is the max number of bins the DDSketch will grow to, in steps of 128 bins.
pub fn new(alpha: f64, max_num_bins: u32, min_value: f64) -> Self {
// Aligned with Java's LogarithmicMapping / LogLikeIndexMapping:
// gamma = (1 + alpha) / (1 - alpha) (correctingFactor=1 for LogarithmicMapping)
// gamma_ln = gamma.ln() (not ln_1p, to match Java's Math.log(gamma))
// See: https://github.com/DataDog/sketches-java/blob/master/src/main/java/com/datadoghq/sketch/ddsketch/mapping/LogLikeIndexMapping.java (gamma() static method)
// See: https://github.com/DataDog/sketches-java/blob/master/src/main/java/com/datadoghq/sketch/ddsketch/mapping/LogarithmicMapping.java (constructor, correctingFactor()=1)
let gamma = (1.0 + alpha) / (1.0 - alpha);
let gamma_ln = gamma.ln();
Config {
max_num_bins,
gamma,
gamma_ln,
min_value,
offset: 1 - (log_gamma(min_value, gamma_ln) as i32),
}
}
/// Return a `Config` using built-in default settings
pub fn defaults() -> Self {
Self::new(DEFAULT_ALPHA, DEFAULT_MAX_BINS, DEFAULT_MIN_VALUE)
}
pub fn key(&self, v: f64) -> i32 {
// Aligned with Java's LogLikeIndexMapping.index(): floor-based indexing.
// Java uses `(int) index` / `(int) index - 1` which is equivalent to floor().
// See: https://github.com/DataDog/sketches-java/blob/master/src/main/java/com/datadoghq/sketch/ddsketch/mapping/LogLikeIndexMapping.java (index() method)
self.log_gamma(v).floor() as i32
}
pub fn value(&self, key: i32) -> f64 {
// Aligned with Java's LogLikeIndexMapping.value():
// lowerBound(index) * (1 + relativeAccuracy)
// = logInverse((index - indexOffset) / multiplier) * (1 + relativeAccuracy)
// = gamma^key * 2*gamma/(gamma+1)
// See: https://github.com/DataDog/sketches-java/blob/master/src/main/java/com/datadoghq/sketch/ddsketch/mapping/LogLikeIndexMapping.java (value() and lowerBound() methods)
self.pow_gamma(key) * (2.0 * self.gamma / (1.0 + self.gamma))
}
pub fn log_gamma(&self, value: f64) -> f64 {
log_gamma(value, self.gamma_ln)
}
pub fn pow_gamma(&self, key: i32) -> f64 {
((key as f64) * self.gamma_ln).exp()
}
pub fn min_possible(&self) -> f64 {
self.min_value
}
/// Reconstruct a Config from a gamma value (as decoded from the binary format).
/// Uses default max_num_bins and min_value.
/// See Java: https://github.com/DataDog/sketches-java/blob/master/src/main/java/com/datadoghq/sketch/ddsketch/mapping/LogarithmicMapping.java (LogarithmicMapping(double gamma, double indexOffset) constructor)
pub(crate) fn from_gamma(gamma: f64) -> Self {
let gamma_ln = gamma.ln();
Config {
max_num_bins: DEFAULT_MAX_BINS,
gamma,
gamma_ln,
min_value: DEFAULT_MIN_VALUE,
offset: 1 - (log_gamma(DEFAULT_MIN_VALUE, gamma_ln) as i32),
}
}
}
impl Default for Config {
fn default() -> Self {
Self::new(DEFAULT_ALPHA, DEFAULT_MAX_BINS, DEFAULT_MIN_VALUE)
}
}

View File

@@ -0,0 +1,385 @@
use std::{error, fmt};
#[cfg(feature = "use_serde")]
use serde::{Deserialize, Serialize};
use crate::config::Config;
use crate::store::Store;
type Result<T> = std::result::Result<T, DDSketchError>;
/// General error type for DDSketch, represents either an invalid quantile or an
/// incompatible merge operation.
#[derive(Debug, Clone)]
pub enum DDSketchError {
Quantile,
Merge,
}
impl fmt::Display for DDSketchError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
DDSketchError::Quantile => {
write!(f, "Invalid quantile, must be between 0 and 1 (inclusive)")
}
DDSketchError::Merge => write!(f, "Can not merge sketches with different configs"),
}
}
}
impl error::Error for DDSketchError {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
// Generic
None
}
}
/// This struct represents a [DDSketch](https://arxiv.org/pdf/1908.10693.pdf)
#[derive(Clone)]
#[cfg_attr(feature = "use_serde", derive(Serialize, Deserialize))]
pub struct DDSketch {
pub(crate) config: Config,
pub(crate) store: Store,
pub(crate) negative_store: Store,
pub(crate) min: f64,
pub(crate) max: f64,
pub(crate) sum: f64,
pub(crate) zero_count: u64,
}
impl Default for DDSketch {
fn default() -> Self {
Self::new(Default::default())
}
}
// XXX: functions should return Option<> in the case of empty
impl DDSketch {
/// Construct a `DDSketch`. Requires a `Config` specifying the parameters of the sketch
pub fn new(config: Config) -> Self {
DDSketch {
config,
store: Store::new(config.max_num_bins as usize),
negative_store: Store::new(config.max_num_bins as usize),
min: f64::INFINITY,
max: f64::NEG_INFINITY,
sum: 0.0,
zero_count: 0,
}
}
/// Add the sample to the sketch
pub fn add(&mut self, v: f64) {
if v > self.config.min_possible() {
let key = self.config.key(v);
self.store.add(key);
} else if v < -self.config.min_possible() {
let key = self.config.key(-v);
self.negative_store.add(key);
} else {
self.zero_count += 1;
}
if v < self.min {
self.min = v;
}
if self.max < v {
self.max = v;
}
self.sum += v;
}
/// Return the quantile value for quantiles between 0.0 and 1.0. Result is an error, represented
/// as DDSketchError::Quantile if the requested quantile is outside of that range.
///
/// If the sketch is empty the result is None, else Some(v) for the quantile value.
pub fn quantile(&self, q: f64) -> Result<Option<f64>> {
if !(0.0..=1.0).contains(&q) {
return Err(DDSketchError::Quantile);
}
if self.empty() {
return Ok(None);
}
if q == 0.0 {
return Ok(Some(self.min));
} else if q == 1.0 {
return Ok(Some(self.max));
}
let rank = (q * (self.count() as f64 - 1.0)) as u64;
let quantile;
if rank < self.negative_store.count() {
let reversed_rank = self.negative_store.count() - rank - 1;
let key = self.negative_store.key_at_rank(reversed_rank);
quantile = -self.config.value(key);
} else if rank < self.zero_count + self.negative_store.count() {
quantile = 0.0;
} else {
let key = self
.store
.key_at_rank(rank - self.zero_count - self.negative_store.count());
quantile = self.config.value(key);
}
Ok(Some(quantile))
}
/// Returns the minimum value seen, or None if sketch is empty
pub fn min(&self) -> Option<f64> {
if self.empty() {
None
} else {
Some(self.min)
}
}
/// Returns the maximum value seen, or None if sketch is empty
pub fn max(&self) -> Option<f64> {
if self.empty() {
None
} else {
Some(self.max)
}
}
/// Returns the sum of values seen, or None if sketch is empty
pub fn sum(&self) -> Option<f64> {
if self.empty() {
None
} else {
Some(self.sum)
}
}
/// Returns the number of values added to the sketch
pub fn count(&self) -> usize {
(self.store.count() + self.zero_count + self.negative_store.count()) as usize
}
/// Returns the length of the underlying `Store`. This is mainly only useful for understanding
/// how much the sketch has grown given the inserted values.
pub fn length(&self) -> usize {
self.store.length() as usize + self.negative_store.length() as usize
}
/// Merge the contents of another sketch into this one. The sketch that is merged into this one
/// is unchanged after the merge.
pub fn merge(&mut self, o: &DDSketch) -> Result<()> {
if self.config != o.config {
return Err(DDSketchError::Merge);
}
let was_empty = self.store.count() == 0;
// Merge the stores
self.store.merge(&o.store);
self.negative_store.merge(&o.negative_store);
self.zero_count += o.zero_count;
// Need to ensure we don't override min/max with initializers
// if either store were empty
if was_empty {
self.min = o.min;
self.max = o.max;
} else if o.store.count() > 0 {
if o.min < self.min {
self.min = o.min
}
if o.max > self.max {
self.max = o.max;
}
}
self.sum += o.sum;
Ok(())
}
fn empty(&self) -> bool {
self.count() == 0
}
/// Encode this sketch into the Java-compatible binary format used by
/// `com.datadoghq.sketch.ddsketch.DDSketchWithExactSummaryStatistics`.
pub fn to_java_bytes(&self) -> Vec<u8> {
crate::encoding::encode_to_java_bytes(self)
}
/// Decode a sketch from the Java-compatible binary format.
/// Accepts bytes produced by Java's `DDSketchWithExactSummaryStatistics.encode()`
/// with or without the `0x02` version prefix.
pub fn from_java_bytes(
bytes: &[u8],
) -> std::result::Result<Self, crate::encoding::DecodeError> {
crate::encoding::decode_from_java_bytes(bytes)
}
}
#[cfg(test)]
mod tests {
use approx::assert_relative_eq;
use crate::{Config, DDSketch};
#[test]
fn test_add_zero() {
let alpha = 0.01;
let c = Config::new(alpha, 2048, 10e-9);
let mut dd = DDSketch::new(c);
dd.add(0.0);
}
#[test]
fn test_quartiles() {
let alpha = 0.01;
let c = Config::new(alpha, 2048, 10e-9);
let mut dd = DDSketch::new(c);
// Initialize sketch with {1.0, 2.0, 3.0, 4.0}
for i in 1..5 {
dd.add(i as f64);
}
// We expect the following mappings from quantile to value:
// [0,0.33]: 1.0, (0.34,0.66]: 2.0, (0.67,0.99]: 3.0, (0.99, 1.0]: 4.0
let test_cases = vec![
(0.0, 1.0),
(0.25, 1.0),
(0.33, 1.0),
(0.34, 2.0),
(0.5, 2.0),
(0.66, 2.0),
(0.67, 3.0),
(0.75, 3.0),
(0.99, 3.0),
(1.0, 4.0),
];
for (q, val) in test_cases {
assert_relative_eq!(dd.quantile(q).unwrap().unwrap(), val, max_relative = alpha);
}
}
#[test]
fn test_neg_quartiles() {
let alpha = 0.01;
let c = Config::new(alpha, 2048, 10e-9);
let mut dd = DDSketch::new(c);
// Initialize sketch with {1.0, 2.0, 3.0, 4.0}
for i in 1..5 {
dd.add(-i as f64);
}
let test_cases = vec![
(0.0, -4.0),
(0.25, -4.0),
(0.5, -3.0),
(0.75, -2.0),
(1.0, -1.0),
];
for (q, val) in test_cases {
assert_relative_eq!(dd.quantile(q).unwrap().unwrap(), val, max_relative = alpha);
}
}
#[test]
fn test_simple_quantile() {
let c = Config::defaults();
let mut dd = DDSketch::new(c);
for i in 1..101 {
dd.add(i as f64);
}
assert_eq!(dd.quantile(0.95).unwrap().unwrap().ceil(), 95.0);
assert!(dd.quantile(-1.01).is_err());
assert!(dd.quantile(1.01).is_err());
}
#[test]
fn test_empty_sketch() {
let c = Config::defaults();
let dd = DDSketch::new(c);
assert_eq!(dd.quantile(0.98).unwrap(), None);
assert_eq!(dd.max(), None);
assert_eq!(dd.min(), None);
assert_eq!(dd.sum(), None);
assert_eq!(dd.count(), 0);
assert!(dd.quantile(1.01).is_err());
}
#[test]
fn test_basic_histogram_data() {
let values = &[
0.754225035,
0.752900282,
0.752812246,
0.752602367,
0.754310155,
0.753525981,
0.752981082,
0.752715536,
0.751667941,
0.755079054,
0.753528150,
0.755188464,
0.752508723,
0.750064549,
0.753960428,
0.751139298,
0.752523560,
0.753253428,
0.753498342,
0.751858358,
0.752104636,
0.753841300,
0.754467374,
0.753814334,
0.750881719,
0.753182556,
0.752576884,
0.753945708,
0.753571911,
0.752314573,
0.752586651,
];
let c = Config::defaults();
let mut dd = DDSketch::new(c);
for value in values {
dd.add(*value);
}
assert_eq!(dd.max(), Some(0.755188464));
assert_eq!(dd.min(), Some(0.750064549));
assert_eq!(dd.count(), 31);
assert_eq!(dd.sum(), Some(23.343630625000003));
assert!(dd.quantile(0.25).unwrap().is_some());
assert!(dd.quantile(0.5).unwrap().is_some());
assert!(dd.quantile(0.75).unwrap().is_some());
}
#[test]
fn test_length() {
let mut dd = DDSketch::default();
assert_eq!(dd.length(), 0);
dd.add(1.0);
assert_eq!(dd.length(), 128);
dd.add(2.0);
dd.add(3.0);
assert_eq!(dd.length(), 128);
dd.add(-1.0);
assert_eq!(dd.length(), 256);
dd.add(-2.0);
dd.add(-3.0);
assert_eq!(dd.length(), 256);
}
}

View File

@@ -0,0 +1,813 @@
//! Java-compatible binary encoding/decoding for DDSketch.
//!
//! This module implements the binary format used by the Java
//! `com.datadoghq.sketch.ddsketch.DDSketchWithExactSummaryStatistics` class
//! from the DataDog/sketches-java library. It enables cross-language
//! serialization so that sketches produced in Rust can be deserialized
//! and merged by Java consumers.
use std::fmt;
use crate::config::Config;
use crate::ddsketch::DDSketch;
use crate::store::Store;
// ---------------------------------------------------------------------------
// Flag byte layout
//
// Each flag byte packs a 2-bit type ordinal in the low bits and a 6-bit
// subflag in the upper bits: (subflag << 2) | type_ordinal
// See: https://github.com/DataDog/sketches-java/blob/master/src/main/java/com/datadoghq/sketch/ddsketch/encoding/Flag.java
// ---------------------------------------------------------------------------
/// The 2-bit type field occupying the low bits of every flag byte.
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum FlagType {
SketchFeatures = 0,
PositiveStore = 1,
IndexMapping = 2,
NegativeStore = 3,
}
impl FlagType {
fn from_byte(b: u8) -> Option<Self> {
match b & 0x03 {
0 => Some(Self::SketchFeatures),
1 => Some(Self::PositiveStore),
2 => Some(Self::IndexMapping),
3 => Some(Self::NegativeStore),
_ => None,
}
}
}
/// Construct a flag byte from a subflag and a type.
const fn flag(subflag: u8, flag_type: FlagType) -> u8 {
(subflag << 2) | (flag_type as u8)
}
// Pre-computed flag bytes for the sketch features we encode/decode.
const FLAG_INDEX_MAPPING_LOG: u8 = flag(0, FlagType::IndexMapping); // 0x02
const FLAG_ZERO_COUNT: u8 = flag(1, FlagType::SketchFeatures); // 0x04
const FLAG_COUNT: u8 = flag(0x28, FlagType::SketchFeatures); // 0xA0
const FLAG_SUM: u8 = flag(0x21, FlagType::SketchFeatures); // 0x84
const FLAG_MIN: u8 = flag(0x22, FlagType::SketchFeatures); // 0x88
const FLAG_MAX: u8 = flag(0x23, FlagType::SketchFeatures); // 0x8C
/// BinEncodingMode subflags for store flag bytes.
/// See: https://github.com/DataDog/sketches-java/blob/master/src/main/java/com/datadoghq/sketch/ddsketch/encoding/BinEncodingMode.java
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum BinEncodingMode {
IndexDeltasAndCounts = 1,
IndexDeltas = 2,
ContiguousCounts = 3,
}
impl BinEncodingMode {
fn from_subflag(subflag: u8) -> Option<Self> {
match subflag {
1 => Some(Self::IndexDeltasAndCounts),
2 => Some(Self::IndexDeltas),
3 => Some(Self::ContiguousCounts),
_ => None,
}
}
}
const VAR_DOUBLE_ROTATE_DISTANCE: u32 = 6;
const MAX_VAR_LEN_64: usize = 9;
const DEFAULT_MAX_BINS: u32 = 2048;
// ---------------------------------------------------------------------------
// Error type
// ---------------------------------------------------------------------------
#[derive(Debug, Clone)]
pub enum DecodeError {
UnexpectedEof,
InvalidFlag(u8),
InvalidData(String),
}
impl fmt::Display for DecodeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::UnexpectedEof => write!(f, "unexpected end of input"),
Self::InvalidFlag(b) => write!(f, "invalid flag byte: 0x{b:02X}"),
Self::InvalidData(msg) => write!(f, "invalid data: {msg}"),
}
}
}
impl std::error::Error for DecodeError {}
// ---------------------------------------------------------------------------
// VarEncoding — bit-exact port of Java VarEncodingHelper
// See: https://github.com/DataDog/sketches-java/blob/master/src/main/java/com/datadoghq/sketch/ddsketch/encoding/VarEncodingHelper.java
// ---------------------------------------------------------------------------
fn encode_unsigned_var_long(out: &mut Vec<u8>, mut value: u64) {
let length = ((63 - value.leading_zeros() as i32) / 7).clamp(0, 8);
for _ in 0..length {
out.push((value as u8) | 0x80);
value >>= 7;
}
out.push(value as u8);
}
fn decode_unsigned_var_long(input: &mut &[u8]) -> Result<u64, DecodeError> {
let mut value: u64 = 0;
let mut shift: u32 = 0;
loop {
let next = read_byte(input)?;
if next < 0x80 || shift == 56 {
return Ok(value | (u64::from(next) << shift));
}
value |= (u64::from(next) & 0x7F) << shift;
shift += 7;
}
}
/// ZigZag encode then var-long encode.
fn encode_signed_var_long(out: &mut Vec<u8>, value: i64) {
let encoded = ((value >> 63) ^ (value << 1)) as u64;
encode_unsigned_var_long(out, encoded);
}
fn decode_signed_var_long(input: &mut &[u8]) -> Result<i64, DecodeError> {
let encoded = decode_unsigned_var_long(input)?;
Ok(((encoded >> 1) as i64) ^ -((encoded & 1) as i64))
}
fn double_to_var_bits(value: f64) -> u64 {
let bits = f64::to_bits(value + 1.0).wrapping_sub(f64::to_bits(1.0));
bits.rotate_left(VAR_DOUBLE_ROTATE_DISTANCE)
}
fn var_bits_to_double(bits: u64) -> f64 {
f64::from_bits(
bits.rotate_right(VAR_DOUBLE_ROTATE_DISTANCE)
.wrapping_add(f64::to_bits(1.0)),
) - 1.0
}
fn encode_var_double(out: &mut Vec<u8>, value: f64) {
let mut bits = double_to_var_bits(value);
for _ in 0..MAX_VAR_LEN_64 - 1 {
let next = (bits >> 57) as u8;
bits <<= 7;
if bits == 0 {
out.push(next);
return;
}
out.push(next | 0x80);
}
out.push((bits >> 56) as u8);
}
fn decode_var_double(input: &mut &[u8]) -> Result<f64, DecodeError> {
let mut bits: u64 = 0;
let mut shift: i32 = 57; // 8*8 - 7
loop {
let next = read_byte(input)?;
if shift == 1 {
bits |= u64::from(next);
break;
}
if next < 0x80 {
bits |= u64::from(next) << shift;
break;
}
bits |= (u64::from(next) & 0x7F) << shift;
shift -= 7;
}
Ok(var_bits_to_double(bits))
}
// ---------------------------------------------------------------------------
// Byte-level helpers
// ---------------------------------------------------------------------------
fn read_byte(input: &mut &[u8]) -> Result<u8, DecodeError> {
match input.split_first() {
Some((&byte, rest)) => {
*input = rest;
Ok(byte)
}
None => Err(DecodeError::UnexpectedEof),
}
}
fn write_f64_le(out: &mut Vec<u8>, value: f64) {
out.extend_from_slice(&value.to_le_bytes());
}
fn read_f64_le(input: &mut &[u8]) -> Result<f64, DecodeError> {
if input.len() < 8 {
return Err(DecodeError::UnexpectedEof);
}
let (bytes, rest) = input.split_at(8);
*input = rest;
// bytes is guaranteed to be length 8 by the split_at above.
let arr = [
bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
];
Ok(f64::from_le_bytes(arr))
}
// ---------------------------------------------------------------------------
// Store encoding/decoding
// See: https://github.com/DataDog/sketches-java/blob/master/src/main/java/com/datadoghq/sketch/ddsketch/store/DenseStore.java (encode/decode methods)
// ---------------------------------------------------------------------------
/// Collect non-zero bins in the store as (absolute_index, count) pairs.
///
/// Allocation is acceptable here: this runs once per encode and the Vec
/// has at most `max_num_bins` entries.
fn collect_non_zero_bins(store: &Store) -> Vec<(i32, u64)> {
if store.count == 0 {
return Vec::new();
}
let start = (store.min_key - store.offset) as usize;
let end = ((store.max_key - store.offset + 1) as usize).min(store.bins.len());
store.bins[start..end]
.iter()
.enumerate()
.filter(|&(_, &count)| count > 0)
.map(|(i, &count)| (start as i32 + i as i32 + store.offset, count))
.collect()
}
fn encode_store(out: &mut Vec<u8>, store: &Store, flag_type: FlagType) {
let bins = collect_non_zero_bins(store);
if bins.is_empty() {
return;
}
out.push(flag(BinEncodingMode::IndexDeltasAndCounts as u8, flag_type));
encode_unsigned_var_long(out, bins.len() as u64);
let mut prev_index: i64 = 0;
for &(index, count) in &bins {
encode_signed_var_long(out, i64::from(index) - prev_index);
encode_var_double(out, count as f64);
prev_index = i64::from(index);
}
}
fn decode_store(input: &mut &[u8], subflag: u8, bin_limit: usize) -> Result<Store, DecodeError> {
let mode = BinEncodingMode::from_subflag(subflag).ok_or_else(|| {
DecodeError::InvalidData(format!("unknown bin encoding mode subflag: {subflag}"))
})?;
let num_bins = decode_unsigned_var_long(input)? as usize;
let mut store = Store::new(bin_limit);
match mode {
BinEncodingMode::IndexDeltasAndCounts => {
let mut index: i64 = 0;
for _ in 0..num_bins {
index += decode_signed_var_long(input)?;
let count = decode_var_double(input)?;
store.add_count(index as i32, count as u64);
}
}
BinEncodingMode::IndexDeltas => {
let mut index: i64 = 0;
for _ in 0..num_bins {
index += decode_signed_var_long(input)?;
store.add_count(index as i32, 1);
}
}
BinEncodingMode::ContiguousCounts => {
let start_index = decode_signed_var_long(input)?;
let index_delta = decode_signed_var_long(input)?;
let mut index = start_index;
for _ in 0..num_bins {
let count = decode_var_double(input)?;
store.add_count(index as i32, count as u64);
index += index_delta;
}
}
}
Ok(store)
}
// ---------------------------------------------------------------------------
// Top-level encode / decode
// ---------------------------------------------------------------------------
/// Encode a DDSketch into the Java-compatible binary format.
///
/// The output follows the encoding order of
/// `DDSketchWithExactSummaryStatistics.encode()` then `DDSketch.encode()`:
///
/// 1. Summary statistics: COUNT, MIN, MAX (if count > 0)
/// 2. SUM (if sum != 0)
/// 3. Index mapping (LOG layout): gamma, indexOffset
/// 4. Zero count (if > 0)
/// 5. Positive store bins
/// 6. Negative store bins
pub fn encode_to_java_bytes(sketch: &DDSketch) -> Vec<u8> {
let mut out = Vec::new();
let count = sketch.count() as f64;
// Summary statistics (DDSketchWithExactSummaryStatistics.encode)
if count != 0.0 {
out.push(FLAG_COUNT);
encode_var_double(&mut out, count);
out.push(FLAG_MIN);
write_f64_le(&mut out, sketch.min);
out.push(FLAG_MAX);
write_f64_le(&mut out, sketch.max);
}
if sketch.sum != 0.0 {
out.push(FLAG_SUM);
write_f64_le(&mut out, sketch.sum);
}
// DDSketch.encode: index mapping + zero count + stores
out.push(FLAG_INDEX_MAPPING_LOG);
write_f64_le(&mut out, sketch.config.gamma);
write_f64_le(&mut out, 0.0_f64);
if sketch.zero_count != 0 {
out.push(FLAG_ZERO_COUNT);
encode_var_double(&mut out, sketch.zero_count as f64);
}
encode_store(&mut out, &sketch.store, FlagType::PositiveStore);
encode_store(&mut out, &sketch.negative_store, FlagType::NegativeStore);
out
}
/// Decode a DDSketch from the Java-compatible binary format.
///
/// Accepts bytes with or without a `0x02` version prefix.
pub fn decode_from_java_bytes(bytes: &[u8]) -> Result<DDSketch, DecodeError> {
if bytes.is_empty() {
return Err(DecodeError::UnexpectedEof);
}
let mut input = bytes;
// Skip optional version prefix (0x02 followed by a valid flag byte).
if input.len() >= 2 && input[0] == 0x02 && is_valid_flag_byte(input[1]) {
input = &input[1..];
}
let mut gamma: Option<f64> = None;
let mut zero_count: f64 = 0.0;
let mut sum: f64 = 0.0;
let mut min: f64 = f64::INFINITY;
let mut max: f64 = f64::NEG_INFINITY;
let mut positive_store: Option<Store> = None;
let mut negative_store: Option<Store> = None;
while !input.is_empty() {
let flag_byte = read_byte(&mut input)?;
let flag_type =
FlagType::from_byte(flag_byte).ok_or(DecodeError::InvalidFlag(flag_byte))?;
let subflag = flag_byte >> 2;
match flag_type {
FlagType::IndexMapping => {
gamma = Some(read_f64_le(&mut input)?);
let _index_offset = read_f64_le(&mut input)?;
}
FlagType::SketchFeatures => match flag_byte {
FLAG_ZERO_COUNT => zero_count += decode_var_double(&mut input)?,
FLAG_COUNT => {
let _count = decode_var_double(&mut input)?;
}
FLAG_SUM => sum = read_f64_le(&mut input)?,
FLAG_MIN => min = read_f64_le(&mut input)?,
FLAG_MAX => max = read_f64_le(&mut input)?,
_ => return Err(DecodeError::InvalidFlag(flag_byte)),
},
FlagType::PositiveStore => {
positive_store = Some(decode_store(
&mut input,
subflag,
DEFAULT_MAX_BINS as usize,
)?);
}
FlagType::NegativeStore => {
negative_store = Some(decode_store(
&mut input,
subflag,
DEFAULT_MAX_BINS as usize,
)?);
}
}
}
let g = gamma.unwrap_or_else(|| Config::defaults().gamma);
let config = Config::from_gamma(g);
let store = positive_store.unwrap_or_else(|| Store::new(config.max_num_bins as usize));
let neg = negative_store.unwrap_or_else(|| Store::new(config.max_num_bins as usize));
Ok(DDSketch {
config,
store,
negative_store: neg,
min,
max,
sum,
zero_count: zero_count as u64,
})
}
/// Check whether a byte is a valid flag byte for the DDSketch binary format.
fn is_valid_flag_byte(b: u8) -> bool {
// Known sketch-feature flags
if matches!(
b,
FLAG_ZERO_COUNT | FLAG_COUNT | FLAG_SUM | FLAG_MIN | FLAG_MAX | FLAG_INDEX_MAPPING_LOG
) {
return true;
}
let Some(flag_type) = FlagType::from_byte(b) else {
return false;
};
let subflag = b >> 2;
match flag_type {
FlagType::PositiveStore | FlagType::NegativeStore => (1..=3).contains(&subflag),
FlagType::IndexMapping => subflag <= 4, // LOG=0, LOG_LINEAR=1 .. LOG_QUARTIC=4
_ => false,
}
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
use crate::{Config, DDSketch};
// --- VarEncoding unit tests ---
#[test]
fn test_unsigned_var_long_zero() {
let mut buf = Vec::new();
encode_unsigned_var_long(&mut buf, 0);
assert_eq!(buf, [0x00]);
let mut input = buf.as_slice();
assert_eq!(decode_unsigned_var_long(&mut input).unwrap(), 0);
assert!(input.is_empty());
}
#[test]
fn test_unsigned_var_long_small() {
let mut buf = Vec::new();
encode_unsigned_var_long(&mut buf, 1);
assert_eq!(buf, [0x01]);
let mut input = buf.as_slice();
assert_eq!(decode_unsigned_var_long(&mut input).unwrap(), 1);
}
#[test]
fn test_unsigned_var_long_128() {
let mut buf = Vec::new();
encode_unsigned_var_long(&mut buf, 128);
assert_eq!(buf, [0x80, 0x01]);
let mut input = buf.as_slice();
assert_eq!(decode_unsigned_var_long(&mut input).unwrap(), 128);
}
#[test]
fn test_unsigned_var_long_roundtrip() {
for v in [0u64, 1, 127, 128, 255, 256, 16383, 16384, u64::MAX] {
let mut buf = Vec::new();
encode_unsigned_var_long(&mut buf, v);
let mut input = buf.as_slice();
let decoded = decode_unsigned_var_long(&mut input).unwrap();
assert_eq!(decoded, v, "roundtrip failed for {}", v);
assert!(input.is_empty());
}
}
#[test]
fn test_signed_var_long_roundtrip() {
for v in [0i64, 1, -1, 63, -64, 64, -65, i64::MAX, i64::MIN] {
let mut buf = Vec::new();
encode_signed_var_long(&mut buf, v);
let mut input = buf.as_slice();
let decoded = decode_signed_var_long(&mut input).unwrap();
assert_eq!(decoded, v, "roundtrip failed for {}", v);
assert!(input.is_empty());
}
}
#[test]
fn test_var_double_roundtrip() {
for v in [0.0, 1.0, 2.0, 5.0, 15.0, 42.0, 100.0, 1e-9, 1e15, 0.5, 7.77] {
let mut buf = Vec::new();
encode_var_double(&mut buf, v);
let mut input = buf.as_slice();
let decoded = decode_var_double(&mut input).unwrap();
assert!(
(decoded - v).abs() < 1e-15 || decoded == v,
"roundtrip failed for {}: got {}",
v,
decoded,
);
assert!(input.is_empty());
}
}
#[test]
fn test_var_double_small_integers() {
let mut buf = Vec::new();
encode_var_double(&mut buf, 1.0);
assert_eq!(buf.len(), 1, "VarDouble(1.0) should be 1 byte");
buf.clear();
encode_var_double(&mut buf, 5.0);
assert_eq!(buf.len(), 1, "VarDouble(5.0) should be 1 byte");
}
// --- DDSketch encode/decode roundtrip tests ---
#[test]
fn test_encode_empty_sketch() {
let sketch = DDSketch::new(Config::defaults());
let bytes = sketch.to_java_bytes();
assert!(!bytes.is_empty());
let decoded = DDSketch::from_java_bytes(&bytes).unwrap();
assert_eq!(decoded.count(), 0);
assert_eq!(decoded.min(), None);
assert_eq!(decoded.max(), None);
assert_eq!(decoded.sum(), None);
}
#[test]
fn test_encode_simple_sketch() {
let mut sketch = DDSketch::new(Config::defaults());
for v in [1.0, 2.0, 3.0, 4.0, 5.0] {
sketch.add(v);
}
let bytes = sketch.to_java_bytes();
let decoded = DDSketch::from_java_bytes(&bytes).unwrap();
assert_eq!(decoded.count(), 5);
assert_eq!(decoded.min(), Some(1.0));
assert_eq!(decoded.max(), Some(5.0));
assert_eq!(decoded.sum(), Some(15.0));
assert_quantiles_match(&sketch, &decoded, &[0.5, 0.9, 0.95, 0.99]);
}
#[test]
fn test_encode_single_value() {
let mut sketch = DDSketch::new(Config::defaults());
sketch.add(42.0);
let bytes = sketch.to_java_bytes();
let decoded = DDSketch::from_java_bytes(&bytes).unwrap();
assert_eq!(decoded.count(), 1);
assert_eq!(decoded.min(), Some(42.0));
assert_eq!(decoded.max(), Some(42.0));
assert_eq!(decoded.sum(), Some(42.0));
}
#[test]
fn test_encode_negative_values() {
let mut sketch = DDSketch::new(Config::defaults());
for v in [-3.0, -1.0, 2.0, 5.0] {
sketch.add(v);
}
let bytes = sketch.to_java_bytes();
let decoded = DDSketch::from_java_bytes(&bytes).unwrap();
assert_eq!(decoded.count(), 4);
assert_eq!(decoded.min(), Some(-3.0));
assert_eq!(decoded.max(), Some(5.0));
assert_eq!(decoded.sum(), Some(3.0));
assert_quantiles_match(&sketch, &decoded, &[0.0, 0.25, 0.5, 0.75, 1.0]);
}
#[test]
fn test_encode_with_zero_value() {
let mut sketch = DDSketch::new(Config::defaults());
for v in [0.0, 1.0, 2.0] {
sketch.add(v);
}
let bytes = sketch.to_java_bytes();
let decoded = DDSketch::from_java_bytes(&bytes).unwrap();
assert_eq!(decoded.count(), 3);
assert_eq!(decoded.min(), Some(0.0));
assert_eq!(decoded.max(), Some(2.0));
assert_eq!(decoded.sum(), Some(3.0));
assert_eq!(decoded.zero_count, 1);
}
#[test]
fn test_encode_large_range() {
let mut sketch = DDSketch::new(Config::defaults());
sketch.add(0.001);
sketch.add(1_000_000.0);
let bytes = sketch.to_java_bytes();
let decoded = DDSketch::from_java_bytes(&bytes).unwrap();
assert_eq!(decoded.count(), 2);
assert_eq!(decoded.min(), Some(0.001));
assert_eq!(decoded.max(), Some(1_000_000.0));
}
#[test]
fn test_encode_with_version_prefix() {
let mut sketch = DDSketch::new(Config::defaults());
for v in [1.0, 2.0, 3.0] {
sketch.add(v);
}
let bytes = sketch.to_java_bytes();
// Simulate Java's toByteArrayV2: prepend 0x02
let mut v2_bytes = vec![0x02];
v2_bytes.extend_from_slice(&bytes);
let decoded = DDSketch::from_java_bytes(&v2_bytes).unwrap();
assert_eq!(decoded.count(), 3);
assert_eq!(decoded.min(), Some(1.0));
assert_eq!(decoded.max(), Some(3.0));
}
#[test]
fn test_byte_level_encoding() {
let mut sketch = DDSketch::new(Config::defaults());
sketch.add(1.0);
let bytes = sketch.to_java_bytes();
assert_eq!(bytes[0], FLAG_COUNT, "first byte should be COUNT flag");
assert!(
bytes.contains(&FLAG_INDEX_MAPPING_LOG),
"should contain index mapping flag"
);
}
// --- Cross-language golden byte tests ---
//
// Golden bytes generated by Java's DDSketchWithExactSummaryStatistics.encode()
// using LogarithmicMapping(0.01) + CollapsingLowestDenseStore(2048).
const GOLDEN_SIMPLE: &str = "a00588000000000000f03f8c0000000000001440840000000000002e4002fd4a815abf52f03f000000000000000005050002440228021e021602";
const GOLDEN_SINGLE: &str = "a0028800000000000045408c000000000000454084000000000000454002fd4a815abf52f03f00000000000000000501f40202";
const GOLDEN_NEGATIVE: &str = "a084408800000000000008c08c000000000000144084000000000000084002fd4a815abf52f03f0000000000000000050244025c02070200026c02";
const GOLDEN_ZERO: &str = "a0048800000000000000008c000000000000004084000000000000084002fd4a815abf52f03f00000000000000000402050200024402";
const GOLDEN_EMPTY: &str = "02fd4a815abf52f03f0000000000000000";
const GOLDEN_MANY: &str = "a08d1488000000000000f03f8c0000000000005940840000000000bab34002fd4a815abf52f03f000000000000000005550002440228021e021602120210020c020c020c0208020a020802060208020602060206020602040206020402040204020402040204020402040204020202040202020402020204020202020204020202020202020402020202020202020202020202020202020202020202020202020202020203020202020202020302020202020302020202020302020203020202030202020302030202020302030203020202030203020302030202";
fn hex_to_bytes(hex: &str) -> Vec<u8> {
(0..hex.len())
.step_by(2)
.map(|i| u8::from_str_radix(&hex[i..i + 2], 16).unwrap())
.collect()
}
fn bytes_to_hex(bytes: &[u8]) -> String {
bytes.iter().map(|b| format!("{b:02x}")).collect()
}
fn assert_golden(label: &str, sketch: &DDSketch, golden_hex: &str) {
let bytes = sketch.to_java_bytes();
let expected = hex_to_bytes(golden_hex);
assert_eq!(
bytes,
expected,
"Rust encoding doesn't match Java golden bytes for {}.\nRust: {}\nJava: {}",
label,
bytes_to_hex(&bytes),
golden_hex,
);
}
fn assert_quantiles_match(a: &DDSketch, b: &DDSketch, quantiles: &[f64]) {
for &q in quantiles {
let va = a.quantile(q).unwrap().unwrap();
let vb = b.quantile(q).unwrap().unwrap();
assert!(
(va - vb).abs() / va.abs().max(1e-15) < 1e-12,
"quantile({}) mismatch: {} vs {}",
q,
va,
vb,
);
}
}
#[test]
fn test_cross_language_simple() {
let mut sketch = DDSketch::new(Config::defaults());
for v in [1.0, 2.0, 3.0, 4.0, 5.0] {
sketch.add(v);
}
assert_golden("SIMPLE", &sketch, GOLDEN_SIMPLE);
}
#[test]
fn test_cross_language_single() {
let mut sketch = DDSketch::new(Config::defaults());
sketch.add(42.0);
assert_golden("SINGLE", &sketch, GOLDEN_SINGLE);
}
#[test]
fn test_cross_language_negative() {
let mut sketch = DDSketch::new(Config::defaults());
for v in [-3.0, -1.0, 2.0, 5.0] {
sketch.add(v);
}
assert_golden("NEGATIVE", &sketch, GOLDEN_NEGATIVE);
}
#[test]
fn test_cross_language_zero() {
let mut sketch = DDSketch::new(Config::defaults());
for v in [0.0, 1.0, 2.0] {
sketch.add(v);
}
assert_golden("ZERO", &sketch, GOLDEN_ZERO);
}
#[test]
fn test_cross_language_empty() {
let sketch = DDSketch::new(Config::defaults());
assert_golden("EMPTY", &sketch, GOLDEN_EMPTY);
}
#[test]
fn test_cross_language_many() {
let mut sketch = DDSketch::new(Config::defaults());
for i in 1..=100 {
sketch.add(i as f64);
}
assert_golden("MANY", &sketch, GOLDEN_MANY);
}
#[test]
fn test_decode_java_golden_bytes() {
for (name, hex) in [
("SIMPLE", GOLDEN_SIMPLE),
("SINGLE", GOLDEN_SINGLE),
("NEGATIVE", GOLDEN_NEGATIVE),
("ZERO", GOLDEN_ZERO),
("EMPTY", GOLDEN_EMPTY),
("MANY", GOLDEN_MANY),
] {
let bytes = hex_to_bytes(hex);
let result = DDSketch::from_java_bytes(&bytes);
assert!(
result.is_ok(),
"failed to decode {}: {:?}",
name,
result.err()
);
}
}
#[test]
fn test_encode_decode_many_values() {
let mut sketch = DDSketch::new(Config::defaults());
for i in 1..=100 {
sketch.add(i as f64);
}
let bytes = sketch.to_java_bytes();
let decoded = DDSketch::from_java_bytes(&bytes).unwrap();
assert_eq!(decoded.count(), 100);
assert_eq!(decoded.min(), Some(1.0));
assert_eq!(decoded.max(), Some(100.0));
assert_eq!(decoded.sum(), Some(5050.0));
let alpha = 0.01;
let orig_p95 = sketch.quantile(0.95).unwrap().unwrap();
let dec_p95 = decoded.quantile(0.95).unwrap().unwrap();
assert!(
(orig_p95 - dec_p95).abs() / orig_p95 < alpha,
"p95 mismatch: {} vs {}",
orig_p95,
dec_p95,
);
}
}

View File

@@ -0,0 +1,52 @@
//! This crate provides a direct port of the [Golang](https://github.com/DataDog/sketches-go)
//! [DDSketch](https://arxiv.org/pdf/1908.10693.pdf) implementation to Rust. All efforts
//! have been made to keep this as close to the original implementation as possible, with a few
//! tweaks to get closer to idiomatic Rust.
//!
//! # Usage
//!
//! Add multiple samples to a DDSketch and invoke the `quantile` method to pull any quantile from
//! 0.0* to *1.0*.
//!
//! ```rust
//! use sketches_ddsketch::{Config, DDSketch};
//!
//! let c = Config::defaults();
//! let mut d = DDSketch::new(c);
//!
//! d.add(1.0);
//! d.add(1.0);
//! d.add(1.0);
//!
//! let q = d.quantile(0.50).unwrap();
//!
//! assert!(q < Some(1.02));
//! assert!(q > Some(0.98));
//! ```
//!
//! Sketches can also be merged.
//!
//! ```rust
//! use sketches_ddsketch::{Config, DDSketch};
//!
//! let c = Config::defaults();
//! let mut d1 = DDSketch::new(c);
//! let mut d2 = DDSketch::new(c);
//!
//! d1.add(1.0);
//! d2.add(2.0);
//! d2.add(2.0);
//!
//! d1.merge(&d2);
//!
//! assert_eq!(d1.count(), 3);
//! ```
pub use self::config::Config;
pub use self::ddsketch::{DDSketch, DDSketchError};
pub use self::encoding::DecodeError;
mod config;
mod ddsketch;
pub mod encoding;
mod store;

View File

@@ -0,0 +1,252 @@
#[cfg(feature = "use_serde")]
use serde::{Deserialize, Serialize};
const CHUNK_SIZE: i32 = 128;
// Divide the `dividend` by the `divisor`, rounding towards positive infinity.
//
// Similar to the nightly only `std::i32::div_ceil`.
fn div_ceil(dividend: i32, divisor: i32) -> i32 {
(dividend + divisor - 1) / divisor
}
/// CollapsingLowestDenseStore
#[derive(Clone, Debug)]
#[cfg_attr(feature = "use_serde", derive(Serialize, Deserialize))]
pub struct Store {
pub(crate) bins: Vec<u64>,
pub(crate) count: u64,
pub(crate) min_key: i32,
pub(crate) max_key: i32,
pub(crate) offset: i32,
pub(crate) bin_limit: usize,
is_collapsed: bool,
}
impl Store {
pub fn new(bin_limit: usize) -> Self {
Store {
bins: Vec::new(),
count: 0,
min_key: i32::MAX,
max_key: i32::MIN,
offset: 0,
bin_limit,
is_collapsed: false,
}
}
/// Return the number of bins.
pub fn length(&self) -> i32 {
self.bins.len() as i32
}
pub fn is_empty(&self) -> bool {
self.bins.is_empty()
}
pub fn add(&mut self, key: i32) {
let idx = self.get_index(key);
self.bins[idx] += 1;
self.count += 1;
}
/// See Java: https://github.com/DataDog/sketches-java/blob/master/src/main/java/com/datadoghq/sketch/ddsketch/store/DenseStore.java (add(int index, double count) method)
pub(crate) fn add_count(&mut self, key: i32, count: u64) {
let idx = self.get_index(key);
self.bins[idx] += count;
self.count += count;
}
fn get_index(&mut self, key: i32) -> usize {
if key < self.min_key {
if self.is_collapsed {
return 0;
}
self.extend_range(key, None);
if self.is_collapsed {
return 0;
}
} else if key > self.max_key {
self.extend_range(key, None);
}
(key - self.offset) as usize
}
fn extend_range(&mut self, key: i32, second_key: Option<i32>) {
let second_key = second_key.unwrap_or(key);
let new_min_key = i32::min(key, i32::min(second_key, self.min_key));
let new_max_key = i32::max(key, i32::max(second_key, self.max_key));
if self.is_empty() {
let new_len = self.get_new_length(new_min_key, new_max_key);
self.bins.resize(new_len, 0);
self.offset = new_min_key;
self.adjust(new_min_key, new_max_key);
} else if new_min_key >= self.min_key && new_max_key < self.offset + self.length() {
self.min_key = new_min_key;
self.max_key = new_max_key;
} else {
// Grow bins
let new_length = self.get_new_length(new_min_key, new_max_key);
if new_length > self.length() as usize {
self.bins.resize(new_length, 0);
}
self.adjust(new_min_key, new_max_key);
}
}
fn get_new_length(&self, new_min_key: i32, new_max_key: i32) -> usize {
let desired_length = new_max_key - new_min_key + 1;
usize::min(
(CHUNK_SIZE * div_ceil(desired_length, CHUNK_SIZE)) as usize,
self.bin_limit,
)
}
fn adjust(&mut self, new_min_key: i32, new_max_key: i32) {
if new_max_key - new_min_key + 1 > self.length() {
let new_min_key = new_max_key - self.length() + 1;
if new_min_key >= self.max_key {
// Put everything in the first bin.
self.offset = new_min_key;
self.min_key = new_min_key;
self.bins.fill(0);
self.bins[0] = self.count;
} else {
let shift = self.offset - new_min_key;
if shift < 0 {
let collapse_start_index = (self.min_key - self.offset) as usize;
let collapse_end_index = (new_min_key - self.offset) as usize;
let collapsed_count: u64 = self.bins[collapse_start_index..collapse_end_index]
.iter()
.sum();
let zero_len = (new_min_key - self.min_key) as usize;
self.bins.splice(
collapse_start_index..collapse_end_index,
std::iter::repeat_n(0, zero_len),
);
self.bins[collapse_end_index] += collapsed_count;
}
self.min_key = new_min_key;
self.shift_bins(shift);
}
self.max_key = new_max_key;
self.is_collapsed = true;
} else {
self.center_bins(new_min_key, new_max_key);
self.min_key = new_min_key;
self.max_key = new_max_key;
}
}
fn shift_bins(&mut self, shift: i32) {
if shift > 0 {
let shift = shift as usize;
self.bins.rotate_right(shift);
for idx in 0..shift {
self.bins[idx] = 0;
}
} else {
let shift = shift.unsigned_abs() as usize;
for idx in 0..shift {
self.bins[idx] = 0;
}
self.bins.rotate_left(shift);
}
self.offset -= shift;
}
fn center_bins(&mut self, new_min_key: i32, new_max_key: i32) {
let middle_key = new_min_key + (new_max_key - new_min_key + 1) / 2;
let shift = self.offset + self.length() / 2 - middle_key;
self.shift_bins(shift)
}
pub fn key_at_rank(&self, rank: u64) -> i32 {
let mut n = 0;
for (i, bin) in self.bins.iter().enumerate() {
n += *bin;
if n > rank {
return i as i32 + self.offset;
}
}
self.max_key
}
pub fn count(&self) -> u64 {
self.count
}
pub fn merge(&mut self, other: &Store) {
if other.count == 0 {
return;
}
if self.count == 0 {
self.copy(other);
return;
}
if other.min_key < self.min_key || other.max_key > self.max_key {
self.extend_range(other.min_key, Some(other.max_key));
}
let collapse_start_index = other.min_key - other.offset;
let mut collapse_end_index = i32::min(self.min_key, other.max_key + 1) - other.offset;
if collapse_end_index > collapse_start_index {
let collapsed_count: u64 = self.bins
[collapse_start_index as usize..collapse_end_index as usize]
.iter()
.sum();
self.bins[0] += collapsed_count;
} else {
collapse_end_index = collapse_start_index;
}
for key in (collapse_end_index + other.offset)..(other.max_key + 1) {
self.bins[(key - self.offset) as usize] += other.bins[(key - other.offset) as usize]
}
self.count += other.count;
}
fn copy(&mut self, o: &Store) {
self.bins = o.bins.clone();
self.count = o.count;
self.min_key = o.min_key;
self.max_key = o.max_key;
self.offset = o.offset;
self.bin_limit = o.bin_limit;
self.is_collapsed = o.is_collapsed;
}
}
#[cfg(test)]
mod tests {
use crate::store::Store;
#[test]
fn test_simple_store() {
let mut s = Store::new(2048);
for i in 0..2048 {
s.add(i);
}
}
#[test]
fn test_simple_store_rev() {
let mut s = Store::new(2048);
for i in (0..2048).rev() {
s.add(i);
}
}
}

View File

@@ -0,0 +1,88 @@
use std::cmp::Ordering;
use std::f64::NAN;
pub struct Dataset {
values: Vec<f64>,
sum: f64,
sorted: bool,
}
fn cmp_f64(a: &f64, b: &f64) -> Ordering {
assert!(!a.is_nan() && !b.is_nan());
if a < b {
return Ordering::Less;
} else if a > b {
return Ordering::Greater;
} else {
return Ordering::Equal;
}
}
impl Dataset {
pub fn new() -> Self {
Dataset {
values: Vec::new(),
sum: 0.0,
sorted: false,
}
}
pub fn add(&mut self, value: f64) {
self.values.push(value);
self.sum += value;
self.sorted = false;
}
// pub fn quantile(&mut self, q: f64) -> f64 {
// self.lower_quantile(q)
// }
pub fn lower_quantile(&mut self, q: f64) -> f64 {
if q < 0.0 || q > 1.0 || self.values.len() == 0 {
return NAN;
}
self.sort();
let rank = q * (self.values.len() - 1) as f64;
self.values[rank.floor() as usize]
}
pub fn upper_quantile(&mut self, q: f64) -> f64 {
if q < 0.0 || q > 1.0 || self.values.len() == 0 {
return NAN;
}
self.sort();
let rank = q * (self.values.len() - 1) as f64;
self.values[rank.ceil() as usize]
}
pub fn min(&mut self) -> f64 {
self.sort();
self.values[0]
}
pub fn max(&mut self) -> f64 {
self.sort();
self.values[self.values.len() - 1]
}
pub fn sum(&self) -> f64 {
self.sum
}
pub fn count(&self) -> usize {
self.values.len()
}
fn sort(&mut self) {
if self.sorted {
return;
}
self.values.sort_by(cmp_f64);
self.sorted = true;
}
}

View File

@@ -0,0 +1,100 @@
extern crate rand;
extern crate rand_distr;
use rand::prelude::*;
pub trait Generator {
fn generate(&mut self) -> f64;
}
// Constant generator
//
pub struct Constant {
value: f64,
}
impl Constant {
pub fn new(value: f64) -> Self {
Constant { value }
}
}
impl Generator for Constant {
fn generate(&mut self) -> f64 {
self.value
}
}
// Linear generator
//
pub struct Linear {
current_value: f64,
step: f64,
}
impl Linear {
pub fn new(start_value: f64, step: f64) -> Self {
Linear {
current_value: start_value,
step,
}
}
}
impl Generator for Linear {
fn generate(&mut self) -> f64 {
let value = self.current_value;
self.current_value += self.step;
value
}
}
// Normal distribution generator
//
pub struct Normal {
distr: rand_distr::Normal<f64>,
}
impl Normal {
pub fn new(mean: f64, stddev: f64) -> Self {
Normal {
distr: rand_distr::Normal::new(mean, stddev).unwrap(),
}
}
}
impl Generator for Normal {
fn generate(&mut self) -> f64 {
self.distr.sample(&mut rand::thread_rng())
}
}
// Lognormal distribution generator
//
pub struct Lognormal {
distr: rand_distr::LogNormal<f64>,
}
impl Lognormal {
pub fn new(mean: f64, stddev: f64) -> Self {
Lognormal {
distr: rand_distr::LogNormal::new(mean, stddev).unwrap(),
}
}
}
impl Generator for Lognormal {
fn generate(&mut self) -> f64 {
self.distr.sample(&mut rand::thread_rng())
}
}
// Exponential distribution generator
//
pub struct Exponential {
distr: rand_distr::Exp<f64>,
}
impl Exponential {
pub fn new(lambda: f64) -> Self {
Exponential {
distr: rand_distr::Exp::new(lambda).unwrap(),
}
}
}
impl Generator for Exponential {
fn generate(&mut self) -> f64 {
self.distr.sample(&mut rand::thread_rng())
}
}

View File

@@ -0,0 +1,2 @@
pub mod dataset;
pub mod generator;

View File

@@ -0,0 +1,316 @@
mod common;
use std::time::Instant;
use common::dataset::Dataset;
use common::generator;
use common::generator::Generator;
use sketches_ddsketch::{Config, DDSketch};
const TEST_ALPHA: f64 = 0.01;
const TEST_MAX_BINS: u32 = 1024;
const TEST_MIN_VALUE: f64 = 1.0e-9;
// Used for float equality
const TEST_ERROR_THRESH: f64 = 1.0e-9;
const TEST_SIZES: [usize; 5] = [3, 5, 10, 100, 1000];
const TEST_QUANTILES: [f64; 10] = [0.0, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99, 0.999, 1.0];
#[test]
fn test_constant() {
evaluate_sketches(|| Box::new(generator::Constant::new(42.0)));
}
#[test]
fn test_linear() {
evaluate_sketches(|| Box::new(generator::Linear::new(0.0, 1.0)));
}
#[test]
fn test_normal() {
evaluate_sketches(|| Box::new(generator::Normal::new(35.0, 1.0)));
}
#[test]
fn test_lognormal() {
evaluate_sketches(|| Box::new(generator::Lognormal::new(0.0, 2.0)));
}
#[test]
fn test_exponential() {
evaluate_sketches(|| Box::new(generator::Exponential::new(2.0)));
}
fn evaluate_test_sizes(f: impl Fn(usize)) {
for sz in &TEST_SIZES {
f(*sz);
}
}
fn evaluate_sketches(gen_factory: impl Fn() -> Box<dyn generator::Generator>) {
evaluate_test_sizes(|sz: usize| {
let mut generator = gen_factory();
evaluate_sketch(sz, &mut generator);
});
}
fn new_config() -> Config {
Config::new(TEST_ALPHA, TEST_MAX_BINS, TEST_MIN_VALUE)
}
fn assert_float_eq(a: f64, b: f64) {
assert!((a - b).abs() < TEST_ERROR_THRESH, "{} != {}", a, b);
}
fn evaluate_sketch(count: usize, generator: &mut Box<dyn generator::Generator>) {
let c = new_config();
let mut g = DDSketch::new(c);
let mut d = Dataset::new();
for _i in 0..count {
let value = generator.generate();
g.add(value);
d.add(value);
}
compare_sketches(&mut d, &g);
}
fn compare_sketches(d: &mut Dataset, g: &DDSketch) {
for q in &TEST_QUANTILES {
let lower = d.lower_quantile(*q);
let upper = d.upper_quantile(*q);
let min_expected;
if lower < 0.0 {
min_expected = lower * (1.0 + TEST_ALPHA);
} else {
min_expected = lower * (1.0 - TEST_ALPHA);
}
let max_expected;
if upper > 0.0 {
max_expected = upper * (1.0 + TEST_ALPHA);
} else {
max_expected = upper * (1.0 - TEST_ALPHA);
}
let quantile = g.quantile(*q).unwrap().unwrap();
assert!(
min_expected <= quantile,
"Lower than min, quantile: {}, wanted {} <= {}",
*q,
min_expected,
quantile
);
assert!(
quantile <= max_expected,
"Higher than max, quantile: {}, wanted {} <= {}",
*q,
quantile,
max_expected
);
// verify that calls do not modify result (not mut so not possible?)
let quantile2 = g.quantile(*q).unwrap().unwrap();
assert_eq!(quantile, quantile2);
}
assert_eq!(g.min().unwrap(), d.min());
assert_eq!(g.max().unwrap(), d.max());
assert_float_eq(g.sum().unwrap(), d.sum());
assert_eq!(g.count(), d.count());
}
#[test]
fn test_merge_normal() {
evaluate_test_sizes(|sz: usize| {
let c = new_config();
let mut d = Dataset::new();
let mut g1 = DDSketch::new(c);
let mut generator1 = generator::Normal::new(35.0, 1.0);
for _ in (0..sz).step_by(3) {
let value = generator1.generate();
g1.add(value);
d.add(value);
}
let mut g2 = DDSketch::new(c);
let mut generator2 = generator::Normal::new(50.0, 2.0);
for _ in (1..sz).step_by(3) {
let value = generator2.generate();
g2.add(value);
d.add(value);
}
g1.merge(&g2).unwrap();
let mut g3 = DDSketch::new(c);
let mut generator3 = generator::Normal::new(40.0, 0.5);
for _ in (2..sz).step_by(3) {
let value = generator3.generate();
g3.add(value);
d.add(value);
}
g1.merge(&g3).unwrap();
compare_sketches(&mut d, &g1);
});
}
#[test]
fn test_merge_empty() {
evaluate_test_sizes(|sz: usize| {
let c = new_config();
let mut d = Dataset::new();
let mut g1 = DDSketch::new(c);
let mut g2 = DDSketch::new(c);
let mut generator = generator::Exponential::new(5.0);
for _ in 0..sz {
let value = generator.generate();
g2.add(value);
d.add(value);
}
g1.merge(&g2).unwrap();
compare_sketches(&mut d, &g1);
let g3 = DDSketch::new(c);
g2.merge(&g3).unwrap();
compare_sketches(&mut d, &g2);
});
}
#[test]
fn test_merge_mixed() {
evaluate_test_sizes(|sz: usize| {
let c = new_config();
let mut d = Dataset::new();
let mut g1 = DDSketch::new(c);
let mut generator1 = generator::Normal::new(100.0, 1.0);
for _ in (0..sz).step_by(3) {
let value = generator1.generate();
g1.add(value);
d.add(value);
}
let mut g2 = DDSketch::new(c);
let mut generator2 = generator::Exponential::new(5.0);
for _ in (1..sz).step_by(3) {
let value = generator2.generate();
g2.add(value);
d.add(value);
}
g1.merge(&g2).unwrap();
let mut g3 = DDSketch::new(c);
let mut generator3 = generator::Exponential::new(0.1);
for _ in (2..sz).step_by(3) {
let value = generator3.generate();
g3.add(value);
d.add(value);
}
g1.merge(&g3).unwrap();
compare_sketches(&mut d, &g1);
})
}
#[test]
fn test_merge_incompatible() {
let c1 = Config::new(TEST_ALPHA, TEST_MAX_BINS, TEST_MIN_VALUE);
let c2 = Config::new(TEST_ALPHA * 2.0, TEST_MAX_BINS, TEST_MIN_VALUE);
let mut d1 = DDSketch::new(c1);
let d2 = DDSketch::new(c2);
assert!(d1.merge(&d2).is_err());
let c3 = Config::new(TEST_ALPHA, TEST_MAX_BINS, TEST_MIN_VALUE * 10.0);
let d3 = DDSketch::new(c3);
assert!(d1.merge(&d3).is_err());
let c4 = Config::new(TEST_ALPHA, TEST_MAX_BINS * 2, TEST_MIN_VALUE);
let d4 = DDSketch::new(c4);
assert!(d1.merge(&d4).is_err());
// the same should work
let c5 = Config::new(TEST_ALPHA, TEST_MAX_BINS, TEST_MIN_VALUE);
let dsame = DDSketch::new(c5);
assert!(d1.merge(&dsame).is_ok());
}
#[test]
#[ignore]
fn test_performance_insert() {
let c = Config::defaults();
let mut g = DDSketch::new(c);
let mut gen = generator::Normal::new(1000.0, 500.0);
let count = 300_000_000;
let mut values = Vec::new();
for _ in 0..count {
values.push(gen.generate());
}
let start_time = Instant::now();
for value in values {
g.add(value);
}
// This simply ensures the operations don't get optimzed out as ignored
let quantile = g.quantile(0.50).unwrap().unwrap();
let elapsed = start_time.elapsed().as_micros() as f64;
let elapsed = elapsed / 1_000_000.0;
println!(
"RESULT: p50={:.2} => Added {}M samples in {:2} secs ({:.2}M samples/sec)",
quantile,
count / 1_000_000,
elapsed,
(count as f64) / 1_000_000.0 / elapsed
);
}
#[test]
#[ignore]
fn test_performance_merge() {
let c = Config::defaults();
let mut gen = generator::Normal::new(1000.0, 500.0);
let merge_count = 500_000;
let sample_count = 1_000;
let mut sketches = Vec::new();
for _ in 0..merge_count {
let mut d = DDSketch::new(c);
for _ in 0..sample_count {
d.add(gen.generate());
}
sketches.push(d);
}
let mut base = DDSketch::new(c);
let start_time = Instant::now();
for sketch in &sketches {
base.merge(sketch).unwrap();
}
let elapsed = start_time.elapsed().as_micros() as f64;
let elapsed = elapsed / 1_000_000.0;
println!(
"RESULT: Merged {} sketches in {:2} secs ({:.2} merges/sec)",
merge_count,
elapsed,
(merge_count as f64) / elapsed
);
}

View File

@@ -95,11 +95,21 @@ pub(crate) fn get_all_ff_reader_or_empty(
allowed_column_types: Option<&[ColumnType]>,
fallback_type: ColumnType,
) -> crate::Result<Vec<(columnar::Column<u64>, ColumnType)>> {
let ff_fields = reader.fast_fields();
let mut ff_field_with_type =
ff_fields.u64_lenient_for_type_all(allowed_column_types, field_name)?;
let mut ff_field_with_type = get_all_ff_readers(reader, field_name, allowed_column_types)?;
if ff_field_with_type.is_empty() {
ff_field_with_type.push((Column::build_empty_column(reader.num_docs()), fallback_type));
}
Ok(ff_field_with_type)
}
/// Get all fast field reader.
pub(crate) fn get_all_ff_readers(
reader: &SegmentReader,
field_name: &str,
allowed_column_types: Option<&[ColumnType]>,
) -> crate::Result<Vec<(columnar::Column<u64>, ColumnType)>> {
let ff_fields = reader.fast_fields();
let ff_field_with_type =
ff_fields.u64_lenient_for_type_all(allowed_column_types, field_name)?;
Ok(ff_field_with_type)
}

View File

@@ -9,11 +9,12 @@ use crate::aggregation::accessor_helpers::{
get_numeric_or_date_column_types,
};
use crate::aggregation::agg_req::{Aggregation, AggregationVariants, Aggregations};
pub use crate::aggregation::bucket::{CompositeAggReqData, CompositeSourceAccessors};
use crate::aggregation::bucket::{
build_segment_filter_collector, build_segment_range_collector, FilterAggReqData,
HistogramAggReqData, HistogramBounds, IncludeExcludeParam, MissingTermAggReqData,
RangeAggReqData, SegmentHistogramCollector, TermMissingAgg, TermsAggReqData, TermsAggregation,
TermsAggregationInternal,
build_segment_filter_collector, build_segment_range_collector, CompositeAggregation,
FilterAggReqData, HistogramAggReqData, HistogramBounds, IncludeExcludeParam,
MissingTermAggReqData, RangeAggReqData, SegmentCompositeCollector, SegmentHistogramCollector,
TermMissingAgg, TermsAggReqData, TermsAggregation, TermsAggregationInternal,
};
use crate::aggregation::metric::{
build_segment_stats_collector, AverageAggregation, CardinalityAggReqData,
@@ -73,6 +74,12 @@ impl AggregationsSegmentCtx {
self.per_request.filter_req_data.push(Some(Box::new(data)));
self.per_request.filter_req_data.len() - 1
}
pub(crate) fn push_composite_req_data(&mut self, data: CompositeAggReqData) -> usize {
self.per_request
.composite_req_data
.push(Some(Box::new(data)));
self.per_request.composite_req_data.len() - 1
}
#[inline]
pub(crate) fn get_term_req_data(&self, idx: usize) -> &TermsAggReqData {
@@ -108,6 +115,12 @@ impl AggregationsSegmentCtx {
.as_deref()
.expect("range_req_data slot is empty (taken)")
}
#[inline]
pub(crate) fn get_composite_req_data(&self, idx: usize) -> &CompositeAggReqData {
self.per_request.composite_req_data[idx]
.as_deref()
.expect("composite_req_data slot is empty (taken)")
}
// ---------- mutable getters ----------
@@ -130,8 +143,14 @@ impl AggregationsSegmentCtx {
.as_deref_mut()
.expect("histogram_req_data slot is empty (taken)")
}
#[inline]
pub(crate) fn get_composite_req_data_mut(&mut self, idx: usize) -> &mut CompositeAggReqData {
self.per_request.composite_req_data[idx]
.as_deref_mut()
.expect("composite_req_data slot is empty (taken)")
}
// ---------- take / put (terms, histogram, range) ----------
// ---------- take / put (terms, histogram, range, composite) ----------
/// Move out the boxed Histogram request at `idx`, leaving `None`.
#[inline]
@@ -181,6 +200,25 @@ impl AggregationsSegmentCtx {
debug_assert!(self.per_request.filter_req_data[idx].is_none());
self.per_request.filter_req_data[idx] = Some(value);
}
/// Move out the Composite request at `idx`.
#[inline]
pub(crate) fn take_composite_req_data(&mut self, idx: usize) -> Box<CompositeAggReqData> {
self.per_request.composite_req_data[idx]
.take()
.expect("composite_req_data slot is empty (taken)")
}
/// Put back a Composite request into an empty slot at `idx`.
#[inline]
pub(crate) fn put_back_composite_req_data(
&mut self,
idx: usize,
value: Box<CompositeAggReqData>,
) {
debug_assert!(self.per_request.composite_req_data[idx].is_none());
self.per_request.composite_req_data[idx] = Some(value);
}
}
/// Each type of aggregation has its own request data struct. This struct holds
@@ -200,6 +238,8 @@ pub struct PerRequestAggSegCtx {
pub range_req_data: Vec<Option<Box<RangeAggReqData>>>,
/// FilterAggReqData contains the request data for a filter aggregation.
pub filter_req_data: Vec<Option<Box<FilterAggReqData>>>,
/// CompositeAggReqData contains the request data for a composite aggregation.
pub composite_req_data: Vec<Option<Box<CompositeAggReqData>>>,
/// Shared by avg, min, max, sum, stats, extended_stats, count
pub stats_metric_req_data: Vec<MetricAggReqData>,
/// CardinalityAggReqData contains the request data for a cardinality aggregation.
@@ -255,6 +295,11 @@ impl PerRequestAggSegCtx {
.iter()
.map(|t| t.get_memory_consumption())
.sum::<usize>()
+ self
.composite_req_data
.iter()
.map(|t| t.as_ref().unwrap().get_memory_consumption())
.sum::<usize>()
+ self.agg_tree.len() * std::mem::size_of::<AggRefNode>()
}
@@ -291,6 +336,11 @@ impl PerRequestAggSegCtx {
.expect("filter_req_data slot is empty (taken)")
.name
.as_str(),
AggKind::Composite => &self.composite_req_data[idx]
.as_deref()
.expect("composite_req_data slot is empty (taken)")
.name
.as_str(),
}
}
@@ -417,6 +467,9 @@ pub(crate) fn build_segment_agg_collector(
)?)),
AggKind::Range => Ok(build_segment_range_collector(req, node)?),
AggKind::Filter => build_segment_filter_collector(req, node),
AggKind::Composite => Ok(Box::new(SegmentCompositeCollector::from_req_and_validate(
req, node,
)?)),
}
}
@@ -447,6 +500,7 @@ pub enum AggKind {
DateHistogram,
Range,
Filter,
Composite,
}
impl AggKind {
@@ -462,6 +516,7 @@ impl AggKind {
AggKind::DateHistogram => "DateHistogram",
AggKind::Range => "Range",
AggKind::Filter => "Filter",
AggKind::Composite => "Composite",
}
}
}
@@ -740,6 +795,14 @@ fn build_nodes(
children,
}])
}
AggregationVariants::Composite(composite_req) => Ok(vec![build_composite_node(
agg_name,
reader,
segment_ordinal,
data,
&req.sub_aggregation,
composite_req,
)?]),
}
}
@@ -935,6 +998,35 @@ fn build_terms_or_cardinality_nodes(
Ok(nodes)
}
fn build_composite_node(
agg_name: &str,
reader: &SegmentReader,
segment_ordinal: SegmentOrdinal,
data: &mut AggregationsSegmentCtx,
sub_aggs: &Aggregations,
req: &CompositeAggregation,
) -> crate::Result<AggRefNode> {
let mut composite_accessors = Vec::with_capacity(req.sources.len());
for source in &req.sources {
let source_after_key_opt = req.after.get(source.name()).map(|k| &k.0);
let source_accessor =
CompositeSourceAccessors::build_for_source(reader, source, source_after_key_opt)?;
composite_accessors.push(source_accessor);
}
let agg = CompositeAggReqData {
name: agg_name.to_string(),
req: req.clone(),
composite_accessors,
};
let idx = data.push_composite_req_data(agg);
let children = build_children(sub_aggs, reader, segment_ordinal, data)?;
Ok(AggRefNode {
kind: AggKind::Composite,
idx_in_req_data: idx,
children,
})
}
/// Builds a single BitSet of allowed term ordinals for a string dictionary column according to
/// include/exclude parameters.
fn build_allowed_term_ids_for_str(

View File

@@ -40,6 +40,7 @@ use super::metric::{
MaxAggregation, MinAggregation, PercentilesAggregationReq, StatsAggregation, SumAggregation,
TopHitsAggregationReq,
};
use crate::aggregation::bucket::CompositeAggregation;
/// The top-level aggregation request structure, which contains [`Aggregation`] and their user
/// defined names. It is also used in buckets aggregations to define sub-aggregations.
@@ -134,6 +135,9 @@ pub enum AggregationVariants {
/// Filter documents into a single bucket.
#[serde(rename = "filter")]
Filter(FilterAggregation),
/// Put data into multi level paginated buckets.
#[serde(rename = "composite")]
Composite(CompositeAggregation),
// Metric aggregation types
/// Computes the average of the extracted values.
@@ -180,6 +184,11 @@ impl AggregationVariants {
AggregationVariants::Histogram(histogram) => vec![histogram.field.as_str()],
AggregationVariants::DateHistogram(histogram) => vec![histogram.field.as_str()],
AggregationVariants::Filter(filter) => filter.get_fast_field_names(),
AggregationVariants::Composite(composite) => composite
.sources
.iter()
.map(|source_map| source_map.field())
.collect(),
AggregationVariants::Average(avg) => vec![avg.field_name()],
AggregationVariants::Count(count) => vec![count.field_name()],
AggregationVariants::Max(max) => vec![max.field_name()],
@@ -214,6 +223,12 @@ impl AggregationVariants {
_ => None,
}
}
pub(crate) fn as_composite(&self) -> Option<&CompositeAggregation> {
match &self {
AggregationVariants::Composite(composite) => Some(composite),
_ => None,
}
}
pub(crate) fn as_percentile(&self) -> Option<&PercentilesAggregationReq> {
match &self {
AggregationVariants::Percentiles(percentile_req) => Some(percentile_req),

View File

@@ -13,6 +13,8 @@ use super::metric::{
ExtendedStats, PercentilesMetricResult, SingleMetricResult, Stats, TopHitsMetricResult,
};
use super::{AggregationError, Key};
use crate::aggregation::bucket::AfterKey;
use crate::aggregation::intermediate_agg_result::CompositeIntermediateKey;
use crate::TantivyError;
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
@@ -158,6 +160,16 @@ pub enum BucketResult {
},
/// This is the filter result - a single bucket with sub-aggregations
Filter(FilterBucketResult),
/// This is the composite aggregation result
Composite {
/// The buckets
///
/// See [`CompositeAggregation`](super::bucket::CompositeAggregation)
buckets: Vec<CompositeBucketEntry>,
/// The key to start after when paginating
#[serde(skip_serializing_if = "FxHashMap::is_empty")]
after_key: FxHashMap<String, AfterKey>,
},
}
impl BucketResult {
@@ -179,6 +191,9 @@ impl BucketResult {
// Only count sub-aggregation buckets
filter_result.sub_aggregations.get_bucket_count()
}
BucketResult::Composite { buckets, .. } => {
buckets.iter().map(|bucket| bucket.get_bucket_count()).sum()
}
}
}
}
@@ -337,3 +352,130 @@ pub struct FilterBucketResult {
#[serde(flatten)]
pub sub_aggregations: AggregationResults,
}
/// The JSON mappable key to identify a composite bucket.
///
/// This is similar to `Key`, but composite keys can also be boolean and null.
///
/// Note the type information loss compared to `CompositeIntermediateKey`.
/// Pagination is performed using `AfterKey`, which encodes type information.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(untagged)]
pub enum CompositeKey {
/// Boolean key
Bool(bool),
/// String key
Str(String),
/// `i64` key
I64(i64),
/// `u64` key
U64(u64),
/// `f64` key
F64(f64),
/// Null key
Null,
}
impl Eq for CompositeKey {}
impl std::hash::Hash for CompositeKey {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
core::mem::discriminant(self).hash(state);
match self {
Self::Bool(val) => val.hash(state),
Self::Str(text) => text.hash(state),
Self::F64(val) => val.to_bits().hash(state),
Self::U64(val) => val.hash(state),
Self::I64(val) => val.hash(state),
Self::Null => {}
}
}
}
impl PartialEq for CompositeKey {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::Bool(l), Self::Bool(r)) => l == r,
(Self::Str(l), Self::Str(r)) => l == r,
(Self::F64(l), Self::F64(r)) => l.to_bits() == r.to_bits(),
(Self::I64(l), Self::I64(r)) => l == r,
(Self::U64(l), Self::U64(r)) => l == r,
(Self::Null, Self::Null) => true,
(
Self::Bool(_)
| Self::Str(_)
| Self::F64(_)
| Self::I64(_)
| Self::U64(_)
| Self::Null,
_,
) => false,
}
}
}
impl From<CompositeIntermediateKey> for CompositeKey {
fn from(value: CompositeIntermediateKey) -> Self {
match value {
CompositeIntermediateKey::Str(s) => Self::Str(s),
CompositeIntermediateKey::IpAddr(s) => {
// Prefer to use the IPv4 representation if possible
if let Some(ip) = s.to_ipv4_mapped() {
Self::Str(ip.to_string())
} else {
Self::Str(s.to_string())
}
}
CompositeIntermediateKey::F64(f) => Self::F64(f),
CompositeIntermediateKey::Bool(f) => Self::Bool(f),
CompositeIntermediateKey::U64(f) => Self::U64(f),
CompositeIntermediateKey::I64(f) => Self::I64(f),
CompositeIntermediateKey::DateTime(f) => Self::I64(f / 1_000_000), // Convert ns to ms
CompositeIntermediateKey::Null => Self::Null,
}
}
}
/// This is the default entry for a bucket, which contains a composite key, count, and optionally
/// sub-aggregations.
/// ...
/// "my_composite": {
/// "buckets": [
/// {
/// "key": {
/// "date": 1494201600000,
/// "product": "rocky"
/// },
/// "doc_count": 5
/// },
/// {
/// "key": {
/// "date": 1494201600000,
/// "product": "balboa"
/// },
/// "doc_count": 2
/// },
/// {
/// "key": {
/// "date": 1494201700000,
/// "product": "john"
/// },
/// "doc_count": 3
/// }
/// ]
/// }
/// ...
/// }
/// ```
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct CompositeBucketEntry {
/// The identifier of the bucket.
pub key: FxHashMap<String, CompositeKey>,
/// Number of documents in the bucket.
pub doc_count: u64,
#[serde(flatten)]
/// Sub-aggregations in this bucket.
pub sub_aggregation: AggregationResults,
}
impl CompositeBucketEntry {
pub(crate) fn get_bucket_count(&self) -> u64 {
1 + self.sub_aggregation.get_bucket_count()
}
}

View File

@@ -0,0 +1,515 @@
use std::fmt::Debug;
use std::net::Ipv6Addr;
use columnar::column_values::{CompactHit, CompactSpaceU64Accessor};
use columnar::{Column, ColumnType, MonotonicallyMappableToU64, StrColumn, TermOrdHit};
use crate::aggregation::accessor_helpers::{get_all_ff_readers, get_numeric_or_date_column_types};
use crate::aggregation::bucket::composite::numeric_types::num_proj;
use crate::aggregation::bucket::composite::numeric_types::num_proj::ProjectedNumber;
use crate::aggregation::bucket::composite::ToTypePaginationOrder;
use crate::aggregation::bucket::{
parse_into_milliseconds, CalendarInterval, CompositeAggregation, CompositeAggregationSource,
MissingOrder, Order,
};
use crate::aggregation::intermediate_agg_result::CompositeIntermediateKey;
use crate::{SegmentReader, TantivyError};
/// Contains all information required by the SegmentCompositeCollector to perform the
/// composite aggregation on a segment.
pub struct CompositeAggReqData {
/// The name of the aggregation.
pub name: String,
/// The normalized term aggregation request.
pub req: CompositeAggregation,
/// Accessors for each source, each source can have multiple accessors (columns).
pub composite_accessors: Vec<CompositeSourceAccessors>,
}
impl CompositeAggReqData {
/// Estimate the memory consumption of this struct in bytes.
pub fn get_memory_consumption(&self) -> usize {
std::mem::size_of::<Self>()
+ self.composite_accessors.len() * std::mem::size_of::<CompositeSourceAccessors>()
}
}
/// Accessors for a single column in a composite source.
pub struct CompositeAccessor {
/// The fast field column
pub column: Column<u64>,
/// The column type
pub column_type: ColumnType,
/// Term dictionary if the column type is Str
///
/// Only used by term sources
pub str_dict_column: Option<StrColumn>,
/// Parsed date interval for date histogram sources
pub date_histogram_interval: PrecomputedDateInterval,
}
/// Accessors to all the columns that belong to the field of a composite source.
pub struct CompositeSourceAccessors {
/// The accessors for this source
pub accessors: Vec<CompositeAccessor>,
/// The key after which to start collecting results. Applies to the first
/// column of the source.
pub after_key: PrecomputedAfterKey,
/// The column index the after_key applies to. The after_key only applies to
/// one column. Columns before should be skipped. Columns after should be
/// kept without comparison to the after_key.
pub after_key_accessor_idx: usize,
/// Whether to skip missing values because of the after_key. Skipping only
/// applies if the value for previous columns were exactly equal to the
/// corresponding after keys (is_on_after_key).
pub skip_missing: bool,
/// The after key was set to null to indicate that the last collected key
/// was a missing value.
pub is_after_key_explicit_missing: bool,
}
impl CompositeSourceAccessors {
/// Creates a new set of accessors for the composite source.
///
/// Precomputes some values to make collection faster.
pub fn build_for_source(
reader: &SegmentReader,
source: &CompositeAggregationSource,
// First option is None when no after key was set in the query, the
// second option is None when the after key was set but its value for
// this source was set to `null`
source_after_key_opt: Option<&CompositeIntermediateKey>,
) -> crate::Result<Self> {
let is_after_key_explicit_missing = source_after_key_opt
.map(|after_key| matches!(after_key, CompositeIntermediateKey::Null))
.unwrap_or(false);
let mut skip_missing = false;
if let Some(CompositeIntermediateKey::Null) = source_after_key_opt {
if !source.missing_bucket() {
return Err(TantivyError::InvalidArgument(
"the 'after' key for a source cannot be null when 'missing_bucket' is false"
.to_string(),
));
}
} else if source_after_key_opt.is_some() {
// if missing buckets come first and we have a non null after key, we skip missing
if MissingOrder::First == source.missing_order() {
skip_missing = true;
}
if MissingOrder::Default == source.missing_order() && Order::Asc == source.order() {
skip_missing = true;
}
};
match source {
CompositeAggregationSource::Terms(source) => {
let allowed_column_types = [
ColumnType::I64,
ColumnType::U64,
ColumnType::F64,
ColumnType::Str,
ColumnType::DateTime,
ColumnType::Bool,
ColumnType::IpAddr,
// ColumnType::Bytes Unsupported
];
let mut columns_and_types =
get_all_ff_readers(reader, &source.field, Some(&allowed_column_types))?;
// Sort columns by their pagination order and determine which to skip
columns_and_types.sort_by_key(|(_, col_type)| col_type.column_pagination_order());
if source.order == Order::Desc {
columns_and_types.reverse();
}
let after_key_accessor_idx = find_first_column_to_collect(
&columns_and_types,
source_after_key_opt,
source.missing_order,
source.order,
)?;
let source_collectors: Vec<CompositeAccessor> = columns_and_types
.into_iter()
.map(|(column, column_type)| {
Ok(CompositeAccessor {
column,
column_type,
str_dict_column: reader.fast_fields().str(&source.field)?,
date_histogram_interval: PrecomputedDateInterval::NotApplicable,
})
})
.collect::<crate::Result<_>>()?;
let after_key = if let Some(first_col) =
source_collectors.get(after_key_accessor_idx)
{
match source_after_key_opt {
Some(after_key) => PrecomputedAfterKey::precompute(
&first_col,
after_key,
&source.field,
source.missing_order,
source.order,
)?,
None => {
precompute_missing_after_key(false, source.missing_order, source.order)
}
}
} else {
// if no columns, we don't care about the after_key
PrecomputedAfterKey::Next(0)
};
Ok(CompositeSourceAccessors {
accessors: source_collectors,
is_after_key_explicit_missing,
skip_missing,
after_key,
after_key_accessor_idx,
})
}
CompositeAggregationSource::Histogram(source) => {
let column_and_types: Vec<(Column, ColumnType)> = get_all_ff_readers(
reader,
&source.field,
Some(get_numeric_or_date_column_types()),
)?;
let source_collectors: Vec<CompositeAccessor> = column_and_types
.into_iter()
.map(|(column, column_type)| {
Ok(CompositeAccessor {
column,
column_type,
str_dict_column: None,
date_histogram_interval: PrecomputedDateInterval::NotApplicable,
})
})
.collect::<crate::Result<_>>()?;
let after_key = match source_after_key_opt {
Some(CompositeIntermediateKey::F64(key)) => {
let normalized_key = *key / source.interval;
num_proj::f64_to_i64(normalized_key).into()
}
Some(CompositeIntermediateKey::Null) => {
precompute_missing_after_key(true, source.missing_order, source.order)
}
None => precompute_missing_after_key(true, source.missing_order, source.order),
_ => {
return Err(crate::TantivyError::InvalidArgument(
"After key type invalid for interval composite source".to_string(),
));
}
};
Ok(CompositeSourceAccessors {
accessors: source_collectors,
is_after_key_explicit_missing,
skip_missing,
after_key,
after_key_accessor_idx: 0,
})
}
CompositeAggregationSource::DateHistogram(source) => {
let column_and_types =
get_all_ff_readers(reader, &source.field, Some(&[ColumnType::DateTime]))?;
let date_histogram_interval =
PrecomputedDateInterval::from_date_histogram_source_intervals(
&source.fixed_interval,
source.calendar_interval,
)?;
let source_collectors: Vec<CompositeAccessor> = column_and_types
.into_iter()
.map(|(column, column_type)| {
Ok(CompositeAccessor {
column,
column_type,
str_dict_column: None,
date_histogram_interval,
})
})
.collect::<crate::Result<_>>()?;
let after_key = match source_after_key_opt {
Some(CompositeIntermediateKey::DateTime(key)) => {
PrecomputedAfterKey::Exact(key.to_u64())
}
Some(CompositeIntermediateKey::Null) => {
precompute_missing_after_key(true, source.missing_order, source.order)
}
None => precompute_missing_after_key(true, source.missing_order, source.order),
_ => {
return Err(crate::TantivyError::InvalidArgument(
"After key type invalid for interval composite source".to_string(),
));
}
};
Ok(CompositeSourceAccessors {
accessors: source_collectors,
is_after_key_explicit_missing,
skip_missing,
after_key,
after_key_accessor_idx: 0,
})
}
}
}
}
/// Finds the index of the first column we should start collecting from to
/// resume the pagination from the after_key.
fn find_first_column_to_collect<T>(
sorted_columns: &[(T, ColumnType)],
after_key_opt: Option<&CompositeIntermediateKey>,
missing_order: MissingOrder,
order: Order,
) -> crate::Result<usize> {
let after_key = match after_key_opt {
None => return Ok(0), // No pagination, start from beginning
Some(key) => key,
};
// Handle null after_key (we were on a missing value last time)
if matches!(after_key, CompositeIntermediateKey::Null) {
return match (missing_order, order) {
// Missing values come first, so all columns remain
(MissingOrder::First, _) | (MissingOrder::Default, Order::Asc) => Ok(0),
// Missing values come last, so all columns are done
(MissingOrder::Last, _) | (MissingOrder::Default, Order::Desc) => {
Ok(sorted_columns.len())
}
};
}
// Find the first column whose type order matches or follows the after_key's
// type in the pagination sequence
let after_key_column_order = after_key.column_pagination_order();
for (idx, (_, col_type)) in sorted_columns.iter().enumerate() {
let col_order = col_type.column_pagination_order();
let is_first_to_collect = match order {
Order::Asc => col_order >= after_key_column_order,
Order::Desc => col_order <= after_key_column_order,
};
if is_first_to_collect {
return Ok(idx);
}
}
// All columns are before the after_key, nothing left to collect
Ok(sorted_columns.len())
}
fn precompute_missing_after_key(
is_after_key_explicit_missing: bool,
missing_order: MissingOrder,
order: Order,
) -> PrecomputedAfterKey {
let after_last = PrecomputedAfterKey::AfterLast;
let before_first = PrecomputedAfterKey::Next(0);
match (is_after_key_explicit_missing, missing_order, order) {
(true, MissingOrder::First, Order::Asc) => before_first,
(true, MissingOrder::First, Order::Desc) => after_last,
(true, MissingOrder::Last, Order::Asc) => after_last,
(true, MissingOrder::Last, Order::Desc) => before_first,
(true, MissingOrder::Default, Order::Asc) => before_first,
(true, MissingOrder::Default, Order::Desc) => after_last,
(false, _, Order::Asc) => before_first,
(false, _, Order::Desc) => after_last,
}
}
/// A parsed representation of the date interval for date histogram sources
#[derive(Clone, Copy, Debug)]
pub enum PrecomputedDateInterval {
/// This is not a date histogram source
NotApplicable,
/// Source was configured with a fixed interval
FixedNanoseconds(i64),
/// Source was configured with a calendar interval
Calendar(CalendarInterval),
}
impl PrecomputedDateInterval {
/// Validates the date histogram source interval fields and parses a date interval from them.
pub fn from_date_histogram_source_intervals(
fixed_interval: &Option<String>,
calendar_interval: Option<CalendarInterval>,
) -> crate::Result<Self> {
match (fixed_interval, calendar_interval) {
(Some(_), Some(_)) | (None, None) => Err(TantivyError::InvalidArgument(
"date histogram source must one and only one of fixed_interval or \
calendar_interval set"
.to_string(),
)),
(Some(fixed_interval), None) => {
let fixed_interval_ms = parse_into_milliseconds(&fixed_interval)?;
Ok(PrecomputedDateInterval::FixedNanoseconds(
fixed_interval_ms * 1_000_000,
))
}
(None, Some(calendar_interval)) => {
Ok(PrecomputedDateInterval::Calendar(calendar_interval))
}
}
}
}
/// The after key projected to the u64 column space
///
/// Some column types (term, IP) might not have an exact representation of the
/// specified after key
#[derive(Debug)]
pub enum PrecomputedAfterKey {
/// The after key could be exactly represented in the column space.
Exact(u64),
/// The after key could not be exactly represented exactly represented, so
/// this is the next closest one.
Next(u64),
/// The after key could not be represented in the column space, it is
/// greater than all value
AfterLast,
}
impl From<TermOrdHit> for PrecomputedAfterKey {
fn from(hit: TermOrdHit) -> Self {
match hit {
TermOrdHit::Exact(ord) => PrecomputedAfterKey::Exact(ord),
// TermOrdHit represents AfterLast as Next(u64::MAX), we keep it as is
TermOrdHit::Next(ord) => PrecomputedAfterKey::Next(ord),
}
}
}
impl From<CompactHit> for PrecomputedAfterKey {
fn from(hit: CompactHit) -> Self {
match hit {
CompactHit::Exact(ord) => PrecomputedAfterKey::Exact(ord as u64),
CompactHit::Next(ord) => PrecomputedAfterKey::Next(ord as u64),
CompactHit::AfterLast => PrecomputedAfterKey::AfterLast,
}
}
}
impl<T: MonotonicallyMappableToU64> From<ProjectedNumber<T>> for PrecomputedAfterKey {
fn from(num: ProjectedNumber<T>) -> Self {
match num {
ProjectedNumber::Exact(number) => PrecomputedAfterKey::Exact(number.to_u64()),
ProjectedNumber::Next(number) => PrecomputedAfterKey::Next(number.to_u64()),
ProjectedNumber::AfterLast => PrecomputedAfterKey::AfterLast,
}
}
}
// /!\ These operators only makes sense if both values are in the same column space
impl PrecomputedAfterKey {
pub fn equals(&self, column_value: u64) -> bool {
match self {
PrecomputedAfterKey::Exact(v) => *v == column_value,
PrecomputedAfterKey::Next(_) => false,
PrecomputedAfterKey::AfterLast => false,
}
}
pub fn gt(&self, column_value: u64) -> bool {
match self {
PrecomputedAfterKey::Exact(v) => *v > column_value,
PrecomputedAfterKey::Next(v) => *v > column_value,
PrecomputedAfterKey::AfterLast => true,
}
}
pub fn lt(&self, column_value: u64) -> bool {
match self {
PrecomputedAfterKey::Exact(v) => *v < column_value,
// a value equal to the next is greater than the after key
PrecomputedAfterKey::Next(v) => *v <= column_value,
PrecomputedAfterKey::AfterLast => false,
}
}
fn precompute_ip_addr(column: &Column<u64>, key: &Ipv6Addr) -> crate::Result<Self> {
let compact_space_accessor = column
.values
.clone()
.downcast_arc::<CompactSpaceU64Accessor>()
.map_err(|_| {
TantivyError::AggregationError(crate::aggregation::AggregationError::InternalError(
"type mismatch: could not downcast to CompactSpaceU64Accessor".to_string(),
))
})?;
let ip_u128 = key.to_bits();
let ip_next_compact = compact_space_accessor.u128_to_next_compact(ip_u128);
Ok(ip_next_compact.into())
}
fn precompute_term_ord(
str_dict_column: &Option<StrColumn>,
key: &str,
field: &str,
) -> crate::Result<Self> {
let dict = str_dict_column
.as_ref()
.expect("dictionary missing for str accessor")
.dictionary();
let next_ord = dict.term_ord_or_next(key).map_err(|_| {
TantivyError::InvalidArgument(format!(
"failed to lookup after_key '{}' for field '{}'",
key, field
))
})?;
Ok(next_ord.into())
}
/// Projects the after key into the column space of the given accessor.
///
/// The computed after key will not take care of skipping entire columns
/// when the after key type is ordered after the accessor's type, that
/// should be performed earlier.
pub fn precompute(
composite_accessor: &CompositeAccessor,
source_after_key: &CompositeIntermediateKey,
field: &str,
missing_order: MissingOrder,
order: Order,
) -> crate::Result<Self> {
use CompositeIntermediateKey as CIKey;
let precomputed_key = match (composite_accessor.column_type, source_after_key) {
(ColumnType::Bytes, _) => panic!("unsupported"),
// null after key
(_, CIKey::Null) => precompute_missing_after_key(false, missing_order, order),
// numerical
(ColumnType::I64, CIKey::I64(k)) => PrecomputedAfterKey::Exact(k.to_u64()),
(ColumnType::I64, CIKey::U64(k)) => num_proj::u64_to_i64(*k).into(),
(ColumnType::I64, CIKey::F64(k)) => num_proj::f64_to_i64(*k).into(),
(ColumnType::U64, CIKey::I64(k)) => num_proj::i64_to_u64(*k).into(),
(ColumnType::U64, CIKey::U64(k)) => PrecomputedAfterKey::Exact(*k),
(ColumnType::U64, CIKey::F64(k)) => num_proj::f64_to_u64(*k).into(),
(ColumnType::F64, CIKey::I64(k)) => num_proj::i64_to_f64(*k).into(),
(ColumnType::F64, CIKey::U64(k)) => num_proj::u64_to_f64(*k).into(),
(ColumnType::F64, CIKey::F64(k)) => PrecomputedAfterKey::Exact(k.to_u64()),
// boolean
(ColumnType::Bool, CIKey::Bool(key)) => PrecomputedAfterKey::Exact(key.to_u64()),
// string
(ColumnType::Str, CIKey::Str(key)) => PrecomputedAfterKey::precompute_term_ord(
&composite_accessor.str_dict_column,
key,
field,
)?,
// date time
(ColumnType::DateTime, CIKey::DateTime(key)) => {
PrecomputedAfterKey::Exact(key.to_u64())
}
// ip address
(ColumnType::IpAddr, CIKey::IpAddr(key)) => {
PrecomputedAfterKey::precompute_ip_addr(&composite_accessor.column, key)?
}
// assume the column's type is ordered after the after_key's type
_ => PrecomputedAfterKey::keep_all(order),
};
Ok(precomputed_key)
}
fn keep_all(order: Order) -> Self {
match order {
Order::Asc => PrecomputedAfterKey::Next(0),
Order::Desc => PrecomputedAfterKey::Next(u64::MAX),
}
}
}

View File

@@ -0,0 +1,140 @@
use time::convert::{Day, Nanosecond};
use time::{Time, UtcDateTime};
const NS_IN_DAY: i64 = Nanosecond::per_t::<i128>(Day) as i64;
/// Computes the timestamp in nanoseconds corresponding to the beginning of the
/// year (January 1st at midnight UTC).
pub(super) fn try_year_bucket(timestamp_ns: i64) -> crate::Result<i64> {
year_bucket_using_time_crate(timestamp_ns).map_err(|e| {
crate::TantivyError::InvalidArgument(format!(
"Failed to compute year bucket for timestamp {}: {}",
timestamp_ns,
e.to_string()
))
})
}
/// Computes the timestamp in nanoseconds corresponding to the beginning of the
/// month (1st at midnight UTC).
pub(super) fn try_month_bucket(timestamp_ns: i64) -> crate::Result<i64> {
month_bucket_using_time_crate(timestamp_ns).map_err(|e| {
crate::TantivyError::InvalidArgument(format!(
"Failed to compute month bucket for timestamp {}: {}",
timestamp_ns,
e.to_string()
))
})
}
/// Computes the timestamp in nanoseconds corresponding to the beginning of the
/// week (Monday at midnight UTC).
pub(super) fn week_bucket(timestamp_ns: i64) -> i64 {
// 1970-01-01 was a Thursday (weekday = 4)
let days_since_epoch = timestamp_ns.div_euclid(NS_IN_DAY);
// Find the weekday: 0=Monday, ..., 6=Sunday
let weekday = (days_since_epoch + 3).rem_euclid(7);
let monday_days_since_epoch = days_since_epoch - weekday;
monday_days_since_epoch * NS_IN_DAY
}
fn year_bucket_using_time_crate(timestamp_ns: i64) -> Result<i64, time::Error> {
let timestamp_ns = UtcDateTime::from_unix_timestamp_nanos(timestamp_ns as i128)?
.replace_ordinal(1)?
.replace_time(Time::MIDNIGHT)
.unix_timestamp_nanos();
Ok(timestamp_ns as i64)
}
fn month_bucket_using_time_crate(timestamp_ns: i64) -> Result<i64, time::Error> {
let timestamp_ns = UtcDateTime::from_unix_timestamp_nanos(timestamp_ns as i128)?
.replace_day(1)?
.replace_time(Time::MIDNIGHT)
.unix_timestamp_nanos();
Ok(timestamp_ns as i64)
}
#[cfg(test)]
mod tests {
use std::i64;
use time::format_description::well_known::Iso8601;
use time::UtcDateTime;
use super::*;
fn ts_ns(iso: &str) -> i64 {
UtcDateTime::parse(iso, &Iso8601::DEFAULT)
.unwrap()
.unix_timestamp_nanos() as i64
}
#[test]
fn test_year_bucket() {
let ts = ts_ns("1970-01-01T00:00:00Z");
let res = try_year_bucket(ts).unwrap();
assert_eq!(res, ts_ns("1970-01-01T00:00:00Z"));
let ts = ts_ns("1970-06-01T10:00:01.010Z");
let res = try_year_bucket(ts).unwrap();
assert_eq!(res, ts_ns("1970-01-01T00:00:00Z"));
let ts = ts_ns("2008-12-31T23:59:59.999999999Z"); // leap year
let res = try_year_bucket(ts).unwrap();
assert_eq!(res, ts_ns("2008-01-01T00:00:00Z"));
let ts = ts_ns("2008-01-01T00:00:00Z"); // leap year
let res = try_year_bucket(ts).unwrap();
assert_eq!(res, ts_ns("2008-01-01T00:00:00Z"));
let ts = ts_ns("2010-12-31T23:59:59.999999999Z");
let res = try_year_bucket(ts).unwrap();
assert_eq!(res, ts_ns("2010-01-01T00:00:00Z"));
let ts = ts_ns("1972-06-01T00:10:00Z");
let res = try_year_bucket(ts).unwrap();
assert_eq!(res, ts_ns("1972-01-01T00:00:00Z"));
}
#[test]
fn test_month_bucket() {
let ts = ts_ns("1970-01-15T00:00:00Z");
let res = try_month_bucket(ts).unwrap();
assert_eq!(res, ts_ns("1970-01-01T00:00:00Z"));
let ts = ts_ns("1970-02-01T00:00:00Z");
let res = try_month_bucket(ts).unwrap();
assert_eq!(res, ts_ns("1970-02-01T00:00:00Z"));
let ts = ts_ns("2000-01-31T23:59:59.999999999Z");
let res = try_month_bucket(ts).unwrap();
assert_eq!(res, ts_ns("2000-01-01T00:00:00Z"));
}
#[test]
fn test_week_bucket() {
let ts = ts_ns("1970-01-05T00:00:00Z"); // Monday
let res = week_bucket(ts);
assert_eq!(res, ts_ns("1970-01-05T00:00:00Z"));
let ts = ts_ns("1970-01-05T23:59:59Z"); // Monday
let res = week_bucket(ts);
assert_eq!(res, ts_ns("1970-01-05T00:00:00Z"));
let ts = ts_ns("1970-01-07T01:13:00Z"); // Wednesday
let res = week_bucket(ts);
assert_eq!(res, ts_ns("1970-01-05T00:00:00Z"));
let ts = ts_ns("1970-01-11T23:59:59.999999999Z"); // Sunday
let res = week_bucket(ts);
assert_eq!(res, ts_ns("1970-01-05T00:00:00Z"));
let ts = ts_ns("2025-10-16T10:41:59.010Z"); // Thursday
let res = week_bucket(ts);
assert_eq!(res, ts_ns("2025-10-13T00:00:00Z"));
let ts = ts_ns("1970-01-01T00:00:00Z"); // Thursday
let res = week_bucket(ts);
assert_eq!(res, ts_ns("1969-12-29T00:00:00Z")); // Negative
}
}

View File

@@ -0,0 +1,595 @@
use std::fmt::Debug;
use std::net::Ipv6Addr;
use columnar::column_values::CompactSpaceU64Accessor;
use columnar::{
Column, ColumnType, Dictionary, MonotonicallyMappableToU128, MonotonicallyMappableToU64,
NumericalValue, StrColumn,
};
use rustc_hash::FxHashMap;
use smallvec::SmallVec;
use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::bucket::composite::accessors::{
CompositeAccessor, CompositeAggReqData, PrecomputedDateInterval,
};
use crate::aggregation::bucket::composite::calendar_interval;
use crate::aggregation::bucket::composite::map::{DynArrayHeapMap, MAX_DYN_ARRAY_SIZE};
use crate::aggregation::bucket::{
CalendarInterval, CompositeAggregationSource, MissingOrder, Order,
};
use crate::aggregation::intermediate_agg_result::{
CompositeIntermediateKey, IntermediateAggregationResult, IntermediateAggregationResults,
IntermediateBucketResult, IntermediateCompositeBucketEntry, IntermediateCompositeBucketResult,
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::BucketId;
use crate::TantivyError;
#[derive(Debug)]
struct CompositeBucketCollector {
count: u32,
}
impl CompositeBucketCollector {
fn new() -> Self {
CompositeBucketCollector { count: 0 }
}
#[inline]
fn collect(&mut self) {
self.count += 1;
}
}
/// The value is represented as a tuple of:
/// - the column index or missing value sentinel
/// - if the value is present, store the accessor index + 1
/// - if the value is missing, store 0 (for missing first) or u8::MAX (for missing last)
/// - the fast field value u64 representation
/// - 0 if the field is missing
/// - regular u64 repr if the ordering is ascending
/// - bitwise NOT of the u64 repr if the ordering is descending
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Default, Hash)]
struct InternalValueRepr(u8, u64);
impl InternalValueRepr {
#[inline]
fn new_term(raw: u64, accessor_idx: u8, order: Order) -> Self {
match order {
Order::Asc => InternalValueRepr(accessor_idx + 1, raw),
Order::Desc => InternalValueRepr(accessor_idx + 1, !raw),
}
}
/// For histogram, the source column does not matter
#[inline]
fn new_histogram(raw: u64, order: Order) -> Self {
match order {
Order::Asc => InternalValueRepr(1, raw),
Order::Desc => InternalValueRepr(1, !raw),
}
}
#[inline]
fn new_missing(order: Order, missing_order: MissingOrder) -> Self {
let column_idx = match (missing_order, order) {
(MissingOrder::First, _) => 0,
(MissingOrder::Last, _) => u8::MAX,
(MissingOrder::Default, Order::Asc) => 0,
(MissingOrder::Default, Order::Desc) => u8::MAX,
};
InternalValueRepr(column_idx, 0)
}
#[inline]
fn decode(self, order: Order) -> Option<(u8, u64)> {
if self.0 == u8::MAX || self.0 == 0 {
return None;
}
match order {
Order::Asc => Some((self.0 - 1, self.1)),
Order::Desc => Some((self.0 - 1, !self.1)),
}
}
}
/// The collector puts values from the fast field into the correct buckets and
/// does a conversion to the correct datatype.
#[derive(Debug)]
pub struct SegmentCompositeCollector {
buckets: DynArrayHeapMap<InternalValueRepr, CompositeBucketCollector>,
accessor_idx: usize,
}
impl SegmentAggregationCollector for SegmentCompositeCollector {
fn add_intermediate_aggregation_result(
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
_parent_bucket_id: BucketId,
) -> crate::Result<()> {
let name = agg_data
.get_composite_req_data(self.accessor_idx)
.name
.clone();
let buckets = self.into_intermediate_bucket_result(agg_data)?;
results.push(
name,
IntermediateAggregationResult::Bucket(IntermediateBucketResult::Composite { buckets }),
)?;
Ok(())
}
#[inline]
fn collect(
&mut self,
_parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let mem_pre = self.get_memory_consumption();
let composite_agg_data = agg_data.take_composite_req_data(self.accessor_idx);
for doc in docs {
let mut sub_level_values = SmallVec::new();
recursive_key_visitor(
*doc,
agg_data,
&composite_agg_data,
0,
&mut sub_level_values,
&mut self.buckets,
true,
)?;
}
agg_data.put_back_composite_req_data(self.accessor_idx, composite_agg_data);
let mem_delta = self.get_memory_consumption() - mem_pre;
if mem_delta > 0 {
agg_data.context.limits.add_memory_consumed(mem_delta)?;
}
Ok(())
}
fn prepare_max_bucket(
&mut self,
_max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
Ok(())
}
fn flush(&mut self, _agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
Ok(())
}
}
impl SegmentCompositeCollector {
fn get_memory_consumption(&self) -> u64 {
self.buckets.memory_consumption()
}
pub(crate) fn from_req_and_validate(
req_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Self> {
validate_req(req_data, node.idx_in_req_data)?;
if !node.children.is_empty() {
let _sub_aggregation = build_segment_agg_collectors(req_data, &node.children)?;
}
let composite_req_data = req_data.get_composite_req_data(node.idx_in_req_data);
Ok(SegmentCompositeCollector {
buckets: DynArrayHeapMap::try_new(composite_req_data.req.sources.len())?,
accessor_idx: node.idx_in_req_data,
})
}
#[inline]
pub(crate) fn into_intermediate_bucket_result(
&mut self,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<IntermediateCompositeBucketResult> {
let mut dict: FxHashMap<Vec<CompositeIntermediateKey>, IntermediateCompositeBucketEntry> =
Default::default();
dict.reserve(self.buckets.size());
let composite_data = agg_data.get_composite_req_data(self.accessor_idx);
let buckets = std::mem::replace(
&mut self.buckets,
DynArrayHeapMap::try_new(composite_data.req.sources.len())
.expect("already validated source count"),
);
for (key_internal_repr, agg) in buckets.into_iter() {
let key = resolve_key(&key_internal_repr, composite_data)?;
let sub_aggregation_res = IntermediateAggregationResults::default();
dict.insert(
key,
IntermediateCompositeBucketEntry {
doc_count: agg.count,
sub_aggregation: sub_aggregation_res,
},
);
}
Ok(IntermediateCompositeBucketResult {
entries: dict,
target_size: composite_data.req.size,
orders: composite_data
.req
.sources
.iter()
.map(|source| match source {
CompositeAggregationSource::Terms(t) => (t.order, t.missing_order),
CompositeAggregationSource::Histogram(h) => (h.order, h.missing_order),
CompositeAggregationSource::DateHistogram(d) => (d.order, d.missing_order),
})
.collect(),
})
}
}
fn validate_req(req_data: &mut AggregationsSegmentCtx, accessor_idx: usize) -> crate::Result<()> {
let composite_data = req_data.get_composite_req_data(accessor_idx);
let req = &composite_data.req;
if req.sources.is_empty() {
return Err(TantivyError::InvalidArgument(
"composite aggregation must have at least one source".to_string(),
));
}
if req.size == 0 {
return Err(TantivyError::InvalidArgument(
"composite aggregation 'size' must be > 0".to_string(),
));
}
let column_types_for_sources = composite_data.composite_accessors.iter().map(|item| {
item.accessors
.iter()
.map(|a| a.column_type)
.collect::<Vec<_>>()
});
for column_types in column_types_for_sources {
if column_types.len() > MAX_DYN_ARRAY_SIZE {
return Err(TantivyError::InvalidArgument(format!(
"composite aggregation source supports maximum {MAX_DYN_ARRAY_SIZE} sources",
)));
}
if column_types.contains(&ColumnType::Bytes) {
return Err(TantivyError::InvalidArgument(
"composite aggregation does not support 'bytes' field type".to_string(),
));
}
}
Ok(())
}
fn collect_bucket_with_limit(
agg_data: &mut AggregationsSegmentCtx,
composite_agg_data: &CompositeAggReqData,
buckets: &mut DynArrayHeapMap<InternalValueRepr, CompositeBucketCollector>,
key: &[InternalValueRepr],
) -> crate::Result<()> {
if (buckets.size() as u32) < composite_agg_data.req.size {
buckets
.get_or_insert_with(key, CompositeBucketCollector::new)
.collect();
return Ok(());
}
if let Some(entry) = buckets.get_mut(key) {
entry.collect();
return Ok(());
}
if let Some(highest_key) = buckets.peek_highest() {
if key < highest_key {
buckets.evict_highest();
buckets
.get_or_insert_with(key, CompositeBucketCollector::new)
.collect();
}
}
let _ = agg_data;
Ok(())
}
/// Converts the composite key from its internal column space representation
/// (segment specific) into its intermediate form.
fn resolve_key(
internal_key: &[InternalValueRepr],
agg_data: &CompositeAggReqData,
) -> crate::Result<Vec<CompositeIntermediateKey>> {
internal_key
.into_iter()
.enumerate()
.map(|(idx, val)| {
resolve_internal_value_repr(
*val,
&agg_data.req.sources[idx],
&agg_data.composite_accessors[idx].accessors,
)
})
.collect()
}
fn resolve_internal_value_repr(
internal_value_repr: InternalValueRepr,
source: &CompositeAggregationSource,
composite_accessors: &[CompositeAccessor],
) -> crate::Result<CompositeIntermediateKey> {
let decoded_value_opt = match source {
CompositeAggregationSource::Terms(source) => internal_value_repr.decode(source.order),
CompositeAggregationSource::Histogram(source) => internal_value_repr.decode(source.order),
CompositeAggregationSource::DateHistogram(source) => {
internal_value_repr.decode(source.order)
}
};
let Some((decoded_accessor_idx, val)) = decoded_value_opt else {
return Ok(CompositeIntermediateKey::Null);
};
let key = match source {
CompositeAggregationSource::Terms(_) => {
let CompositeAccessor {
column_type,
str_dict_column,
column,
..
} = &composite_accessors[decoded_accessor_idx as usize];
resolve_term(val, column_type, str_dict_column, column)?
}
CompositeAggregationSource::Histogram(source) => {
CompositeIntermediateKey::F64(i64::from_u64(val) as f64 * source.interval)
}
CompositeAggregationSource::DateHistogram(_) => {
CompositeIntermediateKey::DateTime(i64::from_u64(val))
}
};
Ok(key)
}
fn resolve_term(
val: u64,
column_type: &ColumnType,
str_dict_column: &Option<StrColumn>,
column: &Column,
) -> crate::Result<CompositeIntermediateKey> {
let key = if *column_type == ColumnType::Str {
let fallback_dict = Dictionary::empty();
let term_dict = str_dict_column
.as_ref()
.map(|el| el.dictionary())
.unwrap_or_else(|| &fallback_dict);
let mut buffer = Vec::new();
term_dict.ord_to_term(val, &mut buffer)?;
CompositeIntermediateKey::Str(
String::from_utf8(buffer.to_vec()).expect("could not convert to String"),
)
} else if *column_type == ColumnType::DateTime {
let val = i64::from_u64(val);
CompositeIntermediateKey::DateTime(val)
} else if *column_type == ColumnType::Bool {
let val = bool::from_u64(val);
CompositeIntermediateKey::Bool(val)
} else if *column_type == ColumnType::IpAddr {
let compact_space_accessor = column
.values
.clone()
.downcast_arc::<CompactSpaceU64Accessor>()
.map_err(|_| {
TantivyError::AggregationError(crate::aggregation::AggregationError::InternalError(
"Type mismatch: Could not downcast to CompactSpaceU64Accessor".to_string(),
))
})?;
let val: u128 = compact_space_accessor.compact_to_u128(val as u32);
let val = Ipv6Addr::from_u128(val);
CompositeIntermediateKey::IpAddr(val)
} else {
if *column_type == ColumnType::U64 {
CompositeIntermediateKey::U64(val)
} else if *column_type == ColumnType::I64 {
CompositeIntermediateKey::I64(i64::from_u64(val))
} else {
let val = f64::from_u64(val);
let val: NumericalValue = val.into();
match val.normalize() {
NumericalValue::U64(val) => CompositeIntermediateKey::U64(val),
NumericalValue::I64(val) => CompositeIntermediateKey::I64(val),
NumericalValue::F64(val) => CompositeIntermediateKey::F64(val),
}
}
};
Ok(key)
}
/// Depth-first walk of the accessors to build the composite key combinations
/// and update the buckets.
fn recursive_key_visitor(
doc_id: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
composite_agg_data: &CompositeAggReqData,
source_idx_for_recursion: usize,
sub_level_values: &mut SmallVec<[InternalValueRepr; MAX_DYN_ARRAY_SIZE]>,
buckets: &mut DynArrayHeapMap<InternalValueRepr, CompositeBucketCollector>,
is_on_after_key: bool,
) -> crate::Result<()> {
if source_idx_for_recursion == composite_agg_data.req.sources.len() {
if !is_on_after_key {
collect_bucket_with_limit(
agg_data,
composite_agg_data,
buckets,
sub_level_values,
)?;
}
return Ok(());
}
let current_level_accessors = &composite_agg_data.composite_accessors[source_idx_for_recursion];
let current_level_source = &composite_agg_data.req.sources[source_idx_for_recursion];
let mut missing = true;
for (accessor_idx, accessor) in current_level_accessors.accessors.iter().enumerate() {
let values = accessor.column.values_for_doc(doc_id);
for value in values {
missing = false;
match current_level_source {
CompositeAggregationSource::Terms(_) => {
let preceeds_after_key_type =
accessor_idx < current_level_accessors.after_key_accessor_idx;
if is_on_after_key && preceeds_after_key_type {
break;
}
let matches_after_key_type =
accessor_idx == current_level_accessors.after_key_accessor_idx;
if matches_after_key_type && is_on_after_key {
let should_skip = match current_level_source.order() {
Order::Asc => current_level_accessors.after_key.gt(value),
Order::Desc => current_level_accessors.after_key.lt(value),
};
if should_skip {
continue;
}
}
sub_level_values.push(InternalValueRepr::new_term(
value,
accessor_idx as u8,
current_level_source.order(),
));
let still_on_after_key =
matches_after_key_type && current_level_accessors.after_key.equals(value);
recursive_key_visitor(
doc_id,
agg_data,
composite_agg_data,
source_idx_for_recursion + 1,
sub_level_values,
buckets,
is_on_after_key && still_on_after_key,
)?;
sub_level_values.pop();
}
CompositeAggregationSource::Histogram(source) => {
let float_value = match accessor.column_type {
ColumnType::U64 => value as f64,
ColumnType::I64 => i64::from_u64(value) as f64,
ColumnType::DateTime => i64::from_u64(value) as f64 / 1_000_000.,
ColumnType::F64 => f64::from_u64(value),
_ => {
panic!(
"unexpected type {:?}. This should not happen",
accessor.column_type
)
}
};
let bucket_index = (float_value / source.interval).floor() as i64;
let bucket_value = i64::to_u64(bucket_index);
if is_on_after_key {
let should_skip = match current_level_source.order() {
Order::Asc => current_level_accessors.after_key.gt(bucket_value),
Order::Desc => current_level_accessors.after_key.lt(bucket_value),
};
if should_skip {
continue;
}
}
sub_level_values.push(InternalValueRepr::new_histogram(
bucket_value,
current_level_source.order(),
));
let still_on_after_key = current_level_accessors.after_key.equals(bucket_value);
recursive_key_visitor(
doc_id,
agg_data,
composite_agg_data,
source_idx_for_recursion + 1,
sub_level_values,
buckets,
is_on_after_key && still_on_after_key,
)?;
sub_level_values.pop();
}
CompositeAggregationSource::DateHistogram(_) => {
let value_ns = match accessor.column_type {
ColumnType::DateTime => i64::from_u64(value),
_ => {
panic!(
"unexpected type {:?}. This should not happen",
accessor.column_type
)
}
};
let bucket_index = match accessor.date_histogram_interval {
PrecomputedDateInterval::FixedNanoseconds(fixed_interval_ns) => {
(value_ns / fixed_interval_ns) * fixed_interval_ns
}
PrecomputedDateInterval::Calendar(CalendarInterval::Year) => {
calendar_interval::try_year_bucket(value_ns)?
}
PrecomputedDateInterval::Calendar(CalendarInterval::Month) => {
calendar_interval::try_month_bucket(value_ns)?
}
PrecomputedDateInterval::Calendar(CalendarInterval::Week) => {
calendar_interval::week_bucket(value_ns)
}
PrecomputedDateInterval::NotApplicable => {
panic!("interval not precomputed for date histogram source")
}
};
let bucket_value = i64::to_u64(bucket_index);
if is_on_after_key {
let should_skip = match current_level_source.order() {
Order::Asc => current_level_accessors.after_key.gt(bucket_value),
Order::Desc => current_level_accessors.after_key.lt(bucket_value),
};
if should_skip {
continue;
}
}
sub_level_values.push(InternalValueRepr::new_histogram(
bucket_value,
current_level_source.order(),
));
let still_on_after_key = current_level_accessors.after_key.equals(bucket_value);
recursive_key_visitor(
doc_id,
agg_data,
composite_agg_data,
source_idx_for_recursion + 1,
sub_level_values,
buckets,
is_on_after_key && still_on_after_key,
)?;
sub_level_values.pop();
}
};
}
}
if missing && current_level_source.missing_bucket() {
if is_on_after_key && current_level_accessors.skip_missing {
return Ok(());
}
sub_level_values.push(InternalValueRepr::new_missing(
current_level_source.order(),
current_level_source.missing_order(),
));
recursive_key_visitor(
doc_id,
agg_data,
composite_agg_data,
source_idx_for_recursion + 1,
sub_level_values,
buckets,
is_on_after_key && current_level_accessors.is_after_key_explicit_missing,
)?;
sub_level_values.pop();
}
Ok(())
}

View File

@@ -0,0 +1,364 @@
use std::collections::BinaryHeap;
use std::fmt::Debug;
use std::hash::Hash;
use rustc_hash::FxHashMap;
use smallvec::SmallVec;
use crate::TantivyError;
/// Map backed by a hash map for fast access and a binary heap to track the
/// highest key. The key is an array of fixed size S.
#[derive(Clone, Debug)]
struct ArrayHeapMap<K: Ord, V, const S: usize> {
pub(crate) buckets: FxHashMap<[K; S], V>,
pub(crate) heap: BinaryHeap<[K; S]>,
}
impl<K: Ord, V, const S: usize> Default for ArrayHeapMap<K, V, S> {
fn default() -> Self {
ArrayHeapMap {
buckets: FxHashMap::default(),
heap: BinaryHeap::default(),
}
}
}
impl<K: Eq + Hash + Clone + Ord, V, const S: usize> ArrayHeapMap<K, V, S> {
/// Panics if the length of `key` is not S.
fn get_or_insert_with<F: FnOnce() -> V>(&mut self, key: &[K], f: F) -> &mut V {
let key_array: &[K; S] = key.try_into().expect("Key length mismatch");
self.buckets.entry(key_array.clone()).or_insert_with(|| {
self.heap.push(key_array.clone());
f()
})
}
/// Panics if the length of `key` is not S.
fn get_mut(&mut self, key: &[K]) -> Option<&mut V> {
let key_array: &[K; S] = key.try_into().expect("Key length mismatch");
self.buckets.get_mut(key_array)
}
fn peek_highest(&self) -> Option<&[K]> {
self.heap.peek().map(|k_array| k_array.as_slice())
}
fn evict_highest(&mut self) {
if let Some(highest) = self.heap.pop() {
self.buckets.remove(&highest);
}
}
fn memory_consumption(&self) -> u64 {
let key_size = std::mem::size_of::<[K; S]>();
let map_size = (key_size + std::mem::size_of::<V>()) * self.buckets.capacity();
let heap_size = key_size * self.heap.capacity();
(map_size + heap_size) as u64
}
}
impl<K: Copy + Ord + Clone + 'static, V: 'static, const S: usize> ArrayHeapMap<K, V, S> {
fn into_iter(self) -> Box<dyn Iterator<Item = (SmallVec<[K; MAX_DYN_ARRAY_SIZE]>, V)>> {
Box::new(
self.buckets
.into_iter()
.map(|(k, v)| (SmallVec::from_slice(&k), v)),
)
}
fn values_mut<'a>(&'a mut self) -> Box<dyn Iterator<Item = &'a mut V> + 'a> {
Box::new(self.buckets.values_mut())
}
}
pub(super) const MAX_DYN_ARRAY_SIZE: usize = 16;
const MAX_DYN_ARRAY_SIZE_PLUS_ONE: usize = MAX_DYN_ARRAY_SIZE + 1;
/// A map optimized for memory footprint, fast access and efficient eviction of
/// the highest key.
///
/// Keys are inlined arrays of size 1 to [MAX_DYN_ARRAY_SIZE] but for a given
/// instance the key size is fixed. This allows to avoid heap allocations for the
/// keys.
#[derive(Clone, Debug)]
pub(super) struct DynArrayHeapMap<K: Ord, V>(DynArrayHeapMapInner<K, V>);
/// Wrapper around ArrayHeapMap to dynamically dispatch on the array size.
#[derive(Clone, Debug)]
enum DynArrayHeapMapInner<K: Ord, V> {
Dim1(ArrayHeapMap<K, V, 1>),
Dim2(ArrayHeapMap<K, V, 2>),
Dim3(ArrayHeapMap<K, V, 3>),
Dim4(ArrayHeapMap<K, V, 4>),
Dim5(ArrayHeapMap<K, V, 5>),
Dim6(ArrayHeapMap<K, V, 6>),
Dim7(ArrayHeapMap<K, V, 7>),
Dim8(ArrayHeapMap<K, V, 8>),
Dim9(ArrayHeapMap<K, V, 9>),
Dim10(ArrayHeapMap<K, V, 10>),
Dim11(ArrayHeapMap<K, V, 11>),
Dim12(ArrayHeapMap<K, V, 12>),
Dim13(ArrayHeapMap<K, V, 13>),
Dim14(ArrayHeapMap<K, V, 14>),
Dim15(ArrayHeapMap<K, V, 15>),
Dim16(ArrayHeapMap<K, V, 16>),
}
impl<K: Ord, V> DynArrayHeapMap<K, V> {
/// Creates a new heap map with dynamic array keys of size `key_dimension`.
pub(super) fn try_new(key_dimension: usize) -> crate::Result<Self> {
let inner = match key_dimension {
0 => {
return Err(TantivyError::InvalidArgument(
"DynArrayHeapMap dimension must be at least 1".to_string(),
))
}
1 => DynArrayHeapMapInner::Dim1(ArrayHeapMap::default()),
2 => DynArrayHeapMapInner::Dim2(ArrayHeapMap::default()),
3 => DynArrayHeapMapInner::Dim3(ArrayHeapMap::default()),
4 => DynArrayHeapMapInner::Dim4(ArrayHeapMap::default()),
5 => DynArrayHeapMapInner::Dim5(ArrayHeapMap::default()),
6 => DynArrayHeapMapInner::Dim6(ArrayHeapMap::default()),
7 => DynArrayHeapMapInner::Dim7(ArrayHeapMap::default()),
8 => DynArrayHeapMapInner::Dim8(ArrayHeapMap::default()),
9 => DynArrayHeapMapInner::Dim9(ArrayHeapMap::default()),
10 => DynArrayHeapMapInner::Dim10(ArrayHeapMap::default()),
11 => DynArrayHeapMapInner::Dim11(ArrayHeapMap::default()),
12 => DynArrayHeapMapInner::Dim12(ArrayHeapMap::default()),
13 => DynArrayHeapMapInner::Dim13(ArrayHeapMap::default()),
14 => DynArrayHeapMapInner::Dim14(ArrayHeapMap::default()),
15 => DynArrayHeapMapInner::Dim15(ArrayHeapMap::default()),
16 => DynArrayHeapMapInner::Dim16(ArrayHeapMap::default()),
MAX_DYN_ARRAY_SIZE_PLUS_ONE.. => {
return Err(TantivyError::InvalidArgument(format!(
"DynArrayHeapMap supports maximum {MAX_DYN_ARRAY_SIZE} dimensions, got \
{key_dimension}",
)))
}
};
Ok(DynArrayHeapMap(inner))
}
/// Number of elements in the map. This is not the dimension of the keys.
pub(super) fn size(&self) -> usize {
match &self.0 {
DynArrayHeapMapInner::Dim1(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim2(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim3(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim4(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim5(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim6(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim7(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim8(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim9(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim10(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim11(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim12(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim13(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim14(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim15(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim16(map) => map.buckets.len(),
}
}
}
impl<K: Ord + Hash + Clone, V> DynArrayHeapMap<K, V> {
/// Get a mutable reference to the value corresponding to `key` or inserts a new
/// value created by calling `f`.
///
/// Panics if the length of `key` does not match the key dimension of the map.
pub(super) fn get_or_insert_with<F: FnOnce() -> V>(&mut self, key: &[K], f: F) -> &mut V {
match &mut self.0 {
DynArrayHeapMapInner::Dim1(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim2(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim3(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim4(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim5(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim6(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim7(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim8(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim9(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim10(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim11(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim12(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim13(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim14(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim15(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim16(map) => map.get_or_insert_with(key, f),
}
}
/// Returns a mutable reference to the value corresponding to `key`.
///
/// Panics if the length of `key` does not match the key dimension of the map.
pub fn get_mut(&mut self, key: &[K]) -> Option<&mut V> {
match &mut self.0 {
DynArrayHeapMapInner::Dim1(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim2(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim3(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim4(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim5(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim6(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim7(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim8(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim9(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim10(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim11(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim12(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim13(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim14(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim15(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim16(map) => map.get_mut(key),
}
}
/// Returns a reference to the highest key in the map.
pub(super) fn peek_highest(&self) -> Option<&[K]> {
match &self.0 {
DynArrayHeapMapInner::Dim1(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim2(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim3(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim4(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim5(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim6(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim7(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim8(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim9(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim10(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim11(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim12(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim13(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim14(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim15(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim16(map) => map.peek_highest(),
}
}
/// Removes the entry with the highest key from the map.
pub(super) fn evict_highest(&mut self) {
match &mut self.0 {
DynArrayHeapMapInner::Dim1(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim2(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim3(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim4(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim5(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim6(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim7(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim8(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim9(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim10(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim11(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim12(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim13(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim14(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim15(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim16(map) => map.evict_highest(),
}
}
pub(crate) fn memory_consumption(&self) -> u64 {
match &self.0 {
DynArrayHeapMapInner::Dim1(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim2(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim3(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim4(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim5(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim6(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim7(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim8(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim9(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim10(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim11(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim12(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim13(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim14(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim15(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim16(map) => map.memory_consumption(),
}
}
}
impl<K: Ord + Clone + Copy + 'static, V: 'static> DynArrayHeapMap<K, V> {
/// Turns this map into an iterator over key-value pairs.
pub fn into_iter(self) -> impl Iterator<Item = (SmallVec<[K; MAX_DYN_ARRAY_SIZE]>, V)> {
match self.0 {
DynArrayHeapMapInner::Dim1(map) => map.into_iter(),
DynArrayHeapMapInner::Dim2(map) => map.into_iter(),
DynArrayHeapMapInner::Dim3(map) => map.into_iter(),
DynArrayHeapMapInner::Dim4(map) => map.into_iter(),
DynArrayHeapMapInner::Dim5(map) => map.into_iter(),
DynArrayHeapMapInner::Dim6(map) => map.into_iter(),
DynArrayHeapMapInner::Dim7(map) => map.into_iter(),
DynArrayHeapMapInner::Dim8(map) => map.into_iter(),
DynArrayHeapMapInner::Dim9(map) => map.into_iter(),
DynArrayHeapMapInner::Dim10(map) => map.into_iter(),
DynArrayHeapMapInner::Dim11(map) => map.into_iter(),
DynArrayHeapMapInner::Dim12(map) => map.into_iter(),
DynArrayHeapMapInner::Dim13(map) => map.into_iter(),
DynArrayHeapMapInner::Dim14(map) => map.into_iter(),
DynArrayHeapMapInner::Dim15(map) => map.into_iter(),
DynArrayHeapMapInner::Dim16(map) => map.into_iter(),
}
}
/// Returns an iterator over mutable references to the values in the map.
pub(super) fn values_mut(&mut self) -> impl Iterator<Item = &mut V> {
match &mut self.0 {
DynArrayHeapMapInner::Dim1(map) => map.values_mut(),
DynArrayHeapMapInner::Dim2(map) => map.values_mut(),
DynArrayHeapMapInner::Dim3(map) => map.values_mut(),
DynArrayHeapMapInner::Dim4(map) => map.values_mut(),
DynArrayHeapMapInner::Dim5(map) => map.values_mut(),
DynArrayHeapMapInner::Dim6(map) => map.values_mut(),
DynArrayHeapMapInner::Dim7(map) => map.values_mut(),
DynArrayHeapMapInner::Dim8(map) => map.values_mut(),
DynArrayHeapMapInner::Dim9(map) => map.values_mut(),
DynArrayHeapMapInner::Dim10(map) => map.values_mut(),
DynArrayHeapMapInner::Dim11(map) => map.values_mut(),
DynArrayHeapMapInner::Dim12(map) => map.values_mut(),
DynArrayHeapMapInner::Dim13(map) => map.values_mut(),
DynArrayHeapMapInner::Dim14(map) => map.values_mut(),
DynArrayHeapMapInner::Dim15(map) => map.values_mut(),
DynArrayHeapMapInner::Dim16(map) => map.values_mut(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dyn_array_heap_map() {
let mut map = DynArrayHeapMap::<u32, &str>::try_new(2).unwrap();
// insert
let key1 = [1u32, 2u32];
let key2 = [2u32, 1u32];
map.get_or_insert_with(&key1, || "a");
map.get_or_insert_with(&key2, || "b");
assert_eq!(map.size(), 2);
// evict highest
assert_eq!(map.peek_highest(), Some(&key2[..]));
map.evict_highest();
assert_eq!(map.size(), 1);
assert_eq!(map.peek_highest(), Some(&key1[..]));
// mutable iterator
{
let mut mut_iter = map.values_mut();
let v = mut_iter.next().unwrap();
assert_eq!(*v, "a");
*v = "c";
assert_eq!(mut_iter.next(), None);
}
// into_iter
let mut iter = map.into_iter();
let (k, v) = iter.next().unwrap();
assert_eq!(k.as_slice(), &key1);
assert_eq!(v, "c");
assert_eq!(iter.next(), None);
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,460 @@
/// This modules helps comparing numerical values of different types (i64, u64
/// and f64).
pub(super) mod num_cmp {
use std::cmp::Ordering;
use crate::TantivyError;
pub fn cmp_i64_f64(left_i: i64, right_f: f64) -> crate::Result<Ordering> {
if right_f.is_nan() {
return Err(TantivyError::InvalidArgument(
"NaN comparison is not supported".to_string(),
));
}
// If right_f is < i64::MIN then left_i > right_f (i64::MIN=-2^63 can be
// exactly represented as f64)
if right_f < i64::MIN as f64 {
return Ok(Ordering::Greater);
}
// If right_f is >= i64::MAX then left_i < right_f (i64::MAX=2^63-1 cannot
// be exactly represented as f64)
if right_f >= i64::MAX as f64 {
return Ok(Ordering::Less);
}
// Now right_f is in (i64::MIN, i64::MAX), so `right_f as i64` is
// well-defined (truncation toward 0)
let right_as_i = right_f as i64;
let result = match left_i.cmp(&right_as_i) {
Ordering::Less => Ordering::Less,
Ordering::Greater => Ordering::Greater,
Ordering::Equal => {
// they have the same integer part, compare the fraction
let rem = right_f - (right_as_i as f64);
if rem == 0.0 {
Ordering::Equal
} else if right_f > 0.0 {
Ordering::Less
} else {
Ordering::Greater
}
}
};
Ok(result)
}
pub fn cmp_u64_f64(left_u: u64, right_f: f64) -> crate::Result<Ordering> {
if right_f.is_nan() {
return Err(TantivyError::InvalidArgument(
"NaN comparison is not supported".to_string(),
));
}
// Negative floats are always less than any u64 >= 0
if right_f < 0.0 {
return Ok(Ordering::Greater);
}
// If right_f is >= u64::MAX then left_u < right_f (u64::MAX=2^64-1 cannot be exactly)
let max_as_f = u64::MAX as f64;
if right_f > max_as_f {
return Ok(Ordering::Less);
}
// Now right_f is in (0, u64::MAX), so `right_f as u64` is well-defined
// (truncation toward 0)
let right_as_u = right_f as u64;
let result = match left_u.cmp(&right_as_u) {
Ordering::Less => Ordering::Less,
Ordering::Greater => Ordering::Greater,
Ordering::Equal => {
// they have the same integer part, compare the fraction
let rem = right_f - (right_as_u as f64);
if rem == 0.0 {
Ordering::Equal
} else {
Ordering::Less
}
}
};
Ok(result)
}
pub fn cmp_i64_u64(left_i: i64, right_u: u64) -> Ordering {
if left_i < 0 {
Ordering::Less
} else {
let left_as_u = left_i as u64;
left_as_u.cmp(&right_u)
}
}
}
/// This modules helps projecting numerical values to other numerical types.
/// When the target value space cannot exactly represent the source value, the
/// next representable value is returned (or AfterLast if the source value is
/// larger than the largest representable value).
///
/// All functions in this module assume that f64 values are not NaN.
pub(super) mod num_proj {
#[derive(Debug, PartialEq)]
pub enum ProjectedNumber<T> {
Exact(T),
Next(T),
AfterLast,
}
pub fn i64_to_u64(value: i64) -> ProjectedNumber<u64> {
if value < 0 {
ProjectedNumber::Next(0)
} else {
ProjectedNumber::Exact(value as u64)
}
}
pub fn u64_to_i64(value: u64) -> ProjectedNumber<i64> {
if value > i64::MAX as u64 {
ProjectedNumber::AfterLast
} else {
ProjectedNumber::Exact(value as i64)
}
}
pub fn f64_to_u64(value: f64) -> ProjectedNumber<u64> {
if value < 0.0 {
ProjectedNumber::Next(0)
} else if value > u64::MAX as f64 {
ProjectedNumber::AfterLast
} else if value.fract() == 0.0 {
ProjectedNumber::Exact(value as u64)
} else {
// casting f64 to u64 truncates toward zero
ProjectedNumber::Next(value as u64 + 1)
}
}
pub fn f64_to_i64(value: f64) -> ProjectedNumber<i64> {
if value < (i64::MIN as f64) {
return ProjectedNumber::Next(i64::MIN);
} else if value >= (i64::MAX as f64) {
return ProjectedNumber::AfterLast;
} else if value.fract() == 0.0 {
ProjectedNumber::Exact(value as i64)
} else if value > 0.0 {
// casting f64 to i64 truncates toward zero
ProjectedNumber::Next(value as i64 + 1)
} else {
ProjectedNumber::Next(value as i64)
}
}
pub fn i64_to_f64(value: i64) -> ProjectedNumber<f64> {
let value_f = value as f64;
let k_roundtrip = value_f as i64;
if k_roundtrip == value {
// between -2^53 and 2^53 all i64 are exactly represented as f64
ProjectedNumber::Exact(value_f)
} else {
// for very large/small i64 values, it is approximated to the closest f64
if k_roundtrip > value {
ProjectedNumber::Next(value_f)
} else {
ProjectedNumber::Next(value_f.next_up())
}
}
}
pub fn u64_to_f64(value: u64) -> ProjectedNumber<f64> {
let value_f = value as f64;
let k_roundtrip = value_f as u64;
if k_roundtrip == value {
// between 0 and 2^53 all u64 are exactly represented as f64
ProjectedNumber::Exact(value_f)
} else if k_roundtrip > value {
ProjectedNumber::Next(value_f)
} else {
ProjectedNumber::Next(value_f.next_up())
}
}
}
#[cfg(test)]
mod num_cmp_tests {
use std::cmp::Ordering;
use super::num_cmp::*;
#[test]
fn test_cmp_u64_f64() {
// Basic comparisons
assert_eq!(cmp_u64_f64(5, 5.0).unwrap(), Ordering::Equal);
assert_eq!(cmp_u64_f64(5, 6.0).unwrap(), Ordering::Less);
assert_eq!(cmp_u64_f64(6, 5.0).unwrap(), Ordering::Greater);
assert_eq!(cmp_u64_f64(0, 0.0).unwrap(), Ordering::Equal);
assert_eq!(cmp_u64_f64(0, 0.1).unwrap(), Ordering::Less);
// Negative float values should always be less than any u64
assert_eq!(cmp_u64_f64(0, -0.1).unwrap(), Ordering::Greater);
assert_eq!(cmp_u64_f64(5, -5.0).unwrap(), Ordering::Greater);
assert_eq!(cmp_u64_f64(u64::MAX, -1e20).unwrap(), Ordering::Greater);
// Tests with extreme values
assert_eq!(cmp_u64_f64(u64::MAX, 1e20).unwrap(), Ordering::Less);
// Precision edge cases: large u64 that loses precision when converted to f64
// => 2^54, exactly represented as f64
let large_f64 = 18_014_398_509_481_984.0;
let large_u64 = 18_014_398_509_481_984;
// prove that large_u64 is exactly represented as f64
assert_eq!(large_u64 as f64, large_f64);
assert_eq!(cmp_u64_f64(large_u64, large_f64).unwrap(), Ordering::Equal);
// => (2^54 + 1) cannot be exactly represented in f64
let large_u64_plus_1 = 18_014_398_509_481_985;
// prove that it is represented as f64 by large_f64
assert_eq!(large_u64_plus_1 as f64, large_f64);
assert_eq!(
cmp_u64_f64(large_u64_plus_1, large_f64).unwrap(),
Ordering::Greater
);
// => (2^54 - 1) cannot be exactly represented in f64
let large_u64_minus_1 = 18_014_398_509_481_983;
// prove that it is also represented as f64 by large_f64
assert_eq!(large_u64_minus_1 as f64, large_f64);
assert_eq!(
cmp_u64_f64(large_u64_minus_1, large_f64).unwrap(),
Ordering::Less
);
// NaN comparison results in an error
assert!(cmp_u64_f64(0, f64::NAN).is_err());
}
#[test]
fn test_cmp_i64_f64() {
// Basic comparisons
assert_eq!(cmp_i64_f64(5, 5.0).unwrap(), Ordering::Equal);
assert_eq!(cmp_i64_f64(5, 6.0).unwrap(), Ordering::Less);
assert_eq!(cmp_i64_f64(6, 5.0).unwrap(), Ordering::Greater);
assert_eq!(cmp_i64_f64(-5, -5.0).unwrap(), Ordering::Equal);
assert_eq!(cmp_i64_f64(-5, -4.0).unwrap(), Ordering::Less);
assert_eq!(cmp_i64_f64(-4, -5.0).unwrap(), Ordering::Greater);
assert_eq!(cmp_i64_f64(-5, 5.0).unwrap(), Ordering::Less);
assert_eq!(cmp_i64_f64(5, -5.0).unwrap(), Ordering::Greater);
assert_eq!(cmp_i64_f64(0, -0.1).unwrap(), Ordering::Greater);
assert_eq!(cmp_i64_f64(0, 0.1).unwrap(), Ordering::Less);
assert_eq!(cmp_i64_f64(-1, -0.5).unwrap(), Ordering::Less);
assert_eq!(cmp_i64_f64(-1, 0.0).unwrap(), Ordering::Less);
assert_eq!(cmp_i64_f64(0, 0.0).unwrap(), Ordering::Equal);
// Tests with extreme values
assert_eq!(cmp_i64_f64(i64::MAX, 1e20).unwrap(), Ordering::Less);
assert_eq!(cmp_i64_f64(i64::MIN, -1e20).unwrap(), Ordering::Greater);
// Precision edge cases: large i64 that loses precision when converted to f64
// => 2^54, exactly represented as f64
let large_f64 = 18_014_398_509_481_984.0;
let large_i64 = 18_014_398_509_481_984;
// prove that large_i64 is exactly represented as f64
assert_eq!(large_i64 as f64, large_f64);
assert_eq!(cmp_i64_f64(large_i64, large_f64).unwrap(), Ordering::Equal);
// => (1_i64 << 54) + 1 cannot be exactly represented in f64
let large_i64_plus_1 = 18_014_398_509_481_985;
// prove that it is represented as f64 by large_f64
assert_eq!(large_i64_plus_1 as f64, large_f64);
assert_eq!(
cmp_i64_f64(large_i64_plus_1, large_f64).unwrap(),
Ordering::Greater
);
// => (1_i64 << 54) - 1 cannot be exactly represented in f64
let large_i64_minus_1 = 18_014_398_509_481_983;
// prove that it is also represented as f64 by large_f64
assert_eq!(large_i64_minus_1 as f64, large_f64);
assert_eq!(
cmp_i64_f64(large_i64_minus_1, large_f64).unwrap(),
Ordering::Less
);
// Same precision edge case but with negative values
// => -2^54, exactly represented as f64
let large_neg_f64 = -18_014_398_509_481_984.0;
let large_neg_i64 = -18_014_398_509_481_984;
// prove that large_neg_i64 is exactly represented as f64
assert_eq!(large_neg_i64 as f64, large_neg_f64);
assert_eq!(
cmp_i64_f64(large_neg_i64, large_neg_f64).unwrap(),
Ordering::Equal
);
// => (-2^54 + 1) cannot be exactly represented in f64
let large_neg_i64_plus_1 = -18_014_398_509_481_985;
// prove that it is represented as f64 by large_neg_f64
assert_eq!(large_neg_i64_plus_1 as f64, large_neg_f64);
assert_eq!(
cmp_i64_f64(large_neg_i64_plus_1, large_neg_f64).unwrap(),
Ordering::Less
);
// => (-2^54 - 1) cannot be exactly represented in f64
let large_neg_i64_minus_1 = -18_014_398_509_481_983;
// prove that it is also represented as f64 by large_neg_f64
assert_eq!(large_neg_i64_minus_1 as f64, large_neg_f64);
assert_eq!(
cmp_i64_f64(large_neg_i64_minus_1, large_neg_f64).unwrap(),
Ordering::Greater
);
// NaN comparison results in an error
assert!(cmp_i64_f64(0, f64::NAN).is_err());
}
#[test]
fn test_cmp_i64_u64() {
// Test with negative i64 values (should always be less than any u64)
assert_eq!(cmp_i64_u64(-1, 0), Ordering::Less);
assert_eq!(cmp_i64_u64(i64::MIN, 0), Ordering::Less);
assert_eq!(cmp_i64_u64(i64::MIN, u64::MAX), Ordering::Less);
// Test with positive i64 values
assert_eq!(cmp_i64_u64(0, 0), Ordering::Equal);
assert_eq!(cmp_i64_u64(1, 0), Ordering::Greater);
assert_eq!(cmp_i64_u64(1, 1), Ordering::Equal);
assert_eq!(cmp_i64_u64(0, 1), Ordering::Less);
assert_eq!(cmp_i64_u64(5, 10), Ordering::Less);
assert_eq!(cmp_i64_u64(10, 5), Ordering::Greater);
// Test with values near i64::MAX and u64 conversion
assert_eq!(cmp_i64_u64(i64::MAX, i64::MAX as u64), Ordering::Equal);
assert_eq!(cmp_i64_u64(i64::MAX, (i64::MAX as u64) + 1), Ordering::Less);
assert_eq!(cmp_i64_u64(i64::MAX, u64::MAX), Ordering::Less);
}
}
#[cfg(test)]
mod num_proj_tests {
use super::num_proj::{self, ProjectedNumber};
#[test]
fn test_i64_to_u64() {
assert_eq!(num_proj::i64_to_u64(-1), ProjectedNumber::Next(0));
assert_eq!(num_proj::i64_to_u64(i64::MIN), ProjectedNumber::Next(0));
assert_eq!(num_proj::i64_to_u64(0), ProjectedNumber::Exact(0));
assert_eq!(num_proj::i64_to_u64(42), ProjectedNumber::Exact(42));
assert_eq!(
num_proj::i64_to_u64(i64::MAX),
ProjectedNumber::Exact(i64::MAX as u64)
);
}
#[test]
fn test_u64_to_i64() {
assert_eq!(num_proj::u64_to_i64(0), ProjectedNumber::Exact(0));
assert_eq!(num_proj::u64_to_i64(42), ProjectedNumber::Exact(42));
assert_eq!(
num_proj::u64_to_i64(i64::MAX as u64),
ProjectedNumber::Exact(i64::MAX)
);
assert_eq!(
num_proj::u64_to_i64((i64::MAX as u64) + 1),
ProjectedNumber::AfterLast
);
assert_eq!(num_proj::u64_to_i64(u64::MAX), ProjectedNumber::AfterLast);
}
#[test]
fn test_f64_to_u64() {
assert_eq!(num_proj::f64_to_u64(-1e25), ProjectedNumber::Next(0));
assert_eq!(num_proj::f64_to_u64(-0.1), ProjectedNumber::Next(0));
assert_eq!(num_proj::f64_to_u64(1e20), ProjectedNumber::AfterLast);
assert_eq!(
num_proj::f64_to_u64(f64::INFINITY),
ProjectedNumber::AfterLast
);
assert_eq!(num_proj::f64_to_u64(0.0), ProjectedNumber::Exact(0));
assert_eq!(num_proj::f64_to_u64(42.0), ProjectedNumber::Exact(42));
assert_eq!(num_proj::f64_to_u64(0.5), ProjectedNumber::Next(1));
assert_eq!(num_proj::f64_to_u64(42.1), ProjectedNumber::Next(43));
}
#[test]
fn test_f64_to_i64() {
assert_eq!(num_proj::f64_to_i64(-1e20), ProjectedNumber::Next(i64::MIN));
assert_eq!(
num_proj::f64_to_i64(f64::NEG_INFINITY),
ProjectedNumber::Next(i64::MIN)
);
assert_eq!(num_proj::f64_to_i64(1e20), ProjectedNumber::AfterLast);
assert_eq!(
num_proj::f64_to_i64(f64::INFINITY),
ProjectedNumber::AfterLast
);
assert_eq!(num_proj::f64_to_i64(0.0), ProjectedNumber::Exact(0));
assert_eq!(num_proj::f64_to_i64(42.0), ProjectedNumber::Exact(42));
assert_eq!(num_proj::f64_to_i64(-42.0), ProjectedNumber::Exact(-42));
assert_eq!(num_proj::f64_to_i64(0.5), ProjectedNumber::Next(1));
assert_eq!(num_proj::f64_to_i64(42.1), ProjectedNumber::Next(43));
assert_eq!(num_proj::f64_to_i64(-0.5), ProjectedNumber::Next(0));
assert_eq!(num_proj::f64_to_i64(-42.1), ProjectedNumber::Next(-42));
}
#[test]
fn test_i64_to_f64() {
assert_eq!(num_proj::i64_to_f64(0), ProjectedNumber::Exact(0.0));
assert_eq!(num_proj::i64_to_f64(42), ProjectedNumber::Exact(42.0));
assert_eq!(num_proj::i64_to_f64(-42), ProjectedNumber::Exact(-42.0));
let max_exact = 9_007_199_254_740_992; // 2^53
assert_eq!(
num_proj::i64_to_f64(max_exact),
ProjectedNumber::Exact(max_exact as f64)
);
// Test values that cannot be exactly represented as f64 (integers above 2^53)
let large_i64 = 9_007_199_254_740_993; // 2^53 + 1
let closest_f64 = 9_007_199_254_740_992.0;
assert_eq!(large_i64 as f64, closest_f64);
if let ProjectedNumber::Next(val) = num_proj::i64_to_f64(large_i64) {
// Verify that the returned float is different from the direct cast
assert!(val > closest_f64);
assert!(val - closest_f64 < 2. * f64::EPSILON * closest_f64);
} else {
panic!("Expected ProjectedNumber::Next for large_i64");
}
// Test with very large negative value
let large_neg_i64 = -9_007_199_254_740_993; // -(2^53 + 1)
let closest_neg_f64 = -9_007_199_254_740_992.0;
assert_eq!(large_neg_i64 as f64, closest_neg_f64);
if let ProjectedNumber::Next(val) = num_proj::i64_to_f64(large_neg_i64) {
// Verify that the returned float is the closest representable f64
assert_eq!(val, closest_neg_f64);
} else {
panic!("Expected ProjectedNumber::Next for large_neg_i64");
}
}
#[test]
fn test_u64_to_f64() {
assert_eq!(num_proj::u64_to_f64(0), ProjectedNumber::Exact(0.0));
assert_eq!(num_proj::u64_to_f64(42), ProjectedNumber::Exact(42.0));
// Test the largest u64 value that can be exactly represented as f64 (2^53)
let max_exact = 9_007_199_254_740_992; // 2^53
assert_eq!(
num_proj::u64_to_f64(max_exact),
ProjectedNumber::Exact(max_exact as f64)
);
// Test values that cannot be exactly represented as f64 (integers above 2^53)
let large_u64 = 9_007_199_254_740_993; // 2^53 + 1
let closest_f64 = 9_007_199_254_740_992.0;
assert_eq!(large_u64 as f64, closest_f64);
if let ProjectedNumber::Next(val) = num_proj::u64_to_f64(large_u64) {
// Verify that the returned float is different from the direct cast
assert!(val > closest_f64);
assert!(val - closest_f64 < 2. * f64::EPSILON * closest_f64);
} else {
panic!("Expected ProjectedNumber::Next for large_u64");
}
}
}

View File

@@ -207,7 +207,7 @@ fn parse_offset_into_milliseconds(input: &str) -> Result<i64, AggregationError>
}
}
fn parse_into_milliseconds(input: &str) -> Result<i64, AggregationError> {
pub(crate) fn parse_into_milliseconds(input: &str) -> Result<i64, AggregationError> {
let split_boundary = input
.as_bytes()
.iter()

View File

@@ -22,6 +22,7 @@
//! - [Range](RangeAggregation)
//! - [Terms](TermsAggregation)
mod composite;
mod filter;
mod histogram;
mod range;
@@ -31,6 +32,7 @@ mod term_missing_agg;
use std::collections::HashMap;
use std::fmt;
pub use composite::*;
pub use filter::*;
pub use histogram::*;
pub use range::*;

View File

@@ -25,9 +25,12 @@ use super::metric::{
use super::segment_agg_result::AggregationLimitsGuard;
use super::{format_date, AggregationError, Key, SerializedKey};
use crate::aggregation::agg_result::{
AggregationResults, BucketEntries, BucketEntry, FilterBucketResult,
AggregationResults, BucketEntries, BucketEntry, CompositeBucketEntry, FilterBucketResult,
};
use crate::aggregation::bucket::{
composite_intermediate_key_ordering, CompositeAggregation, MissingOrder,
TermsAggregationInternal,
};
use crate::aggregation::bucket::TermsAggregationInternal;
use crate::aggregation::metric::CardinalityCollector;
use crate::TantivyError;
@@ -90,6 +93,19 @@ impl From<IntermediateKey> for Key {
impl Eq for IntermediateKey {}
impl std::fmt::Display for IntermediateKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
IntermediateKey::Str(val) => f.write_str(val),
IntermediateKey::F64(val) => f.write_str(&val.to_string()),
IntermediateKey::U64(val) => f.write_str(&val.to_string()),
IntermediateKey::I64(val) => f.write_str(&val.to_string()),
IntermediateKey::Bool(val) => f.write_str(&val.to_string()),
IntermediateKey::IpAddr(val) => f.write_str(&val.to_string()),
}
}
}
impl std::hash::Hash for IntermediateKey {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
core::mem::discriminant(self).hash(state);
@@ -105,6 +121,21 @@ impl std::hash::Hash for IntermediateKey {
}
impl IntermediateAggregationResults {
/// Returns a reference to the intermediate aggregation result for the given key.
pub fn get(&self, key: &str) -> Option<&IntermediateAggregationResult> {
self.aggs_res.get(key)
}
/// Removes and returns the intermediate aggregation result for the given key.
pub fn remove(&mut self, key: &str) -> Option<IntermediateAggregationResult> {
self.aggs_res.remove(key)
}
/// Returns an iterator over the keys in the intermediate aggregation results.
pub fn keys(&self) -> impl Iterator<Item = &String> {
self.aggs_res.keys()
}
/// Add a result
pub fn push(&mut self, key: String, value: IntermediateAggregationResult) -> crate::Result<()> {
let entry = self.aggs_res.entry(key);
@@ -218,6 +249,11 @@ pub(crate) fn empty_from_req(req: &Aggregation) -> IntermediateAggregationResult
is_date_agg: true,
})
}
Composite(_) => {
IntermediateAggregationResult::Bucket(IntermediateBucketResult::Composite {
buckets: Default::default(),
})
}
Average(_) => IntermediateAggregationResult::Metric(IntermediateMetricResult::Average(
IntermediateAverage::default(),
)),
@@ -445,6 +481,11 @@ pub enum IntermediateBucketResult {
/// Sub-aggregation results
sub_aggregations: IntermediateAggregationResults,
},
/// Composite aggregation
Composite {
/// The composite buckets
buckets: IntermediateCompositeBucketResult,
},
}
impl IntermediateBucketResult {
@@ -540,6 +581,13 @@ impl IntermediateBucketResult {
sub_aggregations: final_sub_aggregations,
}))
}
IntermediateBucketResult::Composite { buckets } => buckets.into_final_result(
req.agg
.as_composite()
.expect("unexpected aggregation, expected composite aggregation"),
req.sub_aggregation(),
limits,
),
}
}
@@ -606,6 +654,16 @@ impl IntermediateBucketResult {
*doc_count_left += doc_count_right;
sub_aggs_left.merge_fruits(sub_aggs_right)?;
}
(
IntermediateBucketResult::Composite {
buckets: buckets_left,
},
IntermediateBucketResult::Composite {
buckets: buckets_right,
},
) => {
buckets_left.merge_fruits(buckets_right)?;
}
(IntermediateBucketResult::Range(_), _) => {
panic!("try merge on different types")
}
@@ -618,6 +676,9 @@ impl IntermediateBucketResult {
(IntermediateBucketResult::Filter { .. }, _) => {
panic!("try merge on different types")
}
(IntermediateBucketResult::Composite { .. }, _) => {
panic!("try merge on different types")
}
}
Ok(())
}
@@ -639,6 +700,21 @@ pub struct IntermediateTermBucketResult {
}
impl IntermediateTermBucketResult {
/// Returns a reference to the map of bucket entries keyed by [`IntermediateKey`].
pub fn entries(&self) -> &FxHashMap<IntermediateKey, IntermediateTermBucketEntry> {
&self.entries
}
/// Returns the count of documents not included in the returned buckets.
pub fn sum_other_doc_count(&self) -> u64 {
self.sum_other_doc_count
}
/// Returns the upper bound of the error on document counts in the returned buckets.
pub fn doc_count_error_upper_bound(&self) -> u64 {
self.doc_count_error_upper_bound
}
pub(crate) fn into_final_result(
self,
req: &TermsAggregation,
@@ -820,7 +896,7 @@ impl IntermediateRangeBucketEntry {
};
// If we have a date type on the histogram buckets, we add the `key_as_string` field as
// rfc339
// rfc3339
if column_type == Some(ColumnType::DateTime) {
if let Some(val) = range_bucket_entry.to {
let key_as_string = format_date(val as i64)?;
@@ -846,6 +922,212 @@ pub struct IntermediateTermBucketEntry {
pub sub_aggregation: IntermediateAggregationResults,
}
/// Entry for the composite bucket.
pub type IntermediateCompositeBucketEntry = IntermediateTermBucketEntry;
/// The fully typed key for composite aggregation
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum CompositeIntermediateKey {
/// Bool key
Bool(bool),
/// String key
Str(String),
/// Float key
F64(f64),
/// Signed integer key
I64(i64),
/// Unsigned integer key
U64(u64),
/// DateTime key, nanoseconds since epoch
DateTime(i64),
/// IP Address key
IpAddr(Ipv6Addr),
/// Missing value key
Null,
}
impl Eq for CompositeIntermediateKey {}
impl std::hash::Hash for CompositeIntermediateKey {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
core::mem::discriminant(self).hash(state);
match self {
CompositeIntermediateKey::Bool(val) => val.hash(state),
CompositeIntermediateKey::Str(text) => text.hash(state),
CompositeIntermediateKey::F64(val) => val.to_bits().hash(state),
CompositeIntermediateKey::U64(val) => val.hash(state),
CompositeIntermediateKey::I64(val) => val.hash(state),
CompositeIntermediateKey::DateTime(val) => val.hash(state),
CompositeIntermediateKey::IpAddr(val) => val.hash(state),
CompositeIntermediateKey::Null => {}
}
}
}
/// Composite aggregation page.
#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct IntermediateCompositeBucketResult {
#[serde(
serialize_with = "serialize_composite_entries",
deserialize_with = "deserialize_composite_entries"
)]
pub(crate) entries: FxHashMap<Vec<CompositeIntermediateKey>, IntermediateCompositeBucketEntry>,
pub(crate) target_size: u32,
pub(crate) orders: Vec<(Order, MissingOrder)>,
}
fn serialize_composite_entries<S>(
entries: &FxHashMap<Vec<CompositeIntermediateKey>, IntermediateCompositeBucketEntry>,
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeSeq;
let mut seq = serializer.serialize_seq(Some(entries.len()))?;
for (k, v) in entries {
seq.serialize_element(&(k, v))?;
}
seq.end()
}
fn deserialize_composite_entries<'de, D>(
deserializer: D,
) -> Result<FxHashMap<Vec<CompositeIntermediateKey>, IntermediateCompositeBucketEntry>, D::Error>
where
D: serde::Deserializer<'de>,
{
let vec: Vec<(Vec<CompositeIntermediateKey>, IntermediateCompositeBucketEntry)> =
serde::Deserialize::deserialize(deserializer)?;
Ok(vec.into_iter().collect())
}
impl IntermediateCompositeBucketResult {
pub(crate) fn into_final_result(
self,
req: &CompositeAggregation,
sub_aggregation_req: &Aggregations,
limits: &mut AggregationLimitsGuard,
) -> crate::Result<BucketResult> {
let trimmed_entry_vec =
trim_composite_buckets(self.entries, &self.orders, self.target_size)?;
let after_key = if trimmed_entry_vec.len() == req.size as usize {
trimmed_entry_vec
.last()
.map(|bucket| {
let (intermediate_key, _entry) = bucket;
intermediate_key
.iter()
.enumerate()
.map(|(idx, intermediate_key)| {
let source = &req.sources[idx];
(source.name().to_string(), intermediate_key.clone().into())
})
.collect()
})
.unwrap()
} else {
FxHashMap::default()
};
let buckets = trimmed_entry_vec
.into_iter()
.map(|(intermediate_key, entry)| {
let key = intermediate_key
.into_iter()
.enumerate()
.map(|(idx, intermediate_key)| {
let source = &req.sources[idx];
(source.name().to_string(), intermediate_key.into())
})
.collect();
Ok(CompositeBucketEntry {
key,
doc_count: entry.doc_count as u64,
sub_aggregation: entry
.sub_aggregation
.into_final_result_internal(sub_aggregation_req, limits)?,
})
})
.collect::<crate::Result<Vec<_>>>()?;
Ok(BucketResult::Composite { after_key, buckets })
}
fn merge_fruits(&mut self, other: IntermediateCompositeBucketResult) -> crate::Result<()> {
merge_maps(&mut self.entries, other.entries)?;
if self.entries.len() as u32 > 2 * self.target_size {
// 2x factor used to avoid trimming too often (expensive operation)
// an optimal threshold could probably be figured out
self.trim()?;
}
Ok(())
}
/// Trim the composite buckets to the target size, according to the ordering.
///
/// Returns an error if the ordering comparison fails.
pub(crate) fn trim(&mut self) -> crate::Result<()> {
if self.entries.len() as u32 <= self.target_size {
return Ok(());
}
let sorted_entries = trim_composite_buckets(
std::mem::take(&mut self.entries),
&self.orders,
self.target_size,
)?;
self.entries = sorted_entries.into_iter().collect();
Ok(())
}
}
fn trim_composite_buckets(
entries: FxHashMap<Vec<CompositeIntermediateKey>, IntermediateCompositeBucketEntry>,
orders: &[(Order, MissingOrder)],
target_size: u32,
) -> crate::Result<
Vec<(
Vec<CompositeIntermediateKey>,
IntermediateCompositeBucketEntry,
)>,
> {
let mut entries: Vec<_> = entries.into_iter().collect();
let mut sort_error: Option<TantivyError> = None;
entries.sort_by(|(left_key, _), (right_key, _)| {
// Only attempt sorting if we haven't encountered an error yet
if sort_error.is_some() {
return Ordering::Equal; // Return a default, we'll handle the error after sorting
}
for i in 0..orders.len() {
match composite_intermediate_key_ordering(
&left_key[i],
&right_key[i],
orders[i].0,
orders[i].1,
) {
Ok(ordering) if ordering != Ordering::Equal => return ordering,
Ok(_) => continue, // Equal, try next key
Err(err) => {
sort_error = Some(err);
break;
}
}
}
Ordering::Equal
});
// If we encountered an error during sorting, return it now
if let Some(err) = sort_error {
return Err(err);
}
entries.truncate(target_size as usize);
Ok(entries)
}
impl MergeFruits for IntermediateTermBucketEntry {
fn merge_fruits(&mut self, other: IntermediateTermBucketEntry) -> crate::Result<()> {
self.doc_count += other.doc_count;

View File

@@ -55,6 +55,12 @@ impl IntermediateAverage {
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
Self { stats }
}
/// Returns a reference to the underlying [`IntermediateStats`].
pub fn stats(&self) -> &IntermediateStats {
&self.stats
}
/// Merges the other intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateAverage) {
self.stats.merge_fruits(other.stats);

View File

@@ -1,12 +1,11 @@
use std::collections::hash_map::DefaultHasher;
use std::hash::{BuildHasher, Hasher};
use std::hash::Hash;
use columnar::column_values::CompactSpaceU64Accessor;
use columnar::{Column, ColumnType, Dictionary, StrColumn};
use common::f64_to_u64;
use hyperloglogplus::{HyperLogLog, HyperLogLogPlus};
use datasketches::hll::{HllSketch, HllType, HllUnion};
use rustc_hash::FxHashSet;
use serde::{Deserialize, Serialize};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::intermediate_agg_result::{
@@ -16,29 +15,17 @@ use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::*;
use crate::TantivyError;
#[derive(Clone, Debug, Serialize, Deserialize)]
struct BuildSaltedHasher {
salt: u8,
}
impl BuildHasher for BuildSaltedHasher {
type Hasher = DefaultHasher;
fn build_hasher(&self) -> Self::Hasher {
let mut hasher = DefaultHasher::new();
hasher.write_u8(self.salt);
hasher
}
}
/// Log2 of the number of registers for the HLL sketch.
/// 2^11 = 2048 registers, giving ~2.3% relative error and ~1KB per sketch (Hll4).
const LG_K: u8 = 11;
/// # Cardinality
///
/// The cardinality aggregation allows for computing an estimate
/// of the number of different values in a data set based on the
/// HyperLogLog++ algorithm. This is particularly useful for understanding the
/// uniqueness of values in a large dataset where counting each unique value
/// individually would be computationally expensive.
/// Apache DataSketches HyperLogLog algorithm. This is particularly useful for
/// understanding the uniqueness of values in a large dataset where counting
/// each unique value individually would be computationally expensive.
///
/// For example, you might use a cardinality aggregation to estimate the number
/// of unique visitors to a website by aggregating on a field that contains
@@ -184,7 +171,7 @@ impl SegmentCardinalityCollectorBucket {
term_ids.sort_unstable();
dict.sorted_ords_to_term_cb(term_ids.iter().map(|term| *term as u64), |term| {
self.cardinality.sketch.insert_any(&term);
self.cardinality.insert(term);
Ok(())
})?;
if has_missing {
@@ -195,17 +182,17 @@ impl SegmentCardinalityCollectorBucket {
);
match missing_key {
Key::Str(missing) => {
self.cardinality.sketch.insert_any(&missing);
self.cardinality.insert(missing.as_str());
}
Key::F64(val) => {
let val = f64_to_u64(*val);
self.cardinality.sketch.insert_any(&val);
self.cardinality.insert(val);
}
Key::U64(val) => {
self.cardinality.sketch.insert_any(&val);
self.cardinality.insert(*val);
}
Key::I64(val) => {
self.cardinality.sketch.insert_any(&val);
self.cardinality.insert(*val);
}
}
}
@@ -296,11 +283,11 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector {
})?;
for val in col_block_accessor.iter_vals() {
let val: u128 = compact_space_accessor.compact_to_u128(val as u32);
bucket.cardinality.sketch.insert_any(&val);
bucket.cardinality.insert(val);
}
} else {
for val in col_block_accessor.iter_vals() {
bucket.cardinality.sketch.insert_any(&val);
bucket.cardinality.insert(val);
}
}
@@ -321,11 +308,18 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector {
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
/// The percentiles collector used during segment collection and for merging results.
#[derive(Clone, Debug)]
/// The cardinality collector used during segment collection and for merging results.
/// Uses Apache DataSketches HLL (lg_k=11, Hll4) for compact binary serialization
/// and cross-language compatibility (e.g. Java `datasketches` library).
pub struct CardinalityCollector {
sketch: HyperLogLogPlus<u64, BuildSaltedHasher>,
sketch: HllSketch,
/// Salt derived from `ColumnType`, used to differentiate values of different column types
/// that map to the same u64 (e.g. bool `false` = 0 vs i64 `0`).
/// Not serialized — only needed during insertion, not after sketch registers are populated.
salt: u8,
}
impl Default for CardinalityCollector {
fn default() -> Self {
Self::new(0)
@@ -338,25 +332,52 @@ impl PartialEq for CardinalityCollector {
}
}
impl CardinalityCollector {
/// Compute the final cardinality estimate.
pub fn finalize(self) -> Option<f64> {
Some(self.sketch.clone().count().trunc())
impl Serialize for CardinalityCollector {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let bytes = self.sketch.serialize();
serializer.serialize_bytes(&bytes)
}
}
impl<'de> Deserialize<'de> for CardinalityCollector {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let bytes: Vec<u8> = Deserialize::deserialize(deserializer)?;
let sketch = HllSketch::deserialize(&bytes).map_err(serde::de::Error::custom)?;
Ok(Self { sketch, salt: 0 })
}
}
impl CardinalityCollector {
fn new(salt: u8) -> Self {
Self {
sketch: HyperLogLogPlus::new(16, BuildSaltedHasher { salt }).unwrap(),
sketch: HllSketch::new(LG_K, HllType::Hll4),
salt,
}
}
pub(crate) fn merge_fruits(&mut self, right: CardinalityCollector) -> crate::Result<()> {
self.sketch.merge(&right.sketch).map_err(|err| {
TantivyError::AggregationError(AggregationError::InternalError(format!(
"Error while merging cardinality {err:?}"
)))
})?;
/// Insert a value into the HLL sketch, salted by the column type.
/// The salt ensures that identical u64 values from different column types
/// (e.g. bool `false` vs i64 `0`) are counted as distinct.
pub(crate) fn insert<T: Hash>(&mut self, value: T) {
self.sketch.update((self.salt, value));
}
/// Compute the final cardinality estimate.
pub fn finalize(self) -> Option<f64> {
Some(self.sketch.estimate().trunc())
}
/// Serialize the HLL sketch to its compact binary representation.
/// The format is cross-language compatible with Apache DataSketches (Java, C++, Python).
pub fn to_sketch_bytes(&self) -> Vec<u8> {
self.sketch.serialize()
}
pub(crate) fn merge_fruits(&mut self, right: CardinalityCollector) -> crate::Result<()> {
let mut union = HllUnion::new(LG_K);
union.update(&self.sketch);
union.update(&right.sketch);
self.sketch = union.get_result(HllType::Hll4);
Ok(())
}
}
@@ -518,4 +539,75 @@ mod tests {
Ok(())
}
#[test]
fn cardinality_collector_serde_roundtrip() {
use super::CardinalityCollector;
let mut collector = CardinalityCollector::default();
collector.insert("hello");
collector.insert("world");
collector.insert("hello"); // duplicate
let serialized = serde_json::to_vec(&collector).unwrap();
let deserialized: CardinalityCollector = serde_json::from_slice(&serialized).unwrap();
let original_estimate = collector.finalize().unwrap();
let roundtrip_estimate = deserialized.finalize().unwrap();
assert_eq!(original_estimate, roundtrip_estimate);
assert_eq!(original_estimate, 2.0);
}
#[test]
fn cardinality_collector_merge() {
use super::CardinalityCollector;
let mut left = CardinalityCollector::default();
left.insert("a");
left.insert("b");
let mut right = CardinalityCollector::default();
right.insert("b");
right.insert("c");
left.merge_fruits(right).unwrap();
let estimate = left.finalize().unwrap();
assert_eq!(estimate, 3.0);
}
#[test]
fn cardinality_collector_serialize_deserialize_binary() {
use datasketches::hll::HllSketch;
use super::CardinalityCollector;
let mut collector = CardinalityCollector::default();
collector.insert("apple");
collector.insert("banana");
collector.insert("cherry");
let bytes = collector.to_sketch_bytes();
let deserialized = HllSketch::deserialize(&bytes).unwrap();
assert!((deserialized.estimate() - 3.0).abs() < 0.01);
}
#[test]
fn cardinality_collector_salt_differentiates_types() {
use super::CardinalityCollector;
// Without salt, same u64 value from different column types would collide
let mut collector_bool = CardinalityCollector::new(5); // e.g. ColumnType::Bool
collector_bool.insert(0u64); // false
collector_bool.insert(1u64); // true
let mut collector_i64 = CardinalityCollector::new(2); // e.g. ColumnType::I64
collector_i64.insert(0u64);
collector_i64.insert(1u64);
// Merge them
collector_bool.merge_fruits(collector_i64).unwrap();
let estimate = collector_bool.finalize().unwrap();
// Should be 4 because salt makes (5, 0) != (2, 0) and (5, 1) != (2, 1)
assert_eq!(estimate, 4.0);
}
}

View File

@@ -107,8 +107,11 @@ pub enum PercentileValues {
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
/// The entry when requesting percentiles with keyed: false
pub struct PercentileValuesVecEntry {
key: f64,
value: f64,
/// Percentile
pub key: f64,
/// Value at the percentile
pub value: f64,
}
/// Single-metric aggregations use this common result structure.

View File

@@ -222,6 +222,12 @@ impl PercentilesCollector {
self.sketch.add(val);
}
/// Encode the underlying DDSketch to Java-compatible binary format
/// for cross-language serialization with Java consumers.
pub fn to_sketch_bytes(&self) -> Vec<u8> {
self.sketch.to_java_bytes()
}
pub(crate) fn merge_fruits(&mut self, right: PercentilesCollector) -> crate::Result<()> {
self.sketch.merge(&right.sketch).map_err(|err| {
TantivyError::AggregationError(AggregationError::InternalError(format!(
@@ -610,11 +616,11 @@ mod tests {
assert_eq!(
res["range_with_stats"]["buckets"][0]["percentiles"]["values"]["1.0"],
5.0028295751107414
5.002829575110705
);
assert_eq!(
res["range_with_stats"]["buckets"][0]["percentiles"]["values"]["99.0"],
10.07469668951144
10.07469668951133
);
Ok(())
@@ -659,8 +665,8 @@ mod tests {
let res = exec_request_with_query(agg_req, &index, None)?;
assert_eq!(res["percentiles"]["values"]["1.0"], 5.0028295751107414);
assert_eq!(res["percentiles"]["values"]["99.0"], 10.07469668951144);
assert_eq!(res["percentiles"]["values"]["1.0"], 5.002829575110705);
assert_eq!(res["percentiles"]["values"]["99.0"], 10.07469668951133);
Ok(())
}

View File

@@ -110,6 +110,16 @@ impl Default for IntermediateStats {
}
impl IntermediateStats {
/// Returns the number of values collected.
pub fn count(&self) -> u64 {
self.count
}
/// Returns the sum of all values collected.
pub fn sum(&self) -> f64 {
self.sum
}
/// Merges the other stats intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateStats) {
self.count += other.count;

View File

@@ -1,229 +0,0 @@
/// Codec specific to postings data.
pub mod postings;
/// Standard tantivy codec. This is the codec you use by default.
pub mod standard;
use std::io;
pub use standard::StandardCodec;
use crate::codec::postings::PostingsCodec;
use crate::fieldnorm::FieldNormReader;
use crate::postings::{Postings, TermInfo};
use crate::query::score_combiner::DoNothingCombiner;
use crate::query::term_query::TermScorer;
use crate::query::{box_scorer, Bm25Weight, BufferedUnionScorer, Scorer, SumCombiner};
use crate::schema::IndexRecordOption;
use crate::{DocId, InvertedIndexReader, Score};
/// Codecs describes how data is layed out on disk.
///
/// For the moment, only postings codec can be custom.
pub trait Codec: Clone + std::fmt::Debug + Send + Sync + 'static {
/// The specific postings type used by this codec.
type PostingsCodec: PostingsCodec;
/// ID of the codec. It should be unique to your codec.
/// Make it human-readable, descriptive, short and unique.
const ID: &'static str;
/// Load codec based on the codec configuration.
fn from_json_props(json_value: &serde_json::Value) -> crate::Result<Self>;
/// Get codec configuration.
fn to_json_props(&self) -> serde_json::Value;
/// Returns the postings codec.
fn postings_codec(&self) -> &Self::PostingsCodec;
}
/// Object-safe codec is a Codec that can be used in a trait object.
///
/// The point of it is to offer a way to use a codec without a proliferation of generics.
pub trait ObjectSafeCodec: 'static + Send + Sync {
/// Loads a type-erased Postings object for the given term.
///
/// If the schema used to build the index did not provide enough
/// information to match the requested `option`, a Postings is still
/// returned in a best-effort manner.
fn load_postings_type_erased(
&self,
term_info: &TermInfo,
option: IndexRecordOption,
inverted_index_reader: &InvertedIndexReader,
) -> io::Result<Box<dyn Postings>>;
/// Loads a type-erased TermScorer object for the given term.
///
/// If the schema used to build the index did not provide enough
/// information to match the requested `option`, a TermScorer is still
/// returned in a best-effort manner.
///
/// The point of this contraption is that the return TermScorer is backed,
/// not by Box<dyn Postings> but by the codec's concrete Postings type.
fn load_term_scorer_type_erased(
&self,
term_info: &TermInfo,
option: IndexRecordOption,
inverted_index_reader: &InvertedIndexReader,
fieldnorm_reader: FieldNormReader,
similarity_weight: Bm25Weight,
) -> io::Result<Box<dyn Scorer>>;
/// Loads a type-erased PhraseScorer object for the given term.
///
/// If the schema used to build the index did not provide enough
/// information to match the requested `option`, a TermScorer is still
/// returned in a best-effort manner.
///
/// The point of this contraption is that the return PhraseScorer is backed,
/// not by Box<dyn Postings> but by the codec's concrete Postings type.
fn new_phrase_scorer_type_erased(
&self,
term_infos: &[(usize, TermInfo)],
similarity_weight: Option<Bm25Weight>,
fieldnorm_reader: FieldNormReader,
slop: u32,
inverted_index_reader: &InvertedIndexReader,
) -> io::Result<Box<dyn Scorer>>;
/// Performs a for_each_pruning operation on the given scorer.
///
/// The function will go through matching documents and call the callback
/// function for all docs with a score exceeding the threshold.
///
/// The function itself will return a larger threshold value,
/// meant to update the threshold value.
///
/// If the codec and the scorer allow it, this function can rely on
/// optimizations like the block-max wand.
fn for_each_pruning(
&self,
threshold: Score,
scorer: Box<dyn Scorer>,
callback: &mut dyn FnMut(DocId, Score) -> Score,
);
/// Builds a union scorer possibly specialized if
/// all scorers are `Term<Self::Postings>`.
fn build_union_scorer_with_sum_combiner(
&self,
scorers: Vec<Box<dyn Scorer>>,
num_docs: DocId,
score_combiner_type: SumOrDoNothingCombiner,
) -> Box<dyn Scorer>;
}
impl<TCodec: Codec> ObjectSafeCodec for TCodec {
fn load_postings_type_erased(
&self,
term_info: &TermInfo,
option: IndexRecordOption,
inverted_index_reader: &InvertedIndexReader,
) -> io::Result<Box<dyn Postings>> {
let postings = inverted_index_reader
.read_postings_from_terminfo_specialized(term_info, option, self)?;
Ok(Box::new(postings))
}
fn load_term_scorer_type_erased(
&self,
term_info: &TermInfo,
option: IndexRecordOption,
inverted_index_reader: &InvertedIndexReader,
fieldnorm_reader: FieldNormReader,
similarity_weight: Bm25Weight,
) -> io::Result<Box<dyn Scorer>> {
let scorer = inverted_index_reader.new_term_scorer_specialized(
term_info,
option,
fieldnorm_reader,
similarity_weight,
self,
)?;
Ok(box_scorer(scorer))
}
fn new_phrase_scorer_type_erased(
&self,
term_infos: &[(usize, TermInfo)],
similarity_weight: Option<Bm25Weight>,
fieldnorm_reader: FieldNormReader,
slop: u32,
inverted_index_reader: &InvertedIndexReader,
) -> io::Result<Box<dyn Scorer>> {
let scorer = inverted_index_reader.new_phrase_scorer_type_specialized(
term_infos,
similarity_weight,
fieldnorm_reader,
slop,
self,
)?;
Ok(box_scorer(scorer))
}
fn build_union_scorer_with_sum_combiner(
&self,
scorers: Vec<Box<dyn Scorer>>,
num_docs: DocId,
sum_or_do_nothing_combiner: SumOrDoNothingCombiner,
) -> Box<dyn Scorer> {
if !scorers.iter().all(|scorer| {
scorer.is::<TermScorer<<<Self as Codec>::PostingsCodec as PostingsCodec>::Postings>>()
}) {
return box_scorer(BufferedUnionScorer::build(
scorers,
SumCombiner::default,
num_docs,
));
}
let specialized_scorers: Vec<
TermScorer<<<Self as Codec>::PostingsCodec as PostingsCodec>::Postings>,
> = scorers
.into_iter()
.map(|scorer| {
*scorer.downcast::<TermScorer<_>>().ok().expect(
"Downcast failed despite the fact we already checked the type was correct",
)
})
.collect();
match sum_or_do_nothing_combiner {
SumOrDoNothingCombiner::Sum => box_scorer(BufferedUnionScorer::build(
specialized_scorers,
SumCombiner::default,
num_docs,
)),
SumOrDoNothingCombiner::DoNothing => box_scorer(BufferedUnionScorer::build(
specialized_scorers,
DoNothingCombiner::default,
num_docs,
)),
}
}
fn for_each_pruning(
&self,
threshold: Score,
scorer: Box<dyn Scorer>,
callback: &mut dyn FnMut(DocId, Score) -> Score,
) {
let accerelerated_foreach_pruning_res =
<TCodec as Codec>::PostingsCodec::try_accelerated_for_each_pruning(
threshold, scorer, callback,
);
if let Err(mut scorer) = accerelerated_foreach_pruning_res {
// No acceleration available. We need to do things manually.
scorer.for_each_pruning(threshold, callback);
}
}
}
/// SumCombiner or DoNothingCombiner
#[derive(Copy, Clone)]
pub enum SumOrDoNothingCombiner {
/// Sum scores together
Sum,
/// Do not track any score.
DoNothing,
}

View File

@@ -1,123 +0,0 @@
use std::io;
/// Block-max WAND algorithm.
pub mod block_wand;
use common::OwnedBytes;
use crate::fieldnorm::FieldNormReader;
use crate::postings::Postings;
use crate::query::{Bm25Weight, Scorer};
use crate::schema::IndexRecordOption;
use crate::{DocId, Score};
/// Postings codec.
pub trait PostingsCodec: Send + Sync + 'static {
/// Serializer type for the postings codec.
type PostingsSerializer: PostingsSerializer;
/// Postings type for the postings codec.
type Postings: Postings + Clone;
/// Creates a new postings serializer.
fn new_serializer(
&self,
avg_fieldnorm: Score,
mode: IndexRecordOption,
fieldnorm_reader: Option<FieldNormReader>,
) -> Self::PostingsSerializer;
/// Loads postings
///
/// Record option is the option that was passed at indexing time.
/// Requested option is the option that is requested.
///
/// For instance, we may have term_freq in the posting list
/// but we can skip decompressing as we read the posting list.
///
/// If record option does not support the requested option,
/// this method does NOT return an error and will in fact restrict
/// requested_option to what is available.
fn load_postings(
&self,
doc_freq: u32,
postings_data: OwnedBytes,
record_option: IndexRecordOption,
requested_option: IndexRecordOption,
positions_data: Option<OwnedBytes>,
) -> io::Result<Self::Postings>;
/// If your codec supports different ways to accelerate `for_each_pruning` that's
/// where you should implement it.
///
/// Returning `Err(scorer)` without mutating the scorer nor calling the callback function,
/// is never "wrong". It just leaves the responsability to the caller to call a fallback
/// implementation on the scorer.
///
/// If your codec supports BlockMax-Wand, you just need to have your
/// postings implement `PostingsWithBlockMax` and copy what is done in the StandardPostings
/// codec to enable it.
fn try_accelerated_for_each_pruning(
_threshold: Score,
scorer: Box<dyn Scorer>,
_callback: &mut dyn FnMut(DocId, Score) -> Score,
) -> Result<(), Box<dyn Scorer>> {
Err(scorer)
}
}
/// A postings serializer is a listener that is in charge of serializing postings
///
/// IO is done only once per postings, once all of the data has been received.
/// A serializer will therefore contain internal buffers.
///
/// A serializer is created once and recycled for all postings.
///
/// Clients should use PostingsSerializer as follows.
/// ```rust,no_run
/// // First postings list
/// serializer.new_term(2, true);
/// serializer.write_doc(2, 1);
/// serializer.write_doc(6, 2);
/// serializer.close_term(3);
/// serializer.clear();
/// // Second postings list
/// serializer.new_term(1, true);
/// serializer.write_doc(3, 1);
/// serializer.close_term(3);
/// ```
pub trait PostingsSerializer {
/// The term_doc_freq here is the number of documents
/// in the postings lists.
///
/// It can be used to compute the idf that will be used for the
/// blockmax parameters.
///
/// If not available (e.g. if we do not collect `term_frequencies`
/// blockwand is disabled), the term_doc_freq passed will be set 0.
fn new_term(&mut self, term_doc_freq: u32, record_term_freq: bool);
/// Records a new document id for the current term.
/// The serializer may ignore it.
fn write_doc(&mut self, doc_id: DocId, term_freq: u32);
/// Closes the current term and writes the postings list associated.
fn close_term(&mut self, doc_freq: u32, wrt: &mut impl io::Write) -> io::Result<()>;
}
/// A light complement interface to Postings to allow block-max wand acceleration.
pub trait PostingsWithBlockMax: Postings {
/// Moves the postings to the block containign `target_doc` and returns
/// an upperbound of the score for documents in the block.
///
/// `Warning`: Calling this method may leave the postings in an invalid state.
/// callers are required to call seek before calling any other of the
/// `Postings` method (like doc / advance etc.).
fn seek_block_max(
&mut self,
target_doc: crate::DocId,
fieldnorm_reader: &FieldNormReader,
similarity_weight: &Bm25Weight,
) -> Score;
/// Returns the last document in the current block (or Terminated if this
/// is the last block).
fn last_doc_in_block(&self) -> crate::DocId;
}

View File

@@ -1,35 +0,0 @@
use serde::{Deserialize, Serialize};
use crate::codec::standard::postings::StandardPostingsCodec;
use crate::codec::Codec;
/// Tantivy's default postings codec.
pub mod postings;
/// Tantivy's default codec.
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct StandardCodec;
impl Codec for StandardCodec {
type PostingsCodec = StandardPostingsCodec;
const ID: &'static str = "tantivy-default";
fn from_json_props(json_value: &serde_json::Value) -> crate::Result<Self> {
if !json_value.is_null() {
return Err(crate::TantivyError::InvalidArgument(format!(
"Codec property for the StandardCodec are unexpected. expected null, got {}",
json_value.as_str().unwrap_or("null")
)));
}
Ok(StandardCodec)
}
fn to_json_props(&self) -> serde_json::Value {
serde_json::Value::Null
}
fn postings_codec(&self) -> &Self::PostingsCodec {
&StandardPostingsCodec
}
}

View File

@@ -1,50 +0,0 @@
use crate::postings::compression::COMPRESSION_BLOCK_SIZE;
use crate::DocId;
pub struct Block {
doc_ids: [DocId; COMPRESSION_BLOCK_SIZE],
term_freqs: [u32; COMPRESSION_BLOCK_SIZE],
len: usize,
}
impl Block {
pub fn new() -> Self {
Block {
doc_ids: [0u32; COMPRESSION_BLOCK_SIZE],
term_freqs: [0u32; COMPRESSION_BLOCK_SIZE],
len: 0,
}
}
pub fn doc_ids(&self) -> &[DocId] {
&self.doc_ids[..self.len]
}
pub fn term_freqs(&self) -> &[u32] {
&self.term_freqs[..self.len]
}
pub fn clear(&mut self) {
self.len = 0;
}
pub fn append_doc(&mut self, doc: DocId, term_freq: u32) {
let len = self.len;
self.doc_ids[len] = doc;
self.term_freqs[len] = term_freq;
self.len = len + 1;
}
pub fn is_full(&self) -> bool {
self.len == COMPRESSION_BLOCK_SIZE
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn last_doc(&self) -> DocId {
assert_eq!(self.len, COMPRESSION_BLOCK_SIZE);
self.doc_ids[COMPRESSION_BLOCK_SIZE - 1]
}
}

View File

@@ -1,164 +0,0 @@
use std::io;
use crate::codec::postings::block_wand::{block_wand, block_wand_single_scorer};
use crate::codec::postings::PostingsCodec;
use crate::codec::standard::postings::block_segment_postings::BlockSegmentPostings;
pub use crate::codec::standard::postings::segment_postings::SegmentPostings;
use crate::fieldnorm::FieldNormReader;
use crate::positions::PositionReader;
use crate::query::term_query::TermScorer;
use crate::query::{BufferedUnionScorer, Scorer, SumCombiner};
use crate::schema::IndexRecordOption;
use crate::{DocSet as _, Score, TERMINATED};
mod block;
mod block_segment_postings;
mod segment_postings;
mod skip;
mod standard_postings_serializer;
pub use segment_postings::SegmentPostings as StandardPostings;
pub use standard_postings_serializer::StandardPostingsSerializer;
/// The default postings codec for tantivy.
pub struct StandardPostingsCodec;
#[expect(clippy::enum_variant_names)]
#[derive(Debug, PartialEq, Clone, Copy, Eq)]
pub(crate) enum FreqReadingOption {
NoFreq,
SkipFreq,
ReadFreq,
}
impl PostingsCodec for StandardPostingsCodec {
type PostingsSerializer = StandardPostingsSerializer;
type Postings = SegmentPostings;
fn new_serializer(
&self,
avg_fieldnorm: Score,
mode: IndexRecordOption,
fieldnorm_reader: Option<FieldNormReader>,
) -> Self::PostingsSerializer {
StandardPostingsSerializer::new(avg_fieldnorm, mode, fieldnorm_reader)
}
fn load_postings(
&self,
doc_freq: u32,
postings_data: common::OwnedBytes,
record_option: IndexRecordOption,
requested_option: IndexRecordOption,
positions_data_opt: Option<common::OwnedBytes>,
) -> io::Result<Self::Postings> {
// Rationalize record_option/requested_option.
let requested_option = requested_option.downgrade(record_option);
let block_segment_postings =
BlockSegmentPostings::open(doc_freq, postings_data, record_option, requested_option)?;
let position_reader = positions_data_opt.map(PositionReader::open).transpose()?;
Ok(SegmentPostings::from_block_postings(
block_segment_postings,
position_reader,
))
}
fn try_accelerated_for_each_pruning(
mut threshold: Score,
mut scorer: Box<dyn Scorer>,
callback: &mut dyn FnMut(crate::DocId, Score) -> Score,
) -> Result<(), Box<dyn Scorer>> {
scorer = match scorer.downcast::<TermScorer<Self::Postings>>() {
Ok(term_scorer) => {
block_wand_single_scorer(*term_scorer, threshold, callback);
return Ok(());
}
Err(scorer) => scorer,
};
let mut union_scorer =
scorer.downcast::<BufferedUnionScorer<Box<dyn Scorer>, SumCombiner>>()?;
if !union_scorer
.scorers()
.iter()
.all(|scorer| scorer.is::<TermScorer<Self::Postings>>())
{
return Err(union_scorer);
}
let doc = union_scorer.doc();
if doc == TERMINATED {
return Ok(());
}
let score = union_scorer.score();
if score > threshold {
threshold = callback(doc, score);
}
let boxed_scorers: Vec<Box<dyn Scorer>> = union_scorer.into_scorers();
let scorers: Vec<TermScorer<Self::Postings>> = boxed_scorers
.into_iter()
.map(|scorer| {
*scorer.downcast::<TermScorer<Self::Postings>>().ok().expect(
"Downcast failed despite the fact we already checked the type was correct",
)
})
.collect();
block_wand(scorers, threshold, callback);
Ok(())
}
}
#[cfg(test)]
mod tests {
use common::OwnedBytes;
use super::*;
use crate::codec::postings::PostingsSerializer as _;
use crate::postings::Postings as _;
fn test_segment_postings_tf_aux(num_docs: u32, include_term_freq: bool) -> SegmentPostings {
let mut postings_serializer =
StandardPostingsCodec.new_serializer(1.0f32, IndexRecordOption::WithFreqs, None);
let mut buffer = Vec::new();
postings_serializer.new_term(num_docs, include_term_freq);
for i in 0..num_docs {
postings_serializer.write_doc(i, 2);
}
postings_serializer
.close_term(num_docs, &mut buffer)
.unwrap();
StandardPostingsCodec
.load_postings(
num_docs,
OwnedBytes::new(buffer),
IndexRecordOption::WithFreqs,
IndexRecordOption::WithFreqs,
None,
)
.unwrap()
}
#[test]
fn test_segment_postings_small_block_with_and_without_freq() {
let small_block_without_term_freq = test_segment_postings_tf_aux(1, false);
assert!(!small_block_without_term_freq.has_freq());
assert_eq!(small_block_without_term_freq.doc(), 0);
assert_eq!(small_block_without_term_freq.term_freq(), 1);
let small_block_with_term_freq = test_segment_postings_tf_aux(1, true);
assert!(small_block_with_term_freq.has_freq());
assert_eq!(small_block_with_term_freq.doc(), 0);
assert_eq!(small_block_with_term_freq.term_freq(), 2);
}
#[test]
fn test_segment_postings_large_block_with_and_without_freq() {
let large_block_without_term_freq = test_segment_postings_tf_aux(128, false);
assert!(!large_block_without_term_freq.has_freq());
assert_eq!(large_block_without_term_freq.doc(), 0);
assert_eq!(large_block_without_term_freq.term_freq(), 1);
let large_block_with_term_freq = test_segment_postings_tf_aux(128, true);
assert!(large_block_with_term_freq.has_freq());
assert_eq!(large_block_with_term_freq.doc(), 0);
assert_eq!(large_block_with_term_freq.term_freq(), 2);
}
}

View File

@@ -1,184 +0,0 @@
use std::cmp::Ordering;
use std::io::{self, Write as _};
use common::{BinarySerializable as _, VInt};
use crate::codec::postings::PostingsSerializer;
use crate::codec::standard::postings::block::Block;
use crate::codec::standard::postings::skip::SkipSerializer;
use crate::fieldnorm::FieldNormReader;
use crate::postings::compression::{BlockEncoder, VIntEncoder as _, COMPRESSION_BLOCK_SIZE};
use crate::query::Bm25Weight;
use crate::schema::IndexRecordOption;
use crate::{DocId, Score};
/// Serializer object for tantivy's default postings format.
pub struct StandardPostingsSerializer {
last_doc_id_encoded: u32,
block_encoder: BlockEncoder,
block: Box<Block>,
postings_write: Vec<u8>,
skip_write: SkipSerializer,
mode: IndexRecordOption,
fieldnorm_reader: Option<FieldNormReader>,
bm25_weight: Option<Bm25Weight>,
avg_fieldnorm: Score, /* Average number of term in the field for that segment.
* this value is used to compute the block wand information. */
term_has_freq: bool,
}
impl StandardPostingsSerializer {
pub(crate) fn new(
avg_fieldnorm: Score,
mode: IndexRecordOption,
fieldnorm_reader: Option<FieldNormReader>,
) -> StandardPostingsSerializer {
Self {
last_doc_id_encoded: 0,
block_encoder: BlockEncoder::new(),
block: Box::new(Block::new()),
postings_write: Vec::new(),
skip_write: SkipSerializer::new(),
mode,
fieldnorm_reader,
bm25_weight: None,
avg_fieldnorm,
term_has_freq: false,
}
}
}
impl PostingsSerializer for StandardPostingsSerializer {
fn new_term(&mut self, term_doc_freq: u32, record_term_freq: bool) {
self.clear();
self.term_has_freq = self.mode.has_freq() && record_term_freq;
if !self.term_has_freq {
return;
}
let num_docs_in_segment: u64 =
if let Some(fieldnorm_reader) = self.fieldnorm_reader.as_ref() {
fieldnorm_reader.num_docs() as u64
} else {
return;
};
if num_docs_in_segment == 0 {
return;
}
self.bm25_weight = Some(Bm25Weight::for_one_term_without_explain(
term_doc_freq as u64,
num_docs_in_segment,
self.avg_fieldnorm,
));
}
fn write_doc(&mut self, doc_id: DocId, term_freq: u32) {
self.block.append_doc(doc_id, term_freq);
if self.block.is_full() {
self.write_block();
}
}
fn close_term(&mut self, doc_freq: u32, output_write: &mut impl io::Write) -> io::Result<()> {
if !self.block.is_empty() {
// we have doc ids waiting to be written
// this happens when the number of doc ids is
// not a perfect multiple of our block size.
//
// In that case, the remaining part is encoded
// using variable int encoding.
{
let block_encoded = self
.block_encoder
.compress_vint_sorted(self.block.doc_ids(), self.last_doc_id_encoded);
self.postings_write.write_all(block_encoded)?;
}
// ... Idem for term frequencies
if self.term_has_freq {
let block_encoded = self
.block_encoder
.compress_vint_unsorted(self.block.term_freqs());
self.postings_write.write_all(block_encoded)?;
}
self.block.clear();
}
if doc_freq >= COMPRESSION_BLOCK_SIZE as u32 {
let skip_data = self.skip_write.data();
VInt(skip_data.len() as u64).serialize(output_write)?;
output_write.write_all(skip_data)?;
}
output_write.write_all(&self.postings_write[..])?;
self.skip_write.clear();
self.postings_write.clear();
self.bm25_weight = None;
Ok(())
}
}
impl StandardPostingsSerializer {
fn clear(&mut self) {
self.bm25_weight = None;
self.block.clear();
self.last_doc_id_encoded = 0;
}
fn write_block(&mut self) {
{
// encode the doc ids
let (num_bits, block_encoded): (u8, &[u8]) = self
.block_encoder
.compress_block_sorted(self.block.doc_ids(), self.last_doc_id_encoded);
self.last_doc_id_encoded = self.block.last_doc();
self.skip_write
.write_doc(self.last_doc_id_encoded, num_bits);
// last el block 0, offset block 1,
self.postings_write.extend(block_encoded);
}
if self.term_has_freq {
let (num_bits, block_encoded): (u8, &[u8]) = self
.block_encoder
.compress_block_unsorted(self.block.term_freqs(), true);
self.postings_write.extend(block_encoded);
self.skip_write.write_term_freq(num_bits);
if self.mode.has_positions() {
// We serialize the sum of term freqs within the skip information
// in order to navigate through positions.
let sum_freq = self.block.term_freqs().iter().cloned().sum();
self.skip_write.write_total_term_freq(sum_freq);
}
let mut blockwand_params = (0u8, 0u32);
if let Some(bm25_weight) = self.bm25_weight.as_ref() {
if let Some(fieldnorm_reader) = self.fieldnorm_reader.as_ref() {
let docs = self.block.doc_ids().iter().cloned();
let term_freqs = self.block.term_freqs().iter().cloned();
let fieldnorms = docs.map(|doc| fieldnorm_reader.fieldnorm_id(doc));
blockwand_params = fieldnorms
.zip(term_freqs)
.max_by(
|(left_fieldnorm_id, left_term_freq),
(right_fieldnorm_id, right_term_freq)| {
let left_score =
bm25_weight.tf_factor(*left_fieldnorm_id, *left_term_freq);
let right_score =
bm25_weight.tf_factor(*right_fieldnorm_id, *right_term_freq);
left_score
.partial_cmp(&right_score)
.unwrap_or(Ordering::Equal)
},
)
.unwrap();
}
}
let (fieldnorm_id, term_freq) = blockwand_params;
self.skip_write.write_blockwand_max(fieldnorm_id, term_freq);
}
self.block.clear();
}
}

View File

@@ -1,4 +1,5 @@
mod order;
mod sort_by_bytes;
mod sort_by_erased_type;
mod sort_by_score;
mod sort_by_static_fast_value;
@@ -6,6 +7,7 @@ mod sort_by_string;
mod sort_key_computer;
pub use order::*;
pub use sort_by_bytes::SortByBytes;
pub use sort_by_erased_type::SortByErasedType;
pub use sort_by_score::SortBySimilarityScore;
pub use sort_by_static_fast_value::SortByStaticFastValue;

View File

@@ -0,0 +1,168 @@
use columnar::BytesColumn;
use crate::collector::sort_key::NaturalComparator;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::termdict::TermOrdinal;
use crate::{DocId, Score};
/// Sort by the first value of a bytes column.
///
/// If the field is multivalued, only the first value is considered.
///
/// Documents that do not have this value are still considered.
/// Their sort key will simply be `None`.
#[derive(Debug, Clone)]
pub struct SortByBytes {
column_name: String,
}
impl SortByBytes {
/// Creates a new sort by bytes sort key computer.
pub fn for_field(column_name: impl ToString) -> Self {
SortByBytes {
column_name: column_name.to_string(),
}
}
}
impl SortKeyComputer for SortByBytes {
type SortKey = Option<Vec<u8>>;
type Child = ByBytesColumnSegmentSortKeyComputer;
type Comparator = NaturalComparator;
fn segment_sort_key_computer(
&self,
segment_reader: &crate::SegmentReader,
) -> crate::Result<Self::Child> {
let bytes_column_opt = segment_reader.fast_fields().bytes(&self.column_name)?;
Ok(ByBytesColumnSegmentSortKeyComputer { bytes_column_opt })
}
}
/// Segment-level sort key computer for bytes columns.
pub struct ByBytesColumnSegmentSortKeyComputer {
bytes_column_opt: Option<BytesColumn>,
}
impl SegmentSortKeyComputer for ByBytesColumnSegmentSortKeyComputer {
type SortKey = Option<Vec<u8>>;
type SegmentSortKey = Option<TermOrdinal>;
type SegmentComparator = NaturalComparator;
#[inline(always)]
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Option<TermOrdinal> {
let bytes_column = self.bytes_column_opt.as_ref()?;
bytes_column.ords().first(doc)
}
fn convert_segment_sort_key(&self, term_ord_opt: Option<TermOrdinal>) -> Option<Vec<u8>> {
// TODO: Individual lookups to the dictionary like this are very likely to repeatedly
// decompress the same blocks. See https://github.com/quickwit-oss/tantivy/issues/2776
let term_ord = term_ord_opt?;
let bytes_column = self.bytes_column_opt.as_ref()?;
let mut bytes = Vec::new();
bytes_column
.dictionary()
.ord_to_term(term_ord, &mut bytes)
.ok()?;
Some(bytes)
}
}
#[cfg(test)]
mod tests {
use super::SortByBytes;
use crate::collector::TopDocs;
use crate::query::AllQuery;
use crate::schema::{BytesOptions, Schema, FAST, INDEXED};
use crate::{Index, IndexWriter, Order, TantivyDocument};
#[test]
fn test_sort_by_bytes_asc() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let bytes_field = schema_builder
.add_bytes_field("data", BytesOptions::default().set_fast().set_indexed());
let id_field = schema_builder.add_u64_field("id", FAST | INDEXED);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer: IndexWriter = index.writer_for_tests()?;
// Insert documents with byte values in non-sorted order
let test_data: Vec<(u64, Vec<u8>)> = vec![
(1, vec![0x02, 0x00]),
(2, vec![0x00, 0x10]),
(3, vec![0x01, 0x00]),
(4, vec![0x00, 0x20]),
];
for (id, bytes) in &test_data {
let mut doc = TantivyDocument::new();
doc.add_u64(id_field, *id);
doc.add_bytes(bytes_field, bytes);
index_writer.add_document(doc)?;
}
index_writer.commit()?;
let reader = index.reader()?;
let searcher = reader.searcher();
// Sort ascending by bytes
let top_docs =
TopDocs::with_limit(10).order_by((SortByBytes::for_field("data"), Order::Asc));
let results: Vec<(Option<Vec<u8>>, _)> = searcher.search(&AllQuery, &top_docs)?;
// Expected order: [0x00,0x10], [0x00,0x20], [0x01,0x00], [0x02,0x00]
let sorted_bytes: Vec<Option<Vec<u8>>> = results.into_iter().map(|(b, _)| b).collect();
assert_eq!(
sorted_bytes,
vec![
Some(vec![0x00, 0x10]),
Some(vec![0x00, 0x20]),
Some(vec![0x01, 0x00]),
Some(vec![0x02, 0x00]),
]
);
Ok(())
}
#[test]
fn test_sort_by_bytes_desc() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let bytes_field = schema_builder
.add_bytes_field("data", BytesOptions::default().set_fast().set_indexed());
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer: IndexWriter = index.writer_for_tests()?;
let test_data: Vec<Vec<u8>> = vec![vec![0x00, 0x10], vec![0x02, 0x00], vec![0x01, 0x00]];
for bytes in &test_data {
let mut doc = TantivyDocument::new();
doc.add_bytes(bytes_field, bytes);
index_writer.add_document(doc)?;
}
index_writer.commit()?;
let reader = index.reader()?;
let searcher = reader.searcher();
// Sort descending by bytes
let top_docs =
TopDocs::with_limit(10).order_by((SortByBytes::for_field("data"), Order::Desc));
let results: Vec<(Option<Vec<u8>>, _)> = searcher.search(&AllQuery, &top_docs)?;
// Expected order (descending): [0x02,0x00], [0x01,0x00], [0x00,0x10]
let sorted_bytes: Vec<Option<Vec<u8>>> = results.into_iter().map(|(b, _)| b).collect();
assert_eq!(
sorted_bytes,
vec![
Some(vec![0x02, 0x00]),
Some(vec![0x01, 0x00]),
Some(vec![0x00, 0x10]),
]
);
Ok(())
}
}

View File

@@ -1,7 +1,7 @@
use columnar::{ColumnType, MonotonicallyMappableToU64};
use crate::collector::sort_key::{
NaturalComparator, SortBySimilarityScore, SortByStaticFastValue, SortByString,
NaturalComparator, SortByBytes, SortBySimilarityScore, SortByStaticFastValue, SortByString,
};
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::fastfield::FastFieldNotAvailableError;
@@ -114,6 +114,16 @@ impl SortKeyComputer for SortByErasedType {
},
})
}
ColumnType::Bytes => {
let computer = SortByBytes::for_field(column_name);
let inner = computer.segment_sort_key_computer(segment_reader)?;
Box::new(ErasedSegmentSortKeyComputerWrapper {
inner,
converter: |val: Option<Vec<u8>>| {
val.map(OwnedValue::Bytes).unwrap_or(OwnedValue::Null)
},
})
}
ColumnType::U64 => {
let computer = SortByStaticFastValue::<u64>::for_field(column_name);
let inner = computer.segment_sort_key_computer(segment_reader)?;
@@ -281,6 +291,65 @@ mod tests {
);
}
#[test]
fn test_sort_by_owned_bytes() {
let mut schema_builder = Schema::builder();
let data_field = schema_builder.add_bytes_field("data", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut writer = index.writer_for_tests().unwrap();
writer
.add_document(doc!(data_field => vec![0x03u8, 0x00]))
.unwrap();
writer
.add_document(doc!(data_field => vec![0x01u8, 0x00]))
.unwrap();
writer
.add_document(doc!(data_field => vec![0x02u8, 0x00]))
.unwrap();
writer.add_document(doc!()).unwrap();
writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
// Sort descending (Natural - highest first)
let collector = TopDocs::with_limit(10)
.order_by((SortByErasedType::for_field("data"), ComparatorEnum::Natural));
let top_docs = searcher.search(&AllQuery, &collector).unwrap();
let values: Vec<OwnedValue> = top_docs.into_iter().map(|(key, _)| key).collect();
assert_eq!(
values,
vec![
OwnedValue::Bytes(vec![0x03, 0x00]),
OwnedValue::Bytes(vec![0x02, 0x00]),
OwnedValue::Bytes(vec![0x01, 0x00]),
OwnedValue::Null
]
);
// Sort ascending (ReverseNoneLower - lowest first, nulls last)
let collector = TopDocs::with_limit(10).order_by((
SortByErasedType::for_field("data"),
ComparatorEnum::ReverseNoneLower,
));
let top_docs = searcher.search(&AllQuery, &collector).unwrap();
let values: Vec<OwnedValue> = top_docs.into_iter().map(|(key, _)| key).collect();
assert_eq!(
values,
vec![
OwnedValue::Bytes(vec![0x01, 0x00]),
OwnedValue::Bytes(vec![0x02, 0x00]),
OwnedValue::Bytes(vec![0x03, 0x00]),
OwnedValue::Null
]
);
}
#[test]
fn test_sort_by_owned_reverse() {
let mut schema_builder = Schema::builder();

View File

@@ -4,7 +4,7 @@ use common::{replace_in_place, JsonPathWriter};
use rustc_hash::FxHashMap;
use crate::indexer::indexing_term::IndexingTerm;
use crate::postings::{IndexingContext, IndexingPosition, PostingsWriter as _, PostingsWriterEnum};
use crate::postings::{IndexingContext, IndexingPosition, PostingsWriter};
use crate::schema::document::{ReferenceValue, ReferenceValueLeaf, Value};
use crate::schema::{Type, DATE_TIME_PRECISION_INDEXED};
use crate::time::format_description::well_known::Rfc3339;
@@ -80,7 +80,7 @@ fn index_json_object<'a, V: Value<'a>>(
text_analyzer: &mut TextAnalyzer,
term_buffer: &mut IndexingTerm,
json_path_writer: &mut JsonPathWriter,
postings_writer: &mut PostingsWriterEnum,
postings_writer: &mut dyn PostingsWriter,
ctx: &mut IndexingContext,
positions_per_path: &mut IndexingPositionsPerPath,
) {
@@ -110,7 +110,7 @@ pub(crate) fn index_json_value<'a, V: Value<'a>>(
text_analyzer: &mut TextAnalyzer,
term_buffer: &mut IndexingTerm,
json_path_writer: &mut JsonPathWriter,
postings_writer: &mut PostingsWriterEnum,
postings_writer: &mut dyn PostingsWriter,
ctx: &mut IndexingContext,
positions_per_path: &mut IndexingPositionsPerPath,
) {

View File

@@ -1,6 +1,4 @@
use std::ops::{Deref as _, DerefMut as _};
use common::BitSet;
use std::borrow::{Borrow, BorrowMut};
use crate::fastfield::AliveBitSet;
use crate::DocId;
@@ -132,19 +130,6 @@ pub trait DocSet: Send {
buffer.len()
}
/// Fills the given bitset with the documents in the docset.
///
/// If the docset max_doc is smaller than the largest doc, this function might not consume the
/// docset entirely.
fn fill_bitset(&mut self, bitset: &mut BitSet) {
let bitset_max_value: u32 = bitset.max_value();
let mut doc = self.doc();
while doc < bitset_max_value {
bitset.insert(doc);
doc = self.advance();
}
}
/// Returns the current document
/// Right after creating a new `DocSet`, the docset points to the first document.
///
@@ -248,57 +233,51 @@ impl DocSet for &mut dyn DocSet {
fn count_including_deleted(&mut self) -> u32 {
(**self).count_including_deleted()
}
fn fill_bitset(&mut self, bitset: &mut BitSet) {
(**self).fill_bitset(bitset);
}
}
impl<TDocSet: DocSet + ?Sized> DocSet for Box<TDocSet> {
#[inline]
fn advance(&mut self) -> DocId {
self.deref_mut().advance()
let unboxed: &mut TDocSet = self.borrow_mut();
unboxed.advance()
}
#[inline]
fn seek(&mut self, target: DocId) -> DocId {
self.deref_mut().seek(target)
let unboxed: &mut TDocSet = self.borrow_mut();
unboxed.seek(target)
}
fn seek_danger(&mut self, target: DocId) -> SeekDangerResult {
self.deref_mut().seek_danger(target)
let unboxed: &mut TDocSet = self.borrow_mut();
unboxed.seek_danger(target)
}
#[inline]
fn fill_buffer(&mut self, buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN]) -> usize {
self.deref_mut().fill_buffer(buffer)
let unboxed: &mut TDocSet = self.borrow_mut();
unboxed.fill_buffer(buffer)
}
#[inline]
fn doc(&self) -> DocId {
self.deref().doc()
let unboxed: &TDocSet = self.borrow();
unboxed.doc()
}
#[inline]
fn size_hint(&self) -> u32 {
self.deref().size_hint()
let unboxed: &TDocSet = self.borrow();
unboxed.size_hint()
}
#[inline]
fn cost(&self) -> u64 {
self.deref().cost()
let unboxed: &TDocSet = self.borrow();
unboxed.cost()
}
#[inline]
fn count(&mut self, alive_bitset: &AliveBitSet) -> u32 {
self.deref_mut().count(alive_bitset)
let unboxed: &mut TDocSet = self.borrow_mut();
unboxed.count(alive_bitset)
}
fn count_including_deleted(&mut self) -> u32 {
self.deref_mut().count_including_deleted()
}
fn fill_bitset(&mut self, bitset: &mut BitSet) {
self.deref_mut().fill_bitset(bitset);
let unboxed: &mut TDocSet = self.borrow_mut();
unboxed.count_including_deleted()
}
}

View File

@@ -1,49 +0,0 @@
use std::borrow::Cow;
use serde::{Deserialize, Serialize};
use crate::codec::{Codec, StandardCodec};
/// A Codec configuration is just a serializable object.
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct CodecConfiguration {
codec_id: Cow<'static, str>,
#[serde(default, skip_serializing_if = "serde_json::Value::is_null")]
props: serde_json::Value,
}
impl CodecConfiguration {
/// Returns true if the codec is the standard codec.
pub fn is_standard(&self) -> bool {
self.codec_id == StandardCodec::ID && self.props.is_null()
}
/// Creates a codec instance from the configuration.
///
/// If the codec id does not match the code's name, an error is returned.
pub fn to_codec<C: Codec>(&self) -> crate::Result<C> {
if self.codec_id != C::ID {
return Err(crate::TantivyError::InvalidArgument(format!(
"Codec id mismatch: expected {}, got {}",
C::ID,
self.codec_id
)));
}
C::from_json_props(&self.props)
}
}
impl<'a, C: Codec> From<&'a C> for CodecConfiguration {
fn from(codec: &'a C) -> Self {
CodecConfiguration {
codec_id: Cow::Borrowed(C::ID),
props: codec.to_json_props(),
}
}
}
impl Default for CodecConfiguration {
fn default() -> Self {
CodecConfiguration::from(&StandardCodec)
}
}

View File

@@ -8,14 +8,12 @@ use std::thread::available_parallelism;
use super::segment::Segment;
use super::segment_reader::merge_field_meta_data;
use super::{FieldMetadata, IndexSettings};
use crate::codec::StandardCodec;
use crate::core::{Executor, META_FILEPATH};
use crate::directory::error::OpenReadError;
#[cfg(feature = "mmap")]
use crate::directory::MmapDirectory;
use crate::directory::{Directory, ManagedDirectory, RamDirectory, INDEX_WRITER_LOCK};
use crate::error::{DataCorruption, TantivyError};
use crate::index::codec_configuration::CodecConfiguration;
use crate::index::{IndexMeta, SegmentId, SegmentMeta, SegmentMetaInventory};
use crate::indexer::index_writer::{
IndexWriterOptions, MAX_NUM_THREAD, MEMORY_BUDGET_NUM_BYTES_MIN,
@@ -61,7 +59,6 @@ fn save_new_metas(
schema: Schema,
index_settings: IndexSettings,
directory: &dyn Directory,
codec: CodecConfiguration,
) -> crate::Result<()> {
save_metas(
&IndexMeta {
@@ -70,7 +67,6 @@ fn save_new_metas(
schema,
opstamp: 0u64,
payload: None,
codec,
},
directory,
)?;
@@ -105,21 +101,18 @@ fn save_new_metas(
/// };
/// let index = Index::builder().schema(schema).settings(settings).create_in_ram();
/// ```
pub struct IndexBuilder<Codec: crate::codec::Codec = StandardCodec> {
pub struct IndexBuilder {
schema: Option<Schema>,
index_settings: IndexSettings,
tokenizer_manager: TokenizerManager,
fast_field_tokenizer_manager: TokenizerManager,
codec: Codec,
}
impl Default for IndexBuilder<StandardCodec> {
impl Default for IndexBuilder {
fn default() -> Self {
IndexBuilder::new()
}
}
impl IndexBuilder<StandardCodec> {
impl IndexBuilder {
/// Creates a new `IndexBuilder`
pub fn new() -> Self {
Self {
@@ -127,21 +120,6 @@ impl IndexBuilder<StandardCodec> {
index_settings: IndexSettings::default(),
tokenizer_manager: TokenizerManager::default(),
fast_field_tokenizer_manager: TokenizerManager::default(),
codec: StandardCodec,
}
}
}
impl<Codec: crate::codec::Codec> IndexBuilder<Codec> {
/// Set the codec
#[must_use]
pub fn codec<NewCodec: crate::codec::Codec>(self, codec: NewCodec) -> IndexBuilder<NewCodec> {
IndexBuilder {
schema: self.schema,
index_settings: self.index_settings,
tokenizer_manager: self.tokenizer_manager,
fast_field_tokenizer_manager: self.fast_field_tokenizer_manager,
codec,
}
}
@@ -176,7 +154,7 @@ impl<Codec: crate::codec::Codec> IndexBuilder<Codec> {
/// The index will be allocated in anonymous memory.
/// This is useful for indexing small set of documents
/// for instances like unit test or temporary in memory index.
pub fn create_in_ram(self) -> Result<Index<Codec>, TantivyError> {
pub fn create_in_ram(self) -> Result<Index, TantivyError> {
let ram_directory = RamDirectory::create();
self.create(ram_directory)
}
@@ -187,7 +165,7 @@ impl<Codec: crate::codec::Codec> IndexBuilder<Codec> {
/// If a previous index was in this directory, it returns an
/// [`TantivyError::IndexAlreadyExists`] error.
#[cfg(feature = "mmap")]
pub fn create_in_dir<P: AsRef<Path>>(self, directory_path: P) -> crate::Result<Index<Codec>> {
pub fn create_in_dir<P: AsRef<Path>>(self, directory_path: P) -> crate::Result<Index> {
let mmap_directory: Box<dyn Directory> = Box::new(MmapDirectory::open(directory_path)?);
if Index::exists(&*mmap_directory)? {
return Err(TantivyError::IndexAlreadyExists);
@@ -208,7 +186,7 @@ impl<Codec: crate::codec::Codec> IndexBuilder<Codec> {
self,
dir: impl Into<Box<dyn Directory>>,
mem_budget: usize,
) -> crate::Result<SingleSegmentIndexWriter<Codec, D>> {
) -> crate::Result<SingleSegmentIndexWriter<D>> {
let index = self.create(dir)?;
let index_simple_writer = SingleSegmentIndexWriter::new(index, mem_budget)?;
Ok(index_simple_writer)
@@ -224,7 +202,7 @@ impl<Codec: crate::codec::Codec> IndexBuilder<Codec> {
/// For other unit tests, prefer the [`RamDirectory`], see:
/// [`IndexBuilder::create_in_ram()`].
#[cfg(feature = "mmap")]
pub fn create_from_tempdir(self) -> crate::Result<Index<Codec>> {
pub fn create_from_tempdir(self) -> crate::Result<Index> {
let mmap_directory: Box<dyn Directory> = Box::new(MmapDirectory::create_from_tempdir()?);
self.create(mmap_directory)
}
@@ -237,15 +215,12 @@ impl<Codec: crate::codec::Codec> IndexBuilder<Codec> {
}
/// Opens or creates a new index in the provided directory
pub fn open_or_create<T: Into<Box<dyn Directory>>>(
self,
dir: T,
) -> crate::Result<Index<Codec>> {
pub fn open_or_create<T: Into<Box<dyn Directory>>>(self, dir: T) -> crate::Result<Index> {
let dir: Box<dyn Directory> = dir.into();
if !Index::exists(&*dir)? {
return self.create(dir);
}
let mut index: Index<Codec> = Index::<Codec>::open_with_codec(dir)?;
let mut index = Index::open(dir)?;
index.set_tokenizers(self.tokenizer_manager.clone());
if index.schema() == self.get_expect_schema()? {
Ok(index)
@@ -269,25 +244,18 @@ impl<Codec: crate::codec::Codec> IndexBuilder<Codec> {
/// Creates a new index given an implementation of the trait `Directory`.
///
/// If a directory previously existed, it will be erased.
pub fn create<T: Into<Box<dyn Directory>>>(self, dir: T) -> crate::Result<Index<Codec>> {
self.create_avoid_monomorphization(dir.into())
}
fn create_avoid_monomorphization(self, dir: Box<dyn Directory>) -> crate::Result<Index<Codec>> {
fn create<T: Into<Box<dyn Directory>>>(self, dir: T) -> crate::Result<Index> {
self.validate()?;
let dir = dir.into();
let directory = ManagedDirectory::wrap(dir)?;
let codec: CodecConfiguration = CodecConfiguration::from(&self.codec);
save_new_metas(
self.get_expect_schema()?,
self.index_settings.clone(),
&directory,
codec,
)?;
let schema = self.get_expect_schema()?;
let mut metas = IndexMeta::with_schema_and_codec(schema, &self.codec);
let mut metas = IndexMeta::with_schema(self.get_expect_schema()?);
metas.index_settings = self.index_settings;
let mut index: Index<Codec> =
Index::<Codec>::open_from_metas(directory, &metas, SegmentMetaInventory::default())?;
let mut index = Index::open_from_metas(directory, &metas, SegmentMetaInventory::default());
index.set_tokenizers(self.tokenizer_manager);
index.set_fast_field_tokenizers(self.fast_field_tokenizer_manager);
Ok(index)
@@ -296,7 +264,7 @@ impl<Codec: crate::codec::Codec> IndexBuilder<Codec> {
/// Search Index
#[derive(Clone)]
pub struct Index<Codec: crate::codec::Codec = crate::codec::StandardCodec> {
pub struct Index {
directory: ManagedDirectory,
schema: Schema,
settings: IndexSettings,
@@ -304,7 +272,6 @@ pub struct Index<Codec: crate::codec::Codec = crate::codec::StandardCodec> {
tokenizers: TokenizerManager,
fast_field_tokenizers: TokenizerManager,
inventory: SegmentMetaInventory,
codec: Codec,
}
impl Index {
@@ -312,6 +279,41 @@ impl Index {
pub fn builder() -> IndexBuilder {
IndexBuilder::new()
}
/// Examines the directory to see if it contains an index.
///
/// Effectively, it only checks for the presence of the `meta.json` file.
pub fn exists(dir: &dyn Directory) -> Result<bool, OpenReadError> {
dir.exists(&META_FILEPATH)
}
/// Accessor to the search executor.
///
/// This pool is used by default when calling `searcher.search(...)`
/// to perform search on the individual segments.
///
/// By default the executor is single thread, and simply runs in the calling thread.
pub fn search_executor(&self) -> &Executor {
&self.executor
}
/// Replace the default single thread search executor pool
/// by a thread pool with a given number of threads.
pub fn set_multithread_executor(&mut self, num_threads: usize) -> crate::Result<()> {
self.executor = Executor::multi_thread(num_threads, "tantivy-search-")?;
Ok(())
}
/// Custom thread pool by a outer thread pool.
pub fn set_executor(&mut self, executor: Executor) {
self.executor = executor;
}
/// Replace the default single thread search executor pool
/// by a thread pool with as many threads as there are CPUs on the system.
pub fn set_default_multithread_executor(&mut self) -> crate::Result<()> {
let default_num_threads = available_parallelism()?.get();
self.set_multithread_executor(default_num_threads)
}
/// Creates a new index using the [`RamDirectory`].
///
@@ -322,13 +324,6 @@ impl Index {
IndexBuilder::new().schema(schema).create_in_ram().unwrap()
}
/// Examines the directory to see if it contains an index.
///
/// Effectively, it only checks for the presence of the `meta.json` file.
pub fn exists(directory: &dyn Directory) -> Result<bool, OpenReadError> {
directory.exists(&META_FILEPATH)
}
/// Creates a new index in a given filepath.
/// The index will use the [`MmapDirectory`].
///
@@ -375,108 +370,20 @@ impl Index {
schema: Schema,
settings: IndexSettings,
) -> crate::Result<Index> {
Self::create_to_avoid_monomorphization(dir.into(), schema, settings)
}
fn create_to_avoid_monomorphization(
dir: Box<dyn Directory>,
schema: Schema,
settings: IndexSettings,
) -> crate::Result<Index> {
let dir: Box<dyn Directory> = dir.into();
let mut builder = IndexBuilder::new().schema(schema);
builder = builder.settings(settings);
builder.create(dir)
}
/// Opens a new directory from an index path.
#[cfg(feature = "mmap")]
pub fn open_in_dir<P: AsRef<Path>>(directory_path: P) -> crate::Result<Index> {
Self::open_in_dir_to_avoid_monomorphization(directory_path.as_ref())
}
#[cfg(feature = "mmap")]
#[inline(never)]
fn open_in_dir_to_avoid_monomorphization(directory_path: &Path) -> crate::Result<Index> {
let mmap_directory = MmapDirectory::open(directory_path)?;
Index::open(mmap_directory)
}
/// Open the index using the provided directory
pub fn open<T: Into<Box<dyn Directory>>>(directory: T) -> crate::Result<Index> {
Index::<StandardCodec>::open_with_codec(directory.into())
}
}
impl<Codec: crate::codec::Codec> Index<Codec> {
/// Returns a version of this index with the standard codec.
/// This is useful when you need to pass the index to APIs that
/// don't care about the codec (e.g., for reading).
pub(crate) fn with_standard_codec(&self) -> Index<StandardCodec> {
Index {
directory: self.directory.clone(),
schema: self.schema.clone(),
settings: self.settings.clone(),
executor: self.executor.clone(),
tokenizers: self.tokenizers.clone(),
fast_field_tokenizers: self.fast_field_tokenizers.clone(),
inventory: self.inventory.clone(),
codec: StandardCodec,
}
}
/// Open the index using the provided directory
#[inline(never)]
pub fn open_with_codec(directory: Box<dyn Directory>) -> crate::Result<Index<Codec>> {
let directory = ManagedDirectory::wrap(directory)?;
let inventory = SegmentMetaInventory::default();
let metas = load_metas(&directory, &inventory)?;
let index: Index<Codec> = Index::<Codec>::open_from_metas(directory, &metas, inventory)?;
Ok(index)
}
/// Accessor to the codec.
pub fn codec(&self) -> &Codec {
&self.codec
}
/// Accessor to the search executor.
///
/// This pool is used by default when calling `searcher.search(...)`
/// to perform search on the individual segments.
///
/// By default the executor is single thread, and simply runs in the calling thread.
pub fn search_executor(&self) -> &Executor {
&self.executor
}
/// Replace the default single thread search executor pool
/// by a thread pool with a given number of threads.
pub fn set_multithread_executor(&mut self, num_threads: usize) -> crate::Result<()> {
self.executor = Executor::multi_thread(num_threads, "tantivy-search-")?;
Ok(())
}
/// Custom thread pool by a outer thread pool.
pub fn set_executor(&mut self, executor: Executor) {
self.executor = executor;
}
/// Replace the default single thread search executor pool
/// by a thread pool with as many threads as there are CPUs on the system.
pub fn set_default_multithread_executor(&mut self) -> crate::Result<()> {
let default_num_threads = available_parallelism()?.get();
self.set_multithread_executor(default_num_threads)
}
/// Creates a new index given a directory and an [`IndexMeta`].
fn open_from_metas<C: crate::codec::Codec>(
fn open_from_metas(
directory: ManagedDirectory,
metas: &IndexMeta,
inventory: SegmentMetaInventory,
) -> crate::Result<Index<C>> {
) -> Index {
let schema = metas.schema.clone();
let codec = metas.codec.to_codec::<C>()?;
Ok(Index {
Index {
settings: metas.index_settings.clone(),
directory,
schema,
@@ -484,8 +391,7 @@ impl<Codec: crate::codec::Codec> Index<Codec> {
fast_field_tokenizers: TokenizerManager::default(),
executor: Executor::single_thread(),
inventory,
codec,
})
}
}
/// Setter for the tokenizer manager.
@@ -541,7 +447,7 @@ impl<Codec: crate::codec::Codec> Index<Codec> {
/// Create a default [`IndexReader`] for the given index.
///
/// See [`Index.reader_builder()`].
pub fn reader(&self) -> crate::Result<IndexReader<Codec>> {
pub fn reader(&self) -> crate::Result<IndexReader> {
self.reader_builder().try_into()
}
@@ -549,10 +455,17 @@ impl<Codec: crate::codec::Codec> Index<Codec> {
///
/// Most project should create at most one reader for a given index.
/// This method is typically called only once per `Index` instance.
pub fn reader_builder(&self) -> IndexReaderBuilder<Codec> {
pub fn reader_builder(&self) -> IndexReaderBuilder {
IndexReaderBuilder::new(self.clone())
}
/// Opens a new directory from an index path.
#[cfg(feature = "mmap")]
pub fn open_in_dir<P: AsRef<Path>>(directory_path: P) -> crate::Result<Index> {
let mmap_directory = MmapDirectory::open(directory_path)?;
Index::open(mmap_directory)
}
/// Returns the list of the segment metas tracked by the index.
///
/// Such segments can of course be part of the index,
@@ -593,6 +506,16 @@ impl<Codec: crate::codec::Codec> Index<Codec> {
self.inventory.new_segment_meta(segment_id, max_doc)
}
/// Open the index using the provided directory
pub fn open<T: Into<Box<dyn Directory>>>(directory: T) -> crate::Result<Index> {
let directory = directory.into();
let directory = ManagedDirectory::wrap(directory)?;
let inventory = SegmentMetaInventory::default();
let metas = load_metas(&directory, &inventory)?;
let index = Index::open_from_metas(directory, &metas, inventory);
Ok(index)
}
/// Reads the index meta file from the directory.
pub fn load_metas(&self) -> crate::Result<IndexMeta> {
load_metas(self.directory(), &self.inventory)
@@ -616,7 +539,7 @@ impl<Codec: crate::codec::Codec> Index<Codec> {
pub fn writer_with_options<D: Document>(
&self,
options: IndexWriterOptions,
) -> crate::Result<IndexWriter<Codec, D>> {
) -> crate::Result<IndexWriter<D>> {
let directory_lock = self
.directory
.acquire_lock(&INDEX_WRITER_LOCK)
@@ -658,7 +581,7 @@ impl<Codec: crate::codec::Codec> Index<Codec> {
&self,
num_threads: usize,
overall_memory_budget_in_bytes: usize,
) -> crate::Result<IndexWriter<Codec, D>> {
) -> crate::Result<IndexWriter<D>> {
let memory_arena_in_bytes_per_thread = overall_memory_budget_in_bytes / num_threads;
let options = IndexWriterOptions::builder()
.num_worker_threads(num_threads)
@@ -672,7 +595,7 @@ impl<Codec: crate::codec::Codec> Index<Codec> {
/// That index writer only simply has a single thread and a memory budget of 15 MB.
/// Using a single thread gives us a deterministic allocation of DocId.
#[cfg(test)]
pub fn writer_for_tests<D: Document>(&self) -> crate::Result<IndexWriter<Codec, D>> {
pub fn writer_for_tests<D: Document>(&self) -> crate::Result<IndexWriter<D>> {
self.writer_with_num_threads(1, MEMORY_BUDGET_NUM_BYTES_MIN)
}
@@ -690,7 +613,7 @@ impl<Codec: crate::codec::Codec> Index<Codec> {
pub fn writer<D: Document>(
&self,
memory_budget_in_bytes: usize,
) -> crate::Result<IndexWriter<Codec, D>> {
) -> crate::Result<IndexWriter<D>> {
let mut num_threads = std::cmp::min(available_parallelism()?.get(), MAX_NUM_THREAD);
let memory_budget_num_bytes_per_thread = memory_budget_in_bytes / num_threads;
if memory_budget_num_bytes_per_thread < MEMORY_BUDGET_NUM_BYTES_MIN {
@@ -717,7 +640,7 @@ impl<Codec: crate::codec::Codec> Index<Codec> {
}
/// Returns the list of segments that are searchable
pub fn searchable_segments(&self) -> crate::Result<Vec<Segment<Codec>>> {
pub fn searchable_segments(&self) -> crate::Result<Vec<Segment>> {
Ok(self
.searchable_segment_metas()?
.into_iter()
@@ -726,12 +649,12 @@ impl<Codec: crate::codec::Codec> Index<Codec> {
}
#[doc(hidden)]
pub fn segment(&self, segment_meta: SegmentMeta) -> Segment<Codec> {
pub fn segment(&self, segment_meta: SegmentMeta) -> Segment {
Segment::for_index(self.clone(), segment_meta)
}
/// Creates a new segment.
pub fn new_segment(&self) -> Segment<Codec> {
pub fn new_segment(&self) -> Segment {
let segment_meta = self
.inventory
.new_segment_meta(SegmentId::generate_random(), 0);
@@ -785,7 +708,7 @@ impl<Codec: crate::codec::Codec> Index<Codec> {
}
impl fmt::Debug for Index {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Index({:?})", self.directory)
}
}

View File

@@ -5,8 +5,7 @@ use std::path::PathBuf;
use serde::{Deserialize, Serialize};
use super::SegmentComponent;
use crate::codec::Codec;
use crate::index::{CodecConfiguration, SegmentId};
use crate::index::SegmentId;
use crate::schema::Schema;
use crate::store::Compressor;
use crate::{Inventory, Opstamp, TrackedObject};
@@ -287,10 +286,8 @@ pub struct IndexMeta {
/// This payload is entirely unused by tantivy.
#[serde(skip_serializing_if = "Option::is_none")]
pub payload: Option<String>,
/// Codec configuration for the index.
#[serde(skip_serializing_if = "CodecConfiguration::is_standard")]
pub codec: CodecConfiguration,
}
#[derive(Deserialize, Debug)]
struct UntrackedIndexMeta {
pub segments: Vec<InnerSegmentMeta>,
@@ -300,8 +297,6 @@ struct UntrackedIndexMeta {
pub opstamp: Opstamp,
#[serde(skip_serializing_if = "Option::is_none")]
pub payload: Option<String>,
#[serde(default)]
pub codec: CodecConfiguration,
}
impl UntrackedIndexMeta {
@@ -316,7 +311,6 @@ impl UntrackedIndexMeta {
schema: self.schema,
opstamp: self.opstamp,
payload: self.payload,
codec: self.codec,
}
}
}
@@ -327,14 +321,13 @@ impl IndexMeta {
///
/// This new index does not contains any segments.
/// Opstamp will the value `0u64`.
pub fn with_schema_and_codec<C: Codec>(schema: Schema, codec: &C) -> IndexMeta {
pub fn with_schema(schema: Schema) -> IndexMeta {
IndexMeta {
index_settings: IndexSettings::default(),
segments: vec![],
schema,
opstamp: 0u64,
payload: None,
codec: CodecConfiguration::from(codec),
}
}
@@ -385,38 +378,14 @@ mod tests {
schema,
opstamp: 0u64,
payload: None,
codec: Default::default(),
};
let json_value: serde_json::Value =
serde_json::to_value(&index_metas).expect("serialization failed");
let json = serde_json::ser::to_string(&index_metas).expect("serialization failed");
assert_eq!(
&json_value,
&serde_json::json!(
{
"index_settings": {
"docstore_compression": "none",
"docstore_blocksize": 16384
},
"segments": [],
"schema": [
{
"name": "text",
"type": "text",
"options": {
"indexing": {
"record": "position",
"fieldnorms": true,
"tokenizer": "default"
},
"stored": false,
"fast": false
}
}
],
"opstamp": 0
})
json,
r#"{"index_settings":{"docstore_compression":"none","docstore_blocksize":16384},"segments":[],"schema":[{"name":"text","type":"text","options":{"indexing":{"record":"position","fieldnorms":true,"tokenizer":"default"},"stored":false,"fast":false}}],"opstamp":0}"#
);
let deser_meta: UntrackedIndexMeta = serde_json::from_value(json_value).unwrap();
let deser_meta: UntrackedIndexMeta = serde_json::from_str(&json).unwrap();
assert_eq!(index_metas.index_settings, deser_meta.index_settings);
assert_eq!(index_metas.schema, deser_meta.schema);
assert_eq!(index_metas.opstamp, deser_meta.opstamp);
@@ -442,39 +411,14 @@ mod tests {
schema,
opstamp: 0u64,
payload: None,
codec: Default::default(),
};
let json_value = serde_json::to_value(&index_metas).expect("serialization failed");
let json = serde_json::ser::to_string(&index_metas).expect("serialization failed");
assert_eq!(
&json_value,
&serde_json::json!(
{
"index_settings": {
"docstore_compression": "zstd(compression_level=4)",
"docstore_blocksize": 1000000
},
"segments": [],
"schema": [
{
"name": "text",
"type": "text",
"options": {
"indexing": {
"record": "position",
"fieldnorms": true,
"tokenizer": "default"
},
"stored": false,
"fast": false
}
}
],
"opstamp": 0
}
)
json,
r#"{"index_settings":{"docstore_compression":"zstd(compression_level=4)","docstore_blocksize":1000000},"segments":[],"schema":[{"name":"text","type":"text","options":{"indexing":{"record":"position","fieldnorms":true,"tokenizer":"default"},"stored":false,"fast":false}}],"opstamp":0}"#
);
let deser_meta: UntrackedIndexMeta = serde_json::from_value(json_value).unwrap();
let deser_meta: UntrackedIndexMeta = serde_json::from_str(&json).unwrap();
assert_eq!(index_metas.index_settings, deser_meta.index_settings);
assert_eq!(index_metas.schema, deser_meta.schema);
assert_eq!(index_metas.opstamp, deser_meta.opstamp);

View File

@@ -1,8 +1,7 @@
use std::io;
use std::sync::Arc;
use common::json_path_writer::JSON_END_OF_PATH;
use common::{BinarySerializable, ByteCount, OwnedBytes};
use common::{BinarySerializable, ByteCount};
#[cfg(feature = "quickwit")]
use futures_util::{FutureExt, StreamExt, TryStreamExt};
#[cfg(feature = "quickwit")]
@@ -10,13 +9,9 @@ use itertools::Itertools;
#[cfg(feature = "quickwit")]
use tantivy_fst::automaton::{AlwaysMatch, Automaton};
use crate::codec::postings::PostingsCodec;
use crate::codec::{Codec, ObjectSafeCodec, StandardCodec};
use crate::directory::FileSlice;
use crate::fieldnorm::FieldNormReader;
use crate::postings::{Postings, TermInfo};
use crate::query::term_query::TermScorer;
use crate::query::{Bm25Weight, PhraseScorer, Scorer};
use crate::positions::PositionReader;
use crate::postings::{BlockSegmentPostings, SegmentPostings, TermInfo};
use crate::schema::{IndexRecordOption, Term, Type};
use crate::termdict::TermDictionary;
@@ -38,7 +33,6 @@ pub struct InvertedIndexReader {
positions_file_slice: FileSlice,
record_option: IndexRecordOption,
total_num_tokens: u64,
codec: Arc<dyn ObjectSafeCodec>,
}
/// Object that records the amount of space used by a field in an inverted index.
@@ -74,7 +68,6 @@ impl InvertedIndexReader {
postings_file_slice: FileSlice,
positions_file_slice: FileSlice,
record_option: IndexRecordOption,
codec: Arc<dyn ObjectSafeCodec>,
) -> io::Result<InvertedIndexReader> {
let (total_num_tokens_slice, postings_body) = postings_file_slice.split(8);
let total_num_tokens = u64::deserialize(&mut total_num_tokens_slice.read_bytes()?)?;
@@ -84,7 +77,6 @@ impl InvertedIndexReader {
positions_file_slice,
record_option,
total_num_tokens,
codec,
})
}
@@ -97,7 +89,6 @@ impl InvertedIndexReader {
positions_file_slice: FileSlice::empty(),
record_option,
total_num_tokens: 0u64,
codec: Arc::new(StandardCodec),
}
}
@@ -169,98 +160,61 @@ impl InvertedIndexReader {
Ok(fields)
}
pub(crate) fn new_term_scorer_specialized<C: Codec>(
&self,
term_info: &TermInfo,
option: IndexRecordOption,
fieldnorm_reader: FieldNormReader,
similarity_weight: Bm25Weight,
codec: &C,
) -> io::Result<TermScorer<<<C as Codec>::PostingsCodec as PostingsCodec>::Postings>> {
let postings = self.read_postings_from_terminfo_specialized(term_info, option, codec)?;
let term_scorer = TermScorer::new(postings, fieldnorm_reader, similarity_weight);
Ok(term_scorer)
}
pub(crate) fn new_phrase_scorer_type_specialized<C: Codec>(
&self,
term_infos: &[(usize, TermInfo)],
similarity_weight_opt: Option<Bm25Weight>,
fieldnorm_reader: FieldNormReader,
slop: u32,
codec: &C,
) -> io::Result<PhraseScorer<<<C as Codec>::PostingsCodec as PostingsCodec>::Postings>> {
let mut offset_and_term_postings: Vec<(
usize,
<<C as Codec>::PostingsCodec as PostingsCodec>::Postings,
)> = Vec::with_capacity(term_infos.len());
for (offset, term_info) in term_infos {
let postings = self.read_postings_from_terminfo_specialized(
term_info,
IndexRecordOption::WithFreqsAndPositions,
codec,
)?;
offset_and_term_postings.push((*offset, postings));
}
let phrase_scorer = PhraseScorer::new(
offset_and_term_postings,
similarity_weight_opt,
fieldnorm_reader,
slop,
);
Ok(phrase_scorer)
}
/// Build a new term scorer.
pub fn new_term_scorer(
&self,
term_info: &TermInfo,
option: IndexRecordOption,
fieldnorm_reader: FieldNormReader,
similarity_weight: Bm25Weight,
) -> io::Result<Box<dyn Scorer>> {
let term_scorer = self.codec.load_term_scorer_type_erased(
term_info,
option,
self,
fieldnorm_reader,
similarity_weight,
)?;
Ok(term_scorer)
}
/// Returns a postings object specific with a concrete type.
/// Resets the block segment to another position of the postings
/// file.
///
/// This requires you to provied the actual codec.
pub fn read_postings_from_terminfo_specialized<C: Codec>(
/// This is useful for enumerating through a list of terms,
/// and consuming the associated posting lists while avoiding
/// reallocating a [`BlockSegmentPostings`].
///
/// # Warning
///
/// This does not reset the positions list.
pub fn reset_block_postings_from_terminfo(
&self,
term_info: &TermInfo,
block_postings: &mut BlockSegmentPostings,
) -> io::Result<()> {
let postings_slice = self
.postings_file_slice
.slice(term_info.postings_range.clone());
let postings_bytes = postings_slice.read_bytes()?;
block_postings.reset(term_info.doc_freq, postings_bytes)?;
Ok(())
}
/// Returns a block postings given a `Term`.
/// This method is for an advanced usage only.
///
/// Most users should prefer using [`Self::read_postings()`] instead.
pub fn read_block_postings(
&self,
term: &Term,
option: IndexRecordOption,
codec: &C,
) -> io::Result<<<C as Codec>::PostingsCodec as PostingsCodec>::Postings> {
let option = option.downgrade(self.record_option);
) -> io::Result<Option<BlockSegmentPostings>> {
self.get_term_info(term)?
.map(move |term_info| self.read_block_postings_from_terminfo(&term_info, option))
.transpose()
}
/// Returns a block postings given a `term_info`.
/// This method is for an advanced usage only.
///
/// Most users should prefer using [`Self::read_postings()`] instead.
pub fn read_block_postings_from_terminfo(
&self,
term_info: &TermInfo,
requested_option: IndexRecordOption,
) -> io::Result<BlockSegmentPostings> {
let postings_data = self
.postings_file_slice
.slice(term_info.postings_range.clone())
.read_bytes()?;
let positions_data: Option<OwnedBytes> = if option.has_positions() {
let positions_data = self
.positions_file_slice
.slice(term_info.positions_range.clone())
.read_bytes()?;
Some(positions_data)
} else {
None
};
let postings: <<C as Codec>::PostingsCodec as PostingsCodec>::Postings =
codec.postings_codec().load_postings(
term_info.doc_freq,
postings_data,
self.record_option,
option,
positions_data,
)?;
Ok(postings)
.slice(term_info.postings_range.clone());
BlockSegmentPostings::open(
term_info.doc_freq,
postings_data,
self.record_option,
requested_option,
)
}
/// Returns a posting object given a `term_info`.
@@ -271,9 +225,25 @@ impl InvertedIndexReader {
&self,
term_info: &TermInfo,
option: IndexRecordOption,
) -> io::Result<Box<dyn Postings>> {
self.codec
.load_postings_type_erased(term_info, option, self)
) -> io::Result<SegmentPostings> {
let option = option.downgrade(self.record_option);
let block_postings = self.read_block_postings_from_terminfo(term_info, option)?;
let position_reader = {
if option.has_positions() {
let positions_data = self
.positions_file_slice
.read_bytes_slice(term_info.positions_range.clone())?;
let position_reader = PositionReader::open(positions_data)?;
Some(position_reader)
} else {
None
}
};
Ok(SegmentPostings::from_block_postings(
block_postings,
position_reader,
))
}
/// Returns the total number of tokens recorded for all documents
@@ -296,7 +266,7 @@ impl InvertedIndexReader {
&self,
term: &Term,
option: IndexRecordOption,
) -> io::Result<Option<Box<dyn Postings>>> {
) -> io::Result<Option<SegmentPostings>> {
self.get_term_info(term)?
.map(move |term_info| self.read_postings_from_terminfo(&term_info, option))
.transpose()

View File

@@ -2,7 +2,6 @@
//!
//! It contains `Index` and `Segment`, where a `Index` consists of one or more `Segment`s.
mod codec_configuration;
mod index;
mod index_meta;
mod inverted_index_reader;
@@ -11,7 +10,6 @@ mod segment_component;
mod segment_id;
mod segment_reader;
pub use self::codec_configuration::CodecConfiguration;
pub use self::index::{Index, IndexBuilder};
pub(crate) use self::index_meta::SegmentMetaInventory;
pub use self::index_meta::{IndexMeta, IndexSettings, Order, SegmentMeta};

View File

@@ -2,7 +2,6 @@ use std::fmt;
use std::path::PathBuf;
use super::SegmentComponent;
use crate::codec::StandardCodec;
use crate::directory::error::{OpenReadError, OpenWriteError};
use crate::directory::{Directory, FileSlice, WritePtr};
use crate::index::{Index, SegmentId, SegmentMeta};
@@ -11,25 +10,25 @@ use crate::Opstamp;
/// A segment is a piece of the index.
#[derive(Clone)]
pub struct Segment<C: crate::codec::Codec = StandardCodec> {
index: Index<C>,
pub struct Segment {
index: Index,
meta: SegmentMeta,
}
impl<C: crate::codec::Codec> fmt::Debug for Segment<C> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
impl fmt::Debug for Segment {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Segment({:?})", self.id().uuid_string())
}
}
impl<C: crate::codec::Codec> Segment<C> {
impl Segment {
/// Creates a new segment given an `Index` and a `SegmentId`
pub(crate) fn for_index(index: Index<C>, meta: SegmentMeta) -> Segment<C> {
pub(crate) fn for_index(index: Index, meta: SegmentMeta) -> Segment {
Segment { index, meta }
}
/// Returns the index the segment belongs to.
pub fn index(&self) -> &Index<C> {
pub fn index(&self) -> &Index {
&self.index
}
@@ -47,7 +46,7 @@ impl<C: crate::codec::Codec> Segment<C> {
///
/// This method is only used when updating `max_doc` from 0
/// as we finalize a fresh new segment.
pub fn with_max_doc(self, max_doc: u32) -> Segment<C> {
pub fn with_max_doc(self, max_doc: u32) -> Segment {
Segment {
index: self.index,
meta: self.meta.with_max_doc(max_doc),
@@ -56,7 +55,7 @@ impl<C: crate::codec::Codec> Segment<C> {
#[doc(hidden)]
#[must_use]
pub fn with_delete_meta(self, num_deleted_docs: u32, opstamp: Opstamp) -> Segment<C> {
pub fn with_delete_meta(self, num_deleted_docs: u32, opstamp: Opstamp) -> Segment {
Segment {
index: self.index,
meta: self.meta.with_delete_meta(num_deleted_docs, opstamp),

View File

@@ -6,7 +6,6 @@ use common::{ByteCount, HasLen};
use fnv::FnvHashMap;
use itertools::Itertools;
use crate::codec::ObjectSafeCodec;
use crate::directory::{CompositeFile, FileSlice};
use crate::error::DataCorruption;
use crate::fastfield::{intersect_alive_bitsets, AliveBitSet, FacetReader, FastFieldReaders};
@@ -48,7 +47,6 @@ pub struct SegmentReader {
store_file: FileSlice,
alive_bitset_opt: Option<AliveBitSet>,
schema: Schema,
codec: Arc<dyn ObjectSafeCodec>,
}
impl SegmentReader {
@@ -69,11 +67,6 @@ impl SegmentReader {
&self.schema
}
/// Returns the index codec.
pub fn codec(&self) -> &dyn ObjectSafeCodec {
&*self.codec
}
/// Return the number of documents that have been
/// deleted in the segment.
pub fn num_deleted_docs(&self) -> DocId {
@@ -147,16 +140,15 @@ impl SegmentReader {
}
/// Open a new segment for reading.
pub fn open<C: crate::codec::Codec>(segment: &Segment<C>) -> crate::Result<SegmentReader> {
pub fn open(segment: &Segment) -> crate::Result<SegmentReader> {
Self::open_with_custom_alive_set(segment, None)
}
/// Open a new segment for reading.
pub fn open_with_custom_alive_set<C: crate::codec::Codec>(
segment: &Segment<C>,
pub fn open_with_custom_alive_set(
segment: &Segment,
custom_bitset: Option<AliveBitSet>,
) -> crate::Result<SegmentReader> {
let codec: Arc<dyn ObjectSafeCodec> = Arc::new(segment.index().codec().clone());
let termdict_file = segment.open_read(SegmentComponent::Terms)?;
let termdict_composite = CompositeFile::open(&termdict_file)?;
@@ -212,7 +204,6 @@ impl SegmentReader {
alive_bitset_opt,
positions_composite,
schema,
codec,
})
}
@@ -282,7 +273,6 @@ impl SegmentReader {
postings_file,
positions_file,
record_option,
self.codec.clone(),
)?);
// by releasing the lock in between, we may end up opening the inverting index

View File

@@ -9,7 +9,6 @@ use smallvec::smallvec;
use super::operation::{AddOperation, UserOperation};
use super::segment_updater::SegmentUpdater;
use super::{AddBatch, AddBatchReceiver, AddBatchSender, PreparedCommit};
use crate::codec::{Codec, StandardCodec};
use crate::directory::{DirectoryLock, GarbageCollectionResult, TerminatingWrite};
use crate::error::TantivyError;
use crate::fastfield::write_alive_bitset;
@@ -69,12 +68,12 @@ pub struct IndexWriterOptions {
/// indexing queue.
/// Each indexing thread builds its own independent [`Segment`], via
/// a `SegmentWriter` object.
pub struct IndexWriter<C: Codec = StandardCodec, D: Document = TantivyDocument> {
pub struct IndexWriter<D: Document = TantivyDocument> {
// the lock is just used to bind the
// lifetime of the lock with that of the IndexWriter.
_directory_lock: Option<DirectoryLock>,
index: Index<C>,
index: Index,
options: IndexWriterOptions,
@@ -83,7 +82,7 @@ pub struct IndexWriter<C: Codec = StandardCodec, D: Document = TantivyDocument>
index_writer_status: IndexWriterStatus<D>,
operation_sender: AddBatchSender<D>,
segment_updater: SegmentUpdater<C>,
segment_updater: SegmentUpdater,
worker_id: usize,
@@ -129,8 +128,8 @@ fn compute_deleted_bitset(
/// is `==` target_opstamp.
/// For instance, there was no delete operation between the state of the `segment_entry` and
/// the `target_opstamp`, `segment_entry` is not updated.
pub fn advance_deletes<C: Codec>(
mut segment: Segment<C>,
pub fn advance_deletes(
mut segment: Segment,
segment_entry: &mut SegmentEntry,
target_opstamp: Opstamp,
) -> crate::Result<()> {
@@ -180,11 +179,11 @@ pub fn advance_deletes<C: Codec>(
Ok(())
}
fn index_documents<C: crate::codec::Codec, D: Document>(
fn index_documents<D: Document>(
memory_budget: usize,
segment: Segment<C>,
segment: Segment,
grouped_document_iterator: &mut dyn Iterator<Item = AddBatch<D>>,
segment_updater: &SegmentUpdater<C>,
segment_updater: &SegmentUpdater,
mut delete_cursor: DeleteCursor,
) -> crate::Result<()> {
let mut segment_writer = SegmentWriter::for_segment(memory_budget, segment.clone())?;
@@ -227,8 +226,8 @@ fn index_documents<C: crate::codec::Codec, D: Document>(
}
/// `doc_opstamps` is required to be non-empty.
fn apply_deletes<C: crate::codec::Codec>(
segment: &Segment<C>,
fn apply_deletes(
segment: &Segment,
delete_cursor: &mut DeleteCursor,
doc_opstamps: &[Opstamp],
) -> crate::Result<Option<BitSet>> {
@@ -263,7 +262,7 @@ fn apply_deletes<C: crate::codec::Codec>(
})
}
impl<C: Codec, D: Document> IndexWriter<C, D> {
impl<D: Document> IndexWriter<D> {
/// Create a new index writer. Attempts to acquire a lockfile.
///
/// The lockfile should be deleted on drop, but it is possible
@@ -279,7 +278,7 @@ impl<C: Codec, D: Document> IndexWriter<C, D> {
/// If the memory arena per thread is too small or too big, returns
/// `TantivyError::InvalidArgument`
pub(crate) fn new(
index: &Index<C>,
index: &Index,
options: IndexWriterOptions,
directory_lock: DirectoryLock,
) -> crate::Result<Self> {
@@ -346,7 +345,7 @@ impl<C: Codec, D: Document> IndexWriter<C, D> {
}
/// Accessor to the index.
pub fn index(&self) -> &Index<C> {
pub fn index(&self) -> &Index {
&self.index
}
@@ -394,7 +393,7 @@ impl<C: Codec, D: Document> IndexWriter<C, D> {
/// It is safe to start writing file associated with the new `Segment`.
/// These will not be garbage collected as long as an instance object of
/// `SegmentMeta` object associated with the new `Segment` is "alive".
pub fn new_segment(&self) -> Segment<C> {
pub fn new_segment(&self) -> Segment {
self.index.new_segment()
}
@@ -616,7 +615,7 @@ impl<C: Codec, D: Document> IndexWriter<C, D> {
/// It is also possible to add a payload to the `commit`
/// using this API.
/// See [`PreparedCommit::set_payload()`].
pub fn prepare_commit(&mut self) -> crate::Result<PreparedCommit<'_, C, D>> {
pub fn prepare_commit(&mut self) -> crate::Result<PreparedCommit<'_, D>> {
// Here, because we join all of the worker threads,
// all of the segment update for this commit have been
// sent.
@@ -666,7 +665,7 @@ impl<C: Codec, D: Document> IndexWriter<C, D> {
self.prepare_commit()?.commit()
}
pub(crate) fn segment_updater(&self) -> &SegmentUpdater<C> {
pub(crate) fn segment_updater(&self) -> &SegmentUpdater {
&self.segment_updater
}
@@ -805,7 +804,7 @@ impl<C: Codec, D: Document> IndexWriter<C, D> {
}
}
impl<C: Codec, D: Document> Drop for IndexWriter<C, D> {
impl<D: Document> Drop for IndexWriter<D> {
fn drop(&mut self) {
self.segment_updater.kill();
self.drop_sender();

View File

@@ -1,10 +1,9 @@
#[cfg(test)]
mod tests {
use crate::codec::StandardCodec;
use crate::collector::TopDocs;
use crate::fastfield::AliveBitSet;
use crate::index::Index;
use crate::postings::{DocFreq, Postings};
use crate::postings::Postings;
use crate::query::QueryParser;
use crate::schema::{
self, BytesOptions, Facet, FacetOptions, IndexRecordOption, NumericOptions,
@@ -122,26 +121,21 @@ mod tests {
let my_text_field = index.schema().get_field("text_field").unwrap();
let term_a = Term::from_field_text(my_text_field, "text");
let inverted_index = segment_reader.inverted_index(my_text_field).unwrap();
let term_info = inverted_index.get_term_info(&term_a).unwrap().unwrap();
let mut postings = inverted_index
.read_postings_from_terminfo_specialized(
&term_info,
IndexRecordOption::WithFreqsAndPositions,
&StandardCodec,
)
.read_postings(&term_a, IndexRecordOption::WithFreqsAndPositions)
.unwrap()
.unwrap();
assert_eq!(postings.doc_freq(), DocFreq::Exact(2));
assert_eq!(postings.doc_freq(), 2);
let fallback_bitset = AliveBitSet::for_test_from_deleted_docs(&[0], 100);
assert_eq!(
crate::indexer::merger::doc_freq_given_deletes(
&postings,
postings.doc_freq_given_deletes(
segment_reader.alive_bitset().unwrap_or(&fallback_bitset)
),
2
);
assert_eq!(postings.term_freq(), 1);
let mut output = Vec::new();
let mut output = vec![];
postings.positions(&mut output);
assert_eq!(output, vec![1]);
postings.advance();

View File

@@ -7,8 +7,6 @@ use common::ReadOnlyBitSet;
use itertools::Itertools;
use measure_time::debug_time;
use crate::codec::postings::PostingsCodec;
use crate::codec::{Codec, StandardCodec};
use crate::directory::WritePtr;
use crate::docset::{DocSet, TERMINATED};
use crate::error::DataCorruption;
@@ -17,7 +15,7 @@ use crate::fieldnorm::{FieldNormReader, FieldNormReaders, FieldNormsSerializer,
use crate::index::{Segment, SegmentComponent, SegmentReader};
use crate::indexer::doc_id_mapping::{MappingType, SegmentDocIdMapping};
use crate::indexer::SegmentSerializer;
use crate::postings::{InvertedIndexSerializer, Postings};
use crate::postings::{InvertedIndexSerializer, Postings, SegmentPostings};
use crate::schema::{value_type_to_column_type, Field, FieldType, Schema};
use crate::store::StoreWriter;
use crate::termdict::{TermMerger, TermOrdinal};
@@ -78,11 +76,10 @@ fn estimate_total_num_tokens(readers: &[SegmentReader], field: Field) -> crate::
Ok(total_num_tokens)
}
pub struct IndexMerger<C: Codec = StandardCodec> {
pub struct IndexMerger {
schema: Schema,
pub(crate) readers: Vec<SegmentReader>,
max_doc: u32,
codec: C,
}
struct DeltaComputer {
@@ -147,8 +144,8 @@ fn extract_fast_field_required_columns(schema: &Schema) -> Vec<(String, ColumnTy
.collect()
}
impl<C: Codec> IndexMerger<C> {
pub fn open(schema: Schema, segments: &[Segment<C>]) -> crate::Result<IndexMerger<C>> {
impl IndexMerger {
pub fn open(schema: Schema, segments: &[Segment]) -> crate::Result<IndexMerger> {
let alive_bitset = segments.iter().map(|_| None).collect_vec();
Self::open_with_custom_alive_set(schema, segments, alive_bitset)
}
@@ -165,15 +162,11 @@ impl<C: Codec> IndexMerger<C> {
// This can be used to merge but also apply an additional filter.
// One use case is demux, which is basically taking a list of
// segments and partitions them e.g. by a value in a field.
//
// # Panics if segments is empty.
pub fn open_with_custom_alive_set(
schema: Schema,
segments: &[Segment<C>],
segments: &[Segment],
alive_bitset_opt: Vec<Option<AliveBitSet>>,
) -> crate::Result<IndexMerger<C>> {
assert!(!segments.is_empty());
let codec = segments[0].index().codec().clone();
) -> crate::Result<IndexMerger> {
let mut readers = vec![];
for (segment, new_alive_bitset_opt) in segments.iter().zip(alive_bitset_opt) {
if segment.meta().num_docs() > 0 {
@@ -196,7 +189,6 @@ impl<C: Codec> IndexMerger<C> {
schema,
readers,
max_doc,
codec,
})
}
@@ -295,7 +287,7 @@ impl<C: Codec> IndexMerger<C> {
&self,
indexed_field: Field,
_field_type: &FieldType,
serializer: &mut InvertedIndexSerializer<C>,
serializer: &mut InvertedIndexSerializer,
fieldnorm_reader: Option<FieldNormReader>,
doc_id_mapping: &SegmentDocIdMapping,
) -> crate::Result<()> {
@@ -363,10 +355,7 @@ impl<C: Codec> IndexMerger<C> {
indexed. Have you modified the schema?",
);
let mut segment_postings_containing_the_term: Vec<(
usize,
<C::PostingsCodec as PostingsCodec>::Postings,
)> = Vec::with_capacity(self.readers.len());
let mut segment_postings_containing_the_term: Vec<(usize, SegmentPostings)> = vec![];
while merged_terms.advance() {
segment_postings_containing_the_term.clear();
@@ -378,24 +367,17 @@ impl<C: Codec> IndexMerger<C> {
for (segment_ord, term_info) in merged_terms.current_segment_ords_and_term_infos() {
let segment_reader = &self.readers[segment_ord];
let inverted_index: &InvertedIndexReader = &field_readers[segment_ord];
let postings = inverted_index.read_postings_from_terminfo_specialized(
&term_info,
segment_postings_option,
&self.codec,
)?;
let segment_postings = inverted_index
.read_postings_from_terminfo(&term_info, segment_postings_option)?;
let alive_bitset_opt = segment_reader.alive_bitset();
let doc_freq = if let Some(alive_bitset) = alive_bitset_opt {
doc_freq_given_deletes(&postings, alive_bitset)
segment_postings.doc_freq_given_deletes(alive_bitset)
} else {
// We do not an exact document frequency here.
match postings.doc_freq() {
crate::postings::DocFreq::Approximate(_) => exact_doc_freq(&postings),
crate::postings::DocFreq::Exact(doc_freq) => doc_freq,
}
segment_postings.doc_freq()
};
if doc_freq > 0u32 {
total_doc_freq += doc_freq;
segment_postings_containing_the_term.push((segment_ord, postings));
segment_postings_containing_the_term.push((segment_ord, segment_postings));
}
}
@@ -413,7 +395,11 @@ impl<C: Codec> IndexMerger<C> {
assert!(!segment_postings_containing_the_term.is_empty());
let has_term_freq = {
let has_term_freq = segment_postings_containing_the_term[0].1.has_freq();
let has_term_freq = !segment_postings_containing_the_term[0]
.1
.block_cursor
.freqs()
.is_empty();
for (_, postings) in &segment_postings_containing_the_term[1..] {
// This may look at a strange way to test whether we have term freq or not.
// With JSON object, the schema is not sufficient to know whether a term
@@ -429,7 +415,7 @@ impl<C: Codec> IndexMerger<C> {
//
// Overall the reliable way to know if we have actual frequencies loaded or not
// is to check whether the actual decoded array is empty or not.
if postings.has_freq() != has_term_freq {
if has_term_freq == postings.block_cursor.freqs().is_empty() {
return Err(DataCorruption::comment_only(
"Term freqs are inconsistent across segments",
)
@@ -481,7 +467,7 @@ impl<C: Codec> IndexMerger<C> {
fn write_postings(
&self,
serializer: &mut InvertedIndexSerializer<C>,
serializer: &mut InvertedIndexSerializer,
fieldnorm_readers: FieldNormReaders,
doc_id_mapping: &SegmentDocIdMapping,
) -> crate::Result<()> {
@@ -539,7 +525,7 @@ impl<C: Codec> IndexMerger<C> {
///
/// # Returns
/// The number of documents in the resulting segment.
pub fn write(&self, mut serializer: SegmentSerializer<C>) -> crate::Result<u32> {
pub fn write(&self, mut serializer: SegmentSerializer) -> crate::Result<u32> {
let doc_id_mapping = self.get_doc_id_from_concatenated_data()?;
debug!("write-fieldnorms");
if let Some(fieldnorms_serializer) = serializer.extract_fieldnorms_serializer() {
@@ -567,43 +553,6 @@ impl<C: Codec> IndexMerger<C> {
}
}
/// Compute the number of non-deleted documents.
///
/// This method will clone and scan through the posting lists.
/// (this is a rather expensive operation).
pub(crate) fn doc_freq_given_deletes<P: Postings + Clone>(
postings: &P,
alive_bitset: &AliveBitSet,
) -> u32 {
let mut docset = postings.clone();
let mut doc_freq = 0;
loop {
let doc = docset.doc();
if doc == TERMINATED {
return doc_freq;
}
if alive_bitset.is_alive(doc) {
doc_freq += 1u32;
}
docset.advance();
}
}
/// If the postings is not able to inform us of the document frequency,
/// we just scan through it.
pub(crate) fn exact_doc_freq<P: Postings + Clone>(postings: &P) -> u32 {
let mut docset = postings.clone();
let mut doc_freq = 0;
loop {
let doc = docset.doc();
if doc == TERMINATED {
return doc_freq;
}
doc_freq += 1u32;
docset.advance();
}
}
#[cfg(test)]
mod tests {
@@ -612,16 +561,12 @@ mod tests {
use proptest::strategy::Strategy;
use schema::FAST;
use crate::codec::postings::PostingsCodec;
use crate::codec::standard::postings::StandardPostingsCodec;
use crate::collector::tests::{
BytesFastFieldTestCollector, FastFieldTestCollector, TEST_COLLECTOR_WITH_SCORE,
};
use crate::collector::{Count, FacetCollector};
use crate::fastfield::AliveBitSet;
use crate::index::{Index, SegmentId};
use crate::indexer::NoMergePolicy;
use crate::postings::{DocFreq, Postings as _};
use crate::query::{AllQuery, BooleanQuery, EnableScoring, Scorer, TermQuery};
use crate::schema::{
Facet, FacetOptions, IndexRecordOption, NumericOptions, TantivyDocument, Term,
@@ -1573,10 +1518,10 @@ mod tests {
let searcher = reader.searcher();
let mut term_scorer = term_query
.specialized_weight(EnableScoring::enabled_from_searcher(&searcher))?
.term_scorer_for_test(searcher.segment_reader(0u32), 1.0)
.term_scorer_for_test(searcher.segment_reader(0u32), 1.0)?
.unwrap();
assert_eq!(term_scorer.doc(), 0);
assert_nearly_equals!(term_scorer.seek_block_max(0), 0.0079681855);
assert_nearly_equals!(term_scorer.block_max_score(), 0.0079681855);
assert_nearly_equals!(term_scorer.score(), 0.0079681855);
for _ in 0..81 {
writer.add_document(doc!(text=>"hello happy tax payer"))?;
@@ -1589,13 +1534,13 @@ mod tests {
for segment_reader in searcher.segment_readers() {
let mut term_scorer = term_query
.specialized_weight(EnableScoring::enabled_from_searcher(&searcher))?
.term_scorer_for_test(segment_reader, 1.0)
.term_scorer_for_test(segment_reader, 1.0)?
.unwrap();
// the difference compared to before is intrinsic to the bm25 formula. no worries
// there.
for doc in segment_reader.doc_ids_alive() {
assert_eq!(term_scorer.doc(), doc);
assert_nearly_equals!(term_scorer.seek_block_max(doc), 0.003478312);
assert_nearly_equals!(term_scorer.block_max_score(), 0.003478312);
assert_nearly_equals!(term_scorer.score(), 0.003478312);
term_scorer.advance();
}
@@ -1615,12 +1560,12 @@ mod tests {
let segment_reader = searcher.segment_reader(0u32);
let mut term_scorer = term_query
.specialized_weight(EnableScoring::enabled_from_searcher(&searcher))?
.term_scorer_for_test(segment_reader, 1.0)
.term_scorer_for_test(segment_reader, 1.0)?
.unwrap();
// the difference compared to before is intrinsic to the bm25 formula. no worries there.
for doc in segment_reader.doc_ids_alive() {
assert_eq!(term_scorer.doc(), doc);
assert_nearly_equals!(term_scorer.seek_block_max(doc), 0.003478312);
assert_nearly_equals!(term_scorer.block_max_score(), 0.003478312);
assert_nearly_equals!(term_scorer.score(), 0.003478312);
term_scorer.advance();
}
@@ -1634,16 +1579,4 @@ mod tests {
assert!(((super::MAX_DOC_LIMIT - 1) as i32) >= 0);
assert!((super::MAX_DOC_LIMIT as i32) < 0);
}
#[test]
fn test_doc_freq_given_delete() {
let docs =
<StandardPostingsCodec as PostingsCodec>::Postings::create_from_docs(&[0, 2, 10]);
assert_eq!(docs.doc_freq(), DocFreq::Exact(3));
let alive_bitset = AliveBitSet::for_test_from_deleted_docs(&[2], 12);
assert_eq!(super::doc_freq_given_deletes(&docs, &alive_bitset), 2);
let all_deleted =
AliveBitSet::for_test_from_deleted_docs(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], 12);
assert_eq!(super::doc_freq_given_deletes(&docs, &all_deleted), 0);
}
}

View File

@@ -1,17 +1,16 @@
use super::IndexWriter;
use crate::codec::Codec;
use crate::schema::document::Document;
use crate::{FutureResult, Opstamp, TantivyDocument};
/// A prepared commit
pub struct PreparedCommit<'a, C: Codec, D: Document = TantivyDocument> {
index_writer: &'a mut IndexWriter<C, D>,
pub struct PreparedCommit<'a, D: Document = TantivyDocument> {
index_writer: &'a mut IndexWriter<D>,
payload: Option<String>,
opstamp: Opstamp,
}
impl<'a, C: Codec, D: Document> PreparedCommit<'a, C, D> {
pub(crate) fn new(index_writer: &'a mut IndexWriter<C, D>, opstamp: Opstamp) -> Self {
impl<'a, D: Document> PreparedCommit<'a, D> {
pub(crate) fn new(index_writer: &'a mut IndexWriter<D>, opstamp: Opstamp) -> Self {
Self {
index_writer,
payload: None,

View File

@@ -8,17 +8,17 @@ use crate::store::StoreWriter;
/// Segment serializer is in charge of laying out on disk
/// the data accumulated and sorted by the `SegmentWriter`.
pub struct SegmentSerializer<C: crate::codec::Codec> {
segment: Segment<C>,
pub struct SegmentSerializer {
segment: Segment,
pub(crate) store_writer: StoreWriter,
fast_field_write: WritePtr,
fieldnorms_serializer: Option<FieldNormsSerializer>,
postings_serializer: InvertedIndexSerializer<C>,
postings_serializer: InvertedIndexSerializer,
}
impl<C: crate::codec::Codec> SegmentSerializer<C> {
impl SegmentSerializer {
/// Creates a new `SegmentSerializer`.
pub fn for_segment(mut segment: Segment<C>) -> crate::Result<SegmentSerializer<C>> {
pub fn for_segment(mut segment: Segment) -> crate::Result<SegmentSerializer> {
let settings = segment.index().settings().clone();
let store_writer = {
let store_write = segment.open_write(SegmentComponent::Store)?;
@@ -50,12 +50,12 @@ impl<C: crate::codec::Codec> SegmentSerializer<C> {
self.store_writer.mem_usage()
}
pub fn segment(&self) -> &Segment<C> {
pub fn segment(&self) -> &Segment {
&self.segment
}
/// Accessor to the `PostingsSerializer`.
pub fn get_postings_serializer(&mut self) -> &mut InvertedIndexSerializer<C> {
pub fn get_postings_serializer(&mut self) -> &mut InvertedIndexSerializer {
&mut self.postings_serializer
}

View File

@@ -10,13 +10,10 @@ use std::sync::{Arc, RwLock};
use rayon::{ThreadPool, ThreadPoolBuilder};
use super::segment_manager::SegmentManager;
use crate::codec::Codec;
use crate::core::META_FILEPATH;
use crate::directory::{Directory, DirectoryClone, GarbageCollectionResult};
use crate::fastfield::AliveBitSet;
use crate::index::{
CodecConfiguration, Index, IndexMeta, IndexSettings, Segment, SegmentId, SegmentMeta,
};
use crate::index::{Index, IndexMeta, IndexSettings, Segment, SegmentId, SegmentMeta};
use crate::indexer::delete_queue::DeleteCursor;
use crate::indexer::index_writer::advance_deletes;
use crate::indexer::merge_operation::MergeOperationInventory;
@@ -64,10 +61,10 @@ pub(crate) fn save_metas(metas: &IndexMeta, directory: &dyn Directory) -> crate:
// We voluntarily pass a merge_operation ref to guarantee that
// the merge_operation is alive during the process
#[derive(Clone)]
pub(crate) struct SegmentUpdater<C: Codec>(Arc<InnerSegmentUpdater<C>>);
pub(crate) struct SegmentUpdater(Arc<InnerSegmentUpdater>);
impl<C: Codec> Deref for SegmentUpdater<C> {
type Target = InnerSegmentUpdater<C>;
impl Deref for SegmentUpdater {
type Target = InnerSegmentUpdater;
#[inline]
fn deref(&self) -> &Self::Target {
@@ -75,8 +72,8 @@ impl<C: Codec> Deref for SegmentUpdater<C> {
}
}
fn garbage_collect_files<C: Codec>(
segment_updater: SegmentUpdater<C>,
fn garbage_collect_files(
segment_updater: SegmentUpdater,
) -> crate::Result<GarbageCollectionResult> {
info!("Running garbage collection");
let mut index = segment_updater.index.clone();
@@ -87,8 +84,8 @@ fn garbage_collect_files<C: Codec>(
/// Merges a list of segments the list of segment givens in the `segment_entries`.
/// This function happens in the calling thread and is computationally expensive.
fn merge<Codec: crate::codec::Codec>(
index: &Index<Codec>,
fn merge(
index: &Index,
mut segment_entries: Vec<SegmentEntry>,
target_opstamp: Opstamp,
) -> crate::Result<Option<SegmentEntry>> {
@@ -111,13 +108,13 @@ fn merge<Codec: crate::codec::Codec>(
let delete_cursor = segment_entries[0].delete_cursor().clone();
let segments: Vec<Segment<Codec>> = segment_entries
let segments: Vec<Segment> = segment_entries
.iter()
.map(|segment_entry| index.segment(segment_entry.meta().clone()))
.collect();
// An IndexMerger is like a "view" of our merged segments.
let merger: IndexMerger<Codec> = IndexMerger::open(index.schema(), &segments[..])?;
let merger: IndexMerger = IndexMerger::open(index.schema(), &segments[..])?;
// ... we just serialize this index merger in our new segment to merge the segments.
let segment_serializer = SegmentSerializer::for_segment(merged_segment.clone())?;
@@ -142,10 +139,10 @@ fn merge<Codec: crate::codec::Codec>(
/// meant to work if you have an `IndexWriter` running for the origin indices, or
/// the destination `Index`.
#[doc(hidden)]
pub fn merge_indices<Codec: crate::codec::Codec>(
indices: &[Index<Codec>],
output_directory: Box<dyn Directory>,
) -> crate::Result<Index<Codec>> {
pub fn merge_indices<T: Into<Box<dyn Directory>>>(
indices: &[Index],
output_directory: T,
) -> crate::Result<Index> {
if indices.is_empty() {
// If there are no indices to merge, there is no need to do anything.
return Err(crate::TantivyError::InvalidArgument(
@@ -166,7 +163,7 @@ pub fn merge_indices<Codec: crate::codec::Codec>(
));
}
let mut segments: Vec<Segment<Codec>> = Vec::new();
let mut segments: Vec<Segment> = Vec::new();
for index in indices {
segments.extend(index.searchable_segments()?);
}
@@ -188,12 +185,12 @@ pub fn merge_indices<Codec: crate::codec::Codec>(
/// meant to work if you have an `IndexWriter` running for the origin indices, or
/// the destination `Index`.
#[doc(hidden)]
pub fn merge_filtered_segments<C: crate::codec::Codec, T: Into<Box<dyn Directory>>>(
segments: &[Segment<C>],
pub fn merge_filtered_segments<T: Into<Box<dyn Directory>>>(
segments: &[Segment],
target_settings: IndexSettings,
filter_doc_ids: Vec<Option<AliveBitSet>>,
output_directory: T,
) -> crate::Result<Index<C>> {
) -> crate::Result<Index> {
if segments.is_empty() {
// If there are no indices to merge, there is no need to do anything.
return Err(crate::TantivyError::InvalidArgument(
@@ -214,15 +211,14 @@ pub fn merge_filtered_segments<C: crate::codec::Codec, T: Into<Box<dyn Directory
));
}
let mut merged_index: Index<C> = Index::builder()
.schema(target_schema.clone())
.codec(segments[0].index().codec().clone())
.settings(target_settings.clone())
.create(output_directory.into())?;
let mut merged_index = Index::create(
output_directory,
target_schema.clone(),
target_settings.clone(),
)?;
let merged_segment = merged_index.new_segment();
let merged_segment_id = merged_segment.id();
let merger: IndexMerger<C> =
let merger: IndexMerger =
IndexMerger::open_with_custom_alive_set(merged_index.schema(), segments, filter_doc_ids)?;
let segment_serializer = SegmentSerializer::for_segment(merged_segment)?;
let num_docs = merger.write(segment_serializer)?;
@@ -239,7 +235,6 @@ pub fn merge_filtered_segments<C: crate::codec::Codec, T: Into<Box<dyn Directory
))
.trim_end()
);
let codec_configuration = CodecConfiguration::from(segments[0].index().codec());
let index_meta = IndexMeta {
index_settings: target_settings, // index_settings of all segments should be the same
@@ -247,7 +242,6 @@ pub fn merge_filtered_segments<C: crate::codec::Codec, T: Into<Box<dyn Directory
schema: target_schema,
opstamp: 0u64,
payload: Some(stats),
codec: codec_configuration,
};
// save the meta.json
@@ -256,7 +250,7 @@ pub fn merge_filtered_segments<C: crate::codec::Codec, T: Into<Box<dyn Directory
Ok(merged_index)
}
pub(crate) struct InnerSegmentUpdater<C: Codec> {
pub(crate) struct InnerSegmentUpdater {
// we keep a copy of the current active IndexMeta to
// avoid loading the file every time we need it in the
// `SegmentUpdater`.
@@ -267,7 +261,7 @@ pub(crate) struct InnerSegmentUpdater<C: Codec> {
pool: ThreadPool,
merge_thread_pool: ThreadPool,
index: Index<C>,
index: Index,
segment_manager: SegmentManager,
merge_policy: RwLock<Arc<dyn MergePolicy>>,
killed: AtomicBool,
@@ -275,13 +269,13 @@ pub(crate) struct InnerSegmentUpdater<C: Codec> {
merge_operations: MergeOperationInventory,
}
impl<Codec: crate::codec::Codec> SegmentUpdater<Codec> {
impl SegmentUpdater {
pub fn create(
index: Index<Codec>,
index: Index,
stamper: Stamper,
delete_cursor: &DeleteCursor,
num_merge_threads: usize,
) -> crate::Result<Self> {
) -> crate::Result<SegmentUpdater> {
let segments = index.searchable_segment_metas()?;
let segment_manager = SegmentManager::from_segments(segments, delete_cursor);
let pool = ThreadPoolBuilder::new()
@@ -410,14 +404,12 @@ impl<Codec: crate::codec::Codec> SegmentUpdater<Codec> {
//
// Segment 1 from disk 1, Segment 1 from disk 2, etc.
committed_segment_metas.sort_by_key(|segment_meta| -(segment_meta.max_doc() as i32));
let codec = CodecConfiguration::from(index.codec());
let index_meta = IndexMeta {
index_settings: index.settings().clone(),
segments: committed_segment_metas,
schema: index.schema(),
opstamp,
payload: commit_message,
codec,
};
// TODO add context to the error.
save_metas(&index_meta, directory.box_clone().borrow_mut())?;
@@ -451,7 +443,7 @@ impl<Codec: crate::codec::Codec> SegmentUpdater<Codec> {
opstamp: Opstamp,
payload: Option<String>,
) -> FutureResult<Opstamp> {
let segment_updater: SegmentUpdater<Codec> = self.clone();
let segment_updater: SegmentUpdater = self.clone();
self.schedule_task(move || {
let segment_entries = segment_updater.purge_deletes(opstamp)?;
segment_updater.segment_manager.commit(segment_entries);
@@ -710,7 +702,6 @@ impl<Codec: crate::codec::Codec> SegmentUpdater<Codec> {
#[cfg(test)]
mod tests {
use super::merge_indices;
use crate::codec::StandardCodec;
use crate::collector::TopDocs;
use crate::directory::RamDirectory;
use crate::fastfield::AliveBitSet;
@@ -924,7 +915,7 @@ mod tests {
#[test]
fn test_merge_empty_indices_array() {
let merge_result = merge_indices::<StandardCodec>(&[], Box::new(RamDirectory::default()));
let merge_result = merge_indices(&[], RamDirectory::default());
assert!(merge_result.is_err());
}
@@ -951,10 +942,7 @@ mod tests {
};
// mismatched schema index list
let result = merge_indices(
&[first_index, second_index],
Box::new(RamDirectory::default()),
);
let result = merge_indices(&[first_index, second_index], RamDirectory::default());
assert!(result.is_err());
Ok(())

View File

@@ -4,7 +4,6 @@ use itertools::Itertools;
use tokenizer_api::BoxTokenStream;
use super::operation::AddOperation;
use crate::codec::Codec;
use crate::fastfield::FastFieldsWriter;
use crate::fieldnorm::{FieldNormReaders, FieldNormsWriter};
use crate::index::{Segment, SegmentComponent};
@@ -13,7 +12,7 @@ use crate::indexer::segment_serializer::SegmentSerializer;
use crate::json_utils::{index_json_value, IndexingPositionsPerPath};
use crate::postings::{
compute_table_memory_size, serialize_postings, IndexingContext, IndexingPosition,
PerFieldPostingsWriter, PostingsWriter, PostingsWriterEnum,
PerFieldPostingsWriter, PostingsWriter,
};
use crate::schema::document::{Document, Value};
use crate::schema::{FieldEntry, FieldType, Schema, DATE_TIME_PRECISION_INDEXED};
@@ -46,11 +45,11 @@ fn compute_initial_table_size(per_thread_memory_budget: usize) -> crate::Result<
///
/// They creates the postings list in anonymous memory.
/// The segment is laid on disk when the segment gets `finalized`.
pub struct SegmentWriter<Codec: crate::codec::Codec> {
pub struct SegmentWriter {
pub(crate) max_doc: DocId,
pub(crate) ctx: IndexingContext,
pub(crate) per_field_postings_writers: PerFieldPostingsWriter,
pub(crate) segment_serializer: SegmentSerializer<Codec>,
pub(crate) segment_serializer: SegmentSerializer,
pub(crate) fast_field_writers: FastFieldsWriter,
pub(crate) fieldnorms_writer: FieldNormsWriter,
pub(crate) json_path_writer: JsonPathWriter,
@@ -61,7 +60,7 @@ pub struct SegmentWriter<Codec: crate::codec::Codec> {
schema: Schema,
}
impl<Codec: crate::codec::Codec> SegmentWriter<Codec> {
impl SegmentWriter {
/// Creates a new `SegmentWriter`
///
/// The arguments are defined as follows
@@ -71,10 +70,7 @@ impl<Codec: crate::codec::Codec> SegmentWriter<Codec> {
/// behavior as a memory limit.
/// - segment: The segment being written
/// - schema
pub fn for_segment(
memory_budget_in_bytes: usize,
segment: Segment<Codec>,
) -> crate::Result<Self> {
pub fn for_segment(memory_budget_in_bytes: usize, segment: Segment) -> crate::Result<Self> {
let schema = segment.schema();
let tokenizer_manager = segment.index().tokenizers().clone();
let tokenizer_manager_fast_field = segment.index().fast_field_tokenizer().clone();
@@ -173,7 +169,7 @@ impl<Codec: crate::codec::Codec> SegmentWriter<Codec> {
}
let (term_buffer, ctx) = (&mut self.term_buffer, &mut self.ctx);
let postings_writer: &mut PostingsWriterEnum =
let postings_writer: &mut dyn PostingsWriter =
self.per_field_postings_writers.get_for_field_mut(field);
term_buffer.clear_with_field(field);
@@ -390,13 +386,13 @@ impl<Codec: crate::codec::Codec> SegmentWriter<Codec> {
/// to the `SegmentSerializer`.
///
/// `doc_id_map` is used to map to the new doc_id order.
fn remap_and_write<C: Codec>(
fn remap_and_write(
schema: Schema,
per_field_postings_writers: &PerFieldPostingsWriter,
ctx: IndexingContext,
fast_field_writers: FastFieldsWriter,
fieldnorms_writer: &FieldNormsWriter,
mut serializer: SegmentSerializer<C>,
mut serializer: SegmentSerializer,
) -> crate::Result<()> {
debug!("remap-and-write");
if let Some(fieldnorms_serializer) = serializer.extract_fieldnorms_serializer() {

View File

@@ -1,7 +1,5 @@
use std::marker::PhantomData;
use crate::codec::StandardCodec;
use crate::index::CodecConfiguration;
use crate::indexer::operation::AddOperation;
use crate::indexer::segment_updater::save_metas;
use crate::indexer::SegmentWriter;
@@ -9,25 +7,22 @@ use crate::schema::document::Document;
use crate::{Directory, Index, IndexMeta, Opstamp, Segment, TantivyDocument};
#[doc(hidden)]
pub struct SingleSegmentIndexWriter<
Codec: crate::codec::Codec = StandardCodec,
D: Document = TantivyDocument,
> {
segment_writer: SegmentWriter<Codec>,
segment: Segment<Codec>,
pub struct SingleSegmentIndexWriter<D: Document = TantivyDocument> {
segment_writer: SegmentWriter,
segment: Segment,
opstamp: Opstamp,
_doc: PhantomData<D>,
_phantom: PhantomData<D>,
}
impl<Codec: crate::codec::Codec, D: Document> SingleSegmentIndexWriter<Codec, D> {
pub fn new(index: Index<Codec>, mem_budget: usize) -> crate::Result<Self> {
impl<D: Document> SingleSegmentIndexWriter<D> {
pub fn new(index: Index, mem_budget: usize) -> crate::Result<Self> {
let segment = index.new_segment();
let segment_writer = SegmentWriter::for_segment(mem_budget, segment.clone())?;
Ok(Self {
segment_writer,
segment,
opstamp: 0,
_doc: PhantomData,
_phantom: PhantomData,
})
}
@@ -42,10 +37,10 @@ impl<Codec: crate::codec::Codec, D: Document> SingleSegmentIndexWriter<Codec, D>
.add_document(AddOperation { opstamp, document })
}
pub fn finalize(self) -> crate::Result<Index<Codec>> {
pub fn finalize(self) -> crate::Result<Index> {
let max_doc = self.segment_writer.max_doc();
self.segment_writer.finalize()?;
let segment: Segment<Codec> = self.segment.with_max_doc(max_doc);
let segment: Segment = self.segment.with_max_doc(max_doc);
let index = segment.index();
let index_meta = IndexMeta {
index_settings: index.settings().clone(),
@@ -53,7 +48,6 @@ impl<Codec: crate::codec::Codec, D: Document> SingleSegmentIndexWriter<Codec, D>
schema: index.schema(),
opstamp: 0,
payload: None,
codec: CodecConfiguration::from(index.codec()),
};
save_metas(&index_meta, index.directory())?;
index.directory().sync_directory()?;

View File

@@ -166,9 +166,6 @@ mod functional_test;
#[macro_use]
mod macros;
/// Tantivy codecs describes how data is layed out on disk.
pub mod codec;
mod future_result;
// Re-exports

View File

@@ -1,19 +1,28 @@
use std::io;
use common::{OwnedBytes, VInt};
use common::VInt;
use crate::codec::standard::postings::skip::{BlockInfo, SkipReader};
use crate::codec::standard::postings::FreqReadingOption;
use crate::directory::{FileSlice, OwnedBytes};
use crate::fieldnorm::FieldNormReader;
use crate::postings::compression::{BlockDecoder, VIntDecoder as _, COMPRESSION_BLOCK_SIZE};
use crate::postings::compression::{BlockDecoder, VIntDecoder, COMPRESSION_BLOCK_SIZE};
use crate::postings::{BlockInfo, FreqReadingOption, SkipReader};
use crate::query::Bm25Weight;
use crate::schema::IndexRecordOption;
use crate::{DocId, Score, TERMINATED};
fn max_score<I: Iterator<Item = Score>>(mut it: I) -> Option<Score> {
it.next().map(|first| it.fold(first, Score::max))
}
/// `BlockSegmentPostings` is a cursor iterating over blocks
/// of documents.
///
/// # Warning
///
/// While it is useful for some very specific high-performance
/// use cases, you should prefer using `SegmentPostings` for most usage.
#[derive(Clone)]
pub(crate) struct BlockSegmentPostings {
pub struct BlockSegmentPostings {
pub(crate) doc_decoder: BlockDecoder,
block_loaded: bool,
freq_decoder: BlockDecoder,
@@ -79,7 +88,7 @@ fn split_into_skips_and_postings(
}
impl BlockSegmentPostings {
/// Opens a `StandardPostingsReader`.
/// Opens a `BlockSegmentPostings`.
/// `doc_freq` is the number of documents in the posting list.
/// `record_option` represents the amount of data available according to the schema.
/// `requested_option` is the amount of data requested by the user.
@@ -87,10 +96,11 @@ impl BlockSegmentPostings {
/// term frequency blocks.
pub(crate) fn open(
doc_freq: u32,
bytes: OwnedBytes,
data: FileSlice,
mut record_option: IndexRecordOption,
requested_option: IndexRecordOption,
) -> io::Result<BlockSegmentPostings> {
let bytes = data.read_bytes()?;
let (skip_data_opt, postings_data) = split_into_skips_and_postings(doc_freq, bytes)?;
let skip_reader = match skip_data_opt {
Some(skip_data) => {
@@ -128,86 +138,6 @@ impl BlockSegmentPostings {
block_segment_postings.load_block();
Ok(block_segment_postings)
}
}
fn max_score<I: Iterator<Item = Score>>(mut it: I) -> Option<Score> {
it.next().map(|first| it.fold(first, Score::max))
}
impl BlockSegmentPostings {
/// Returns the overall number of documents in the block postings.
/// It does not take in account whether documents are deleted or not.
///
/// This `doc_freq` is simply the sum of the length of all of the blocks
/// length, and it does not take in account deleted documents.
pub fn doc_freq(&self) -> u32 {
self.doc_freq
}
/// Returns the array of docs in the current block.
///
/// Before the first call to `.advance()`, the block
/// returned by `.docs()` is empty.
#[inline]
pub fn docs(&self) -> &[DocId] {
debug_assert!(self.block_loaded);
self.doc_decoder.output_array()
}
/// Return the document at index `idx` of the block.
#[inline]
pub fn doc(&self, idx: usize) -> u32 {
self.doc_decoder.output(idx)
}
/// Return the array of `term freq` in the block.
#[inline]
pub fn freqs(&self) -> &[u32] {
debug_assert!(self.block_loaded);
self.freq_decoder.output_array()
}
/// Return the frequency at index `idx` of the block.
#[inline]
pub fn freq(&self, idx: usize) -> u32 {
debug_assert!(self.block_loaded);
self.freq_decoder.output(idx)
}
/// Position on a block that may contains `target_doc`.
///
/// If all docs are smaller than target, the block loaded may be empty,
/// or be the last an incomplete VInt block.
pub fn seek(&mut self, target_doc: DocId) -> usize {
// Move to the block that might contain our document.
self.seek_block_without_loading(target_doc);
self.load_block();
// At this point we are on the block that might contain our document.
let doc = self.doc_decoder.seek_within_block(target_doc);
// The last block is not full and padded with TERMINATED,
// so we are guaranteed to have at least one value (real or padding)
// that is >= target_doc.
debug_assert!(doc < COMPRESSION_BLOCK_SIZE);
// `doc` is now the first element >= `target_doc`.
// If all docs are smaller than target, the current block is incomplete and padded
// with TERMINATED. After the search, the cursor points to the first TERMINATED.
doc
}
pub fn position_offset(&self) -> u64 {
self.skip_reader.position_offset()
}
/// Advance to the next block.
pub fn advance(&mut self) {
self.skip_reader.advance();
self.block_loaded = false;
self.block_max_score_cache = None;
self.load_block();
}
/// Returns the block_max_score for the current block.
/// It does not require the block to be loaded. For instance, it is ok to call this method
@@ -230,7 +160,7 @@ impl BlockSegmentPostings {
}
// this is the last block of the segment posting list.
// If it is actually loaded, we can compute block max manually.
if self.block_loaded {
if self.block_is_loaded() {
let docs = self.doc_decoder.output_array().iter().cloned();
let freqs = self.freq_decoder.output_array().iter().cloned();
let bm25_scores = docs.zip(freqs).map(|(doc, term_freq)| {
@@ -247,25 +177,112 @@ impl BlockSegmentPostings {
// We do not cache it however, so that it gets computed when once block is loaded.
bm25_weight.max_score()
}
}
impl BlockSegmentPostings {
/// Returns an empty segment postings object
pub fn empty() -> BlockSegmentPostings {
BlockSegmentPostings {
doc_decoder: BlockDecoder::with_val(TERMINATED),
block_loaded: true,
freq_decoder: BlockDecoder::with_val(1),
freq_reading_option: FreqReadingOption::NoFreq,
block_max_score_cache: None,
doc_freq: 0,
data: OwnedBytes::empty(),
skip_reader: SkipReader::new(OwnedBytes::empty(), 0, IndexRecordOption::Basic),
}
pub(crate) fn freq_reading_option(&self) -> FreqReadingOption {
self.freq_reading_option
}
pub(crate) fn skip_reader(&self) -> &SkipReader {
&self.skip_reader
// Resets the block segment postings on another position
// in the postings file.
//
// This is useful for enumerating through a list of terms,
// and consuming the associated posting lists while avoiding
// reallocating a `BlockSegmentPostings`.
//
// # Warning
//
// This does not reset the positions list.
pub(crate) fn reset(&mut self, doc_freq: u32, postings_data: OwnedBytes) -> io::Result<()> {
let (skip_data_opt, postings_data) =
split_into_skips_and_postings(doc_freq, postings_data)?;
self.data = postings_data;
self.block_max_score_cache = None;
self.block_loaded = false;
if let Some(skip_data) = skip_data_opt {
self.skip_reader.reset(skip_data, doc_freq);
} else {
self.skip_reader.reset(OwnedBytes::empty(), doc_freq);
}
self.doc_freq = doc_freq;
self.load_block();
Ok(())
}
/// Returns the overall number of documents in the block postings.
/// It does not take in account whether documents are deleted or not.
///
/// This `doc_freq` is simply the sum of the length of all of the blocks
/// length, and it does not take in account deleted documents.
pub fn doc_freq(&self) -> u32 {
self.doc_freq
}
/// Returns the array of docs in the current block.
///
/// Before the first call to `.advance()`, the block
/// returned by `.docs()` is empty.
#[inline]
pub fn docs(&self) -> &[DocId] {
debug_assert!(self.block_is_loaded());
self.doc_decoder.output_array()
}
/// Return the document at index `idx` of the block.
#[inline]
pub fn doc(&self, idx: usize) -> u32 {
self.doc_decoder.output(idx)
}
/// Return the array of `term freq` in the block.
#[inline]
pub fn freqs(&self) -> &[u32] {
debug_assert!(self.block_is_loaded());
self.freq_decoder.output_array()
}
/// Return the frequency at index `idx` of the block.
#[inline]
pub fn freq(&self, idx: usize) -> u32 {
debug_assert!(self.block_is_loaded());
self.freq_decoder.output(idx)
}
/// Returns the length of the current block.
///
/// All blocks have a length of `NUM_DOCS_PER_BLOCK`,
/// except the last block that may have a length
/// of any number between 1 and `NUM_DOCS_PER_BLOCK - 1`
#[inline]
pub fn block_len(&self) -> usize {
debug_assert!(self.block_is_loaded());
self.doc_decoder.output_len
}
/// Position on a block that may contains `target_doc`.
///
/// If all docs are smaller than target, the block loaded may be empty,
/// or be the last an incomplete VInt block.
pub fn seek(&mut self, target_doc: DocId) -> usize {
// Move to the block that might contain our document.
self.seek_block(target_doc);
self.load_block();
// At this point we are on the block that might contain our document.
let doc = self.doc_decoder.seek_within_block(target_doc);
// The last block is not full and padded with TERMINATED,
// so we are guaranteed to have at least one value (real or padding)
// that is >= target_doc.
debug_assert!(doc < COMPRESSION_BLOCK_SIZE);
// `doc` is now the first element >= `target_doc`.
// If all docs are smaller than target, the current block is incomplete and padded
// with TERMINATED. After the search, the cursor points to the first TERMINATED.
doc
}
pub(crate) fn position_offset(&self) -> u64 {
self.skip_reader.position_offset()
}
/// Dangerous API! This calls seeks the next block on the skip list,
@@ -274,15 +291,19 @@ impl BlockSegmentPostings {
/// `.load_block()` needs to be called manually afterwards.
/// If all docs are smaller than target, the block loaded may be empty,
/// or be the last an incomplete VInt block.
pub(crate) fn seek_block_without_loading(&mut self, target_doc: DocId) {
pub(crate) fn seek_block(&mut self, target_doc: DocId) {
if self.skip_reader.seek(target_doc) {
self.block_max_score_cache = None;
self.block_loaded = false;
}
}
pub(crate) fn block_is_loaded(&self) -> bool {
self.block_loaded
}
pub(crate) fn load_block(&mut self) {
if self.block_loaded {
if self.block_is_loaded() {
return;
}
let offset = self.skip_reader.byte_offset();
@@ -330,40 +351,68 @@ impl BlockSegmentPostings {
}
self.block_loaded = true;
}
/// Advance to the next block.
pub fn advance(&mut self) {
self.skip_reader.advance();
self.block_loaded = false;
self.block_max_score_cache = None;
self.load_block();
}
/// Returns an empty segment postings object
pub fn empty() -> BlockSegmentPostings {
BlockSegmentPostings {
doc_decoder: BlockDecoder::with_val(TERMINATED),
block_loaded: true,
freq_decoder: BlockDecoder::with_val(1),
freq_reading_option: FreqReadingOption::NoFreq,
block_max_score_cache: None,
doc_freq: 0,
data: OwnedBytes::empty(),
skip_reader: SkipReader::new(OwnedBytes::empty(), 0, IndexRecordOption::Basic),
}
}
pub(crate) fn skip_reader(&self) -> &SkipReader {
&self.skip_reader
}
}
#[cfg(test)]
mod tests {
use common::OwnedBytes;
use common::HasLen;
use super::BlockSegmentPostings;
use crate::codec::postings::PostingsSerializer;
use crate::codec::standard::postings::segment_postings::SegmentPostings;
use crate::codec::standard::postings::StandardPostingsSerializer;
use crate::docset::{DocSet, TERMINATED};
use crate::index::Index;
use crate::postings::compression::COMPRESSION_BLOCK_SIZE;
use crate::schema::IndexRecordOption;
use crate::postings::postings::Postings;
use crate::postings::SegmentPostings;
use crate::schema::{IndexRecordOption, Schema, Term, INDEXED};
use crate::DocId;
#[cfg(test)]
fn build_block_postings(docs: &[u32]) -> BlockSegmentPostings {
let doc_freq = docs.len() as u32;
let mut postings_serializer =
StandardPostingsSerializer::new(1.0f32, IndexRecordOption::Basic, None);
postings_serializer.new_term(docs.len() as u32, false);
for doc in docs {
postings_serializer.write_doc(*doc, 1u32);
}
let mut buffer: Vec<u8> = Vec::new();
postings_serializer
.close_term(doc_freq, &mut buffer)
.unwrap();
BlockSegmentPostings::open(
doc_freq,
OwnedBytes::new(buffer),
IndexRecordOption::Basic,
IndexRecordOption::Basic,
)
.unwrap()
#[test]
fn test_empty_segment_postings() {
let mut postings = SegmentPostings::empty();
assert_eq!(postings.doc(), TERMINATED);
assert_eq!(postings.advance(), TERMINATED);
assert_eq!(postings.advance(), TERMINATED);
assert_eq!(postings.doc_freq(), 0);
assert_eq!(postings.len(), 0);
}
#[test]
fn test_empty_postings_doc_returns_terminated() {
let mut postings = SegmentPostings::empty();
assert_eq!(postings.doc(), TERMINATED);
assert_eq!(postings.advance(), TERMINATED);
}
#[test]
fn test_empty_postings_doc_term_freq_returns_0() {
let postings = SegmentPostings::empty();
assert_eq!(postings.term_freq(), 1);
}
#[test]
@@ -378,7 +427,7 @@ mod tests {
#[test]
fn test_block_segment_postings() -> crate::Result<()> {
let mut block_segments = build_block_postings(&(0..100_000).collect::<Vec<u32>>());
let mut block_segments = build_block_postings(&(0..100_000).collect::<Vec<u32>>())?;
let mut offset: u32 = 0u32;
// checking that the `doc_freq` is correct
assert_eq!(block_segments.doc_freq(), 100_000);
@@ -403,7 +452,7 @@ mod tests {
doc_ids.push(129);
doc_ids.push(130);
{
let block_segments = build_block_postings(&doc_ids);
let block_segments = build_block_postings(&doc_ids)?;
let mut docset = SegmentPostings::from_block_postings(block_segments, None);
assert_eq!(docset.seek(128), 129);
assert_eq!(docset.doc(), 129);
@@ -412,7 +461,7 @@ mod tests {
assert_eq!(docset.advance(), TERMINATED);
}
{
let block_segments = build_block_postings(&doc_ids);
let block_segments = build_block_postings(&doc_ids).unwrap();
let mut docset = SegmentPostings::from_block_postings(block_segments, None);
assert_eq!(docset.seek(129), 129);
assert_eq!(docset.doc(), 129);
@@ -421,7 +470,7 @@ mod tests {
assert_eq!(docset.advance(), TERMINATED);
}
{
let block_segments = build_block_postings(&doc_ids);
let block_segments = build_block_postings(&doc_ids)?;
let mut docset = SegmentPostings::from_block_postings(block_segments, None);
assert_eq!(docset.doc(), 0);
assert_eq!(docset.seek(131), TERMINATED);
@@ -430,13 +479,38 @@ mod tests {
Ok(())
}
fn build_block_postings(docs: &[DocId]) -> crate::Result<BlockSegmentPostings> {
let mut schema_builder = Schema::builder();
let int_field = schema_builder.add_u64_field("id", INDEXED);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests()?;
let mut last_doc = 0u32;
for &doc in docs {
for _ in last_doc..doc {
index_writer.add_document(doc!(int_field=>1u64))?;
}
index_writer.add_document(doc!(int_field=>0u64))?;
last_doc = doc + 1;
}
index_writer.commit()?;
let searcher = index.reader()?.searcher();
let segment_reader = searcher.segment_reader(0);
let inverted_index = segment_reader.inverted_index(int_field).unwrap();
let term = Term::from_field_u64(int_field, 0u64);
let term_info = inverted_index.get_term_info(&term)?.unwrap();
let block_postings = inverted_index
.read_block_postings_from_terminfo(&term_info, IndexRecordOption::Basic)?;
Ok(block_postings)
}
#[test]
fn test_block_segment_postings_seek() -> crate::Result<()> {
let mut docs = Vec::new();
let mut docs = vec![0];
for i in 0..1300 {
docs.push((i * i / 100) + i);
}
let mut block_postings = build_block_postings(&docs[..]);
let mut block_postings = build_block_postings(&docs[..])?;
for i in &[0, 424, 10000] {
block_postings.seek(*i);
let docs = block_postings.docs();
@@ -447,4 +521,40 @@ mod tests {
assert_eq!(block_postings.doc(COMPRESSION_BLOCK_SIZE - 1), TERMINATED);
Ok(())
}
#[test]
fn test_reset_block_segment_postings() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let int_field = schema_builder.add_u64_field("id", INDEXED);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests()?;
// create two postings list, one containing even number,
// the other containing odd numbers.
for i in 0..6 {
let doc = doc!(int_field=> (i % 2) as u64);
index_writer.add_document(doc)?;
}
index_writer.commit()?;
let searcher = index.reader()?.searcher();
let segment_reader = searcher.segment_reader(0);
let mut block_segments;
{
let term = Term::from_field_u64(int_field, 0u64);
let inverted_index = segment_reader.inverted_index(int_field)?;
let term_info = inverted_index.get_term_info(&term)?.unwrap();
block_segments = inverted_index
.read_block_postings_from_terminfo(&term_info, IndexRecordOption::Basic)?;
}
assert_eq!(block_segments.docs(), &[0, 2, 4]);
{
let term = Term::from_field_u64(int_field, 1u64);
let inverted_index = segment_reader.inverted_index(int_field)?;
let term_info = inverted_index.get_term_info(&term)?.unwrap();
inverted_index.reset_block_postings_from_terminfo(&term_info, &mut block_segments)?;
}
assert_eq!(block_segments.docs(), &[1, 3, 5]);
Ok(())
}
}

View File

@@ -3,7 +3,6 @@ use std::io;
use common::json_path_writer::JSON_END_OF_PATH;
use stacker::Addr;
use crate::codec::Codec;
use crate::indexer::indexing_term::IndexingTerm;
use crate::indexer::path_to_unordered_id::OrderedPathId;
use crate::postings::postings_writer::SpecializedPostingsWriter;
@@ -23,6 +22,12 @@ pub(crate) struct JsonPostingsWriter<Rec: Recorder> {
non_str_posting_writer: SpecializedPostingsWriter<DocIdRecorder>,
}
impl<Rec: Recorder> From<JsonPostingsWriter<Rec>> for Box<dyn PostingsWriter> {
fn from(json_postings_writer: JsonPostingsWriter<Rec>) -> Box<dyn PostingsWriter> {
Box::new(json_postings_writer)
}
}
impl<Rec: Recorder> PostingsWriter for JsonPostingsWriter<Rec> {
#[inline]
fn subscribe(
@@ -53,12 +58,12 @@ impl<Rec: Recorder> PostingsWriter for JsonPostingsWriter<Rec> {
}
/// The actual serialization format is handled by the `PostingsSerializer`.
fn serialize<C: Codec>(
fn serialize(
&self,
ordered_term_addrs: &[(Field, OrderedPathId, &[u8], Addr)],
ordered_id_to_path: &[&str],
ctx: &IndexingContext,
serializer: &mut FieldSerializer<C>,
serializer: &mut FieldSerializer,
) -> io::Result<()> {
let mut term_buffer = JsonTermSerializer(Vec::with_capacity(48));
let mut buffer_lender = BufferLender::default();

View File

@@ -1,5 +1,5 @@
use crate::docset::{DocSet, TERMINATED};
use crate::postings::{DocFreq, Postings};
use crate::postings::{Postings, SegmentPostings};
use crate::DocId;
/// `LoadedPostings` is a `DocSet` and `Postings` implementation.
@@ -25,16 +25,16 @@ impl LoadedPostings {
/// Creates a new `LoadedPostings` from a `SegmentPostings`.
///
/// It will also preload positions, if positions are available in the SegmentPostings.
pub fn load(postings: &mut Box<dyn Postings>) -> LoadedPostings {
let num_docs: usize = u32::from(postings.doc_freq()) as usize;
pub fn load(segment_postings: &mut SegmentPostings) -> LoadedPostings {
let num_docs = segment_postings.doc_freq() as usize;
let mut doc_ids = Vec::with_capacity(num_docs);
let mut positions = Vec::with_capacity(num_docs);
let mut position_offsets = Vec::with_capacity(num_docs);
while postings.doc() != TERMINATED {
while segment_postings.doc() != TERMINATED {
position_offsets.push(positions.len() as u32);
doc_ids.push(postings.doc());
postings.append_positions_with_offset(0, &mut positions);
postings.advance();
doc_ids.push(segment_postings.doc());
segment_postings.append_positions_with_offset(0, &mut positions);
segment_postings.advance();
}
position_offsets.push(positions.len() as u32);
LoadedPostings {
@@ -101,14 +101,6 @@ impl Postings for LoadedPostings {
output.push(*pos + offset);
}
}
fn has_freq(&self) -> bool {
true
}
fn doc_freq(&self) -> DocFreq {
DocFreq::Exact(self.doc_ids.len() as u32)
}
}
#[cfg(test)]

View File

@@ -4,6 +4,7 @@ mod block_search;
pub(crate) use self::block_search::branchless_binary_search;
mod block_segment_postings;
pub(crate) mod compression;
mod indexing_context;
mod json_postings_writer;
@@ -12,22 +13,32 @@ mod per_field_postings_writer;
mod postings;
mod postings_writer;
mod recorder;
mod segment_postings;
mod serializer;
mod skip;
mod term_info;
pub(crate) use loaded_postings::LoadedPostings;
pub use postings::DocFreq;
pub(crate) use stacker::compute_table_memory_size;
pub use self::block_segment_postings::BlockSegmentPostings;
pub(crate) use self::indexing_context::IndexingContext;
pub(crate) use self::per_field_postings_writer::PerFieldPostingsWriter;
pub use self::postings::Postings;
pub(crate) use self::postings_writer::{
serialize_postings, IndexingPosition, PostingsWriter, PostingsWriterEnum,
};
pub(crate) use self::postings_writer::{serialize_postings, IndexingPosition, PostingsWriter};
pub use self::segment_postings::SegmentPostings;
pub use self::serializer::{FieldSerializer, InvertedIndexSerializer};
pub(crate) use self::skip::{BlockInfo, SkipReader};
pub use self::term_info::TermInfo;
#[expect(clippy::enum_variant_names)]
#[derive(Debug, PartialEq, Clone, Copy, Eq)]
pub(crate) enum FreqReadingOption {
NoFreq,
SkipFreq,
ReadFreq,
}
#[cfg(test)]
pub(crate) mod tests {
use std::mem;
@@ -38,7 +49,6 @@ pub(crate) mod tests {
use crate::index::{Index, SegmentComponent, SegmentReader};
use crate::indexer::operation::AddOperation;
use crate::indexer::SegmentWriter;
use crate::postings::DocFreq;
use crate::query::Scorer;
use crate::schema::{
Field, IndexRecordOption, Schema, Term, TextFieldIndexing, TextOptions, INDEXED, TEXT,
@@ -269,11 +279,11 @@ pub(crate) mod tests {
}
{
let term_a = Term::from_field_text(text_field, "a");
let mut postings_a: Box<dyn Postings> = segment_reader
let mut postings_a = segment_reader
.inverted_index(term_a.field())?
.read_postings(&term_a, IndexRecordOption::WithFreqsAndPositions)?
.unwrap();
assert_eq!(postings_a.doc_freq(), DocFreq::Exact(1000));
assert_eq!(postings_a.len(), 1000);
assert_eq!(postings_a.doc(), 0);
assert_eq!(postings_a.term_freq(), 6);
postings_a.positions(&mut positions);
@@ -296,7 +306,7 @@ pub(crate) mod tests {
.inverted_index(term_e.field())?
.read_postings(&term_e, IndexRecordOption::WithFreqsAndPositions)?
.unwrap();
assert_eq!(postings_e.doc_freq(), DocFreq::Exact(1000 - 2));
assert_eq!(postings_e.len(), 1000 - 2);
for i in 2u32..1000u32 {
assert_eq!(postings_e.term_freq(), i);
postings_e.positions(&mut positions);

View File

@@ -1,15 +1,16 @@
use crate::postings::json_postings_writer::JsonPostingsWriter;
use crate::postings::postings_writer::{PostingsWriterEnum, SpecializedPostingsWriter};
use crate::postings::postings_writer::SpecializedPostingsWriter;
use crate::postings::recorder::{DocIdRecorder, TermFrequencyRecorder, TfAndPositionRecorder};
use crate::postings::PostingsWriter;
use crate::schema::{Field, FieldEntry, FieldType, IndexRecordOption, Schema};
pub(crate) struct PerFieldPostingsWriter {
per_field_postings_writers: Vec<PostingsWriterEnum>,
per_field_postings_writers: Vec<Box<dyn PostingsWriter>>,
}
impl PerFieldPostingsWriter {
pub fn for_schema(schema: &Schema) -> Self {
let per_field_postings_writers: Vec<PostingsWriterEnum> = schema
let per_field_postings_writers = schema
.fields()
.map(|(_, field_entry)| posting_writer_from_field_entry(field_entry))
.collect();
@@ -18,16 +19,16 @@ impl PerFieldPostingsWriter {
}
}
pub(crate) fn get_for_field(&self, field: Field) -> &PostingsWriterEnum {
&self.per_field_postings_writers[field.field_id() as usize]
pub(crate) fn get_for_field(&self, field: Field) -> &dyn PostingsWriter {
self.per_field_postings_writers[field.field_id() as usize].as_ref()
}
pub(crate) fn get_for_field_mut(&mut self, field: Field) -> &mut PostingsWriterEnum {
&mut self.per_field_postings_writers[field.field_id() as usize]
pub(crate) fn get_for_field_mut(&mut self, field: Field) -> &mut dyn PostingsWriter {
self.per_field_postings_writers[field.field_id() as usize].as_mut()
}
}
fn posting_writer_from_field_entry(field_entry: &FieldEntry) -> PostingsWriterEnum {
fn posting_writer_from_field_entry(field_entry: &FieldEntry) -> Box<dyn PostingsWriter> {
match *field_entry.field_type() {
FieldType::Str(ref text_options) => text_options
.get_indexing_options()
@@ -50,7 +51,7 @@ fn posting_writer_from_field_entry(field_entry: &FieldEntry) -> PostingsWriterEn
| FieldType::Date(_)
| FieldType::Bytes(_)
| FieldType::IpAddr(_)
| FieldType::Facet(_) => <SpecializedPostingsWriter<DocIdRecorder>>::default().into(),
| FieldType::Facet(_) => Box::<SpecializedPostingsWriter<DocIdRecorder>>::default(),
FieldType::JsonObject(ref json_object_options) => {
if let Some(text_indexing_option) = json_object_options.get_text_indexing_options() {
match text_indexing_option.index_option() {

View File

@@ -1,25 +1,5 @@
use crate::docset::DocSet;
/// Result of the doc_freq method.
///
/// Postings can inform us that the document frequency is approximate.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DocFreq {
/// The document frequency is approximate.
Approximate(u32),
/// The document frequency is exact.
Exact(u32),
}
impl From<DocFreq> for u32 {
fn from(doc_freq: DocFreq) -> Self {
match doc_freq {
DocFreq::Approximate(approximate_doc_freq) => approximate_doc_freq,
DocFreq::Exact(doc_freq) => doc_freq,
}
}
}
/// Postings (also called inverted list)
///
/// For a given term, it is the list of doc ids of the doc
@@ -34,9 +14,6 @@ pub trait Postings: DocSet + 'static {
/// The number of times the term appears in the document.
fn term_freq(&self) -> u32;
/// Returns the number of documents containing the term in the segment.
fn doc_freq(&self) -> DocFreq;
/// Returns the positions offsetted with a given value.
/// It is not necessary to clear the `output` before calling this method.
/// The output vector will be resized to the `term_freq`.
@@ -54,16 +31,6 @@ pub trait Postings: DocSet + 'static {
fn positions(&mut self, output: &mut Vec<u32>) {
self.positions_with_offset(0u32, output);
}
/// Returns true if the term_frequency is available.
///
/// This is a tricky question, because on JSON fields, it is possible
/// for a text term to have term freq, whereas a number term in the field has none.
///
/// This function returns whether the actual term has term frequencies or not.
/// In this above JSON field example, `has_freq` should return true for the
/// earlier and false for the latter.
fn has_freq(&self) -> bool;
}
impl Postings for Box<dyn Postings> {
@@ -74,12 +41,4 @@ impl Postings for Box<dyn Postings> {
fn append_positions_with_offset(&mut self, offset: u32, output: &mut Vec<u32>) {
(**self).append_positions_with_offset(offset, output);
}
fn has_freq(&self) -> bool {
(**self).has_freq()
}
fn doc_freq(&self) -> DocFreq {
(**self).doc_freq()
}
}

View File

@@ -4,14 +4,10 @@ use std::ops::Range;
use stacker::Addr;
use crate::codec::Codec;
use crate::fieldnorm::FieldNormReaders;
use crate::indexer::indexing_term::IndexingTerm;
use crate::indexer::path_to_unordered_id::OrderedPathId;
use crate::postings::json_postings_writer::JsonPostingsWriter;
use crate::postings::recorder::{
BufferLender, DocIdRecorder, Recorder, TermFrequencyRecorder, TfAndPositionRecorder,
};
use crate::postings::recorder::{BufferLender, Recorder};
use crate::postings::{
FieldSerializer, IndexingContext, InvertedIndexSerializer, PerFieldPostingsWriter,
};
@@ -49,12 +45,12 @@ fn make_field_partition(
/// Serialize the inverted index.
/// It pushes all term, one field at a time, towards the
/// postings serializer.
pub(crate) fn serialize_postings<C: Codec>(
pub(crate) fn serialize_postings(
ctx: IndexingContext,
schema: Schema,
per_field_postings_writers: &PerFieldPostingsWriter,
fieldnorm_readers: FieldNormReaders,
serializer: &mut InvertedIndexSerializer<C>,
serializer: &mut InvertedIndexSerializer,
) -> crate::Result<()> {
// Replace unordered ids by ordered ids to be able to sort
let unordered_id_to_ordered_id: Vec<OrderedPathId> =
@@ -104,141 +100,6 @@ pub(crate) struct IndexingPosition {
pub end_position: u32,
}
pub enum PostingsWriterEnum {
DocId(SpecializedPostingsWriter<DocIdRecorder>),
DocIdTf(SpecializedPostingsWriter<TermFrequencyRecorder>),
DocTfAndPosition(SpecializedPostingsWriter<TfAndPositionRecorder>),
JsonDocId(JsonPostingsWriter<DocIdRecorder>),
JsonDocIdTf(JsonPostingsWriter<TermFrequencyRecorder>),
JsonDocTfAndPosition(JsonPostingsWriter<TfAndPositionRecorder>),
}
impl From<SpecializedPostingsWriter<DocIdRecorder>> for PostingsWriterEnum {
fn from(doc_id_recorder_writer: SpecializedPostingsWriter<DocIdRecorder>) -> Self {
PostingsWriterEnum::DocId(doc_id_recorder_writer)
}
}
impl From<SpecializedPostingsWriter<TermFrequencyRecorder>> for PostingsWriterEnum {
fn from(doc_id_tf_recorder_writer: SpecializedPostingsWriter<TermFrequencyRecorder>) -> Self {
PostingsWriterEnum::DocIdTf(doc_id_tf_recorder_writer)
}
}
impl From<SpecializedPostingsWriter<TfAndPositionRecorder>> for PostingsWriterEnum {
fn from(
doc_id_tf_and_positions_recorder_writer: SpecializedPostingsWriter<TfAndPositionRecorder>,
) -> Self {
PostingsWriterEnum::DocTfAndPosition(doc_id_tf_and_positions_recorder_writer)
}
}
impl From<JsonPostingsWriter<DocIdRecorder>> for PostingsWriterEnum {
fn from(doc_id_recorder_writer: JsonPostingsWriter<DocIdRecorder>) -> Self {
PostingsWriterEnum::JsonDocId(doc_id_recorder_writer)
}
}
impl From<JsonPostingsWriter<TermFrequencyRecorder>> for PostingsWriterEnum {
fn from(doc_id_tf_recorder_writer: JsonPostingsWriter<TermFrequencyRecorder>) -> Self {
PostingsWriterEnum::JsonDocIdTf(doc_id_tf_recorder_writer)
}
}
impl From<JsonPostingsWriter<TfAndPositionRecorder>> for PostingsWriterEnum {
fn from(
doc_id_tf_and_positions_recorder_writer: JsonPostingsWriter<TfAndPositionRecorder>,
) -> Self {
PostingsWriterEnum::JsonDocTfAndPosition(doc_id_tf_and_positions_recorder_writer)
}
}
impl PostingsWriter for PostingsWriterEnum {
fn subscribe(&mut self, doc: DocId, pos: u32, term: &IndexingTerm, ctx: &mut IndexingContext) {
match self {
PostingsWriterEnum::DocId(writer) => writer.subscribe(doc, pos, term, ctx),
PostingsWriterEnum::DocIdTf(writer) => writer.subscribe(doc, pos, term, ctx),
PostingsWriterEnum::DocTfAndPosition(writer) => writer.subscribe(doc, pos, term, ctx),
PostingsWriterEnum::JsonDocId(writer) => writer.subscribe(doc, pos, term, ctx),
PostingsWriterEnum::JsonDocIdTf(writer) => writer.subscribe(doc, pos, term, ctx),
PostingsWriterEnum::JsonDocTfAndPosition(writer) => {
writer.subscribe(doc, pos, term, ctx)
}
}
}
fn serialize<C: Codec>(
&self,
term_addrs: &[(Field, OrderedPathId, &[u8], Addr)],
ordered_id_to_path: &[&str],
ctx: &IndexingContext,
serializer: &mut FieldSerializer<C>,
) -> io::Result<()> {
match self {
PostingsWriterEnum::DocId(writer) => {
writer.serialize(term_addrs, ordered_id_to_path, ctx, serializer)
}
PostingsWriterEnum::DocIdTf(writer) => {
writer.serialize(term_addrs, ordered_id_to_path, ctx, serializer)
}
PostingsWriterEnum::DocTfAndPosition(writer) => {
writer.serialize(term_addrs, ordered_id_to_path, ctx, serializer)
}
PostingsWriterEnum::JsonDocId(writer) => {
writer.serialize(term_addrs, ordered_id_to_path, ctx, serializer)
}
PostingsWriterEnum::JsonDocIdTf(writer) => {
writer.serialize(term_addrs, ordered_id_to_path, ctx, serializer)
}
PostingsWriterEnum::JsonDocTfAndPosition(writer) => {
writer.serialize(term_addrs, ordered_id_to_path, ctx, serializer)
}
}
}
/// Tokenize a text and subscribe all of its token.
fn index_text(
&mut self,
doc_id: DocId,
token_stream: &mut dyn TokenStream,
term_buffer: &mut IndexingTerm,
ctx: &mut IndexingContext,
indexing_position: &mut IndexingPosition,
) {
match self {
PostingsWriterEnum::DocId(writer) => {
writer.index_text(doc_id, token_stream, term_buffer, ctx, indexing_position)
}
PostingsWriterEnum::DocIdTf(writer) => {
writer.index_text(doc_id, token_stream, term_buffer, ctx, indexing_position)
}
PostingsWriterEnum::DocTfAndPosition(writer) => {
writer.index_text(doc_id, token_stream, term_buffer, ctx, indexing_position)
}
PostingsWriterEnum::JsonDocId(writer) => {
writer.index_text(doc_id, token_stream, term_buffer, ctx, indexing_position)
}
PostingsWriterEnum::JsonDocIdTf(writer) => {
writer.index_text(doc_id, token_stream, term_buffer, ctx, indexing_position)
}
PostingsWriterEnum::JsonDocTfAndPosition(writer) => {
writer.index_text(doc_id, token_stream, term_buffer, ctx, indexing_position)
}
}
}
fn total_num_tokens(&self) -> u64 {
match self {
PostingsWriterEnum::DocId(writer) => writer.total_num_tokens(),
PostingsWriterEnum::DocIdTf(writer) => writer.total_num_tokens(),
PostingsWriterEnum::DocTfAndPosition(writer) => writer.total_num_tokens(),
PostingsWriterEnum::JsonDocId(writer) => writer.total_num_tokens(),
PostingsWriterEnum::JsonDocIdTf(writer) => writer.total_num_tokens(),
PostingsWriterEnum::JsonDocTfAndPosition(writer) => writer.total_num_tokens(),
}
}
}
/// The `PostingsWriter` is in charge of receiving documenting
/// and building a `Segment` in anonymous memory.
///
@@ -255,12 +116,12 @@ pub(crate) trait PostingsWriter: Send + Sync {
/// Serializes the postings on disk.
/// The actual serialization format is handled by the `PostingsSerializer`.
fn serialize<C: Codec>(
fn serialize(
&self,
term_addrs: &[(Field, OrderedPathId, &[u8], Addr)],
ordered_id_to_path: &[&str],
ctx: &IndexingContext,
serializer: &mut FieldSerializer<C>,
serializer: &mut FieldSerializer,
) -> io::Result<()>;
/// Tokenize a text and subscribe all of its token.
@@ -310,14 +171,22 @@ pub(crate) struct SpecializedPostingsWriter<Rec: Recorder> {
_recorder_type: PhantomData<Rec>,
}
impl<Rec: Recorder> From<SpecializedPostingsWriter<Rec>> for Box<dyn PostingsWriter> {
fn from(
specialized_postings_writer: SpecializedPostingsWriter<Rec>,
) -> Box<dyn PostingsWriter> {
Box::new(specialized_postings_writer)
}
}
impl<Rec: Recorder> SpecializedPostingsWriter<Rec> {
#[inline]
pub(crate) fn serialize_one_term<C: Codec>(
pub(crate) fn serialize_one_term(
term: &[u8],
addr: Addr,
buffer_lender: &mut BufferLender,
ctx: &IndexingContext,
serializer: &mut FieldSerializer<C>,
serializer: &mut FieldSerializer,
) -> io::Result<()> {
let recorder: Rec = ctx.term_index.read(addr);
let term_doc_freq = recorder.term_doc_freq().unwrap_or(0u32);
@@ -358,12 +227,12 @@ impl<Rec: Recorder> PostingsWriter for SpecializedPostingsWriter<Rec> {
});
}
fn serialize<C: Codec>(
fn serialize(
&self,
term_addrs: &[(Field, OrderedPathId, &[u8], Addr)],
_ordered_id_to_path: &[&str],
ctx: &IndexingContext,
serializer: &mut FieldSerializer<C>,
serializer: &mut FieldSerializer,
) -> io::Result<()> {
let mut buffer_lender = BufferLender::default();
for (_field, _path_id, term, addr) in term_addrs {

View File

@@ -1,7 +1,6 @@
use common::read_u32_vint;
use stacker::{ExpUnrolledLinkedList, MemoryArena};
use crate::codec::Codec;
use crate::postings::FieldSerializer;
use crate::DocId;
@@ -68,10 +67,10 @@ pub(crate) trait Recorder: Copy + Default + Send + Sync + 'static {
/// Close the document. It will help record the term frequency.
fn close_doc(&mut self, arena: &mut MemoryArena);
/// Pushes the postings information to the serializer.
fn serialize<C: Codec>(
fn serialize(
&self,
arena: &MemoryArena,
serializer: &mut FieldSerializer<C>,
serializer: &mut FieldSerializer<'_>,
buffer_lender: &mut BufferLender,
);
/// Returns the number of document containing this term.
@@ -111,10 +110,10 @@ impl Recorder for DocIdRecorder {
#[inline]
fn close_doc(&mut self, _arena: &mut MemoryArena) {}
fn serialize<C: Codec>(
fn serialize(
&self,
arena: &MemoryArena,
serializer: &mut FieldSerializer<C>,
serializer: &mut FieldSerializer<'_>,
buffer_lender: &mut BufferLender,
) {
let buffer = buffer_lender.lend_u8();
@@ -179,10 +178,10 @@ impl Recorder for TermFrequencyRecorder {
self.current_tf = 0;
}
fn serialize<C: Codec>(
fn serialize(
&self,
arena: &MemoryArena,
serializer: &mut FieldSerializer<C>,
serializer: &mut FieldSerializer<'_>,
buffer_lender: &mut BufferLender,
) {
let buffer = buffer_lender.lend_u8();
@@ -236,10 +235,10 @@ impl Recorder for TfAndPositionRecorder {
self.stack.writer(arena).write_u32_vint(POSITION_END);
}
fn serialize<C: Codec>(
fn serialize(
&self,
arena: &MemoryArena,
serializer: &mut FieldSerializer<C>,
serializer: &mut FieldSerializer<'_>,
buffer_lender: &mut BufferLender,
) {
let (buffer_u8, buffer_positions) = buffer_lender.lend_all();

View File

@@ -1,14 +1,11 @@
use common::BitSet;
use common::HasLen;
use super::BlockSegmentPostings;
use crate::codec::postings::PostingsWithBlockMax;
use crate::docset::DocSet;
use crate::fieldnorm::FieldNormReader;
use crate::fastfield::AliveBitSet;
use crate::positions::PositionReader;
use crate::postings::compression::COMPRESSION_BLOCK_SIZE;
use crate::postings::{DocFreq, Postings};
use crate::query::Bm25Weight;
use crate::{DocId, Score};
use crate::postings::{BlockSegmentPostings, Postings};
use crate::{DocId, TERMINATED};
/// `SegmentPostings` represents the inverted list or postings associated with
/// a term in a `Segment`.
@@ -32,6 +29,31 @@ impl SegmentPostings {
}
}
/// Compute the number of non-deleted documents.
///
/// This method will clone and scan through the posting lists.
/// (this is a rather expensive operation).
pub fn doc_freq_given_deletes(&self, alive_bitset: &AliveBitSet) -> u32 {
let mut docset = self.clone();
let mut doc_freq = 0;
loop {
let doc = docset.doc();
if doc == TERMINATED {
return doc_freq;
}
if alive_bitset.is_alive(doc) {
doc_freq += 1u32;
}
docset.advance();
}
}
/// Returns the overall number of documents in the block postings.
/// It does not take in account whether documents are deleted or not.
pub fn doc_freq(&self) -> u32 {
self.block_cursor.doc_freq()
}
/// Creates a segment postings object with the given documents
/// and no frequency encoded.
///
@@ -42,19 +64,13 @@ impl SegmentPostings {
/// buffer with the serialized data.
#[cfg(test)]
pub fn create_from_docs(docs: &[u32]) -> SegmentPostings {
use common::OwnedBytes;
use crate::directory::FileSlice;
use crate::postings::serializer::PostingsSerializer;
use crate::schema::IndexRecordOption;
let mut buffer = Vec::new();
{
use crate::codec::postings::PostingsSerializer;
let mut postings_serializer =
crate::codec::standard::postings::StandardPostingsSerializer::new(
0.0,
IndexRecordOption::Basic,
None,
);
PostingsSerializer::new(0.0, IndexRecordOption::Basic, None);
postings_serializer.new_term(docs.len() as u32, false);
for &doc in docs {
postings_serializer.write_doc(doc, 1u32);
@@ -65,7 +81,7 @@ impl SegmentPostings {
}
let block_segment_postings = BlockSegmentPostings::open(
docs.len() as u32,
OwnedBytes::new(buffer),
FileSlice::from(buffer),
IndexRecordOption::Basic,
IndexRecordOption::Basic,
)
@@ -79,11 +95,9 @@ impl SegmentPostings {
doc_and_tfs: &[(u32, u32)],
fieldnorms: Option<&[u32]>,
) -> SegmentPostings {
use common::OwnedBytes;
use crate::codec::postings::PostingsSerializer as _;
use crate::codec::standard::postings::StandardPostingsSerializer;
use crate::directory::FileSlice;
use crate::fieldnorm::FieldNormReader;
use crate::postings::serializer::PostingsSerializer;
use crate::schema::IndexRecordOption;
use crate::Score;
let mut buffer: Vec<u8> = Vec::new();
@@ -100,7 +114,7 @@ impl SegmentPostings {
total_num_tokens as Score / fieldnorms.len() as Score
})
.unwrap_or(0.0);
let mut postings_serializer = StandardPostingsSerializer::new(
let mut postings_serializer = PostingsSerializer::new(
average_field_norm,
IndexRecordOption::WithFreqs,
fieldnorm_reader,
@@ -114,7 +128,7 @@ impl SegmentPostings {
.unwrap();
let block_segment_postings = BlockSegmentPostings::open(
doc_and_tfs.len() as u32,
OwnedBytes::new(buffer),
FileSlice::from(buffer),
IndexRecordOption::WithFreqs,
IndexRecordOption::WithFreqs,
)
@@ -144,6 +158,7 @@ impl DocSet for SegmentPostings {
// next needs to be called a first time to point to the correct element.
#[inline]
fn advance(&mut self) -> DocId {
debug_assert!(self.block_cursor.block_is_loaded());
if self.cur == COMPRESSION_BLOCK_SIZE - 1 {
self.cur = 0;
self.block_cursor.advance();
@@ -182,31 +197,13 @@ impl DocSet for SegmentPostings {
}
fn size_hint(&self) -> u32 {
self.doc_freq().into()
self.len() as u32
}
}
fn fill_bitset(&mut self, bitset: &mut BitSet) {
let bitset_max_value: DocId = bitset.max_value();
loop {
let docs = self.block_cursor.docs();
let Some(&last_doc) = docs.last() else {
break;
};
if last_doc < bitset_max_value {
// All docs are within the range of the bitset
for &doc in docs {
bitset.insert(doc);
}
} else {
for &doc in docs {
if doc < bitset_max_value {
bitset.insert(doc);
}
}
break;
}
self.block_cursor.advance();
}
impl HasLen for SegmentPostings {
fn len(&self) -> usize {
self.block_cursor.doc_freq() as usize
}
}
@@ -232,13 +229,6 @@ impl Postings for SegmentPostings {
self.block_cursor.freq(self.cur)
}
/// Returns the overall number of documents in the block postings.
/// It does not take in account whether documents are deleted or not.
#[inline(always)]
fn doc_freq(&self) -> DocFreq {
DocFreq::Exact(self.block_cursor.doc_freq())
}
fn append_positions_with_offset(&mut self, offset: u32, output: &mut Vec<u32>) {
let term_freq = self.term_freq();
let prev_len = output.len();
@@ -262,42 +252,24 @@ impl Postings for SegmentPostings {
}
}
}
fn has_freq(&self) -> bool {
!self.block_cursor.freqs().is_empty()
}
}
impl PostingsWithBlockMax for SegmentPostings {
fn seek_block_max(
&mut self,
target_doc: crate::DocId,
fieldnorm_reader: &FieldNormReader,
similarity_weight: &Bm25Weight,
) -> Score {
self.block_cursor.seek_block_without_loading(target_doc);
self.block_cursor
.block_max_score(fieldnorm_reader, similarity_weight)
}
fn last_doc_in_block(&self) -> crate::DocId {
self.block_cursor.skip_reader().last_doc_in_block()
}
}
#[cfg(test)]
mod tests {
use common::HasLen;
use super::SegmentPostings;
use crate::docset::{DocSet, TERMINATED};
use crate::postings::Postings;
use crate::fastfield::AliveBitSet;
use crate::postings::postings::Postings;
#[test]
fn test_empty_segment_postings() {
let mut postings = SegmentPostings::empty();
assert_eq!(postings.doc(), TERMINATED);
assert_eq!(postings.advance(), TERMINATED);
assert_eq!(postings.advance(), TERMINATED);
assert_eq!(postings.doc_freq(), crate::postings::DocFreq::Exact(0));
assert_eq!(postings.len(), 0);
}
#[test]
@@ -312,4 +284,15 @@ mod tests {
let postings = SegmentPostings::empty();
assert_eq!(postings.term_freq(), 1);
}
#[test]
fn test_doc_freq() {
let docs = SegmentPostings::create_from_docs(&[0, 2, 10]);
assert_eq!(docs.doc_freq(), 3);
let alive_bitset = AliveBitSet::for_test_from_deleted_docs(&[2], 12);
assert_eq!(docs.doc_freq_given_deletes(&alive_bitset), 2);
let all_deleted =
AliveBitSet::for_test_from_deleted_docs(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], 12);
assert_eq!(docs.doc_freq_given_deletes(&all_deleted), 0);
}
}

View File

@@ -1,14 +1,16 @@
use std::cmp::Ordering;
use std::io::{self, Write};
use common::{BinarySerializable, CountingWriter};
use common::{BinarySerializable, CountingWriter, VInt};
use super::TermInfo;
use crate::codec::postings::PostingsSerializer;
use crate::codec::Codec;
use crate::directory::{CompositeWrite, WritePtr};
use crate::fieldnorm::FieldNormReader;
use crate::index::Segment;
use crate::positions::PositionSerializer;
use crate::postings::compression::{BlockEncoder, VIntEncoder, COMPRESSION_BLOCK_SIZE};
use crate::postings::skip::SkipSerializer;
use crate::query::Bm25Weight;
use crate::schema::{Field, FieldEntry, FieldType, IndexRecordOption, Schema};
use crate::termdict::TermDictionaryBuilder;
use crate::{DocId, Score};
@@ -44,27 +46,22 @@ use crate::{DocId, Score};
///
/// A description of the serialization format is
/// [available here](https://fulmicoton.gitbooks.io/tantivy-doc/content/inverted-index.html).
pub struct InvertedIndexSerializer<C: Codec> {
pub struct InvertedIndexSerializer {
terms_write: CompositeWrite<WritePtr>,
postings_write: CompositeWrite<WritePtr>,
positions_write: CompositeWrite<WritePtr>,
schema: Schema,
codec: C,
}
use crate::codec::postings::PostingsCodec;
impl<C: Codec> InvertedIndexSerializer<C> {
impl InvertedIndexSerializer {
/// Open a new `InvertedIndexSerializer` for the given segment
pub fn open(segment: &mut Segment<C>) -> crate::Result<InvertedIndexSerializer<C>> {
pub fn open(segment: &mut Segment) -> crate::Result<InvertedIndexSerializer> {
use crate::index::SegmentComponent::{Positions, Postings, Terms};
let codec = segment.index().codec().clone();
let inv_index_serializer = InvertedIndexSerializer {
terms_write: CompositeWrite::wrap(segment.open_write(Terms)?),
postings_write: CompositeWrite::wrap(segment.open_write(Postings)?),
positions_write: CompositeWrite::wrap(segment.open_write(Positions)?),
schema: segment.schema(),
codec,
};
Ok(inv_index_serializer)
}
@@ -78,7 +75,7 @@ impl<C: Codec> InvertedIndexSerializer<C> {
field: Field,
total_num_tokens: u64,
fieldnorm_reader: Option<FieldNormReader>,
) -> io::Result<FieldSerializer<'_, C>> {
) -> io::Result<FieldSerializer<'_>> {
let field_entry: &FieldEntry = self.schema.get_field_entry(field);
let term_dictionary_write = self.terms_write.for_field(field);
let postings_write = self.postings_write.for_field(field);
@@ -91,7 +88,6 @@ impl<C: Codec> InvertedIndexSerializer<C> {
postings_write,
positions_write,
fieldnorm_reader,
&self.codec,
)
}
@@ -106,9 +102,9 @@ impl<C: Codec> InvertedIndexSerializer<C> {
/// The field serializer is in charge of
/// the serialization of a specific field.
pub struct FieldSerializer<'a, C: Codec> {
pub struct FieldSerializer<'a> {
term_dictionary_builder: TermDictionaryBuilder<&'a mut CountingWriter<WritePtr>>,
postings_serializer: <C::PostingsCodec as PostingsCodec>::PostingsSerializer,
postings_serializer: PostingsSerializer,
positions_serializer_opt: Option<PositionSerializer<&'a mut CountingWriter<WritePtr>>>,
current_term_info: TermInfo,
term_open: bool,
@@ -116,7 +112,7 @@ pub struct FieldSerializer<'a, C: Codec> {
postings_start_offset: u64,
}
impl<'a, C: Codec> FieldSerializer<'a, C> {
impl<'a> FieldSerializer<'a> {
fn create(
field_type: &FieldType,
total_num_tokens: u64,
@@ -124,8 +120,7 @@ impl<'a, C: Codec> FieldSerializer<'a, C> {
postings_write: &'a mut CountingWriter<WritePtr>,
positions_write: &'a mut CountingWriter<WritePtr>,
fieldnorm_reader: Option<FieldNormReader>,
codec: &C,
) -> io::Result<FieldSerializer<'a, C>> {
) -> io::Result<FieldSerializer<'a>> {
total_num_tokens.serialize(postings_write)?;
let index_record_option = field_type
.index_record_option()
@@ -135,11 +130,8 @@ impl<'a, C: Codec> FieldSerializer<'a, C> {
.as_ref()
.map(|ff_reader| total_num_tokens as Score / ff_reader.num_docs() as Score)
.unwrap_or(0.0);
let postings_serializer = codec.postings_codec().new_serializer(
average_fieldnorm,
index_record_option,
fieldnorm_reader,
);
let postings_serializer =
PostingsSerializer::new(average_fieldnorm, index_record_option, fieldnorm_reader);
let positions_serializer_opt = if index_record_option.has_positions() {
Some(PositionSerializer::new(positions_write))
} else {
@@ -192,6 +184,7 @@ impl<'a, C: Codec> FieldSerializer<'a, C> {
"Called new_term, while the previous term was not closed."
);
self.term_open = true;
self.postings_serializer.clear();
self.current_term_info = self.current_term_info();
self.term_dictionary_builder.insert_key(term)?;
self.postings_serializer
@@ -255,3 +248,223 @@ impl<'a, C: Codec> FieldSerializer<'a, C> {
Ok(())
}
}
struct Block {
doc_ids: [DocId; COMPRESSION_BLOCK_SIZE],
term_freqs: [u32; COMPRESSION_BLOCK_SIZE],
len: usize,
}
impl Block {
fn new() -> Self {
Block {
doc_ids: [0u32; COMPRESSION_BLOCK_SIZE],
term_freqs: [0u32; COMPRESSION_BLOCK_SIZE],
len: 0,
}
}
fn doc_ids(&self) -> &[DocId] {
&self.doc_ids[..self.len]
}
fn term_freqs(&self) -> &[u32] {
&self.term_freqs[..self.len]
}
fn clear(&mut self) {
self.len = 0;
}
fn append_doc(&mut self, doc: DocId, term_freq: u32) {
let len = self.len;
self.doc_ids[len] = doc;
self.term_freqs[len] = term_freq;
self.len = len + 1;
}
fn is_full(&self) -> bool {
self.len == COMPRESSION_BLOCK_SIZE
}
fn is_empty(&self) -> bool {
self.len == 0
}
fn last_doc(&self) -> DocId {
assert_eq!(self.len, COMPRESSION_BLOCK_SIZE);
self.doc_ids[COMPRESSION_BLOCK_SIZE - 1]
}
}
pub struct PostingsSerializer {
last_doc_id_encoded: u32,
block_encoder: BlockEncoder,
block: Box<Block>,
postings_write: Vec<u8>,
skip_write: SkipSerializer,
mode: IndexRecordOption,
fieldnorm_reader: Option<FieldNormReader>,
bm25_weight: Option<Bm25Weight>,
avg_fieldnorm: Score, /* Average number of term in the field for that segment.
* this value is used to compute the block wand information. */
term_has_freq: bool,
}
impl PostingsSerializer {
pub fn new(
avg_fieldnorm: Score,
mode: IndexRecordOption,
fieldnorm_reader: Option<FieldNormReader>,
) -> PostingsSerializer {
PostingsSerializer {
block_encoder: BlockEncoder::new(),
block: Box::new(Block::new()),
postings_write: Vec::new(),
skip_write: SkipSerializer::new(),
last_doc_id_encoded: 0u32,
mode,
fieldnorm_reader,
bm25_weight: None,
avg_fieldnorm,
term_has_freq: false,
}
}
pub fn new_term(&mut self, term_doc_freq: u32, record_term_freq: bool) {
self.bm25_weight = None;
self.term_has_freq = self.mode.has_freq() && record_term_freq;
if !self.term_has_freq {
return;
}
let num_docs_in_segment: u64 =
if let Some(fieldnorm_reader) = self.fieldnorm_reader.as_ref() {
fieldnorm_reader.num_docs() as u64
} else {
return;
};
if num_docs_in_segment == 0 {
return;
}
self.bm25_weight = Some(Bm25Weight::for_one_term_without_explain(
term_doc_freq as u64,
num_docs_in_segment,
self.avg_fieldnorm,
));
}
fn write_block(&mut self) {
{
// encode the doc ids
let (num_bits, block_encoded): (u8, &[u8]) = self
.block_encoder
.compress_block_sorted(self.block.doc_ids(), self.last_doc_id_encoded);
self.last_doc_id_encoded = self.block.last_doc();
self.skip_write
.write_doc(self.last_doc_id_encoded, num_bits);
// last el block 0, offset block 1,
self.postings_write.extend(block_encoded);
}
if self.term_has_freq {
let (num_bits, block_encoded): (u8, &[u8]) = self
.block_encoder
.compress_block_unsorted(self.block.term_freqs(), true);
self.postings_write.extend(block_encoded);
self.skip_write.write_term_freq(num_bits);
if self.mode.has_positions() {
// We serialize the sum of term freqs within the skip information
// in order to navigate through positions.
let sum_freq = self.block.term_freqs().iter().cloned().sum();
self.skip_write.write_total_term_freq(sum_freq);
}
let mut blockwand_params = (0u8, 0u32);
if let Some(bm25_weight) = self.bm25_weight.as_ref() {
if let Some(fieldnorm_reader) = self.fieldnorm_reader.as_ref() {
let docs = self.block.doc_ids().iter().cloned();
let term_freqs = self.block.term_freqs().iter().cloned();
let fieldnorms = docs.map(|doc| fieldnorm_reader.fieldnorm_id(doc));
blockwand_params = fieldnorms
.zip(term_freqs)
.max_by(
|(left_fieldnorm_id, left_term_freq),
(right_fieldnorm_id, right_term_freq)| {
let left_score =
bm25_weight.tf_factor(*left_fieldnorm_id, *left_term_freq);
let right_score =
bm25_weight.tf_factor(*right_fieldnorm_id, *right_term_freq);
left_score
.partial_cmp(&right_score)
.unwrap_or(Ordering::Equal)
},
)
.unwrap();
}
}
let (fieldnorm_id, term_freq) = blockwand_params;
self.skip_write.write_blockwand_max(fieldnorm_id, term_freq);
}
self.block.clear();
}
pub fn write_doc(&mut self, doc_id: DocId, term_freq: u32) {
self.block.append_doc(doc_id, term_freq);
if self.block.is_full() {
self.write_block();
}
}
pub fn close_term(
&mut self,
doc_freq: u32,
output_write: &mut impl std::io::Write,
) -> io::Result<()> {
if !self.block.is_empty() {
// we have doc ids waiting to be written
// this happens when the number of doc ids is
// not a perfect multiple of our block size.
//
// In that case, the remaining part is encoded
// using variable int encoding.
{
let block_encoded = self
.block_encoder
.compress_vint_sorted(self.block.doc_ids(), self.last_doc_id_encoded);
self.postings_write.write_all(block_encoded)?;
}
// ... Idem for term frequencies
if self.term_has_freq {
let block_encoded = self
.block_encoder
.compress_vint_unsorted(self.block.term_freqs());
self.postings_write.write_all(block_encoded)?;
}
self.block.clear();
}
if doc_freq >= COMPRESSION_BLOCK_SIZE as u32 {
let skip_data = self.skip_write.data();
VInt(skip_data.len() as u64).serialize(output_write)?;
output_write.write_all(skip_data)?;
}
output_write.write_all(&self.postings_write[..])?;
self.skip_write.clear();
self.postings_write.clear();
self.bm25_weight = None;
Ok(())
}
fn clear(&mut self) {
self.block.clear();
self.last_doc_id_encoded = 0;
}
}

View File

@@ -142,6 +142,23 @@ impl SkipReader {
skip_reader
}
pub fn reset(&mut self, data: OwnedBytes, doc_freq: u32) {
self.last_doc_in_block = if doc_freq >= COMPRESSION_BLOCK_SIZE as u32 {
0
} else {
TERMINATED
};
self.last_doc_in_previous_block = 0u32;
self.owned_read = data;
self.block_info = BlockInfo::VInt { num_docs: doc_freq };
self.byte_offset = 0;
self.remaining_docs = doc_freq;
self.position_offset = 0u64;
if doc_freq >= COMPRESSION_BLOCK_SIZE as u32 {
self.read_block_info();
}
}
// Returns the block max score for this block if available.
//
// The block max score is available for all full bitpacked block,

View File

@@ -2,7 +2,7 @@ use crate::docset::{DocSet, COLLECT_BLOCK_BUFFER_LEN, TERMINATED};
use crate::index::SegmentReader;
use crate::query::boost_query::BoostScorer;
use crate::query::explanation::does_not_match;
use crate::query::{box_scorer, EnableScoring, Explanation, Query, Scorer, Weight};
use crate::query::{EnableScoring, Explanation, Query, Scorer, Weight};
use crate::{DocId, Score};
/// Query that matches all of the documents.
@@ -24,9 +24,9 @@ impl Weight for AllWeight {
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
let all_scorer = AllScorer::new(reader.max_doc());
if boost != 1.0 {
Ok(box_scorer(BoostScorer::new(all_scorer, boost)))
Ok(Box::new(BoostScorer::new(all_scorer, boost)))
} else {
Ok(box_scorer(all_scorer))
Ok(Box::new(all_scorer))
}
}

View File

@@ -10,7 +10,7 @@ use crate::postings::TermInfo;
use crate::query::{BitSetDocSet, ConstScorer, Explanation, Scorer, Weight};
use crate::schema::{Field, IndexRecordOption};
use crate::termdict::{TermDictionary, TermStreamer};
use crate::{DocId, DocSet, Score, TantivyError};
use crate::{DocId, Score, TantivyError};
/// A weight struct for Fuzzy Term and Regex Queries
pub struct AutomatonWeight<A> {
@@ -92,9 +92,18 @@ where
let mut term_stream = self.automaton_stream(term_dict)?;
while term_stream.advance() {
let term_info = term_stream.value();
let mut block_segment_postings =
inverted_index.read_postings_from_terminfo(term_info, IndexRecordOption::Basic)?;
block_segment_postings.fill_bitset(&mut doc_bitset);
let mut block_segment_postings = inverted_index
.read_block_postings_from_terminfo(term_info, IndexRecordOption::Basic)?;
loop {
let docs = block_segment_postings.docs();
if docs.is_empty() {
break;
}
for &doc in docs {
doc_bitset.insert(doc);
}
block_segment_postings.advance();
}
}
let doc_bitset = BitSetDocSet::from(doc_bitset);
let const_scorer = ConstScorer::new(doc_bitset, boost);

View File

@@ -24,13 +24,6 @@ impl BitSetDocSet {
self.cursor_bucket = bucket_addr;
self.cursor_tinybitset = self.docs.tinyset(bucket_addr);
}
/// Returns the number of documents in the bitset.
///
/// This call is not free: it will bitcount the number of bits in the bitset.
pub fn doc_freq(&self) -> u32 {
self.docs.len() as u32
}
}
impl From<BitSet> for BitSetDocSet {

View File

@@ -1,6 +1,5 @@
use std::ops::{Deref, DerefMut};
use crate::codec::postings::PostingsWithBlockMax;
use crate::query::term_query::TermScorer;
use crate::query::Scorer;
use crate::{DocId, DocSet, Score, TERMINATED};
@@ -14,8 +13,8 @@ use crate::{DocId, DocSet, Score, TERMINATED};
/// We always have `before_pivot_len` < `pivot_len`.
///
/// `None` is returned if we establish that no document can exceed the threshold.
fn find_pivot_doc<TPostings: PostingsWithBlockMax>(
term_scorers: &[TermScorerWithMaxScore<TPostings>],
fn find_pivot_doc(
term_scorers: &[TermScorerWithMaxScore],
threshold: Score,
) -> Option<(usize, usize, DocId)> {
let mut max_score = 0.0;
@@ -47,8 +46,8 @@ fn find_pivot_doc<TPostings: PostingsWithBlockMax>(
/// the next doc candidate defined by the min of `last_doc_in_block + 1` for
/// scorer in scorers[..pivot_len] and `scorer.doc()` for scorer in scorers[pivot_len..].
/// Note: before and after calling this method, scorers need to be sorted by their `.doc()`.
fn block_max_was_too_low_advance_one_scorer<TPostings: PostingsWithBlockMax>(
scorers: &mut [TermScorerWithMaxScore<TPostings>],
fn block_max_was_too_low_advance_one_scorer(
scorers: &mut [TermScorerWithMaxScore],
pivot_len: usize,
) {
debug_assert!(is_sorted(scorers.iter().map(|scorer| scorer.doc())));
@@ -83,10 +82,7 @@ fn block_max_was_too_low_advance_one_scorer<TPostings: PostingsWithBlockMax>(
// Given a list of term_scorers and a `ord` and assuming that `term_scorers[ord]` is sorted
// except term_scorers[ord] that might be in advance compared to its ranks,
// bubble up term_scorers[ord] in order to restore the ordering.
fn restore_ordering<TPostings: PostingsWithBlockMax>(
term_scorers: &mut [TermScorerWithMaxScore<TPostings>],
ord: usize,
) {
fn restore_ordering(term_scorers: &mut [TermScorerWithMaxScore], ord: usize) {
let doc = term_scorers[ord].doc();
for i in ord + 1..term_scorers.len() {
if term_scorers[i].doc() >= doc {
@@ -101,10 +97,9 @@ fn restore_ordering<TPostings: PostingsWithBlockMax>(
// If this works, return true.
// If this fails (ie: one of the term_scorer does not contain `pivot_doc` and seek goes past the
// pivot), reorder the term_scorers to ensure the list is still sorted and returns `false`.
// If a term_scorer reach TERMINATED in the process return false remove the term_scorer and
// return.
fn align_scorers<TPostings: PostingsWithBlockMax>(
term_scorers: &mut Vec<TermScorerWithMaxScore<TPostings>>,
// If a term_scorer reach TERMINATED in the process return false remove the term_scorer and return.
fn align_scorers(
term_scorers: &mut Vec<TermScorerWithMaxScore>,
pivot_doc: DocId,
before_pivot_len: usize,
) -> bool {
@@ -131,10 +126,7 @@ fn align_scorers<TPostings: PostingsWithBlockMax>(
// Assumes terms_scorers[..pivot_len] are positioned on the same doc (pivot_doc).
// Advance term_scorers[..pivot_len] and out of these removes the terminated scores.
// Restores the ordering of term_scorers.
fn advance_all_scorers_on_pivot<TPostings: PostingsWithBlockMax>(
term_scorers: &mut Vec<TermScorerWithMaxScore<TPostings>>,
pivot_len: usize,
) {
fn advance_all_scorers_on_pivot(term_scorers: &mut Vec<TermScorerWithMaxScore>, pivot_len: usize) {
for term_scorer in &mut term_scorers[..pivot_len] {
term_scorer.advance();
}
@@ -153,12 +145,12 @@ fn advance_all_scorers_on_pivot<TPostings: PostingsWithBlockMax>(
/// Implements the WAND (Weak AND) algorithm for dynamic pruning
/// described in the paper "Faster Top-k Document Retrieval Using Block-Max Indexes".
/// Link: <http://engineering.nyu.edu/~suel/papers/bmw.pdf>
pub fn block_wand<TPostings: PostingsWithBlockMax>(
mut scorers: Vec<TermScorer<TPostings>>,
pub fn block_wand(
mut scorers: Vec<TermScorer>,
mut threshold: Score,
callback: &mut dyn FnMut(u32, Score) -> Score,
) {
let mut scorers: Vec<TermScorerWithMaxScore<TPostings>> = scorers
let mut scorers: Vec<TermScorerWithMaxScore> = scorers
.iter_mut()
.map(TermScorerWithMaxScore::from)
.collect();
@@ -174,7 +166,10 @@ pub fn block_wand<TPostings: PostingsWithBlockMax>(
let block_max_score_upperbound: Score = scorers[..pivot_len]
.iter_mut()
.map(|scorer| scorer.seek_block_max(pivot_doc))
.map(|scorer| {
scorer.seek_block(pivot_doc);
scorer.block_max_score()
})
.sum();
// Beware after shallow advance, skip readers can be in advance compared to
@@ -225,22 +220,21 @@ pub fn block_wand<TPostings: PostingsWithBlockMax>(
/// - On a block, advance until the end and execute `callback` when the doc score is greater or
/// equal to the `threshold`.
pub fn block_wand_single_scorer(
mut scorer: TermScorer<impl PostingsWithBlockMax>,
mut scorer: TermScorer,
mut threshold: Score,
callback: &mut dyn FnMut(u32, Score) -> Score,
) {
let mut doc = scorer.doc();
let mut block_max_score = scorer.seek_block_max(doc);
loop {
// We position the scorer on a block that can reach
// the threshold.
while block_max_score < threshold {
while scorer.block_max_score() < threshold {
let last_doc_in_block = scorer.last_doc_in_block();
if last_doc_in_block == TERMINATED {
return;
}
doc = last_doc_in_block + 1;
block_max_score = scorer.seek_block_max(doc);
scorer.seek_block(doc);
}
// Seek will effectively load that block.
doc = scorer.seek(doc);
@@ -262,33 +256,31 @@ pub fn block_wand_single_scorer(
}
}
doc += 1;
block_max_score = scorer.seek_block_max(doc);
scorer.seek_block(doc);
}
}
struct TermScorerWithMaxScore<'a, TPostings: PostingsWithBlockMax> {
scorer: &'a mut TermScorer<TPostings>,
struct TermScorerWithMaxScore<'a> {
scorer: &'a mut TermScorer,
max_score: Score,
}
impl<'a, TPostings: PostingsWithBlockMax> From<&'a mut TermScorer<TPostings>>
for TermScorerWithMaxScore<'a, TPostings>
{
fn from(scorer: &'a mut TermScorer<TPostings>) -> Self {
impl<'a> From<&'a mut TermScorer> for TermScorerWithMaxScore<'a> {
fn from(scorer: &'a mut TermScorer) -> Self {
let max_score = scorer.max_score();
TermScorerWithMaxScore { scorer, max_score }
}
}
impl<TPostings: PostingsWithBlockMax> Deref for TermScorerWithMaxScore<'_, TPostings> {
type Target = TermScorer<TPostings>;
impl Deref for TermScorerWithMaxScore<'_> {
type Target = TermScorer;
fn deref(&self) -> &Self::Target {
self.scorer
}
}
impl<TPostings: PostingsWithBlockMax> DerefMut for TermScorerWithMaxScore<'_, TPostings> {
impl DerefMut for TermScorerWithMaxScore<'_> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.scorer
}

View File

@@ -1,18 +1,24 @@
use std::collections::HashMap;
use crate::codec::{ObjectSafeCodec, SumOrDoNothingCombiner};
use crate::docset::COLLECT_BLOCK_BUFFER_LEN;
use crate::index::SegmentReader;
use crate::postings::FreqReadingOption;
use crate::query::disjunction::Disjunction;
use crate::query::explanation::does_not_match;
use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner};
use crate::query::weight::for_each_docset_buffered;
use crate::query::term_query::TermScorer;
use crate::query::weight::{for_each_docset_buffered, for_each_pruning_scorer, for_each_scorer};
use crate::query::{
box_scorer, intersect_scorers, AllScorer, BufferedUnionScorer, EmptyScorer, Exclude,
Explanation, Occur, RequiredOptionalScorer, Scorer, SumCombiner, Weight,
intersect_scorers, AllScorer, BufferedUnionScorer, EmptyScorer, Exclude, Explanation, Occur,
RequiredOptionalScorer, Scorer, Weight,
};
use crate::{DocId, Score};
enum SpecializedScorer {
TermUnion(Vec<TermScorer>),
Other(Box<dyn Scorer>),
}
fn scorer_disjunction<TScoreCombiner>(
scorers: Vec<Box<dyn Scorer>>,
score_combiner: TScoreCombiner,
@@ -26,7 +32,7 @@ where
if scorers.len() == 1 {
return scorers.into_iter().next().unwrap(); // Safe unwrap.
}
box_scorer(Disjunction::new(
Box::new(Disjunction::new(
scorers,
score_combiner,
minimum_match_required,
@@ -38,41 +44,57 @@ fn scorer_union<TScoreCombiner>(
scorers: Vec<Box<dyn Scorer>>,
score_combiner_fn: impl Fn() -> TScoreCombiner,
num_docs: u32,
codec: &dyn ObjectSafeCodec,
) -> Box<dyn Scorer>
) -> SpecializedScorer
where
TScoreCombiner: ScoreCombiner,
{
match scorers.len() {
0 => box_scorer(EmptyScorer),
1 => scorers.into_iter().next().unwrap(),
_ => {
let combiner_opt: Option<SumOrDoNothingCombiner> = if std::any::TypeId::of::<
TScoreCombiner,
>() == std::any::TypeId::of::<
SumCombiner,
>() {
Some(SumOrDoNothingCombiner::Sum)
} else if std::any::TypeId::of::<TScoreCombiner>()
== std::any::TypeId::of::<DoNothingCombiner>()
assert!(!scorers.is_empty());
if scorers.len() == 1 {
return SpecializedScorer::Other(scorers.into_iter().next().unwrap()); //< we checked the size beforehand
}
{
let is_all_term_queries = scorers.iter().all(|scorer| scorer.is::<TermScorer>());
if is_all_term_queries {
let scorers: Vec<TermScorer> = scorers
.into_iter()
.map(|scorer| *(scorer.downcast::<TermScorer>().map_err(|_| ()).unwrap()))
.collect();
if scorers
.iter()
.all(|scorer| scorer.freq_reading_option() == FreqReadingOption::ReadFreq)
{
Some(SumOrDoNothingCombiner::DoNothing)
// Block wand is only available if we read frequencies.
return SpecializedScorer::TermUnion(scorers);
} else {
None
};
if let Some(combiner) = combiner_opt {
let scorer =
codec.build_union_scorer_with_sum_combiner(scorers, num_docs, combiner);
scorer
} else {
box_scorer(BufferedUnionScorer::build(
return SpecializedScorer::Other(Box::new(BufferedUnionScorer::build(
scorers,
score_combiner_fn,
num_docs,
))
)));
}
}
}
SpecializedScorer::Other(Box::new(BufferedUnionScorer::build(
scorers,
score_combiner_fn,
num_docs,
)))
}
fn into_box_scorer<TScoreCombiner: ScoreCombiner>(
scorer: SpecializedScorer,
score_combiner_fn: impl Fn() -> TScoreCombiner,
num_docs: u32,
) -> Box<dyn Scorer> {
match scorer {
SpecializedScorer::TermUnion(term_scorers) => {
let union_scorer =
BufferedUnionScorer::build(term_scorers, score_combiner_fn, num_docs);
Box::new(union_scorer)
}
SpecializedScorer::Other(scorer) => scorer,
}
}
/// Returns the effective MUST scorer, accounting for removed AllScorers.
@@ -88,7 +110,7 @@ fn effective_must_scorer(
if must_scorers.is_empty() {
if removed_all_scorer_count > 0 {
// Had AllScorer(s) only - all docs match
Some(box_scorer(AllScorer::new(max_doc)))
Some(Box::new(AllScorer::new(max_doc)))
} else {
// No MUST constraint at all
None
@@ -106,26 +128,28 @@ fn effective_must_scorer(
/// When `scoring_enabled` is false, we can just return AllScorer alone since
/// we don't need score contributions from the should_scorer.
fn effective_should_scorer_for_union<TScoreCombiner: ScoreCombiner>(
should_scorer: Box<dyn Scorer>,
should_scorer: SpecializedScorer,
removed_all_scorer_count: usize,
max_doc: DocId,
num_docs: u32,
score_combiner_fn: impl Fn() -> TScoreCombiner,
scoring_enabled: bool,
) -> Box<dyn Scorer> {
) -> SpecializedScorer {
if removed_all_scorer_count > 0 {
if scoring_enabled {
// Need to union to get score contributions from both
let all_scorers: Vec<Box<dyn Scorer>> =
vec![should_scorer, box_scorer(AllScorer::new(max_doc))];
box_scorer(BufferedUnionScorer::build(
let all_scorers: Vec<Box<dyn Scorer>> = vec![
into_box_scorer(should_scorer, &score_combiner_fn, num_docs),
Box::new(AllScorer::new(max_doc)),
];
SpecializedScorer::Other(Box::new(BufferedUnionScorer::build(
all_scorers,
score_combiner_fn,
num_docs,
))
)))
} else {
// Scoring disabled - AllScorer alone is sufficient
box_scorer(AllScorer::new(max_doc))
SpecializedScorer::Other(Box::new(AllScorer::new(max_doc)))
}
} else {
should_scorer
@@ -136,9 +160,9 @@ enum ShouldScorersCombinationMethod {
// Should scorers are irrelevant.
Ignored,
// Only contributes to final score.
Optional(Box<dyn Scorer>),
Optional(SpecializedScorer),
// Regardless of score, the should scorers may impact whether a document is matching or not.
Required(Box<dyn Scorer>),
Required(SpecializedScorer),
}
/// Weight associated to the `BoolQuery`.
@@ -200,7 +224,7 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
reader: &SegmentReader,
boost: Score,
score_combiner_fn: impl Fn() -> TComplexScoreCombiner,
) -> crate::Result<Box<dyn Scorer>> {
) -> crate::Result<SpecializedScorer> {
let num_docs = reader.num_docs();
let mut per_occur_scorers = self.per_occur_scorers(reader, boost)?;
@@ -210,7 +234,7 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
let must_special_scorer_counts = remove_and_count_all_and_empty_scorers(&mut must_scorers);
if must_special_scorer_counts.num_empty_scorers > 0 {
return Ok(box_scorer(EmptyScorer));
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)));
}
let mut should_scorers = per_occur_scorers.remove(&Occur::Should).unwrap_or_default();
@@ -225,7 +249,7 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
if exclude_special_scorer_counts.num_all_scorers > 0 {
// We exclude all documents at one point.
return Ok(box_scorer(EmptyScorer));
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)));
}
let effective_minimum_number_should_match = self
@@ -237,7 +261,7 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
if effective_minimum_number_should_match > num_of_should_scorers {
// We don't have enough scorers to satisfy the minimum number of should matches.
// The request will match no documents.
return Ok(box_scorer(EmptyScorer));
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)));
}
match effective_minimum_number_should_match {
0 if num_of_should_scorers == 0 => ShouldScorersCombinationMethod::Ignored,
@@ -245,13 +269,11 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
should_scorers,
&score_combiner_fn,
num_docs,
reader.codec(),
)),
1 => ShouldScorersCombinationMethod::Required(scorer_union(
should_scorers,
&score_combiner_fn,
num_docs,
reader.codec(),
)),
n if num_of_should_scorers == n => {
// When num_of_should_scorers equals the number of should clauses,
@@ -259,26 +281,16 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
must_scorers.append(&mut should_scorers);
ShouldScorersCombinationMethod::Ignored
}
_ => ShouldScorersCombinationMethod::Required(scorer_disjunction(
should_scorers,
score_combiner_fn(),
effective_minimum_number_should_match,
_ => ShouldScorersCombinationMethod::Required(SpecializedScorer::Other(
scorer_disjunction(
should_scorers,
score_combiner_fn(),
effective_minimum_number_should_match,
),
)),
}
};
let exclude_scorer_opt: Option<Box<dyn Scorer>> = if exclude_scorers.is_empty() {
None
} else {
let exclude_scorers_union: Box<dyn Scorer> = scorer_union(
exclude_scorers,
DoNothingCombiner::default,
num_docs,
reader.codec(),
);
Some(exclude_scorers_union)
};
let include_scorer = match (should_scorers, must_scorers) {
(ShouldScorersCombinationMethod::Ignored, must_scorers) => {
// No SHOULD clauses (or they were absorbed into MUST).
@@ -291,8 +303,8 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
reader.max_doc(),
num_docs,
)
.unwrap_or_else(|| box_scorer(EmptyScorer));
boxed_scorer
.unwrap_or_else(|| Box::new(EmptyScorer));
SpecializedScorer::Other(boxed_scorer)
}
(ShouldScorersCombinationMethod::Optional(should_scorer), must_scorers) => {
// Optional SHOULD: contributes to scoring but not required for matching.
@@ -317,12 +329,16 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
Some(must_scorer) => {
// Has MUST constraint: SHOULD only affects scoring.
if self.scoring_enabled {
box_scorer(RequiredOptionalScorer::<_, _, TScoreCombiner>::new(
SpecializedScorer::Other(Box::new(RequiredOptionalScorer::<
_,
_,
TScoreCombiner,
>::new(
must_scorer,
should_scorer,
))
into_box_scorer(should_scorer, &score_combiner_fn, num_docs),
)))
} else {
must_scorer
SpecializedScorer::Other(must_scorer)
}
}
}
@@ -342,16 +358,33 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
}
Some(must_scorer) => {
// Has MUST constraint: intersect MUST with SHOULD.
intersect_scorers(vec![must_scorer, should_scorer], num_docs)
let should_boxed =
into_box_scorer(should_scorer, &score_combiner_fn, num_docs);
SpecializedScorer::Other(intersect_scorers(
vec![must_scorer, should_boxed],
num_docs,
))
}
}
}
};
if let Some(exclude_scorer) = exclude_scorer_opt {
Ok(box_scorer(Exclude::new(include_scorer, exclude_scorer)))
} else {
Ok(include_scorer)
if exclude_scorers.is_empty() {
return Ok(include_scorer);
}
let include_scorer_boxed = into_box_scorer(include_scorer, &score_combiner_fn, num_docs);
let scorer: Box<dyn Scorer> = if exclude_scorers.len() == 1 {
let exclude_scorer = exclude_scorers.pop().unwrap();
match exclude_scorer.downcast::<TermScorer>() {
// Cast to TermScorer succeeded
Ok(exclude_scorer) => Box::new(Exclude::new(include_scorer_boxed, *exclude_scorer)),
// We get back the original Box<dyn Scorer>
Err(exclude_scorer) => Box::new(Exclude::new(include_scorer_boxed, exclude_scorer)),
}
} else {
Box::new(Exclude::new(include_scorer_boxed, exclude_scorers))
};
Ok(SpecializedScorer::Other(scorer))
}
}
@@ -381,6 +414,7 @@ fn remove_and_count_all_and_empty_scorers(
impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombiner> {
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
let num_docs = reader.num_docs();
if self.weights.is_empty() {
Ok(Box::new(EmptyScorer))
} else if self.weights.len() == 1 {
@@ -392,8 +426,14 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
}
} else if self.scoring_enabled {
self.complex_scorer(reader, boost, &self.score_combiner_fn)
.map(|specialized_scorer| {
into_box_scorer(specialized_scorer, &self.score_combiner_fn, num_docs)
})
} else {
self.complex_scorer(reader, boost, DoNothingCombiner::default)
.map(|specialized_scorer| {
into_box_scorer(specialized_scorer, DoNothingCombiner::default, num_docs)
})
}
}
@@ -422,8 +462,20 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
reader: &SegmentReader,
callback: &mut dyn FnMut(DocId, Score),
) -> crate::Result<()> {
let mut scorer = self.complex_scorer(reader, 1.0, &self.score_combiner_fn)?;
scorer.for_each(callback);
let scorer = self.complex_scorer(reader, 1.0, &self.score_combiner_fn)?;
match scorer {
SpecializedScorer::TermUnion(term_scorers) => {
let mut union_scorer = BufferedUnionScorer::build(
term_scorers,
&self.score_combiner_fn,
reader.num_docs(),
);
for_each_scorer(&mut union_scorer, callback);
}
SpecializedScorer::Other(mut scorer) => {
for_each_scorer(scorer.as_mut(), callback);
}
}
Ok(())
}
@@ -432,9 +484,22 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
reader: &SegmentReader,
callback: &mut dyn FnMut(&[DocId]),
) -> crate::Result<()> {
let mut scorer = self.complex_scorer(reader, 1.0, || DoNothingCombiner)?;
let scorer = self.complex_scorer(reader, 1.0, || DoNothingCombiner)?;
let mut buffer = [0u32; COLLECT_BLOCK_BUFFER_LEN];
for_each_docset_buffered(scorer.as_mut(), &mut buffer, callback);
match scorer {
SpecializedScorer::TermUnion(term_scorers) => {
let mut union_scorer = BufferedUnionScorer::build(
term_scorers,
&self.score_combiner_fn,
reader.num_docs(),
);
for_each_docset_buffered(&mut union_scorer, &mut buffer, callback);
}
SpecializedScorer::Other(mut scorer) => {
for_each_docset_buffered(scorer.as_mut(), &mut buffer, callback);
}
}
Ok(())
}
@@ -455,7 +520,14 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
callback: &mut dyn FnMut(DocId, Score) -> Score,
) -> crate::Result<()> {
let scorer = self.complex_scorer(reader, 1.0, &self.score_combiner_fn)?;
reader.codec().for_each_pruning(threshold, scorer, callback);
match scorer {
SpecializedScorer::TermUnion(term_scorers) => {
super::block_wand(term_scorers, threshold, callback);
}
SpecializedScorer::Other(mut scorer) => {
for_each_pruning_scorer(scorer.as_mut(), threshold, callback);
}
}
Ok(())
}
}

View File

@@ -1,6 +1,8 @@
mod block_wand;
mod boolean_query;
mod boolean_weight;
pub(crate) use self::block_wand::{block_wand, block_wand_single_scorer};
pub use self::boolean_query::BooleanQuery;
pub use self::boolean_weight::BooleanWeight;

View File

@@ -1,7 +1,7 @@
use std::fmt;
use crate::docset::COLLECT_BLOCK_BUFFER_LEN;
use crate::query::{box_scorer, EnableScoring, Explanation, Query, Scorer, Weight};
use crate::query::{EnableScoring, Explanation, Query, Scorer, Weight};
use crate::{DocId, DocSet, Score, SegmentReader, TantivyError, Term};
/// `ConstScoreQuery` is a wrapper over a query to provide a constant score.
@@ -65,10 +65,7 @@ impl ConstWeight {
impl Weight for ConstWeight {
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
let inner_scorer = self.weight.scorer(reader, boost)?;
Ok(box_scorer(ConstScorer::new(
inner_scorer,
boost * self.score,
)))
Ok(Box::new(ConstScorer::new(inner_scorer, boost * self.score)))
}
fn explain(&self, reader: &SegmentReader, doc: u32) -> crate::Result<Explanation> {

View File

@@ -2,7 +2,7 @@ use super::Scorer;
use crate::docset::TERMINATED;
use crate::index::SegmentReader;
use crate::query::explanation::does_not_match;
use crate::query::{box_scorer, EnableScoring, Explanation, Query, Weight};
use crate::query::{EnableScoring, Explanation, Query, Weight};
use crate::{DocId, DocSet, Score, Searcher};
/// `EmptyQuery` is a dummy `Query` in which no document matches.
@@ -27,7 +27,7 @@ impl Query for EmptyQuery {
pub struct EmptyWeight;
impl Weight for EmptyWeight {
fn scorer(&self, _reader: &SegmentReader, _boost: Score) -> crate::Result<Box<dyn Scorer>> {
Ok(box_scorer(EmptyScorer))
Ok(Box::new(EmptyScorer))
}
fn explain(&self, _reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {

View File

@@ -1,48 +1,71 @@
use crate::docset::{DocSet, TERMINATED};
use crate::docset::{DocSet, SeekDangerResult, TERMINATED};
use crate::query::Scorer;
use crate::{DocId, Score};
#[inline]
fn is_within<TDocSetExclude: DocSet>(docset: &mut TDocSetExclude, doc: DocId) -> bool {
docset.doc() <= doc && docset.seek(doc) == doc
}
/// Filters a given `DocSet` by removing the docs from a given `DocSet`.
/// An exclusion set is a set of documents
/// that should be excluded from a given DocSet.
///
/// The excluding docset has no impact on scoring.
pub struct Exclude<TDocSet, TDocSetExclude> {
underlying_docset: TDocSet,
excluding_docset: TDocSetExclude,
/// It can be a single DocSet, or a Vec of DocSets.
pub trait ExclusionSet: Send {
/// Returns `true` if the given `doc` is in the exclusion set.
fn contains(&mut self, doc: DocId) -> bool;
}
impl<TDocSet, TDocSetExclude> Exclude<TDocSet, TDocSetExclude>
impl<TDocSet: DocSet> ExclusionSet for TDocSet {
#[inline]
fn contains(&mut self, doc: DocId) -> bool {
self.seek_danger(doc) == SeekDangerResult::Found
}
}
impl<TDocSet: DocSet> ExclusionSet for Vec<TDocSet> {
#[inline]
fn contains(&mut self, doc: DocId) -> bool {
for docset in self.iter_mut() {
if docset.seek_danger(doc) == SeekDangerResult::Found {
return true;
}
}
false
}
}
/// Filters a given `DocSet` by removing the docs from an exclusion set.
///
/// The excluding docsets have no impact on scoring.
pub struct Exclude<TDocSet, TExclusionSet> {
underlying_docset: TDocSet,
exclusion_set: TExclusionSet,
}
impl<TDocSet, TExclusionSet> Exclude<TDocSet, TExclusionSet>
where
TDocSet: DocSet,
TDocSetExclude: DocSet,
TExclusionSet: ExclusionSet,
{
/// Creates a new `ExcludeScorer`
pub fn new(
mut underlying_docset: TDocSet,
mut excluding_docset: TDocSetExclude,
) -> Exclude<TDocSet, TDocSetExclude> {
mut exclusion_set: TExclusionSet,
) -> Exclude<TDocSet, TExclusionSet> {
while underlying_docset.doc() != TERMINATED {
let target = underlying_docset.doc();
if !is_within(&mut excluding_docset, target) {
if !exclusion_set.contains(target) {
break;
}
underlying_docset.advance();
}
Exclude {
underlying_docset,
excluding_docset,
exclusion_set,
}
}
}
impl<TDocSet, TDocSetExclude> DocSet for Exclude<TDocSet, TDocSetExclude>
impl<TDocSet, TExclusionSet> DocSet for Exclude<TDocSet, TExclusionSet>
where
TDocSet: DocSet,
TDocSetExclude: DocSet,
TExclusionSet: ExclusionSet,
{
fn advance(&mut self) -> DocId {
loop {
@@ -50,7 +73,7 @@ where
if candidate == TERMINATED {
return TERMINATED;
}
if !is_within(&mut self.excluding_docset, candidate) {
if !self.exclusion_set.contains(candidate) {
return candidate;
}
}
@@ -61,7 +84,7 @@ where
if candidate == TERMINATED {
return TERMINATED;
}
if !is_within(&mut self.excluding_docset, candidate) {
if !self.exclusion_set.contains(candidate) {
return candidate;
}
self.advance()
@@ -79,10 +102,10 @@ where
}
}
impl<TScorer, TDocSetExclude> Scorer for Exclude<TScorer, TDocSetExclude>
impl<TScorer, TExclusionSet> Scorer for Exclude<TScorer, TExclusionSet>
where
TScorer: Scorer,
TDocSetExclude: DocSet + 'static,
TExclusionSet: ExclusionSet + 'static,
{
#[inline]
fn score(&mut self) -> Score {

View File

@@ -3,7 +3,7 @@ use core::fmt::Debug;
use columnar::{ColumnIndex, DynamicColumn};
use common::BitSet;
use super::{box_scorer, ConstScorer, EmptyScorer};
use super::{ConstScorer, EmptyScorer};
use crate::docset::{DocSet, TERMINATED};
use crate::index::SegmentReader;
use crate::query::all_query::AllScorer;
@@ -117,7 +117,7 @@ impl Weight for ExistsWeight {
}
}
if non_empty_columns.is_empty() {
return Ok(box_scorer(EmptyScorer));
return Ok(Box::new(EmptyScorer));
}
// If any column is full, all docs match.
@@ -128,9 +128,9 @@ impl Weight for ExistsWeight {
{
let all_scorer = AllScorer::new(max_doc);
if boost != 1.0f32 {
return Ok(box_scorer(BoostScorer::new(all_scorer, boost)));
return Ok(Box::new(BoostScorer::new(all_scorer, boost)));
} else {
return Ok(box_scorer(all_scorer));
return Ok(Box::new(all_scorer));
}
}
@@ -138,7 +138,7 @@ impl Weight for ExistsWeight {
// NOTE: A lower number may be better for very sparse columns
if non_empty_columns.len() < 4 {
let docset = ExistsDocSet::new(non_empty_columns, reader.max_doc());
return Ok(box_scorer(ConstScorer::new(docset, boost)));
return Ok(Box::new(ConstScorer::new(docset, boost)));
}
// If we have many dynamic columns, precompute a bitset of matching docs
@@ -162,7 +162,7 @@ impl Weight for ExistsWeight {
}
}
let docset = BitSetDocSet::from(doc_bitset);
Ok(box_scorer(ConstScorer::new(docset, boost)))
Ok(Box::new(ConstScorer::new(docset, boost)))
}
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {

View File

@@ -1,7 +1,7 @@
use super::size_hint::estimate_intersection;
use crate::docset::{DocSet, SeekDangerResult, TERMINATED};
use crate::query::term_query::TermScorer;
use crate::query::{box_scorer, EmptyScorer, Scorer};
use crate::query::{EmptyScorer, Scorer};
use crate::{DocId, Score};
/// Returns the intersection scorer.
@@ -20,7 +20,7 @@ pub fn intersect_scorers(
num_docs_segment: u32,
) -> Box<dyn Scorer> {
if scorers.is_empty() {
return box_scorer(EmptyScorer);
return Box::new(EmptyScorer);
}
if scorers.len() == 1 {
return scorers.pop().unwrap();
@@ -29,7 +29,7 @@ pub fn intersect_scorers(
scorers.sort_by_key(|scorer| scorer.cost());
let doc = go_to_first_doc(&mut scorers[..]);
if doc == TERMINATED {
return box_scorer(EmptyScorer);
return Box::new(EmptyScorer);
}
// We know that we have at least 2 elements.
let left = scorers.remove(0);
@@ -38,14 +38,14 @@ pub fn intersect_scorers(
.iter()
.all(|&scorer| scorer.is::<TermScorer>());
if all_term_scorers {
return box_scorer(Intersection {
return Box::new(Intersection {
left: *(left.downcast::<TermScorer>().map_err(|_| ()).unwrap()),
right: *(right.downcast::<TermScorer>().map_err(|_| ()).unwrap()),
others: scorers,
num_docs: num_docs_segment,
});
}
box_scorer(Intersection {
Box::new(Intersection {
left,
right,
others: scorers,

View File

@@ -24,7 +24,7 @@ mod reqopt_scorer;
mod scorer;
mod set_query;
mod size_hint;
pub(crate) mod term_query;
mod term_query;
mod union;
mod weight;
@@ -43,7 +43,7 @@ pub use self::boost_query::{BoostQuery, BoostWeight};
pub use self::const_score_query::{ConstScoreQuery, ConstScorer};
pub use self::disjunction_max_query::DisjunctionMaxQuery;
pub use self::empty_query::{EmptyQuery, EmptyScorer, EmptyWeight};
pub use self::exclude::Exclude;
pub use self::exclude::{Exclude, ExclusionSet};
pub use self::exist_query::ExistsQuery;
pub use self::explanation::Explanation;
#[cfg(test)]
@@ -54,14 +54,13 @@ pub use self::more_like_this::{MoreLikeThisQuery, MoreLikeThisQueryBuilder};
pub use self::phrase_prefix_query::PhrasePrefixQuery;
pub use self::phrase_query::regex_phrase_query::{wildcard_query_to_regex_str, RegexPhraseQuery};
pub use self::phrase_query::PhraseQuery;
pub(crate) use self::phrase_query::PhraseScorer;
pub use self::query::{EnableScoring, Query, QueryClone};
pub use self::query_parser::{QueryParser, QueryParserError};
pub use self::range_query::*;
pub use self::regex_query::RegexQuery;
pub use self::reqopt_scorer::RequiredOptionalScorer;
pub use self::score_combiner::{DisjunctionMaxCombiner, ScoreCombiner, SumCombiner};
pub use self::scorer::{box_scorer, Scorer};
pub use self::scorer::Scorer;
pub use self::set_query::TermSetQuery;
pub use self::term_query::TermQuery;
pub use self::union::BufferedUnionScorer;

View File

@@ -2,7 +2,7 @@ use crate::docset::{DocSet, SeekDangerResult, TERMINATED};
use crate::fieldnorm::FieldNormReader;
use crate::postings::Postings;
use crate::query::bm25::Bm25Weight;
use crate::query::phrase_query::{intersection_exists, PhraseScorer};
use crate::query::phrase_query::{intersection_count, PhraseScorer};
use crate::query::Scorer;
use crate::{DocId, Score};
@@ -100,6 +100,7 @@ pub struct PhrasePrefixScorer<TPostings: Postings> {
phrase_scorer: PhraseKind<TPostings>,
suffixes: Vec<TPostings>,
suffix_offset: u32,
phrase_count: u32,
suffix_position_buffer: Vec<u32>,
}
@@ -143,6 +144,7 @@ impl<TPostings: Postings> PhrasePrefixScorer<TPostings> {
phrase_scorer,
suffixes,
suffix_offset: (max_offset - suffix_pos) as u32,
phrase_count: 0,
suffix_position_buffer: Vec::with_capacity(100),
};
if phrase_prefix_scorer.doc() != TERMINATED && !phrase_prefix_scorer.matches_prefix() {
@@ -151,7 +153,12 @@ impl<TPostings: Postings> PhrasePrefixScorer<TPostings> {
phrase_prefix_scorer
}
pub fn phrase_count(&self) -> u32 {
self.phrase_count
}
fn matches_prefix(&mut self) -> bool {
let mut count = 0;
let current_doc = self.doc();
let pos_matching = self.phrase_scorer.get_intersection();
for suffix in &mut self.suffixes {
@@ -161,12 +168,11 @@ impl<TPostings: Postings> PhrasePrefixScorer<TPostings> {
let doc = suffix.seek(current_doc);
if doc == current_doc {
suffix.positions_with_offset(self.suffix_offset, &mut self.suffix_position_buffer);
if intersection_exists(pos_matching, &self.suffix_position_buffer) {
return true;
}
count += intersection_count(pos_matching, &self.suffix_position_buffer);
}
}
false
self.phrase_count = count as u32;
count != 0
}
}

View File

@@ -1,11 +1,12 @@
use super::{prefix_end, PhrasePrefixScorer};
use crate::fieldnorm::FieldNormReader;
use crate::index::SegmentReader;
use crate::postings::Postings;
use crate::postings::SegmentPostings;
use crate::query::bm25::Bm25Weight;
use crate::query::{box_scorer, EmptyScorer, Scorer, Weight};
use crate::query::explanation::does_not_match;
use crate::query::{EmptyScorer, Explanation, Scorer, Weight};
use crate::schema::{IndexRecordOption, Term};
use crate::Score;
use crate::{DocId, DocSet, Score};
pub struct PhrasePrefixWeight {
phrase_terms: Vec<(usize, Term)>,
@@ -45,13 +46,13 @@ impl PhrasePrefixWeight {
&self,
reader: &SegmentReader,
boost: Score,
) -> crate::Result<Option<Box<dyn Scorer>>> {
) -> crate::Result<Option<PhrasePrefixScorer<SegmentPostings>>> {
let similarity_weight_opt = self
.similarity_weight_opt
.as_ref()
.map(|similarity_weight| similarity_weight.boost_by(boost));
let fieldnorm_reader = self.fieldnorm_reader(reader)?;
let mut term_postings_list: Vec<(usize, Box<dyn Postings>)> = Vec::new();
let mut term_postings_list = Vec::new();
for &(offset, ref term) in &self.phrase_terms {
if let Some(postings) = reader
.inverted_index(term.field())?
@@ -102,32 +103,49 @@ impl PhrasePrefixWeight {
}
}
Ok(Some(box_scorer(PhrasePrefixScorer::new(
Ok(Some(PhrasePrefixScorer::new(
term_postings_list,
similarity_weight_opt,
fieldnorm_reader,
suffixes,
self.prefix.0,
))))
)))
}
}
impl Weight for PhrasePrefixWeight {
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
if let Some(scorer) = self.phrase_scorer(reader, boost)? {
Ok(scorer)
Ok(Box::new(scorer))
} else {
Ok(box_scorer(EmptyScorer))
Ok(Box::new(EmptyScorer))
}
}
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {
let scorer_opt = self.phrase_scorer(reader, 1.0)?;
if scorer_opt.is_none() {
return Err(does_not_match(doc));
}
let mut scorer = scorer_opt.unwrap();
if scorer.seek(doc) != doc {
return Err(does_not_match(doc));
}
let fieldnorm_reader = self.fieldnorm_reader(reader)?;
let fieldnorm_id = fieldnorm_reader.fieldnorm_id(doc);
let phrase_count = scorer.phrase_count();
let mut explanation = Explanation::new("Phrase Prefix Scorer", scorer.score());
if let Some(similarity_weight) = self.similarity_weight_opt.as_ref() {
explanation.add_detail(similarity_weight.explain(fieldnorm_id, phrase_count));
}
Ok(explanation)
}
}
#[cfg(test)]
mod tests {
use crate::docset::TERMINATED;
use crate::index::Index;
use crate::postings::Postings;
use crate::query::phrase_prefix_query::PhrasePrefixScorer;
use crate::query::{EnableScoring, PhrasePrefixQuery, Query};
use crate::schema::{Schema, TEXT};
use crate::{DocSet, IndexWriter, Term};
@@ -168,14 +186,14 @@ mod tests {
.phrase_prefix_query_weight(enable_scoring)
.unwrap()
.unwrap();
let mut phrase_scorer_boxed = phrase_weight
let mut phrase_scorer = phrase_weight
.phrase_scorer(searcher.segment_reader(0u32), 1.0)?
.unwrap();
let phrase_scorer: &mut PhrasePrefixScorer<Box<dyn Postings>> =
phrase_scorer_boxed.as_any_mut().downcast_mut().unwrap();
assert_eq!(phrase_scorer.doc(), 1);
assert_eq!(phrase_scorer.phrase_count(), 2);
assert_eq!(phrase_scorer.advance(), 2);
assert_eq!(phrase_scorer.doc(), 2);
assert_eq!(phrase_scorer.phrase_count(), 1);
assert_eq!(phrase_scorer.advance(), TERMINATED);
Ok(())
}
@@ -195,15 +213,14 @@ mod tests {
.phrase_prefix_query_weight(enable_scoring)
.unwrap()
.unwrap();
let mut phrase_scorer_boxed = phrase_weight
let mut phrase_scorer = phrase_weight
.phrase_scorer(searcher.segment_reader(0u32), 1.0)?
.unwrap();
let phrase_scorer = phrase_scorer_boxed
.downcast_mut::<PhrasePrefixScorer<Box<dyn Postings>>>()
.unwrap();
assert_eq!(phrase_scorer.doc(), 1);
assert_eq!(phrase_scorer.phrase_count(), 2);
assert_eq!(phrase_scorer.advance(), 2);
assert_eq!(phrase_scorer.doc(), 2);
assert_eq!(phrase_scorer.phrase_count(), 1);
assert_eq!(phrase_scorer.advance(), TERMINATED);
Ok(())
}

View File

@@ -5,7 +5,7 @@ pub mod regex_phrase_query;
mod regex_phrase_weight;
pub use self::phrase_query::PhraseQuery;
pub(crate) use self::phrase_scorer::intersection_exists;
pub(crate) use self::phrase_scorer::intersection_count;
pub use self::phrase_scorer::PhraseScorer;
pub use self::phrase_weight::PhraseWeight;

Some files were not shown because too many files have changed in this diff Show More