Compare commits

..

1 Commits

Author SHA1 Message Date
Pascal Seitz
d1555fe9f8 SegmentReader as trait 2026-01-27 13:56:40 +01:00
104 changed files with 809 additions and 6275 deletions

View File

@@ -1,125 +0,0 @@
---
name: rationalize-deps
description: Analyze Cargo.toml dependencies and attempt to remove unused features to reduce compile times and binary size
---
# Rationalize Dependencies
This skill analyzes Cargo.toml dependencies to identify and remove unused features.
## Overview
Many crates enable features by default that may not be needed. This skill:
1. Identifies dependencies with default features enabled
2. Tests if `default-features = false` works
3. Identifies which specific features are actually needed
4. Verifies compilation after changes
## Step 1: Identify the target
Ask the user which crate(s) to analyze:
- A specific crate name (e.g., "tokio", "serde")
- A specific workspace member (e.g., "quickwit-search")
- "all" to scan the entire workspace
## Step 2: Analyze current dependencies
For the workspace Cargo.toml (`quickwit/Cargo.toml`), list dependencies that:
- Do NOT have `default-features = false`
- Have default features that might be unnecessary
Run: `cargo tree -p <crate> -f "{p} {f}" --edges features` to see what features are actually used.
## Step 3: For each candidate dependency
### 3a: Check the crate's default features
Look up the crate on crates.io or check its Cargo.toml to understand:
- What features are enabled by default
- What each feature provides
Use: `cargo metadata --format-version=1 | jq '.packages[] | select(.name == "<crate>") | .features'`
### 3b: Try disabling default features
Modify the dependency in `quickwit/Cargo.toml`:
From:
```toml
some-crate = { version = "1.0" }
```
To:
```toml
some-crate = { version = "1.0", default-features = false }
```
### 3c: Run cargo check
Run: `cargo check --workspace` (or target specific packages for faster feedback)
If compilation fails:
1. Read the error messages to identify which features are needed
2. Add only the required features explicitly:
```toml
some-crate = { version = "1.0", default-features = false, features = ["needed-feature"] }
```
3. Re-run cargo check
### 3d: Binary search for minimal features
If there are many default features, use binary search:
1. Start with no features
2. If it fails, add half the default features
3. Continue until you find the minimal set
## Step 4: Document findings
For each dependency analyzed, report:
- Original configuration
- New configuration (if changed)
- Features that were removed
- Any features that are required
## Step 5: Verify full build
After all changes, run:
```bash
cargo check --workspace --all-targets
cargo test --workspace --no-run
```
## Common Patterns
### Serde
Often only needs `derive`:
```toml
serde = { version = "1.0", default-features = false, features = ["derive", "std"] }
```
### Tokio
Identify which runtime features are actually used:
```toml
tokio = { version = "1.0", default-features = false, features = ["rt-multi-thread", "macros", "sync"] }
```
### Reqwest
Often doesn't need all TLS backends:
```toml
reqwest = { version = "0.11", default-features = false, features = ["rustls-tls", "json"] }
```
## Rollback
If changes cause issues:
```bash
git checkout quickwit/Cargo.toml
cargo check --workspace
```
## Tips
- Start with large crates that have many default features (tokio, reqwest, hyper)
- Use `cargo bloat --crates` to identify large dependencies
- Check `cargo tree -d` for duplicate dependencies that might indicate feature conflicts
- Some features are needed only for tests - consider using `[dev-dependencies]` features

View File

@@ -1,60 +0,0 @@
---
name: simple-pr
description: Create a simple PR from staged changes with an auto-generated commit message
disable-model-invocation: true
---
# Simple PR
Follow these steps to create a simple PR from staged changes:
## Step 1: Check workspace state
Run: `git status`
Verify that all changes have been staged (no unstaged changes). If there are unstaged changes, abort and ask the user to stage their changes first with `git add`.
Also verify that we are on the `main` branch. If not, abort and ask the user to switch to main first.
## Step 2: Ensure main is up to date
Run: `git pull origin main`
This ensures we're working from the latest code.
## Step 3: Review staged changes
Run: `git diff --cached`
Review the staged changes to understand what the PR will contain.
## Step 4: Generate commit message
Based on the staged changes, generate a concise commit message (1-2 sentences) that describes the "why" rather than the "what".
Display the proposed commit message to the user and ask for confirmation before proceeding.
## Step 5: Create a new branch
Get the git username: `git config user.name | tr ' ' '-' | tr '[:upper:]' '[:lower:]'`
Create a short, descriptive branch name based on the changes (e.g., `fix-typo-in-readme`, `add-retry-logic`, `update-deps`).
Create and checkout the branch: `git checkout -b {username}/{short-descriptive-name}`
## Step 6: Commit changes
Commit with the message from step 3:
```
git commit -m "{commit-message}"
```
## Step 7: Push and open a PR
Push the branch and open a PR:
```
git push -u origin {branch-name}
gh pr create --title "{commit-message-title}" --body "{longer-description-if-needed}"
```
Report the PR URL to the user when complete.

View File

@@ -15,7 +15,7 @@ rust-version = "1.85"
exclude = ["benches/*.json", "benches/*.txt"]
[dependencies]
oneshot = "0.1.13"
oneshot = "0.1.7"
base64 = "0.22.0"
byteorder = "1.4.3"
crc32fast = "1.3.2"
@@ -47,7 +47,7 @@ rustc-hash = "2.0.0"
thiserror = "2.0.1"
htmlescape = "0.3.1"
fail = { version = "0.5.0", optional = true }
time = { version = "0.3.47", features = ["serde-well-known"] }
time = { version = "0.3.35", features = ["serde-well-known"] }
smallvec = "1.8.0"
rayon = "1.5.2"
lru = "0.16.3"
@@ -64,8 +64,8 @@ query-grammar = { version = "0.25.0", path = "./query-grammar", package = "tanti
tantivy-bitpacker = { version = "0.9", path = "./bitpacker" }
common = { version = "0.10", path = "./common/", package = "tantivy-common" }
tokenizer-api = { version = "0.6", path = "./tokenizer-api", package = "tantivy-tokenizer-api" }
sketches-ddsketch = { git = "https://github.com/quickwit-oss/rust-sketches-ddsketch.git", rev = "555caf1", features = ["use_serde"] }
datasketches = "0.2.0"
sketches-ddsketch = { version = "0.3.0", features = ["use_serde"] }
hyperloglogplus = { version = "0.4.1", features = ["const-loop"] }
futures-util = { version = "0.3.28", optional = true }
futures-channel = { version = "0.3.28", optional = true }
fnv = "1.0.7"
@@ -86,7 +86,7 @@ futures = "0.3.21"
paste = "1.0.11"
more-asserts = "0.3.1"
rand_distr = "0.5"
time = { version = "0.3.47", features = ["serde-well-known", "macros"] }
time = { version = "0.3.10", features = ["serde-well-known", "macros"] }
postcard = { version = "1.0.4", features = [
"use-std",
], default-features = false }
@@ -193,11 +193,3 @@ harness = false
[[bench]]
name = "str_search_and_get"
harness = false
[[bench]]
name = "merge_segments"
harness = false
[[bench]]
name = "regex_all_terms"
harness = false

View File

@@ -10,7 +10,7 @@ use tantivy::aggregation::agg_req::Aggregations;
use tantivy::aggregation::AggregationCollector;
use tantivy::query::{AllQuery, TermQuery};
use tantivy::schema::{IndexRecordOption, Schema, TextFieldIndexing, FAST, STRING};
use tantivy::{doc, DateTime, Index, Term};
use tantivy::{doc, Index, Term};
#[global_allocator]
pub static GLOBAL: &PeakMemAlloc<std::alloc::System> = &INSTRUMENTED_SYSTEM;
@@ -70,12 +70,6 @@ fn bench_agg(mut group: InputGroup<Index>) {
register!(group, terms_many_json_mixed_type_with_avg_sub_agg);
register!(group, composite_term_many_page_1000);
register!(group, composite_term_many_page_1000_with_avg_sub_agg);
register!(group, composite_term_few);
register!(group, composite_histogram);
register!(group, composite_histogram_calendar);
register!(group, cardinality_agg);
register!(group, terms_status_with_cardinality_agg);
@@ -320,75 +314,6 @@ fn terms_many_json_mixed_type_with_avg_sub_agg(index: &Index) {
execute_agg(index, agg_req);
}
fn composite_term_few(index: &Index) {
let agg_req = json!({
"my_ctf": {
"composite": {
"sources": [
{ "text_few_terms": { "terms": { "field": "text_few_terms" } } }
],
"size": 1000
}
},
});
execute_agg(index, agg_req);
}
fn composite_term_many_page_1000(index: &Index) {
let agg_req = json!({
"my_ctmp1000": {
"composite": {
"sources": [
{ "text_many_terms": { "terms": { "field": "text_many_terms" } } }
],
"size": 1000
}
},
});
execute_agg(index, agg_req);
}
fn composite_term_many_page_1000_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
"my_ctmp1000wasa": {
"composite": {
"sources": [
{ "text_many_terms": { "terms": { "field": "text_many_terms" } } }
],
"size": 1000,
},
"aggs": {
"average_f64": { "avg": { "field": "score_f64" } }
}
},
});
execute_agg(index, agg_req);
}
fn composite_histogram(index: &Index) {
let agg_req = json!({
"my_ch": {
"composite": {
"sources": [
{ "f64_histogram": { "histogram": { "field": "score_f64", "interval": 1 } } }
],
"size": 1000
}
},
});
execute_agg(index, agg_req);
}
fn composite_histogram_calendar(index: &Index) {
let agg_req = json!({
"my_chc": {
"composite": {
"sources": [
{ "time_histogram": { "date_histogram": { "field": "timestamp", "calendar_interval": "month" } } }
],
"size": 1000
}
},
});
execute_agg(index, agg_req);
}
fn execute_agg(index: &Index, agg_req: serde_json::Value) {
let agg_req: Aggregations = serde_json::from_value(agg_req).unwrap();
let collector = get_collector(agg_req);
@@ -571,7 +496,6 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
let text_field_all_unique_terms =
schema_builder.add_text_field("text_all_unique_terms", STRING | FAST);
let text_field_many_terms = schema_builder.add_text_field("text_many_terms", STRING | FAST);
let text_field_few_terms = schema_builder.add_text_field("text_few_terms", STRING | FAST);
let text_field_few_terms_status =
schema_builder.add_text_field("text_few_terms_status", STRING | FAST);
let text_field_1000_terms_zipf =
@@ -580,7 +504,6 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
let score_field = schema_builder.add_u64_field("score", score_fieldtype.clone());
let score_field_f64 = schema_builder.add_f64_field("score_f64", score_fieldtype.clone());
let score_field_i64 = schema_builder.add_i64_field("score_i64", score_fieldtype);
let date_field = schema_builder.add_date_field("timestamp", FAST);
// use tmp dir
let index = if reuse_index {
Index::create_in_dir("agg_bench", schema_builder.build())?
@@ -600,7 +523,6 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
let log_level_distribution =
WeightedIndex::new(status_field_data.iter().map(|item| item.1)).unwrap();
let few_terms_data = ["INFO", "ERROR", "WARN", "DEBUG"];
let lg_norm = rand_distr::LogNormal::new(2.996f64, 0.979f64).unwrap();
let many_terms_data = (0..150_000)
@@ -636,8 +558,6 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
text_field_all_unique_terms => "coolo",
text_field_many_terms => "cool",
text_field_many_terms => "cool",
text_field_few_terms => "cool",
text_field_few_terms => "cool",
text_field_few_terms_status => log_level_sample_a,
text_field_few_terms_status => log_level_sample_b,
text_field_1000_terms_zipf => term_1000_a.as_str(),
@@ -668,13 +588,11 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
json_field => json,
text_field_all_unique_terms => format!("unique_term_{}", rng.random::<u64>()),
text_field_many_terms => many_terms_data.choose(&mut rng).unwrap().to_string(),
text_field_few_terms => few_terms_data.choose(&mut rng).unwrap().to_string(),
text_field_few_terms_status => status_field_data[log_level_distribution.sample(&mut rng)].0,
text_field_1000_terms_zipf => terms_1000[zipf_1000.sample(&mut rng) as usize - 1].as_str(),
score_field => val as u64,
score_field_f64 => lg_norm.sample(&mut rng),
score_field_i64 => val as i64,
date_field => DateTime::from_timestamp_millis((val * 1_000_000.) as i64),
))?;
if cardinality == Cardinality::OptionalSparse {
for _ in 0..20 {

View File

@@ -1,224 +0,0 @@
// Benchmarks segment merging
//
// Notes:
// - Input segments are kept intact (no deletes / no IndexWriter merge).
// - Output is written to a `NullDirectory` that discards all files except
// fieldnorms (needed for merging).
use std::collections::HashMap;
use std::io::{self, Write};
use std::path::{Path, PathBuf};
use std::sync::{Arc, RwLock};
use binggan::{black_box, BenchRunner};
use rand::prelude::*;
use rand::rngs::StdRng;
use rand::SeedableRng;
use tantivy::directory::error::{DeleteError, OpenReadError, OpenWriteError};
use tantivy::directory::{
AntiCallToken, Directory, FileHandle, OwnedBytes, TerminatingWrite, WatchCallback, WatchHandle,
WritePtr,
};
use tantivy::indexer::{merge_filtered_segments, NoMergePolicy};
use tantivy::schema::{Schema, TEXT};
use tantivy::{doc, HasLen, Index, IndexSettings, Segment};
#[derive(Clone, Default, Debug)]
struct NullDirectory {
blobs: Arc<RwLock<HashMap<PathBuf, OwnedBytes>>>,
}
struct NullWriter;
impl Write for NullWriter {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
impl TerminatingWrite for NullWriter {
fn terminate_ref(&mut self, _token: AntiCallToken) -> io::Result<()> {
Ok(())
}
}
struct InMemoryWriter {
path: PathBuf,
buffer: Vec<u8>,
blobs: Arc<RwLock<HashMap<PathBuf, OwnedBytes>>>,
}
impl Write for InMemoryWriter {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.buffer.extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
impl TerminatingWrite for InMemoryWriter {
fn terminate_ref(&mut self, _token: AntiCallToken) -> io::Result<()> {
let bytes = OwnedBytes::new(std::mem::take(&mut self.buffer));
self.blobs.write().unwrap().insert(self.path.clone(), bytes);
Ok(())
}
}
#[derive(Debug, Default)]
struct NullFileHandle;
impl HasLen for NullFileHandle {
fn len(&self) -> usize {
0
}
}
impl FileHandle for NullFileHandle {
fn read_bytes(&self, _range: std::ops::Range<usize>) -> io::Result<OwnedBytes> {
unimplemented!()
}
}
impl Directory for NullDirectory {
fn get_file_handle(&self, path: &Path) -> Result<Arc<dyn FileHandle>, OpenReadError> {
if let Some(bytes) = self.blobs.read().unwrap().get(path) {
return Ok(Arc::new(bytes.clone()));
}
Ok(Arc::new(NullFileHandle))
}
fn delete(&self, _path: &Path) -> Result<(), DeleteError> {
Ok(())
}
fn exists(&self, _path: &Path) -> Result<bool, OpenReadError> {
Ok(true)
}
fn open_write(&self, path: &Path) -> Result<WritePtr, OpenWriteError> {
let path_buf = path.to_path_buf();
if path.to_string_lossy().ends_with(".fieldnorm") {
let writer = InMemoryWriter {
path: path_buf,
buffer: Vec::new(),
blobs: Arc::clone(&self.blobs),
};
Ok(io::BufWriter::new(Box::new(writer)))
} else {
Ok(io::BufWriter::new(Box::new(NullWriter)))
}
}
fn atomic_read(&self, path: &Path) -> Result<Vec<u8>, OpenReadError> {
if let Some(bytes) = self.blobs.read().unwrap().get(path) {
return Ok(bytes.as_slice().to_vec());
}
Err(OpenReadError::FileDoesNotExist(path.to_path_buf()))
}
fn atomic_write(&self, _path: &Path, _data: &[u8]) -> io::Result<()> {
Ok(())
}
fn sync_directory(&self) -> io::Result<()> {
Ok(())
}
fn watch(&self, _watch_callback: WatchCallback) -> tantivy::Result<WatchHandle> {
Ok(WatchHandle::empty())
}
}
struct MergeScenario {
#[allow(dead_code)]
index: Index,
segments: Vec<Segment>,
settings: IndexSettings,
label: String,
}
fn build_index(
num_segments: usize,
docs_per_segment: usize,
tokens_per_doc: usize,
vocab_size: usize,
) -> MergeScenario {
let mut schema_builder = Schema::builder();
let body = schema_builder.add_text_field("body", TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema.clone());
assert!(vocab_size > 0);
let total_tokens = num_segments * docs_per_segment * tokens_per_doc;
let use_unique_terms = vocab_size >= total_tokens;
let mut rng = StdRng::from_seed([7u8; 32]);
let mut next_token_id: u64 = 0;
{
let mut writer = index.writer_with_num_threads(1, 256_000_000).unwrap();
writer.set_merge_policy(Box::new(NoMergePolicy));
for _ in 0..num_segments {
for _ in 0..docs_per_segment {
let mut tokens = Vec::with_capacity(tokens_per_doc);
for _ in 0..tokens_per_doc {
let token_id = if use_unique_terms {
let id = next_token_id;
next_token_id += 1;
id
} else {
rng.random_range(0..vocab_size as u64)
};
tokens.push(format!("term_{token_id}"));
}
writer.add_document(doc!(body => tokens.join(" "))).unwrap();
}
writer.commit().unwrap();
}
}
let segments = index.searchable_segments().unwrap();
let settings = index.settings().clone();
let label = format!(
"segments={}, docs/seg={}, tokens/doc={}, vocab={}",
num_segments, docs_per_segment, tokens_per_doc, vocab_size
);
MergeScenario {
index,
segments,
settings,
label,
}
}
fn main() {
let scenarios = vec![
build_index(8, 50_000, 12, 8),
build_index(16, 50_000, 12, 8),
build_index(16, 100_000, 12, 8),
build_index(8, 50_000, 8, 8 * 50_000 * 8),
];
let mut runner = BenchRunner::new();
for scenario in scenarios {
let mut group = runner.new_group();
group.set_name(format!("merge_segments inv_index — {}", scenario.label));
let segments = scenario.segments.clone();
let settings = scenario.settings.clone();
group.register("merge", move |_| {
let output_dir = NullDirectory::default();
let filter_doc_ids = vec![None; segments.len()];
let merged_index =
merge_filtered_segments(&segments, settings.clone(), filter_doc_ids, output_dir)
.unwrap();
black_box(merged_index);
});
group.run();
}
}

View File

@@ -1,113 +0,0 @@
// Benchmarks regex query that matches all terms in a synthetic index.
//
// Corpus model:
// - N unique terms: t000000, t000001, ...
// - M docs
// - K tokens per doc: doc i gets terms derived from (i, token_index)
//
// Query:
// - Regex "t.*" to match all terms
//
// Run with:
// - cargo bench --bench regex_all_terms
//
use std::fmt::Write;
use binggan::{black_box, BenchRunner};
use tantivy::collector::Count;
use tantivy::query::RegexQuery;
use tantivy::schema::{Schema, TEXT};
use tantivy::{doc, Index, ReloadPolicy};
const HEAP_SIZE_BYTES: usize = 200_000_000;
#[derive(Clone, Copy)]
struct BenchConfig {
num_terms: usize,
num_docs: usize,
tokens_per_doc: usize,
}
fn main() {
let configs = default_configs();
let mut runner = BenchRunner::new();
for config in configs {
let (index, text_field) = build_index(config, HEAP_SIZE_BYTES);
let reader = index
.reader_builder()
.reload_policy(ReloadPolicy::Manual)
.try_into()
.expect("reader");
let searcher = reader.searcher();
let query = RegexQuery::from_pattern("t.*", text_field).expect("regex query");
let mut group = runner.new_group();
group.set_name(format!(
"regex_all_terms_t{}_d{}_k{}",
config.num_terms, config.num_docs, config.tokens_per_doc
));
group.register("regex_count", move |_| {
let count = searcher.search(&query, &Count).expect("search");
black_box(count);
});
group.run();
}
}
fn default_configs() -> Vec<BenchConfig> {
vec![
BenchConfig {
num_terms: 10_000,
num_docs: 100_000,
tokens_per_doc: 1,
},
BenchConfig {
num_terms: 10_000,
num_docs: 100_000,
tokens_per_doc: 8,
},
BenchConfig {
num_terms: 100_000,
num_docs: 100_000,
tokens_per_doc: 1,
},
BenchConfig {
num_terms: 100_000,
num_docs: 100_000,
tokens_per_doc: 8,
},
]
}
fn build_index(config: BenchConfig, heap_size_bytes: usize) -> (Index, tantivy::schema::Field) {
let mut schema_builder = Schema::builder();
let text_field = schema_builder.add_text_field("text", TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let term_width = config.num_terms.to_string().len();
{
let mut writer = index
.writer_with_num_threads(1, heap_size_bytes)
.expect("writer");
let mut buffer = String::new();
for doc_id in 0..config.num_docs {
buffer.clear();
for token_idx in 0..config.tokens_per_doc {
if token_idx > 0 {
buffer.push(' ');
}
let term_id = (doc_id * config.tokens_per_doc + token_idx) % config.num_terms;
write!(&mut buffer, "t{term_id:0term_width$}").expect("write token");
}
writer
.add_document(doc!(text_field => buffer.as_str()))
.expect("add_document");
}
writer.commit().expect("commit");
}
(index, text_field)
}

View File

@@ -59,7 +59,7 @@ pub struct RowAddr {
pub row_id: RowId,
}
pub use sstable::{Dictionary, TermOrdHit};
pub use sstable::Dictionary;
pub type Streamer<'a> = sstable::Streamer<'a, VoidSSTable>;
pub use common::DateTime;

View File

@@ -15,10 +15,11 @@ repository = "https://github.com/quickwit-oss/tantivy"
byteorder = "1.4.3"
ownedbytes = { version= "0.9", path="../ownedbytes" }
async-trait = "0.1"
time = { version = "0.3.47", features = ["serde-well-known"] }
time = { version = "0.3.10", features = ["serde-well-known"] }
serde = { version = "1.0.136", features = ["derive"] }
[dev-dependencies]
binggan = "0.14.0"
proptest = "1.0.0"
rand = "0.9"

View File

@@ -62,9 +62,7 @@ impl<W: TerminatingWrite> TerminatingWrite for CountingWriter<W> {
pub struct AntiCallToken(());
/// Trait used to indicate when no more write need to be done on a writer
///
/// Thread-safety is enforced at the call sites that require it.
pub trait TerminatingWrite: Write {
pub trait TerminatingWrite: Write + Send + Sync {
/// Indicate that the writer will no longer be used. Internally call terminate_ref.
fn terminate(mut self) -> io::Result<()>
where Self: Sized {

View File

@@ -60,7 +60,7 @@ At indexing, tantivy will try to interpret number and strings as different type
priority order.
Numbers will be interpreted as u64, i64 and f64 in that order.
Strings will be interpreted as rfc3339 dates or simple strings.
Strings will be interpreted as rfc3999 dates or simple strings.
The first working type is picked and is the only term that is emitted for indexing.
Note this interpretation happens on a per-document basis, and there is no effort to try to sniff
@@ -81,7 +81,7 @@ Will be interpreted as
(my_path.my_segment, String, 233) or (my_path.my_segment, u64, 233)
```
Likewise, we need to emit two tokens if the query contains an rfc3339 date.
Likewise, we need to emit two tokens if the query contains an rfc3999 date.
Indeed the date could have been actually a single token inside the text of a document at ingestion time. Generally speaking, we will always at least emit a string token in query parsing, and sometimes more.
If one more json field is defined, things get even more complicated.

View File

@@ -70,7 +70,7 @@ impl Collector for StatsCollector {
fn for_segment(
&self,
_segment_local_id: u32,
segment_reader: &SegmentReader,
segment_reader: &dyn SegmentReader,
) -> tantivy::Result<StatsSegmentCollector> {
let fast_field_reader = segment_reader.fast_fields().u64(&self.field)?;
Ok(StatsSegmentCollector {

View File

@@ -65,7 +65,7 @@ fn main() -> tantivy::Result<()> {
);
let top_docs_by_custom_score =
// Call TopDocs with a custom tweak score
TopDocs::with_limit(2).tweak_score(move |segment_reader: &SegmentReader| {
TopDocs::with_limit(2).tweak_score(move |segment_reader: &dyn SegmentReader| {
let ingredient_reader = segment_reader.facet_reader("ingredient").unwrap();
let facet_dict = ingredient_reader.facet_dict();

View File

@@ -43,7 +43,7 @@ impl DynamicPriceColumn {
}
}
pub fn price_for_segment(&self, segment_reader: &SegmentReader) -> Option<Arc<Vec<Price>>> {
pub fn price_for_segment(&self, segment_reader: &dyn SegmentReader) -> Option<Arc<Vec<Price>>> {
let segment_key = (segment_reader.segment_id(), segment_reader.delete_opstamp());
self.price_cache.read().unwrap().get(&segment_key).cloned()
}
@@ -157,7 +157,7 @@ fn main() -> tantivy::Result<()> {
let query = query_parser.parse_query("cooking")?;
let searcher = reader.searcher();
let score_by_price = move |segment_reader: &SegmentReader| {
let score_by_price = move |segment_reader: &dyn SegmentReader| {
let price = price_dynamic_column
.price_for_segment(segment_reader)
.unwrap();

View File

@@ -560,7 +560,7 @@ fn range_infallible(inp: &str) -> JResult<&str, UserInputLeaf> {
(
(
value((), tag(">=")),
map(word_infallible(")", false), |(bound, err)| {
map(word_infallible("", false), |(bound, err)| {
(
(
bound
@@ -574,7 +574,7 @@ fn range_infallible(inp: &str) -> JResult<&str, UserInputLeaf> {
),
(
value((), tag("<=")),
map(word_infallible(")", false), |(bound, err)| {
map(word_infallible("", false), |(bound, err)| {
(
(
UserInputBound::Unbounded,
@@ -588,7 +588,7 @@ fn range_infallible(inp: &str) -> JResult<&str, UserInputLeaf> {
),
(
value((), tag(">")),
map(word_infallible(")", false), |(bound, err)| {
map(word_infallible("", false), |(bound, err)| {
(
(
bound
@@ -602,7 +602,7 @@ fn range_infallible(inp: &str) -> JResult<&str, UserInputLeaf> {
),
(
value((), tag("<")),
map(word_infallible(")", false), |(bound, err)| {
map(word_infallible("", false), |(bound, err)| {
(
(
UserInputBound::Unbounded,
@@ -704,11 +704,7 @@ fn regex(inp: &str) -> IResult<&str, UserInputLeaf> {
many1(alt((preceded(char('\\'), char('/')), none_of("/")))),
char('/'),
),
peek(alt((
value((), multispace1),
value((), char(')')),
value((), eof),
))),
peek(alt((multispace1, eof))),
),
|elements| UserInputLeaf::Regex {
field: None,
@@ -725,12 +721,8 @@ fn regex_infallible(inp: &str) -> JResult<&str, UserInputLeaf> {
opt_i_err(char('/'), "missing delimiter /"),
),
opt_i_err(
peek(alt((
value((), multispace1),
value((), char(')')),
value((), eof),
))),
"expected whitespace, closing parenthesis, or end of input",
peek(alt((multispace1, eof))),
"expected whitespace or end of input",
),
)(inp)
{
@@ -1331,14 +1323,6 @@ mod test {
test_parse_query_to_ast_helper("<a", "{\"*\" TO \"a\"}");
test_parse_query_to_ast_helper("<=a", "{\"*\" TO \"a\"]");
test_parse_query_to_ast_helper("<=bsd", "{\"*\" TO \"bsd\"]");
test_parse_query_to_ast_helper("(<=42)", "{\"*\" TO \"42\"]");
test_parse_query_to_ast_helper("(<=42 )", "{\"*\" TO \"42\"]");
test_parse_query_to_ast_helper("(age:>5)", "\"age\":{\"5\" TO \"*\"}");
test_parse_query_to_ast_helper(
"(title:bar AND age:>12)",
"(+\"title\":bar +\"age\":{\"12\" TO \"*\"})",
);
}
#[test]
@@ -1715,10 +1699,6 @@ mod test {
test_parse_query_to_ast_helper("foo:(A OR B)", "(?\"foo\":A ?\"foo\":B)");
test_parse_query_to_ast_helper("foo:(A* OR B*)", "(?\"foo\":A* ?\"foo\":B*)");
test_parse_query_to_ast_helper("foo:(*A OR *B)", "(?\"foo\":*A ?\"foo\":*B)");
// Regexes between parentheses
test_parse_query_to_ast_helper("foo:(/A.*/)", "\"foo\":/A.*/");
test_parse_query_to_ast_helper("foo:(/A.*/ OR /B.*/)", "(?\"foo\":/A.*/ ?\"foo\":/B.*/)");
}
#[test]

View File

@@ -66,7 +66,6 @@ impl UserInputLeaf {
}
UserInputLeaf::Range { field, .. } if field.is_none() => *field = Some(default_field),
UserInputLeaf::Set { field, .. } if field.is_none() => *field = Some(default_field),
UserInputLeaf::Regex { field, .. } if field.is_none() => *field = Some(default_field),
_ => (), // field was already set, do nothing
}
}

View File

@@ -57,7 +57,7 @@ pub(crate) fn get_numeric_or_date_column_types() -> &'static [ColumnType] {
/// Get fast field reader or empty as default.
pub(crate) fn get_ff_reader(
reader: &SegmentReader,
reader: &dyn SegmentReader,
field_name: &str,
allowed_column_types: Option<&[ColumnType]>,
) -> crate::Result<(columnar::Column<u64>, ColumnType)> {
@@ -74,7 +74,7 @@ pub(crate) fn get_ff_reader(
}
pub(crate) fn get_dynamic_columns(
reader: &SegmentReader,
reader: &dyn SegmentReader,
field_name: &str,
) -> crate::Result<Vec<columnar::DynamicColumn>> {
let ff_fields = reader.fast_fields().dynamic_column_handles(field_name)?;
@@ -90,7 +90,7 @@ pub(crate) fn get_dynamic_columns(
///
/// Is guaranteed to return at least one column.
pub(crate) fn get_all_ff_reader_or_empty(
reader: &SegmentReader,
reader: &dyn SegmentReader,
field_name: &str,
allowed_column_types: Option<&[ColumnType]>,
fallback_type: ColumnType,

View File

@@ -10,10 +10,9 @@ use crate::aggregation::accessor_helpers::{
};
use crate::aggregation::agg_req::{Aggregation, AggregationVariants, Aggregations};
use crate::aggregation::bucket::{
build_segment_filter_collector, build_segment_range_collector, CompositeAggReqData,
CompositeAggregation, CompositeSourceAccessors, FilterAggReqData, HistogramAggReqData,
HistogramBounds, IncludeExcludeParam, MissingTermAggReqData, RangeAggReqData,
SegmentHistogramCollector, TermMissingAgg, TermsAggReqData, TermsAggregation,
build_segment_filter_collector, build_segment_range_collector, FilterAggReqData,
HistogramAggReqData, HistogramBounds, IncludeExcludeParam, MissingTermAggReqData,
RangeAggReqData, SegmentHistogramCollector, TermMissingAgg, TermsAggReqData, TermsAggregation,
TermsAggregationInternal,
};
use crate::aggregation::metric::{
@@ -74,12 +73,6 @@ impl AggregationsSegmentCtx {
self.per_request.filter_req_data.push(Some(Box::new(data)));
self.per_request.filter_req_data.len() - 1
}
pub(crate) fn push_composite_req_data(&mut self, data: CompositeAggReqData) -> usize {
self.per_request
.composite_req_data
.push(Some(Box::new(data)));
self.per_request.composite_req_data.len() - 1
}
#[inline]
pub(crate) fn get_term_req_data(&self, idx: usize) -> &TermsAggReqData {
@@ -115,12 +108,6 @@ impl AggregationsSegmentCtx {
.as_deref()
.expect("range_req_data slot is empty (taken)")
}
#[inline]
pub(crate) fn get_composite_req_data(&self, idx: usize) -> &CompositeAggReqData {
self.per_request.composite_req_data[idx]
.as_deref()
.expect("composite_req_data slot is empty (taken)")
}
// ---------- mutable getters ----------
@@ -194,25 +181,6 @@ impl AggregationsSegmentCtx {
debug_assert!(self.per_request.filter_req_data[idx].is_none());
self.per_request.filter_req_data[idx] = Some(value);
}
/// Move out the Composite request at `idx`.
#[inline]
pub(crate) fn take_composite_req_data(&mut self, idx: usize) -> Box<CompositeAggReqData> {
self.per_request.composite_req_data[idx]
.take()
.expect("composite_req_data slot is empty (taken)")
}
/// Put back a Composite request into an empty slot at `idx`.
#[inline]
pub(crate) fn put_back_composite_req_data(
&mut self,
idx: usize,
value: Box<CompositeAggReqData>,
) {
debug_assert!(self.per_request.composite_req_data[idx].is_none());
self.per_request.composite_req_data[idx] = Some(value);
}
}
/// Each type of aggregation has its own request data struct. This struct holds
@@ -240,8 +208,6 @@ pub struct PerRequestAggSegCtx {
pub top_hits_req_data: Vec<TopHitsAggReqData>,
/// MissingTermAggReqData contains the request data for a missing term aggregation.
pub missing_term_req_data: Vec<MissingTermAggReqData>,
/// CompositeAggReqData contains the request data for a composite aggregation.
pub composite_req_data: Vec<Option<Box<CompositeAggReqData>>>,
/// Request tree used to build collectors.
pub agg_tree: Vec<AggRefNode>,
@@ -289,11 +255,6 @@ impl PerRequestAggSegCtx {
.iter()
.map(|t| t.get_memory_consumption())
.sum::<usize>()
+ self
.composite_req_data
.iter()
.map(|b| b.as_ref().map(|d| d.get_memory_consumption()).unwrap_or(0))
.sum::<usize>()
+ self.agg_tree.len() * std::mem::size_of::<AggRefNode>()
}
@@ -330,11 +291,6 @@ impl PerRequestAggSegCtx {
.expect("filter_req_data slot is empty (taken)")
.name
.as_str(),
AggKind::Composite => self.composite_req_data[idx]
.as_deref()
.expect("composite_req_data slot is empty (taken)")
.name
.as_str(),
}
}
@@ -461,11 +417,6 @@ pub(crate) fn build_segment_agg_collector(
)?)),
AggKind::Range => Ok(build_segment_range_collector(req, node)?),
AggKind::Filter => build_segment_filter_collector(req, node),
AggKind::Composite => Ok(Box::new(
crate::aggregation::bucket::SegmentCompositeCollector::from_req_and_validate(
req, node,
)?,
)),
}
}
@@ -496,7 +447,6 @@ pub enum AggKind {
DateHistogram,
Range,
Filter,
Composite,
}
impl AggKind {
@@ -512,7 +462,6 @@ impl AggKind {
AggKind::DateHistogram => "DateHistogram",
AggKind::Range => "Range",
AggKind::Filter => "Filter",
AggKind::Composite => "Composite",
}
}
}
@@ -520,7 +469,7 @@ impl AggKind {
/// Build AggregationsData by walking the request tree.
pub(crate) fn build_aggregations_data_from_req(
aggs: &Aggregations,
reader: &SegmentReader,
reader: &dyn SegmentReader,
segment_ordinal: SegmentOrdinal,
context: AggContextParams,
) -> crate::Result<AggregationsSegmentCtx> {
@@ -540,7 +489,7 @@ pub(crate) fn build_aggregations_data_from_req(
fn build_nodes(
agg_name: &str,
req: &Aggregation,
reader: &SegmentReader,
reader: &dyn SegmentReader,
segment_ordinal: SegmentOrdinal,
data: &mut AggregationsSegmentCtx,
is_top_level: bool,
@@ -760,14 +709,6 @@ fn build_nodes(
children,
}])
}
AggregationVariants::Composite(composite_req) => Ok(vec![build_composite_node(
agg_name,
reader,
segment_ordinal,
data,
&req.sub_aggregation,
composite_req,
)?]),
AggregationVariants::Filter(filter_req) => {
// Build the query and evaluator upfront
let schema = reader.schema();
@@ -787,7 +728,6 @@ fn build_nodes(
let idx_in_req_data = data.push_filter_req_data(FilterAggReqData {
name: agg_name.to_string(),
req: filter_req.clone(),
segment_reader: reader.clone(),
evaluator,
matching_docs_buffer,
is_top_level,
@@ -802,38 +742,9 @@ fn build_nodes(
}
}
fn build_composite_node(
agg_name: &str,
reader: &SegmentReader,
_segment_ordinal: SegmentOrdinal,
data: &mut AggregationsSegmentCtx,
sub_aggs: &Aggregations,
req: &CompositeAggregation,
) -> crate::Result<AggRefNode> {
let mut composite_accessors = Vec::with_capacity(req.sources.len());
for source in &req.sources {
let source_after_key_opt = req.after.get(source.name()).map(|k| &k.0);
let source_accessor =
CompositeSourceAccessors::build_for_source(reader, source, source_after_key_opt)?;
composite_accessors.push(source_accessor);
}
let agg = CompositeAggReqData {
name: agg_name.to_string(),
req: req.clone(),
composite_accessors,
};
let idx = data.push_composite_req_data(agg);
let children = build_children(sub_aggs, reader, _segment_ordinal, data)?;
Ok(AggRefNode {
kind: AggKind::Composite,
idx_in_req_data: idx,
children,
})
}
fn build_children(
aggs: &Aggregations,
reader: &SegmentReader,
reader: &dyn SegmentReader,
segment_ordinal: SegmentOrdinal,
data: &mut AggregationsSegmentCtx,
) -> crate::Result<Vec<AggRefNode>> {
@@ -852,7 +763,7 @@ fn build_children(
}
fn get_term_agg_accessors(
reader: &SegmentReader,
reader: &dyn SegmentReader,
field_name: &str,
missing: &Option<Key>,
) -> crate::Result<Vec<(Column<u64>, ColumnType)>> {
@@ -905,7 +816,7 @@ fn build_terms_or_cardinality_nodes(
agg_name: &str,
field_name: &str,
missing: &Option<Key>,
reader: &SegmentReader,
reader: &dyn SegmentReader,
segment_ordinal: SegmentOrdinal,
data: &mut AggregationsSegmentCtx,
sub_aggs: &Aggregations,

View File

@@ -32,8 +32,8 @@ use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use super::bucket::{
CompositeAggregation, DateHistogramAggregationReq, FilterAggregation, HistogramAggregation,
RangeAggregation, TermsAggregation,
DateHistogramAggregationReq, FilterAggregation, HistogramAggregation, RangeAggregation,
TermsAggregation,
};
use super::metric::{
AverageAggregation, CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation,
@@ -134,9 +134,6 @@ pub enum AggregationVariants {
/// Filter documents into a single bucket.
#[serde(rename = "filter")]
Filter(FilterAggregation),
/// Multi-dimensional, paginable bucket aggregation.
#[serde(rename = "composite")]
Composite(CompositeAggregation),
// Metric aggregation types
/// Computes the average of the extracted values.
@@ -183,11 +180,6 @@ impl AggregationVariants {
AggregationVariants::Histogram(histogram) => vec![histogram.field.as_str()],
AggregationVariants::DateHistogram(histogram) => vec![histogram.field.as_str()],
AggregationVariants::Filter(filter) => filter.get_fast_field_names(),
AggregationVariants::Composite(composite) => composite
.sources
.iter()
.map(|source| source.field())
.collect(),
AggregationVariants::Average(avg) => vec![avg.field_name()],
AggregationVariants::Count(count) => vec![count.field_name()],
AggregationVariants::Max(max) => vec![max.field_name()],
@@ -222,12 +214,6 @@ impl AggregationVariants {
_ => None,
}
}
pub(crate) fn as_composite(&self) -> Option<&CompositeAggregation> {
match &self {
AggregationVariants::Composite(composite) => Some(composite),
_ => None,
}
}
pub(crate) fn as_percentile(&self) -> Option<&PercentilesAggregationReq> {
match &self {
AggregationVariants::Percentiles(percentile_req) => Some(percentile_req),

View File

@@ -9,12 +9,10 @@ use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use super::bucket::GetDocCount;
use super::intermediate_agg_result::CompositeIntermediateKey;
use super::metric::{
ExtendedStats, PercentilesMetricResult, SingleMetricResult, Stats, TopHitsMetricResult,
};
use super::{AggregationError, Key};
use crate::aggregation::bucket::AfterKey;
use crate::TantivyError;
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
@@ -160,14 +158,6 @@ pub enum BucketResult {
},
/// This is the filter result - a single bucket with sub-aggregations
Filter(FilterBucketResult),
/// This is the composite result
Composite {
/// The buckets
buckets: Vec<CompositeBucketEntry>,
/// The key to start after when paginating
#[serde(skip_serializing_if = "FxHashMap::is_empty")]
after_key: FxHashMap<String, AfterKey>,
},
}
impl BucketResult {
@@ -189,9 +179,6 @@ impl BucketResult {
// Only count sub-aggregation buckets
filter_result.sub_aggregations.get_bucket_count()
}
BucketResult::Composite { buckets, .. } => {
buckets.iter().map(|bucket| bucket.get_bucket_count()).sum()
}
}
}
}
@@ -350,87 +337,3 @@ pub struct FilterBucketResult {
#[serde(flatten)]
pub sub_aggregations: AggregationResults,
}
/// Note the type information loss compared to `CompositeIntermediateKey`.
/// Pagination is performed using `AfterKey`, which encodes type information.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(untagged)]
pub enum CompositeKey {
/// Boolean key
Bool(bool),
/// String key
Str(String),
/// `i64` key
I64(i64),
/// `u64` key
U64(u64),
/// `f64` key
F64(f64),
/// Null key
Null,
}
impl Eq for CompositeKey {}
impl std::hash::Hash for CompositeKey {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
core::mem::discriminant(self).hash(state);
match self {
Self::Bool(val) => val.hash(state),
Self::Str(text) => text.hash(state),
Self::F64(val) => val.to_bits().hash(state),
Self::U64(val) => val.hash(state),
Self::I64(val) => val.hash(state),
Self::Null => {}
}
}
}
impl PartialEq for CompositeKey {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::Bool(l), Self::Bool(r)) => l == r,
(Self::Str(l), Self::Str(r)) => l == r,
(Self::F64(l), Self::F64(r)) => l.to_bits() == r.to_bits(),
(Self::I64(l), Self::I64(r)) => l == r,
(Self::U64(l), Self::U64(r)) => l == r,
(Self::Null, Self::Null) => true,
_ => false,
}
}
}
impl From<CompositeIntermediateKey> for CompositeKey {
fn from(value: CompositeIntermediateKey) -> Self {
match value {
CompositeIntermediateKey::Str(s) => Self::Str(s),
CompositeIntermediateKey::IpAddr(s) => {
if let Some(ip) = s.to_ipv4_mapped() {
Self::Str(ip.to_string())
} else {
Self::Str(s.to_string())
}
}
CompositeIntermediateKey::F64(f) => Self::F64(f),
CompositeIntermediateKey::Bool(f) => Self::Bool(f),
CompositeIntermediateKey::U64(f) => Self::U64(f),
CompositeIntermediateKey::I64(f) => Self::I64(f),
CompositeIntermediateKey::DateTime(f) => Self::I64(f / 1_000_000), // ns to ms
CompositeIntermediateKey::Null => Self::Null,
}
}
}
/// Composite bucket entry with a multi-dimensional key.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct CompositeBucketEntry {
/// The identifier of the bucket.
pub key: FxHashMap<String, CompositeKey>,
/// Number of documents in the bucket.
pub doc_count: u64,
#[serde(flatten)]
/// Sub-aggregations in this bucket.
pub sub_aggregation: AggregationResults,
}
impl CompositeBucketEntry {
pub(crate) fn get_bucket_count(&self) -> u64 {
1 + self.sub_aggregation.get_bucket_count()
}
}

View File

@@ -1,548 +0,0 @@
use std::net::Ipv6Addr;
use columnar::column_values::CompactSpaceU64Accessor;
use columnar::{Column, ColumnType, MonotonicallyMappableToU64, StrColumn, TermOrdHit};
use crate::aggregation::accessor_helpers::get_numeric_or_date_column_types;
use crate::aggregation::bucket::composite::numeric_types::num_proj;
use crate::aggregation::bucket::composite::numeric_types::num_proj::ProjectedNumber;
use crate::aggregation::bucket::composite::ToTypePaginationOrder;
use crate::aggregation::bucket::{
parse_into_milliseconds, CalendarInterval, CompositeAggregation, CompositeAggregationSource,
MissingOrder, Order,
};
use crate::aggregation::intermediate_agg_result::CompositeIntermediateKey;
use crate::{SegmentReader, TantivyError};
/// Contains all information required by the SegmentCompositeCollector to perform the
/// composite aggregation on a segment.
pub struct CompositeAggReqData {
/// The name of the aggregation.
pub name: String,
/// The normalized term aggregation request.
pub req: CompositeAggregation,
/// Accessors for each source, each source can have multiple accessors (columns).
pub composite_accessors: Vec<CompositeSourceAccessors>,
}
impl CompositeAggReqData {
/// Estimate the memory consumption of this struct in bytes.
pub fn get_memory_consumption(&self) -> usize {
std::mem::size_of::<Self>()
+ self.composite_accessors.len() * std::mem::size_of::<CompositeSourceAccessors>()
}
}
/// Accessors for a single column in a composite source.
pub struct CompositeAccessor {
/// The fast field column
pub column: Column<u64>,
/// The column type
pub column_type: ColumnType,
/// Term dictionary if the column type is Str
///
/// Only used by term sources
pub str_dict_column: Option<StrColumn>,
/// Parsed date interval for date histogram sources
pub date_histogram_interval: PrecomputedDateInterval,
}
/// Accessors to all the columns that belong to the field of a composite source.
pub struct CompositeSourceAccessors {
/// The accessors for this source
pub accessors: Vec<CompositeAccessor>,
/// The key after which to start collecting results. Applies to the first
/// column of the source.
pub after_key: PrecomputedAfterKey,
/// The column index the after_key applies to. The after_key only applies to
/// one column. Columns before should be skipped. Columns after should be
/// kept without comparison to the after_key.
pub after_key_accessor_idx: usize,
/// Whether to skip missing values because of the after_key. Skipping only
/// applies if the value for previous columns were exactly equal to the
/// corresponding after keys (is_on_after_key).
pub skip_missing: bool,
/// The after key was set to null to indicate that the last collected key
/// was a missing value.
pub is_after_key_explicit_missing: bool,
}
impl CompositeSourceAccessors {
/// Creates a new set of accessors for the composite source.
///
/// Precomputes some values to make collection faster.
pub fn build_for_source(
reader: &SegmentReader,
source: &CompositeAggregationSource,
// First option is None when no after key was set in the query, the
// second option is None when the after key was set but its value for
// this source was set to `null`
source_after_key_opt: Option<&CompositeIntermediateKey>,
) -> crate::Result<Self> {
let is_after_key_explicit_missing = source_after_key_opt
.map(|after_key| matches!(after_key, CompositeIntermediateKey::Null))
.unwrap_or(false);
let mut skip_missing = false;
if let Some(CompositeIntermediateKey::Null) = source_after_key_opt {
if !source.missing_bucket() {
return Err(TantivyError::InvalidArgument(
"the 'after' key for a source cannot be null when 'missing_bucket' is false"
.to_string(),
));
}
} else if source_after_key_opt.is_some() {
// if missing buckets come first and we have a non null after key, we skip missing
if MissingOrder::First == source.missing_order() {
skip_missing = true;
}
if MissingOrder::Default == source.missing_order() && Order::Asc == source.order() {
skip_missing = true;
}
};
match source {
CompositeAggregationSource::Terms(source) => {
let allowed_column_types = [
ColumnType::I64,
ColumnType::U64,
ColumnType::F64,
ColumnType::Str,
ColumnType::DateTime,
ColumnType::Bool,
ColumnType::IpAddr,
// ColumnType::Bytes Unsupported
];
let mut columns_and_types = reader
.fast_fields()
.u64_lenient_for_type_all(Some(&allowed_column_types), &source.field)?;
// Sort columns by their pagination order and determine which to skip
columns_and_types.sort_by_key(|(_, col_type): &(Column, ColumnType)| {
col_type.column_pagination_order()
});
if source.order == Order::Desc {
columns_and_types.reverse();
}
let after_key_accessor_idx = find_first_column_to_collect(
&columns_and_types,
source_after_key_opt,
source.missing_order,
source.order,
)?;
let source_collectors: Vec<CompositeAccessor> = columns_and_types
.into_iter()
.map(|(column, column_type)| {
Ok(CompositeAccessor {
column,
column_type,
str_dict_column: reader.fast_fields().str(&source.field)?,
date_histogram_interval: PrecomputedDateInterval::NotApplicable,
})
})
.collect::<crate::Result<_>>()?;
let after_key = if let Some(first_col) =
source_collectors.get(after_key_accessor_idx)
{
match source_after_key_opt {
Some(after_key) => PrecomputedAfterKey::precompute(
&first_col,
after_key,
&source.field,
source.missing_order,
source.order,
)?,
None => {
precompute_missing_after_key(false, source.missing_order, source.order)
}
}
} else {
// if no columns, we don't care about the after_key
PrecomputedAfterKey::Next(0)
};
Ok(CompositeSourceAccessors {
accessors: source_collectors,
is_after_key_explicit_missing,
skip_missing,
after_key,
after_key_accessor_idx,
})
}
CompositeAggregationSource::Histogram(source) => {
let column_and_types: Vec<(Column, ColumnType)> =
reader.fast_fields().u64_lenient_for_type_all(
Some(get_numeric_or_date_column_types()),
&source.field,
)?;
let source_collectors: Vec<CompositeAccessor> = column_and_types
.into_iter()
.map(|(column, column_type)| {
Ok(CompositeAccessor {
column,
column_type,
str_dict_column: None,
date_histogram_interval: PrecomputedDateInterval::NotApplicable,
})
})
.collect::<crate::Result<_>>()?;
let after_key = match source_after_key_opt {
Some(CompositeIntermediateKey::F64(key)) => {
let normalized_key = *key / source.interval;
num_proj::f64_to_i64(normalized_key).into()
}
Some(CompositeIntermediateKey::Null) => {
precompute_missing_after_key(true, source.missing_order, source.order)
}
None => precompute_missing_after_key(true, source.missing_order, source.order),
_ => {
return Err(crate::TantivyError::InvalidArgument(
"After key type invalid for interval composite source".to_string(),
));
}
};
Ok(CompositeSourceAccessors {
accessors: source_collectors,
is_after_key_explicit_missing,
skip_missing,
after_key,
after_key_accessor_idx: 0,
})
}
CompositeAggregationSource::DateHistogram(source) => {
let column_and_types = reader
.fast_fields()
.u64_lenient_for_type_all(Some(&[ColumnType::DateTime]), &source.field)?;
let date_histogram_interval =
PrecomputedDateInterval::from_date_histogram_source_intervals(
&source.fixed_interval,
source.calendar_interval,
)?;
let source_collectors: Vec<CompositeAccessor> = column_and_types
.into_iter()
.map(|(column, column_type)| {
Ok(CompositeAccessor {
column,
column_type,
str_dict_column: None,
date_histogram_interval,
})
})
.collect::<crate::Result<_>>()?;
let after_key = match source_after_key_opt {
Some(CompositeIntermediateKey::DateTime(key)) => {
PrecomputedAfterKey::Exact(key.to_u64())
}
Some(CompositeIntermediateKey::Null) => {
precompute_missing_after_key(true, source.missing_order, source.order)
}
None => precompute_missing_after_key(true, source.missing_order, source.order),
_ => {
return Err(crate::TantivyError::InvalidArgument(
"After key type invalid for interval composite source".to_string(),
));
}
};
Ok(CompositeSourceAccessors {
accessors: source_collectors,
is_after_key_explicit_missing,
skip_missing,
after_key,
after_key_accessor_idx: 0,
})
}
}
}
}
/// Finds the index of the first column we should start collecting from to
/// resume the pagination from the after_key.
fn find_first_column_to_collect<T>(
sorted_columns: &[(T, ColumnType)],
after_key_opt: Option<&CompositeIntermediateKey>,
missing_order: MissingOrder,
order: Order,
) -> crate::Result<usize> {
let after_key = match after_key_opt {
None => return Ok(0), // No pagination, start from beginning
Some(key) => key,
};
// Handle null after_key (we were on a missing value last time)
if matches!(after_key, CompositeIntermediateKey::Null) {
return match (missing_order, order) {
// Missing values come first, so all columns remain
(MissingOrder::First, _) | (MissingOrder::Default, Order::Asc) => Ok(0),
// Missing values come last, so all columns are done
(MissingOrder::Last, _) | (MissingOrder::Default, Order::Desc) => {
Ok(sorted_columns.len())
}
};
}
// Find the first column whose type order matches or follows the after_key's
// type in the pagination sequence
let after_key_column_order = after_key.column_pagination_order();
for (idx, (_, col_type)) in sorted_columns.iter().enumerate() {
let col_order = col_type.column_pagination_order();
let is_first_to_collect = match order {
Order::Asc => col_order >= after_key_column_order,
Order::Desc => col_order <= after_key_column_order,
};
if is_first_to_collect {
return Ok(idx);
}
}
// All columns are before the after_key, nothing left to collect
Ok(sorted_columns.len())
}
fn precompute_missing_after_key(
is_after_key_explicit_missing: bool,
missing_order: MissingOrder,
order: Order,
) -> PrecomputedAfterKey {
let after_last = PrecomputedAfterKey::AfterLast;
let before_first = PrecomputedAfterKey::Next(0);
match (is_after_key_explicit_missing, missing_order, order) {
(true, MissingOrder::First, Order::Asc) => before_first,
(true, MissingOrder::First, Order::Desc) => after_last,
(true, MissingOrder::Last, Order::Asc) => after_last,
(true, MissingOrder::Last, Order::Desc) => before_first,
(true, MissingOrder::Default, Order::Asc) => before_first,
(true, MissingOrder::Default, Order::Desc) => after_last,
(false, _, Order::Asc) => before_first,
(false, _, Order::Desc) => after_last,
}
}
/// A parsed representation of the date interval for date histogram sources
#[derive(Clone, Copy, Debug)]
pub enum PrecomputedDateInterval {
/// This is not a date histogram source
NotApplicable,
/// Source was configured with a fixed interval
FixedNanoseconds(i64),
/// Source was configured with a calendar interval
Calendar(CalendarInterval),
}
impl PrecomputedDateInterval {
/// Validates the date histogram source interval fields and parses a date interval from them.
pub fn from_date_histogram_source_intervals(
fixed_interval: &Option<String>,
calendar_interval: Option<CalendarInterval>,
) -> crate::Result<Self> {
match (fixed_interval, calendar_interval) {
(Some(_), Some(_)) | (None, None) => Err(TantivyError::InvalidArgument(
"date histogram source must one and only one of fixed_interval or \
calendar_interval set"
.to_string(),
)),
(Some(fixed_interval), None) => {
let fixed_interval_ms = parse_into_milliseconds(&fixed_interval)?;
Ok(PrecomputedDateInterval::FixedNanoseconds(
fixed_interval_ms * 1_000_000,
))
}
(None, Some(calendar_interval)) => {
Ok(PrecomputedDateInterval::Calendar(calendar_interval))
}
}
}
}
/// The after key projected to the u64 column space
///
/// Some column types (term, IP) might not have an exact representation of the
/// specified after key
#[derive(Debug)]
pub enum PrecomputedAfterKey {
/// The after key could be exactly represented in the column space.
Exact(u64),
/// The after key could not be exactly represented exactly represented, so
/// this is the next closest one.
Next(u64),
/// The after key could not be represented in the column space, it is
/// greater than all value
AfterLast,
}
impl From<TermOrdHit> for PrecomputedAfterKey {
fn from(hit: TermOrdHit) -> Self {
match hit {
TermOrdHit::Exact(ord) => PrecomputedAfterKey::Exact(ord),
// TermOrdHit represents AfterLast as Next(u64::MAX), we keep it as is
TermOrdHit::Next(ord) => PrecomputedAfterKey::Next(ord),
}
}
}
impl<T: MonotonicallyMappableToU64> From<ProjectedNumber<T>> for PrecomputedAfterKey {
fn from(num: ProjectedNumber<T>) -> Self {
match num {
ProjectedNumber::Exact(number) => PrecomputedAfterKey::Exact(number.to_u64()),
ProjectedNumber::Next(number) => PrecomputedAfterKey::Next(number.to_u64()),
ProjectedNumber::AfterLast => PrecomputedAfterKey::AfterLast,
}
}
}
// /!\ These operators only makes sense if both values are in the same column space
impl PrecomputedAfterKey {
pub fn equals(&self, column_value: u64) -> bool {
match self {
PrecomputedAfterKey::Exact(v) => *v == column_value,
PrecomputedAfterKey::Next(_) => false,
PrecomputedAfterKey::AfterLast => false,
}
}
pub fn gt(&self, column_value: u64) -> bool {
match self {
PrecomputedAfterKey::Exact(v) => *v > column_value,
PrecomputedAfterKey::Next(v) => *v > column_value,
PrecomputedAfterKey::AfterLast => true,
}
}
pub fn lt(&self, column_value: u64) -> bool {
match self {
PrecomputedAfterKey::Exact(v) => *v < column_value,
// a value equal to the next is greater than the after key
PrecomputedAfterKey::Next(v) => *v <= column_value,
PrecomputedAfterKey::AfterLast => false,
}
}
fn precompute_ip_addr(column: &Column<u64>, key: &Ipv6Addr) -> crate::Result<Self> {
// For IP addresses we need to find the compact space value.
// We try to convert via the column's min/max range scan.
// Since CompactSpaceU64Accessor::u128_to_compact is not public,
// we search linearly for the exact u64 value by scanning column values.
let ip_u128 = key.to_bits();
// Scan for matching value - IP columns are typically small
let num_vals = column.values.num_vals();
let mut found_exact = false;
let mut exact_compact = 0u64;
let mut best_next: Option<u64> = None;
for doc_id in 0..num_vals {
let val = column.values.get_val(doc_id);
// We need the CompactSpaceU64Accessor to convert compact to u128
let compact_accessor = column
.values
.clone()
.downcast_arc::<CompactSpaceU64Accessor>()
.map_err(|_| {
TantivyError::AggregationError(
crate::aggregation::AggregationError::InternalError(
"type mismatch: could not downcast to CompactSpaceU64Accessor"
.to_string(),
),
)
})?;
let val_u128 = compact_accessor.compact_to_u128(val as u32);
if val_u128 == ip_u128 {
found_exact = true;
exact_compact = val;
break;
} else if val_u128 > ip_u128 {
match best_next {
None => best_next = Some(val),
Some(current_best) => {
let current_u128 = compact_accessor.compact_to_u128(current_best as u32);
if val_u128 < current_u128 {
best_next = Some(val);
}
}
}
}
}
if found_exact {
Ok(PrecomputedAfterKey::Exact(exact_compact))
} else if let Some(next) = best_next {
Ok(PrecomputedAfterKey::Next(next))
} else {
Ok(PrecomputedAfterKey::AfterLast)
}
}
fn precompute_term_ord(
str_dict_column: &Option<StrColumn>,
key: &str,
field: &str,
) -> crate::Result<Self> {
let dict = str_dict_column
.as_ref()
.expect("dictionary missing for str accessor")
.dictionary();
let next_ord = dict.term_ord_or_next(key).map_err(|_| {
TantivyError::InvalidArgument(format!(
"failed to lookup after_key '{}' for field '{}'",
key, field
))
})?;
Ok(next_ord.into())
}
/// Projects the after key into the column space of the given accessor.
///
/// The computed after key will not take care of skipping entire columns
/// when the after key type is ordered after the accessor's type, that
/// should be performed earlier.
pub fn precompute(
composite_accessor: &CompositeAccessor,
source_after_key: &CompositeIntermediateKey,
field: &str,
missing_order: MissingOrder,
order: Order,
) -> crate::Result<Self> {
use CompositeIntermediateKey as CIKey;
let precomputed_key = match (composite_accessor.column_type, source_after_key) {
(ColumnType::Bytes, _) => panic!("unsupported"),
// null after key
(_, CIKey::Null) => precompute_missing_after_key(false, missing_order, order),
// numerical
(ColumnType::I64, CIKey::I64(k)) => PrecomputedAfterKey::Exact(k.to_u64()),
(ColumnType::I64, CIKey::U64(k)) => num_proj::u64_to_i64(*k).into(),
(ColumnType::I64, CIKey::F64(k)) => num_proj::f64_to_i64(*k).into(),
(ColumnType::U64, CIKey::I64(k)) => num_proj::i64_to_u64(*k).into(),
(ColumnType::U64, CIKey::U64(k)) => PrecomputedAfterKey::Exact(*k),
(ColumnType::U64, CIKey::F64(k)) => num_proj::f64_to_u64(*k).into(),
(ColumnType::F64, CIKey::I64(k)) => num_proj::i64_to_f64(*k).into(),
(ColumnType::F64, CIKey::U64(k)) => num_proj::u64_to_f64(*k).into(),
(ColumnType::F64, CIKey::F64(k)) => PrecomputedAfterKey::Exact(k.to_u64()),
// boolean
(ColumnType::Bool, CIKey::Bool(key)) => PrecomputedAfterKey::Exact(key.to_u64()),
// string
(ColumnType::Str, CIKey::Str(key)) => PrecomputedAfterKey::precompute_term_ord(
&composite_accessor.str_dict_column,
key,
field,
)?,
// date time
(ColumnType::DateTime, CIKey::DateTime(key)) => {
PrecomputedAfterKey::Exact(key.to_u64())
}
// ip address
(ColumnType::IpAddr, CIKey::IpAddr(key)) => {
PrecomputedAfterKey::precompute_ip_addr(&composite_accessor.column, key)?
}
// assume the column's type is ordered after the after_key's type
_ => PrecomputedAfterKey::keep_all(order),
};
Ok(precomputed_key)
}
fn keep_all(order: Order) -> Self {
match order {
Order::Asc => PrecomputedAfterKey::Next(0),
Order::Desc => PrecomputedAfterKey::Next(u64::MAX),
}
}
}

View File

@@ -1,140 +0,0 @@
use time::convert::{Day, Nanosecond};
use time::{Time, UtcDateTime};
const NS_IN_DAY: i64 = Nanosecond::per_t::<i128>(Day) as i64;
/// Computes the timestamp in nanoseconds corresponding to the beginning of the
/// year (January 1st at midnight UTC).
pub(super) fn try_year_bucket(timestamp_ns: i64) -> crate::Result<i64> {
year_bucket_using_time_crate(timestamp_ns).map_err(|e| {
crate::TantivyError::InvalidArgument(format!(
"Failed to compute year bucket for timestamp {}: {}",
timestamp_ns,
e.to_string()
))
})
}
/// Computes the timestamp in nanoseconds corresponding to the beginning of the
/// month (1st at midnight UTC).
pub(super) fn try_month_bucket(timestamp_ns: i64) -> crate::Result<i64> {
month_bucket_using_time_crate(timestamp_ns).map_err(|e| {
crate::TantivyError::InvalidArgument(format!(
"Failed to compute month bucket for timestamp {}: {}",
timestamp_ns,
e.to_string()
))
})
}
/// Computes the timestamp in nanoseconds corresponding to the beginning of the
/// week (Monday at midnight UTC).
pub(super) fn week_bucket(timestamp_ns: i64) -> i64 {
// 1970-01-01 was a Thursday (weekday = 4)
let days_since_epoch = timestamp_ns.div_euclid(NS_IN_DAY);
// Find the weekday: 0=Monday, ..., 6=Sunday
let weekday = (days_since_epoch + 3).rem_euclid(7);
let monday_days_since_epoch = days_since_epoch - weekday;
monday_days_since_epoch * NS_IN_DAY
}
fn year_bucket_using_time_crate(timestamp_ns: i64) -> Result<i64, time::Error> {
let timestamp_ns = UtcDateTime::from_unix_timestamp_nanos(timestamp_ns as i128)?
.replace_ordinal(1)?
.replace_time(Time::MIDNIGHT)
.unix_timestamp_nanos();
Ok(timestamp_ns as i64)
}
fn month_bucket_using_time_crate(timestamp_ns: i64) -> Result<i64, time::Error> {
let timestamp_ns = UtcDateTime::from_unix_timestamp_nanos(timestamp_ns as i128)?
.replace_day(1)?
.replace_time(Time::MIDNIGHT)
.unix_timestamp_nanos();
Ok(timestamp_ns as i64)
}
#[cfg(test)]
mod tests {
use std::i64;
use time::format_description::well_known::Iso8601;
use time::UtcDateTime;
use super::*;
fn ts_ns(iso: &str) -> i64 {
UtcDateTime::parse(iso, &Iso8601::DEFAULT)
.unwrap()
.unix_timestamp_nanos() as i64
}
#[test]
fn test_year_bucket() {
let ts = ts_ns("1970-01-01T00:00:00Z");
let res = try_year_bucket(ts).unwrap();
assert_eq!(res, ts_ns("1970-01-01T00:00:00Z"));
let ts = ts_ns("1970-06-01T10:00:01.010Z");
let res = try_year_bucket(ts).unwrap();
assert_eq!(res, ts_ns("1970-01-01T00:00:00Z"));
let ts = ts_ns("2008-12-31T23:59:59.999999999Z"); // leap year
let res = try_year_bucket(ts).unwrap();
assert_eq!(res, ts_ns("2008-01-01T00:00:00Z"));
let ts = ts_ns("2008-01-01T00:00:00Z"); // leap year
let res = try_year_bucket(ts).unwrap();
assert_eq!(res, ts_ns("2008-01-01T00:00:00Z"));
let ts = ts_ns("2010-12-31T23:59:59.999999999Z");
let res = try_year_bucket(ts).unwrap();
assert_eq!(res, ts_ns("2010-01-01T00:00:00Z"));
let ts = ts_ns("1972-06-01T00:10:00Z");
let res = try_year_bucket(ts).unwrap();
assert_eq!(res, ts_ns("1972-01-01T00:00:00Z"));
}
#[test]
fn test_month_bucket() {
let ts = ts_ns("1970-01-15T00:00:00Z");
let res = try_month_bucket(ts).unwrap();
assert_eq!(res, ts_ns("1970-01-01T00:00:00Z"));
let ts = ts_ns("1970-02-01T00:00:00Z");
let res = try_month_bucket(ts).unwrap();
assert_eq!(res, ts_ns("1970-02-01T00:00:00Z"));
let ts = ts_ns("2000-01-31T23:59:59.999999999Z");
let res = try_month_bucket(ts).unwrap();
assert_eq!(res, ts_ns("2000-01-01T00:00:00Z"));
}
#[test]
fn test_week_bucket() {
let ts = ts_ns("1970-01-05T00:00:00Z"); // Monday
let res = week_bucket(ts);
assert_eq!(res, ts_ns("1970-01-05T00:00:00Z"));
let ts = ts_ns("1970-01-05T23:59:59Z"); // Monday
let res = week_bucket(ts);
assert_eq!(res, ts_ns("1970-01-05T00:00:00Z"));
let ts = ts_ns("1970-01-07T01:13:00Z"); // Wednesday
let res = week_bucket(ts);
assert_eq!(res, ts_ns("1970-01-05T00:00:00Z"));
let ts = ts_ns("1970-01-11T23:59:59.999999999Z"); // Sunday
let res = week_bucket(ts);
assert_eq!(res, ts_ns("1970-01-05T00:00:00Z"));
let ts = ts_ns("2025-10-16T10:41:59.010Z"); // Thursday
let res = week_bucket(ts);
assert_eq!(res, ts_ns("2025-10-13T00:00:00Z"));
let ts = ts_ns("1970-01-01T00:00:00Z"); // Thursday
let res = week_bucket(ts);
assert_eq!(res, ts_ns("1969-12-29T00:00:00Z")); // Negative
}
}

View File

@@ -1,674 +0,0 @@
use std::fmt::Debug;
use std::mem;
use std::net::Ipv6Addr;
use columnar::column_values::CompactSpaceU64Accessor;
use columnar::{
Column, ColumnType, Dictionary, MonotonicallyMappableToU128, MonotonicallyMappableToU64,
NumericalValue, StrColumn,
};
use rustc_hash::FxHashMap;
use smallvec::SmallVec;
use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::bucket::composite::accessors::{
CompositeAccessor, CompositeAggReqData, PrecomputedDateInterval,
};
use crate::aggregation::bucket::composite::calendar_interval;
use crate::aggregation::bucket::composite::map::{DynArrayHeapMap, MAX_DYN_ARRAY_SIZE};
use crate::aggregation::bucket::{
CalendarInterval, CompositeAggregationSource, MissingOrder, Order,
};
use crate::aggregation::cached_sub_aggs::{CachedSubAggs, HighCardSubAggCache};
use crate::aggregation::intermediate_agg_result::{
CompositeIntermediateKey, IntermediateAggregationResult, IntermediateAggregationResults,
IntermediateBucketResult, IntermediateCompositeBucketEntry, IntermediateCompositeBucketResult,
};
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
use crate::aggregation::BucketId;
use crate::TantivyError;
#[derive(Clone, Debug)]
struct CompositeBucketCollector {
count: u32,
bucket_id: BucketId,
}
/// Compact sortable representation of a single source value within a composite key.
///
/// The struct encodes both the column identity and the fast field value in a way
/// that preserves the desired sort order via the derived `Ord` implementation
/// (fields are compared top-to-bottom: `sort_key` first, then `encoded_value`).
///
/// ## `sort_key` encoding
/// - `0` — missing value, sorted first
/// - `1..=254` — present value; the original accessor index is `sort_key - 1`
/// - `u8::MAX` (255) — missing value, sorted last
///
/// ## `encoded_value` encoding
/// - `0` when the field is missing
/// - The raw u64 fast-field representation when order is ascending
/// - Bitwise NOT of the raw u64 when order is descending
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Default, Hash)]
struct InternalValueRepr {
/// Column index biased by +1 (so 0 and u8::MAX are reserved for missing sentinels).
sort_key: u8,
/// Fast field value, possibly bit-flipped for descending order.
encoded_value: u64,
}
impl InternalValueRepr {
#[inline]
fn new_term(raw: u64, accessor_idx: u8, order: Order) -> Self {
let encoded_value = match order {
Order::Asc => raw,
Order::Desc => !raw,
};
InternalValueRepr {
sort_key: accessor_idx + 1,
encoded_value,
}
}
/// For histogram sources the column index is irrelevant (always 1).
#[inline]
fn new_histogram(raw: u64, order: Order) -> Self {
let encoded_value = match order {
Order::Asc => raw,
Order::Desc => !raw,
};
InternalValueRepr {
sort_key: 1,
encoded_value,
}
}
#[inline]
fn new_missing(order: Order, missing_order: MissingOrder) -> Self {
let sort_key = match (missing_order, order) {
(MissingOrder::First, _) | (MissingOrder::Default, Order::Asc) => 0,
(MissingOrder::Last, _) | (MissingOrder::Default, Order::Desc) => u8::MAX,
};
InternalValueRepr {
sort_key,
encoded_value: 0,
}
}
/// Decode back to `(accessor_idx, raw_value)`.
/// Returns `None` when the value represents a missing field.
#[inline]
fn decode(self, order: Order) -> Option<(u8, u64)> {
if self.sort_key == 0 || self.sort_key == u8::MAX {
return None;
}
let raw = match order {
Order::Asc => self.encoded_value,
Order::Desc => !self.encoded_value,
};
Some((self.sort_key - 1, raw))
}
}
/// The collector puts values from the fast field into the correct buckets and
/// does a conversion to the correct datatype.
#[derive(Debug)]
pub struct SegmentCompositeCollector {
/// One DynArrayHeapMap per parent bucket.
parent_buckets: Vec<DynArrayHeapMap<InternalValueRepr, CompositeBucketCollector>>,
accessor_idx: usize,
sub_agg: Option<CachedSubAggs<HighCardSubAggCache>>,
bucket_id_provider: BucketIdProvider,
/// Number of sources, needed when creating new DynArrayHeapMaps.
num_sources: usize,
}
impl SegmentAggregationCollector for SegmentCompositeCollector {
fn add_intermediate_aggregation_result(
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let name = agg_data
.get_composite_req_data(self.accessor_idx)
.name
.clone();
let buckets = self.into_intermediate_bucket_result(agg_data, parent_bucket_id)?;
results.push(
name,
IntermediateAggregationResult::Bucket(IntermediateBucketResult::Composite { buckets }),
)?;
Ok(())
}
fn collect(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let mem_pre = self.get_memory_consumption();
let composite_agg_data = agg_data.take_composite_req_data(self.accessor_idx);
for doc in docs {
let mut sub_level_values = SmallVec::new();
recursive_key_visitor(
*doc,
&composite_agg_data,
0,
&mut sub_level_values,
&mut self.parent_buckets[parent_bucket_id as usize],
true,
&mut self.sub_agg,
&mut self.bucket_id_provider,
)?;
}
agg_data.put_back_composite_req_data(self.accessor_idx, composite_agg_data);
if let Some(sub_agg) = &mut self.sub_agg {
sub_agg.check_flush_local(agg_data)?;
}
let mem_delta = self.get_memory_consumption() - mem_pre;
if mem_delta > 0 {
agg_data.context.limits.add_memory_consumed(mem_delta)?;
}
Ok(())
}
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
if let Some(sub_agg) = &mut self.sub_agg {
sub_agg.flush(agg_data)?;
}
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
let required_len = max_bucket as usize + 1;
while self.parent_buckets.len() < required_len {
let map = DynArrayHeapMap::try_new(self.num_sources)?;
self.parent_buckets.push(map);
}
Ok(())
}
}
impl SegmentCompositeCollector {
fn get_memory_consumption(&self) -> u64 {
self.parent_buckets
.iter()
.map(|m| m.memory_consumption())
.sum()
}
pub(crate) fn from_req_and_validate(
req_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Self> {
validate_req(req_data, node.idx_in_req_data)?;
let has_sub_aggregations = !node.children.is_empty();
let sub_agg = if has_sub_aggregations {
let sub_agg_collector = build_segment_agg_collectors(req_data, &node.children)?;
Some(CachedSubAggs::new(sub_agg_collector))
} else {
None
};
let composite_req_data = req_data.get_composite_req_data(node.idx_in_req_data);
let num_sources = composite_req_data.req.sources.len();
Ok(SegmentCompositeCollector {
parent_buckets: vec![DynArrayHeapMap::try_new(num_sources)?],
accessor_idx: node.idx_in_req_data,
sub_agg,
bucket_id_provider: BucketIdProvider::default(),
num_sources,
})
}
#[inline]
fn into_intermediate_bucket_result(
&mut self,
agg_data: &AggregationsSegmentCtx,
parent_bucket_id: BucketId,
) -> crate::Result<IntermediateCompositeBucketResult> {
let empty_map = DynArrayHeapMap::try_new(self.num_sources)?;
let heap_map = mem::replace(
&mut self.parent_buckets[parent_bucket_id as usize],
empty_map,
);
let mut dict: FxHashMap<Vec<CompositeIntermediateKey>, IntermediateCompositeBucketEntry> =
Default::default();
dict.reserve(heap_map.size());
let composite_data = agg_data.get_composite_req_data(self.accessor_idx);
for (key_internal_repr, agg) in heap_map.into_iter() {
let key = resolve_key(&key_internal_repr, composite_data)?;
let mut sub_aggregation_res = IntermediateAggregationResults::default();
if let Some(sub_agg) = &mut self.sub_agg {
sub_agg
.get_sub_agg_collector()
.add_intermediate_aggregation_result(
agg_data,
&mut sub_aggregation_res,
agg.bucket_id,
)?;
}
dict.insert(
key,
IntermediateCompositeBucketEntry {
doc_count: agg.count,
sub_aggregation: sub_aggregation_res,
},
);
}
Ok(IntermediateCompositeBucketResult {
entries: dict,
target_size: composite_data.req.size,
orders: composite_data
.req
.sources
.iter()
.map(|source| match source {
CompositeAggregationSource::Terms(t) => (t.order, t.missing_order),
CompositeAggregationSource::Histogram(h) => (h.order, h.missing_order),
CompositeAggregationSource::DateHistogram(d) => (d.order, d.missing_order),
})
.collect(),
})
}
}
fn validate_req(req_data: &mut AggregationsSegmentCtx, accessor_idx: usize) -> crate::Result<()> {
let composite_data = req_data.get_composite_req_data(accessor_idx);
let req = &composite_data.req;
if req.sources.is_empty() {
return Err(TantivyError::InvalidArgument(
"composite aggregation must have at least one source".to_string(),
));
}
if req.size == 0 {
return Err(TantivyError::InvalidArgument(
"composite aggregation 'size' must be > 0".to_string(),
));
}
let column_types_for_sources = composite_data.composite_accessors.iter().map(|item| {
item.accessors
.iter()
.map(|a| a.column_type)
.collect::<Vec<_>>()
});
for column_types in column_types_for_sources {
if column_types.len() > MAX_DYN_ARRAY_SIZE {
return Err(TantivyError::InvalidArgument(format!(
"composite aggregation source supports maximum {MAX_DYN_ARRAY_SIZE} sources",
)));
}
if column_types.contains(&ColumnType::Bytes) {
return Err(TantivyError::InvalidArgument(
"composite aggregation does not support 'bytes' field type".to_string(),
));
}
}
Ok(())
}
fn collect_bucket_with_limit(
doc_id: crate::DocId,
limit_num_buckets: usize,
buckets: &mut DynArrayHeapMap<InternalValueRepr, CompositeBucketCollector>,
key: &[InternalValueRepr],
sub_agg: &mut Option<CachedSubAggs<HighCardSubAggCache>>,
bucket_id_provider: &mut BucketIdProvider,
) {
let mut record_in_bucket = |bucket: &mut CompositeBucketCollector| {
bucket.count += 1;
if let Some(sub_agg) = sub_agg {
sub_agg.push(bucket.bucket_id, doc_id);
}
};
// We still have room for buckets, just insert
if buckets.size() < limit_num_buckets {
let bucket = buckets.get_or_insert_with(key, || CompositeBucketCollector {
count: 0,
bucket_id: bucket_id_provider.next_bucket_id(),
});
record_in_bucket(bucket);
return;
}
// Map is full, but we can still update the bucket if it already exists
if let Some(bucket) = buckets.get_mut(key) {
record_in_bucket(bucket);
return;
}
// Check if the item qualifies to enter the top-k, and evict the highest if it does
if let Some(highest_key) = buckets.peek_highest() {
if key < highest_key {
buckets.evict_highest();
let bucket = buckets.get_or_insert_with(key, || CompositeBucketCollector {
count: 0,
bucket_id: bucket_id_provider.next_bucket_id(),
});
record_in_bucket(bucket);
}
}
}
/// Converts the composite key from its internal column space representation
/// (segment specific) into its intermediate form.
fn resolve_key(
internal_key: &[InternalValueRepr],
agg_data: &CompositeAggReqData,
) -> crate::Result<Vec<CompositeIntermediateKey>> {
internal_key
.iter()
.enumerate()
.map(|(idx, val)| {
resolve_internal_value_repr(
*val,
&agg_data.req.sources[idx],
&agg_data.composite_accessors[idx].accessors,
)
})
.collect()
}
fn resolve_internal_value_repr(
internal_value_repr: InternalValueRepr,
source: &CompositeAggregationSource,
composite_accessors: &[CompositeAccessor],
) -> crate::Result<CompositeIntermediateKey> {
let decoded_value_opt = match source {
CompositeAggregationSource::Terms(source) => internal_value_repr.decode(source.order),
CompositeAggregationSource::Histogram(source) => internal_value_repr.decode(source.order),
CompositeAggregationSource::DateHistogram(source) => {
internal_value_repr.decode(source.order)
}
};
let Some((decoded_accessor_idx, val)) = decoded_value_opt else {
return Ok(CompositeIntermediateKey::Null);
};
let key = match source {
CompositeAggregationSource::Terms(_) => {
let CompositeAccessor {
column_type,
str_dict_column,
column,
..
} = &composite_accessors[decoded_accessor_idx as usize];
resolve_term(val, column_type, str_dict_column, column)?
}
CompositeAggregationSource::Histogram(source) => {
CompositeIntermediateKey::F64(i64::from_u64(val) as f64 * source.interval)
}
CompositeAggregationSource::DateHistogram(_) => {
CompositeIntermediateKey::DateTime(i64::from_u64(val))
}
};
Ok(key)
}
fn resolve_term(
val: u64,
column_type: &ColumnType,
str_dict_column: &Option<StrColumn>,
column: &Column,
) -> crate::Result<CompositeIntermediateKey> {
let key = if *column_type == ColumnType::Str {
let fallback_dict = Dictionary::empty();
let term_dict = str_dict_column
.as_ref()
.map(|el| el.dictionary())
.unwrap_or_else(|| &fallback_dict);
let mut buffer = Vec::new();
term_dict.ord_to_term(val, &mut buffer)?;
CompositeIntermediateKey::Str(
String::from_utf8(buffer.to_vec()).expect("could not convert to String"),
)
} else if *column_type == ColumnType::DateTime {
let val = i64::from_u64(val);
CompositeIntermediateKey::DateTime(val)
} else if *column_type == ColumnType::Bool {
let val = bool::from_u64(val);
CompositeIntermediateKey::Bool(val)
} else if *column_type == ColumnType::IpAddr {
let compact_space_accessor = column
.values
.clone()
.downcast_arc::<CompactSpaceU64Accessor>()
.map_err(|_| {
TantivyError::AggregationError(crate::aggregation::AggregationError::InternalError(
"Type mismatch: Could not downcast to CompactSpaceU64Accessor".to_string(),
))
})?;
let val: u128 = compact_space_accessor.compact_to_u128(val as u32);
let val = Ipv6Addr::from_u128(val);
CompositeIntermediateKey::IpAddr(val)
} else if *column_type == ColumnType::U64 {
CompositeIntermediateKey::U64(val)
} else if *column_type == ColumnType::I64 {
CompositeIntermediateKey::I64(i64::from_u64(val))
} else {
let val = f64::from_u64(val);
let val: NumericalValue = val.into();
match val.normalize() {
NumericalValue::U64(val) => CompositeIntermediateKey::U64(val),
NumericalValue::I64(val) => CompositeIntermediateKey::I64(val),
NumericalValue::F64(val) => CompositeIntermediateKey::F64(val),
}
};
Ok(key)
}
/// Depth-first walk of the accessors to build the composite key combinations
/// and update the buckets.
fn recursive_key_visitor(
doc_id: crate::DocId,
composite_agg_data: &CompositeAggReqData,
source_idx_for_recursion: usize,
sub_level_values: &mut SmallVec<[InternalValueRepr; MAX_DYN_ARRAY_SIZE]>,
buckets: &mut DynArrayHeapMap<InternalValueRepr, CompositeBucketCollector>,
// whether we need to consider the after_key in the following levels
is_on_after_key: bool,
sub_agg: &mut Option<CachedSubAggs<HighCardSubAggCache>>,
bucket_id_provider: &mut BucketIdProvider,
) -> crate::Result<()> {
if source_idx_for_recursion == composite_agg_data.req.sources.len() {
if !is_on_after_key {
collect_bucket_with_limit(
doc_id,
composite_agg_data.req.size as usize,
buckets,
sub_level_values,
sub_agg,
bucket_id_provider,
);
}
return Ok(());
}
let current_level_accessors = &composite_agg_data.composite_accessors[source_idx_for_recursion];
let current_level_source = &composite_agg_data.req.sources[source_idx_for_recursion];
let mut missing = true;
for (accessor_idx, accessor) in current_level_accessors.accessors.iter().enumerate() {
let values = accessor.column.values_for_doc(doc_id);
for value in values {
missing = false;
match current_level_source {
CompositeAggregationSource::Terms(_) => {
let preceeds_after_key_type =
accessor_idx < current_level_accessors.after_key_accessor_idx;
if is_on_after_key && preceeds_after_key_type {
break;
}
let matches_after_key_type =
accessor_idx == current_level_accessors.after_key_accessor_idx;
if matches_after_key_type && is_on_after_key {
let should_skip = match current_level_source.order() {
Order::Asc => current_level_accessors.after_key.gt(value),
Order::Desc => current_level_accessors.after_key.lt(value),
};
if should_skip {
continue;
}
}
sub_level_values.push(InternalValueRepr::new_term(
value,
accessor_idx as u8,
current_level_source.order(),
));
let still_on_after_key =
matches_after_key_type && current_level_accessors.after_key.equals(value);
recursive_key_visitor(
doc_id,
composite_agg_data,
source_idx_for_recursion + 1,
sub_level_values,
buckets,
is_on_after_key && still_on_after_key,
sub_agg,
bucket_id_provider,
)?;
sub_level_values.pop();
}
CompositeAggregationSource::Histogram(source) => {
let float_value = match accessor.column_type {
ColumnType::U64 => value as f64,
ColumnType::I64 => i64::from_u64(value) as f64,
ColumnType::DateTime => i64::from_u64(value) as f64 / 1_000_000.,
ColumnType::F64 => f64::from_u64(value),
_ => {
panic!(
"unexpected type {:?}. This should not happen",
accessor.column_type
)
}
};
let bucket_index = (float_value / source.interval).floor() as i64;
let bucket_value = i64::to_u64(bucket_index);
if is_on_after_key {
let should_skip = match current_level_source.order() {
Order::Asc => current_level_accessors.after_key.gt(bucket_value),
Order::Desc => current_level_accessors.after_key.lt(bucket_value),
};
if should_skip {
continue;
}
}
sub_level_values.push(InternalValueRepr::new_histogram(
bucket_value,
current_level_source.order(),
));
let still_on_after_key = current_level_accessors.after_key.equals(bucket_value);
recursive_key_visitor(
doc_id,
composite_agg_data,
source_idx_for_recursion + 1,
sub_level_values,
buckets,
is_on_after_key && still_on_after_key,
sub_agg,
bucket_id_provider,
)?;
sub_level_values.pop();
}
CompositeAggregationSource::DateHistogram(_) => {
let value_ns = match accessor.column_type {
ColumnType::DateTime => i64::from_u64(value),
_ => {
panic!(
"unexpected type {:?}. This should not happen",
accessor.column_type
)
}
};
let bucket_index = match accessor.date_histogram_interval {
PrecomputedDateInterval::FixedNanoseconds(fixed_interval_ns) => {
(value_ns / fixed_interval_ns) * fixed_interval_ns
}
PrecomputedDateInterval::Calendar(CalendarInterval::Year) => {
calendar_interval::try_year_bucket(value_ns)?
}
PrecomputedDateInterval::Calendar(CalendarInterval::Month) => {
calendar_interval::try_month_bucket(value_ns)?
}
PrecomputedDateInterval::Calendar(CalendarInterval::Week) => {
calendar_interval::week_bucket(value_ns)
}
PrecomputedDateInterval::NotApplicable => {
panic!("interval not precomputed for date histogram source")
}
};
let bucket_value = i64::to_u64(bucket_index);
if is_on_after_key {
let should_skip = match current_level_source.order() {
Order::Asc => current_level_accessors.after_key.gt(bucket_value),
Order::Desc => current_level_accessors.after_key.lt(bucket_value),
};
if should_skip {
continue;
}
}
sub_level_values.push(InternalValueRepr::new_histogram(
bucket_value,
current_level_source.order(),
));
let still_on_after_key = current_level_accessors.after_key.equals(bucket_value);
recursive_key_visitor(
doc_id,
composite_agg_data,
source_idx_for_recursion + 1,
sub_level_values,
buckets,
is_on_after_key && still_on_after_key,
sub_agg,
bucket_id_provider,
)?;
sub_level_values.pop();
}
};
}
}
if missing && current_level_source.missing_bucket() {
if is_on_after_key && current_level_accessors.skip_missing {
return Ok(());
}
sub_level_values.push(InternalValueRepr::new_missing(
current_level_source.order(),
current_level_source.missing_order(),
));
recursive_key_visitor(
doc_id,
composite_agg_data,
source_idx_for_recursion + 1,
sub_level_values,
buckets,
is_on_after_key && current_level_accessors.is_after_key_explicit_missing,
sub_agg,
bucket_id_provider,
)?;
sub_level_values.pop();
}
Ok(())
}

View File

@@ -1,329 +0,0 @@
use std::collections::BinaryHeap;
use std::fmt::Debug;
use std::hash::Hash;
use rustc_hash::FxHashMap;
use smallvec::SmallVec;
use crate::TantivyError;
/// Map backed by a hash map for fast access and a binary heap to track the
/// highest key. The key is an array of fixed size S.
#[derive(Clone, Debug)]
struct ArrayHeapMap<K: Ord, V, const S: usize> {
pub(crate) buckets: FxHashMap<[K; S], V>,
pub(crate) heap: BinaryHeap<[K; S]>,
}
impl<K: Ord, V, const S: usize> Default for ArrayHeapMap<K, V, S> {
fn default() -> Self {
ArrayHeapMap {
buckets: FxHashMap::default(),
heap: BinaryHeap::default(),
}
}
}
impl<K: Eq + Hash + Clone + Ord, V, const S: usize> ArrayHeapMap<K, V, S> {
/// Panics if the length of `key` is not S.
fn get_or_insert_with<F: FnOnce() -> V>(&mut self, key: &[K], f: F) -> &mut V {
let key_array: &[K; S] = key.try_into().expect("Key length mismatch");
self.buckets.entry(key_array.clone()).or_insert_with(|| {
self.heap.push(key_array.clone());
f()
})
}
/// Panics if the length of `key` is not S.
fn get_mut(&mut self, key: &[K]) -> Option<&mut V> {
let key_array: &[K; S] = key.try_into().expect("Key length mismatch");
self.buckets.get_mut(key_array)
}
fn peek_highest(&self) -> Option<&[K]> {
self.heap.peek().map(|k_array| k_array.as_slice())
}
fn evict_highest(&mut self) {
if let Some(highest) = self.heap.pop() {
self.buckets.remove(&highest);
}
}
fn memory_consumption(&self) -> u64 {
let key_size = std::mem::size_of::<[K; S]>();
let map_size = (key_size + std::mem::size_of::<V>()) * self.buckets.capacity();
let heap_size = key_size * self.heap.capacity();
(map_size + heap_size) as u64
}
}
impl<K: Copy + Ord + Clone + 'static, V: 'static, const S: usize> ArrayHeapMap<K, V, S> {
fn into_iter(self) -> Box<dyn Iterator<Item = (SmallVec<[K; MAX_DYN_ARRAY_SIZE]>, V)>> {
Box::new(
self.buckets
.into_iter()
.map(|(k, v)| (SmallVec::from_slice(&k), v)),
)
}
}
pub(super) const MAX_DYN_ARRAY_SIZE: usize = 16;
const MAX_DYN_ARRAY_SIZE_PLUS_ONE: usize = MAX_DYN_ARRAY_SIZE + 1;
/// A map optimized for memory footprint, fast access and efficient eviction of
/// the highest key.
///
/// Keys are inlined arrays of size 1 to [MAX_DYN_ARRAY_SIZE] but for a given
/// instance the key size is fixed. This allows to avoid heap allocations for the
/// keys.
#[derive(Clone, Debug)]
pub(super) struct DynArrayHeapMap<K: Ord, V>(DynArrayHeapMapInner<K, V>);
/// Wrapper around ArrayHeapMap to dynamically dispatch on the array size.
#[derive(Clone, Debug)]
enum DynArrayHeapMapInner<K: Ord, V> {
Dim1(ArrayHeapMap<K, V, 1>),
Dim2(ArrayHeapMap<K, V, 2>),
Dim3(ArrayHeapMap<K, V, 3>),
Dim4(ArrayHeapMap<K, V, 4>),
Dim5(ArrayHeapMap<K, V, 5>),
Dim6(ArrayHeapMap<K, V, 6>),
Dim7(ArrayHeapMap<K, V, 7>),
Dim8(ArrayHeapMap<K, V, 8>),
Dim9(ArrayHeapMap<K, V, 9>),
Dim10(ArrayHeapMap<K, V, 10>),
Dim11(ArrayHeapMap<K, V, 11>),
Dim12(ArrayHeapMap<K, V, 12>),
Dim13(ArrayHeapMap<K, V, 13>),
Dim14(ArrayHeapMap<K, V, 14>),
Dim15(ArrayHeapMap<K, V, 15>),
Dim16(ArrayHeapMap<K, V, 16>),
}
impl<K: Ord, V> DynArrayHeapMap<K, V> {
/// Creates a new heap map with dynamic array keys of size `key_dimension`.
pub(super) fn try_new(key_dimension: usize) -> crate::Result<Self> {
let inner = match key_dimension {
0 => {
return Err(TantivyError::InvalidArgument(
"DynArrayHeapMap dimension must be at least 1".to_string(),
))
}
1 => DynArrayHeapMapInner::Dim1(ArrayHeapMap::default()),
2 => DynArrayHeapMapInner::Dim2(ArrayHeapMap::default()),
3 => DynArrayHeapMapInner::Dim3(ArrayHeapMap::default()),
4 => DynArrayHeapMapInner::Dim4(ArrayHeapMap::default()),
5 => DynArrayHeapMapInner::Dim5(ArrayHeapMap::default()),
6 => DynArrayHeapMapInner::Dim6(ArrayHeapMap::default()),
7 => DynArrayHeapMapInner::Dim7(ArrayHeapMap::default()),
8 => DynArrayHeapMapInner::Dim8(ArrayHeapMap::default()),
9 => DynArrayHeapMapInner::Dim9(ArrayHeapMap::default()),
10 => DynArrayHeapMapInner::Dim10(ArrayHeapMap::default()),
11 => DynArrayHeapMapInner::Dim11(ArrayHeapMap::default()),
12 => DynArrayHeapMapInner::Dim12(ArrayHeapMap::default()),
13 => DynArrayHeapMapInner::Dim13(ArrayHeapMap::default()),
14 => DynArrayHeapMapInner::Dim14(ArrayHeapMap::default()),
15 => DynArrayHeapMapInner::Dim15(ArrayHeapMap::default()),
16 => DynArrayHeapMapInner::Dim16(ArrayHeapMap::default()),
MAX_DYN_ARRAY_SIZE_PLUS_ONE.. => {
return Err(TantivyError::InvalidArgument(format!(
"DynArrayHeapMap supports maximum {MAX_DYN_ARRAY_SIZE} dimensions, got \
{key_dimension}",
)))
}
};
Ok(DynArrayHeapMap(inner))
}
/// Number of elements in the map. This is not the dimension of the keys.
pub(super) fn size(&self) -> usize {
match &self.0 {
DynArrayHeapMapInner::Dim1(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim2(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim3(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim4(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim5(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim6(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim7(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim8(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim9(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim10(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim11(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim12(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim13(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim14(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim15(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim16(map) => map.buckets.len(),
}
}
}
impl<K: Ord + Hash + Clone, V> DynArrayHeapMap<K, V> {
/// Get a mutable reference to the value corresponding to `key` or inserts a new
/// value created by calling `f`.
///
/// Panics if the length of `key` does not match the key dimension of the map.
pub(super) fn get_or_insert_with<F: FnOnce() -> V>(&mut self, key: &[K], f: F) -> &mut V {
match &mut self.0 {
DynArrayHeapMapInner::Dim1(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim2(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim3(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim4(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim5(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim6(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim7(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim8(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim9(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim10(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim11(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim12(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim13(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim14(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim15(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim16(map) => map.get_or_insert_with(key, f),
}
}
/// Returns a mutable reference to the value corresponding to `key`.
///
/// Panics if the length of `key` does not match the key dimension of the map.
pub fn get_mut(&mut self, key: &[K]) -> Option<&mut V> {
match &mut self.0 {
DynArrayHeapMapInner::Dim1(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim2(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim3(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim4(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim5(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim6(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim7(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim8(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim9(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim10(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim11(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim12(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim13(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim14(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim15(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim16(map) => map.get_mut(key),
}
}
/// Returns a reference to the highest key in the map.
pub(super) fn peek_highest(&self) -> Option<&[K]> {
match &self.0 {
DynArrayHeapMapInner::Dim1(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim2(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim3(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim4(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim5(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim6(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim7(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim8(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim9(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim10(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim11(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim12(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim13(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim14(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim15(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim16(map) => map.peek_highest(),
}
}
/// Removes the entry with the highest key from the map.
pub(super) fn evict_highest(&mut self) {
match &mut self.0 {
DynArrayHeapMapInner::Dim1(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim2(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim3(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim4(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim5(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim6(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim7(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim8(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim9(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim10(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim11(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim12(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim13(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim14(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim15(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim16(map) => map.evict_highest(),
}
}
pub(crate) fn memory_consumption(&self) -> u64 {
match &self.0 {
DynArrayHeapMapInner::Dim1(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim2(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim3(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim4(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim5(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim6(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim7(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim8(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim9(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim10(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim11(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim12(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim13(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim14(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim15(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim16(map) => map.memory_consumption(),
}
}
}
impl<K: Ord + Clone + Copy + 'static, V: 'static> DynArrayHeapMap<K, V> {
/// Turns this map into an iterator over key-value pairs.
pub fn into_iter(self) -> impl Iterator<Item = (SmallVec<[K; MAX_DYN_ARRAY_SIZE]>, V)> {
match self.0 {
DynArrayHeapMapInner::Dim1(map) => map.into_iter(),
DynArrayHeapMapInner::Dim2(map) => map.into_iter(),
DynArrayHeapMapInner::Dim3(map) => map.into_iter(),
DynArrayHeapMapInner::Dim4(map) => map.into_iter(),
DynArrayHeapMapInner::Dim5(map) => map.into_iter(),
DynArrayHeapMapInner::Dim6(map) => map.into_iter(),
DynArrayHeapMapInner::Dim7(map) => map.into_iter(),
DynArrayHeapMapInner::Dim8(map) => map.into_iter(),
DynArrayHeapMapInner::Dim9(map) => map.into_iter(),
DynArrayHeapMapInner::Dim10(map) => map.into_iter(),
DynArrayHeapMapInner::Dim11(map) => map.into_iter(),
DynArrayHeapMapInner::Dim12(map) => map.into_iter(),
DynArrayHeapMapInner::Dim13(map) => map.into_iter(),
DynArrayHeapMapInner::Dim14(map) => map.into_iter(),
DynArrayHeapMapInner::Dim15(map) => map.into_iter(),
DynArrayHeapMapInner::Dim16(map) => map.into_iter(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dyn_array_heap_map() {
let mut map = DynArrayHeapMap::<u32, &str>::try_new(2).unwrap();
// insert
let key1 = [1u32, 2u32];
let key2 = [2u32, 1u32];
map.get_or_insert_with(&key1, || "a");
map.get_or_insert_with(&key2, || "b");
assert_eq!(map.size(), 2);
// evict highest
assert_eq!(map.peek_highest(), Some(&key2[..]));
map.evict_highest();
assert_eq!(map.size(), 1);
assert_eq!(map.peek_highest(), Some(&key1[..]));
// into_iter
let mut iter = map.into_iter();
let (k, v) = iter.next().unwrap();
assert_eq!(k.as_slice(), &key1);
assert_eq!(v, "c");
assert_eq!(iter.next(), None);
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,460 +0,0 @@
/// This modules helps comparing numerical values of different types (i64, u64
/// and f64).
pub(super) mod num_cmp {
use std::cmp::Ordering;
use crate::TantivyError;
pub fn cmp_i64_f64(left_i: i64, right_f: f64) -> crate::Result<Ordering> {
if right_f.is_nan() {
return Err(TantivyError::InvalidArgument(
"NaN comparison is not supported".to_string(),
));
}
// If right_f is < i64::MIN then left_i > right_f (i64::MIN=-2^63 can be
// exactly represented as f64)
if right_f < i64::MIN as f64 {
return Ok(Ordering::Greater);
}
// If right_f is >= i64::MAX then left_i < right_f (i64::MAX=2^63-1 cannot
// be exactly represented as f64)
if right_f >= i64::MAX as f64 {
return Ok(Ordering::Less);
}
// Now right_f is in (i64::MIN, i64::MAX), so `right_f as i64` is
// well-defined (truncation toward 0)
let right_as_i = right_f as i64;
let result = match left_i.cmp(&right_as_i) {
Ordering::Less => Ordering::Less,
Ordering::Greater => Ordering::Greater,
Ordering::Equal => {
// they have the same integer part, compare the fraction
let rem = right_f - (right_as_i as f64);
if rem == 0.0 {
Ordering::Equal
} else if right_f > 0.0 {
Ordering::Less
} else {
Ordering::Greater
}
}
};
Ok(result)
}
pub fn cmp_u64_f64(left_u: u64, right_f: f64) -> crate::Result<Ordering> {
if right_f.is_nan() {
return Err(TantivyError::InvalidArgument(
"NaN comparison is not supported".to_string(),
));
}
// Negative floats are always less than any u64 >= 0
if right_f < 0.0 {
return Ok(Ordering::Greater);
}
// If right_f is >= u64::MAX then left_u < right_f (u64::MAX=2^64-1 cannot be exactly)
let max_as_f = u64::MAX as f64;
if right_f > max_as_f {
return Ok(Ordering::Less);
}
// Now right_f is in (0, u64::MAX), so `right_f as u64` is well-defined
// (truncation toward 0)
let right_as_u = right_f as u64;
let result = match left_u.cmp(&right_as_u) {
Ordering::Less => Ordering::Less,
Ordering::Greater => Ordering::Greater,
Ordering::Equal => {
// they have the same integer part, compare the fraction
let rem = right_f - (right_as_u as f64);
if rem == 0.0 {
Ordering::Equal
} else {
Ordering::Less
}
}
};
Ok(result)
}
pub fn cmp_i64_u64(left_i: i64, right_u: u64) -> Ordering {
if left_i < 0 {
Ordering::Less
} else {
let left_as_u = left_i as u64;
left_as_u.cmp(&right_u)
}
}
}
/// This modules helps projecting numerical values to other numerical types.
/// When the target value space cannot exactly represent the source value, the
/// next representable value is returned (or AfterLast if the source value is
/// larger than the largest representable value).
///
/// All functions in this module assume that f64 values are not NaN.
pub(super) mod num_proj {
#[derive(Debug, PartialEq)]
pub enum ProjectedNumber<T> {
Exact(T),
Next(T),
AfterLast,
}
pub fn i64_to_u64(value: i64) -> ProjectedNumber<u64> {
if value < 0 {
ProjectedNumber::Next(0)
} else {
ProjectedNumber::Exact(value as u64)
}
}
pub fn u64_to_i64(value: u64) -> ProjectedNumber<i64> {
if value > i64::MAX as u64 {
ProjectedNumber::AfterLast
} else {
ProjectedNumber::Exact(value as i64)
}
}
pub fn f64_to_u64(value: f64) -> ProjectedNumber<u64> {
if value < 0.0 {
ProjectedNumber::Next(0)
} else if value > u64::MAX as f64 {
ProjectedNumber::AfterLast
} else if value.fract() == 0.0 {
ProjectedNumber::Exact(value as u64)
} else {
// casting f64 to u64 truncates toward zero
ProjectedNumber::Next(value as u64 + 1)
}
}
pub fn f64_to_i64(value: f64) -> ProjectedNumber<i64> {
if value < (i64::MIN as f64) {
return ProjectedNumber::Next(i64::MIN);
} else if value >= (i64::MAX as f64) {
return ProjectedNumber::AfterLast;
} else if value.fract() == 0.0 {
ProjectedNumber::Exact(value as i64)
} else if value > 0.0 {
// casting f64 to i64 truncates toward zero
ProjectedNumber::Next(value as i64 + 1)
} else {
ProjectedNumber::Next(value as i64)
}
}
pub fn i64_to_f64(value: i64) -> ProjectedNumber<f64> {
let value_f = value as f64;
let k_roundtrip = value_f as i64;
if k_roundtrip == value {
// between -2^53 and 2^53 all i64 are exactly represented as f64
ProjectedNumber::Exact(value_f)
} else {
// for very large/small i64 values, it is approximated to the closest f64
if k_roundtrip > value {
ProjectedNumber::Next(value_f)
} else {
ProjectedNumber::Next(value_f.next_up())
}
}
}
pub fn u64_to_f64(value: u64) -> ProjectedNumber<f64> {
let value_f = value as f64;
let k_roundtrip = value_f as u64;
if k_roundtrip == value {
// between 0 and 2^53 all u64 are exactly represented as f64
ProjectedNumber::Exact(value_f)
} else if k_roundtrip > value {
ProjectedNumber::Next(value_f)
} else {
ProjectedNumber::Next(value_f.next_up())
}
}
}
#[cfg(test)]
mod num_cmp_tests {
use std::cmp::Ordering;
use super::num_cmp::*;
#[test]
fn test_cmp_u64_f64() {
// Basic comparisons
assert_eq!(cmp_u64_f64(5, 5.0).unwrap(), Ordering::Equal);
assert_eq!(cmp_u64_f64(5, 6.0).unwrap(), Ordering::Less);
assert_eq!(cmp_u64_f64(6, 5.0).unwrap(), Ordering::Greater);
assert_eq!(cmp_u64_f64(0, 0.0).unwrap(), Ordering::Equal);
assert_eq!(cmp_u64_f64(0, 0.1).unwrap(), Ordering::Less);
// Negative float values should always be less than any u64
assert_eq!(cmp_u64_f64(0, -0.1).unwrap(), Ordering::Greater);
assert_eq!(cmp_u64_f64(5, -5.0).unwrap(), Ordering::Greater);
assert_eq!(cmp_u64_f64(u64::MAX, -1e20).unwrap(), Ordering::Greater);
// Tests with extreme values
assert_eq!(cmp_u64_f64(u64::MAX, 1e20).unwrap(), Ordering::Less);
// Precision edge cases: large u64 that loses precision when converted to f64
// => 2^54, exactly represented as f64
let large_f64 = 18_014_398_509_481_984.0;
let large_u64 = 18_014_398_509_481_984;
// prove that large_u64 is exactly represented as f64
assert_eq!(large_u64 as f64, large_f64);
assert_eq!(cmp_u64_f64(large_u64, large_f64).unwrap(), Ordering::Equal);
// => (2^54 + 1) cannot be exactly represented in f64
let large_u64_plus_1 = 18_014_398_509_481_985;
// prove that it is represented as f64 by large_f64
assert_eq!(large_u64_plus_1 as f64, large_f64);
assert_eq!(
cmp_u64_f64(large_u64_plus_1, large_f64).unwrap(),
Ordering::Greater
);
// => (2^54 - 1) cannot be exactly represented in f64
let large_u64_minus_1 = 18_014_398_509_481_983;
// prove that it is also represented as f64 by large_f64
assert_eq!(large_u64_minus_1 as f64, large_f64);
assert_eq!(
cmp_u64_f64(large_u64_minus_1, large_f64).unwrap(),
Ordering::Less
);
// NaN comparison results in an error
assert!(cmp_u64_f64(0, f64::NAN).is_err());
}
#[test]
fn test_cmp_i64_f64() {
// Basic comparisons
assert_eq!(cmp_i64_f64(5, 5.0).unwrap(), Ordering::Equal);
assert_eq!(cmp_i64_f64(5, 6.0).unwrap(), Ordering::Less);
assert_eq!(cmp_i64_f64(6, 5.0).unwrap(), Ordering::Greater);
assert_eq!(cmp_i64_f64(-5, -5.0).unwrap(), Ordering::Equal);
assert_eq!(cmp_i64_f64(-5, -4.0).unwrap(), Ordering::Less);
assert_eq!(cmp_i64_f64(-4, -5.0).unwrap(), Ordering::Greater);
assert_eq!(cmp_i64_f64(-5, 5.0).unwrap(), Ordering::Less);
assert_eq!(cmp_i64_f64(5, -5.0).unwrap(), Ordering::Greater);
assert_eq!(cmp_i64_f64(0, -0.1).unwrap(), Ordering::Greater);
assert_eq!(cmp_i64_f64(0, 0.1).unwrap(), Ordering::Less);
assert_eq!(cmp_i64_f64(-1, -0.5).unwrap(), Ordering::Less);
assert_eq!(cmp_i64_f64(-1, 0.0).unwrap(), Ordering::Less);
assert_eq!(cmp_i64_f64(0, 0.0).unwrap(), Ordering::Equal);
// Tests with extreme values
assert_eq!(cmp_i64_f64(i64::MAX, 1e20).unwrap(), Ordering::Less);
assert_eq!(cmp_i64_f64(i64::MIN, -1e20).unwrap(), Ordering::Greater);
// Precision edge cases: large i64 that loses precision when converted to f64
// => 2^54, exactly represented as f64
let large_f64 = 18_014_398_509_481_984.0;
let large_i64 = 18_014_398_509_481_984;
// prove that large_i64 is exactly represented as f64
assert_eq!(large_i64 as f64, large_f64);
assert_eq!(cmp_i64_f64(large_i64, large_f64).unwrap(), Ordering::Equal);
// => (1_i64 << 54) + 1 cannot be exactly represented in f64
let large_i64_plus_1 = 18_014_398_509_481_985;
// prove that it is represented as f64 by large_f64
assert_eq!(large_i64_plus_1 as f64, large_f64);
assert_eq!(
cmp_i64_f64(large_i64_plus_1, large_f64).unwrap(),
Ordering::Greater
);
// => (1_i64 << 54) - 1 cannot be exactly represented in f64
let large_i64_minus_1 = 18_014_398_509_481_983;
// prove that it is also represented as f64 by large_f64
assert_eq!(large_i64_minus_1 as f64, large_f64);
assert_eq!(
cmp_i64_f64(large_i64_minus_1, large_f64).unwrap(),
Ordering::Less
);
// Same precision edge case but with negative values
// => -2^54, exactly represented as f64
let large_neg_f64 = -18_014_398_509_481_984.0;
let large_neg_i64 = -18_014_398_509_481_984;
// prove that large_neg_i64 is exactly represented as f64
assert_eq!(large_neg_i64 as f64, large_neg_f64);
assert_eq!(
cmp_i64_f64(large_neg_i64, large_neg_f64).unwrap(),
Ordering::Equal
);
// => (-2^54 + 1) cannot be exactly represented in f64
let large_neg_i64_plus_1 = -18_014_398_509_481_985;
// prove that it is represented as f64 by large_neg_f64
assert_eq!(large_neg_i64_plus_1 as f64, large_neg_f64);
assert_eq!(
cmp_i64_f64(large_neg_i64_plus_1, large_neg_f64).unwrap(),
Ordering::Less
);
// => (-2^54 - 1) cannot be exactly represented in f64
let large_neg_i64_minus_1 = -18_014_398_509_481_983;
// prove that it is also represented as f64 by large_neg_f64
assert_eq!(large_neg_i64_minus_1 as f64, large_neg_f64);
assert_eq!(
cmp_i64_f64(large_neg_i64_minus_1, large_neg_f64).unwrap(),
Ordering::Greater
);
// NaN comparison results in an error
assert!(cmp_i64_f64(0, f64::NAN).is_err());
}
#[test]
fn test_cmp_i64_u64() {
// Test with negative i64 values (should always be less than any u64)
assert_eq!(cmp_i64_u64(-1, 0), Ordering::Less);
assert_eq!(cmp_i64_u64(i64::MIN, 0), Ordering::Less);
assert_eq!(cmp_i64_u64(i64::MIN, u64::MAX), Ordering::Less);
// Test with positive i64 values
assert_eq!(cmp_i64_u64(0, 0), Ordering::Equal);
assert_eq!(cmp_i64_u64(1, 0), Ordering::Greater);
assert_eq!(cmp_i64_u64(1, 1), Ordering::Equal);
assert_eq!(cmp_i64_u64(0, 1), Ordering::Less);
assert_eq!(cmp_i64_u64(5, 10), Ordering::Less);
assert_eq!(cmp_i64_u64(10, 5), Ordering::Greater);
// Test with values near i64::MAX and u64 conversion
assert_eq!(cmp_i64_u64(i64::MAX, i64::MAX as u64), Ordering::Equal);
assert_eq!(cmp_i64_u64(i64::MAX, (i64::MAX as u64) + 1), Ordering::Less);
assert_eq!(cmp_i64_u64(i64::MAX, u64::MAX), Ordering::Less);
}
}
#[cfg(test)]
mod num_proj_tests {
use super::num_proj::{self, ProjectedNumber};
#[test]
fn test_i64_to_u64() {
assert_eq!(num_proj::i64_to_u64(-1), ProjectedNumber::Next(0));
assert_eq!(num_proj::i64_to_u64(i64::MIN), ProjectedNumber::Next(0));
assert_eq!(num_proj::i64_to_u64(0), ProjectedNumber::Exact(0));
assert_eq!(num_proj::i64_to_u64(42), ProjectedNumber::Exact(42));
assert_eq!(
num_proj::i64_to_u64(i64::MAX),
ProjectedNumber::Exact(i64::MAX as u64)
);
}
#[test]
fn test_u64_to_i64() {
assert_eq!(num_proj::u64_to_i64(0), ProjectedNumber::Exact(0));
assert_eq!(num_proj::u64_to_i64(42), ProjectedNumber::Exact(42));
assert_eq!(
num_proj::u64_to_i64(i64::MAX as u64),
ProjectedNumber::Exact(i64::MAX)
);
assert_eq!(
num_proj::u64_to_i64((i64::MAX as u64) + 1),
ProjectedNumber::AfterLast
);
assert_eq!(num_proj::u64_to_i64(u64::MAX), ProjectedNumber::AfterLast);
}
#[test]
fn test_f64_to_u64() {
assert_eq!(num_proj::f64_to_u64(-1e25), ProjectedNumber::Next(0));
assert_eq!(num_proj::f64_to_u64(-0.1), ProjectedNumber::Next(0));
assert_eq!(num_proj::f64_to_u64(1e20), ProjectedNumber::AfterLast);
assert_eq!(
num_proj::f64_to_u64(f64::INFINITY),
ProjectedNumber::AfterLast
);
assert_eq!(num_proj::f64_to_u64(0.0), ProjectedNumber::Exact(0));
assert_eq!(num_proj::f64_to_u64(42.0), ProjectedNumber::Exact(42));
assert_eq!(num_proj::f64_to_u64(0.5), ProjectedNumber::Next(1));
assert_eq!(num_proj::f64_to_u64(42.1), ProjectedNumber::Next(43));
}
#[test]
fn test_f64_to_i64() {
assert_eq!(num_proj::f64_to_i64(-1e20), ProjectedNumber::Next(i64::MIN));
assert_eq!(
num_proj::f64_to_i64(f64::NEG_INFINITY),
ProjectedNumber::Next(i64::MIN)
);
assert_eq!(num_proj::f64_to_i64(1e20), ProjectedNumber::AfterLast);
assert_eq!(
num_proj::f64_to_i64(f64::INFINITY),
ProjectedNumber::AfterLast
);
assert_eq!(num_proj::f64_to_i64(0.0), ProjectedNumber::Exact(0));
assert_eq!(num_proj::f64_to_i64(42.0), ProjectedNumber::Exact(42));
assert_eq!(num_proj::f64_to_i64(-42.0), ProjectedNumber::Exact(-42));
assert_eq!(num_proj::f64_to_i64(0.5), ProjectedNumber::Next(1));
assert_eq!(num_proj::f64_to_i64(42.1), ProjectedNumber::Next(43));
assert_eq!(num_proj::f64_to_i64(-0.5), ProjectedNumber::Next(0));
assert_eq!(num_proj::f64_to_i64(-42.1), ProjectedNumber::Next(-42));
}
#[test]
fn test_i64_to_f64() {
assert_eq!(num_proj::i64_to_f64(0), ProjectedNumber::Exact(0.0));
assert_eq!(num_proj::i64_to_f64(42), ProjectedNumber::Exact(42.0));
assert_eq!(num_proj::i64_to_f64(-42), ProjectedNumber::Exact(-42.0));
let max_exact = 9_007_199_254_740_992; // 2^53
assert_eq!(
num_proj::i64_to_f64(max_exact),
ProjectedNumber::Exact(max_exact as f64)
);
// Test values that cannot be exactly represented as f64 (integers above 2^53)
let large_i64 = 9_007_199_254_740_993; // 2^53 + 1
let closest_f64 = 9_007_199_254_740_992.0;
assert_eq!(large_i64 as f64, closest_f64);
if let ProjectedNumber::Next(val) = num_proj::i64_to_f64(large_i64) {
// Verify that the returned float is different from the direct cast
assert!(val > closest_f64);
assert!(val - closest_f64 < 2. * f64::EPSILON * closest_f64);
} else {
panic!("Expected ProjectedNumber::Next for large_i64");
}
// Test with very large negative value
let large_neg_i64 = -9_007_199_254_740_993; // -(2^53 + 1)
let closest_neg_f64 = -9_007_199_254_740_992.0;
assert_eq!(large_neg_i64 as f64, closest_neg_f64);
if let ProjectedNumber::Next(val) = num_proj::i64_to_f64(large_neg_i64) {
// Verify that the returned float is the closest representable f64
assert_eq!(val, closest_neg_f64);
} else {
panic!("Expected ProjectedNumber::Next for large_neg_i64");
}
}
#[test]
fn test_u64_to_f64() {
assert_eq!(num_proj::u64_to_f64(0), ProjectedNumber::Exact(0.0));
assert_eq!(num_proj::u64_to_f64(42), ProjectedNumber::Exact(42.0));
// Test the largest u64 value that can be exactly represented as f64 (2^53)
let max_exact = 9_007_199_254_740_992; // 2^53
assert_eq!(
num_proj::u64_to_f64(max_exact),
ProjectedNumber::Exact(max_exact as f64)
);
// Test values that cannot be exactly represented as f64 (integers above 2^53)
let large_u64 = 9_007_199_254_740_993; // 2^53 + 1
let closest_f64 = 9_007_199_254_740_992.0;
assert_eq!(large_u64 as f64, closest_f64);
if let ProjectedNumber::Next(val) = num_proj::u64_to_f64(large_u64) {
// Verify that the returned float is different from the direct cast
assert!(val > closest_f64);
assert!(val - closest_f64 < 2. * f64::EPSILON * closest_f64);
} else {
panic!("Expected ProjectedNumber::Next for large_u64");
}
}
}

View File

@@ -401,8 +401,6 @@ pub struct FilterAggReqData {
pub name: String,
/// The filter aggregation
pub req: FilterAggregation,
/// The segment reader
pub segment_reader: SegmentReader,
/// Document evaluator for the filter query (precomputed BitSet)
/// This is built once when the request data is created
pub evaluator: DocumentQueryEvaluator,
@@ -414,9 +412,8 @@ pub struct FilterAggReqData {
impl FilterAggReqData {
pub(crate) fn get_memory_consumption(&self) -> usize {
// Estimate: name + segment reader reference + bitset + buffer capacity
// Estimate: name + bitset + buffer capacity
self.name.len()
+ std::mem::size_of::<SegmentReader>()
+ self.evaluator.bitset.len() / 8 // BitSet memory (bits to bytes)
+ self.matching_docs_buffer.capacity() * std::mem::size_of::<DocId>()
+ std::mem::size_of::<bool>()
@@ -438,7 +435,7 @@ impl DocumentQueryEvaluator {
pub(crate) fn new(
query: Box<dyn Query>,
schema: Schema,
segment_reader: &SegmentReader,
segment_reader: &dyn SegmentReader,
) -> crate::Result<Self> {
let max_doc = segment_reader.max_doc();

View File

@@ -207,7 +207,7 @@ fn parse_offset_into_milliseconds(input: &str) -> Result<i64, AggregationError>
}
}
pub(crate) fn parse_into_milliseconds(input: &str) -> Result<i64, AggregationError> {
fn parse_into_milliseconds(input: &str) -> Result<i64, AggregationError> {
let split_boundary = input
.as_bytes()
.iter()

View File

@@ -22,7 +22,6 @@
//! - [Range](RangeAggregation)
//! - [Terms](TermsAggregation)
mod composite;
mod filter;
mod histogram;
mod range;
@@ -32,7 +31,6 @@ mod term_missing_agg;
use std::collections::HashMap;
use std::fmt;
pub use composite::*;
pub use filter::*;
pub use histogram::*;
pub use range::*;

View File

@@ -66,7 +66,7 @@ impl Collector for DistributedAggregationCollector {
fn for_segment(
&self,
segment_local_id: crate::SegmentOrdinal,
reader: &crate::SegmentReader,
reader: &dyn SegmentReader,
) -> crate::Result<Self::Child> {
AggregationSegmentCollector::from_agg_req_and_reader(
&self.agg,
@@ -96,7 +96,7 @@ impl Collector for AggregationCollector {
fn for_segment(
&self,
segment_local_id: crate::SegmentOrdinal,
reader: &crate::SegmentReader,
reader: &dyn SegmentReader,
) -> crate::Result<Self::Child> {
AggregationSegmentCollector::from_agg_req_and_reader(
&self.agg,
@@ -145,7 +145,7 @@ impl AggregationSegmentCollector {
/// reader. Also includes validation, e.g. checking field types and existence.
pub fn from_agg_req_and_reader(
agg: &Aggregations,
reader: &SegmentReader,
reader: &dyn SegmentReader,
segment_ordinal: SegmentOrdinal,
context: &AggContextParams,
) -> crate::Result<Self> {

View File

@@ -15,9 +15,8 @@ use serde::{Deserialize, Serialize};
use super::agg_req::{Aggregation, AggregationVariants, Aggregations};
use super::agg_result::{AggregationResult, BucketResult, MetricResult, RangeBucketEntry};
use super::bucket::{
composite_intermediate_key_ordering, cut_off_buckets, get_agg_name_and_property,
intermediate_histogram_buckets_to_final_buckets, CompositeAggregation, GetDocCount,
MissingOrder, Order, OrderTarget, RangeAggregation, TermsAggregation,
cut_off_buckets, get_agg_name_and_property, intermediate_histogram_buckets_to_final_buckets,
GetDocCount, Order, OrderTarget, RangeAggregation, TermsAggregation,
};
use super::metric::{
IntermediateAverage, IntermediateCount, IntermediateExtendedStats, IntermediateMax,
@@ -26,7 +25,7 @@ use super::metric::{
use super::segment_agg_result::AggregationLimitsGuard;
use super::{format_date, AggregationError, Key, SerializedKey};
use crate::aggregation::agg_result::{
AggregationResults, BucketEntries, BucketEntry, CompositeBucketEntry, FilterBucketResult,
AggregationResults, BucketEntries, BucketEntry, FilterBucketResult,
};
use crate::aggregation::bucket::TermsAggregationInternal;
use crate::aggregation::metric::CardinalityCollector;
@@ -91,19 +90,6 @@ impl From<IntermediateKey> for Key {
impl Eq for IntermediateKey {}
impl std::fmt::Display for IntermediateKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
IntermediateKey::Str(val) => f.write_str(val),
IntermediateKey::F64(val) => f.write_str(&val.to_string()),
IntermediateKey::U64(val) => f.write_str(&val.to_string()),
IntermediateKey::I64(val) => f.write_str(&val.to_string()),
IntermediateKey::Bool(val) => f.write_str(&val.to_string()),
IntermediateKey::IpAddr(val) => f.write_str(&val.to_string()),
}
}
}
impl std::hash::Hash for IntermediateKey {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
core::mem::discriminant(self).hash(state);
@@ -119,21 +105,6 @@ impl std::hash::Hash for IntermediateKey {
}
impl IntermediateAggregationResults {
/// Returns a reference to the intermediate aggregation result for the given key.
pub fn get(&self, key: &str) -> Option<&IntermediateAggregationResult> {
self.aggs_res.get(key)
}
/// Removes and returns the intermediate aggregation result for the given key.
pub fn remove(&mut self, key: &str) -> Option<IntermediateAggregationResult> {
self.aggs_res.remove(key)
}
/// Returns an iterator over the keys in the intermediate aggregation results.
pub fn keys(&self) -> impl Iterator<Item = &String> {
self.aggs_res.keys()
}
/// Add a result
pub fn push(&mut self, key: String, value: IntermediateAggregationResult) -> crate::Result<()> {
let entry = self.aggs_res.entry(key);
@@ -281,11 +252,6 @@ pub(crate) fn empty_from_req(req: &Aggregation) -> IntermediateAggregationResult
doc_count: 0,
sub_aggregations: IntermediateAggregationResults::default(),
}),
Composite(_) => {
IntermediateAggregationResult::Bucket(IntermediateBucketResult::Composite {
buckets: IntermediateCompositeBucketResult::default(),
})
}
}
}
@@ -479,11 +445,6 @@ pub enum IntermediateBucketResult {
/// Sub-aggregation results
sub_aggregations: IntermediateAggregationResults,
},
/// Composite aggregation
Composite {
/// The composite buckets
buckets: IntermediateCompositeBucketResult,
},
}
impl IntermediateBucketResult {
@@ -579,13 +540,6 @@ impl IntermediateBucketResult {
sub_aggregations: final_sub_aggregations,
}))
}
IntermediateBucketResult::Composite { buckets } => {
let composite_req = req
.agg
.as_composite()
.expect("unexpected aggregation, expected composite aggregation");
buckets.into_final_result(composite_req, req.sub_aggregation(), limits)
}
}
}
@@ -652,16 +606,6 @@ impl IntermediateBucketResult {
*doc_count_left += doc_count_right;
sub_aggs_left.merge_fruits(sub_aggs_right)?;
}
(
IntermediateBucketResult::Composite {
buckets: composite_left,
},
IntermediateBucketResult::Composite {
buckets: composite_right,
},
) => {
composite_left.merge_fruits(composite_right)?;
}
(IntermediateBucketResult::Range(_), _) => {
panic!("try merge on different types")
}
@@ -674,9 +618,6 @@ impl IntermediateBucketResult {
(IntermediateBucketResult::Filter { .. }, _) => {
panic!("try merge on different types")
}
(IntermediateBucketResult::Composite { .. }, _) => {
panic!("try merge on different types")
}
}
Ok(())
}
@@ -698,21 +639,6 @@ pub struct IntermediateTermBucketResult {
}
impl IntermediateTermBucketResult {
/// Returns a reference to the map of bucket entries keyed by [`IntermediateKey`].
pub fn entries(&self) -> &FxHashMap<IntermediateKey, IntermediateTermBucketEntry> {
&self.entries
}
/// Returns the count of documents not included in the returned buckets.
pub fn sum_other_doc_count(&self) -> u64 {
self.sum_other_doc_count
}
/// Returns the upper bound of the error on document counts in the returned buckets.
pub fn doc_count_error_upper_bound(&self) -> u64 {
self.doc_count_error_upper_bound
}
pub(crate) fn into_final_result(
self,
req: &TermsAggregation,
@@ -894,7 +820,7 @@ impl IntermediateRangeBucketEntry {
};
// If we have a date type on the histogram buckets, we add the `key_as_string` field as
// rfc3339
// rfc339
if column_type == Some(ColumnType::DateTime) {
if let Some(val) = range_bucket_entry.to {
let key_as_string = format_date(val as i64)?;
@@ -945,176 +871,6 @@ impl MergeFruits for IntermediateHistogramBucketEntry {
}
}
/// Entry for the composite bucket.
pub type IntermediateCompositeBucketEntry = IntermediateTermBucketEntry;
/// The fully typed key for composite aggregation
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum CompositeIntermediateKey {
/// Bool key
Bool(bool),
/// String key
Str(String),
/// Float key
F64(f64),
/// Signed integer key
I64(i64),
/// Unsigned integer key
U64(u64),
/// DateTime key, nanoseconds since epoch
DateTime(i64),
/// IP Address key
IpAddr(Ipv6Addr),
/// Missing value key
Null,
}
impl Eq for CompositeIntermediateKey {}
impl std::hash::Hash for CompositeIntermediateKey {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
core::mem::discriminant(self).hash(state);
match self {
CompositeIntermediateKey::Bool(val) => val.hash(state),
CompositeIntermediateKey::Str(text) => text.hash(state),
CompositeIntermediateKey::F64(val) => val.to_bits().hash(state),
CompositeIntermediateKey::U64(val) => val.hash(state),
CompositeIntermediateKey::I64(val) => val.hash(state),
CompositeIntermediateKey::DateTime(val) => val.hash(state),
CompositeIntermediateKey::IpAddr(val) => val.hash(state),
CompositeIntermediateKey::Null => {}
}
}
}
/// Composite aggregation page.
#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct IntermediateCompositeBucketResult {
pub(crate) entries: FxHashMap<Vec<CompositeIntermediateKey>, IntermediateCompositeBucketEntry>,
pub(crate) target_size: u32,
pub(crate) orders: Vec<(Order, MissingOrder)>,
}
impl IntermediateCompositeBucketResult {
pub(crate) fn into_final_result(
self,
req: &CompositeAggregation,
sub_aggregation_req: &Aggregations,
limits: &mut AggregationLimitsGuard,
) -> crate::Result<BucketResult> {
let trimmed_entry_vec =
trim_composite_buckets(self.entries, &self.orders, self.target_size)?;
let after_key = if trimmed_entry_vec.len() == req.size as usize {
trimmed_entry_vec
.last()
.map(|bucket| {
let (intermediate_key, _entry) = bucket;
intermediate_key
.iter()
.enumerate()
.map(|(idx, intermediate_key)| {
let source = &req.sources[idx];
(source.name().to_string(), intermediate_key.clone().into())
})
.collect()
})
.unwrap()
} else {
FxHashMap::default()
};
let buckets = trimmed_entry_vec
.into_iter()
.map(|(intermediate_key, entry)| {
let key = intermediate_key
.into_iter()
.enumerate()
.map(|(idx, intermediate_key)| {
let source = &req.sources[idx];
(source.name().to_string(), intermediate_key.into())
})
.collect();
Ok(CompositeBucketEntry {
key,
doc_count: entry.doc_count as u64,
sub_aggregation: entry
.sub_aggregation
.into_final_result_internal(sub_aggregation_req, limits)?,
})
})
.collect::<crate::Result<Vec<_>>>()?;
Ok(BucketResult::Composite { after_key, buckets })
}
fn merge_fruits(&mut self, other: IntermediateCompositeBucketResult) -> crate::Result<()> {
merge_maps(&mut self.entries, other.entries)?;
if self.entries.len() as u32 > 2 * self.target_size {
self.trim()?;
}
Ok(())
}
/// Trim the composite buckets to the target size, according to the ordering.
pub(crate) fn trim(&mut self) -> crate::Result<()> {
if self.entries.len() as u32 <= self.target_size {
return Ok(());
}
let sorted_entries = trim_composite_buckets(
std::mem::take(&mut self.entries),
&self.orders,
self.target_size,
)?;
self.entries = sorted_entries.into_iter().collect();
Ok(())
}
}
fn trim_composite_buckets(
entries: FxHashMap<Vec<CompositeIntermediateKey>, IntermediateCompositeBucketEntry>,
orders: &[(Order, MissingOrder)],
target_size: u32,
) -> crate::Result<
Vec<(
Vec<CompositeIntermediateKey>,
IntermediateCompositeBucketEntry,
)>,
> {
let mut entries: Vec<_> = entries.into_iter().collect();
let mut sort_error: Option<TantivyError> = None;
entries.sort_by(|(left_key, _), (right_key, _)| {
if sort_error.is_some() {
return Ordering::Equal;
}
for idx in 0..orders.len() {
match composite_intermediate_key_ordering(
&left_key[idx],
&right_key[idx],
orders[idx].0,
orders[idx].1,
) {
Ok(ordering) if ordering != Ordering::Equal => return ordering,
Ok(_) => continue,
Err(err) => {
sort_error = Some(err);
break;
}
}
}
Ordering::Equal
});
if let Some(err) = sort_error {
return Err(err);
}
entries.truncate(target_size as usize);
Ok(entries)
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;

View File

@@ -55,12 +55,6 @@ impl IntermediateAverage {
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
Self { stats }
}
/// Returns a reference to the underlying [`IntermediateStats`].
pub fn stats(&self) -> &IntermediateStats {
&self.stats
}
/// Merges the other intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateAverage) {
self.stats.merge_fruits(other.stats);

View File

@@ -1,11 +1,12 @@
use std::hash::Hash;
use std::collections::hash_map::DefaultHasher;
use std::hash::{BuildHasher, Hasher};
use columnar::column_values::CompactSpaceU64Accessor;
use columnar::{Column, ColumnType, Dictionary, StrColumn};
use common::f64_to_u64;
use datasketches::hll::{HllSketch, HllType, HllUnion};
use hyperloglogplus::{HyperLogLog, HyperLogLogPlus};
use rustc_hash::FxHashSet;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde::{Deserialize, Serialize};
use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::intermediate_agg_result::{
@@ -15,17 +16,29 @@ use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::*;
use crate::TantivyError;
/// Log2 of the number of registers for the HLL sketch.
/// 2^11 = 2048 registers, giving ~2.3% relative error and ~1KB per sketch (Hll4).
const LG_K: u8 = 11;
#[derive(Clone, Debug, Serialize, Deserialize)]
struct BuildSaltedHasher {
salt: u8,
}
impl BuildHasher for BuildSaltedHasher {
type Hasher = DefaultHasher;
fn build_hasher(&self) -> Self::Hasher {
let mut hasher = DefaultHasher::new();
hasher.write_u8(self.salt);
hasher
}
}
/// # Cardinality
///
/// The cardinality aggregation allows for computing an estimate
/// of the number of different values in a data set based on the
/// Apache DataSketches HyperLogLog algorithm. This is particularly useful for
/// understanding the uniqueness of values in a large dataset where counting
/// each unique value individually would be computationally expensive.
/// HyperLogLog++ algorithm. This is particularly useful for understanding the
/// uniqueness of values in a large dataset where counting each unique value
/// individually would be computationally expensive.
///
/// For example, you might use a cardinality aggregation to estimate the number
/// of unique visitors to a website by aggregating on a field that contains
@@ -171,7 +184,7 @@ impl SegmentCardinalityCollectorBucket {
term_ids.sort_unstable();
dict.sorted_ords_to_term_cb(term_ids.iter().map(|term| *term as u64), |term| {
self.cardinality.insert(term);
self.cardinality.sketch.insert_any(&term);
Ok(())
})?;
if has_missing {
@@ -182,17 +195,17 @@ impl SegmentCardinalityCollectorBucket {
);
match missing_key {
Key::Str(missing) => {
self.cardinality.insert(missing.as_str());
self.cardinality.sketch.insert_any(&missing);
}
Key::F64(val) => {
let val = f64_to_u64(*val);
self.cardinality.insert(val);
self.cardinality.sketch.insert_any(&val);
}
Key::U64(val) => {
self.cardinality.insert(*val);
self.cardinality.sketch.insert_any(&val);
}
Key::I64(val) => {
self.cardinality.insert(*val);
self.cardinality.sketch.insert_any(&val);
}
}
}
@@ -283,11 +296,11 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector {
})?;
for val in col_block_accessor.iter_vals() {
let val: u128 = compact_space_accessor.compact_to_u128(val as u32);
bucket.cardinality.insert(val);
bucket.cardinality.sketch.insert_any(&val);
}
} else {
for val in col_block_accessor.iter_vals() {
bucket.cardinality.insert(val);
bucket.cardinality.sketch.insert_any(&val);
}
}
@@ -308,18 +321,11 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector {
}
}
#[derive(Clone, Debug)]
/// The cardinality collector used during segment collection and for merging results.
/// Uses Apache DataSketches HLL (lg_k=11, Hll4) for compact binary serialization
/// and cross-language compatibility (e.g. Java `datasketches` library).
#[derive(Clone, Debug, Serialize, Deserialize)]
/// The percentiles collector used during segment collection and for merging results.
pub struct CardinalityCollector {
sketch: HllSketch,
/// Salt derived from `ColumnType`, used to differentiate values of different column types
/// that map to the same u64 (e.g. bool `false` = 0 vs i64 `0`).
/// Not serialized — only needed during insertion, not after sketch registers are populated.
salt: u8,
sketch: HyperLogLogPlus<u64, BuildSaltedHasher>,
}
impl Default for CardinalityCollector {
fn default() -> Self {
Self::new(0)
@@ -332,52 +338,25 @@ impl PartialEq for CardinalityCollector {
}
}
impl Serialize for CardinalityCollector {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let bytes = self.sketch.serialize();
serializer.serialize_bytes(&bytes)
}
}
impl<'de> Deserialize<'de> for CardinalityCollector {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let bytes: Vec<u8> = Deserialize::deserialize(deserializer)?;
let sketch = HllSketch::deserialize(&bytes).map_err(serde::de::Error::custom)?;
Ok(Self { sketch, salt: 0 })
}
}
impl CardinalityCollector {
/// Compute the final cardinality estimate.
pub fn finalize(self) -> Option<f64> {
Some(self.sketch.clone().count().trunc())
}
fn new(salt: u8) -> Self {
Self {
sketch: HllSketch::new(LG_K, HllType::Hll4),
salt,
sketch: HyperLogLogPlus::new(16, BuildSaltedHasher { salt }).unwrap(),
}
}
/// Insert a value into the HLL sketch, salted by the column type.
/// The salt ensures that identical u64 values from different column types
/// (e.g. bool `false` vs i64 `0`) are counted as distinct.
pub(crate) fn insert<T: Hash>(&mut self, value: T) {
self.sketch.update((self.salt, value));
}
/// Compute the final cardinality estimate.
pub fn finalize(self) -> Option<f64> {
Some(self.sketch.estimate().trunc())
}
/// Serialize the HLL sketch to its compact binary representation.
/// The format is cross-language compatible with Apache DataSketches (Java, C++, Python).
pub fn to_sketch_bytes(&self) -> Vec<u8> {
self.sketch.serialize()
}
pub(crate) fn merge_fruits(&mut self, right: CardinalityCollector) -> crate::Result<()> {
let mut union = HllUnion::new(LG_K);
union.update(&self.sketch);
union.update(&right.sketch);
self.sketch = union.get_result(HllType::Hll4);
self.sketch.merge(&right.sketch).map_err(|err| {
TantivyError::AggregationError(AggregationError::InternalError(format!(
"Error while merging cardinality {err:?}"
)))
})?;
Ok(())
}
}
@@ -539,75 +518,4 @@ mod tests {
Ok(())
}
#[test]
fn cardinality_collector_serde_roundtrip() {
use super::CardinalityCollector;
let mut collector = CardinalityCollector::default();
collector.insert("hello");
collector.insert("world");
collector.insert("hello"); // duplicate
let serialized = serde_json::to_vec(&collector).unwrap();
let deserialized: CardinalityCollector = serde_json::from_slice(&serialized).unwrap();
let original_estimate = collector.finalize().unwrap();
let roundtrip_estimate = deserialized.finalize().unwrap();
assert_eq!(original_estimate, roundtrip_estimate);
assert_eq!(original_estimate, 2.0);
}
#[test]
fn cardinality_collector_merge() {
use super::CardinalityCollector;
let mut left = CardinalityCollector::default();
left.insert("a");
left.insert("b");
let mut right = CardinalityCollector::default();
right.insert("b");
right.insert("c");
left.merge_fruits(right).unwrap();
let estimate = left.finalize().unwrap();
assert_eq!(estimate, 3.0);
}
#[test]
fn cardinality_collector_serialize_deserialize_binary() {
use datasketches::hll::HllSketch;
use super::CardinalityCollector;
let mut collector = CardinalityCollector::default();
collector.insert("apple");
collector.insert("banana");
collector.insert("cherry");
let bytes = collector.to_sketch_bytes();
let deserialized = HllSketch::deserialize(&bytes).unwrap();
assert!((deserialized.estimate() - 3.0).abs() < 0.01);
}
#[test]
fn cardinality_collector_salt_differentiates_types() {
use super::CardinalityCollector;
// Without salt, same u64 value from different column types would collide
let mut collector_bool = CardinalityCollector::new(5); // e.g. ColumnType::Bool
collector_bool.insert(0u64); // false
collector_bool.insert(1u64); // true
let mut collector_i64 = CardinalityCollector::new(2); // e.g. ColumnType::I64
collector_i64.insert(0u64);
collector_i64.insert(1u64);
// Merge them
collector_bool.merge_fruits(collector_i64).unwrap();
let estimate = collector_bool.finalize().unwrap();
// Should be 4 because salt makes (5, 0) != (2, 0) and (5, 1) != (2, 1)
assert_eq!(estimate, 4.0);
}
}

View File

@@ -107,11 +107,8 @@ pub enum PercentileValues {
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
/// The entry when requesting percentiles with keyed: false
pub struct PercentileValuesVecEntry {
/// Percentile
pub key: f64,
/// Value at the percentile
pub value: f64,
key: f64,
value: f64,
}
/// Single-metric aggregations use this common result structure.

View File

@@ -222,12 +222,6 @@ impl PercentilesCollector {
self.sketch.add(val);
}
/// Encode the underlying DDSketch to Java-compatible binary format
/// for cross-language serialization with Java consumers.
pub fn to_sketch_bytes(&self) -> Vec<u8> {
self.sketch.to_java_bytes()
}
pub(crate) fn merge_fruits(&mut self, right: PercentilesCollector) -> crate::Result<()> {
self.sketch.merge(&right.sketch).map_err(|err| {
TantivyError::AggregationError(AggregationError::InternalError(format!(
@@ -331,7 +325,7 @@ mod tests {
use crate::aggregation::AggregationCollector;
use crate::query::AllQuery;
use crate::schema::{Schema, FAST};
use crate::{assert_nearly_equals, Index};
use crate::Index;
#[test]
fn test_aggregation_percentiles_empty_index() -> crate::Result<()> {
@@ -614,16 +608,12 @@ mod tests {
let res = exec_request_with_query(agg_req, &index, None)?;
assert_eq!(res["range_with_stats"]["buckets"][0]["doc_count"], 3);
assert_nearly_equals!(
res["range_with_stats"]["buckets"][0]["percentiles"]["values"]["1.0"]
.as_f64()
.unwrap(),
assert_eq!(
res["range_with_stats"]["buckets"][0]["percentiles"]["values"]["1.0"],
5.0028295751107414
);
assert_nearly_equals!(
res["range_with_stats"]["buckets"][0]["percentiles"]["values"]["99.0"]
.as_f64()
.unwrap(),
assert_eq!(
res["range_with_stats"]["buckets"][0]["percentiles"]["values"]["99.0"],
10.07469668951144
);
@@ -669,14 +659,8 @@ mod tests {
let res = exec_request_with_query(agg_req, &index, None)?;
assert_nearly_equals!(
res["percentiles"]["values"]["1.0"].as_f64().unwrap(),
5.0028295751107414
);
assert_nearly_equals!(
res["percentiles"]["values"]["99.0"].as_f64().unwrap(),
10.07469668951144
);
assert_eq!(res["percentiles"]["values"]["1.0"], 5.0028295751107414);
assert_eq!(res["percentiles"]["values"]["99.0"], 10.07469668951144);
Ok(())
}

View File

@@ -110,16 +110,6 @@ impl Default for IntermediateStats {
}
impl IntermediateStats {
/// Returns the number of values collected.
pub fn count(&self) -> u64 {
self.count
}
/// Returns the sum of all values collected.
pub fn sum(&self) -> f64 {
self.sum
}
/// Merges the other stats intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateStats) {
self.count += other.count;

View File

@@ -43,7 +43,7 @@ impl Collector for Count {
fn for_segment(
&self,
_: SegmentOrdinal,
_: &SegmentReader,
_: &dyn SegmentReader,
) -> crate::Result<SegmentCountCollector> {
Ok(SegmentCountCollector::default())
}

View File

@@ -1,7 +1,7 @@
use std::collections::HashSet;
use super::{Collector, SegmentCollector};
use crate::{DocAddress, DocId, Score};
use crate::{DocAddress, DocId, Score, SegmentReader};
/// Collectors that returns the set of DocAddress that matches the query.
///
@@ -15,7 +15,7 @@ impl Collector for DocSetCollector {
fn for_segment(
&self,
segment_local_id: crate::SegmentOrdinal,
_segment: &crate::SegmentReader,
_segment: &dyn SegmentReader,
) -> crate::Result<Self::Child> {
Ok(DocSetChildCollector {
segment_local_id,

View File

@@ -265,7 +265,7 @@ impl Collector for FacetCollector {
fn for_segment(
&self,
_: SegmentOrdinal,
reader: &SegmentReader,
reader: &dyn SegmentReader,
) -> crate::Result<FacetSegmentCollector> {
let facet_reader = reader.facet_reader(&self.field_name)?;
let facet_dict = facet_reader.facet_dict();

View File

@@ -113,7 +113,7 @@ where
fn for_segment(
&self,
segment_local_id: u32,
segment_reader: &SegmentReader,
segment_reader: &dyn SegmentReader,
) -> crate::Result<Self::Child> {
let column_opt = segment_reader.fast_fields().column_opt(&self.field)?;
@@ -287,7 +287,7 @@ where
fn for_segment(
&self,
segment_local_id: u32,
segment_reader: &SegmentReader,
segment_reader: &dyn SegmentReader,
) -> crate::Result<Self::Child> {
let column_opt = segment_reader.fast_fields().bytes(&self.field)?;

View File

@@ -6,7 +6,7 @@ use fastdivide::DividerU64;
use crate::collector::{Collector, SegmentCollector};
use crate::fastfield::{FastFieldNotAvailableError, FastValue};
use crate::schema::Type;
use crate::{DocId, Score};
use crate::{DocId, Score, SegmentReader};
/// Histogram builds an histogram of the values of a fastfield for the
/// collected DocSet.
@@ -110,7 +110,7 @@ impl Collector for HistogramCollector {
fn for_segment(
&self,
_segment_local_id: crate::SegmentOrdinal,
segment: &crate::SegmentReader,
segment: &dyn SegmentReader,
) -> crate::Result<Self::Child> {
let column_opt = segment.fast_fields().u64_lenient(&self.field)?;
let (column, _column_type) = column_opt.ok_or_else(|| FastFieldNotAvailableError {

View File

@@ -156,7 +156,7 @@ pub trait Collector: Sync + Send {
fn for_segment(
&self,
segment_local_id: SegmentOrdinal,
segment: &SegmentReader,
segment: &dyn SegmentReader,
) -> crate::Result<Self::Child>;
/// Returns true iff the collector requires to compute scores for documents.
@@ -174,7 +174,7 @@ pub trait Collector: Sync + Send {
&self,
weight: &dyn Weight,
segment_ord: u32,
reader: &SegmentReader,
reader: &dyn SegmentReader,
) -> crate::Result<<Self::Child as SegmentCollector>::Fruit> {
let with_scoring = self.requires_scoring();
let mut segment_collector = self.for_segment(segment_ord, reader)?;
@@ -186,7 +186,7 @@ pub trait Collector: Sync + Send {
pub(crate) fn default_collect_segment_impl<TSegmentCollector: SegmentCollector>(
segment_collector: &mut TSegmentCollector,
weight: &dyn Weight,
reader: &SegmentReader,
reader: &dyn SegmentReader,
with_scoring: bool,
) -> crate::Result<()> {
match (reader.alive_bitset(), with_scoring) {
@@ -255,7 +255,7 @@ impl<TCollector: Collector> Collector for Option<TCollector> {
fn for_segment(
&self,
segment_local_id: SegmentOrdinal,
segment: &SegmentReader,
segment: &dyn SegmentReader,
) -> crate::Result<Self::Child> {
Ok(if let Some(inner) = self {
let inner_segment_collector = inner.for_segment(segment_local_id, segment)?;
@@ -336,7 +336,7 @@ where
fn for_segment(
&self,
segment_local_id: u32,
segment: &SegmentReader,
segment: &dyn SegmentReader,
) -> crate::Result<Self::Child> {
let left = self.0.for_segment(segment_local_id, segment)?;
let right = self.1.for_segment(segment_local_id, segment)?;
@@ -407,7 +407,7 @@ where
fn for_segment(
&self,
segment_local_id: u32,
segment: &SegmentReader,
segment: &dyn SegmentReader,
) -> crate::Result<Self::Child> {
let one = self.0.for_segment(segment_local_id, segment)?;
let two = self.1.for_segment(segment_local_id, segment)?;
@@ -487,7 +487,7 @@ where
fn for_segment(
&self,
segment_local_id: u32,
segment: &SegmentReader,
segment: &dyn SegmentReader,
) -> crate::Result<Self::Child> {
let one = self.0.for_segment(segment_local_id, segment)?;
let two = self.1.for_segment(segment_local_id, segment)?;

View File

@@ -24,7 +24,7 @@ impl<TCollector: Collector> Collector for CollectorWrapper<TCollector> {
fn for_segment(
&self,
segment_local_id: u32,
reader: &SegmentReader,
reader: &dyn SegmentReader,
) -> crate::Result<Box<dyn BoxableSegmentCollector>> {
let child = self.0.for_segment(segment_local_id, reader)?;
Ok(Box::new(SegmentCollectorWrapper(child)))
@@ -209,7 +209,7 @@ impl Collector for MultiCollector<'_> {
fn for_segment(
&self,
segment_local_id: SegmentOrdinal,
segment: &SegmentReader,
segment: &dyn SegmentReader,
) -> crate::Result<MultiCollectorChild> {
let children = self
.collector_wrappers

View File

@@ -1,5 +1,4 @@
mod order;
mod sort_by_bytes;
mod sort_by_erased_type;
mod sort_by_score;
mod sort_by_static_fast_value;
@@ -7,7 +6,6 @@ mod sort_by_string;
mod sort_key_computer;
pub use order::*;
pub use sort_by_bytes::SortByBytes;
pub use sort_by_erased_type::SortByErasedType;
pub use sort_by_score::SortBySimilarityScore;
pub use sort_by_static_fast_value::SortByStaticFastValue;

View File

@@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize};
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::schema::{OwnedValue, Schema};
use crate::{DocId, Order, Score};
use crate::{DocId, Order, Score, SegmentReader};
fn compare_owned_value<const NULLS_FIRST: bool>(lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
match (lhs, rhs) {
@@ -430,7 +430,7 @@ where
fn segment_sort_key_computer(
&self,
segment_reader: &crate::SegmentReader,
segment_reader: &dyn SegmentReader,
) -> crate::Result<Self::Child> {
let child = self.0.segment_sort_key_computer(segment_reader)?;
Ok(SegmentSortKeyComputerWithComparator {
@@ -468,7 +468,7 @@ where
fn segment_sort_key_computer(
&self,
segment_reader: &crate::SegmentReader,
segment_reader: &dyn SegmentReader,
) -> crate::Result<Self::Child> {
let child = self.0.segment_sort_key_computer(segment_reader)?;
Ok(SegmentSortKeyComputerWithComparator {

View File

@@ -1,168 +0,0 @@
use columnar::BytesColumn;
use crate::collector::sort_key::NaturalComparator;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::termdict::TermOrdinal;
use crate::{DocId, Score};
/// Sort by the first value of a bytes column.
///
/// If the field is multivalued, only the first value is considered.
///
/// Documents that do not have this value are still considered.
/// Their sort key will simply be `None`.
#[derive(Debug, Clone)]
pub struct SortByBytes {
column_name: String,
}
impl SortByBytes {
/// Creates a new sort by bytes sort key computer.
pub fn for_field(column_name: impl ToString) -> Self {
SortByBytes {
column_name: column_name.to_string(),
}
}
}
impl SortKeyComputer for SortByBytes {
type SortKey = Option<Vec<u8>>;
type Child = ByBytesColumnSegmentSortKeyComputer;
type Comparator = NaturalComparator;
fn segment_sort_key_computer(
&self,
segment_reader: &crate::SegmentReader,
) -> crate::Result<Self::Child> {
let bytes_column_opt = segment_reader.fast_fields().bytes(&self.column_name)?;
Ok(ByBytesColumnSegmentSortKeyComputer { bytes_column_opt })
}
}
/// Segment-level sort key computer for bytes columns.
pub struct ByBytesColumnSegmentSortKeyComputer {
bytes_column_opt: Option<BytesColumn>,
}
impl SegmentSortKeyComputer for ByBytesColumnSegmentSortKeyComputer {
type SortKey = Option<Vec<u8>>;
type SegmentSortKey = Option<TermOrdinal>;
type SegmentComparator = NaturalComparator;
#[inline(always)]
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Option<TermOrdinal> {
let bytes_column = self.bytes_column_opt.as_ref()?;
bytes_column.ords().first(doc)
}
fn convert_segment_sort_key(&self, term_ord_opt: Option<TermOrdinal>) -> Option<Vec<u8>> {
// TODO: Individual lookups to the dictionary like this are very likely to repeatedly
// decompress the same blocks. See https://github.com/quickwit-oss/tantivy/issues/2776
let term_ord = term_ord_opt?;
let bytes_column = self.bytes_column_opt.as_ref()?;
let mut bytes = Vec::new();
bytes_column
.dictionary()
.ord_to_term(term_ord, &mut bytes)
.ok()?;
Some(bytes)
}
}
#[cfg(test)]
mod tests {
use super::SortByBytes;
use crate::collector::TopDocs;
use crate::query::AllQuery;
use crate::schema::{BytesOptions, Schema, FAST, INDEXED};
use crate::{Index, IndexWriter, Order, TantivyDocument};
#[test]
fn test_sort_by_bytes_asc() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let bytes_field = schema_builder
.add_bytes_field("data", BytesOptions::default().set_fast().set_indexed());
let id_field = schema_builder.add_u64_field("id", FAST | INDEXED);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer: IndexWriter = index.writer_for_tests()?;
// Insert documents with byte values in non-sorted order
let test_data: Vec<(u64, Vec<u8>)> = vec![
(1, vec![0x02, 0x00]),
(2, vec![0x00, 0x10]),
(3, vec![0x01, 0x00]),
(4, vec![0x00, 0x20]),
];
for (id, bytes) in &test_data {
let mut doc = TantivyDocument::new();
doc.add_u64(id_field, *id);
doc.add_bytes(bytes_field, bytes);
index_writer.add_document(doc)?;
}
index_writer.commit()?;
let reader = index.reader()?;
let searcher = reader.searcher();
// Sort ascending by bytes
let top_docs =
TopDocs::with_limit(10).order_by((SortByBytes::for_field("data"), Order::Asc));
let results: Vec<(Option<Vec<u8>>, _)> = searcher.search(&AllQuery, &top_docs)?;
// Expected order: [0x00,0x10], [0x00,0x20], [0x01,0x00], [0x02,0x00]
let sorted_bytes: Vec<Option<Vec<u8>>> = results.into_iter().map(|(b, _)| b).collect();
assert_eq!(
sorted_bytes,
vec![
Some(vec![0x00, 0x10]),
Some(vec![0x00, 0x20]),
Some(vec![0x01, 0x00]),
Some(vec![0x02, 0x00]),
]
);
Ok(())
}
#[test]
fn test_sort_by_bytes_desc() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let bytes_field = schema_builder
.add_bytes_field("data", BytesOptions::default().set_fast().set_indexed());
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer: IndexWriter = index.writer_for_tests()?;
let test_data: Vec<Vec<u8>> = vec![vec![0x00, 0x10], vec![0x02, 0x00], vec![0x01, 0x00]];
for bytes in &test_data {
let mut doc = TantivyDocument::new();
doc.add_bytes(bytes_field, bytes);
index_writer.add_document(doc)?;
}
index_writer.commit()?;
let reader = index.reader()?;
let searcher = reader.searcher();
// Sort descending by bytes
let top_docs =
TopDocs::with_limit(10).order_by((SortByBytes::for_field("data"), Order::Desc));
let results: Vec<(Option<Vec<u8>>, _)> = searcher.search(&AllQuery, &top_docs)?;
// Expected order (descending): [0x02,0x00], [0x01,0x00], [0x00,0x10]
let sorted_bytes: Vec<Option<Vec<u8>>> = results.into_iter().map(|(b, _)| b).collect();
assert_eq!(
sorted_bytes,
vec![
Some(vec![0x02, 0x00]),
Some(vec![0x01, 0x00]),
Some(vec![0x00, 0x10]),
]
);
Ok(())
}
}

View File

@@ -1,12 +1,12 @@
use columnar::{ColumnType, MonotonicallyMappableToU64};
use crate::collector::sort_key::{
NaturalComparator, SortByBytes, SortBySimilarityScore, SortByStaticFastValue, SortByString,
NaturalComparator, SortBySimilarityScore, SortByStaticFastValue, SortByString,
};
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::fastfield::FastFieldNotAvailableError;
use crate::schema::OwnedValue;
use crate::{DateTime, DocId, Score};
use crate::{DateTime, DocId, Score, SegmentReader};
/// Sort by the boxed / OwnedValue representation of either a fast field, or of the score.
///
@@ -86,7 +86,7 @@ impl SortKeyComputer for SortByErasedType {
fn segment_sort_key_computer(
&self,
segment_reader: &crate::SegmentReader,
segment_reader: &dyn SegmentReader,
) -> crate::Result<Self::Child> {
let inner: Box<dyn ErasedSegmentSortKeyComputer> = match self {
Self::Field(column_name) => {
@@ -114,16 +114,6 @@ impl SortKeyComputer for SortByErasedType {
},
})
}
ColumnType::Bytes => {
let computer = SortByBytes::for_field(column_name);
let inner = computer.segment_sort_key_computer(segment_reader)?;
Box::new(ErasedSegmentSortKeyComputerWrapper {
inner,
converter: |val: Option<Vec<u8>>| {
val.map(OwnedValue::Bytes).unwrap_or(OwnedValue::Null)
},
})
}
ColumnType::U64 => {
let computer = SortByStaticFastValue::<u64>::for_field(column_name);
let inner = computer.segment_sort_key_computer(segment_reader)?;
@@ -291,65 +281,6 @@ mod tests {
);
}
#[test]
fn test_sort_by_owned_bytes() {
let mut schema_builder = Schema::builder();
let data_field = schema_builder.add_bytes_field("data", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut writer = index.writer_for_tests().unwrap();
writer
.add_document(doc!(data_field => vec![0x03u8, 0x00]))
.unwrap();
writer
.add_document(doc!(data_field => vec![0x01u8, 0x00]))
.unwrap();
writer
.add_document(doc!(data_field => vec![0x02u8, 0x00]))
.unwrap();
writer.add_document(doc!()).unwrap();
writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
// Sort descending (Natural - highest first)
let collector = TopDocs::with_limit(10)
.order_by((SortByErasedType::for_field("data"), ComparatorEnum::Natural));
let top_docs = searcher.search(&AllQuery, &collector).unwrap();
let values: Vec<OwnedValue> = top_docs.into_iter().map(|(key, _)| key).collect();
assert_eq!(
values,
vec![
OwnedValue::Bytes(vec![0x03, 0x00]),
OwnedValue::Bytes(vec![0x02, 0x00]),
OwnedValue::Bytes(vec![0x01, 0x00]),
OwnedValue::Null
]
);
// Sort ascending (ReverseNoneLower - lowest first, nulls last)
let collector = TopDocs::with_limit(10).order_by((
SortByErasedType::for_field("data"),
ComparatorEnum::ReverseNoneLower,
));
let top_docs = searcher.search(&AllQuery, &collector).unwrap();
let values: Vec<OwnedValue> = top_docs.into_iter().map(|(key, _)| key).collect();
assert_eq!(
values,
vec![
OwnedValue::Bytes(vec![0x01, 0x00]),
OwnedValue::Bytes(vec![0x02, 0x00]),
OwnedValue::Bytes(vec![0x03, 0x00]),
OwnedValue::Null
]
);
}
#[test]
fn test_sort_by_owned_reverse() {
let mut schema_builder = Schema::builder();

View File

@@ -1,6 +1,6 @@
use crate::collector::sort_key::NaturalComparator;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer, TopNComputer};
use crate::{DocAddress, DocId, Score};
use crate::{DocAddress, DocId, Score, SegmentReader};
/// Sort by similarity score.
#[derive(Clone, Debug, Copy)]
@@ -19,7 +19,7 @@ impl SortKeyComputer for SortBySimilarityScore {
fn segment_sort_key_computer(
&self,
_segment_reader: &crate::SegmentReader,
_segment_reader: &dyn SegmentReader,
) -> crate::Result<Self::Child> {
Ok(SortBySimilarityScore)
}
@@ -29,7 +29,7 @@ impl SortKeyComputer for SortBySimilarityScore {
&self,
k: usize,
weight: &dyn crate::query::Weight,
reader: &crate::SegmentReader,
reader: &dyn SegmentReader,
segment_ord: u32,
) -> crate::Result<Vec<(Self::SortKey, DocAddress)>> {
let mut top_n: TopNComputer<Score, DocId, Self::Comparator> =

View File

@@ -61,7 +61,7 @@ impl<T: FastValue> SortKeyComputer for SortByStaticFastValue<T> {
fn segment_sort_key_computer(
&self,
segment_reader: &SegmentReader,
segment_reader: &dyn SegmentReader,
) -> crate::Result<Self::Child> {
let sort_column_opt = segment_reader.fast_fields().u64_lenient(&self.field)?;
let (sort_column, _sort_column_type) =

View File

@@ -3,7 +3,7 @@ use columnar::StrColumn;
use crate::collector::sort_key::NaturalComparator;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::termdict::TermOrdinal;
use crate::{DocId, Score};
use crate::{DocId, Score, SegmentReader};
/// Sort by the first value of a string column.
///
@@ -35,7 +35,7 @@ impl SortKeyComputer for SortByString {
fn segment_sort_key_computer(
&self,
segment_reader: &crate::SegmentReader,
segment_reader: &dyn SegmentReader,
) -> crate::Result<Self::Child> {
let str_column_opt = segment_reader.fast_fields().str(&self.column_name)?;
Ok(ByStringColumnSegmentSortKeyComputer { str_column_opt })

View File

@@ -119,7 +119,7 @@ pub trait SortKeyComputer: Sync {
&self,
k: usize,
weight: &dyn crate::query::Weight,
reader: &crate::SegmentReader,
reader: &dyn SegmentReader,
segment_ord: u32,
) -> crate::Result<Vec<(Self::SortKey, DocAddress)>> {
let with_scoring = self.requires_scoring();
@@ -135,7 +135,7 @@ pub trait SortKeyComputer: Sync {
}
/// Builds a child sort key computer for a specific segment.
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child>;
fn segment_sort_key_computer(&self, segment_reader: &dyn SegmentReader) -> Result<Self::Child>;
}
impl<HeadSortKeyComputer, TailSortKeyComputer> SortKeyComputer
@@ -156,7 +156,7 @@ where
(self.0.comparator(), self.1.comparator())
}
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
fn segment_sort_key_computer(&self, segment_reader: &dyn SegmentReader) -> Result<Self::Child> {
Ok((
self.0.segment_sort_key_computer(segment_reader)?,
self.1.segment_sort_key_computer(segment_reader)?,
@@ -357,7 +357,7 @@ where
)
}
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
fn segment_sort_key_computer(&self, segment_reader: &dyn SegmentReader) -> Result<Self::Child> {
let sort_key_computer1 = self.0.segment_sort_key_computer(segment_reader)?;
let sort_key_computer2 = self.1.segment_sort_key_computer(segment_reader)?;
let sort_key_computer3 = self.2.segment_sort_key_computer(segment_reader)?;
@@ -420,7 +420,7 @@ where
SortKeyComputer4::Comparator,
);
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
fn segment_sort_key_computer(&self, segment_reader: &dyn SegmentReader) -> Result<Self::Child> {
let sort_key_computer1 = self.0.segment_sort_key_computer(segment_reader)?;
let sort_key_computer2 = self.1.segment_sort_key_computer(segment_reader)?;
let sort_key_computer3 = self.2.segment_sort_key_computer(segment_reader)?;
@@ -454,7 +454,7 @@ where
impl<F, SegmentF, TSortKey> SortKeyComputer for F
where
F: 'static + Send + Sync + Fn(&SegmentReader) -> SegmentF,
F: 'static + Send + Sync + Fn(&dyn SegmentReader) -> SegmentF,
SegmentF: 'static + FnMut(DocId) -> TSortKey,
TSortKey: 'static + PartialOrd + Clone + Send + Sync + std::fmt::Debug,
{
@@ -462,7 +462,7 @@ where
type Child = SegmentF;
type Comparator = NaturalComparator;
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
fn segment_sort_key_computer(&self, segment_reader: &dyn SegmentReader) -> Result<Self::Child> {
Ok((self)(segment_reader))
}
}
@@ -509,10 +509,10 @@ mod tests {
#[test]
fn test_lazy_score_computer() {
let score_computer_primary = |_segment_reader: &SegmentReader| |_doc: DocId| 200u32;
let score_computer_primary = |_segment_reader: &dyn SegmentReader| |_doc: DocId| 200u32;
let call_count = Arc::new(AtomicUsize::new(0));
let call_count_clone = call_count.clone();
let score_computer_secondary = move |_segment_reader: &SegmentReader| {
let score_computer_secondary = move |_segment_reader: &dyn SegmentReader| {
let call_count_new_clone = call_count_clone.clone();
move |_doc: DocId| {
call_count_new_clone.fetch_add(1, AtomicOrdering::SeqCst);
@@ -572,10 +572,10 @@ mod tests {
#[test]
fn test_lazy_score_computer_dynamic_ordering() {
let score_computer_primary = |_segment_reader: &SegmentReader| |_doc: DocId| 200u32;
let score_computer_primary = |_segment_reader: &dyn SegmentReader| |_doc: DocId| 200u32;
let call_count = Arc::new(AtomicUsize::new(0));
let call_count_clone = call_count.clone();
let score_computer_secondary = move |_segment_reader: &SegmentReader| {
let score_computer_secondary = move |_segment_reader: &dyn SegmentReader| {
let call_count_new_clone = call_count_clone.clone();
move |_doc: DocId| {
call_count_new_clone.fetch_add(1, AtomicOrdering::SeqCst);

View File

@@ -32,7 +32,11 @@ where TSortKeyComputer: SortKeyComputer + Send + Sync + 'static
self.sort_key_computer.check_schema(schema)
}
fn for_segment(&self, segment_ord: u32, segment_reader: &SegmentReader) -> Result<Self::Child> {
fn for_segment(
&self,
segment_ord: u32,
segment_reader: &dyn SegmentReader,
) -> Result<Self::Child> {
let segment_sort_key_computer = self
.sort_key_computer
.segment_sort_key_computer(segment_reader)?;
@@ -63,7 +67,7 @@ where TSortKeyComputer: SortKeyComputer + Send + Sync + 'static
&self,
weight: &dyn Weight,
segment_ord: u32,
reader: &SegmentReader,
reader: &dyn SegmentReader,
) -> crate::Result<Vec<(TSortKeyComputer::SortKey, DocAddress)>> {
let k = self.doc_range.end;
let docs = self

View File

@@ -5,7 +5,7 @@ use crate::query::{AllQuery, QueryParser};
use crate::schema::{Schema, FAST, TEXT};
use crate::time::format_description::well_known::Rfc3339;
use crate::time::OffsetDateTime;
use crate::{DateTime, DocAddress, Index, Searcher, TantivyDocument};
use crate::{DateTime, DocAddress, Index, Searcher, SegmentReader, TantivyDocument};
pub const TEST_COLLECTOR_WITH_SCORE: TestCollector = TestCollector {
compute_score: true,
@@ -109,7 +109,7 @@ impl Collector for TestCollector {
fn for_segment(
&self,
segment_id: SegmentOrdinal,
_reader: &SegmentReader,
_reader: &dyn SegmentReader,
) -> crate::Result<TestSegmentCollector> {
Ok(TestSegmentCollector {
segment_id,
@@ -180,7 +180,7 @@ impl Collector for FastFieldTestCollector {
fn for_segment(
&self,
_: SegmentOrdinal,
segment_reader: &SegmentReader,
segment_reader: &dyn SegmentReader,
) -> crate::Result<FastFieldSegmentCollector> {
let reader = segment_reader
.fast_fields()
@@ -243,7 +243,7 @@ impl Collector for BytesFastFieldTestCollector {
fn for_segment(
&self,
_segment_local_id: u32,
segment_reader: &SegmentReader,
segment_reader: &dyn SegmentReader,
) -> crate::Result<BytesFastFieldSegmentCollector> {
let column_opt = segment_reader.fast_fields().bytes(&self.field)?;
Ok(BytesFastFieldSegmentCollector {

View File

@@ -393,7 +393,7 @@ impl TopDocs {
/// // This is where we build our collector with our custom score.
/// let top_docs_by_custom_score = TopDocs
/// ::with_limit(10)
/// .tweak_score(move |segment_reader: &SegmentReader| {
/// .tweak_score(move |segment_reader: &dyn SegmentReader| {
/// // The argument is a function that returns our scoring
/// // function.
/// //
@@ -442,7 +442,7 @@ pub struct TweakScoreFn<F>(F);
impl<F, TTweakScoreSortKeyFn, TSortKey> SortKeyComputer for TweakScoreFn<F>
where
F: 'static + Send + Sync + Fn(&SegmentReader) -> TTweakScoreSortKeyFn,
F: 'static + Send + Sync + Fn(&dyn SegmentReader) -> TTweakScoreSortKeyFn,
TTweakScoreSortKeyFn: 'static + Fn(DocId, Score) -> TSortKey,
TweakScoreSegmentSortKeyComputer<TTweakScoreSortKeyFn>:
SegmentSortKeyComputer<SortKey = TSortKey, SegmentSortKey = TSortKey>,
@@ -458,7 +458,7 @@ where
fn segment_sort_key_computer(
&self,
segment_reader: &SegmentReader,
segment_reader: &dyn SegmentReader,
) -> crate::Result<Self::Child> {
Ok({
TweakScoreSegmentSortKeyComputer {
@@ -1525,7 +1525,7 @@ mod tests {
let text_query = query_parser.parse_query("droopy tax")?;
let collector = TopDocs::with_limit(2)
.and_offset(1)
.order_by(move |_segment_reader: &SegmentReader| move |doc: DocId| doc);
.order_by(move |_segment_reader: &dyn SegmentReader| move |doc: DocId| doc);
let score_docs: Vec<(u32, DocAddress)> =
index.reader()?.searcher().search(&text_query, &collector)?;
assert_eq!(
@@ -1543,7 +1543,7 @@ mod tests {
let text_query = query_parser.parse_query("droopy tax").unwrap();
let collector = TopDocs::with_limit(2)
.and_offset(1)
.order_by(move |_segment_reader: &SegmentReader| move |doc: DocId| doc);
.order_by(move |_segment_reader: &dyn SegmentReader| move |doc: DocId| doc);
let score_docs: Vec<(u32, DocAddress)> = index
.reader()
.unwrap()

View File

@@ -4,7 +4,7 @@ use std::{fmt, io};
use crate::collector::Collector;
use crate::core::Executor;
use crate::index::{SegmentId, SegmentReader};
use crate::index::{ArcSegmentReader, SegmentId, SegmentReader};
use crate::query::{Bm25StatisticsProvider, EnableScoring, Query};
use crate::schema::document::DocumentDeserialize;
use crate::schema::{Schema, Term};
@@ -36,7 +36,7 @@ pub struct SearcherGeneration {
impl SearcherGeneration {
pub(crate) fn from_segment_readers(
segment_readers: &[SegmentReader],
segment_readers: &[ArcSegmentReader],
generation_id: u64,
) -> Self {
let mut segment_id_to_del_opstamp = BTreeMap::new();
@@ -133,7 +133,7 @@ impl Searcher {
pub fn doc_freq(&self, term: &Term) -> crate::Result<u64> {
let mut total_doc_freq = 0;
for segment_reader in &self.inner.segment_readers {
let inverted_index = segment_reader.inverted_index(term.field())?;
let inverted_index = segment_reader.as_ref().inverted_index(term.field())?;
let doc_freq = inverted_index.doc_freq(term)?;
total_doc_freq += u64::from(doc_freq);
}
@@ -146,7 +146,7 @@ impl Searcher {
pub async fn doc_freq_async(&self, term: &Term) -> crate::Result<u64> {
let mut total_doc_freq = 0;
for segment_reader in &self.inner.segment_readers {
let inverted_index = segment_reader.inverted_index(term.field())?;
let inverted_index = segment_reader.as_ref().inverted_index(term.field())?;
let doc_freq = inverted_index.doc_freq_async(term).await?;
total_doc_freq += u64::from(doc_freq);
}
@@ -154,13 +154,13 @@ impl Searcher {
}
/// Return the list of segment readers
pub fn segment_readers(&self) -> &[SegmentReader] {
pub fn segment_readers(&self) -> &[ArcSegmentReader] {
&self.inner.segment_readers
}
/// Returns the segment_reader associated with the given segment_ord
pub fn segment_reader(&self, segment_ord: u32) -> &SegmentReader {
&self.inner.segment_readers[segment_ord as usize]
pub fn segment_reader(&self, segment_ord: u32) -> &dyn SegmentReader {
self.inner.segment_readers[segment_ord as usize].as_ref()
}
/// Runs a query on the segment readers wrapped by the searcher.
@@ -229,7 +229,11 @@ impl Searcher {
let segment_readers = self.segment_readers();
let fruits = executor.map(
|(segment_ord, segment_reader)| {
collector.collect_segment(weight.as_ref(), segment_ord as u32, segment_reader)
collector.collect_segment(
weight.as_ref(),
segment_ord as u32,
segment_reader.as_ref(),
)
},
segment_readers.iter().enumerate(),
)?;
@@ -259,7 +263,7 @@ impl From<Arc<SearcherInner>> for Searcher {
pub(crate) struct SearcherInner {
schema: Schema,
index: Index,
segment_readers: Vec<SegmentReader>,
segment_readers: Vec<ArcSegmentReader>,
store_readers: Vec<StoreReader>,
generation: TrackedObject<SearcherGeneration>,
}
@@ -269,7 +273,7 @@ impl SearcherInner {
pub(crate) fn new(
schema: Schema,
index: Index,
segment_readers: Vec<SegmentReader>,
segment_readers: Vec<ArcSegmentReader>,
generation: TrackedObject<SearcherGeneration>,
doc_store_cache_num_blocks: usize,
) -> io::Result<SearcherInner> {
@@ -301,7 +305,7 @@ impl fmt::Debug for Searcher {
let segment_ids = self
.segment_readers()
.iter()
.map(SegmentReader::segment_id)
.map(|segment_reader| segment_reader.segment_id())
.collect::<Vec<_>>();
write!(f, "Searcher({segment_ids:?})")
}

View File

@@ -676,7 +676,7 @@ mod tests {
let num_segments = reader.searcher().segment_readers().len();
assert!(num_segments <= 4);
let num_components_except_deletes_and_tempstore =
crate::index::SegmentComponent::iterator().len() - 1;
crate::index::SegmentComponent::iterator().len() - 2;
let max_num_mmapped = num_components_except_deletes_and_tempstore * num_segments;
assert_eventually(|| {
let num_mmapped = mmap_directory.get_cache_info().mmapped.len();

View File

@@ -21,7 +21,7 @@ use std::path::PathBuf;
pub use common::file_slice::{FileHandle, FileSlice};
pub use common::{AntiCallToken, OwnedBytes, TerminatingWrite};
pub use self::composite_file::{CompositeFile, CompositeWrite};
pub(crate) use self::composite_file::{CompositeFile, CompositeWrite};
pub use self::directory::{Directory, DirectoryClone, DirectoryLock};
pub use self::directory_lock::{Lock, INDEX_WRITER_LOCK, META_LOCK};
pub use self::ram_directory::RamDirectory;
@@ -52,7 +52,7 @@ pub use self::mmap_directory::MmapDirectory;
///
/// `WritePtr` are required to implement both Write
/// and Seek.
pub type WritePtr = BufWriter<Box<dyn TerminatingWrite + Send + Sync>>;
pub type WritePtr = BufWriter<Box<dyn TerminatingWrite>>;
#[cfg(test)]
mod tests;

View File

@@ -65,8 +65,8 @@ pub trait DocSet: Send {
/// `seek_danger(..)` until it returns `Found`, and get back to a valid state.
///
/// `seek_lower_bound` can be any `DocId` (in the docset or not) as long as it is in
/// `(target .. seek_result] U {TERMINATED}` where `seek_result` is the first document in the
/// docset greater than to `target`.
/// `(target .. seek_result]` where `seek_result` is the first document in the docset greater
/// than to `target`.
///
/// `seek_danger` may return `SeekLowerBound(TERMINATED)`.
///
@@ -98,7 +98,7 @@ pub trait DocSet: Send {
if doc == target {
SeekDangerResult::Found
} else {
SeekDangerResult::SeekLowerBound(doc)
SeekDangerResult::SeekLowerBound(self.doc())
}
}

View File

@@ -96,7 +96,7 @@ mod tests {
};
use crate::time::OffsetDateTime;
use crate::tokenizer::{LowerCaser, RawTokenizer, TextAnalyzer, TokenizerManager};
use crate::{Index, IndexWriter, SegmentReader};
use crate::{Index, IndexWriter};
pub static SCHEMA: Lazy<Schema> = Lazy::new(|| {
let mut schema_builder = Schema::builder();
@@ -430,7 +430,7 @@ mod tests {
.searcher()
.segment_readers()
.iter()
.map(SegmentReader::segment_id)
.map(|segment_reader| segment_reader.segment_id())
.collect();
assert_eq!(segment_ids.len(), 2);
index_writer.merge(&segment_ids[..]).wait().unwrap();

View File

@@ -14,7 +14,7 @@ use crate::directory::error::OpenReadError;
use crate::directory::MmapDirectory;
use crate::directory::{Directory, ManagedDirectory, RamDirectory, INDEX_WRITER_LOCK};
use crate::error::{DataCorruption, TantivyError};
use crate::index::{IndexMeta, SegmentId, SegmentMeta, SegmentMetaInventory};
use crate::index::{IndexMeta, SegmentId, SegmentMeta, SegmentMetaInventory, SegmentReader};
use crate::indexer::index_writer::{
IndexWriterOptions, MAX_NUM_THREAD, MEMORY_BUDGET_NUM_BYTES_MIN,
};
@@ -24,7 +24,7 @@ use crate::reader::{IndexReader, IndexReaderBuilder};
use crate::schema::document::Document;
use crate::schema::{Field, FieldType, Schema};
use crate::tokenizer::{TextAnalyzer, TokenizerManager};
use crate::SegmentReader;
use crate::TantivySegmentReader;
fn load_metas(
directory: &dyn Directory,
@@ -492,7 +492,7 @@ impl Index {
let segments = self.searchable_segments()?;
let fields_metadata: Vec<Vec<FieldMetadata>> = segments
.into_iter()
.map(|segment| SegmentReader::open(&segment)?.fields_metadata())
.map(|segment| TantivySegmentReader::open(&segment)?.fields_metadata())
.collect::<Result<_, _>>()?;
Ok(merge_field_meta_data(fields_metadata))
}

View File

@@ -1,6 +1,8 @@
use std::collections::HashSet;
use std::fmt;
use std::path::PathBuf;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use serde::{Deserialize, Serialize};
@@ -35,6 +37,7 @@ impl SegmentMetaInventory {
let inner = InnerSegmentMeta {
segment_id,
max_doc,
include_temp_doc_store: Arc::new(AtomicBool::new(true)),
deletes: None,
};
SegmentMeta::from(self.inventory.track(inner))
@@ -82,6 +85,15 @@ impl SegmentMeta {
self.tracked.segment_id
}
/// Removes the Component::TempStore from the alive list and
/// therefore marks the temp docstore file to be deleted by
/// the garbage collection.
pub fn untrack_temp_docstore(&self) {
self.tracked
.include_temp_doc_store
.store(false, std::sync::atomic::Ordering::Relaxed);
}
/// Returns the number of deleted documents.
pub fn num_deleted_docs(&self) -> u32 {
self.tracked
@@ -99,9 +111,20 @@ impl SegmentMeta {
/// is by removing all files that have been created by tantivy
/// and are not used by any segment anymore.
pub fn list_files(&self) -> HashSet<PathBuf> {
SegmentComponent::iterator()
.map(|component| self.relative_path(*component))
.collect::<HashSet<PathBuf>>()
if self
.tracked
.include_temp_doc_store
.load(std::sync::atomic::Ordering::Relaxed)
{
SegmentComponent::iterator()
.map(|component| self.relative_path(*component))
.collect::<HashSet<PathBuf>>()
} else {
SegmentComponent::iterator()
.filter(|comp| *comp != &SegmentComponent::TempStore)
.map(|component| self.relative_path(*component))
.collect::<HashSet<PathBuf>>()
}
}
/// Returns the relative path of a component of our segment.
@@ -115,6 +138,7 @@ impl SegmentMeta {
SegmentComponent::Positions => ".pos".to_string(),
SegmentComponent::Terms => ".term".to_string(),
SegmentComponent::Store => ".store".to_string(),
SegmentComponent::TempStore => ".store.temp".to_string(),
SegmentComponent::FastFields => ".fast".to_string(),
SegmentComponent::FieldNorms => ".fieldnorm".to_string(),
SegmentComponent::Delete => format!(".{}.del", self.delete_opstamp().unwrap_or(0)),
@@ -159,6 +183,7 @@ impl SegmentMeta {
segment_id: inner_meta.segment_id,
max_doc,
deletes: None,
include_temp_doc_store: Arc::new(AtomicBool::new(true)),
});
SegmentMeta { tracked }
}
@@ -177,6 +202,7 @@ impl SegmentMeta {
let tracked = self.tracked.map(move |inner_meta| InnerSegmentMeta {
segment_id: inner_meta.segment_id,
max_doc: inner_meta.max_doc,
include_temp_doc_store: Arc::new(AtomicBool::new(true)),
deletes: Some(delete_meta),
});
SegmentMeta { tracked }
@@ -188,6 +214,14 @@ struct InnerSegmentMeta {
segment_id: SegmentId,
max_doc: u32,
pub deletes: Option<DeleteMeta>,
/// If you want to avoid the SegmentComponent::TempStore file to be covered by
/// garbage collection and deleted, set this to true. This is used during merge.
#[serde(skip)]
#[serde(default = "default_temp_store")]
pub(crate) include_temp_doc_store: Arc<AtomicBool>,
}
fn default_temp_store() -> Arc<AtomicBool> {
Arc::new(AtomicBool::new(false))
}
impl InnerSegmentMeta {

View File

@@ -1,4 +1,9 @@
#[cfg(feature = "quickwit")]
use std::future::Future;
use std::io;
#[cfg(feature = "quickwit")]
use std::pin::Pin;
use std::sync::Arc;
use common::json_path_writer::JSON_END_OF_PATH;
use common::{BinarySerializable, ByteCount};
@@ -27,7 +32,102 @@ use crate::termdict::TermDictionary;
///
/// `InvertedIndexReader` are created by calling
/// [`SegmentReader::inverted_index()`](crate::SegmentReader::inverted_index).
pub struct InvertedIndexReader {
pub trait InvertedIndexReader: Send + Sync {
/// Returns the term info associated with the term.
fn get_term_info(&self, term: &Term) -> io::Result<Option<TermInfo>>;
/// Return the term dictionary datastructure.
fn terms(&self) -> &TermDictionary;
/// Return the fields and types encoded in the dictionary in lexicographic order.
/// Only valid on JSON fields.
///
/// Notice: This requires a full scan and therefore **very expensive**.
/// TODO: Move to sstable to use the index.
#[doc(hidden)]
fn list_encoded_json_fields(&self) -> io::Result<Vec<InvertedIndexFieldSpace>>;
/// Returns a block postings given a `Term`.
/// This method is for an advanced usage only.
///
/// Most users should prefer using [`Self::read_postings()`] instead.
fn read_block_postings(
&self,
term: &Term,
option: IndexRecordOption,
) -> io::Result<Option<BlockSegmentPostings>>;
/// Returns a block postings given a `term_info`.
/// This method is for an advanced usage only.
///
/// Most users should prefer using [`Self::read_postings()`] instead.
fn read_block_postings_from_terminfo(
&self,
term_info: &TermInfo,
requested_option: IndexRecordOption,
) -> io::Result<BlockSegmentPostings>;
/// Returns a posting object given a `term_info`.
/// This method is for an advanced usage only.
///
/// Most users should prefer using [`Self::read_postings()`] instead.
fn read_postings_from_terminfo(
&self,
term_info: &TermInfo,
option: IndexRecordOption,
) -> io::Result<SegmentPostings>;
/// Returns the total number of tokens recorded for all documents
/// (including deleted documents).
fn total_num_tokens(&self) -> u64;
/// Returns the segment postings associated with the term, and with the given option,
/// or `None` if the term has never been encountered and indexed.
fn read_postings(
&self,
term: &Term,
option: IndexRecordOption,
) -> io::Result<Option<SegmentPostings>>;
/// Returns the number of documents containing the term.
fn doc_freq(&self, term: &Term) -> io::Result<u32>;
/// Returns the number of documents containing the term asynchronously.
#[cfg(feature = "quickwit")]
fn doc_freq_async<'a>(&'a self, term: &'a Term) -> BoxFuture<'a, io::Result<u32>>;
/// Warmup a block postings given a `Term`.
/// This method is for an advanced usage only.
///
/// returns a boolean, whether the term was found in the dictionary
#[cfg(feature = "quickwit")]
fn warm_postings<'a>(
&'a self,
term: &'a Term,
with_positions: bool,
) -> BoxFuture<'a, io::Result<bool>>;
/// Warmup the block postings for all terms.
/// This method is for an advanced usage only.
///
/// If you know which terms to pre-load, prefer using [`Self::warm_postings`] or
/// [`Self::warm_postings`] instead.
#[cfg(feature = "quickwit")]
fn warm_postings_full<'a>(&'a self, with_positions: bool) -> BoxFuture<'a, io::Result<()>>;
}
/// Convenient alias for an atomically reference counted inverted index reader handle.
pub type ArcInvertedIndexReader = Arc<dyn InvertedIndexReader>;
#[cfg(feature = "quickwit")]
/// Boxed future used by async inverted index reader methods.
pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
/// The tantivy inverted index reader is in charge of accessing
/// the inverted index associated with a specific field.
///
/// This is the default implementation of [`InvertedIndexReader`].
pub struct TantivyInvertedIndexReader {
termdict: TermDictionary,
postings_file_slice: FileSlice,
positions_file_slice: FileSlice,
@@ -36,11 +136,16 @@ pub struct InvertedIndexReader {
}
/// Object that records the amount of space used by a field in an inverted index.
pub(crate) struct InvertedIndexFieldSpace {
pub struct InvertedIndexFieldSpace {
/// The JSON field name (without the parent field).
pub field_name: String,
/// The field type encoded in the term dictionary.
pub field_type: Type,
/// Total postings size for this field.
pub postings_size: ByteCount,
/// Total positions size for this field.
pub positions_size: ByteCount,
/// Number of terms for this field.
pub num_terms: u64,
}
@@ -62,16 +167,16 @@ impl InvertedIndexFieldSpace {
}
}
impl InvertedIndexReader {
impl TantivyInvertedIndexReader {
pub(crate) fn new(
termdict: TermDictionary,
postings_file_slice: FileSlice,
positions_file_slice: FileSlice,
record_option: IndexRecordOption,
) -> io::Result<InvertedIndexReader> {
) -> io::Result<TantivyInvertedIndexReader> {
let (total_num_tokens_slice, postings_body) = postings_file_slice.split(8);
let total_num_tokens = u64::deserialize(&mut total_num_tokens_slice.read_bytes()?)?;
Ok(InvertedIndexReader {
Ok(TantivyInvertedIndexReader {
termdict,
postings_file_slice: postings_body,
positions_file_slice,
@@ -82,8 +187,8 @@ impl InvertedIndexReader {
/// Creates an empty `InvertedIndexReader` object, which
/// contains no terms at all.
pub fn empty(record_option: IndexRecordOption) -> InvertedIndexReader {
InvertedIndexReader {
pub fn empty(record_option: IndexRecordOption) -> TantivyInvertedIndexReader {
TantivyInvertedIndexReader {
termdict: TermDictionary::empty(),
postings_file_slice: FileSlice::empty(),
positions_file_slice: FileSlice::empty(),
@@ -160,29 +265,6 @@ impl InvertedIndexReader {
Ok(fields)
}
/// Resets the block segment to another position of the postings
/// file.
///
/// This is useful for enumerating through a list of terms,
/// and consuming the associated posting lists while avoiding
/// reallocating a [`BlockSegmentPostings`].
///
/// # Warning
///
/// This does not reset the positions list.
pub fn reset_block_postings_from_terminfo(
&self,
term_info: &TermInfo,
block_postings: &mut BlockSegmentPostings,
) -> io::Result<()> {
let postings_slice = self
.postings_file_slice
.slice(term_info.postings_range.clone());
let postings_bytes = postings_slice.read_bytes()?;
block_postings.reset(term_info.doc_freq, postings_bytes)?;
Ok(())
}
/// Returns a block postings given a `Term`.
/// This method is for an advanced usage only.
///
@@ -282,7 +364,7 @@ impl InvertedIndexReader {
}
#[cfg(feature = "quickwit")]
impl InvertedIndexReader {
impl TantivyInvertedIndexReader {
pub(crate) async fn get_term_info_async(&self, term: &Term) -> io::Result<Option<TermInfo>> {
self.termdict.get_async(term.serialized_value_bytes()).await
}
@@ -492,3 +574,84 @@ impl InvertedIndexReader {
.unwrap_or(0u32))
}
}
impl InvertedIndexReader for TantivyInvertedIndexReader {
fn get_term_info(&self, term: &Term) -> io::Result<Option<TermInfo>> {
TantivyInvertedIndexReader::get_term_info(self, term)
}
fn terms(&self) -> &TermDictionary {
TantivyInvertedIndexReader::terms(self)
}
fn list_encoded_json_fields(&self) -> io::Result<Vec<InvertedIndexFieldSpace>> {
TantivyInvertedIndexReader::list_encoded_json_fields(self)
}
fn read_block_postings(
&self,
term: &Term,
option: IndexRecordOption,
) -> io::Result<Option<BlockSegmentPostings>> {
TantivyInvertedIndexReader::read_block_postings(self, term, option)
}
fn read_block_postings_from_terminfo(
&self,
term_info: &TermInfo,
requested_option: IndexRecordOption,
) -> io::Result<BlockSegmentPostings> {
TantivyInvertedIndexReader::read_block_postings_from_terminfo(
self,
term_info,
requested_option,
)
}
fn read_postings_from_terminfo(
&self,
term_info: &TermInfo,
option: IndexRecordOption,
) -> io::Result<SegmentPostings> {
TantivyInvertedIndexReader::read_postings_from_terminfo(self, term_info, option)
}
fn total_num_tokens(&self) -> u64 {
TantivyInvertedIndexReader::total_num_tokens(self)
}
fn read_postings(
&self,
term: &Term,
option: IndexRecordOption,
) -> io::Result<Option<SegmentPostings>> {
TantivyInvertedIndexReader::read_postings(self, term, option)
}
fn doc_freq(&self, term: &Term) -> io::Result<u32> {
TantivyInvertedIndexReader::doc_freq(self, term)
}
#[cfg(feature = "quickwit")]
fn doc_freq_async<'a>(&'a self, term: &'a Term) -> BoxFuture<'a, io::Result<u32>> {
Box::pin(async move { TantivyInvertedIndexReader::doc_freq_async(self, term).await })
}
#[cfg(feature = "quickwit")]
fn warm_postings<'a>(
&'a self,
term: &'a Term,
with_positions: bool,
) -> BoxFuture<'a, io::Result<bool>> {
Box::pin(async move {
TantivyInvertedIndexReader::warm_postings(self, term, with_positions).await
})
}
#[cfg(feature = "quickwit")]
fn warm_postings_full<'a>(&'a self, with_positions: bool) -> BoxFuture<'a, io::Result<()>> {
Box::pin(async move {
TantivyInvertedIndexReader::warm_postings_full(self, with_positions).await
})
}
}

View File

@@ -13,8 +13,13 @@ mod segment_reader;
pub use self::index::{Index, IndexBuilder};
pub(crate) use self::index_meta::SegmentMetaInventory;
pub use self::index_meta::{IndexMeta, IndexSettings, Order, SegmentMeta};
pub use self::inverted_index_reader::InvertedIndexReader;
pub use self::inverted_index_reader::{
ArcInvertedIndexReader, InvertedIndexFieldSpace, InvertedIndexReader,
TantivyInvertedIndexReader,
};
pub use self::segment::Segment;
pub use self::segment_component::SegmentComponent;
pub use self::segment_id::SegmentId;
pub use self::segment_reader::{FieldMetadata, SegmentReader};
pub use self::segment_reader::{
ArcSegmentReader, FieldMetadata, SegmentReader, TantivySegmentReader,
};

View File

@@ -23,6 +23,8 @@ pub enum SegmentComponent {
/// Accessing a document from the store is relatively slow, as it
/// requires to decompress the entire block it belongs to.
Store,
/// Temporary storage of the documents, before streamed to `Store`.
TempStore,
/// Bitset describing which document of the segment is alive.
/// (It was representing deleted docs but changed to represent alive docs from v0.17)
Delete,
@@ -31,13 +33,14 @@ pub enum SegmentComponent {
impl SegmentComponent {
/// Iterates through the components.
pub fn iterator() -> slice::Iter<'static, SegmentComponent> {
static SEGMENT_COMPONENTS: [SegmentComponent; 7] = [
static SEGMENT_COMPONENTS: [SegmentComponent; 8] = [
SegmentComponent::Postings,
SegmentComponent::Positions,
SegmentComponent::FastFields,
SegmentComponent::FieldNorms,
SegmentComponent::Terms,
SegmentComponent::Store,
SegmentComponent::TempStore,
SegmentComponent::Delete,
];
SEGMENT_COMPONENTS.iter()

View File

@@ -9,8 +9,10 @@ use itertools::Itertools;
use crate::directory::{CompositeFile, FileSlice};
use crate::error::DataCorruption;
use crate::fastfield::{intersect_alive_bitsets, AliveBitSet, FacetReader, FastFieldReaders};
use crate::fieldnorm::{FieldNormReader, FieldNormReaders};
use crate::index::{InvertedIndexReader, Segment, SegmentComponent, SegmentId};
use crate::fieldnorm::FieldNormReaders;
use crate::index::{
ArcInvertedIndexReader, Segment, SegmentComponent, SegmentId, TantivyInvertedIndexReader,
};
use crate::json_utils::json_path_sep_to_dot;
use crate::schema::{Field, IndexRecordOption, Schema, Type};
use crate::space_usage::SegmentSpaceUsage;
@@ -18,6 +20,93 @@ use crate::store::StoreReader;
use crate::termdict::TermDictionary;
use crate::{DocId, Opstamp};
/// Abstraction over a segment reader for accessing all data structures of a segment.
///
/// This trait exists to decouple the query layer from the concrete on-disk layout. Alternative
/// codecs can implement it to expose their own segment representation.
pub trait SegmentReader: Send + Sync {
/// Highest document id ever attributed in this segment + 1.
fn max_doc(&self) -> DocId;
/// Number of alive documents. Deleted documents are not counted.
fn num_docs(&self) -> DocId;
/// Returns the schema of the index this segment belongs to.
fn schema(&self) -> &Schema;
/// Return the number of documents that have been deleted in the segment.
fn num_deleted_docs(&self) -> DocId {
self.max_doc() - self.num_docs()
}
/// Returns true if some of the documents of the segment have been deleted.
fn has_deletes(&self) -> bool {
self.num_deleted_docs() > 0
}
/// Accessor to a segment's fast field reader.
fn fast_fields(&self) -> &FastFieldReaders;
/// Accessor to the `FacetReader` associated with a given `Field`.
fn facet_reader(&self, field_name: &str) -> crate::Result<FacetReader> {
let schema = self.schema();
let field = schema.get_field(field_name)?;
let field_entry = schema.get_field_entry(field);
if field_entry.field_type().value_type() != Type::Facet {
return Err(crate::TantivyError::SchemaError(format!(
"`{field_name}` is not a facet field.`"
)));
}
let Some(facet_column) = self.fast_fields().str(field_name)? else {
panic!("Facet Field `{field_name}` is missing. This should not happen");
};
Ok(FacetReader::new(facet_column))
}
/// Accessor to the segment's field norms readers container.
fn fieldnorms_readers(&self) -> &FieldNormReaders;
/// Accessor to the segment's [`StoreReader`](crate::store::StoreReader).
fn get_store_reader(&self, cache_num_blocks: usize) -> io::Result<StoreReader>;
/// Returns a field reader associated with the field given in argument.
fn inverted_index(&self, field: Field) -> crate::Result<ArcInvertedIndexReader>;
/// Returns the list of fields that have been indexed in the segment.
fn fields_metadata(&self) -> crate::Result<Vec<FieldMetadata>>;
/// Returns the segment id
fn segment_id(&self) -> SegmentId;
/// Returns the delete opstamp
fn delete_opstamp(&self) -> Option<Opstamp>;
/// Returns the bitset representing the alive `DocId`s.
fn alive_bitset(&self) -> Option<&AliveBitSet>;
/// Returns true if the `doc` is marked as deleted.
fn is_deleted(&self, doc: DocId) -> bool {
self.alive_bitset()
.map(|alive_bitset| alive_bitset.is_deleted(doc))
.unwrap_or(false)
}
/// Returns an iterator that will iterate over the alive document ids
fn doc_ids_alive(&self) -> Box<dyn Iterator<Item = DocId> + Send + '_> {
if let Some(alive_bitset) = &self.alive_bitset() {
Box::new(alive_bitset.iter_alive())
} else {
Box::new(0u32..self.max_doc())
}
}
/// Summarize total space usage of this segment.
fn space_usage(&self) -> io::Result<SegmentSpaceUsage>;
}
/// Convenient alias for an atomically reference counted segment reader handle.
pub type ArcSegmentReader = Arc<dyn SegmentReader>;
/// Entry point to access all of the datastructures of the `Segment`
///
/// - term dictionary
@@ -29,8 +118,8 @@ use crate::{DocId, Opstamp};
/// The segment reader has a very low memory footprint,
/// as close to all of the memory data is mmapped.
#[derive(Clone)]
pub struct SegmentReader {
inv_idx_reader_cache: Arc<RwLock<HashMap<Field, Arc<InvertedIndexReader>>>>,
pub struct TantivySegmentReader {
inv_idx_reader_cache: Arc<RwLock<HashMap<Field, ArcInvertedIndexReader>>>,
segment_id: SegmentId,
delete_opstamp: Option<Opstamp>,
@@ -49,98 +138,9 @@ pub struct SegmentReader {
schema: Schema,
}
impl SegmentReader {
/// Returns the highest document id ever attributed in
/// this segment + 1.
pub fn max_doc(&self) -> DocId {
self.max_doc
}
/// Returns the number of alive documents.
/// Deleted documents are not counted.
pub fn num_docs(&self) -> DocId {
self.num_docs
}
/// Returns the schema of the index this segment belongs to.
pub fn schema(&self) -> &Schema {
&self.schema
}
/// Return the number of documents that have been
/// deleted in the segment.
pub fn num_deleted_docs(&self) -> DocId {
self.max_doc - self.num_docs
}
/// Returns true if some of the documents of the segment have been deleted.
pub fn has_deletes(&self) -> bool {
self.num_deleted_docs() > 0
}
/// Accessor to a segment's fast field reader given a field.
///
/// Returns the u64 fast value reader if the field
/// is a u64 field indexed as "fast".
///
/// Return a FastFieldNotAvailableError if the field is not
/// declared as a fast field in the schema.
///
/// # Panics
/// May panic if the index is corrupted.
pub fn fast_fields(&self) -> &FastFieldReaders {
&self.fast_fields_readers
}
/// Accessor to the `FacetReader` associated with a given `Field`.
pub fn facet_reader(&self, field_name: &str) -> crate::Result<FacetReader> {
let schema = self.schema();
let field = schema.get_field(field_name)?;
let field_entry = schema.get_field_entry(field);
if field_entry.field_type().value_type() != Type::Facet {
return Err(crate::TantivyError::SchemaError(format!(
"`{field_name}` is not a facet field.`"
)));
}
let Some(facet_column) = self.fast_fields().str(field_name)? else {
panic!("Facet Field `{field_name}` is missing. This should not happen");
};
Ok(FacetReader::new(facet_column))
}
/// Accessor to the segment's `Field norms`'s reader.
///
/// Field norms are the length (in tokens) of the fields.
/// It is used in the computation of the [TfIdf](https://fulmicoton.gitbooks.io/tantivy-doc/content/tfidf.html).
///
/// They are simply stored as a fast field, serialized in
/// the `.fieldnorm` file of the segment.
pub fn get_fieldnorms_reader(&self, field: Field) -> crate::Result<FieldNormReader> {
self.fieldnorm_readers.get_field(field)?.ok_or_else(|| {
let field_name = self.schema.get_field_name(field);
let err_msg = format!(
"Field norm not found for field {field_name:?}. Was the field set to record norm \
during indexing?"
);
crate::TantivyError::SchemaError(err_msg)
})
}
#[doc(hidden)]
pub fn fieldnorms_readers(&self) -> &FieldNormReaders {
&self.fieldnorm_readers
}
/// Accessor to the segment's [`StoreReader`](crate::store::StoreReader).
///
/// `cache_num_blocks` sets the number of decompressed blocks to be cached in an LRU.
/// The size of blocks is configurable, this should be reflexted in the
pub fn get_store_reader(&self, cache_num_blocks: usize) -> io::Result<StoreReader> {
StoreReader::open(self.store_file.clone(), cache_num_blocks)
}
impl TantivySegmentReader {
/// Open a new segment for reading.
pub fn open(segment: &Segment) -> crate::Result<SegmentReader> {
pub fn open(segment: &Segment) -> crate::Result<TantivySegmentReader> {
Self::open_with_custom_alive_set(segment, None)
}
@@ -148,7 +148,7 @@ impl SegmentReader {
pub fn open_with_custom_alive_set(
segment: &Segment,
custom_bitset: Option<AliveBitSet>,
) -> crate::Result<SegmentReader> {
) -> crate::Result<TantivySegmentReader> {
let termdict_file = segment.open_read(SegmentComponent::Terms)?;
let termdict_composite = CompositeFile::open(&termdict_file)?;
@@ -190,7 +190,7 @@ impl SegmentReader {
.map(|alive_bitset| alive_bitset.num_alive_docs() as u32)
.unwrap_or(max_doc);
Ok(SegmentReader {
Ok(TantivySegmentReader {
inv_idx_reader_cache: Default::default(),
num_docs,
max_doc,
@@ -206,6 +206,52 @@ impl SegmentReader {
schema,
})
}
}
impl SegmentReader for TantivySegmentReader {
/// Returns the highest document id ever attributed in
/// this segment + 1.
fn max_doc(&self) -> DocId {
self.max_doc
}
/// Returns the number of alive documents.
/// Deleted documents are not counted.
fn num_docs(&self) -> DocId {
self.num_docs
}
/// Returns the schema of the index this segment belongs to.
fn schema(&self) -> &Schema {
&self.schema
}
/// Accessor to a segment's fast field reader given a field.
///
/// Returns the u64 fast value reader if the field
/// is a u64 field indexed as "fast".
///
/// Return a FastFieldNotAvailableError if the field is not
/// declared as a fast field in the schema.
///
/// # Panics
/// May panic if the index is corrupted.
fn fast_fields(&self) -> &FastFieldReaders {
&self.fast_fields_readers
}
#[doc(hidden)]
fn fieldnorms_readers(&self) -> &FieldNormReaders {
&self.fieldnorm_readers
}
/// Accessor to the segment's [`StoreReader`](crate::store::StoreReader).
///
/// `cache_num_blocks` sets the number of decompressed blocks to be cached in an LRU.
/// The size of blocks is configurable, this should be reflexted in the
fn get_store_reader(&self, cache_num_blocks: usize) -> io::Result<StoreReader> {
StoreReader::open(self.store_file.clone(), cache_num_blocks)
}
/// Returns a field reader associated with the field given in argument.
/// If the field was not present in the index during indexing time,
@@ -219,7 +265,7 @@ impl SegmentReader {
/// is returned.
/// Similarly, if the field is marked as indexed but no term has been indexed for the given
/// index, an empty `InvertedIndexReader` is returned (but no warning is logged).
pub fn inverted_index(&self, field: Field) -> crate::Result<Arc<InvertedIndexReader>> {
fn inverted_index(&self, field: Field) -> crate::Result<ArcInvertedIndexReader> {
if let Some(inv_idx_reader) = self
.inv_idx_reader_cache
.read()
@@ -244,7 +290,7 @@ impl SegmentReader {
//
// Returns an empty inverted index.
let record_option = record_option_opt.unwrap_or(IndexRecordOption::Basic);
return Ok(Arc::new(InvertedIndexReader::empty(record_option)));
return Ok(Arc::new(TantivyInvertedIndexReader::empty(record_option)));
}
let record_option = record_option_opt.unwrap();
@@ -268,7 +314,7 @@ impl SegmentReader {
DataCorruption::comment_only(error_msg)
})?;
let inv_idx_reader = Arc::new(InvertedIndexReader::new(
let inv_idx_reader: ArcInvertedIndexReader = Arc::new(TantivyInvertedIndexReader::new(
TermDictionary::open(termdict_file)?,
postings_file,
positions_file,
@@ -298,7 +344,7 @@ impl SegmentReader {
/// Disclaimer: Some fields may not be listed here. For instance, if the schema contains a json
/// field that is not indexed nor a fast field but is stored, it is possible for the field
/// to not be listed.
pub fn fields_metadata(&self) -> crate::Result<Vec<FieldMetadata>> {
fn fields_metadata(&self) -> crate::Result<Vec<FieldMetadata>> {
let mut indexed_fields: Vec<FieldMetadata> = Vec::new();
let mut map_to_canonical = FnvHashMap::default();
for (field, field_entry) in self.schema().fields() {
@@ -420,39 +466,22 @@ impl SegmentReader {
}
/// Returns the segment id
pub fn segment_id(&self) -> SegmentId {
fn segment_id(&self) -> SegmentId {
self.segment_id
}
/// Returns the delete opstamp
pub fn delete_opstamp(&self) -> Option<Opstamp> {
fn delete_opstamp(&self) -> Option<Opstamp> {
self.delete_opstamp
}
/// Returns the bitset representing the alive `DocId`s.
pub fn alive_bitset(&self) -> Option<&AliveBitSet> {
fn alive_bitset(&self) -> Option<&AliveBitSet> {
self.alive_bitset_opt.as_ref()
}
/// Returns true if the `doc` is marked
/// as deleted.
pub fn is_deleted(&self, doc: DocId) -> bool {
self.alive_bitset()
.map(|alive_bitset| alive_bitset.is_deleted(doc))
.unwrap_or(false)
}
/// Returns an iterator that will iterate over the alive document ids
pub fn doc_ids_alive(&self) -> Box<dyn Iterator<Item = DocId> + Send + '_> {
if let Some(alive_bitset) = &self.alive_bitset_opt {
Box::new(alive_bitset.iter_alive())
} else {
Box::new(0u32..self.max_doc)
}
}
/// Summarize total space usage of this segment.
pub fn space_usage(&self) -> io::Result<SegmentSpaceUsage> {
fn space_usage(&self) -> io::Result<SegmentSpaceUsage> {
Ok(SegmentSpaceUsage::new(
self.num_docs(),
self.termdict_composite.space_usage(self.schema()),
@@ -576,7 +605,7 @@ fn intersect_alive_bitset(
}
}
impl fmt::Debug for SegmentReader {
impl fmt::Debug for TantivySegmentReader {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "SegmentReader({:?})", self.segment_id)
}

View File

@@ -250,11 +250,15 @@ mod tests {
struct DummyWeight;
impl Weight for DummyWeight {
fn scorer(&self, _reader: &SegmentReader, _boost: Score) -> crate::Result<Box<dyn Scorer>> {
fn scorer(
&self,
_reader: &dyn SegmentReader,
_boost: Score,
) -> crate::Result<Box<dyn Scorer>> {
Err(crate::TantivyError::InternalError("dummy impl".to_owned()))
}
fn explain(&self, _reader: &SegmentReader, _doc: DocId) -> crate::Result<Explanation> {
fn explain(&self, _reader: &dyn SegmentReader, _doc: DocId) -> crate::Result<Explanation> {
Err(crate::TantivyError::InternalError("dummy impl".to_owned()))
}
}

View File

@@ -12,7 +12,9 @@ use super::{AddBatch, AddBatchReceiver, AddBatchSender, PreparedCommit};
use crate::directory::{DirectoryLock, GarbageCollectionResult, TerminatingWrite};
use crate::error::TantivyError;
use crate::fastfield::write_alive_bitset;
use crate::index::{Index, Segment, SegmentComponent, SegmentId, SegmentMeta, SegmentReader};
use crate::index::{
Index, Segment, SegmentComponent, SegmentId, SegmentMeta, SegmentReader, TantivySegmentReader,
};
use crate::indexer::delete_queue::{DeleteCursor, DeleteQueue};
use crate::indexer::doc_opstamp_mapping::DocToOpstampMapping;
use crate::indexer::index_writer_status::IndexWriterStatus;
@@ -94,7 +96,7 @@ pub struct IndexWriter<D: Document = TantivyDocument> {
fn compute_deleted_bitset(
alive_bitset: &mut BitSet,
segment_reader: &SegmentReader,
segment_reader: &dyn SegmentReader,
delete_cursor: &mut DeleteCursor,
doc_opstamps: &DocToOpstampMapping,
target_opstamp: Opstamp,
@@ -143,7 +145,7 @@ pub fn advance_deletes(
return Ok(());
}
let segment_reader = SegmentReader::open(&segment)?;
let segment_reader = TantivySegmentReader::open(&segment)?;
let max_doc = segment_reader.max_doc();
let mut alive_bitset: BitSet = match segment_entry.alive_bitset() {
@@ -218,7 +220,7 @@ fn index_documents<D: Document>(
let alive_bitset_opt = apply_deletes(&segment_with_max_doc, &mut delete_cursor, &doc_opstamps)?;
let meta = segment_with_max_doc.meta().clone();
meta.untrack_temp_docstore();
// update segment_updater inventory to remove tempstore
let segment_entry = SegmentEntry::new(meta, delete_cursor, alive_bitset_opt);
segment_updater.schedule_add_segment(segment_entry).wait()?;
@@ -243,7 +245,7 @@ fn apply_deletes(
.max()
.expect("Empty DocOpstamp is forbidden");
let segment_reader = SegmentReader::open(segment)?;
let segment_reader = TantivySegmentReader::open(segment)?;
let doc_to_opstamps = DocToOpstampMapping::WithMap(doc_opstamps);
let max_doc = segment.meta().max_doc();

View File

@@ -94,7 +94,7 @@ impl MergePolicy for LogMergePolicy {
fn compute_merge_candidates(&self, segments: &[SegmentMeta]) -> Vec<MergeCandidate> {
let size_sorted_segments = segments
.iter()
.filter(|seg| (seg.num_docs() as usize) <= self.max_docs_before_merge)
.filter(|seg| seg.num_docs() <= (self.max_docs_before_merge as u32))
.sorted_by_key(|seg| std::cmp::Reverse(seg.max_doc()))
.collect::<Vec<&SegmentMeta>>();
@@ -372,21 +372,4 @@ mod tests {
assert_eq!(merge_candidates[0].0.len(), 1);
assert_eq!(merge_candidates[0].0[0], test_input[1].id());
}
#[test]
fn test_max_docs_before_merge_large_value() {
// Regression test: (max_docs_before_merge as u32) truncates values > u32::MAX.
// Casting num_docs() to usize instead avoids the truncation.
let mut policy = LogMergePolicy::default();
policy.set_min_num_segments(2);
policy.set_max_docs_before_merge(5_000_000_000usize);
let test_input = vec![
create_random_segment_meta(100_000),
create_random_segment_meta(100_000),
];
let result = policy.compute_merge_candidates(&test_input);
// Both segments should be eligible (100_000 < 5_000_000_000)
assert_eq!(result.len(), 1);
assert_eq!(result[0].0.len(), 2);
}
}

View File

@@ -1,5 +1,3 @@
use std::sync::Arc;
use columnar::{
ColumnType, ColumnarReader, MergeRowOrder, RowAddr, ShuffleMergeOrder, StackMergeOrder,
};
@@ -12,14 +10,14 @@ use crate::docset::{DocSet, TERMINATED};
use crate::error::DataCorruption;
use crate::fastfield::AliveBitSet;
use crate::fieldnorm::{FieldNormReader, FieldNormReaders, FieldNormsSerializer, FieldNormsWriter};
use crate::index::{Segment, SegmentComponent, SegmentReader};
use crate::index::{Segment, SegmentComponent, SegmentReader, TantivySegmentReader};
use crate::indexer::doc_id_mapping::{MappingType, SegmentDocIdMapping};
use crate::indexer::SegmentSerializer;
use crate::postings::{InvertedIndexSerializer, Postings, SegmentPostings};
use crate::schema::{value_type_to_column_type, Field, FieldType, Schema};
use crate::store::StoreWriter;
use crate::termdict::{TermMerger, TermOrdinal};
use crate::{DocAddress, DocId, InvertedIndexReader};
use crate::{ArcInvertedIndexReader, DocAddress, DocId};
/// Segment's max doc must be `< MAX_DOC_LIMIT`.
///
@@ -27,7 +25,7 @@ use crate::{DocAddress, DocId, InvertedIndexReader};
pub const MAX_DOC_LIMIT: u32 = 1 << 31;
fn estimate_total_num_tokens_in_single_segment(
reader: &SegmentReader,
reader: &dyn SegmentReader,
field: Field,
) -> crate::Result<u64> {
// There are no deletes. We can simply use the exact value saved into the posting list.
@@ -68,7 +66,7 @@ fn estimate_total_num_tokens_in_single_segment(
Ok((segment_num_tokens as f64 * ratio) as u64)
}
fn estimate_total_num_tokens(readers: &[SegmentReader], field: Field) -> crate::Result<u64> {
fn estimate_total_num_tokens(readers: &[TantivySegmentReader], field: Field) -> crate::Result<u64> {
let mut total_num_tokens: u64 = 0;
for reader in readers {
total_num_tokens += estimate_total_num_tokens_in_single_segment(reader, field)?;
@@ -78,7 +76,7 @@ fn estimate_total_num_tokens(readers: &[SegmentReader], field: Field) -> crate::
pub struct IndexMerger {
schema: Schema,
pub(crate) readers: Vec<SegmentReader>,
pub(crate) readers: Vec<TantivySegmentReader>,
max_doc: u32,
}
@@ -170,8 +168,10 @@ impl IndexMerger {
let mut readers = vec![];
for (segment, new_alive_bitset_opt) in segments.iter().zip(alive_bitset_opt) {
if segment.meta().num_docs() > 0 {
let reader =
SegmentReader::open_with_custom_alive_set(segment, new_alive_bitset_opt)?;
let reader = TantivySegmentReader::open_with_custom_alive_set(
segment,
new_alive_bitset_opt,
)?;
readers.push(reader);
}
}
@@ -204,8 +204,20 @@ impl IndexMerger {
let fieldnorms_readers: Vec<FieldNormReader> = self
.readers
.iter()
.map(|reader| reader.get_fieldnorms_reader(field))
.collect::<Result<_, _>>()?;
.map(|reader| {
reader
.fieldnorms_readers()
.get_field(field)?
.ok_or_else(|| {
let field_name = self.schema.get_field_name(field);
let err_msg = format!(
"Field norm not found for field {field_name:?}. Was the field set \
to record norm during indexing?"
);
crate::TantivyError::SchemaError(err_msg)
})
})
.collect::<crate::Result<_>>()?;
for old_doc_addr in doc_id_mapping.iter_old_doc_addrs() {
let fieldnorms_reader = &fieldnorms_readers[old_doc_addr.segment_ord as usize];
let fieldnorm_id = fieldnorms_reader.fieldnorm_id(old_doc_addr.doc_id);
@@ -262,7 +274,7 @@ impl IndexMerger {
}),
);
let has_deletes: bool = self.readers.iter().any(SegmentReader::has_deletes);
let has_deletes: bool = self.readers.iter().any(|reader| reader.has_deletes());
let mapping_type = if has_deletes {
MappingType::StackedWithDeletes
} else {
@@ -297,7 +309,7 @@ impl IndexMerger {
let mut max_term_ords: Vec<TermOrdinal> = Vec::new();
let field_readers: Vec<Arc<InvertedIndexReader>> = self
let field_readers: Vec<ArcInvertedIndexReader> = self
.readers
.iter()
.map(|reader| reader.inverted_index(indexed_field))
@@ -366,7 +378,7 @@ impl IndexMerger {
// Let's compute the list of non-empty posting lists
for (segment_ord, term_info) in merged_terms.current_segment_ords_and_term_infos() {
let segment_reader = &self.readers[segment_ord];
let inverted_index: &InvertedIndexReader = &field_readers[segment_ord];
let inverted_index = field_readers[segment_ord].as_ref();
let segment_postings = inverted_index
.read_postings_from_terminfo(&term_info, segment_postings_option)?;
let alive_bitset_opt = segment_reader.alive_bitset();
@@ -1534,7 +1546,7 @@ mod tests {
for segment_reader in searcher.segment_readers() {
let mut term_scorer = term_query
.specialized_weight(EnableScoring::enabled_from_searcher(&searcher))?
.term_scorer_for_test(segment_reader, 1.0)?
.term_scorer_for_test(segment_reader.as_ref(), 1.0)?
.unwrap();
// the difference compared to before is intrinsic to the bm25 formula. no worries
// there.

View File

@@ -403,8 +403,7 @@ impl SegmentUpdater {
// from the different drives.
//
// Segment 1 from disk 1, Segment 1 from disk 2, etc.
committed_segment_metas
.sort_by_key(|segment_meta| std::cmp::Reverse(segment_meta.max_doc()));
committed_segment_metas.sort_by_key(|segment_meta| -(segment_meta.max_doc() as i32));
let index_meta = IndexMeta {
index_settings: index.settings().clone(),
segments: committed_segment_metas,
@@ -706,29 +705,12 @@ mod tests {
use crate::collector::TopDocs;
use crate::directory::RamDirectory;
use crate::fastfield::AliveBitSet;
use crate::index::{SegmentId, SegmentMetaInventory};
use crate::indexer::merge_policy::tests::MergeWheneverPossible;
use crate::indexer::merger::IndexMerger;
use crate::indexer::segment_updater::merge_filtered_segments;
use crate::query::QueryParser;
use crate::schema::*;
use crate::{Directory, DocAddress, Index, Segment};
#[test]
fn test_segment_sort_large_max_doc() {
// Regression test: -(max_doc as i32) overflows for max_doc >= 2^31.
// Using std::cmp::Reverse avoids this.
let inventory = SegmentMetaInventory::default();
let mut metas = vec![
inventory.new_segment_meta(SegmentId::generate_random(), 100),
inventory.new_segment_meta(SegmentId::generate_random(), (1u32 << 31) - 1),
inventory.new_segment_meta(SegmentId::generate_random(), 50_000),
];
metas.sort_by_key(|m| std::cmp::Reverse(m.max_doc()));
assert_eq!(metas[0].max_doc(), (1u32 << 31) - 1);
assert_eq!(metas[1].max_doc(), 50_000);
assert_eq!(metas[2].max_doc(), 100);
}
use crate::{Directory, DocAddress, Index, Segment, SegmentReader};
#[test]
fn test_delete_during_merge() -> crate::Result<()> {

View File

@@ -871,7 +871,7 @@ mod tests {
let searcher = reader.searcher();
let segment_reader = searcher.segment_reader(0u32);
fn assert_type(reader: &SegmentReader, field: &str, typ: ColumnType) {
fn assert_type(reader: &dyn SegmentReader, field: &str, typ: ColumnType) {
let cols = reader.fast_fields().dynamic_column_handles(field).unwrap();
assert_eq!(cols.len(), 1, "{field}");
assert_eq!(cols[0].column_type(), typ, "{field}");
@@ -890,7 +890,7 @@ mod tests {
assert_type(segment_reader, "json.my_arr", ColumnType::I64);
assert_type(segment_reader, "json.my_arr.my_key", ColumnType::Str);
fn assert_empty(reader: &SegmentReader, field: &str) {
fn assert_empty(reader: &dyn SegmentReader, field: &str) {
let cols = reader.fast_fields().dynamic_column_handles(field).unwrap();
assert_eq!(cols.len(), 0);
}

View File

@@ -169,10 +169,8 @@ mod macros;
mod future_result;
// Re-exports
pub use columnar;
pub use common::{ByteCount, DateTime};
pub use query_grammar;
pub use time;
pub use {columnar, query_grammar, time};
pub use crate::error::TantivyError;
pub use crate::future_result::FutureResult;
@@ -226,8 +224,9 @@ pub use self::docset::{DocSet, COLLECT_BLOCK_BUFFER_LEN, TERMINATED};
pub use crate::core::{json_utils, Executor, Searcher, SearcherGeneration};
pub use crate::directory::Directory;
pub use crate::index::{
Index, IndexBuilder, IndexMeta, IndexSettings, InvertedIndexReader, Order, Segment,
SegmentMeta, SegmentReader,
ArcInvertedIndexReader, ArcSegmentReader, Index, IndexBuilder, IndexMeta, IndexSettings,
InvertedIndexReader, Order, Segment, SegmentMeta, SegmentReader, TantivyInvertedIndexReader,
TantivySegmentReader,
};
pub use crate::indexer::{IndexWriter, SingleSegmentIndexWriter};
pub use crate::schema::{Document, TantivyDocument, Term};
@@ -525,11 +524,11 @@ pub mod tests {
let searcher = index_reader.searcher();
let reader = searcher.segment_reader(0);
{
let fieldnorm_reader = reader.get_fieldnorms_reader(text_field)?;
let fieldnorm_reader = reader.fieldnorms_readers().get_field(text_field)?.unwrap();
assert_eq!(fieldnorm_reader.fieldnorm(0), 3);
}
{
let fieldnorm_reader = reader.get_fieldnorms_reader(title_field)?;
let fieldnorm_reader = reader.fieldnorms_readers().get_field(title_field)?.unwrap();
assert_eq!(fieldnorm_reader.fieldnorm_id(0), 0);
}
Ok(())
@@ -547,15 +546,18 @@ pub mod tests {
index_writer.commit()?;
let reader = index.reader()?;
let searcher = reader.searcher();
let segment_reader: &SegmentReader = searcher.segment_reader(0);
let fieldnorms_reader = segment_reader.get_fieldnorms_reader(text_field)?;
let segment_reader: &dyn SegmentReader = searcher.segment_reader(0);
let fieldnorms_reader = segment_reader
.fieldnorms_readers()
.get_field(text_field)?
.unwrap();
assert_eq!(fieldnorms_reader.fieldnorm(0), 3);
assert_eq!(fieldnorms_reader.fieldnorm(1), 0);
assert_eq!(fieldnorms_reader.fieldnorm(2), 2);
Ok(())
}
fn advance_undeleted(docset: &mut dyn DocSet, reader: &SegmentReader) -> bool {
fn advance_undeleted(docset: &mut dyn DocSet, reader: &dyn SegmentReader) -> bool {
let mut doc = docset.advance();
while doc != TERMINATED {
if !reader.is_deleted(doc) {
@@ -1072,7 +1074,7 @@ pub mod tests {
}
let reader = index.reader()?;
let searcher = reader.searcher();
let segment_reader: &SegmentReader = searcher.segment_reader(0);
let segment_reader: &dyn SegmentReader = searcher.segment_reader(0);
{
let fast_field_reader_res = segment_reader.fast_fields().u64("text");
assert!(fast_field_reader_res.is_err());

View File

@@ -182,32 +182,6 @@ impl BlockSegmentPostings {
self.freq_reading_option
}
// Resets the block segment postings on another position
// in the postings file.
//
// This is useful for enumerating through a list of terms,
// and consuming the associated posting lists while avoiding
// reallocating a `BlockSegmentPostings`.
//
// # Warning
//
// This does not reset the positions list.
pub(crate) fn reset(&mut self, doc_freq: u32, postings_data: OwnedBytes) -> io::Result<()> {
let (skip_data_opt, postings_data) =
split_into_skips_and_postings(doc_freq, postings_data)?;
self.data = postings_data;
self.block_max_score_cache = None;
self.block_loaded = false;
if let Some(skip_data) = skip_data_opt {
self.skip_reader.reset(skip_data, doc_freq);
} else {
self.skip_reader.reset(OwnedBytes::empty(), doc_freq);
}
self.doc_freq = doc_freq;
self.load_block();
Ok(())
}
/// Returns the overall number of documents in the block postings.
/// It does not take in account whether documents are deleted or not.
///
@@ -303,10 +277,10 @@ impl BlockSegmentPostings {
}
pub(crate) fn load_block(&mut self) {
let offset = self.skip_reader.byte_offset();
if self.block_is_loaded() {
return;
}
let offset = self.skip_reader.byte_offset();
match self.skip_reader.block_info() {
BlockInfo::BitPacked {
doc_num_bits,
@@ -521,40 +495,4 @@ mod tests {
assert_eq!(block_postings.doc(COMPRESSION_BLOCK_SIZE - 1), TERMINATED);
Ok(())
}
#[test]
fn test_reset_block_segment_postings() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let int_field = schema_builder.add_u64_field("id", INDEXED);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests()?;
// create two postings list, one containing even number,
// the other containing odd numbers.
for i in 0..6 {
let doc = doc!(int_field=> (i % 2) as u64);
index_writer.add_document(doc)?;
}
index_writer.commit()?;
let searcher = index.reader()?.searcher();
let segment_reader = searcher.segment_reader(0);
let mut block_segments;
{
let term = Term::from_field_u64(int_field, 0u64);
let inverted_index = segment_reader.inverted_index(int_field)?;
let term_info = inverted_index.get_term_info(&term)?.unwrap();
block_segments = inverted_index
.read_block_postings_from_terminfo(&term_info, IndexRecordOption::Basic)?;
}
assert_eq!(block_segments.docs(), &[0, 2, 4]);
{
let term = Term::from_field_u64(int_field, 1u64);
let inverted_index = segment_reader.inverted_index(int_field)?;
let term_info = inverted_index.get_term_info(&term)?.unwrap();
inverted_index.reset_block_postings_from_terminfo(&term_info, &mut block_segments)?;
}
assert_eq!(block_segments.docs(), &[1, 3, 5]);
Ok(())
}
}

View File

@@ -14,8 +14,7 @@ mod postings;
mod postings_writer;
mod recorder;
mod segment_postings;
/// Serializer module for the inverted index
pub mod serializer;
mod serializer;
mod skip;
mod term_info;
@@ -47,7 +46,7 @@ pub(crate) mod tests {
use super::{InvertedIndexSerializer, Postings};
use crate::docset::{DocSet, TERMINATED};
use crate::fieldnorm::FieldNormReader;
use crate::index::{Index, SegmentComponent, SegmentReader};
use crate::index::{Index, SegmentComponent, SegmentReader, TantivySegmentReader};
use crate::indexer::operation::AddOperation;
use crate::indexer::SegmentWriter;
use crate::query::Scorer;
@@ -259,9 +258,12 @@ pub(crate) mod tests {
segment_writer.finalize()?;
}
{
let segment_reader = SegmentReader::open(&segment)?;
let segment_reader = TantivySegmentReader::open(&segment)?;
{
let fieldnorm_reader = segment_reader.get_fieldnorms_reader(text_field)?;
let fieldnorm_reader = segment_reader
.fieldnorms_readers()
.get_field(text_field)?
.unwrap();
assert_eq!(fieldnorm_reader.fieldnorm(0), 8 + 5);
assert_eq!(fieldnorm_reader.fieldnorm(1), 2);
for i in 2..1000 {

View File

@@ -168,20 +168,12 @@ impl DocSet for SegmentPostings {
self.doc()
}
#[inline]
fn seek(&mut self, target: DocId) -> DocId {
debug_assert!(self.doc() <= target);
if self.doc() >= target {
return self.doc();
}
// As an optimization, if the block is already loaded, we can
// cheaply check the next doc.
self.cur = (self.cur + 1).min(COMPRESSION_BLOCK_SIZE - 1);
if self.doc() >= target {
return self.doc();
}
// Delegate block-local search to BlockSegmentPostings::seek, which returns
// the in-block index of the first doc >= target.
self.cur = self.block_cursor.seek(target);

View File

@@ -11,7 +11,7 @@ use crate::positions::PositionSerializer;
use crate::postings::compression::{BlockEncoder, VIntEncoder, COMPRESSION_BLOCK_SIZE};
use crate::postings::skip::SkipSerializer;
use crate::query::Bm25Weight;
use crate::schema::{Field, FieldEntry, IndexRecordOption, Schema};
use crate::schema::{Field, FieldEntry, FieldType, IndexRecordOption, Schema};
use crate::termdict::TermDictionaryBuilder;
use crate::{DocId, Score};
@@ -80,12 +80,9 @@ impl InvertedIndexSerializer {
let term_dictionary_write = self.terms_write.for_field(field);
let postings_write = self.postings_write.for_field(field);
let positions_write = self.positions_write.for_field(field);
let index_record_option = field_entry
.field_type()
.index_record_option()
.unwrap_or(IndexRecordOption::Basic);
let field_type: FieldType = (*field_entry.field_type()).clone();
FieldSerializer::create(
index_record_option,
&field_type,
total_num_tokens,
term_dictionary_write,
postings_write,
@@ -105,27 +102,29 @@ impl InvertedIndexSerializer {
/// The field serializer is in charge of
/// the serialization of a specific field.
pub struct FieldSerializer<'a, W: Write = WritePtr> {
term_dictionary_builder: TermDictionaryBuilder<&'a mut CountingWriter<W>>,
pub struct FieldSerializer<'a> {
term_dictionary_builder: TermDictionaryBuilder<&'a mut CountingWriter<WritePtr>>,
postings_serializer: PostingsSerializer,
positions_serializer_opt: Option<PositionSerializer<&'a mut CountingWriter<W>>>,
positions_serializer_opt: Option<PositionSerializer<&'a mut CountingWriter<WritePtr>>>,
current_term_info: TermInfo,
term_open: bool,
postings_write: &'a mut CountingWriter<W>,
postings_write: &'a mut CountingWriter<WritePtr>,
postings_start_offset: u64,
}
impl<'a, W: Write> FieldSerializer<'a, W> {
/// Creates a new `FieldSerializer` for the given field type.
pub fn create(
index_record_option: IndexRecordOption,
impl<'a> FieldSerializer<'a> {
fn create(
field_type: &FieldType,
total_num_tokens: u64,
term_dictionary_write: &'a mut CountingWriter<W>,
postings_write: &'a mut CountingWriter<W>,
positions_write: &'a mut CountingWriter<W>,
term_dictionary_write: &'a mut CountingWriter<WritePtr>,
postings_write: &'a mut CountingWriter<WritePtr>,
positions_write: &'a mut CountingWriter<WritePtr>,
fieldnorm_reader: Option<FieldNormReader>,
) -> io::Result<FieldSerializer<'a, W>> {
) -> io::Result<FieldSerializer<'a>> {
total_num_tokens.serialize(postings_write)?;
let index_record_option = field_type
.index_record_option()
.unwrap_or(IndexRecordOption::Basic);
let term_dictionary_builder = TermDictionaryBuilder::create(term_dictionary_write)?;
let average_fieldnorm = fieldnorm_reader
.as_ref()
@@ -193,11 +192,6 @@ impl<'a, W: Write> FieldSerializer<'a, W> {
Ok(())
}
/// Starts the postings for a new term without recording term frequencies.
pub fn new_term_without_freq(&mut self, term: &[u8]) -> io::Result<()> {
self.new_term(term, 0, false)
}
/// Serialize the information that a document contains for the current term:
/// its term frequency, and the position deltas.
///
@@ -303,7 +297,6 @@ impl Block {
}
}
/// Serializer for postings lists.
pub struct PostingsSerializer {
last_doc_id_encoded: u32,
@@ -323,9 +316,6 @@ pub struct PostingsSerializer {
}
impl PostingsSerializer {
/// Creates a new `PostingsSerializer`.
/// * avg_fieldnorm - average field norm for the field being serialized.
/// * mode - indexing options for the field being serialized.
pub fn new(
avg_fieldnorm: Score,
mode: IndexRecordOption,
@@ -348,8 +338,6 @@ impl PostingsSerializer {
}
}
/// Starts the serialization for a new term.
/// * term_doc_freq - the number of documents containing the term.
pub fn new_term(&mut self, term_doc_freq: u32, record_term_freq: bool) {
self.bm25_weight = None;
@@ -389,7 +377,6 @@ impl PostingsSerializer {
self.postings_write.extend(block_encoded);
}
if self.term_has_freq {
// encode the term frequencies
let (num_bits, block_encoded): (u8, &[u8]) = self
.block_encoder
.compress_block_unsorted(self.block.term_freqs(), true);
@@ -430,9 +417,6 @@ impl PostingsSerializer {
self.block.clear();
}
/// Register that the given document contains the current term.
/// * doc_id - the document id.
/// * term_freq - the term frequency within the document.
pub fn write_doc(&mut self, doc_id: DocId, term_freq: u32) {
self.block.append_doc(doc_id, term_freq);
if self.block.is_full() {
@@ -440,7 +424,6 @@ impl PostingsSerializer {
}
}
/// Finish the serialization for this term.
pub fn close_term(
&mut self,
doc_freq: u32,

View File

@@ -14,11 +14,7 @@ use crate::{DocId, Score, TERMINATED};
// (requiring a 6th bit), but the biggest doc_id we can want to encode is TERMINATED-1, which can
// be represented on 31b without delta encoding.
fn encode_bitwidth(bitwidth: u8, delta_1: bool) -> u8 {
assert!(
bitwidth < 32,
"bitwidth needs to be less than 32, but got {}",
bitwidth
);
assert!(bitwidth < 32);
bitwidth | ((delta_1 as u8) << 6)
}
@@ -146,23 +142,6 @@ impl SkipReader {
skip_reader
}
pub fn reset(&mut self, data: OwnedBytes, doc_freq: u32) {
self.last_doc_in_block = if doc_freq >= COMPRESSION_BLOCK_SIZE as u32 {
0
} else {
TERMINATED
};
self.last_doc_in_previous_block = 0u32;
self.owned_read = data;
self.block_info = BlockInfo::VInt { num_docs: doc_freq };
self.byte_offset = 0;
self.remaining_docs = doc_freq;
self.position_offset = 0u64;
if doc_freq >= COMPRESSION_BLOCK_SIZE as u32 {
self.read_block_info();
}
}
// Returns the block max score for this block if available.
//
// The block max score is available for all full bitpacked block,

View File

@@ -21,7 +21,7 @@ impl Query for AllQuery {
pub struct AllWeight;
impl Weight for AllWeight {
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
fn scorer(&self, reader: &dyn SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
let all_scorer = AllScorer::new(reader.max_doc());
if boost != 1.0 {
Ok(Box::new(BoostScorer::new(all_scorer, boost)))
@@ -30,7 +30,7 @@ impl Weight for AllWeight {
}
}
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {
fn explain(&self, reader: &dyn SegmentReader, doc: DocId) -> crate::Result<Explanation> {
if doc >= reader.max_doc() {
return Err(does_not_match(doc));
}

View File

@@ -67,7 +67,7 @@ where
}
/// Returns the term infos that match the automaton
pub fn get_match_term_infos(&self, reader: &SegmentReader) -> crate::Result<Vec<TermInfo>> {
pub fn get_match_term_infos(&self, reader: &dyn SegmentReader) -> crate::Result<Vec<TermInfo>> {
let inverted_index = reader.inverted_index(self.field)?;
let term_dict = inverted_index.terms();
let mut term_stream = self.automaton_stream(term_dict)?;
@@ -84,7 +84,7 @@ where
A: Automaton + Send + Sync + 'static,
A::State: Clone,
{
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
fn scorer(&self, reader: &dyn SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
let max_doc = reader.max_doc();
let mut doc_bitset = BitSet::with_max_value(max_doc);
let inverted_index = reader.inverted_index(self.field)?;
@@ -110,7 +110,7 @@ where
Ok(Box::new(const_scorer))
}
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {
fn explain(&self, reader: &dyn SegmentReader, doc: DocId) -> crate::Result<Explanation> {
let mut scorer = self.scorer(reader, 1.0)?;
if scorer.seek(doc) == doc {
Ok(Explanation::new("AutomatonScorer", 1.0))

View File

@@ -205,7 +205,7 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
fn per_occur_scorers(
&self,
reader: &SegmentReader,
reader: &dyn SegmentReader,
boost: Score,
) -> crate::Result<HashMap<Occur, Vec<Box<dyn Scorer>>>> {
let mut per_occur_scorers: HashMap<Occur, Vec<Box<dyn Scorer>>> = HashMap::new();
@@ -221,7 +221,7 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
fn complex_scorer<TComplexScoreCombiner: ScoreCombiner>(
&self,
reader: &SegmentReader,
reader: &dyn SegmentReader,
boost: Score,
score_combiner_fn: impl Fn() -> TComplexScoreCombiner,
) -> crate::Result<SpecializedScorer> {
@@ -291,6 +291,18 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
}
};
let exclude_scorer_opt: Option<Box<dyn Scorer>> = if exclude_scorers.is_empty() {
None
} else {
let exclude_specialized_scorer: SpecializedScorer =
scorer_union(exclude_scorers, DoNothingCombiner::default, num_docs);
Some(into_box_scorer(
exclude_specialized_scorer,
DoNothingCombiner::default,
num_docs,
))
};
let include_scorer = match (should_scorers, must_scorers) {
(ShouldScorersCombinationMethod::Ignored, must_scorers) => {
// No SHOULD clauses (or they were absorbed into MUST).
@@ -368,23 +380,16 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
}
}
};
if exclude_scorers.is_empty() {
return Ok(include_scorer);
}
let include_scorer_boxed = into_box_scorer(include_scorer, &score_combiner_fn, num_docs);
let scorer: Box<dyn Scorer> = if exclude_scorers.len() == 1 {
let exclude_scorer = exclude_scorers.pop().unwrap();
match exclude_scorer.downcast::<TermScorer>() {
// Cast to TermScorer succeeded
Ok(exclude_scorer) => Box::new(Exclude::new(include_scorer_boxed, *exclude_scorer)),
// We get back the original Box<dyn Scorer>
Err(exclude_scorer) => Box::new(Exclude::new(include_scorer_boxed, exclude_scorer)),
}
if let Some(exclude_scorer) = exclude_scorer_opt {
let include_scorer_boxed =
into_box_scorer(include_scorer, &score_combiner_fn, num_docs);
Ok(SpecializedScorer::Other(Box::new(Exclude::new(
include_scorer_boxed,
exclude_scorer,
))))
} else {
Box::new(Exclude::new(include_scorer_boxed, exclude_scorers))
};
Ok(SpecializedScorer::Other(scorer))
Ok(include_scorer)
}
}
}
@@ -413,7 +418,7 @@ fn remove_and_count_all_and_empty_scorers(
}
impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombiner> {
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
fn scorer(&self, reader: &dyn SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
let num_docs = reader.num_docs();
if self.weights.is_empty() {
Ok(Box::new(EmptyScorer))
@@ -437,7 +442,7 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
}
}
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {
fn explain(&self, reader: &dyn SegmentReader, doc: DocId) -> crate::Result<Explanation> {
let mut scorer = self.scorer(reader, 1.0)?;
if scorer.seek(doc) != doc {
return Err(does_not_match(doc));
@@ -459,7 +464,7 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
fn for_each(
&self,
reader: &SegmentReader,
reader: &dyn SegmentReader,
callback: &mut dyn FnMut(DocId, Score),
) -> crate::Result<()> {
let scorer = self.complex_scorer(reader, 1.0, &self.score_combiner_fn)?;
@@ -481,7 +486,7 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
fn for_each_no_score(
&self,
reader: &SegmentReader,
reader: &dyn SegmentReader,
callback: &mut dyn FnMut(&[DocId]),
) -> crate::Result<()> {
let scorer = self.complex_scorer(reader, 1.0, || DoNothingCombiner)?;
@@ -516,7 +521,7 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
fn for_each_pruning(
&self,
threshold: Score,
reader: &SegmentReader,
reader: &dyn SegmentReader,
callback: &mut dyn FnMut(DocId, Score) -> Score,
) -> crate::Result<()> {
let scorer = self.complex_scorer(reader, 1.0, &self.score_combiner_fn)?;

View File

@@ -67,11 +67,11 @@ impl BoostWeight {
}
impl Weight for BoostWeight {
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
fn scorer(&self, reader: &dyn SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
self.weight.scorer(reader, boost * self.boost)
}
fn explain(&self, reader: &SegmentReader, doc: u32) -> crate::Result<Explanation> {
fn explain(&self, reader: &dyn SegmentReader, doc: u32) -> crate::Result<Explanation> {
let underlying_explanation = self.weight.explain(reader, doc)?;
let score = underlying_explanation.value() * self.boost;
let mut explanation =
@@ -80,7 +80,7 @@ impl Weight for BoostWeight {
Ok(explanation)
}
fn count(&self, reader: &SegmentReader) -> crate::Result<u32> {
fn count(&self, reader: &dyn SegmentReader) -> crate::Result<u32> {
self.weight.count(reader)
}
}

View File

@@ -63,12 +63,12 @@ impl ConstWeight {
}
impl Weight for ConstWeight {
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
fn scorer(&self, reader: &dyn SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
let inner_scorer = self.weight.scorer(reader, boost)?;
Ok(Box::new(ConstScorer::new(inner_scorer, boost * self.score)))
}
fn explain(&self, reader: &SegmentReader, doc: u32) -> crate::Result<Explanation> {
fn explain(&self, reader: &dyn SegmentReader, doc: u32) -> crate::Result<Explanation> {
let mut scorer = self.scorer(reader, 1.0)?;
if scorer.seek(doc) != doc {
return Err(TantivyError::InvalidArgument(format!(
@@ -81,7 +81,7 @@ impl Weight for ConstWeight {
Ok(explanation)
}
fn count(&self, reader: &SegmentReader) -> crate::Result<u32> {
fn count(&self, reader: &dyn SegmentReader) -> crate::Result<u32> {
self.weight.count(reader)
}
}

View File

@@ -26,11 +26,11 @@ impl Query for EmptyQuery {
/// It is useful for tests and handling edge cases.
pub struct EmptyWeight;
impl Weight for EmptyWeight {
fn scorer(&self, _reader: &SegmentReader, _boost: Score) -> crate::Result<Box<dyn Scorer>> {
fn scorer(&self, _reader: &dyn SegmentReader, _boost: Score) -> crate::Result<Box<dyn Scorer>> {
Ok(Box::new(EmptyScorer))
}
fn explain(&self, _reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {
fn explain(&self, _reader: &dyn SegmentReader, doc: DocId) -> crate::Result<Explanation> {
Err(does_not_match(doc))
}
}

View File

@@ -1,71 +1,48 @@
use crate::docset::{DocSet, SeekDangerResult, TERMINATED};
use crate::docset::{DocSet, TERMINATED};
use crate::query::Scorer;
use crate::{DocId, Score};
/// An exclusion set is a set of documents
/// that should be excluded from a given DocSet.
#[inline]
fn is_within<TDocSetExclude: DocSet>(docset: &mut TDocSetExclude, doc: DocId) -> bool {
docset.doc() <= doc && docset.seek(doc) == doc
}
/// Filters a given `DocSet` by removing the docs from a given `DocSet`.
///
/// It can be a single DocSet, or a Vec of DocSets.
pub trait ExclusionSet: Send {
/// Returns `true` if the given `doc` is in the exclusion set.
fn contains(&mut self, doc: DocId) -> bool;
}
impl<TDocSet: DocSet> ExclusionSet for TDocSet {
#[inline]
fn contains(&mut self, doc: DocId) -> bool {
self.seek_danger(doc) == SeekDangerResult::Found
}
}
impl<TDocSet: DocSet> ExclusionSet for Vec<TDocSet> {
#[inline]
fn contains(&mut self, doc: DocId) -> bool {
for docset in self.iter_mut() {
if docset.seek_danger(doc) == SeekDangerResult::Found {
return true;
}
}
false
}
}
/// Filters a given `DocSet` by removing the docs from an exclusion set.
///
/// The excluding docsets have no impact on scoring.
pub struct Exclude<TDocSet, TExclusionSet> {
/// The excluding docset has no impact on scoring.
pub struct Exclude<TDocSet, TDocSetExclude> {
underlying_docset: TDocSet,
exclusion_set: TExclusionSet,
excluding_docset: TDocSetExclude,
}
impl<TDocSet, TExclusionSet> Exclude<TDocSet, TExclusionSet>
impl<TDocSet, TDocSetExclude> Exclude<TDocSet, TDocSetExclude>
where
TDocSet: DocSet,
TExclusionSet: ExclusionSet,
TDocSetExclude: DocSet,
{
/// Creates a new `ExcludeScorer`
pub fn new(
mut underlying_docset: TDocSet,
mut exclusion_set: TExclusionSet,
) -> Exclude<TDocSet, TExclusionSet> {
mut excluding_docset: TDocSetExclude,
) -> Exclude<TDocSet, TDocSetExclude> {
while underlying_docset.doc() != TERMINATED {
let target = underlying_docset.doc();
if !exclusion_set.contains(target) {
if !is_within(&mut excluding_docset, target) {
break;
}
underlying_docset.advance();
}
Exclude {
underlying_docset,
exclusion_set,
excluding_docset,
}
}
}
impl<TDocSet, TExclusionSet> DocSet for Exclude<TDocSet, TExclusionSet>
impl<TDocSet, TDocSetExclude> DocSet for Exclude<TDocSet, TDocSetExclude>
where
TDocSet: DocSet,
TExclusionSet: ExclusionSet,
TDocSetExclude: DocSet,
{
fn advance(&mut self) -> DocId {
loop {
@@ -73,7 +50,7 @@ where
if candidate == TERMINATED {
return TERMINATED;
}
if !self.exclusion_set.contains(candidate) {
if !is_within(&mut self.excluding_docset, candidate) {
return candidate;
}
}
@@ -84,7 +61,7 @@ where
if candidate == TERMINATED {
return TERMINATED;
}
if !self.exclusion_set.contains(candidate) {
if !is_within(&mut self.excluding_docset, candidate) {
return candidate;
}
self.advance()
@@ -102,10 +79,10 @@ where
}
}
impl<TScorer, TExclusionSet> Scorer for Exclude<TScorer, TExclusionSet>
impl<TScorer, TDocSetExclude> Scorer for Exclude<TScorer, TDocSetExclude>
where
TScorer: Scorer,
TExclusionSet: ExclusionSet + 'static,
TDocSetExclude: DocSet + 'static,
{
#[inline]
fn score(&mut self) -> Score {

View File

@@ -98,7 +98,7 @@ pub struct ExistsWeight {
}
impl Weight for ExistsWeight {
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
fn scorer(&self, reader: &dyn SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
let fast_field_reader = reader.fast_fields();
let mut column_handles = fast_field_reader.dynamic_column_handles(&self.field_name)?;
if self.field_type == Type::Json && self.json_subpaths {
@@ -165,7 +165,7 @@ impl Weight for ExistsWeight {
Ok(Box::new(ConstScorer::new(docset, boost)))
}
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {
fn explain(&self, reader: &dyn SegmentReader, doc: DocId) -> crate::Result<Explanation> {
let mut scorer = self.scorer(reader, 1.0)?;
if scorer.seek(doc) != doc {
return Err(does_not_match(doc));

View File

@@ -84,14 +84,6 @@ impl<TDocSet: DocSet> Intersection<TDocSet, TDocSet> {
docsets.sort_by_key(|docset| docset.cost());
go_to_first_doc(&mut docsets);
let left = docsets.remove(0);
debug_assert!({
let doc = left.doc();
if doc == TERMINATED {
true
} else {
docsets.iter().all(|docset| docset.doc() == doc)
}
});
let right = docsets.remove(0);
Intersection {
left,
@@ -120,24 +112,30 @@ impl<TDocSet: DocSet, TOtherDocSet: DocSet> DocSet for Intersection<TDocSet, TOt
// Invariant:
// - candidate is always <= to the next document in the intersection.
// - candidate strictly increases at every occurence of the loop.
let mut candidate = left.doc() + 1;
let mut candidate = 0;
// Termination: candidate strictly increases.
'outer: while candidate < TERMINATED {
// As we enter the loop, we should always have candidate < next_doc.
candidate = left.seek(candidate);
// This step always increases candidate.
//
// TODO: Think about which value would make sense here
// It depends on the DocSet implementation, when a seek would outweigh an advance.
candidate = if candidate > left.doc().wrapping_add(100) {
left.seek(candidate)
} else {
left.advance()
};
// Left is positionned on `candidate`.
debug_assert_eq!(left.doc(), candidate);
if let SeekDangerResult::SeekLowerBound(seek_lower_bound) = right.seek_danger(candidate)
{
debug_assert!(
seek_lower_bound == TERMINATED || seek_lower_bound > candidate,
"seek_lower_bound {seek_lower_bound} must be greater than candidate \
{candidate}"
);
// The max is technically useless but it makes the invariant
// easier to proofread.
debug_assert!(seek_lower_bound >= candidate);
candidate = seek_lower_bound;
continue;
}
@@ -150,11 +148,7 @@ impl<TDocSet: DocSet, TOtherDocSet: DocSet> DocSet for Intersection<TDocSet, TOt
other.seek_danger(candidate)
{
// One of the scorer does not match, let's restart at the top of the loop.
debug_assert!(
seek_lower_bound == TERMINATED || seek_lower_bound > candidate,
"seek_lower_bound {seek_lower_bound} must be greater than candidate \
{candidate}"
);
debug_assert!(seek_lower_bound >= candidate);
candidate = seek_lower_bound;
continue 'outer;
}
@@ -244,12 +238,9 @@ mod tests {
use proptest::prelude::*;
use super::Intersection;
use crate::collector::Count;
use crate::docset::{DocSet, TERMINATED};
use crate::postings::tests::test_skip_against_unoptimized;
use crate::query::{QueryParser, VecDocSet};
use crate::schema::{Schema, TEXT};
use crate::Index;
use crate::query::VecDocSet;
#[test]
fn test_intersection() {
@@ -420,29 +411,4 @@ mod tests {
assert_eq!(intersection.doc(), TERMINATED);
}
}
#[test]
fn test_bug_2811_intersection_candidate_should_increase() {
let mut schema_builder = Schema::builder();
let text_field = schema_builder.add_text_field("text", TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut writer = index.writer_for_tests().unwrap();
writer
.add_document(doc!(text_field=>"hello happy tax"))
.unwrap();
writer.add_document(doc!(text_field=>"hello")).unwrap();
writer.add_document(doc!(text_field=>"hello")).unwrap();
writer.add_document(doc!(text_field=>"happy tax")).unwrap();
writer.commit().unwrap();
let query_parser = QueryParser::for_index(&index, Vec::new());
let query = query_parser
.parse_query(r#"+text:hello +text:"happy tax""#)
.unwrap();
let searcher = index.reader().unwrap().searcher();
let c = searcher.search(&*query, &Count).unwrap();
assert_eq!(c, 1);
}
}

View File

@@ -43,7 +43,7 @@ pub use self::boost_query::{BoostQuery, BoostWeight};
pub use self::const_score_query::{ConstScoreQuery, ConstScorer};
pub use self::disjunction_max_query::DisjunctionMaxQuery;
pub use self::empty_query::{EmptyQuery, EmptyScorer, EmptyWeight};
pub use self::exclude::{Exclude, ExclusionSet};
pub use self::exclude::Exclude;
pub use self::exist_query::ExistsQuery;
pub use self::explanation::Explanation;
#[cfg(test)]

View File

@@ -32,7 +32,7 @@ impl PhrasePrefixWeight {
}
}
fn fieldnorm_reader(&self, reader: &SegmentReader) -> crate::Result<FieldNormReader> {
fn fieldnorm_reader(&self, reader: &dyn SegmentReader) -> crate::Result<FieldNormReader> {
let field = self.phrase_terms[0].1.field();
if self.similarity_weight_opt.is_some() {
if let Some(fieldnorm_reader) = reader.fieldnorms_readers().get_field(field)? {
@@ -44,7 +44,7 @@ impl PhrasePrefixWeight {
pub(crate) fn phrase_scorer(
&self,
reader: &SegmentReader,
reader: &dyn SegmentReader,
boost: Score,
) -> crate::Result<Option<PhrasePrefixScorer<SegmentPostings>>> {
let similarity_weight_opt = self
@@ -114,7 +114,7 @@ impl PhrasePrefixWeight {
}
impl Weight for PhrasePrefixWeight {
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
fn scorer(&self, reader: &dyn SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
if let Some(scorer) = self.phrase_scorer(reader, boost)? {
Ok(Box::new(scorer))
} else {
@@ -122,7 +122,7 @@ impl Weight for PhrasePrefixWeight {
}
}
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {
fn explain(&self, reader: &dyn SegmentReader, doc: DocId) -> crate::Result<Explanation> {
let scorer_opt = self.phrase_scorer(reader, 1.0)?;
if scorer_opt.is_none() {
return Err(does_not_match(doc));

View File

@@ -531,12 +531,7 @@ impl<TPostings: Postings> DocSet for PhraseScorer<TPostings> {
}
fn seek_danger(&mut self, target: DocId) -> SeekDangerResult {
debug_assert!(
target >= self.doc(),
"target ({}) should be greater than or equal to doc ({})",
target,
self.doc()
);
debug_assert!(target >= self.doc());
let seek_res = self.intersection_docset.seek_danger(target);
if seek_res != SeekDangerResult::Found {
return seek_res;

View File

@@ -29,7 +29,7 @@ impl PhraseWeight {
}
}
fn fieldnorm_reader(&self, reader: &SegmentReader) -> crate::Result<FieldNormReader> {
fn fieldnorm_reader(&self, reader: &dyn SegmentReader) -> crate::Result<FieldNormReader> {
let field = self.phrase_terms[0].1.field();
if self.similarity_weight_opt.is_some() {
if let Some(fieldnorm_reader) = reader.fieldnorms_readers().get_field(field)? {
@@ -41,7 +41,7 @@ impl PhraseWeight {
pub(crate) fn phrase_scorer(
&self,
reader: &SegmentReader,
reader: &dyn SegmentReader,
boost: Score,
) -> crate::Result<Option<PhraseScorer<SegmentPostings>>> {
let similarity_weight_opt = self
@@ -74,7 +74,7 @@ impl PhraseWeight {
}
impl Weight for PhraseWeight {
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
fn scorer(&self, reader: &dyn SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
if let Some(scorer) = self.phrase_scorer(reader, boost)? {
Ok(Box::new(scorer))
} else {
@@ -82,7 +82,7 @@ impl Weight for PhraseWeight {
}
}
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {
fn explain(&self, reader: &dyn SegmentReader, doc: DocId) -> crate::Result<Explanation> {
let scorer_opt = self.phrase_scorer(reader, 1.0)?;
if scorer_opt.is_none() {
return Err(does_not_match(doc));

View File

@@ -45,7 +45,7 @@ impl RegexPhraseWeight {
}
}
fn fieldnorm_reader(&self, reader: &SegmentReader) -> crate::Result<FieldNormReader> {
fn fieldnorm_reader(&self, reader: &dyn SegmentReader) -> crate::Result<FieldNormReader> {
if self.similarity_weight_opt.is_some() {
if let Some(fieldnorm_reader) = reader.fieldnorms_readers().get_field(self.field)? {
return Ok(fieldnorm_reader);
@@ -56,7 +56,7 @@ impl RegexPhraseWeight {
pub(crate) fn phrase_scorer(
&self,
reader: &SegmentReader,
reader: &dyn SegmentReader,
boost: Score,
) -> crate::Result<Option<PhraseScorer<UnionType>>> {
let similarity_weight_opt = self
@@ -84,7 +84,8 @@ impl RegexPhraseWeight {
"Phrase query exceeded max expansions {num_terms}"
)));
}
let union = Self::get_union_from_term_infos(&term_infos, reader, &inverted_index)?;
let union =
Self::get_union_from_term_infos(&term_infos, reader, inverted_index.as_ref())?;
posting_lists.push((offset, union));
}
@@ -99,7 +100,7 @@ impl RegexPhraseWeight {
/// Add all docs of the term to the docset
fn add_to_bitset(
inverted_index: &InvertedIndexReader,
inverted_index: &dyn InvertedIndexReader,
term_info: &TermInfo,
doc_bitset: &mut BitSet,
) -> crate::Result<()> {
@@ -174,8 +175,8 @@ impl RegexPhraseWeight {
/// Use Roaring Bitmaps for sparse terms. The full bitvec is main memory consumer currently.
pub(crate) fn get_union_from_term_infos(
term_infos: &[TermInfo],
reader: &SegmentReader,
inverted_index: &InvertedIndexReader,
reader: &dyn SegmentReader,
inverted_index: &dyn InvertedIndexReader,
) -> crate::Result<UnionType> {
let max_doc = reader.max_doc();
@@ -269,7 +270,7 @@ impl RegexPhraseWeight {
}
impl Weight for RegexPhraseWeight {
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
fn scorer(&self, reader: &dyn SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
if let Some(scorer) = self.phrase_scorer(reader, boost)? {
Ok(Box::new(scorer))
} else {
@@ -277,7 +278,7 @@ impl Weight for RegexPhraseWeight {
}
}
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {
fn explain(&self, reader: &dyn SegmentReader, doc: DocId) -> crate::Result<Explanation> {
let scorer_opt = self.phrase_scorer(reader, 1.0)?;
if scorer_opt.is_none() {
return Err(does_not_match(doc));

View File

@@ -146,7 +146,7 @@ pub trait Query: QueryClone + Send + Sync + downcast_rs::Downcast + fmt::Debug {
let weight = self.weight(EnableScoring::disabled_from_searcher(searcher))?;
let mut result = 0;
for reader in searcher.segment_readers() {
result += weight.count(reader)? as usize;
result += weight.count(reader.as_ref())? as usize;
}
Ok(result)
}

View File

@@ -2068,16 +2068,6 @@ mod test {
format!("Regex(Field(0), {:#?})", expected_regex).as_str(),
false,
);
let expected_regex2 = tantivy_fst::Regex::new(r".*a").unwrap();
test_parse_query_to_logical_ast_helper(
"title:(/.*b/ OR /.*a/)",
format!(
"(Regex(Field(0), {:#?}) Regex(Field(0), {:#?}))",
expected_regex, expected_regex2
)
.as_str(),
false,
);
// Invalid field
let err = parse_query_to_logical_ast("float:/.*b/", false).unwrap_err();

View File

@@ -19,8 +19,7 @@ pub(crate) fn is_type_valid_for_fastfield_range_query(typ: Type) -> bool {
| Type::Bool
| Type::Date
| Type::Json
| Type::IpAddr
| Type::Bytes => true,
Type::Facet => false,
| Type::IpAddr => true,
Type::Facet | Type::Bytes => false,
}
}

View File

@@ -212,7 +212,7 @@ impl InvertedIndexRangeWeight {
}
impl Weight for InvertedIndexRangeWeight {
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
fn scorer(&self, reader: &dyn SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
let max_doc = reader.max_doc();
let mut doc_bitset = BitSet::with_max_value(max_doc);
@@ -245,7 +245,7 @@ impl Weight for InvertedIndexRangeWeight {
Ok(Box::new(ConstScorer::new(doc_bitset, boost)))
}
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {
fn explain(&self, reader: &dyn SegmentReader, doc: DocId) -> crate::Result<Explanation> {
let mut scorer = self.scorer(reader, 1.0)?;
if scorer.seek(doc) != doc {
return Err(does_not_match(doc));
@@ -686,7 +686,7 @@ mod tests {
.weight(EnableScoring::disabled_from_schema(&schema))
.unwrap();
let range_scorer = range_weight
.scorer(&searcher.segment_readers()[0], 1.0f32)
.scorer(searcher.segment_readers()[0].as_ref(), 1.0f32)
.unwrap();
range_scorer
};

View File

@@ -6,8 +6,8 @@ use std::net::Ipv6Addr;
use std::ops::{Bound, RangeInclusive};
use columnar::{
BytesColumn, Cardinality, Column, ColumnType, MonotonicallyMappableToU128,
MonotonicallyMappableToU64, NumericalType, StrColumn,
Cardinality, Column, ColumnType, MonotonicallyMappableToU128, MonotonicallyMappableToU64,
NumericalType, StrColumn,
};
use common::bounds::{BoundsRange, TransformBound};
@@ -52,7 +52,7 @@ impl FastFieldRangeWeight {
}
impl Weight for FastFieldRangeWeight {
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
fn scorer(&self, reader: &dyn SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
// Check if both bounds are Bound::Unbounded
if self.bounds.is_unbounded() {
return Ok(Box::new(AllScorer::new(reader.max_doc())));
@@ -163,25 +163,6 @@ impl Weight for FastFieldRangeWeight {
};
let dict = str_dict_column.dictionary();
let bounds = self.bounds.map_bound(get_value_bytes);
// Get term ids for terms
let (lower_bound, upper_bound) =
dict.term_bounds_to_ord(bounds.lower_bound, bounds.upper_bound)?;
let fast_field_reader = reader.fast_fields();
let Some((column, _col_type)) =
fast_field_reader.u64_lenient_for_type(None, &field_name)?
else {
return Ok(Box::new(EmptyScorer));
};
search_on_u64_ff(column, boost, BoundsRange::new(lower_bound, upper_bound))
} else if field_type.is_bytes() {
let Some(bytes_column): Option<BytesColumn> =
reader.fast_fields().bytes(&field_name)?
else {
return Ok(Box::new(EmptyScorer));
};
let dict = bytes_column.dictionary();
let bounds = self.bounds.map_bound(get_value_bytes);
// Get term ids for terms
let (lower_bound, upper_bound) =
@@ -238,7 +219,7 @@ impl Weight for FastFieldRangeWeight {
}
}
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {
fn explain(&self, reader: &dyn SegmentReader, doc: DocId) -> crate::Result<Explanation> {
let mut scorer = self.scorer(reader, 1.0)?;
if scorer.seek(doc) != doc {
return Err(TantivyError::InvalidArgument(format!(
@@ -255,7 +236,7 @@ impl Weight for FastFieldRangeWeight {
///
/// Convert into fast field value space and search.
fn search_on_json_numerical_field(
reader: &SegmentReader,
reader: &dyn SegmentReader,
field_name: &str,
typ: Type,
bounds: BoundsRange<ValueBytes<Vec<u8>>>,
@@ -1421,66 +1402,6 @@ mod tests {
Ok(())
}
#[test]
fn test_bytes_field_ff_range_query() -> crate::Result<()> {
use crate::schema::BytesOptions;
let mut schema_builder = Schema::builder();
let bytes_field = schema_builder
.add_bytes_field("data", BytesOptions::default().set_fast().set_indexed());
let schema = schema_builder.build();
let index = Index::create_in_ram(schema.clone());
let mut index_writer: IndexWriter = index.writer_for_tests()?;
// Insert documents with lexicographically sortable byte values
// Using simple byte sequences that have clear ordering
let values: Vec<Vec<u8>> = vec![
vec![0x00, 0x10],
vec![0x00, 0x20],
vec![0x00, 0x30],
vec![0x01, 0x00],
vec![0x01, 0x10],
vec![0x02, 0x00],
];
for value in &values {
let mut doc = TantivyDocument::new();
doc.add_bytes(bytes_field, value);
index_writer.add_document(doc)?;
}
index_writer.commit()?;
let reader = index.reader()?;
let searcher = reader.searcher();
// Test: Range query [0x00, 0x20] to [0x01, 0x00] (inclusive)
// Should match: [0x00, 0x20], [0x00, 0x30], [0x01, 0x00]
let lower = Term::from_field_bytes(bytes_field, &[0x00, 0x20]);
let upper = Term::from_field_bytes(bytes_field, &[0x01, 0x00]);
let range_query = RangeQuery::new(Bound::Included(lower), Bound::Included(upper));
let count = searcher.search(&range_query, &Count)?;
assert_eq!(
count, 3,
"Expected 3 documents in range [0x00,0x20] to [0x01,0x00]"
);
// Test: Range query > [0x01, 0x00] (exclusive lower bound)
// Should match: [0x01, 0x10], [0x02, 0x00]
let lower = Term::from_field_bytes(bytes_field, &[0x01, 0x00]);
let range_query = RangeQuery::new(Bound::Excluded(lower), Bound::Unbounded);
let count = searcher.search(&range_query, &Count)?;
assert_eq!(count, 2, "Expected 2 documents > [0x01,0x00]");
// Test: Range query < [0x00, 0x30] (exclusive upper bound)
// Should match: [0x00, 0x10], [0x00, 0x20]
let upper = Term::from_field_bytes(bytes_field, &[0x00, 0x30]);
let range_query = RangeQuery::new(Bound::Unbounded, Bound::Excluded(upper));
let count = searcher.search(&range_query, &Count)?;
assert_eq!(count, 2, "Expected 2 documents < [0x00,0x30]");
Ok(())
}
}
#[cfg(test)]

View File

@@ -105,7 +105,6 @@ impl DocSet for TermScorer {
#[inline]
fn seek(&mut self, target: DocId) -> DocId {
debug_assert!(target >= self.doc());
self.postings.seek(target)
}
@@ -264,7 +263,9 @@ mod tests {
let mut block_max_scores_b = vec![];
let mut docs = vec![];
{
let mut term_scorer = term_weight.term_scorer_for_test(reader, 1.0)?.unwrap();
let mut term_scorer = term_weight
.term_scorer_for_test(reader.as_ref(), 1.0)?
.unwrap();
while term_scorer.doc() != TERMINATED {
let mut score = term_scorer.score();
docs.push(term_scorer.doc());
@@ -278,7 +279,9 @@ mod tests {
}
}
{
let mut term_scorer = term_weight.term_scorer_for_test(reader, 1.0)?.unwrap();
let mut term_scorer = term_weight
.term_scorer_for_test(reader.as_ref(), 1.0)?
.unwrap();
for d in docs {
term_scorer.seek_block(d);
block_max_scores_b.push(term_scorer.block_max_score());

View File

@@ -34,11 +34,11 @@ impl TermOrEmptyOrAllScorer {
}
impl Weight for TermWeight {
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
fn scorer(&self, reader: &dyn SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
Ok(self.specialized_scorer(reader, boost)?.into_boxed_scorer())
}
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {
fn explain(&self, reader: &dyn SegmentReader, doc: DocId) -> crate::Result<Explanation> {
match self.specialized_scorer(reader, 1.0)? {
TermOrEmptyOrAllScorer::TermScorer(mut term_scorer) => {
if term_scorer.doc() > doc || term_scorer.seek(doc) != doc {
@@ -53,7 +53,7 @@ impl Weight for TermWeight {
}
}
fn count(&self, reader: &SegmentReader) -> crate::Result<u32> {
fn count(&self, reader: &dyn SegmentReader) -> crate::Result<u32> {
if let Some(alive_bitset) = reader.alive_bitset() {
Ok(self.scorer(reader, 1.0)?.count(alive_bitset))
} else {
@@ -68,7 +68,7 @@ impl Weight for TermWeight {
/// `DocSet` and push the scored documents to the collector.
fn for_each(
&self,
reader: &SegmentReader,
reader: &dyn SegmentReader,
callback: &mut dyn FnMut(DocId, Score),
) -> crate::Result<()> {
match self.specialized_scorer(reader, 1.0)? {
@@ -87,7 +87,7 @@ impl Weight for TermWeight {
/// `DocSet` and push the scored documents to the collector.
fn for_each_no_score(
&self,
reader: &SegmentReader,
reader: &dyn SegmentReader,
callback: &mut dyn FnMut(&[DocId]),
) -> crate::Result<()> {
match self.specialized_scorer(reader, 1.0)? {
@@ -118,7 +118,7 @@ impl Weight for TermWeight {
fn for_each_pruning(
&self,
threshold: Score,
reader: &SegmentReader,
reader: &dyn SegmentReader,
callback: &mut dyn FnMut(DocId, Score) -> Score,
) -> crate::Result<()> {
let specialized_scorer = self.specialized_scorer(reader, 1.0)?;
@@ -166,7 +166,7 @@ impl TermWeight {
#[cfg(test)]
pub(crate) fn term_scorer_for_test(
&self,
reader: &SegmentReader,
reader: &dyn SegmentReader,
boost: Score,
) -> crate::Result<Option<TermScorer>> {
let scorer = self.specialized_scorer(reader, boost)?;
@@ -178,7 +178,7 @@ impl TermWeight {
fn specialized_scorer(
&self,
reader: &SegmentReader,
reader: &dyn SegmentReader,
boost: Score,
) -> crate::Result<TermOrEmptyOrAllScorer> {
let field = self.term.field();
@@ -206,7 +206,10 @@ impl TermWeight {
)))
}
fn fieldnorm_reader(&self, segment_reader: &SegmentReader) -> crate::Result<FieldNormReader> {
fn fieldnorm_reader(
&self,
segment_reader: &dyn SegmentReader,
) -> crate::Result<FieldNormReader> {
if self.scoring_enabled {
if let Some(field_norm_reader) = segment_reader
.fieldnorms_readers()

View File

@@ -69,13 +69,13 @@ pub trait Weight: Send + Sync + 'static {
/// `boost` is a multiplier to apply to the score.
///
/// See [`Query`](crate::query::Query).
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>>;
fn scorer(&self, reader: &dyn SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>>;
/// Returns an [`Explanation`] for the given document.
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation>;
fn explain(&self, reader: &dyn SegmentReader, doc: DocId) -> crate::Result<Explanation>;
/// Returns the number documents within the given [`SegmentReader`].
fn count(&self, reader: &SegmentReader) -> crate::Result<u32> {
fn count(&self, reader: &dyn SegmentReader) -> crate::Result<u32> {
let mut scorer = self.scorer(reader, 1.0)?;
if let Some(alive_bitset) = reader.alive_bitset() {
Ok(scorer.count(alive_bitset))
@@ -88,7 +88,7 @@ pub trait Weight: Send + Sync + 'static {
/// `DocSet` and push the scored documents to the collector.
fn for_each(
&self,
reader: &SegmentReader,
reader: &dyn SegmentReader,
callback: &mut dyn FnMut(DocId, Score),
) -> crate::Result<()> {
let mut scorer = self.scorer(reader, 1.0)?;
@@ -100,7 +100,7 @@ pub trait Weight: Send + Sync + 'static {
/// `DocSet` and push the scored documents to the collector.
fn for_each_no_score(
&self,
reader: &SegmentReader,
reader: &dyn SegmentReader,
callback: &mut dyn FnMut(&[DocId]),
) -> crate::Result<()> {
let mut docset = self.scorer(reader, 1.0)?;
@@ -123,7 +123,7 @@ pub trait Weight: Send + Sync + 'static {
fn for_each_pruning(
&self,
threshold: Score,
reader: &SegmentReader,
reader: &dyn SegmentReader,
callback: &mut dyn FnMut(DocId, Score) -> Score,
) -> crate::Result<()> {
let mut scorer = self.scorer(reader, 1.0)?;

View File

@@ -10,7 +10,7 @@ use self::warming::WarmingState;
use crate::core::searcher::{SearcherGeneration, SearcherInner};
use crate::directory::{Directory, WatchCallback, WatchHandle, META_LOCK};
use crate::store::DOCSTORE_CACHE_CAPACITY;
use crate::{Index, Inventory, Searcher, SegmentReader, TrackedObject};
use crate::{ArcSegmentReader, Index, Inventory, Searcher, TantivySegmentReader, TrackedObject};
/// Defines when a new version of the index should be reloaded.
///
@@ -189,19 +189,22 @@ impl InnerIndexReader {
///
/// This function acquires a lock to prevent GC from removing files
/// as we are opening our index.
fn open_segment_readers(index: &Index) -> crate::Result<Vec<SegmentReader>> {
fn open_segment_readers(index: &Index) -> crate::Result<Vec<ArcSegmentReader>> {
// Prevents segment files from getting deleted while we are in the process of opening them
let _meta_lock = index.directory().acquire_lock(&META_LOCK)?;
let searchable_segments = index.searchable_segments()?;
let segment_readers = searchable_segments
.iter()
.map(SegmentReader::open)
.map(|segment| {
TantivySegmentReader::open(segment)
.map(|reader| Arc::new(reader) as ArcSegmentReader)
})
.collect::<crate::Result<_>>()?;
Ok(segment_readers)
}
fn track_segment_readers_in_inventory(
segment_readers: &[SegmentReader],
segment_readers: &[ArcSegmentReader],
searcher_generation_counter: &Arc<AtomicU64>,
searcher_generation_inventory: &Inventory<SearcherGeneration>,
) -> TrackedObject<SearcherGeneration> {

View File

@@ -210,8 +210,11 @@ mod tests {
index_writer.add_document(doc!(text=>"abc"))?;
index_writer.commit()?;
let searcher = index.reader()?.searcher();
let err = searcher.segment_reader(0u32).get_fieldnorms_reader(text);
assert!(matches!(err, Err(crate::TantivyError::SchemaError(_))));
let fieldnorm_opt = searcher
.segment_reader(0u32)
.fieldnorms_readers()
.get_field(text)?;
assert!(fieldnorm_opt.is_none());
Ok(())
}
}

Some files were not shown because too many files have changed in this diff Show More