Compare commits

...

16 Commits

Author SHA1 Message Date
PSeitz
923f0508f2 seek_exact + cost based intersection (#2538)
* seek_exact + cost based intersection

Adds `seek_exact` and `cost` to `DocSet` for a more efficient intersection.
Unlike `seek`, `seek_exact` does not require the DocSet to advance to the next hit, if the target does not exist.

`cost` allows to address the different DocSet types and their cost
model and is used to determine the DocSet that drives the intersection.
E.g. fast field range queries may do a full scan. Phrase queries load the positions to check if a we have a hit.
They both have a higher cost than their size_hint would suggest.

Improves `size_hint` estimation for intersection and union, by having a
estimation based on random distribution with a co-location factor.

Refactor range query benchmark.

Closes #2531

*Future Work*

Implement `seek_exact` for BufferedUnionScorer and RangeDocSet (fast field range queries)
Evaluate replacing `seek` with `seek_exact` to reduce code complexity

* Apply suggestions from code review

Co-authored-by: Paul Masurel <paul@quickwit.io>

* add API contract verfication

* impl seek_exact on union

* rename seek_exact

* add mixed AND OR test, fix buffered_union

* Add a proptest of BooleanQuery. (#2690)

* fix build

* Increase the document count.

* fix merge conflict

* fix debug assert

* Fix compilation errors after rebase

- Remove duplicate proptest_boolean_query module
- Remove duplicate cost() method implementations
- Fix TopDocs API usage (add .order_by_score())
- Remove duplicate imports
- Remove unused variable assignments

---------

Co-authored-by: Paul Masurel <paul@quickwit.io>
Co-authored-by: Pascal Seitz <pascal.seitz@datadoghq.com>
Co-authored-by: Stu Hood <stuhood@gmail.com>
2025-12-30 14:43:25 +01:00
ChangRui-Ryan
e0b62e00ac optimize RangeDocSet for non-overlapping query ranges (#2783) 2025-12-29 16:55:28 +01:00
Stu Hood
ce97beb86f Add support for natural-order-with-none-highest in TopDocs::order_by (#2780)
* Add `ComparatorEnum::NaturalNoneHigher`.

* Fix comments.
2025-12-23 09:22:20 +01:00
Stu Hood
c0f21a45ae Use a strict comparison in TopNComputer (#2777)
* Remove `(Partial)Ord` from `ComparableDoc`, and unify comparison between `TopNComputer` and `Comparator`.

* Doc cleanups.

* Require Ord for `ComparableDoc`.

* Semantics are actually _ascending_ DocId order.

* Adjust docs again for ascending DocId order.

* minor change

---------

Co-authored-by: Paul Masurel <paul.masurel@datadoghq.com>
2025-12-18 12:13:23 +01:00
Moe
73657dff77 fix: fixed integer overflow in ExpUnrolledLinkedList for large datasets (#2735)
* Fixed the overflow issue.

* Fixed lint issues.

* Applied PR fixes.

* Fixed a lint issue.
2025-12-16 22:57:12 +01:00
Moe
e3c9be1f92 fix: boolean query incorrectly dropping documents when AllScorer is present (#2760)
* Fixed the range issue.

* Fixed the second all scorer issue

* Improved docs + tests

* Improved code.

* Fixed lint issues.

* Improved tests + logic based on PR comments.

* Fixed lint issues.

* Increase the document count.

* Improved the prop-tests

* Expand the index size, and remove unused parameter.

---------

Co-authored-by: Stu Hood <stuhood@gmail.com>
2025-12-16 22:52:02 +01:00
Ming
ba61ed6ef3 fix: vint buffer can overflow (#2778)
* fix vint overflow

* comment
2025-12-16 22:50:41 +01:00
trinity-1686a
d0e1600135 fix bug with minimum_should_match and AllScorer (#2774) 2025-12-14 10:10:45 +01:00
PSeitz-dd
e9020d17d4 fix coverage (#2769) 2025-12-11 11:35:58 +01:00
PSeitz-dd
5ba0031f7d move rand_distr to dev_dep (#2772) 2025-12-11 18:23:50 +08:00
Philippe Noël
22dde8f9ae chore: Make some delete-related functions public (#46) (#2766)
Co-authored-by: Ming <ming.ying.nyc@gmail.com>
2025-12-11 01:22:15 +01:00
Philippe Noël
14cc24614e Make DeleteMeta pub (#2765)
Co-authored-by: Ming Ying <ming.ying.nyc@gmail.com>
2025-12-11 00:11:03 +01:00
Philippe Noël
8a1079b2dc expose AddOperation and with_max_doc (#7) (#2762)
Co-authored-by: Ming <ming.ying.nyc@gmail.com>
2025-12-11 00:10:42 +01:00
Philippe Noël
794ff1ffc9 chore: Make Language hashable (#79) (#2763)
Co-authored-by: Ming <ming.ying.nyc@gmail.com>
2025-12-10 15:38:43 +01:00
PSeitz-dd
c6912ce89a Handle JSON fields and columnar in space_usage (#2761)
return field names in space_usage instead of `Field`
more detailed info for columns
2025-12-10 20:33:33 +08:00
PSeitz
618e3bd11b Term and IndexingTerm cleanup (#2750)
* refactor term

* add deprecated functions

---------

Co-authored-by: Pascal Seitz <pascal.seitz@datadoghq.com>
2025-12-05 09:48:40 +08:00
51 changed files with 2435 additions and 875 deletions

View File

@@ -15,11 +15,11 @@ jobs:
steps:
- uses: actions/checkout@v4
- name: Install Rust
run: rustup toolchain install nightly-2024-07-01 --profile minimal --component llvm-tools-preview
run: rustup toolchain install nightly-2025-12-01 --profile minimal --component llvm-tools-preview
- uses: Swatinem/rust-cache@v2
- uses: taiki-e/install-action@cargo-llvm-cov
- name: Generate code coverage
run: cargo +nightly-2024-07-01 llvm-cov --all-features --workspace --doctests --lcov --output-path lcov.info
run: cargo +nightly-2025-12-01 llvm-cov --all-features --workspace --doctests --lcov --output-path lcov.info
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
continue-on-error: true

View File

@@ -75,12 +75,12 @@ typetag = "0.2.21"
winapi = "0.3.9"
[dev-dependencies]
binggan = "0.14.0"
binggan = "0.14.2"
rand = "0.8.5"
maplit = "1.0.2"
matches = "0.1.9"
pretty_assertions = "1.2.1"
proptest = "1.0.0"
proptest = "1.7.0"
test-log = "0.2.10"
futures = "0.3.21"
paste = "1.0.11"
@@ -173,6 +173,14 @@ harness = false
name = "exists_json"
harness = false
[[bench]]
name = "range_query"
harness = false
[[bench]]
name = "and_or_queries"
harness = false
[[bench]]
name = "range_queries"
harness = false

365
benches/range_queries.rs Normal file
View File

@@ -0,0 +1,365 @@
use std::ops::Bound;
use binggan::{black_box, BenchGroup, BenchRunner};
use rand::prelude::*;
use rand::rngs::StdRng;
use rand::SeedableRng;
use tantivy::collector::{Count, DocSetCollector, TopDocs};
use tantivy::query::RangeQuery;
use tantivy::schema::{Schema, FAST, INDEXED};
use tantivy::{doc, Index, Order, ReloadPolicy, Searcher, Term};
#[derive(Clone)]
struct BenchIndex {
#[allow(dead_code)]
index: Index,
searcher: Searcher,
}
fn build_shared_indices(num_docs: usize, distribution: &str) -> BenchIndex {
// Schema with fast fields only
let mut schema_builder = Schema::builder();
let f_num_rand_fast = schema_builder.add_u64_field("num_rand_fast", INDEXED | FAST);
let f_num_asc_fast = schema_builder.add_u64_field("num_asc_fast", INDEXED | FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema.clone());
// Populate index with stable RNG for reproducibility.
let mut rng = StdRng::from_seed([7u8; 32]);
{
let mut writer = index.writer_with_num_threads(1, 4_000_000_000).unwrap();
match distribution {
"dense" => {
for doc_id in 0..num_docs {
let num_rand = rng.gen_range(0u64..1000u64);
let num_asc = (doc_id / 10000) as u64;
writer
.add_document(doc!(
f_num_rand_fast=>num_rand,
f_num_asc_fast=>num_asc,
))
.unwrap();
}
}
"sparse" => {
for doc_id in 0..num_docs {
let num_rand = rng.gen_range(0u64..10000000u64);
let num_asc = doc_id as u64;
writer
.add_document(doc!(
f_num_rand_fast=>num_rand,
f_num_asc_fast=>num_asc,
))
.unwrap();
}
}
_ => {
panic!("Unsupported distribution type");
}
}
writer.commit().unwrap();
}
// Prepare reader/searcher once.
let reader = index
.reader_builder()
.reload_policy(ReloadPolicy::Manual)
.try_into()
.unwrap();
let searcher = reader.searcher();
BenchIndex { index, searcher }
}
fn main() {
// Prepare corpora with varying scenarios
let scenarios = vec![
// Dense distribution - random values in small range (0-999)
(
"dense_values_search_low_value_range".to_string(),
10_000_000,
"dense",
0,
9,
),
(
"dense_values_search_high_value_range".to_string(),
10_000_000,
"dense",
990,
999,
),
(
"dense_values_search_out_of_range".to_string(),
10_000_000,
"dense",
1000,
1002,
),
(
"sparse_values_search_low_value_range".to_string(),
10_000_000,
"sparse",
0,
9,
),
(
"sparse_values_search_high_value_range".to_string(),
10_000_000,
"sparse",
9_999_990,
9_999_999,
),
(
"sparse_values_search_out_of_range".to_string(),
10_000_000,
"sparse",
10_000_000,
10_000_002,
),
];
let mut runner = BenchRunner::new();
for (scenario_id, n, num_rand_distribution, range_low, range_high) in scenarios {
// Build index for this scenario
let bench_index = build_shared_indices(n, num_rand_distribution);
// Create benchmark group
let mut group = runner.new_group();
// Now set the name (this moves scenario_id)
group.set_name(scenario_id);
// Define fast field types
let field_names = ["num_rand_fast", "num_asc_fast"];
// Generate range queries for fast fields
for &field_name in &field_names {
// Create the range query
let field = bench_index.searcher.schema().get_field(field_name).unwrap();
let lower_term = Term::from_field_u64(field, range_low);
let upper_term = Term::from_field_u64(field, range_high);
let query = RangeQuery::new(Bound::Included(lower_term), Bound::Included(upper_term));
run_benchmark_tasks(
&mut group,
&bench_index,
query,
field_name,
range_low,
range_high,
);
}
group.run();
}
}
/// Run all benchmark tasks for a given range query and field name
fn run_benchmark_tasks(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query: RangeQuery,
field_name: &str,
range_low: u64,
range_high: u64,
) {
// Test count
add_bench_task_count(
bench_group,
bench_index,
query.clone(),
"count",
field_name,
range_low,
range_high,
);
// Test top 100 by the field (ascending order)
{
let collector_name = format!("top100_by_{}_asc", field_name);
let field_name_owned = field_name.to_string();
add_bench_task_top100_asc(
bench_group,
bench_index,
query.clone(),
&collector_name,
field_name,
range_low,
range_high,
field_name_owned,
);
}
// Test top 100 by the field (descending order)
{
let collector_name = format!("top100_by_{}_desc", field_name);
let field_name_owned = field_name.to_string();
add_bench_task_top100_desc(
bench_group,
bench_index,
query,
&collector_name,
field_name,
range_low,
range_high,
field_name_owned,
);
}
}
fn add_bench_task_count(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query: RangeQuery,
collector_name: &str,
field_name: &str,
range_low: u64,
range_high: u64,
) {
let task_name = format!(
"range_{}_[{} TO {}]_{}",
field_name, range_low, range_high, collector_name
);
let search_task = CountSearchTask {
searcher: bench_index.searcher.clone(),
query,
};
bench_group.register(task_name, move |_| black_box(search_task.run()));
}
fn add_bench_task_docset(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query: RangeQuery,
collector_name: &str,
field_name: &str,
range_low: u64,
range_high: u64,
) {
let task_name = format!(
"range_{}_[{} TO {}]_{}",
field_name, range_low, range_high, collector_name
);
let search_task = DocSetSearchTask {
searcher: bench_index.searcher.clone(),
query,
};
bench_group.register(task_name, move |_| black_box(search_task.run()));
}
fn add_bench_task_top100_asc(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query: RangeQuery,
collector_name: &str,
field_name: &str,
range_low: u64,
range_high: u64,
field_name_owned: String,
) {
let task_name = format!(
"range_{}_[{} TO {}]_{}",
field_name, range_low, range_high, collector_name
);
let search_task = Top100AscSearchTask {
searcher: bench_index.searcher.clone(),
query,
field_name: field_name_owned,
};
bench_group.register(task_name, move |_| black_box(search_task.run()));
}
fn add_bench_task_top100_desc(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query: RangeQuery,
collector_name: &str,
field_name: &str,
range_low: u64,
range_high: u64,
field_name_owned: String,
) {
let task_name = format!(
"range_{}_[{} TO {}]_{}",
field_name, range_low, range_high, collector_name
);
let search_task = Top100DescSearchTask {
searcher: bench_index.searcher.clone(),
query,
field_name: field_name_owned,
};
bench_group.register(task_name, move |_| black_box(search_task.run()));
}
struct CountSearchTask {
searcher: Searcher,
query: RangeQuery,
}
impl CountSearchTask {
#[inline(never)]
pub fn run(&self) -> usize {
self.searcher.search(&self.query, &Count).unwrap()
}
}
struct DocSetSearchTask {
searcher: Searcher,
query: RangeQuery,
}
impl DocSetSearchTask {
#[inline(never)]
pub fn run(&self) -> usize {
let result = self.searcher.search(&self.query, &DocSetCollector).unwrap();
result.len()
}
}
struct Top100AscSearchTask {
searcher: Searcher,
query: RangeQuery,
field_name: String,
}
impl Top100AscSearchTask {
#[inline(never)]
pub fn run(&self) -> usize {
let collector =
TopDocs::with_limit(100).order_by_fast_field::<u64>(&self.field_name, Order::Asc);
let result = self.searcher.search(&self.query, &collector).unwrap();
for (_score, doc_address) in &result {
let _doc: tantivy::TantivyDocument = self.searcher.doc(*doc_address).unwrap();
}
result.len()
}
}
struct Top100DescSearchTask {
searcher: Searcher,
query: RangeQuery,
field_name: String,
}
impl Top100DescSearchTask {
#[inline(never)]
pub fn run(&self) -> usize {
let collector =
TopDocs::with_limit(100).order_by_fast_field::<u64>(&self.field_name, Order::Desc);
let result = self.searcher.search(&self.query, &collector).unwrap();
for (_score, doc_address) in &result {
let _doc: tantivy::TantivyDocument = self.searcher.doc(*doc_address).unwrap();
}
result.len()
}
}

260
benches/range_query.rs Normal file
View File

@@ -0,0 +1,260 @@
use std::fmt::Display;
use std::net::Ipv6Addr;
use std::ops::RangeInclusive;
use binggan::plugins::PeakMemAllocPlugin;
use binggan::{black_box, BenchRunner, OutputValue, PeakMemAlloc, INSTRUMENTED_SYSTEM};
use columnar::MonotonicallyMappableToU128;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use tantivy::collector::{Count, TopDocs};
use tantivy::query::QueryParser;
use tantivy::schema::*;
use tantivy::{doc, Index};
#[global_allocator]
pub static GLOBAL: &PeakMemAlloc<std::alloc::System> = &INSTRUMENTED_SYSTEM;
fn main() {
bench_range_query();
}
fn bench_range_query() {
let index = get_index_0_to_100();
let mut runner = BenchRunner::new();
runner.add_plugin(PeakMemAllocPlugin::new(GLOBAL));
runner.set_name("range_query on u64");
let field_name_and_descr: Vec<_> = vec![
("id", "Single Valued Range Field"),
("ids", "Multi Valued Range Field"),
];
let range_num_hits = vec![
("90_percent", get_90_percent()),
("10_percent", get_10_percent()),
("1_percent", get_1_percent()),
];
test_range(&mut runner, &index, &field_name_and_descr, range_num_hits);
runner.set_name("range_query on ip");
let field_name_and_descr: Vec<_> = vec![
("ip", "Single Valued Range Field"),
("ips", "Multi Valued Range Field"),
];
let range_num_hits = vec![
("90_percent", get_90_percent_ip()),
("10_percent", get_10_percent_ip()),
("1_percent", get_1_percent_ip()),
];
test_range(&mut runner, &index, &field_name_and_descr, range_num_hits);
}
fn test_range<T: Display>(
runner: &mut BenchRunner,
index: &Index,
field_name_and_descr: &[(&str, &str)],
range_num_hits: Vec<(&str, RangeInclusive<T>)>,
) {
for (field, suffix) in field_name_and_descr {
let term_num_hits = vec![
("", ""),
("1_percent", "veryfew"),
("10_percent", "few"),
("90_percent", "most"),
];
let mut group = runner.new_group();
group.set_name(suffix);
// all intersect combinations
for (range_name, range) in &range_num_hits {
for (term_name, term) in &term_num_hits {
let index = &index;
let test_name = if term_name.is_empty() {
format!("id_range_hit_{}", range_name)
} else {
format!(
"id_range_hit_{}_intersect_with_term_{}",
range_name, term_name
)
};
group.register(test_name, move |_| {
let query = if term_name.is_empty() {
"".to_string()
} else {
format!("AND id_name:{}", term)
};
black_box(execute_query(field, range, &query, index));
});
}
}
group.run();
}
}
fn get_index_0_to_100() -> Index {
let mut rng = StdRng::from_seed([1u8; 32]);
let num_vals = 100_000;
let docs: Vec<_> = (0..num_vals)
.map(|_i| {
let id_name = if rng.gen_bool(0.01) {
"veryfew".to_string() // 1%
} else if rng.gen_bool(0.1) {
"few".to_string() // 9%
} else {
"most".to_string() // 90%
};
Doc {
id_name,
id: rng.gen_range(0..100),
// Multiply by 1000, so that we create most buckets in the compact space
// The benches depend on this range to select n-percent of elements with the
// methods below.
ip: Ipv6Addr::from_u128(rng.gen_range(0..100) * 1000),
}
})
.collect();
create_index_from_docs(&docs)
}
#[derive(Clone, Debug)]
pub struct Doc {
pub id_name: String,
pub id: u64,
pub ip: Ipv6Addr,
}
pub fn create_index_from_docs(docs: &[Doc]) -> Index {
let mut schema_builder = Schema::builder();
let id_u64_field = schema_builder.add_u64_field("id", INDEXED | STORED | FAST);
let ids_u64_field =
schema_builder.add_u64_field("ids", NumericOptions::default().set_fast().set_indexed());
let id_f64_field = schema_builder.add_f64_field("id_f64", INDEXED | STORED | FAST);
let ids_f64_field = schema_builder.add_f64_field(
"ids_f64",
NumericOptions::default().set_fast().set_indexed(),
);
let id_i64_field = schema_builder.add_i64_field("id_i64", INDEXED | STORED | FAST);
let ids_i64_field = schema_builder.add_i64_field(
"ids_i64",
NumericOptions::default().set_fast().set_indexed(),
);
let text_field = schema_builder.add_text_field("id_name", STRING | STORED);
let text_field2 = schema_builder.add_text_field("id_name_fast", STRING | STORED | FAST);
let ip_field = schema_builder.add_ip_addr_field("ip", FAST);
let ips_field = schema_builder.add_ip_addr_field("ips", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
{
let mut index_writer = index.writer_with_num_threads(1, 50_000_000).unwrap();
for doc in docs.iter() {
index_writer
.add_document(doc!(
ids_i64_field => doc.id as i64,
ids_i64_field => doc.id as i64,
ids_f64_field => doc.id as f64,
ids_f64_field => doc.id as f64,
ids_u64_field => doc.id,
ids_u64_field => doc.id,
id_u64_field => doc.id,
id_f64_field => doc.id as f64,
id_i64_field => doc.id as i64,
text_field => doc.id_name.to_string(),
text_field2 => doc.id_name.to_string(),
ips_field => doc.ip,
ips_field => doc.ip,
ip_field => doc.ip,
))
.unwrap();
}
index_writer.commit().unwrap();
}
index
}
fn get_90_percent() -> RangeInclusive<u64> {
0..=90
}
fn get_10_percent() -> RangeInclusive<u64> {
0..=10
}
fn get_1_percent() -> RangeInclusive<u64> {
10..=10
}
fn get_90_percent_ip() -> RangeInclusive<Ipv6Addr> {
let start = Ipv6Addr::from_u128(0);
let end = Ipv6Addr::from_u128(90 * 1000);
start..=end
}
fn get_10_percent_ip() -> RangeInclusive<Ipv6Addr> {
let start = Ipv6Addr::from_u128(0);
let end = Ipv6Addr::from_u128(10 * 1000);
start..=end
}
fn get_1_percent_ip() -> RangeInclusive<Ipv6Addr> {
let start = Ipv6Addr::from_u128(10 * 1000);
let end = Ipv6Addr::from_u128(10 * 1000);
start..=end
}
struct NumHits {
count: usize,
}
impl OutputValue for NumHits {
fn column_title() -> &'static str {
"NumHits"
}
fn format(&self) -> Option<String> {
Some(self.count.to_string())
}
}
fn execute_query<T: Display>(
field: &str,
id_range: &RangeInclusive<T>,
suffix: &str,
index: &Index,
) -> NumHits {
let gen_query_inclusive = |from: &T, to: &T| {
format!(
"{}:[{} TO {}] {}",
field,
&from.to_string(),
&to.to_string(),
suffix
)
};
let query = gen_query_inclusive(id_range.start(), id_range.end());
execute_query_(&query, index)
}
fn execute_query_(query: &str, index: &Index) -> NumHits {
let query_from_text = |text: &str| {
QueryParser::for_index(index, vec![])
.parse_query(text)
.unwrap()
};
let query = query_from_text(query);
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let num_hits = searcher
.search(&query, &(TopDocs::with_limit(10).order_by_score(), Count))
.unwrap()
.1;
NumHits { count: num_hits }
}

View File

@@ -41,12 +41,6 @@ fn transform_range_before_linear_transformation(
if range.is_empty() {
return None;
}
if stats.min_value > *range.end() {
return None;
}
if stats.max_value < *range.start() {
return None;
}
let shifted_range =
range.start().saturating_sub(stats.min_value)..=range.end().saturating_sub(stats.min_value);
let start_before_gcd_multiplication: u64 = div_ceil(*shifted_range.start(), stats.gcd);

View File

@@ -3,7 +3,8 @@ use std::sync::Arc;
use std::{fmt, io};
use common::file_slice::FileSlice;
use common::{ByteCount, DateTime, HasLen, OwnedBytes};
use common::{ByteCount, DateTime, OwnedBytes};
use serde::{Deserialize, Serialize};
use crate::column::{BytesColumn, Column, StrColumn};
use crate::column_values::{StrictlyMonotonicFn, monotonic_map_column};
@@ -317,10 +318,89 @@ impl DynamicColumnHandle {
}
pub fn num_bytes(&self) -> ByteCount {
self.file_slice.len().into()
self.file_slice.num_bytes()
}
/// Legacy helper returning the column space usage.
pub fn column_and_dictionary_num_bytes(&self) -> io::Result<ColumnSpaceUsage> {
self.space_usage()
}
/// Return the space usage of the column, optionally broken down by dictionary and column
/// values.
///
/// For dictionary encoded columns (strings and bytes), this splits the total footprint into
/// the dictionary and the remaining column data (including index and values).
/// For all other column types, the dictionary size is `None` and the column size
/// equals the total bytes.
pub fn space_usage(&self) -> io::Result<ColumnSpaceUsage> {
let total_num_bytes = self.num_bytes();
let dynamic_column = self.open()?;
let dictionary_num_bytes = match &dynamic_column {
DynamicColumn::Bytes(bytes_column) => bytes_column.dictionary().num_bytes(),
DynamicColumn::Str(str_column) => str_column.dictionary().num_bytes(),
_ => {
return Ok(ColumnSpaceUsage::new(self.num_bytes(), None));
}
};
assert!(dictionary_num_bytes <= total_num_bytes);
let column_num_bytes =
ByteCount::from(total_num_bytes.get_bytes() - dictionary_num_bytes.get_bytes());
Ok(ColumnSpaceUsage::new(
column_num_bytes,
Some(dictionary_num_bytes),
))
}
pub fn column_type(&self) -> ColumnType {
self.column_type
}
}
/// Represents space usage of a column.
///
/// `column_num_bytes` tracks the column payload (index, values and footer).
/// For dictionary encoded columns, `dictionary_num_bytes` captures the dictionary footprint.
/// [`ColumnSpaceUsage::total_num_bytes`] returns the sum of both parts.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ColumnSpaceUsage {
column_num_bytes: ByteCount,
dictionary_num_bytes: Option<ByteCount>,
}
impl ColumnSpaceUsage {
pub(crate) fn new(
column_num_bytes: ByteCount,
dictionary_num_bytes: Option<ByteCount>,
) -> Self {
ColumnSpaceUsage {
column_num_bytes,
dictionary_num_bytes,
}
}
pub fn column_num_bytes(&self) -> ByteCount {
self.column_num_bytes
}
pub fn dictionary_num_bytes(&self) -> Option<ByteCount> {
self.dictionary_num_bytes
}
pub fn total_num_bytes(&self) -> ByteCount {
self.column_num_bytes + self.dictionary_num_bytes.unwrap_or_default()
}
/// Merge two space usage values by summing their components.
pub fn merge(&self, other: &ColumnSpaceUsage) -> ColumnSpaceUsage {
let dictionary_num_bytes = match (self.dictionary_num_bytes, other.dictionary_num_bytes) {
(Some(lhs), Some(rhs)) => Some(lhs + rhs),
(Some(val), None) | (None, Some(val)) => Some(val),
(None, None) => None,
};
ColumnSpaceUsage {
column_num_bytes: self.column_num_bytes + other.column_num_bytes,
dictionary_num_bytes,
}
}
}

View File

@@ -48,7 +48,7 @@ pub use columnar::{
use sstable::VoidSSTable;
pub use value::{NumericalType, NumericalValue};
pub use self::dynamic_column::{DynamicColumn, DynamicColumnHandle};
pub use self::dynamic_column::{ColumnSpaceUsage, DynamicColumn, DynamicColumnHandle};
pub type RowId = u32;
pub type DocId = u32;

View File

@@ -11,7 +11,26 @@ pub use sort_by_string::SortByString;
pub use sort_key_computer::{SegmentSortKeyComputer, SortKeyComputer};
#[cfg(test)]
mod tests {
pub(crate) mod tests {
// By spec, regardless of whether ascending or descending order was requested, in presence of a
// tie, we sort by ascending doc id/doc address.
pub(crate) fn sort_hits<TSortKey: Ord, D: Ord>(
hits: &mut [ComparableDoc<TSortKey, D>],
order: Order,
) {
if order.is_asc() {
hits.sort_by(|l, r| l.sort_key.cmp(&r.sort_key).then(l.doc.cmp(&r.doc)));
} else {
hits.sort_by(|l, r| {
l.sort_key
.cmp(&r.sort_key)
.reverse() // This is descending
.then(l.doc.cmp(&r.doc))
});
}
}
use std::collections::HashMap;
use std::ops::Range;
@@ -372,15 +391,10 @@ mod tests {
// Using the TopDocs collector should always be equivalent to sorting, skipping the
// offset, and then taking the limit.
let sorted_docs: Vec<_> = if order.is_desc() {
let mut comparable_docs: Vec<ComparableDoc<_, _, true>> =
let sorted_docs: Vec<_> = {
let mut comparable_docs: Vec<ComparableDoc<_, _>> =
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
comparable_docs.sort();
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
} else {
let mut comparable_docs: Vec<ComparableDoc<_, _, false>> =
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
comparable_docs.sort();
sort_hits(&mut comparable_docs, order);
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
};
let expected_docs = sorted_docs.into_iter().skip(offset).take(limit).collect::<Vec<_>>();

View File

@@ -12,8 +12,13 @@ pub trait Comparator<T>: Send + Sync + std::fmt::Debug + Default {
fn compare(&self, lhs: &T, rhs: &T) -> Ordering;
}
/// With the natural comparator, the top k collector will return
/// the top documents in decreasing order.
/// Compare values naturally (e.g. 1 < 2).
///
/// When used with `TopDocs`, which reverses the order, this results in a
/// "Descending" sort (Greatest values first).
///
/// `None` (or Null for `OwnedValue`) values are considered to be smaller than any other value,
/// and will therefore appear last in a descending sort (e.g. `[Some(20), Some(10), None]`).
#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)]
pub struct NaturalComparator;
@@ -24,13 +29,18 @@ impl<T: PartialOrd> Comparator<T> for NaturalComparator {
}
}
/// Sorts document in reverse order.
/// Compare values in reverse (e.g. 2 < 1).
///
/// If the sort key is None, it will considered as the lowest value, and will therefore appear
/// first.
/// When used with `TopDocs`, which reverses the order, this results in an
/// "Ascending" sort (Smallest values first).
///
/// `None` is considered smaller than `Some` in the underlying comparator, but because the
/// comparison is reversed, `None` is effectively treated as the lowest value in the resulting
/// Ascending sort (e.g. `[None, Some(10), Some(20)]`).
///
/// The ReverseComparator does not necessarily imply that the sort order is reversed compared
/// to the NaturalComparator. In presence of a tie, both version will retain the higher doc ids.
/// to the NaturalComparator. In presence of a tie on the sort key, documents will always be
/// sorted by ascending `DocId`/`DocAddress` in TopN results, regardless of the sort key's order.
#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)]
pub struct ReverseComparator;
@@ -43,11 +53,15 @@ where NaturalComparator: Comparator<T>
}
}
/// Sorts document in reverse order, but considers None as having the lowest value.
/// Compare values in reverse, but treating `None` as lower than `Some`.
///
/// When used with `TopDocs`, which reverses the order, this results in an
/// "Ascending" sort (Smallest values first), but with `None` values appearing last
/// (e.g. `[Some(10), Some(20), None]`).
///
/// This is usually what is wanted when sorting by a field in an ascending order.
/// For instance, in a e-commerce website, if I sort by price ascending, I most likely want the
/// cheapest items first, and the items without a price at last.
/// For instance, in an e-commerce website, if sorting by price ascending,
/// the cheapest items would appear first, and items without a price would appear last.
#[derive(Debug, Copy, Clone, Default)]
pub struct ReverseNoneIsLowerComparator;
@@ -107,6 +121,70 @@ impl Comparator<String> for ReverseNoneIsLowerComparator {
}
}
/// Compare values naturally, but treating `None` as higher than `Some`.
///
/// When used with `TopDocs`, which reverses the order, this results in a
/// "Descending" sort (Greatest values first), but with `None` values appearing first
/// (e.g. `[None, Some(20), Some(10)]`).
#[derive(Debug, Copy, Clone, Default, Serialize, Deserialize)]
pub struct NaturalNoneIsHigherComparator;
impl<T> Comparator<Option<T>> for NaturalNoneIsHigherComparator
where NaturalComparator: Comparator<T>
{
#[inline(always)]
fn compare(&self, lhs_opt: &Option<T>, rhs_opt: &Option<T>) -> Ordering {
match (lhs_opt, rhs_opt) {
(None, None) => Ordering::Equal,
(None, Some(_)) => Ordering::Greater,
(Some(_), None) => Ordering::Less,
(Some(lhs), Some(rhs)) => NaturalComparator.compare(lhs, rhs),
}
}
}
impl Comparator<u32> for NaturalNoneIsHigherComparator {
#[inline(always)]
fn compare(&self, lhs: &u32, rhs: &u32) -> Ordering {
NaturalComparator.compare(lhs, rhs)
}
}
impl Comparator<u64> for NaturalNoneIsHigherComparator {
#[inline(always)]
fn compare(&self, lhs: &u64, rhs: &u64) -> Ordering {
NaturalComparator.compare(lhs, rhs)
}
}
impl Comparator<f64> for NaturalNoneIsHigherComparator {
#[inline(always)]
fn compare(&self, lhs: &f64, rhs: &f64) -> Ordering {
NaturalComparator.compare(lhs, rhs)
}
}
impl Comparator<f32> for NaturalNoneIsHigherComparator {
#[inline(always)]
fn compare(&self, lhs: &f32, rhs: &f32) -> Ordering {
NaturalComparator.compare(lhs, rhs)
}
}
impl Comparator<i64> for NaturalNoneIsHigherComparator {
#[inline(always)]
fn compare(&self, lhs: &i64, rhs: &i64) -> Ordering {
NaturalComparator.compare(lhs, rhs)
}
}
impl Comparator<String> for NaturalNoneIsHigherComparator {
#[inline(always)]
fn compare(&self, lhs: &String, rhs: &String) -> Ordering {
NaturalComparator.compare(lhs, rhs)
}
}
/// An enum representing the different sort orders.
#[derive(Debug, Clone, Copy, Eq, PartialEq, Default)]
pub enum ComparatorEnum {
@@ -115,8 +193,10 @@ pub enum ComparatorEnum {
Natural,
/// Reverse order (See [ReverseComparator])
Reverse,
/// Reverse order by treating None as the lowest value.(See [ReverseNoneLowerComparator])
/// Reverse order by treating None as the lowest value. (See [ReverseNoneLowerComparator])
ReverseNoneLower,
/// Natural order but treating None as the highest value. (See [NaturalNoneIsHigherComparator])
NaturalNoneHigher,
}
impl From<Order> for ComparatorEnum {
@@ -133,6 +213,7 @@ where
ReverseNoneIsLowerComparator: Comparator<T>,
NaturalComparator: Comparator<T>,
ReverseComparator: Comparator<T>,
NaturalNoneIsHigherComparator: Comparator<T>,
{
#[inline(always)]
fn compare(&self, lhs: &T, rhs: &T) -> Ordering {
@@ -140,6 +221,7 @@ where
ComparatorEnum::Natural => NaturalComparator.compare(lhs, rhs),
ComparatorEnum::Reverse => ReverseComparator.compare(lhs, rhs),
ComparatorEnum::ReverseNoneLower => ReverseNoneIsLowerComparator.compare(lhs, rhs),
ComparatorEnum::NaturalNoneHigher => NaturalNoneIsHigherComparator.compare(lhs, rhs),
}
}
}
@@ -346,3 +428,31 @@ where
.convert_segment_sort_key(sort_key)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_natural_none_is_higher() {
let comp = NaturalNoneIsHigherComparator;
let null = None;
let v1 = Some(1_u64);
let v2 = Some(2_u64);
// NaturalNoneIsGreaterComparator logic:
// 1. Delegates to NaturalComparator for non-nulls.
// NaturalComparator compare(2, 1) -> 2.cmp(1) -> Greater.
assert_eq!(comp.compare(&v2, &v1), Ordering::Greater);
// 2. Treats None (Null) as Greater than any value.
// compare(None, Some(2)) should be Greater.
assert_eq!(comp.compare(&null, &v2), Ordering::Greater);
// compare(Some(1), None) should be Less.
assert_eq!(comp.compare(&v1, &null), Ordering::Less);
// compare(None, None) should be Equal.
assert_eq!(comp.compare(&null, &null), Ordering::Equal);
}
}

View File

@@ -1,64 +1,22 @@
use std::cmp::Ordering;
use serde::{Deserialize, Serialize};
/// Contains a feature (field, score, etc.) of a document along with the document address.
///
/// It guarantees stable sorting: in case of a tie on the feature, the document
/// address is used.
///
/// The REVERSE_ORDER generic parameter controls whether the by-feature order
/// should be reversed, which is useful for achieving for example largest-first
/// semantics without having to wrap the feature in a `Reverse`.
#[derive(Clone, Default, Serialize, Deserialize)]
pub struct ComparableDoc<T, D, const REVERSE_ORDER: bool = false> {
/// Used only by TopNComputer, which implements the actual comparison via a `Comparator`.
#[derive(Clone, Default, Eq, PartialEq, Serialize, Deserialize)]
pub struct ComparableDoc<T, D> {
/// The feature of the document. In practice, this is
/// is any type that implements `PartialOrd`.
/// is a type which can be compared with a `Comparator<T>`.
pub sort_key: T,
/// The document address. In practice, this is any
/// type that implements `PartialOrd`, and is guaranteed
/// to be unique for each document.
/// The document address. In practice, this is either a `DocId` or `DocAddress`.
pub doc: D,
}
impl<T: std::fmt::Debug, D: std::fmt::Debug, const R: bool> std::fmt::Debug
for ComparableDoc<T, D, R>
{
impl<T: std::fmt::Debug, D: std::fmt::Debug> std::fmt::Debug for ComparableDoc<T, D> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct(format!("ComparableDoc<_, _ {R}").as_str())
f.debug_struct("ComparableDoc")
.field("feature", &self.sort_key)
.field("doc", &self.doc)
.finish()
}
}
impl<T: PartialOrd, D: PartialOrd, const R: bool> PartialOrd for ComparableDoc<T, D, R> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<T: PartialOrd, D: PartialOrd, const R: bool> Ord for ComparableDoc<T, D, R> {
#[inline]
fn cmp(&self, other: &Self) -> Ordering {
let by_feature = self
.sort_key
.partial_cmp(&other.sort_key)
.map(|ord| if R { ord.reverse() } else { ord })
.unwrap_or(Ordering::Equal);
let lazy_by_doc_address = || self.doc.partial_cmp(&other.doc).unwrap_or(Ordering::Equal);
// In case of a tie on the feature, we sort by ascending
// `DocAddress` in order to ensure a stable sorting of the
// documents.
by_feature.then_with(lazy_by_doc_address)
}
}
impl<T: PartialOrd, D: PartialOrd, const R: bool> PartialEq for ComparableDoc<T, D, R> {
fn eq(&self, other: &Self) -> bool {
self.cmp(other) == Ordering::Equal
}
}
impl<T: PartialOrd, D: PartialOrd, const R: bool> Eq for ComparableDoc<T, D, R> {}

View File

@@ -23,10 +23,9 @@ use crate::{DocAddress, DocId, Order, Score, SegmentReader};
/// The theoretical complexity for collecting the top `K` out of `N` documents
/// is `O(N + K)`.
///
/// This collector does not guarantee a stable sorting in case of a tie on the
/// document score, for stable sorting `PartialOrd` needs to resolve on other fields
/// like docid in case of score equality.
/// Only then, it is suitable for pagination.
/// This collector guarantees a stable sorting in case of a tie on the
/// document score/sort key: The document address (`DocAddress`) is used as a tie breaker.
/// In case of a tie on the sort key, documents are always sorted by ascending `DocAddress`.
///
/// ```rust
/// use tantivy::collector::TopDocs;
@@ -500,8 +499,13 @@ where
///
/// For TopN == 0, it will be relative expensive.
///
/// When using the natural comparator, the top N computer returns the top N elements in
/// descending order, as expected for a top N.
/// The TopNComputer will tiebreak by using ascending `D` (DocId or DocAddress):
/// i.e., in case of a tie on the sort key, the `DocId|DocAddress` are always sorted in
/// ascending order, regardless of the `Comparator` used for the `Score` type.
///
/// NOTE: Items must be `push`ed to the TopNComputer in ascending `DocId|DocAddress` order, as the
/// threshold used to eliminate docs does not include the `DocId` or `DocAddress`: this provides
/// the ascending `DocId|DocAddress` tie-breaking behavior without additional comparisons.
#[derive(Serialize, Deserialize)]
#[serde(from = "TopNComputerDeser<Score, D, C>")]
pub struct TopNComputer<Score, D, C> {
@@ -580,6 +584,18 @@ where
}
}
#[inline(always)]
fn compare_for_top_k<TSortKey, D: Ord, C: Comparator<TSortKey>>(
c: &C,
lhs: &ComparableDoc<TSortKey, D>,
rhs: &ComparableDoc<TSortKey, D>,
) -> std::cmp::Ordering {
c.compare(&lhs.sort_key, &rhs.sort_key)
.reverse() // Reverse here because we want top K.
.then_with(|| lhs.doc.cmp(&rhs.doc)) // Regardless of asc/desc, in presence of a tie, we
// sort by doc id
}
impl<TSortKey, D, C> TopNComputer<TSortKey, D, C>
where
D: Ord,
@@ -600,10 +616,13 @@ where
/// Push a new document to the top n.
/// If the document is below the current threshold, it will be ignored.
///
/// NOTE: `push` must be called in ascending `DocId`/`DocAddress` order.
#[inline]
pub fn push(&mut self, sort_key: TSortKey, doc: D) {
if let Some(last_median) = &self.threshold {
if self.comparator.compare(&sort_key, last_median) == Ordering::Less {
// See the struct docs for an explanation of why this comparison is strict.
if self.comparator.compare(&sort_key, last_median) != Ordering::Greater {
return;
}
}
@@ -629,9 +648,7 @@ where
fn truncate_top_n(&mut self) -> TSortKey {
// Use select_nth_unstable to find the top nth score
let (_, median_el, _) = self.buffer.select_nth_unstable_by(self.top_n, |lhs, rhs| {
self.comparator
.compare(&rhs.sort_key, &lhs.sort_key)
.then_with(|| lhs.doc.cmp(&rhs.doc))
compare_for_top_k(&self.comparator, lhs, rhs)
});
let median_score = median_el.sort_key.clone();
@@ -646,11 +663,8 @@ where
if self.buffer.len() > self.top_n {
self.truncate_top_n();
}
self.buffer.sort_unstable_by(|left, right| {
self.comparator
.compare(&right.sort_key, &left.sort_key)
.then_with(|| left.doc.cmp(&right.doc))
});
self.buffer
.sort_unstable_by(|lhs, rhs| compare_for_top_k(&self.comparator, lhs, rhs));
self.buffer
}
@@ -755,6 +769,33 @@ mod tests {
);
}
#[test]
fn test_topn_computer_duplicates() {
let mut computer: TopNComputer<u32, u32, NaturalComparator> =
TopNComputer::new_with_comparator(2, NaturalComparator);
computer.push(1u32, 1u32);
computer.push(1u32, 2u32);
computer.push(1u32, 3u32);
computer.push(1u32, 4u32);
computer.push(1u32, 5u32);
// In the presence of duplicates, DocIds are always ascending order.
assert_eq!(
computer.into_sorted_vec(),
&[
ComparableDoc {
sort_key: 1u32,
doc: 1u32,
},
ComparableDoc {
sort_key: 1u32,
doc: 2u32,
}
]
);
}
#[test]
fn test_topn_computer_no_panic() {
for top_n in 0..10 {
@@ -772,14 +813,17 @@ mod tests {
#[test]
fn test_topn_computer_asc_prop(
limit in 0..10_usize,
docs in proptest::collection::vec((0..100_u64, 0..100_u64), 0..100_usize),
mut docs in proptest::collection::vec((0..100_u64, 0..100_u64), 0..100_usize),
) {
// NB: TopNComputer must receive inputs in ascending DocId order.
docs.sort_by_key(|(_, doc_id)| *doc_id);
let mut computer: TopNComputer<_, _, ReverseComparator> = TopNComputer::new_with_comparator(limit, ReverseComparator);
for (feature, doc) in &docs {
computer.push(*feature, *doc);
}
let mut comparable_docs: Vec<ComparableDoc<u64, u64>> = docs.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc }).collect::<Vec<_>>();
comparable_docs.sort();
let mut comparable_docs: Vec<ComparableDoc<u64, u64>> =
docs.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc }).collect();
crate::collector::sort_key::tests::sort_hits(&mut comparable_docs, Order::Asc);
comparable_docs.truncate(limit);
prop_assert_eq!(
computer.into_sorted_vec(),
@@ -1406,15 +1450,10 @@ mod tests {
// Using the TopDocs collector should always be equivalent to sorting, skipping the
// offset, and then taking the limit.
let sorted_docs: Vec<_> = if order.is_desc() {
let mut comparable_docs: Vec<ComparableDoc<_, _, true>> =
let sorted_docs: Vec<_> = {
let mut comparable_docs: Vec<ComparableDoc<_, _>> =
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
comparable_docs.sort();
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
} else {
let mut comparable_docs: Vec<ComparableDoc<_, _, false>> =
all_results.into_iter().map(|(sort_key, doc)| ComparableDoc { sort_key, doc}).collect();
comparable_docs.sort();
crate::collector::sort_key::tests::sort_hits(&mut comparable_docs, order);
comparable_docs.into_iter().map(|cd| (cd.sort_key, cd.doc)).collect()
};
let expected_docs = sorted_docs.into_iter().skip(offset).take(limit).collect::<Vec<_>>();

View File

@@ -406,7 +406,7 @@ mod tests {
let mut term = Term::from_field_json_path(field, "color", false);
term.append_type_and_str("red");
assert_eq!(term.serialized_term(), b"\x00\x00\x00\x01jcolor\x00sred")
assert_eq!(term.serialized_value_bytes(), b"color\x00sred".to_vec())
}
#[test]
@@ -416,8 +416,8 @@ mod tests {
term.append_type_and_fast_value(-4i64);
assert_eq!(
term.serialized_term(),
b"\x00\x00\x00\x01jcolor\x00i\x7f\xff\xff\xff\xff\xff\xff\xfc"
term.serialized_value_bytes(),
b"color\x00i\x7f\xff\xff\xff\xff\xff\xff\xfc".to_vec()
)
}
@@ -428,8 +428,8 @@ mod tests {
term.append_type_and_fast_value(4u64);
assert_eq!(
term.serialized_term(),
b"\x00\x00\x00\x01jcolor\x00u\x00\x00\x00\x00\x00\x00\x00\x04"
term.serialized_value_bytes(),
b"color\x00u\x00\x00\x00\x00\x00\x00\x00\x04".to_vec()
)
}
@@ -439,8 +439,8 @@ mod tests {
let mut term = Term::from_field_json_path(field, "color", false);
term.append_type_and_fast_value(4.0f64);
assert_eq!(
term.serialized_term(),
b"\x00\x00\x00\x01jcolor\x00f\xc0\x10\x00\x00\x00\x00\x00\x00"
term.serialized_value_bytes(),
b"color\x00f\xc0\x10\x00\x00\x00\x00\x00\x00".to_vec()
)
}
@@ -450,8 +450,8 @@ mod tests {
let mut term = Term::from_field_json_path(field, "color", false);
term.append_type_and_fast_value(true);
assert_eq!(
term.serialized_term(),
b"\x00\x00\x00\x01jcolor\x00o\x00\x00\x00\x00\x00\x00\x00\x01"
term.serialized_value_bytes(),
b"color\x00o\x00\x00\x00\x00\x00\x00\x00\x01".to_vec()
)
}

View File

@@ -5,7 +5,7 @@ use std::ops::Range;
use common::{BinarySerializable, CountingWriter, HasLen, VInt};
use crate::directory::{FileSlice, TerminatingWrite, WritePtr};
use crate::schema::Field;
use crate::schema::{Field, Schema};
use crate::space_usage::{FieldUsage, PerFieldSpaceUsage};
#[derive(Eq, PartialEq, Hash, Copy, Ord, PartialOrd, Clone, Debug)]
@@ -167,10 +167,11 @@ impl CompositeFile {
.map(|byte_range| self.data.slice(byte_range.clone()))
}
pub fn space_usage(&self) -> PerFieldSpaceUsage {
pub fn space_usage(&self, schema: &Schema) -> PerFieldSpaceUsage {
let mut fields = Vec::new();
for (&field_addr, byte_range) in &self.offsets_index {
let mut field_usage = FieldUsage::empty(field_addr.field);
let field_name = schema.get_field_name(field_addr.field).to_string();
let mut field_usage = FieldUsage::empty(field_name);
field_usage.add_field_idx(field_addr.idx, byte_range.len().into());
fields.push(field_usage);
}

View File

@@ -40,6 +40,8 @@ pub trait DocSet: Send {
/// of `DocSet` should support it.
///
/// Calling `seek(TERMINATED)` is also legal and is the normal way to consume a `DocSet`.
///
/// `target` has to be larger or equal to `.doc()` when calling `seek`.
fn seek(&mut self, target: DocId) -> DocId {
let mut doc = self.doc();
debug_assert!(doc <= target);
@@ -49,6 +51,33 @@ pub trait DocSet: Send {
doc
}
/// Seeks to the target if possible and returns true if the target is in the DocSet.
///
/// DocSets that already have an efficient `seek` method don't need to implement
/// `seek_into_the_danger_zone`. All wrapper DocSets should forward
/// `seek_into_the_danger_zone` to the underlying DocSet.
///
/// ## API Behaviour
/// If `seek_into_the_danger_zone` is returning true, a call to `doc()` has to return target.
/// If `seek_into_the_danger_zone` is returning false, a call to `doc()` may return any doc
/// between the last doc that matched and target or a doc that is a valid next hit after
/// target. The DocSet is considered to be in an invalid state until
/// `seek_into_the_danger_zone` returns true again.
///
/// `target` needs to be equal or larger than `doc` when in a valid state.
///
/// Consecutive calls are not allowed to have decreasing `target` values.
///
/// # Warning
/// This is an advanced API used by intersection. The API contract is tricky, avoid using it.
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
let current_doc = self.doc();
if current_doc < target {
self.seek(target);
}
self.doc() == target
}
/// Fills a given mutable buffer with the next doc ids from the
/// `DocSet`
///
@@ -94,6 +123,15 @@ pub trait DocSet: Send {
/// which would be the number of documents in the DocSet.
///
/// By default this returns `size_hint()`.
///
/// DocSets may have vastly different cost depending on their type,
/// e.g. an intersection with 10 hits is much cheaper than
/// a phrase search with 10 hits, since it needs to load positions.
///
/// ### Future Work
/// We may want to differentiate `DocSet` costs more more granular, e.g.
/// creation_cost, advance_cost, seek_cost on to get a good estimation
/// what query types to choose.
fn cost(&self) -> u64 {
self.size_hint() as u64
}
@@ -137,6 +175,10 @@ impl DocSet for &mut dyn DocSet {
(**self).seek(target)
}
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
(**self).seek_into_the_danger_zone(target)
}
fn doc(&self) -> u32 {
(**self).doc()
}
@@ -169,6 +211,11 @@ impl<TDocSet: DocSet + ?Sized> DocSet for Box<TDocSet> {
unboxed.seek(target)
}
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
let unboxed: &mut TDocSet = self.borrow_mut();
unboxed.seek_into_the_danger_zone(target)
}
fn fill_buffer(&mut self, buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN]) -> usize {
let unboxed: &mut TDocSet = self.borrow_mut();
unboxed.fill_buffer(buffer)

View File

@@ -8,7 +8,7 @@ use columnar::{
};
use common::ByteCount;
use crate::core::json_utils::encode_column_name;
use crate::core::json_utils::{encode_column_name, json_path_sep_to_dot};
use crate::directory::FileSlice;
use crate::schema::{Field, FieldEntry, FieldType, Schema};
use crate::space_usage::{FieldUsage, PerFieldSpaceUsage};
@@ -39,19 +39,15 @@ impl FastFieldReaders {
self.resolve_column_name_given_default_field(column_name, default_field_opt)
}
pub(crate) fn space_usage(&self, schema: &Schema) -> io::Result<PerFieldSpaceUsage> {
pub(crate) fn space_usage(&self) -> io::Result<PerFieldSpaceUsage> {
let mut per_field_usages: Vec<FieldUsage> = Default::default();
for (field, field_entry) in schema.fields() {
let column_handles = self.columnar.read_columns(field_entry.name())?;
let num_bytes: ByteCount = column_handles
.iter()
.map(|column_handle| column_handle.num_bytes())
.sum();
let mut field_usage = FieldUsage::empty(field);
field_usage.add_field_idx(0, num_bytes);
for (mut field_name, column_handle) in self.columnar.iter_columns()? {
json_path_sep_to_dot(&mut field_name);
let space_usage = column_handle.space_usage()?;
let mut field_usage = FieldUsage::empty(field_name);
field_usage.set_column_usage(space_usage);
per_field_usages.push(field_usage);
}
// TODO fix space usage for JSON fields.
Ok(PerFieldSpaceUsage::new(per_field_usages))
}

View File

@@ -2,7 +2,7 @@ use std::sync::Arc;
use super::{fieldnorm_to_id, id_to_fieldnorm};
use crate::directory::{CompositeFile, FileSlice, OwnedBytes};
use crate::schema::Field;
use crate::schema::{Field, Schema};
use crate::space_usage::PerFieldSpaceUsage;
use crate::DocId;
@@ -37,8 +37,8 @@ impl FieldNormReaders {
}
/// Return a break down of the space usage per field.
pub fn space_usage(&self) -> PerFieldSpaceUsage {
self.data.space_usage()
pub fn space_usage(&self, schema: &Schema) -> PerFieldSpaceUsage {
self.data.space_usage(schema)
}
/// Returns a handle to inner file

View File

@@ -13,9 +13,9 @@ use crate::store::Compressor;
use crate::{Inventory, Opstamp, TrackedObject};
#[derive(Clone, Debug, Serialize, Deserialize)]
struct DeleteMeta {
pub struct DeleteMeta {
num_deleted_docs: u32,
opstamp: Opstamp,
pub opstamp: Opstamp,
}
#[derive(Clone, Default)]
@@ -213,7 +213,7 @@ impl SegmentMeta {
struct InnerSegmentMeta {
segment_id: SegmentId,
max_doc: u32,
deletes: Option<DeleteMeta>,
pub deletes: Option<DeleteMeta>,
/// If you want to avoid the SegmentComponent::TempStore file to be covered by
/// garbage collection and deleted, set this to true. This is used during merge.
#[serde(skip)]

View File

@@ -46,7 +46,7 @@ impl Segment {
///
/// This method is only used when updating `max_doc` from 0
/// as we finalize a fresh new segment.
pub(crate) fn with_max_doc(self, max_doc: u32) -> Segment {
pub fn with_max_doc(self, max_doc: u32) -> Segment {
Segment {
index: self.index,
meta: self.meta.with_max_doc(max_doc),

View File

@@ -455,11 +455,11 @@ impl SegmentReader {
pub fn space_usage(&self) -> io::Result<SegmentSpaceUsage> {
Ok(SegmentSpaceUsage::new(
self.num_docs(),
self.termdict_composite.space_usage(),
self.postings_composite.space_usage(),
self.positions_composite.space_usage(),
self.fast_fields_readers.space_usage(self.schema())?,
self.fieldnorm_readers.space_usage(),
self.termdict_composite.space_usage(self.schema()),
self.postings_composite.space_usage(self.schema()),
self.positions_composite.space_usage(self.schema()),
self.fast_fields_readers.space_usage()?,
self.fieldnorm_readers.space_usage(self.schema()),
self.get_store_reader(0)?.space_usage(),
self.alive_bitset_opt
.as_ref()

View File

@@ -23,13 +23,18 @@ struct InnerDeleteQueue {
last_block: Weak<Block>,
}
/// The delete queue is a linked list storing delete operations.
///
/// Several consumers can hold a reference to it. Delete operations
/// get dropped/gc'ed when no more consumers are holding a reference
/// to them.
#[derive(Clone)]
pub struct DeleteQueue {
inner: Arc<RwLock<InnerDeleteQueue>>,
}
impl DeleteQueue {
// Creates a new delete queue.
/// Creates a new empty delete queue.
pub fn new() -> DeleteQueue {
DeleteQueue {
inner: Arc::default(),
@@ -58,10 +63,10 @@ impl DeleteQueue {
block
}
// Creates a new cursor that makes it possible to
// consume future delete operations.
//
// Past delete operations are not accessible.
/// Creates a new cursor that makes it possible to
/// consume future delete operations.
///
/// Past delete operations are not accessible.
pub fn cursor(&self) -> DeleteCursor {
let last_block = self.get_last_block();
let operations_len = last_block.operations.len();
@@ -71,7 +76,7 @@ impl DeleteQueue {
}
}
// Appends a new delete operations.
/// Appends a new delete operations.
pub fn push(&self, delete_operation: DeleteOperation) {
self.inner
.write()
@@ -169,6 +174,7 @@ struct Block {
next: NextBlock,
}
/// As we process delete operations, keeps track of our position.
#[derive(Clone)]
pub struct DeleteCursor {
block: Arc<Block>,

View File

@@ -128,7 +128,7 @@ fn compute_deleted_bitset(
/// is `==` target_opstamp.
/// For instance, there was no delete operation between the state of the `segment_entry` and
/// the `target_opstamp`, `segment_entry` is not updated.
pub(crate) fn advance_deletes(
pub fn advance_deletes(
mut segment: Segment,
segment_entry: &mut SegmentEntry,
target_opstamp: Opstamp,

View File

@@ -3,21 +3,21 @@ use std::net::Ipv6Addr;
use columnar::MonotonicallyMappableToU128;
use crate::fastfield::FastValue;
use crate::schema::{Field, Type};
use crate::schema::Field;
/// Term represents the value that the token can take.
/// It's a serialized representation over different types.
/// IndexingTerm is used to represent a term during indexing.
/// It's a serialized representation over field and value.
///
/// It actually wraps a `Vec<u8>`. The first 5 bytes are metadata.
/// 4 bytes are the field id, and the last byte is the type.
/// It actually wraps a `Vec<u8>`. The first 4 bytes are the field.
///
/// The serialized value `ValueBytes` is considered everything after the 4 first bytes (term id).
/// We serialize the field, because we index everything in a single
/// global term dictionary during indexing.
#[derive(Clone)]
pub(crate) struct IndexingTerm<B = Vec<u8>>(B)
where B: AsRef<[u8]>;
/// The number of bytes used as metadata by `Term`.
const TERM_METADATA_LENGTH: usize = 5;
const TERM_METADATA_LENGTH: usize = 4;
impl IndexingTerm {
/// Create a new Term with a buffer with a given capacity.
@@ -31,10 +31,9 @@ impl IndexingTerm {
/// Use `clear_with_field_and_type` in that case.
///
/// Sets field and the type.
pub(crate) fn set_field_and_type(&mut self, field: Field, typ: Type) {
pub(crate) fn set_field(&mut self, field: Field) {
assert!(self.is_empty());
self.0[0..4].clone_from_slice(field.field_id().to_be_bytes().as_ref());
self.0[4] = typ.to_code();
}
/// Is empty if there are no value bytes.
@@ -42,10 +41,10 @@ impl IndexingTerm {
self.0.len() == TERM_METADATA_LENGTH
}
/// Removes the value_bytes and set the field and type code.
pub(crate) fn clear_with_field_and_type(&mut self, typ: Type, field: Field) {
/// Removes the value_bytes and set the field
pub(crate) fn clear_with_field(&mut self, field: Field) {
self.truncate_value_bytes(0);
self.set_field_and_type(field, typ);
self.set_field(field);
}
/// Sets a u64 value in the term.
@@ -122,6 +121,23 @@ impl IndexingTerm {
impl<B> IndexingTerm<B>
where B: AsRef<[u8]>
{
/// Wraps serialized term bytes.
///
/// The input buffer is expected to be the concatenation of the big endian encoded field id
/// followed by the serialized value bytes (type tag + payload).
#[inline]
pub fn wrap(serialized_term: B) -> IndexingTerm<B> {
debug_assert!(serialized_term.as_ref().len() >= TERM_METADATA_LENGTH);
IndexingTerm(serialized_term)
}
/// Returns the field this term belongs to.
#[inline]
pub fn field(&self) -> Field {
let field_id_bytes: [u8; 4] = self.0.as_ref()[..4].try_into().unwrap();
Field::from_field_id(u32::from_be_bytes(field_id_bytes))
}
/// Returns the serialized representation of Term.
/// This includes field_id, value type and value.
///
@@ -136,6 +152,7 @@ where B: AsRef<[u8]>
#[cfg(test)]
mod tests {
use super::IndexingTerm;
use crate::schema::*;
#[test]
@@ -143,42 +160,55 @@ mod tests {
let mut schema_builder = Schema::builder();
schema_builder.add_text_field("text", STRING);
let title_field = schema_builder.add_text_field("title", STRING);
let term = Term::from_field_text(title_field, "test");
let mut term = IndexingTerm::with_capacity(0);
term.set_field(title_field);
term.set_bytes(b"test");
assert_eq!(term.field(), title_field);
assert_eq!(term.typ(), Type::Str);
assert_eq!(term.value().as_str(), Some("test"))
assert_eq!(term.serialized_term(), b"\x00\x00\x00\x01test".to_vec())
}
/// Size (in bytes) of the buffer of a fast value (u64, i64, f64, or date) term.
/// <field> + <type byte> + <value len>
///
/// - <field> is a big endian encoded u32 field id
/// - <type_byte>'s most significant bit expresses whether the term is a json term or not The
/// remaining 7 bits are used to encode the type of the value. If this is a JSON term, the
/// type is the type of the leaf of the json.
/// - <value> is, if this is not the json term, a binary representation specific to the type.
/// If it is a JSON Term, then it is prepended with the path that leads to this leaf value.
const FAST_VALUE_TERM_LEN: usize = 4 + 1 + 8;
const FAST_VALUE_TERM_LEN: usize = 4 + 8;
#[test]
pub fn test_term_u64() {
let mut schema_builder = Schema::builder();
let count_field = schema_builder.add_u64_field("count", INDEXED);
let term = Term::from_field_u64(count_field, 983u64);
let mut term = IndexingTerm::with_capacity(0);
term.set_field(count_field);
term.set_u64(983u64);
assert_eq!(term.field(), count_field);
assert_eq!(term.typ(), Type::U64);
assert_eq!(term.serialized_term().len(), FAST_VALUE_TERM_LEN);
assert_eq!(term.value().as_u64(), Some(983u64))
}
#[test]
pub fn test_term_bool() {
let mut schema_builder = Schema::builder();
let bool_field = schema_builder.add_bool_field("bool", INDEXED);
let term = Term::from_field_bool(bool_field, true);
let term = {
let mut term = IndexingTerm::with_capacity(0);
term.set_field(bool_field);
term.set_bool(true);
term
};
assert_eq!(term.field(), bool_field);
assert_eq!(term.typ(), Type::Bool);
assert_eq!(term.serialized_term().len(), FAST_VALUE_TERM_LEN);
assert_eq!(term.value().as_bool(), Some(true))
}
#[test]
pub fn indexing_term_wrap_extracts_field() {
let field = Field::from_field_id(7u32);
let mut term = IndexingTerm::with_capacity(0);
term.set_field(field);
term.append_bytes(b"abc");
let wrapped = IndexingTerm::wrap(term.serialized_term());
assert_eq!(wrapped.field(), field);
assert_eq!(wrapped.serialized_term(), term.serialized_term());
}
}

View File

@@ -4,7 +4,7 @@
//! `IndexWriter` is the main entry point for that, which created from
//! [`Index::writer`](crate::Index::writer).
pub(crate) mod delete_queue;
pub mod delete_queue;
pub(crate) mod path_to_unordered_id;
pub(crate) mod doc_id_mapping;
@@ -32,12 +32,11 @@ mod stamper;
use crossbeam_channel as channel;
use smallvec::SmallVec;
pub use self::index_writer::{IndexWriter, IndexWriterOptions};
pub use self::index_writer::{advance_deletes, IndexWriter, IndexWriterOptions};
pub use self::log_merge_policy::LogMergePolicy;
pub use self::merge_operation::MergeOperation;
pub use self::merge_policy::{MergeCandidate, MergePolicy, NoMergePolicy};
use self::operation::AddOperation;
pub use self::operation::UserOperation;
pub use self::operation::{AddOperation, DeleteOperation, UserOperation};
pub use self::prepared_commit::PreparedCommit;
pub use self::segment_entry::SegmentEntry;
pub(crate) use self::segment_serializer::SegmentSerializer;

View File

@@ -5,14 +5,20 @@ use crate::Opstamp;
/// Timestamped Delete operation.
pub struct DeleteOperation {
/// Operation stamp.
/// It is used to check whether the delete operation
/// applies to an added document operation.
pub opstamp: Opstamp,
/// Weight is used to define the set of documents to be deleted.
pub target: Box<dyn Weight>,
}
/// Timestamped Add operation.
#[derive(Eq, PartialEq, Debug)]
pub struct AddOperation<D: Document = TantivyDocument> {
/// Operation stamp.
pub opstamp: Opstamp,
/// Document to be added.
pub document: D,
}

View File

@@ -171,7 +171,7 @@ impl SegmentWriter {
let (term_buffer, ctx) = (&mut self.term_buffer, &mut self.ctx);
let postings_writer: &mut dyn PostingsWriter =
self.per_field_postings_writers.get_for_field_mut(field);
term_buffer.clear_with_field_and_type(field_entry.field_type().value_type(), field);
term_buffer.clear_with_field(field);
match field_entry.field_type() {
FieldType::Facet(_) => {

View File

@@ -216,9 +216,7 @@ use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
pub use self::docset::{DocSet, COLLECT_BLOCK_BUFFER_LEN, TERMINATED};
#[doc(hidden)]
pub use crate::core::json_utils;
pub use crate::core::{Executor, Searcher, SearcherGeneration};
pub use crate::core::{json_utils, Executor, Searcher, SearcherGeneration};
pub use crate::directory::Directory;
pub use crate::index::{
Index, IndexBuilder, IndexMeta, IndexSettings, InvertedIndexReader, Order, Segment,

View File

@@ -1,12 +1,15 @@
use bitpacking::{BitPacker, BitPacker4x};
use common::FixedSize;
pub const COMPRESSION_BLOCK_SIZE: usize = BitPacker4x::BLOCK_LEN;
const COMPRESSED_BLOCK_MAX_SIZE: usize = COMPRESSION_BLOCK_SIZE * u32::SIZE_IN_BYTES;
// in vint encoding, each byte stores 7 bits of data, so we need at most 32 / 7 = 4.57 bytes to
// store a u32 in the worst case, rounding up to 5 bytes total
const MAX_VINT_SIZE: usize = 5;
const COMPRESSED_BLOCK_MAX_SIZE: usize = COMPRESSION_BLOCK_SIZE * MAX_VINT_SIZE;
mod vint;
/// Returns the size in bytes of a compressed block, given `num_bits`.
#[inline]
pub fn compressed_block_size(num_bits: u8) -> usize {
(num_bits as usize) * COMPRESSION_BLOCK_SIZE / 8
}
@@ -267,7 +270,6 @@ impl VIntDecoder for BlockDecoder {
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use crate::TERMINATED;
@@ -372,6 +374,13 @@ pub(crate) mod tests {
}
}
}
#[test]
fn test_compress_vint_unsorted_does_not_overflow() {
let mut encoder = BlockEncoder::new();
let input: Vec<u32> = vec![u32::MAX; COMPRESSION_BLOCK_SIZE];
encoder.compress_vint_unsorted(&input);
}
}
#[cfg(all(test, feature = "unstable"))]

View File

@@ -8,7 +8,7 @@ use crate::indexer::path_to_unordered_id::OrderedPathId;
use crate::postings::postings_writer::SpecializedPostingsWriter;
use crate::postings::recorder::{BufferLender, DocIdRecorder, Recorder};
use crate::postings::{FieldSerializer, IndexingContext, IndexingPosition, PostingsWriter};
use crate::schema::{Field, Type, ValueBytes};
use crate::schema::{Field, Type};
use crate::tokenizer::TokenStream;
use crate::DocId;
@@ -79,8 +79,7 @@ impl<Rec: Recorder> PostingsWriter for JsonPostingsWriter<Rec> {
term_buffer.truncate(term_path_len);
term_buffer.append_bytes(term);
let json_value = ValueBytes::wrap(term);
let typ = json_value.typ();
let typ = Type::from_code(term[0]).expect("Invalid type code in JSON term");
if typ == Type::Str {
SpecializedPostingsWriter::<Rec>::serialize_one_term(
term_buffer.as_bytes(),
@@ -107,6 +106,8 @@ impl<Rec: Recorder> PostingsWriter for JsonPostingsWriter<Rec> {
}
}
/// Helper to build the JSON term bytes that land in the term dictionary.
/// Format: `[json path utf8][JSON_END_OF_PATH][type tag][payload]`
struct JsonTermSerializer(Vec<u8>);
impl JsonTermSerializer {
/// Appends a JSON path to the Term.

View File

@@ -11,7 +11,7 @@ use crate::postings::recorder::{BufferLender, Recorder};
use crate::postings::{
FieldSerializer, IndexingContext, InvertedIndexSerializer, PerFieldPostingsWriter,
};
use crate::schema::{Field, Schema, Term, Type};
use crate::schema::{Field, Schema, Type};
use crate::tokenizer::{Token, TokenStream, MAX_TOKEN_LEN};
use crate::DocId;
@@ -59,14 +59,14 @@ pub(crate) fn serialize_postings(
let mut term_offsets: Vec<(Field, OrderedPathId, &[u8], Addr)> =
Vec::with_capacity(ctx.term_index.len());
term_offsets.extend(ctx.term_index.iter().map(|(key, addr)| {
let field = Term::wrap(key).field();
let field = IndexingTerm::wrap(key).field();
if schema.get_field_entry(field).field_type().value_type() == Type::Json {
let byte_range_path = 5..5 + 4;
let byte_range_path = 4..4 + 4;
let unordered_id = u32::from_be_bytes(key[byte_range_path.clone()].try_into().unwrap());
let path_id = unordered_id_to_ordered_id[unordered_id as usize];
(field, path_id, &key[byte_range_path.end..], addr)
} else {
(field, 0.into(), &key[5..], addr)
(field, 0.into(), &key[4..], addr)
}
}));
// Sort by field, path, and term

View File

@@ -23,7 +23,11 @@ pub struct AllWeight;
impl Weight for AllWeight {
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
let all_scorer = AllScorer::new(reader.max_doc());
Ok(Box::new(BoostScorer::new(all_scorer, boost)))
if boost != 1.0 {
Ok(Box::new(BoostScorer::new(all_scorer, boost)))
} else {
Ok(Box::new(all_scorer))
}
}
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {
@@ -58,6 +62,15 @@ impl DocSet for AllScorer {
self.doc
}
fn seek(&mut self, target: DocId) -> DocId {
debug_assert!(target >= self.doc);
self.doc = target;
if self.doc >= self.max_doc {
self.doc = TERMINATED;
}
self.doc
}
fn fill_buffer(&mut self, buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN]) -> usize {
if self.doc() == TERMINATED {
return 0;

View File

@@ -483,7 +483,7 @@ mod tests {
let checkpoints_for_each_pruning =
compute_checkpoints_for_each_pruning(term_scorers.clone(), top_k);
let checkpoints_manual =
compute_checkpoints_manual(term_scorers.clone(), top_k, 100_000);
compute_checkpoints_manual(term_scorers.clone(), top_k, max_doc as u32);
assert_eq!(checkpoints_for_each_pruning.len(), checkpoints_manual.len());
for (&(left_doc, left_score), &(right_doc, right_score)) in checkpoints_for_each_pruning
.iter()

View File

@@ -97,6 +97,65 @@ fn into_box_scorer<TScoreCombiner: ScoreCombiner>(
}
}
/// Returns the effective MUST scorer, accounting for removed AllScorers.
///
/// When AllScorer instances are removed from must_scorers as an optimization,
/// we must restore the "match all" semantics if the list becomes empty.
fn effective_must_scorer(
must_scorers: Vec<Box<dyn Scorer>>,
removed_all_scorer_count: usize,
max_doc: DocId,
num_docs: u32,
) -> Option<Box<dyn Scorer>> {
if must_scorers.is_empty() {
if removed_all_scorer_count > 0 {
// Had AllScorer(s) only - all docs match
Some(Box::new(AllScorer::new(max_doc)))
} else {
// No MUST constraint at all
None
}
} else {
Some(intersect_scorers(must_scorers, num_docs))
}
}
/// Returns a SHOULD scorer with AllScorer union if any were removed.
///
/// For union semantics (OR): if any SHOULD clause was an AllScorer, the result
/// should include all documents. We restore this by unioning with AllScorer.
///
/// When `scoring_enabled` is false, we can just return AllScorer alone since
/// we don't need score contributions from the should_scorer.
fn effective_should_scorer_for_union<TScoreCombiner: ScoreCombiner>(
should_scorer: SpecializedScorer,
removed_all_scorer_count: usize,
max_doc: DocId,
num_docs: u32,
score_combiner_fn: impl Fn() -> TScoreCombiner,
scoring_enabled: bool,
) -> SpecializedScorer {
if removed_all_scorer_count > 0 {
if scoring_enabled {
// Need to union to get score contributions from both
let all_scorers: Vec<Box<dyn Scorer>> = vec![
into_box_scorer(should_scorer, &score_combiner_fn, num_docs),
Box::new(AllScorer::new(max_doc)),
];
SpecializedScorer::Other(Box::new(BufferedUnionScorer::build(
all_scorers,
score_combiner_fn,
num_docs,
)))
} else {
// Scoring disabled - AllScorer alone is sufficient
SpecializedScorer::Other(Box::new(AllScorer::new(max_doc)))
}
} else {
should_scorer
}
}
enum ShouldScorersCombinationMethod {
// Should scorers are irrelevant.
Ignored,
@@ -193,18 +252,18 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)));
}
let minimum_number_should_match = self
let effective_minimum_number_should_match = self
.minimum_number_should_match
.saturating_sub(should_special_scorer_counts.num_all_scorers);
let should_scorers: ShouldScorersCombinationMethod = {
let num_of_should_scorers = should_scorers.len();
if minimum_number_should_match > num_of_should_scorers {
if effective_minimum_number_should_match > num_of_should_scorers {
// We don't have enough scorers to satisfy the minimum number of should matches.
// The request will match no documents.
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)));
}
match minimum_number_should_match {
match effective_minimum_number_should_match {
0 if num_of_should_scorers == 0 => ShouldScorersCombinationMethod::Ignored,
0 => ShouldScorersCombinationMethod::Optional(scorer_union(
should_scorers,
@@ -226,7 +285,7 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
scorer_disjunction(
should_scorers,
score_combiner_fn(),
self.minimum_number_should_match,
effective_minimum_number_should_match,
),
)),
}
@@ -246,53 +305,78 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
let include_scorer = match (should_scorers, must_scorers) {
(ShouldScorersCombinationMethod::Ignored, must_scorers) => {
let boxed_scorer: Box<dyn Scorer> = if must_scorers.is_empty() {
// We do not have any should scorers, nor all scorers.
// There are still two cases here.
//
// If this follows the removal of some AllScorers in the should/must clauses,
// then we match all documents.
//
// Otherwise, it is really just an EmptyScorer.
if must_special_scorer_counts.num_all_scorers
+ should_special_scorer_counts.num_all_scorers
> 0
{
Box::new(AllScorer::new(reader.max_doc()))
} else {
Box::new(EmptyScorer)
}
} else {
intersect_scorers(must_scorers, num_docs)
};
// No SHOULD clauses (or they were absorbed into MUST).
// Result depends entirely on MUST + any removed AllScorers.
let combined_all_scorer_count = must_special_scorer_counts.num_all_scorers
+ should_special_scorer_counts.num_all_scorers;
let boxed_scorer: Box<dyn Scorer> = effective_must_scorer(
must_scorers,
combined_all_scorer_count,
reader.max_doc(),
num_docs,
)
.unwrap_or_else(|| Box::new(EmptyScorer));
SpecializedScorer::Other(boxed_scorer)
}
(ShouldScorersCombinationMethod::Optional(should_scorer), must_scorers) => {
if must_scorers.is_empty() && must_special_scorer_counts.num_all_scorers == 0 {
// Optional options are promoted to required if no must scorers exists.
should_scorer
} else {
let must_scorer = intersect_scorers(must_scorers, num_docs);
if self.scoring_enabled {
SpecializedScorer::Other(Box::new(RequiredOptionalScorer::<
_,
_,
TScoreCombiner,
>::new(
must_scorer,
into_box_scorer(should_scorer, &score_combiner_fn, num_docs),
)))
} else {
SpecializedScorer::Other(must_scorer)
// Optional SHOULD: contributes to scoring but not required for matching.
match effective_must_scorer(
must_scorers,
must_special_scorer_counts.num_all_scorers,
reader.max_doc(),
num_docs,
) {
None => {
// No MUST constraint: promote SHOULD to required.
// Must preserve any removed AllScorers from SHOULD via union.
effective_should_scorer_for_union(
should_scorer,
should_special_scorer_counts.num_all_scorers,
reader.max_doc(),
num_docs,
&score_combiner_fn,
self.scoring_enabled,
)
}
Some(must_scorer) => {
// Has MUST constraint: SHOULD only affects scoring.
if self.scoring_enabled {
SpecializedScorer::Other(Box::new(RequiredOptionalScorer::<
_,
_,
TScoreCombiner,
>::new(
must_scorer,
into_box_scorer(should_scorer, &score_combiner_fn, num_docs),
)))
} else {
SpecializedScorer::Other(must_scorer)
}
}
}
}
(ShouldScorersCombinationMethod::Required(should_scorer), mut must_scorers) => {
if must_scorers.is_empty() {
should_scorer
} else {
must_scorers.push(into_box_scorer(should_scorer, &score_combiner_fn, num_docs));
SpecializedScorer::Other(intersect_scorers(must_scorers, num_docs))
(ShouldScorersCombinationMethod::Required(should_scorer), must_scorers) => {
// Required SHOULD: at least `minimum_number_should_match` must match.
// Semantics: (MUST constraint) AND (SHOULD constraint)
match effective_must_scorer(
must_scorers,
must_special_scorer_counts.num_all_scorers,
reader.max_doc(),
num_docs,
) {
None => {
// No MUST constraint: SHOULD alone determines matching.
should_scorer
}
Some(must_scorer) => {
// Has MUST constraint: intersect MUST with SHOULD.
let should_boxed =
into_box_scorer(should_scorer, &score_combiner_fn, num_docs);
SpecializedScorer::Other(intersect_scorers(
vec![must_scorer, should_boxed],
num_docs,
))
}
}
}
};

View File

@@ -9,12 +9,14 @@ pub use self::boolean_weight::BooleanWeight;
#[cfg(test)]
mod tests {
use std::ops::Bound;
use super::*;
use crate::collector::tests::TEST_COLLECTOR_WITH_SCORE;
use crate::collector::TopDocs;
use crate::collector::{Count, TopDocs};
use crate::query::term_query::TermScorer;
use crate::query::{
AllScorer, EmptyScorer, EnableScoring, Intersection, Occur, Query, QueryParser,
AllScorer, EmptyScorer, EnableScoring, Intersection, Occur, Query, QueryParser, RangeQuery,
RequiredOptionalScorer, Scorer, SumCombiner, TermQuery,
};
use crate::schema::*;
@@ -374,4 +376,466 @@ mod tests {
}
Ok(())
}
#[test]
pub fn test_min_should_match_with_all_query() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let text_field = schema_builder.add_text_field("text", TEXT);
let num_field =
schema_builder.add_i64_field("num", NumericOptions::default().set_fast().set_indexed());
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer: IndexWriter = index.writer_for_tests()?;
index_writer.add_document(doc!(text_field => "apple", num_field => 10i64))?;
index_writer.add_document(doc!(text_field => "banana", num_field => 20i64))?;
index_writer.commit()?;
let searcher = index.reader()?.searcher();
let effective_all_match_query: Box<dyn Query> = Box::new(RangeQuery::new(
Bound::Excluded(Term::from_field_i64(num_field, 0)),
Bound::Unbounded,
));
let term_query: Box<dyn Query> = Box::new(TermQuery::new(
Term::from_field_text(text_field, "apple"),
IndexRecordOption::Basic,
));
// in some previous version, we would remove the 2 all_match, but then say we need *4*
// matches out of the 3 term queries, which matches nothing.
let mut bool_query = BooleanQuery::new(vec![
(Occur::Should, effective_all_match_query.box_clone()),
(Occur::Should, effective_all_match_query.box_clone()),
(Occur::Should, term_query.box_clone()),
(Occur::Should, term_query.box_clone()),
(Occur::Should, term_query.box_clone()),
]);
bool_query.set_minimum_number_should_match(4);
let count = searcher.search(&bool_query, &Count)?;
assert_eq!(count, 1);
Ok(())
}
// =========================================================================
// AllScorer Preservation Regression Tests
// =========================================================================
//
// These tests verify the fix for a bug where AllScorer instances (produced by
// queries matching all documents, such as range queries covering all values)
// were incorrectly removed from Boolean query processing, causing documents
// to be unexpectedly excluded from results.
//
// The bug manifested in several scenarios:
// 1. SHOULD + SHOULD where one clause is AllScorer
// 2. MUST (AllScorer) + SHOULD
// 3. Range queries in Boolean clauses when all documents match the range
/// Regression test: SHOULD clause with AllScorer combined with other SHOULD clauses.
///
/// When a SHOULD clause produces an AllScorer (e.g., from a range query matching
/// all documents), the Boolean query should still match all documents.
///
/// Bug before fix: AllScorer was removed during optimization, leaving only the
/// other SHOULD clauses, which incorrectly excluded documents.
#[test]
pub fn test_should_with_all_scorer_regression() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let text_field = schema_builder.add_text_field("text", TEXT);
let num_field =
schema_builder.add_i64_field("num", NumericOptions::default().set_fast().set_indexed());
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer: IndexWriter = index.writer_for_tests()?;
// All docs have num > 0, so range query will return AllScorer
index_writer.add_document(doc!(text_field => "hello", num_field => 10i64))?;
index_writer.add_document(doc!(text_field => "world", num_field => 20i64))?;
index_writer.add_document(doc!(text_field => "hello world", num_field => 30i64))?;
index_writer.add_document(doc!(text_field => "foo", num_field => 40i64))?;
index_writer.add_document(doc!(text_field => "bar", num_field => 50i64))?;
index_writer.add_document(doc!(text_field => "baz", num_field => 60i64))?;
index_writer.commit()?;
let searcher = index.reader()?.searcher();
// Range query matching all docs (returns AllScorer)
let all_match_query: Box<dyn Query> = Box::new(RangeQuery::new(
Bound::Excluded(Term::from_field_i64(num_field, 0)),
Bound::Unbounded,
));
let term_query: Box<dyn Query> = Box::new(TermQuery::new(
Term::from_field_text(text_field, "hello"),
IndexRecordOption::Basic,
));
// Verify range matches all 6 docs
assert_eq!(searcher.search(all_match_query.as_ref(), &Count)?, 6);
// RangeQuery(all) OR TermQuery should match all 6 docs
let bool_query = BooleanQuery::new(vec![
(Occur::Should, all_match_query.box_clone()),
(Occur::Should, term_query.box_clone()),
]);
let count = searcher.search(&bool_query, &Count)?;
assert_eq!(count, 6, "SHOULD with AllScorer should match all docs");
// Order should not matter
let bool_query_reversed = BooleanQuery::new(vec![
(Occur::Should, term_query.box_clone()),
(Occur::Should, all_match_query.box_clone()),
]);
let count_reversed = searcher.search(&bool_query_reversed, &Count)?;
assert_eq!(
count_reversed, 6,
"Order of SHOULD clauses should not matter"
);
Ok(())
}
/// Regression test: MUST clause with AllScorer combined with SHOULD clause.
///
/// When MUST contains an AllScorer, all documents satisfy the MUST constraint.
/// The SHOULD clause should only affect scoring, not filtering.
///
/// Bug before fix: AllScorer was removed, leaving an empty must_scorers vector.
/// intersect_scorers([]) incorrectly returned EmptyScorer, matching 0 documents.
#[test]
pub fn test_must_all_with_should_regression() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let text_field = schema_builder.add_text_field("text", TEXT);
let num_field =
schema_builder.add_i64_field("num", NumericOptions::default().set_fast().set_indexed());
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer: IndexWriter = index.writer_for_tests()?;
// All docs have num > 0, so range query will return AllScorer
index_writer.add_document(doc!(text_field => "apple", num_field => 10i64))?;
index_writer.add_document(doc!(text_field => "banana", num_field => 20i64))?;
index_writer.add_document(doc!(text_field => "cherry", num_field => 30i64))?;
index_writer.add_document(doc!(text_field => "date", num_field => 40i64))?;
index_writer.commit()?;
let searcher = index.reader()?.searcher();
// Range query matching all docs (returns AllScorer)
let all_match_query: Box<dyn Query> = Box::new(RangeQuery::new(
Bound::Excluded(Term::from_field_i64(num_field, 0)),
Bound::Unbounded,
));
let term_query: Box<dyn Query> = Box::new(TermQuery::new(
Term::from_field_text(text_field, "apple"),
IndexRecordOption::Basic,
));
// Verify range matches all 4 docs
assert_eq!(searcher.search(all_match_query.as_ref(), &Count)?, 4);
// MUST(range matching all) AND SHOULD(term) should match all 4 docs
let bool_query = BooleanQuery::new(vec![
(Occur::Must, all_match_query.box_clone()),
(Occur::Should, term_query.box_clone()),
]);
let count = searcher.search(&bool_query, &Count)?;
assert_eq!(count, 4, "MUST AllScorer + SHOULD should match all docs");
Ok(())
}
/// Regression test: Range queries in Boolean clauses when all documents match.
///
/// Range queries can return AllScorer as an optimization when all indexed values
/// fall within the range. This test ensures such queries work correctly in
/// Boolean combinations.
///
/// This is the most common real-world manifestation of the bug, occurring in
/// queries like: (age > 50 OR name = 'Alice') AND status = 'active'
/// when all documents have age > 50.
#[test]
pub fn test_range_query_all_match_in_boolean() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let name_field = schema_builder.add_text_field("name", TEXT);
let age_field =
schema_builder.add_i64_field("age", NumericOptions::default().set_fast().set_indexed());
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer: IndexWriter = index.writer_for_tests()?;
// All documents have age > 50, so range query will return AllScorer
index_writer.add_document(doc!(name_field => "alice", age_field => 55_i64))?;
index_writer.add_document(doc!(name_field => "bob", age_field => 60_i64))?;
index_writer.add_document(doc!(name_field => "charlie", age_field => 70_i64))?;
index_writer.add_document(doc!(name_field => "diana", age_field => 80_i64))?;
index_writer.commit()?;
let searcher = index.reader()?.searcher();
let range_query: Box<dyn Query> = Box::new(RangeQuery::new(
Bound::Excluded(Term::from_field_i64(age_field, 50)),
Bound::Unbounded,
));
let term_query: Box<dyn Query> = Box::new(TermQuery::new(
Term::from_field_text(name_field, "alice"),
IndexRecordOption::Basic,
));
// Verify preconditions
assert_eq!(searcher.search(range_query.as_ref(), &Count)?, 4);
assert_eq!(searcher.search(term_query.as_ref(), &Count)?, 1);
// SHOULD(range) OR SHOULD(term): range matches all, so result is 4
let should_query = BooleanQuery::new(vec![
(Occur::Should, range_query.box_clone()),
(Occur::Should, term_query.box_clone()),
]);
assert_eq!(
searcher.search(&should_query, &Count)?,
4,
"SHOULD range OR term should match all"
);
// MUST(range) AND SHOULD(term): range matches all, term is optional
let must_should_query = BooleanQuery::new(vec![
(Occur::Must, range_query.box_clone()),
(Occur::Should, term_query.box_clone()),
]);
assert_eq!(
searcher.search(&must_should_query, &Count)?,
4,
"MUST range + SHOULD term should match all"
);
Ok(())
}
/// Test multiple AllScorer instances in different clause types.
///
/// Verifies correct behavior when AllScorers appear in multiple positions.
#[test]
pub fn test_multiple_all_scorers() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let text_field = schema_builder.add_text_field("text", TEXT);
let num_field =
schema_builder.add_i64_field("num", NumericOptions::default().set_fast().set_indexed());
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer: IndexWriter = index.writer_for_tests()?;
// All docs have num > 0, so range queries will return AllScorer
index_writer.add_document(doc!(text_field => "doc1", num_field => 10i64))?;
index_writer.add_document(doc!(text_field => "doc2", num_field => 20i64))?;
index_writer.add_document(doc!(text_field => "doc3", num_field => 30i64))?;
index_writer.commit()?;
let searcher = index.reader()?.searcher();
// Two different range queries that both match all docs (return AllScorer)
let all_query1: Box<dyn Query> = Box::new(RangeQuery::new(
Bound::Excluded(Term::from_field_i64(num_field, 0)),
Bound::Unbounded,
));
let all_query2: Box<dyn Query> = Box::new(RangeQuery::new(
Bound::Excluded(Term::from_field_i64(num_field, 5)),
Bound::Unbounded,
));
let term_query: Box<dyn Query> = Box::new(TermQuery::new(
Term::from_field_text(text_field, "doc1"),
IndexRecordOption::Basic,
));
// Multiple AllScorers in SHOULD
let multi_all_should = BooleanQuery::new(vec![
(Occur::Should, all_query1.box_clone()),
(Occur::Should, all_query2.box_clone()),
(Occur::Should, term_query.box_clone()),
]);
assert_eq!(
searcher.search(&multi_all_should, &Count)?,
3,
"Multiple AllScorers in SHOULD"
);
// AllScorer in both MUST and SHOULD
let all_must_and_should = BooleanQuery::new(vec![
(Occur::Must, all_query1.box_clone()),
(Occur::Should, all_query2.box_clone()),
]);
assert_eq!(
searcher.search(&all_must_and_should, &Count)?,
3,
"AllScorer in both MUST and SHOULD"
);
Ok(())
}
}
/// A proptest which generates arbitrary permutations of a simple boolean AST, and then matches
/// the result against an index which contains all permutations of documents with N fields.
#[cfg(test)]
mod proptest_boolean_query {
use std::collections::{BTreeMap, HashSet};
use std::ops::Bound;
use proptest::collection::vec;
use proptest::prelude::*;
use crate::collector::DocSetCollector;
use crate::query::{AllQuery, BooleanQuery, Occur, Query, RangeQuery, TermQuery};
use crate::schema::{Field, NumericOptions, OwnedValue, Schema, TEXT};
use crate::{DocId, Index, Term};
#[derive(Debug, Clone)]
enum BooleanQueryAST {
/// Matches all documents via AllQuery (wraps AllScorer in BoostScorer)
All,
/// Matches all documents via RangeQuery (returns bare AllScorer)
/// This is the actual trigger for the AllScorer preservation bug
RangeAll,
/// Matches documents where the field has value "true"
Leaf {
field_idx: usize,
},
Union(Vec<BooleanQueryAST>),
Intersection(Vec<BooleanQueryAST>),
}
impl BooleanQueryAST {
fn matches(&self, doc_id: DocId) -> bool {
match self {
BooleanQueryAST::All => true,
BooleanQueryAST::RangeAll => true,
BooleanQueryAST::Leaf { field_idx } => Self::matches_field(doc_id, *field_idx),
BooleanQueryAST::Union(children) => {
children.iter().any(|child| child.matches(doc_id))
}
BooleanQueryAST::Intersection(children) => {
children.iter().all(|child| child.matches(doc_id))
}
}
}
fn matches_field(doc_id: DocId, field_idx: usize) -> bool {
((doc_id as usize) >> field_idx) & 1 == 1
}
fn to_query(&self, fields: &[Field], range_field: Field) -> Box<dyn Query> {
match self {
BooleanQueryAST::All => Box::new(AllQuery),
BooleanQueryAST::RangeAll => {
// Range query that matches all docs (all have value >= 0)
// This returns bare AllScorer, triggering the bug we fixed
Box::new(RangeQuery::new(
Bound::Included(Term::from_field_i64(range_field, 0)),
Bound::Unbounded,
))
}
BooleanQueryAST::Leaf { field_idx } => Box::new(TermQuery::new(
Term::from_field_text(fields[*field_idx], "true"),
crate::schema::IndexRecordOption::Basic,
)),
BooleanQueryAST::Union(children) => {
let sub_queries = children
.iter()
.map(|child| (Occur::Should, child.to_query(fields, range_field)))
.collect();
Box::new(BooleanQuery::new(sub_queries))
}
BooleanQueryAST::Intersection(children) => {
let sub_queries = children
.iter()
.map(|child| (Occur::Must, child.to_query(fields, range_field)))
.collect();
Box::new(BooleanQuery::new(sub_queries))
}
}
}
}
fn doc_ids(num_docs: usize, num_fields: usize) -> impl Iterator<Item = DocId> {
let permutations = 1 << num_fields;
let copies = (num_docs as f32 / permutations as f32).ceil() as u32;
(0..(permutations * copies)).into_iter()
}
fn create_index_with_boolean_permutations(
num_docs: usize,
num_fields: usize,
) -> (Index, Vec<Field>, Field) {
let mut schema_builder = Schema::builder();
let fields: Vec<Field> = (0..num_fields)
.map(|i| schema_builder.add_text_field(&format!("field_{}", i), TEXT))
.collect();
// Add a numeric field for RangeQuery tests - all docs have value = doc_id
let range_field = schema_builder.add_i64_field(
"range_field",
NumericOptions::default().set_fast().set_indexed(),
);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut writer = index.writer_for_tests().unwrap();
for doc_id in doc_ids(num_docs, num_fields) {
let mut doc: BTreeMap<_, OwnedValue> = BTreeMap::default();
for (field_idx, &field) in fields.iter().enumerate() {
if (doc_id >> field_idx) & 1 == 1 {
doc.insert(field, "true".into());
}
}
// All docs have non-negative values, so RangeQuery(>=0) matches all
doc.insert(range_field, (doc_id as i64).into());
writer.add_document(doc).unwrap();
}
writer.commit().unwrap();
(index, fields, range_field)
}
fn arb_boolean_query_ast(num_fields: usize) -> impl Strategy<Value = BooleanQueryAST> {
// Leaf strategies: term queries, AllQuery, and RangeQuery matching all docs
let leaf = prop_oneof![
(0..num_fields).prop_map(|field_idx| BooleanQueryAST::Leaf { field_idx }),
Just(BooleanQueryAST::All),
Just(BooleanQueryAST::RangeAll),
];
leaf.prop_recursive(
8, // 8 levels of recursion
256, // 256 nodes max
10, // 10 items per collection
|inner| {
prop_oneof![
vec(inner.clone(), 1..10).prop_map(BooleanQueryAST::Union),
vec(inner, 1..10).prop_map(BooleanQueryAST::Intersection),
]
},
)
}
#[test]
fn proptest_boolean_query() {
// In the presence of optimizations around buffering, it can take large numbers of
// documents to uncover some issues.
let num_fields = 8;
let num_docs = 1 << num_fields;
let (index, fields, range_field) =
create_index_with_boolean_permutations(num_docs, num_fields);
let searcher = index.reader().unwrap().searcher();
proptest!(|(ast in arb_boolean_query_ast(num_fields))| {
let query = ast.to_query(&fields, range_field);
let mut matching_docs = HashSet::new();
for doc_id in doc_ids(num_docs, num_fields) {
if ast.matches(doc_id as DocId) {
matching_docs.insert(doc_id as DocId);
}
}
let doc_addresses = searcher.search(&*query, &DocSetCollector).unwrap();
let result_docs: HashSet<DocId> =
doc_addresses.into_iter().map(|doc_address| doc_address.doc_id).collect();
prop_assert_eq!(result_docs, matching_docs);
});
}
}

View File

@@ -104,6 +104,9 @@ impl<S: Scorer> DocSet for BoostScorer<S> {
fn seek(&mut self, target: DocId) -> DocId {
self.underlying.seek(target)
}
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
self.underlying.seek_into_the_danger_zone(target)
}
fn fill_buffer(&mut self, buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN]) -> usize {
self.underlying.fill_buffer(buffer)

View File

@@ -62,6 +62,16 @@ impl<T: Scorer> DocSet for ScorerWrapper<T> {
self.current_doc = doc_id;
doc_id
}
fn seek(&mut self, target: DocId) -> DocId {
let doc_id = self.scorer.seek(target);
self.current_doc = doc_id;
doc_id
}
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
let found = self.scorer.seek_into_the_danger_zone(target);
self.current_doc = self.scorer.doc();
found
}
fn doc(&self) -> DocId {
self.current_doc

View File

@@ -1,5 +1,5 @@
use super::size_hint::estimate_intersection;
use crate::docset::{DocSet, TERMINATED};
use crate::query::size_hint::estimate_intersection;
use crate::query::term_query::TermScorer;
use crate::query::{EmptyScorer, Scorer};
use crate::{DocId, Score};
@@ -12,6 +12,9 @@ use crate::{DocId, Score};
/// For better performance, the function uses a
/// specialized implementation if the two
/// shortest scorers are `TermScorer`s.
///
/// num_docs_segment is the number of documents in the segment. It is used for estimating the
/// `size_hint` of the intersection.
pub fn intersect_scorers(
mut scorers: Vec<Box<dyn Scorer>>,
num_docs_segment: u32,
@@ -105,32 +108,44 @@ impl<TDocSet: DocSet, TOtherDocSet: DocSet> DocSet for Intersection<TDocSet, TOt
fn advance(&mut self) -> DocId {
let (left, right) = (&mut self.left, &mut self.right);
let mut candidate = left.advance();
if candidate == TERMINATED {
return TERMINATED;
}
'outer: loop {
loop {
// In the first part we look for a document in the intersection
// of the two rarest `DocSet` in the intersection.
loop {
let right_doc = right.seek(candidate);
candidate = left.seek(right_doc);
if candidate == right_doc {
if right.seek_into_the_danger_zone(candidate) {
break;
}
let right_doc = right.doc();
// TODO: Think about which value would make sense here
// It depends on the DocSet implementation, when a seek would outweigh an advance.
if right_doc > candidate.wrapping_add(100) {
candidate = left.seek(right_doc);
} else {
candidate = left.advance();
}
if candidate == TERMINATED {
return TERMINATED;
}
}
debug_assert_eq!(left.doc(), right.doc());
// test the remaining scorers;
for docset in self.others.iter_mut() {
let seek_doc = docset.seek(candidate);
if seek_doc > candidate {
candidate = left.seek(seek_doc);
continue 'outer;
}
// test the remaining scorers
if self
.others
.iter_mut()
.all(|docset| docset.seek_into_the_danger_zone(candidate))
{
debug_assert_eq!(candidate, self.left.doc());
debug_assert_eq!(candidate, self.right.doc());
debug_assert!(self.others.iter().all(|docset| docset.doc() == candidate));
return candidate;
}
debug_assert_eq!(candidate, self.left.doc());
debug_assert_eq!(candidate, self.right.doc());
debug_assert!(self.others.iter().all(|docset| docset.doc() == candidate));
return candidate;
candidate = left.advance();
}
}
@@ -146,6 +161,19 @@ impl<TDocSet: DocSet, TOtherDocSet: DocSet> DocSet for Intersection<TDocSet, TOt
doc
}
/// Seeks to the target if necessary and checks if the target is an exact match.
///
/// Some implementations may choose to advance past the target if beneficial for performance.
/// The return value is `true` if the target is in the docset, and `false` otherwise.
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
self.left.seek_into_the_danger_zone(target)
&& self.right.seek_into_the_danger_zone(target)
&& self
.others
.iter_mut()
.all(|docset| docset.seek_into_the_danger_zone(target))
}
fn doc(&self) -> DocId {
self.left.doc()
}
@@ -181,6 +209,8 @@ where
#[cfg(test)]
mod tests {
use proptest::prelude::*;
use super::Intersection;
use crate::docset::{DocSet, TERMINATED};
use crate::postings::tests::test_skip_against_unoptimized;
@@ -270,4 +300,38 @@ mod tests {
let intersection = Intersection::new(vec![a, b, c], 10);
assert_eq!(intersection.doc(), TERMINATED);
}
// Strategy to generate sorted and deduplicated vectors of u32 document IDs
fn sorted_deduped_vec(max_val: u32, max_size: usize) -> impl Strategy<Value = Vec<u32>> {
prop::collection::vec(0..max_val, 0..max_size).prop_map(|mut vec| {
vec.sort();
vec.dedup();
vec
})
}
proptest! {
#[test]
fn prop_test_intersection_consistency(
a in sorted_deduped_vec(100, 10),
b in sorted_deduped_vec(100, 10),
num_docs in 100u32..500u32
) {
let left = VecDocSet::from(a.clone());
let right = VecDocSet::from(b.clone());
let mut intersection = Intersection::new(vec![left, right], num_docs);
let expected: Vec<u32> = a.iter()
.cloned()
.filter(|doc| b.contains(doc))
.collect();
for expected_doc in expected {
assert_eq!(intersection.doc(), expected_doc);
intersection.advance();
}
assert_eq!(intersection.doc(), TERMINATED);
}
}
}

View File

@@ -70,9 +70,83 @@ pub use self::weight::Weight;
#[cfg(test)]
mod tests {
use crate::collector::TopDocs;
use crate::query::phrase_query::tests::create_index;
use crate::query::QueryParser;
use crate::schema::{Schema, TEXT};
use crate::{Index, Term};
use crate::{DocAddress, Index, Term};
#[test]
pub fn test_mixed_intersection_and_union() -> crate::Result<()> {
let index = create_index(&["a b", "a c", "a b c", "b"])?;
let schema = index.schema();
let text_field = schema.get_field("text").unwrap();
let searcher = index.reader()?.searcher();
let do_search = |term: &str| {
let query = QueryParser::for_index(&index, vec![text_field])
.parse_query(term)
.unwrap();
let top_docs: Vec<(f32, DocAddress)> = searcher
.search(&query, &TopDocs::with_limit(10).order_by_score())
.unwrap();
top_docs.iter().map(|el| el.1.doc_id).collect::<Vec<_>>()
};
assert_eq!(do_search("a AND b"), vec![0, 2]);
assert_eq!(do_search("(a OR b) AND C"), vec![2, 1]);
// The intersection code has special code for more than 2 intersections
// left, right + others
// The will place the union in the "others" insersection to that seek_into_the_danger_zone
// is called
assert_eq!(
do_search("(a OR b) AND (c OR a) AND (b OR c)"),
vec![2, 1, 0]
);
Ok(())
}
#[test]
pub fn test_mixed_intersection_and_union_with_skip() -> crate::Result<()> {
// Test 4096 skip in BufferedUnionScorer
let mut data: Vec<&str> = Vec::new();
data.push("a b");
let zz_data = vec!["z z"; 5000];
data.extend_from_slice(&zz_data);
data.extend_from_slice(&["a c"]);
data.extend_from_slice(&zz_data);
data.extend_from_slice(&["a b c", "b"]);
let index = create_index(&data)?;
let schema = index.schema();
let text_field = schema.get_field("text").unwrap();
let searcher = index.reader()?.searcher();
let do_search = |term: &str| {
let query = QueryParser::for_index(&index, vec![text_field])
.parse_query(term)
.unwrap();
let top_docs: Vec<(f32, DocAddress)> = searcher
.search(&query, &TopDocs::with_limit(10).order_by_score())
.unwrap();
top_docs.iter().map(|el| el.1.doc_id).collect::<Vec<_>>()
};
assert_eq!(do_search("a AND b"), vec![0, 10002]);
assert_eq!(do_search("(a OR b) AND C"), vec![10002, 5001]);
// The intersection code has special code for more than 2 intersections
// left, right + others
// The will place the union in the "others" insersection to that seek_into_the_danger_zone
// is called
assert_eq!(
do_search("(a OR b) AND (c OR a) AND (b OR c)"),
vec![10002, 5001, 0]
);
Ok(())
}
#[test]
fn test_query_terms() {

View File

@@ -193,6 +193,14 @@ impl<TPostings: Postings> DocSet for PhrasePrefixScorer<TPostings> {
self.advance()
}
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
if self.phrase_scorer.seek_into_the_danger_zone(target) {
self.matches_prefix()
} else {
false
}
}
fn doc(&self) -> DocId {
self.phrase_scorer.doc()
}

View File

@@ -382,8 +382,9 @@ impl<TPostings: Postings> PhraseScorer<TPostings> {
PostingsWithOffset::new(postings, (max_offset - offset) as u32)
})
.collect::<Vec<_>>();
let intersection_docset = Intersection::new(postings_with_offsets, num_docs);
let mut scorer = PhraseScorer {
intersection_docset: Intersection::new(postings_with_offsets, num_docs),
intersection_docset,
num_terms: num_docsets,
left_positions: Vec::with_capacity(100),
right_positions: Vec::with_capacity(100),
@@ -529,20 +530,34 @@ impl<TPostings: Postings> DocSet for PhraseScorer<TPostings> {
self.advance()
}
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
debug_assert!(target >= self.doc());
if self.intersection_docset.seek_into_the_danger_zone(target) && self.phrase_match() {
return true;
}
false
}
fn doc(&self) -> DocId {
self.intersection_docset.doc()
}
fn size_hint(&self) -> u32 {
self.intersection_docset.size_hint()
// We adjust the intersection estimate, since actual phrase hits are much lower than where
// the all appear.
// The estimate should depend on average field length, e.g. if the field is really short
// a phrase hit is more likely
self.intersection_docset.size_hint() / (10 * self.num_terms as u32)
}
/// Returns a best-effort hint of the
/// cost to drive the docset.
fn cost(&self) -> u64 {
// Evaluating phrase matches is generally more expensive than simple term matches,
// as it requires loading and comparing positions. Use a conservative multiplier
// based on the number of terms.
// While determing a potential hit is cheap for phrases, evaluating an actual hit is
// expensive since it requires to load positions for a doc and check if they are next to
// each other.
// So the cost estimation would be the number of times we need to check if a doc is a hit *
// 10 * self.num_terms.
self.intersection_docset.size_hint() as u64 * 10 * self.num_terms as u64
}
}

View File

@@ -62,6 +62,17 @@ pub(crate) struct RangeDocSet<T> {
const DEFAULT_FETCH_HORIZON: u32 = 128;
impl<T: Send + Sync + PartialOrd + Copy + Debug + 'static> RangeDocSet<T> {
pub(crate) fn new(value_range: RangeInclusive<T>, column: Column<T>) -> Self {
if *value_range.start() > column.max_value() || *value_range.end() < column.min_value() {
return Self {
value_range,
column,
loaded_docs: VecCursor::new(),
next_fetch_start: TERMINATED,
fetch_horizon: DEFAULT_FETCH_HORIZON,
last_seek_pos_opt: None,
};
}
let mut range_docset = Self {
value_range,
column,
@@ -81,6 +92,9 @@ impl<T: Send + Sync + PartialOrd + Copy + Debug + 'static> RangeDocSet<T> {
/// Returns true if more data could be fetched
fn fetch_block(&mut self) {
if self.next_fetch_start >= self.column.num_docs() {
return;
}
const MAX_HORIZON: u32 = 100_000;
while self.loaded_docs.is_empty() {
let finished_to_end = self.fetch_horizon(self.fetch_horizon);
@@ -105,10 +119,10 @@ impl<T: Send + Sync + PartialOrd + Copy + Debug + 'static> RangeDocSet<T> {
fn fetch_horizon(&mut self, horizon: u32) -> bool {
let mut finished_to_end = false;
let limit = self.column.num_docs();
let mut end = self.next_fetch_start + horizon;
if end >= limit {
end = limit;
let num_docs = self.column.num_docs();
let mut fetch_end = self.next_fetch_start + horizon;
if fetch_end >= num_docs {
fetch_end = num_docs;
finished_to_end = true;
}
@@ -116,7 +130,7 @@ impl<T: Send + Sync + PartialOrd + Copy + Debug + 'static> RangeDocSet<T> {
let doc_buffer: &mut Vec<DocId> = self.loaded_docs.get_cleared_data();
self.column.get_docids_for_value_range(
self.value_range.clone(),
self.next_fetch_start..end,
self.next_fetch_start..fetch_end,
doc_buffer,
);
if let Some(last_doc) = last_doc {
@@ -124,7 +138,7 @@ impl<T: Send + Sync + PartialOrd + Copy + Debug + 'static> RangeDocSet<T> {
self.loaded_docs.next();
}
}
self.next_fetch_start = end;
self.next_fetch_start = fetch_end;
finished_to_end
}
@@ -136,9 +150,6 @@ impl<T: Send + Sync + PartialOrd + Copy + Debug + 'static> DocSet for RangeDocSe
if let Some(docid) = self.loaded_docs.next() {
return docid;
}
if self.next_fetch_start >= self.column.num_docs() {
return TERMINATED;
}
self.fetch_block();
self.loaded_docs.current().unwrap_or(TERMINATED)
}
@@ -174,15 +185,25 @@ impl<T: Send + Sync + PartialOrd + Copy + Debug + 'static> DocSet for RangeDocSe
}
fn size_hint(&self) -> u32 {
self.column.num_docs()
// TODO: Implement a better size hint
self.column.num_docs() / 10
}
/// Returns a best-effort hint of the
/// cost to drive the docset.
fn cost(&self) -> u64 {
// Advancing the docset is relatively expensive since it scans the column.
// Keep cost relative to a term query driver; use num_docs as baseline.
self.column.num_docs() as u64
// Advancing the docset is pretty expensive since it scans the whole column, there is no
// index currently (will change with an kd-tree)
// Since we use SIMD to scan the fast field range query we lower the cost a little bit,
// assuming that we hit 10% of the docs like in size_hint.
//
// If we would return a cost higher than num_docs, we would never choose ff range query as
// the driver in a DocSet, when intersecting a term query with a fast field. But
// it's the faster choice when the term query has a lot of docids and the range
// query has not.
//
// Ideally this would take the fast field codec into account
(self.column.num_docs() as f64 * 0.8) as u64
}
}
@@ -236,4 +257,52 @@ mod tests {
let count = searcher.search(&query, &Count).unwrap();
assert_eq!(count, 500);
}
#[test]
fn range_query_no_overlap_optimization() {
let mut schema_builder = schema::SchemaBuilder::new();
let id_field = schema_builder.add_text_field("id", schema::STRING);
let value_field = schema_builder.add_u64_field("value", schema::FAST | schema::INDEXED);
let dir = RamDirectory::default();
let index = IndexBuilder::new()
.schema(schema_builder.build())
.open_or_create(dir)
.unwrap();
{
let mut writer = index.writer(15_000_000).unwrap();
// Add documents with values in the range [10, 20]
for i in 0..100 {
let mut doc = TantivyDocument::new();
doc.add_text(id_field, format!("doc{i}"));
doc.add_u64(value_field, 10 + (i % 11) as u64); // values in range 10-20
writer.add_document(doc).unwrap();
}
writer.commit().unwrap();
}
let reader = index.reader().unwrap();
let searcher = reader.searcher();
// Test a range query [100, 200] that has no overlap with data range [10, 20]
let query = RangeQuery::new(
Bound::Included(Term::from_field_u64(value_field, 100)),
Bound::Included(Term::from_field_u64(value_field, 200)),
);
let count = searcher.search(&query, &Count).unwrap();
assert_eq!(count, 0); // should return 0 results since there's no overlap
// Test another non-overlapping range: [0, 5] while data range is [10, 20]
let query2 = RangeQuery::new(
Bound::Included(Term::from_field_u64(value_field, 0)),
Bound::Included(Term::from_field_u64(value_field, 5)),
);
let count2 = searcher.search(&query2, &Count).unwrap();
assert_eq!(count2, 0); // should return 0 results since there's no overlap
}
}

View File

@@ -1598,449 +1598,3 @@ pub(crate) mod ip_range_tests {
Ok(())
}
}
#[cfg(all(test, feature = "unstable"))]
mod bench {
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use test::Bencher;
use super::tests::*;
use super::*;
use crate::collector::Count;
use crate::query::QueryParser;
use crate::Index;
fn get_index_0_to_100() -> Index {
let mut rng = StdRng::from_seed([1u8; 32]);
let num_vals = 100_000;
let docs: Vec<_> = (0..num_vals)
.map(|_i| {
let id_name = if rng.gen_bool(0.01) {
"veryfew".to_string() // 1%
} else if rng.gen_bool(0.1) {
"few".to_string() // 9%
} else {
"many".to_string() // 90%
};
Doc {
id_name,
id: rng.gen_range(0..100),
}
})
.collect();
create_index_from_docs(&docs, false)
}
fn get_90_percent() -> RangeInclusive<u64> {
0..=90
}
fn get_10_percent() -> RangeInclusive<u64> {
0..=10
}
fn get_1_percent() -> RangeInclusive<u64> {
10..=10
}
fn execute_query(
field: &str,
id_range: RangeInclusive<u64>,
suffix: &str,
index: &Index,
) -> usize {
let gen_query_inclusive = |from: &u64, to: &u64| {
format!(
"{}:[{} TO {}] {}",
field,
&from.to_string(),
&to.to_string(),
suffix
)
};
let query = gen_query_inclusive(id_range.start(), id_range.end());
let query_from_text = |text: &str| {
QueryParser::for_index(index, vec![])
.parse_query(text)
.unwrap()
};
let query = query_from_text(&query);
let reader = index.reader().unwrap();
let searcher = reader.searcher();
searcher.search(&query, &(Count)).unwrap()
}
#[bench]
fn bench_id_range_hit_90_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("id", get_90_percent(), "", &index));
}
#[bench]
fn bench_id_range_hit_10_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("id", get_10_percent(), "", &index));
}
#[bench]
fn bench_id_range_hit_1_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("id", get_1_percent(), "", &index));
}
#[bench]
fn bench_id_range_hit_10_percent_intersect_with_10_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("id", get_10_percent(), "AND id_name:few", &index));
}
#[bench]
fn bench_id_range_hit_1_percent_intersect_with_10_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("id", get_1_percent(), "AND id_name:few", &index));
}
#[bench]
fn bench_id_range_hit_1_percent_intersect_with_90_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("id", get_1_percent(), "AND id_name:many", &index));
}
#[bench]
fn bench_id_range_hit_1_percent_intersect_with_1_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("id", get_1_percent(), "AND id_name:veryfew", &index));
}
#[bench]
fn bench_id_range_hit_10_percent_intersect_with_90_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("id", get_10_percent(), "AND id_name:many", &index));
}
#[bench]
fn bench_id_range_hit_90_percent_intersect_with_90_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("id", get_90_percent(), "AND id_name:many", &index));
}
#[bench]
fn bench_id_range_hit_90_percent_intersect_with_10_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("id", get_90_percent(), "AND id_name:few", &index));
}
#[bench]
fn bench_id_range_hit_90_percent_intersect_with_1_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("id", get_90_percent(), "AND id_name:veryfew", &index));
}
#[bench]
fn bench_id_range_hit_90_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ids", get_90_percent(), "", &index));
}
#[bench]
fn bench_id_range_hit_10_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ids", get_10_percent(), "", &index));
}
#[bench]
fn bench_id_range_hit_1_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ids", get_1_percent(), "", &index));
}
#[bench]
fn bench_id_range_hit_10_percent_intersect_with_10_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ids", get_10_percent(), "AND id_name:few", &index));
}
#[bench]
fn bench_id_range_hit_1_percent_intersect_with_10_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ids", get_1_percent(), "AND id_name:few", &index));
}
#[bench]
fn bench_id_range_hit_1_percent_intersect_with_90_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ids", get_1_percent(), "AND id_name:many", &index));
}
#[bench]
fn bench_id_range_hit_1_percent_intersect_with_1_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ids", get_1_percent(), "AND id_name:veryfew", &index));
}
#[bench]
fn bench_id_range_hit_10_percent_intersect_with_90_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ids", get_10_percent(), "AND id_name:many", &index));
}
#[bench]
fn bench_id_range_hit_90_percent_intersect_with_90_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ids", get_90_percent(), "AND id_name:many", &index));
}
#[bench]
fn bench_id_range_hit_90_percent_intersect_with_10_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ids", get_90_percent(), "AND id_name:few", &index));
}
#[bench]
fn bench_id_range_hit_90_percent_intersect_with_1_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ids", get_90_percent(), "AND id_name:veryfew", &index));
}
}
#[cfg(all(test, feature = "unstable"))]
mod bench_ip {
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use test::Bencher;
use super::ip_range_tests::*;
use super::*;
use crate::collector::Count;
use crate::query::QueryParser;
use crate::Index;
fn get_index_0_to_100() -> Index {
let mut rng = StdRng::from_seed([1u8; 32]);
let num_vals = 100_000;
let docs: Vec<_> = (0..num_vals)
.map(|_i| {
let id = if rng.gen_bool(0.01) {
"veryfew".to_string() // 1%
} else if rng.gen_bool(0.1) {
"few".to_string() // 9%
} else {
"many".to_string() // 90%
};
Doc {
id,
// Multiply by 1000, so that we create many buckets in the compact space
// The benches depend on this range to select n-percent of elements with the
// methods below.
ip: Ipv6Addr::from_u128(rng.gen_range(0..100) * 1000),
}
})
.collect();
create_index_from_ip_docs(&docs)
}
fn get_90_percent() -> RangeInclusive<Ipv6Addr> {
let start = Ipv6Addr::from_u128(0);
let end = Ipv6Addr::from_u128(90 * 1000);
start..=end
}
fn get_10_percent() -> RangeInclusive<Ipv6Addr> {
let start = Ipv6Addr::from_u128(0);
let end = Ipv6Addr::from_u128(10 * 1000);
start..=end
}
fn get_1_percent() -> RangeInclusive<Ipv6Addr> {
let start = Ipv6Addr::from_u128(10 * 1000);
let end = Ipv6Addr::from_u128(10 * 1000);
start..=end
}
fn execute_query(
field: &str,
ip_range: RangeInclusive<Ipv6Addr>,
suffix: &str,
index: &Index,
) -> usize {
let gen_query_inclusive = |from: &Ipv6Addr, to: &Ipv6Addr| {
format!(
"{}:[{} TO {}] {}",
field,
&from.to_string(),
&to.to_string(),
suffix
)
};
let query = gen_query_inclusive(ip_range.start(), ip_range.end());
let query_from_text = |text: &str| {
QueryParser::for_index(index, vec![])
.parse_query(text)
.unwrap()
};
let query = query_from_text(&query);
let reader = index.reader().unwrap();
let searcher = reader.searcher();
searcher.search(&query, &(Count)).unwrap()
}
#[bench]
fn bench_ip_range_hit_90_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ip", get_90_percent(), "", &index));
}
#[bench]
fn bench_ip_range_hit_10_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ip", get_10_percent(), "", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ip", get_1_percent(), "", &index));
}
#[bench]
fn bench_ip_range_hit_10_percent_intersect_with_10_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ip", get_10_percent(), "AND id:few", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent_intersect_with_10_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ip", get_1_percent(), "AND id:few", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent_intersect_with_90_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ip", get_1_percent(), "AND id:many", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent_intersect_with_1_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ip", get_1_percent(), "AND id:veryfew", &index));
}
#[bench]
fn bench_ip_range_hit_10_percent_intersect_with_90_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ip", get_10_percent(), "AND id:many", &index));
}
#[bench]
fn bench_ip_range_hit_90_percent_intersect_with_90_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ip", get_90_percent(), "AND id:many", &index));
}
#[bench]
fn bench_ip_range_hit_90_percent_intersect_with_10_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ip", get_90_percent(), "AND id:few", &index));
}
#[bench]
fn bench_ip_range_hit_90_percent_intersect_with_1_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ip", get_90_percent(), "AND id:veryfew", &index));
}
#[bench]
fn bench_ip_range_hit_90_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ips", get_90_percent(), "", &index));
}
#[bench]
fn bench_ip_range_hit_10_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ips", get_10_percent(), "", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ips", get_1_percent(), "", &index));
}
#[bench]
fn bench_ip_range_hit_10_percent_intersect_with_10_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ips", get_10_percent(), "AND id:few", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent_intersect_with_10_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ips", get_1_percent(), "AND id:few", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent_intersect_with_90_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ips", get_1_percent(), "AND id:many", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent_intersect_with_1_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ips", get_1_percent(), "AND id:veryfew", &index));
}
#[bench]
fn bench_ip_range_hit_10_percent_intersect_with_90_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ips", get_10_percent(), "AND id:many", &index));
}
#[bench]
fn bench_ip_range_hit_90_percent_intersect_with_90_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ips", get_90_percent(), "AND id:many", &index));
}
#[bench]
fn bench_ip_range_hit_90_percent_intersect_with_10_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ips", get_90_percent(), "AND id:few", &index));
}
#[bench]
fn bench_ip_range_hit_90_percent_intersect_with_1_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| execute_query("ips", get_90_percent(), "AND id:veryfew", &index));
}
}

View File

@@ -56,6 +56,11 @@ where
self.req_scorer.seek(target)
}
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
self.score_cache = None;
self.req_scorer.seek_into_the_danger_zone(target)
}
fn doc(&self) -> DocId {
self.req_scorer.doc()
}

View File

@@ -98,14 +98,17 @@ impl TermScorer {
}
impl DocSet for TermScorer {
#[inline]
fn advance(&mut self) -> DocId {
self.postings.advance()
}
#[inline]
fn seek(&mut self, target: DocId) -> DocId {
self.postings.seek(target)
}
#[inline]
fn doc(&self) -> DocId {
self.postings.doc()
}

View File

@@ -15,7 +15,7 @@ const HORIZON: u32 = 64u32 * 64u32;
// This function is similar except that it does is not unstable, and
// it does not keep the original vector ordering.
//
// Also, it does not "yield" any elements.
// Elements are dropped and not yielded.
fn unordered_drain_filter<T, P>(v: &mut Vec<T>, mut predicate: P)
where P: FnMut(&mut T) -> bool {
let mut i = 0;
@@ -143,6 +143,12 @@ impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> BufferedUnionScorer<TScorer
}
false
}
fn is_in_horizon(&self, target: DocId) -> bool {
// wrapping_sub, because target may be < window_start_doc
let gap = target.wrapping_sub(self.window_start_doc);
gap < HORIZON
}
}
impl<TScorer, TScoreCombiner> DocSet for BufferedUnionScorer<TScorer, TScoreCombiner>
@@ -217,7 +223,27 @@ where
}
}
// TODO Also implement `count` with deletes efficiently.
fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool {
if self.is_in_horizon(target) {
// Our value is within the buffered horizon and the docset may already have been
// processed and removed, so we need to use seek, which uses the regular advance.
self.seek(target) == target
} else {
// The docsets are not in the buffered range, so we can use seek_into_the_danger_zone
// of the underlying docsets
let is_hit = self
.docsets
.iter_mut()
.any(|docset| docset.seek_into_the_danger_zone(target));
// The API requires the DocSet to be in a valid state when `seek_into_the_danger_zone`
// returns true.
if is_hit {
self.seek(target);
}
is_hit
}
}
fn doc(&self) -> DocId {
self.doc
@@ -231,6 +257,7 @@ where
self.docsets.iter().map(|docset| docset.cost()).sum()
}
// TODO Also implement `count` with deletes efficiently.
fn count_including_deleted(&mut self) -> u32 {
if self.doc == TERMINATED {
return 0;

View File

@@ -92,6 +92,7 @@ impl<TDocSet: DocSet> DocSet for SimpleUnion<TDocSet> {
}
fn size_hint(&self) -> u32 {
// TODO: use estimate_union
self.docsets
.iter()
.map(|docset| docset.size_hint())

View File

@@ -1,10 +1,11 @@
use std::hash::{Hash, Hasher};
use std::hash::Hash;
use std::net::Ipv6Addr;
use std::{fmt, str};
use columnar::MonotonicallyMappableToU128;
use common::json_path_writer::{JSON_END_OF_PATH, JSON_PATH_SEGMENT_SEP_STR};
use common::JsonPathWriter;
use serde::{Deserialize, Serialize};
use super::date_time_options::DATE_TIME_PRECISION_INDEXED;
use super::{Field, Schema};
@@ -16,23 +17,54 @@ use crate::DateTime;
/// Term represents the value that the token can take.
/// It's a serialized representation over different types.
///
/// It actually wraps a `Vec<u8>`. The first 5 bytes are metadata.
/// 4 bytes are the field id, and the last byte is the type.
///
/// The serialized value `ValueBytes` is considered everything after the 4 first bytes (term id).
#[derive(Clone)]
pub struct Term<B = Vec<u8>>(B)
where B: AsRef<[u8]>;
/// A term is composed of Field and the serialized value bytes.
/// The serialized value bytes themselves start with a one byte type tag followed by the payload.
#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)]
pub struct Term {
field: Field,
serialized_value_bytes: Vec<u8>,
}
/// The number of bytes used as metadata by `Term`.
const TERM_METADATA_LENGTH: usize = 5;
/// The number of bytes used as metadata when serializing a term.
const TERM_TYPE_TAG_LEN: usize = 1;
impl Term {
/// Takes a serialized term and wraps it as a Term.
/// First 4 bytes are the field id
#[deprecated(
note = "we want to avoid working on the serialized representation directly, replace with \
typed API calls (add more if needed) or use serde to serialize/deserialize"
)]
pub fn wrap(serialized: &[u8]) -> Term {
let field_id_bytes: [u8; 4] = serialized[0..4].try_into().unwrap();
let field_id = u32::from_be_bytes(field_id_bytes);
Term {
field: Field::from_field_id(field_id),
serialized_value_bytes: serialized[4..].to_vec(),
}
}
/// Returns the serialized representation of the term.
/// First 4 bytes are the field id
#[deprecated(
note = "we want to avoid working on the serialized representation directly, replace with \
typed API calls (add more if needed) or use serde to serialize/deserialize"
)]
pub fn serialized_term(&self) -> Vec<u8> {
let mut serialized = Vec::with_capacity(4 + self.serialized_value_bytes.len());
serialized.extend(self.field.field_id().to_be_bytes().as_ref());
serialized.extend_from_slice(&self.serialized_value_bytes);
serialized
}
/// Create a new Term with a buffer with a given capacity.
pub fn with_capacity(capacity: usize) -> Term {
let mut data = Vec::with_capacity(TERM_METADATA_LENGTH + capacity);
data.resize(TERM_METADATA_LENGTH, 0u8);
Term(data)
let mut data = Vec::with_capacity(TERM_TYPE_TAG_LEN + capacity);
data.resize(TERM_TYPE_TAG_LEN, 0u8);
Term {
field: Field::from_field_id(0u32),
serialized_value_bytes: data,
}
}
/// Creates a term from a json path.
@@ -89,7 +121,7 @@ impl Term {
fn with_bytes_and_field_and_payload(typ: Type, field: Field, bytes: &[u8]) -> Term {
let mut term = Self::with_capacity(bytes.len());
term.set_field_and_type(field, typ);
term.0.extend_from_slice(bytes);
term.serialized_value_bytes.extend_from_slice(bytes);
term
}
@@ -105,13 +137,13 @@ impl Term {
/// Sets field and the type.
pub(crate) fn set_field_and_type(&mut self, field: Field, typ: Type) {
assert!(self.is_empty());
self.0[0..4].clone_from_slice(field.field_id().to_be_bytes().as_ref());
self.0[4] = typ.to_code();
self.field = field;
self.serialized_value_bytes[0] = typ.to_code();
}
/// Is empty if there are no value bytes.
pub fn is_empty(&self) -> bool {
self.0.len() == TERM_METADATA_LENGTH
self.serialized_value_bytes.len() == TERM_TYPE_TAG_LEN
}
/// Builds a term given a field, and a `Ipv6Addr`-value
@@ -177,7 +209,7 @@ impl Term {
/// Removes the value_bytes and set the type code.
pub fn clear_with_type(&mut self, typ: Type) {
self.truncate_value_bytes(0);
self.0[4] = typ.to_code();
self.serialized_value_bytes[0] = typ.to_code();
}
/// Append a type marker + fast value to a term.
@@ -185,9 +217,10 @@ impl Term {
///
/// It will not clear existing bytes.
pub fn append_type_and_fast_value<T: FastValue>(&mut self, val: T) {
self.0.push(T::to_type().to_code());
self.serialized_value_bytes.push(T::to_type().to_code());
let value = val.to_u64();
self.0.extend(value.to_be_bytes().as_ref());
self.serialized_value_bytes
.extend(value.to_be_bytes().as_ref());
}
/// Append a string type marker + string to a term.
@@ -195,24 +228,25 @@ impl Term {
///
/// It will not clear existing bytes.
pub fn append_type_and_str(&mut self, val: &str) {
self.0.push(Type::Str.to_code());
self.0.extend(val.as_bytes().as_ref());
self.serialized_value_bytes.push(Type::Str.to_code());
self.serialized_value_bytes.extend(val.as_bytes().as_ref());
}
/// Sets the value of a `Bytes` field.
pub fn set_bytes(&mut self, bytes: &[u8]) {
self.truncate_value_bytes(0);
self.0.extend(bytes);
self.serialized_value_bytes.extend(bytes);
}
/// Truncates the value bytes of the term. Value and field type stays the same.
pub fn truncate_value_bytes(&mut self, len: usize) {
self.0.truncate(len + TERM_METADATA_LENGTH);
self.serialized_value_bytes
.truncate(len + TERM_TYPE_TAG_LEN);
}
/// The length of the bytes.
pub fn len_bytes(&self) -> usize {
self.0.len() - TERM_METADATA_LENGTH
self.serialized_value_bytes.len() - TERM_TYPE_TAG_LEN
}
/// Appends value bytes to the Term.
@@ -220,18 +254,9 @@ impl Term {
/// This function returns the segment that has just been added.
#[inline]
pub fn append_bytes(&mut self, bytes: &[u8]) -> &mut [u8] {
let len_before = self.0.len();
self.0.extend_from_slice(bytes);
&mut self.0[len_before..]
}
}
impl<B> Term<B>
where B: AsRef<[u8]>
{
/// Wraps a object holding bytes
pub fn wrap(data: B) -> Term<B> {
Term(data)
let len_before = self.serialized_value_bytes.len();
self.serialized_value_bytes.extend_from_slice(bytes);
&mut self.serialized_value_bytes[len_before..]
}
/// Return the type of the term.
@@ -241,8 +266,7 @@ where B: AsRef<[u8]>
/// Returns the field.
pub fn field(&self) -> Field {
let field_id_bytes: [u8; 4] = (&self.0.as_ref()[..4]).try_into().unwrap();
Field::from_field_id(u32::from_be_bytes(field_id_bytes))
self.field
}
/// Returns the serialized representation of the value.
@@ -252,23 +276,13 @@ where B: AsRef<[u8]>
/// If the term is a u64, its value is encoded according
/// to `byteorder::BigEndian`.
pub fn serialized_value_bytes(&self) -> &[u8] {
&self.0.as_ref()[TERM_METADATA_LENGTH..]
&self.serialized_value_bytes[TERM_TYPE_TAG_LEN..]
}
/// Returns the value of the term.
/// address or JSON path + value. (this does not include the field.)
pub fn value(&self) -> ValueBytes<&[u8]> {
ValueBytes::wrap(&self.0.as_ref()[4..])
}
/// Returns the serialized representation of Term.
/// This includes field_id, value type and value.
///
/// Do NOT rely on this byte representation in the index.
/// This value is likely to change in the future.
#[inline]
pub fn serialized_term(&self) -> &[u8] {
self.0.as_ref()
ValueBytes::wrap(self.serialized_value_bytes.as_ref())
}
}
@@ -452,10 +466,7 @@ where B: AsRef<[u8]>
}
}
/// Returns the serialized representation of Term.
///
/// Do NOT rely on this byte representation in the index.
/// This value is likely to change in the future.
/// Returns the serialized representation of the value bytes including the type tag.
pub fn as_serialized(&self) -> &[u8] {
self.0.as_ref()
}
@@ -508,40 +519,6 @@ where B: AsRef<[u8]>
}
}
impl<B> Ord for Term<B>
where B: AsRef<[u8]>
{
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.serialized_term().cmp(other.serialized_term())
}
}
impl<B> PartialOrd for Term<B>
where B: AsRef<[u8]>
{
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl<B> PartialEq for Term<B>
where B: AsRef<[u8]>
{
fn eq(&self, other: &Self) -> bool {
self.serialized_term() == other.serialized_term()
}
}
impl<B> Eq for Term<B> where B: AsRef<[u8]> {}
impl<B> Hash for Term<B>
where B: AsRef<[u8]>
{
fn hash<H: Hasher>(&self, state: &mut H) {
self.0.as_ref().hash(state)
}
}
fn write_opt<T: std::fmt::Debug>(f: &mut fmt::Formatter, val_opt: Option<T>) -> fmt::Result {
if let Some(val) = val_opt {
write!(f, "{val:?}")?;
@@ -549,13 +526,11 @@ fn write_opt<T: std::fmt::Debug>(f: &mut fmt::Formatter, val_opt: Option<T>) ->
Ok(())
}
impl<B> fmt::Debug for Term<B>
where B: AsRef<[u8]>
{
impl fmt::Debug for Term {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let field_id = self.field().field_id();
let field_id = self.field.field_id();
write!(f, "Term(field={field_id}, ")?;
let value_bytes = ValueBytes::wrap(&self.0.as_ref()[4..]);
let value_bytes = ValueBytes::wrap(&self.serialized_value_bytes);
value_bytes.debug_value_bytes(f)?;
write!(f, ")",)?;
Ok(())
@@ -578,17 +553,6 @@ mod tests {
assert_eq!(term.value().as_str(), Some("test"))
}
/// Size (in bytes) of the buffer of a fast value (u64, i64, f64, or date) term.
/// <field> + <type byte> + <value len>
///
/// - <field> is a big endian encoded u32 field id
/// - <type_byte>'s most significant bit expresses whether the term is a json term or not The
/// remaining 7 bits are used to encode the type of the value. If this is a JSON term, the
/// type is the type of the leaf of the json.
/// - <value> is, if this is not the json term, a binary representation specific to the type.
/// If it is a JSON Term, then it is prepended with the path that leads to this leaf value.
const FAST_VALUE_TERM_LEN: usize = 4 + 1 + 8;
#[test]
pub fn test_term_u64() {
let mut schema_builder = Schema::builder();
@@ -596,7 +560,7 @@ mod tests {
let term = Term::from_field_u64(count_field, 983u64);
assert_eq!(term.field(), count_field);
assert_eq!(term.typ(), Type::U64);
assert_eq!(term.serialized_term().len(), FAST_VALUE_TERM_LEN);
assert_eq!(term.serialized_value_bytes().len(), 8);
assert_eq!(term.value().as_u64(), Some(983u64))
}
@@ -607,7 +571,7 @@ mod tests {
let term = Term::from_field_bool(bool_field, true);
assert_eq!(term.field(), bool_field);
assert_eq!(term.typ(), Type::Bool);
assert_eq!(term.serialized_term().len(), FAST_VALUE_TERM_LEN);
assert_eq!(term.serialized_value_bytes().len(), 8);
assert_eq!(term.value().as_bool(), Some(true))
}
}

View File

@@ -7,13 +7,14 @@
//! storage-level details into consideration. For example, if your file system block size is 4096
//! bytes, we can under-count actual resultant space usage by up to 4095 bytes per file.
use std::collections::HashMap;
use std::collections::btree_map::Entry;
use std::collections::BTreeMap;
use columnar::ColumnSpaceUsage;
use common::ByteCount;
use serde::{Deserialize, Serialize};
use crate::index::SegmentComponent;
use crate::schema::Field;
/// Enum containing any of the possible space usage results for segment components.
pub enum ComponentSpaceUsage {
@@ -212,17 +213,26 @@ impl StoreSpaceUsage {
/// Multiple indexes are used to handle variable length things, where
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PerFieldSpaceUsage {
fields: HashMap<Field, FieldUsage>,
fields: BTreeMap<String, FieldUsage>,
total: ByteCount,
}
impl PerFieldSpaceUsage {
pub(crate) fn new(fields: Vec<FieldUsage>) -> PerFieldSpaceUsage {
let total = fields.iter().map(FieldUsage::total).sum();
let field_usage_map: HashMap<Field, FieldUsage> = fields
.into_iter()
.map(|field_usage| (field_usage.field(), field_usage))
.collect();
let mut total = ByteCount::default();
let mut field_usage_map: BTreeMap<String, FieldUsage> = BTreeMap::new();
for field_usage in fields {
total += field_usage.total();
let field_name = field_usage.field_name().to_string();
match field_usage_map.entry(field_name) {
Entry::Vacant(entry) => {
entry.insert(field_usage);
}
Entry::Occupied(mut entry) => {
entry.get_mut().merge(field_usage);
}
}
}
PerFieldSpaceUsage {
fields: field_usage_map,
total,
@@ -230,8 +240,8 @@ impl PerFieldSpaceUsage {
}
/// Per field space usage
pub fn fields(&self) -> impl Iterator<Item = (&Field, &FieldUsage)> {
self.fields.iter()
pub fn fields(&self) -> impl Iterator<Item = &FieldUsage> {
self.fields.values()
}
/// Bytes used by the represented file
@@ -246,20 +256,23 @@ impl PerFieldSpaceUsage {
/// See documentation for [`PerFieldSpaceUsage`] for slightly more information.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct FieldUsage {
field: Field,
field_name: String,
num_bytes: ByteCount,
/// A field can be composed of more than one piece.
/// These pieces are indexed by arbitrary numbers starting at zero.
/// `self.num_bytes` includes all of `self.sub_num_bytes`.
sub_num_bytes: Vec<Option<ByteCount>>,
/// Space usage of the column for fast fields, if relevant.
column_space_usage: Option<ColumnSpaceUsage>,
}
impl FieldUsage {
pub(crate) fn empty(field: Field) -> FieldUsage {
pub(crate) fn empty(field_name: impl Into<String>) -> FieldUsage {
FieldUsage {
field,
field_name: field_name.into(),
num_bytes: Default::default(),
sub_num_bytes: Vec::new(),
column_space_usage: None,
}
}
@@ -272,9 +285,14 @@ impl FieldUsage {
self.num_bytes += size
}
pub(crate) fn set_column_usage(&mut self, column_space_usage: ColumnSpaceUsage) {
self.num_bytes += column_space_usage.total_num_bytes();
self.column_space_usage = Some(column_space_usage);
}
/// Field
pub fn field(&self) -> Field {
self.field
pub fn field_name(&self) -> &str {
&self.field_name
}
/// Space usage for each index
@@ -282,16 +300,64 @@ impl FieldUsage {
&self.sub_num_bytes[..]
}
/// Returns the number of bytes used by the column payload, if the field is columnar.
pub fn column_num_bytes(&self) -> Option<ByteCount> {
self.column_space_usage
.as_ref()
.map(ColumnSpaceUsage::column_num_bytes)
}
/// Returns the number of bytes used by the dictionary for dictionary-encoded columns.
pub fn dictionary_num_bytes(&self) -> Option<ByteCount> {
self.column_space_usage
.as_ref()
.and_then(ColumnSpaceUsage::dictionary_num_bytes)
}
/// Returns the space usage of the column, if any.
pub fn column_space_usage(&self) -> Option<&ColumnSpaceUsage> {
self.column_space_usage.as_ref()
}
/// Total bytes used for this field in this context
pub fn total(&self) -> ByteCount {
self.num_bytes
}
fn merge(&mut self, other: FieldUsage) {
assert_eq!(self.field_name, other.field_name);
self.num_bytes += other.num_bytes;
if other.sub_num_bytes.len() > self.sub_num_bytes.len() {
self.sub_num_bytes.resize(other.sub_num_bytes.len(), None);
}
for (idx, num_bytes_opt) in other.sub_num_bytes.into_iter().enumerate() {
if let Some(num_bytes) = num_bytes_opt {
match self.sub_num_bytes[idx] {
Some(existing) => self.sub_num_bytes[idx] = Some(existing + num_bytes),
None => self.sub_num_bytes[idx] = Some(num_bytes),
}
}
}
self.column_space_usage =
merge_column_space_usage(self.column_space_usage.take(), other.column_space_usage);
}
}
fn merge_column_space_usage(
left: Option<ColumnSpaceUsage>,
right: Option<ColumnSpaceUsage>,
) -> Option<ColumnSpaceUsage> {
match (left, right) {
(Some(lhs), Some(rhs)) => Some(lhs.merge(&rhs)),
(Some(space), None) | (None, Some(space)) => Some(space),
(None, None) => None,
}
}
#[cfg(test)]
mod test {
use crate::index::Index;
use crate::schema::{Field, Schema, FAST, INDEXED, STORED, TEXT};
use crate::schema::{Schema, FAST, INDEXED, STORED, TEXT};
use crate::space_usage::PerFieldSpaceUsage;
use crate::{IndexWriter, Term};
@@ -307,17 +373,17 @@ mod test {
fn expect_single_field(
field_space: &PerFieldSpaceUsage,
field: &Field,
field: &str,
min_size: u64,
max_size: u64,
) {
assert!(field_space.total() >= min_size);
assert!(field_space.total() <= max_size);
assert_eq!(
vec![(field, field_space.total())],
vec![(field.to_string(), field_space.total())],
field_space
.fields()
.map(|(x, y)| (x, y.total()))
.map(|usage| (usage.field_name().to_string(), usage.total()))
.collect::<Vec<_>>()
);
}
@@ -327,6 +393,7 @@ mod test {
let mut schema_builder = Schema::builder();
let name = schema_builder.add_u64_field("name", FAST | INDEXED);
let schema = schema_builder.build();
let field_name = schema.get_field_name(name).to_string();
let index = Index::create_in_ram(schema);
{
@@ -349,11 +416,11 @@ mod test {
assert_eq!(4, segment.num_docs());
expect_single_field(segment.termdict(), &name, 1, 512);
expect_single_field(segment.postings(), &name, 1, 512);
expect_single_field(segment.termdict(), &field_name, 1, 512);
expect_single_field(segment.postings(), &field_name, 1, 512);
assert_eq!(segment.positions().total(), 0);
expect_single_field(segment.fast_fields(), &name, 1, 512);
expect_single_field(segment.fieldnorms(), &name, 1, 512);
expect_single_field(segment.fast_fields(), &field_name, 1, 512);
expect_single_field(segment.fieldnorms(), &field_name, 1, 512);
// TODO: understand why the following fails
// assert_eq!(0, segment.store().total());
assert_eq!(segment.deletes(), 0);
@@ -365,6 +432,7 @@ mod test {
let mut schema_builder = Schema::builder();
let name = schema_builder.add_text_field("name", TEXT);
let schema = schema_builder.build();
let field_name = schema.get_field_name(name).to_string();
let index = Index::create_in_ram(schema);
{
@@ -389,11 +457,11 @@ mod test {
assert_eq!(4, segment.num_docs());
expect_single_field(segment.termdict(), &name, 1, 512);
expect_single_field(segment.postings(), &name, 1, 512);
expect_single_field(segment.positions(), &name, 1, 512);
expect_single_field(segment.termdict(), &field_name, 1, 512);
expect_single_field(segment.postings(), &field_name, 1, 512);
expect_single_field(segment.positions(), &field_name, 1, 512);
assert_eq!(segment.fast_fields().total(), 0);
expect_single_field(segment.fieldnorms(), &name, 1, 512);
expect_single_field(segment.fieldnorms(), &field_name, 1, 512);
// TODO: understand why the following fails
// assert_eq!(0, segment.store().total());
assert_eq!(segment.deletes(), 0);
@@ -429,10 +497,15 @@ mod test {
assert_eq!(4, segment.num_docs());
assert_eq!(segment.termdict().total(), 0);
assert!(segment.termdict().fields().next().is_none());
assert_eq!(segment.postings().total(), 0);
assert!(segment.postings().fields().next().is_none());
assert_eq!(segment.positions().total(), 0);
assert!(segment.positions().fields().next().is_none());
assert_eq!(segment.fast_fields().total(), 0);
assert!(segment.fast_fields().fields().next().is_none());
assert_eq!(segment.fieldnorms().total(), 0);
assert!(segment.fieldnorms().fields().next().is_none());
assert!(segment.store().total() > 0);
assert!(segment.store().total() < 512);
assert_eq!(segment.deletes(), 0);
@@ -444,6 +517,7 @@ mod test {
let mut schema_builder = Schema::builder();
let name = schema_builder.add_u64_field("name", INDEXED);
let schema = schema_builder.build();
let field_name = schema.get_field_name(name).to_string();
let index = Index::create_in_ram(schema);
{
@@ -474,11 +548,11 @@ mod test {
assert_eq!(2, segment_space_usage.num_docs());
expect_single_field(segment_space_usage.termdict(), &name, 1, 512);
expect_single_field(segment_space_usage.postings(), &name, 1, 512);
expect_single_field(segment_space_usage.termdict(), &field_name, 1, 512);
expect_single_field(segment_space_usage.postings(), &field_name, 1, 512);
assert_eq!(segment_space_usage.positions().total(), 0u64);
assert_eq!(segment_space_usage.fast_fields().total(), 0u64);
expect_single_field(segment_space_usage.fieldnorms(), &name, 1, 512);
expect_single_field(segment_space_usage.fieldnorms(), &field_name, 1, 512);
assert!(segment_space_usage.deletes() > 0);
Ok(())
}

View File

@@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize};
use super::{Token, TokenFilter, TokenStream, Tokenizer};
/// Available stemmer languages.
#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Copy, Clone)]
#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Copy, Clone, Hash)]
#[allow(missing_docs)]
pub enum Language {
Arabic,

View File

@@ -8,7 +8,7 @@ use std::sync::Arc;
use common::bounds::{TransformBound, transform_bound_inner_res};
use common::file_slice::FileSlice;
use common::{BinarySerializable, OwnedBytes};
use common::{BinarySerializable, ByteCount, OwnedBytes};
use futures_util::{StreamExt, TryStreamExt, stream};
use itertools::Itertools;
use tantivy_fst::Automaton;
@@ -43,6 +43,7 @@ use crate::{
pub struct Dictionary<TSSTable: SSTable = VoidSSTable> {
pub sstable_slice: FileSlice,
pub sstable_index: SSTableIndex,
num_bytes: ByteCount,
num_terms: u64,
phantom_data: PhantomData<TSSTable>,
}
@@ -278,6 +279,7 @@ impl<TSSTable: SSTable> Dictionary<TSSTable> {
/// Opens a `TermDictionary`.
pub fn open(term_dictionary_file: FileSlice) -> io::Result<Self> {
let num_bytes = term_dictionary_file.num_bytes();
let (main_slice, footer_len_slice) = term_dictionary_file.split_from_end(20);
let mut footer_len_bytes: OwnedBytes = footer_len_slice.read_bytes()?;
let index_offset = u64::deserialize(&mut footer_len_bytes)?;
@@ -317,6 +319,7 @@ impl<TSSTable: SSTable> Dictionary<TSSTable> {
Ok(Dictionary {
sstable_slice,
sstable_index,
num_bytes,
num_terms,
phantom_data: PhantomData,
})
@@ -343,6 +346,11 @@ impl<TSSTable: SSTable> Dictionary<TSSTable> {
self.num_terms as usize
}
/// Returns the total number of bytes used by the dictionary on disk.
pub fn num_bytes(&self) -> ByteCount {
self.num_bytes
}
/// Decode a DeltaReader up to key, returning the number of terms traversed
///
/// If the key was not found, returns Ok(None).

View File

@@ -11,7 +11,6 @@ description = "term hashmap used for indexing"
murmurhash32 = "0.3"
common = { version = "0.10", path = "../common/", package = "tantivy-common" }
ahash = { version = "0.8.11", default-features = false, optional = true }
rand_distr = "0.4.3"
[[bench]]
@@ -29,6 +28,7 @@ zipf = "7.0.0"
rustc-hash = "2.1.0"
proptest = "1.2.0"
binggan = { version = "0.14.0" }
rand_distr = "0.4.3"
[features]
compare_hash_only = ["ahash"] # Compare hash only, not the key in the Hashmap

View File

@@ -5,7 +5,7 @@ use common::serialize_vint_u32;
use crate::fastcpy::fast_short_slice_copy;
use crate::{Addr, MemoryArena};
const FIRST_BLOCK_NUM: u16 = 2;
const FIRST_BLOCK_NUM: u32 = 2;
/// An exponential unrolled link.
///
@@ -33,8 +33,8 @@ pub struct ExpUnrolledLinkedList {
// u16, since the max size of each block is (1<<next_cap_pow_2)
// Limited to 15, so we don't overflow remaining_cap.
remaining_cap: u16,
// To get the current number of blocks: block_num - FIRST_BLOCK_NUM
block_num: u16,
// Tracks the number of blocks allocated: block_num - FIRST_BLOCK_NUM
block_num: u32,
head: Addr,
tail: Addr,
}
@@ -110,16 +110,27 @@ impl ExpUnrolledLinkedListWriter<'_> {
}
}
// The block size is 2^block_num + 2, but max 2^15= 32k
// Initial size is 8, for the first block => block_num == 1
// The block size is 2^block_num, but max 2^15 = 32KB
// Initial size is 8 bytes (2^3), for the first block => block_num == 2
// Block size caps at 32KB (2^15) regardless of how high block_num goes
#[inline]
fn get_block_size(block_num: u16) -> u16 {
1 << block_num.min(15)
fn get_block_size(block_num: u32) -> u16 {
// Cap at 15 to prevent block sizes > 32KB
// block_num can now be much larger than 15, but block size maxes out
let exp = block_num.min(15) as u32;
(1u32 << exp) as u16
}
impl ExpUnrolledLinkedList {
#[inline(always)]
pub fn increment_num_blocks(&mut self) {
self.block_num += 1;
// Add overflow check as a safety measure
// With u32, we can handle up to ~4 billion blocks before overflow
// At 32KB per block (max size), that's 128 TB of data
self.block_num = self
.block_num
.checked_add(1)
.expect("ExpUnrolledLinkedList block count overflow - exceeded 4 billion blocks");
}
#[inline]
@@ -132,9 +143,26 @@ impl ExpUnrolledLinkedList {
if addr.is_null() {
return;
}
let last_block_len = get_block_size(self.block_num) as usize - self.remaining_cap as usize;
// Full Blocks
// Calculate last block length with bounds checking to prevent underflow
let block_size = get_block_size(self.block_num) as usize;
let last_block_len = block_size.saturating_sub(self.remaining_cap as usize);
// Safety check: if remaining_cap > block_size, the metadata is corrupted
assert!(
self.remaining_cap as usize <= block_size,
"ExpUnrolledLinkedList metadata corruption detected: remaining_cap ({}) > block_size \
({}). This indicates a serious bug, please report! (block_num={}, head={:?}, \
tail={:?})",
self.remaining_cap,
block_size,
self.block_num,
self.head,
self.tail
);
// Full Blocks (iterate through all blocks except the last one)
// Note: Blocks are numbered starting from FIRST_BLOCK_NUM+1 (=3) after first allocation
for block_num in FIRST_BLOCK_NUM + 1..self.block_num {
let cap = get_block_size(block_num) as usize;
let data = arena.slice(addr, cap);
@@ -259,6 +287,177 @@ mod tests {
assert_eq!(&vec1[..], &res1[..]);
assert_eq!(&vec2[..], &res2[..]);
}
// Tests for u32 block_num fix (issue with large arrays)
#[test]
fn test_block_num_exceeds_u16_max() {
// Test that we can handle more than 65,535 blocks (old u16 limit)
let mut eull = ExpUnrolledLinkedList::default();
// Simulate allocating 70,000 blocks (exceeds u16::MAX of 65,535)
for _ in 0..70_000 {
eull.increment_num_blocks();
}
// Verify block_num is correct
assert_eq!(eull.block_num, FIRST_BLOCK_NUM + 70_000);
// Verify we can still get block size (should be capped at 32KB)
let block_size = get_block_size(eull.block_num);
assert_eq!(block_size, 1 << 15); // 32KB max
}
#[test]
fn test_large_dataset_simulation() {
// Simulate the scenario: large arrays requiring many blocks
// We write enough data to require thousands of blocks
let mut arena = MemoryArena::default();
let mut eull = ExpUnrolledLinkedList::default();
// Write 100 MB of data (this will require ~3,200 blocks at 32KB each)
// This is enough to validate the system works with large datasets
// but not so much that the test is slow
let bytes_per_write = 10_000;
let num_writes = 10_000; // 10k * 10k = 100 MB
let data: Vec<u8> = (0..bytes_per_write).map(|i| (i % 256) as u8).collect();
for _ in 0..num_writes {
eull.writer(&mut arena).extend_from_slice(&data);
}
// Verify we allocated many blocks (should be in the thousands)
assert!(
eull.block_num > 1000,
"block_num ({}) should be > 1000 for this much data",
eull.block_num
);
// Verify we can read back correctly
let mut buffer = Vec::new();
eull.read_to_end(&arena, &mut buffer);
assert_eq!(buffer.len(), bytes_per_write * num_writes);
// Verify data integrity on a sample
for i in 0..bytes_per_write {
assert_eq!(buffer[i], (i % 256) as u8);
}
}
#[test]
fn test_get_block_size_with_large_block_num() {
// Test that get_block_size handles large u32 values correctly
// Small block numbers (under 15)
assert_eq!(get_block_size(2), 4); // 2^2 = 4
assert_eq!(get_block_size(3), 8); // 2^3 = 8
assert_eq!(get_block_size(10), 1024); // 2^10 = 1KB
// At the cap (15)
assert_eq!(get_block_size(15), 32768); // 2^15 = 32KB
// Beyond the cap (should stay at 32KB)
assert_eq!(get_block_size(16), 32768);
assert_eq!(get_block_size(100), 32768);
assert_eq!(get_block_size(65_536), 32768); // Old u16::MAX + 1
assert_eq!(get_block_size(100_000), 32768);
assert_eq!(get_block_size(1_000_000), 32768);
}
#[test]
fn test_increment_blocks_near_u16_boundary() {
// Test incrementing around the old u16::MAX boundary
let mut eull = ExpUnrolledLinkedList::default();
// Set to just before old limit
for _ in 0..65_533 {
eull.increment_num_blocks();
}
assert_eq!(eull.block_num, FIRST_BLOCK_NUM + 65_533);
// Cross the old u16::MAX boundary (this would have overflowed before)
eull.increment_num_blocks(); // 65,534
eull.increment_num_blocks(); // 65,535 (old max)
eull.increment_num_blocks(); // 65,536 (would overflow u16)
eull.increment_num_blocks(); // 65,537
// Verify we're past the old limit
assert_eq!(eull.block_num, FIRST_BLOCK_NUM + 65_537);
}
#[test]
fn test_write_and_read_with_many_blocks() {
// Test that write/read works correctly with many blocks
let mut arena = MemoryArena::default();
let mut eull = ExpUnrolledLinkedList::default();
// Write data that will span many blocks
let test_data: Vec<u8> = (0..50_000).map(|i| (i % 256) as u8).collect();
eull.writer(&mut arena).extend_from_slice(&test_data);
// Read it back
let mut buffer = Vec::new();
eull.read_to_end(&arena, &mut buffer);
// Verify data integrity
assert_eq!(buffer.len(), test_data.len());
assert_eq!(&buffer[..], &test_data[..]);
}
#[test]
fn test_multiple_eull_with_large_block_counts() {
// Test multiple ExpUnrolledLinkedLists with high block counts
// (simulates parallel columnar writes)
let mut arena = MemoryArena::default();
let mut eull1 = ExpUnrolledLinkedList::default();
let mut eull2 = ExpUnrolledLinkedList::default();
// Write different data to each
for i in 0..10_000u32 {
eull1.writer(&mut arena).write_u32_vint(i);
eull2.writer(&mut arena).write_u32_vint(i * 2);
}
// Read back and verify
let mut buf1 = Vec::new();
let mut buf2 = Vec::new();
eull1.read_to_end(&arena, &mut buf1);
eull2.read_to_end(&arena, &mut buf2);
// Deserialize and check
let mut cursor1 = &buf1[..];
let mut cursor2 = &buf2[..];
for i in 0..10_000u32 {
assert_eq!(read_u32_vint(&mut cursor1), i);
assert_eq!(read_u32_vint(&mut cursor2), i * 2);
}
}
#[test]
fn test_block_size_stays_capped() {
// Verify that even with massive block numbers, size stays at 32KB
let mut eull = ExpUnrolledLinkedList::default();
// Increment to a very large number
for _ in 0..200_000 {
eull.increment_num_blocks();
}
let block_size = get_block_size(eull.block_num);
assert_eq!(block_size, 32768, "Block size should be capped at 32KB");
}
#[test]
#[should_panic(expected = "ExpUnrolledLinkedList block count overflow")]
fn test_increment_overflow_protection() {
// Test that we panic gracefully if we somehow hit u32::MAX
// This is extremely unlikely in practice (would require 128TB of data)
let mut eull = ExpUnrolledLinkedList::default();
eull.block_num = u32::MAX;
// This should panic with our custom error message
eull.increment_num_blocks();
}
}
#[cfg(all(test, feature = "unstable"))]