Compare commits

..

1 Commits

Author SHA1 Message Date
Pascal Seitz
e2dae2f433 fix out of order bug 2024-06-25 08:35:58 +08:00
54 changed files with 1221 additions and 2513 deletions

View File

@@ -15,11 +15,11 @@ jobs:
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Install Rust - name: Install Rust
run: rustup toolchain install nightly-2024-07-01 --profile minimal --component llvm-tools-preview run: rustup toolchain install nightly-2024-04-10 --profile minimal --component llvm-tools-preview
- uses: Swatinem/rust-cache@v2 - uses: Swatinem/rust-cache@v2
- uses: taiki-e/install-action@cargo-llvm-cov - uses: taiki-e/install-action@cargo-llvm-cov
- name: Generate code coverage - name: Generate code coverage
run: cargo +nightly-2024-07-01 llvm-cov --all-features --workspace --doctests --lcov --output-path lcov.info run: cargo +nightly-2024-04-10 llvm-cov --all-features --workspace --doctests --lcov --output-path lcov.info
- name: Upload coverage to Codecov - name: Upload coverage to Codecov
uses: codecov/codecov-action@v3 uses: codecov/codecov-action@v3
continue-on-error: true continue-on-error: true

View File

@@ -11,7 +11,7 @@ repository = "https://github.com/quickwit-oss/tantivy"
readme = "README.md" readme = "README.md"
keywords = ["search", "information", "retrieval"] keywords = ["search", "information", "retrieval"]
edition = "2021" edition = "2021"
rust-version = "1.66" rust-version = "1.63"
exclude = ["benches/*.json", "benches/*.txt"] exclude = ["benches/*.json", "benches/*.txt"]
[dependencies] [dependencies]
@@ -38,7 +38,7 @@ levenshtein_automata = "0.2.1"
uuid = { version = "1.0.0", features = ["v4", "serde"] } uuid = { version = "1.0.0", features = ["v4", "serde"] }
crossbeam-channel = "0.5.4" crossbeam-channel = "0.5.4"
rust-stemmers = "1.2.0" rust-stemmers = "1.2.0"
downcast-rs = "1.2.1" downcast-rs = "1.2.0"
bitpacking = { version = "0.9.2", default-features = false, features = [ bitpacking = { version = "0.9.2", default-features = false, features = [
"bitpacker4x", "bitpacker4x",
] } ] }
@@ -64,7 +64,6 @@ tantivy-bitpacker = { version = "0.6", path = "./bitpacker" }
common = { version = "0.7", path = "./common/", package = "tantivy-common" } common = { version = "0.7", path = "./common/", package = "tantivy-common" }
tokenizer-api = { version = "0.3", path = "./tokenizer-api", package = "tantivy-tokenizer-api" } tokenizer-api = { version = "0.3", path = "./tokenizer-api", package = "tantivy-tokenizer-api" }
sketches-ddsketch = { version = "0.3.0", features = ["use_serde"] } sketches-ddsketch = { version = "0.3.0", features = ["use_serde"] }
hyperloglogplus = { version = "0.4.1", features = ["const-loop"] }
futures-util = { version = "0.3.28", optional = true } futures-util = { version = "0.3.28", optional = true }
fnv = "1.0.7" fnv = "1.0.7"

View File

@@ -18,7 +18,7 @@ Tantivy is, in fact, strongly inspired by Lucene's design.
## Benchmark ## Benchmark
The following [benchmark](https://tantivy-search.github.io/bench/) breaks down the The following [benchmark](https://tantivy-search.github.io/bench/) breakdowns
performance for different types of queries/collections. performance for different types of queries/collections.
Your mileage WILL vary depending on the nature of queries and their load. Your mileage WILL vary depending on the nature of queries and their load.

View File

@@ -51,15 +51,10 @@ fn bench_agg(mut group: InputGroup<Index>) {
register!(group, percentiles_f64); register!(group, percentiles_f64);
register!(group, terms_few); register!(group, terms_few);
register!(group, terms_many); register!(group, terms_many);
register!(group, terms_many_top_1000);
register!(group, terms_many_order_by_term); register!(group, terms_many_order_by_term);
register!(group, terms_many_with_top_hits); register!(group, terms_many_with_top_hits);
register!(group, terms_many_with_avg_sub_agg); register!(group, terms_many_with_avg_sub_agg);
register!(group, terms_many_json_mixed_type_with_avg_sub_agg); register!(group, terms_many_json_mixed_type_with_sub_agg_card);
register!(group, cardinality_agg);
register!(group, terms_few_with_cardinality_agg);
register!(group, range_agg); register!(group, range_agg);
register!(group, range_agg_with_avg_sub_agg); register!(group, range_agg_with_avg_sub_agg);
register!(group, range_agg_with_term_agg_few); register!(group, range_agg_with_term_agg_few);
@@ -128,33 +123,6 @@ fn percentiles_f64(index: &Index) {
}); });
execute_agg(index, agg_req); execute_agg(index, agg_req);
} }
fn cardinality_agg(index: &Index) {
let agg_req = json!({
"cardinality": {
"cardinality": {
"field": "text_many_terms"
},
}
});
execute_agg(index, agg_req);
}
fn terms_few_with_cardinality_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms" },
"aggs": {
"cardinality": {
"cardinality": {
"field": "text_many_terms"
},
}
}
},
});
execute_agg(index, agg_req);
}
fn terms_few(index: &Index) { fn terms_few(index: &Index) {
let agg_req = json!({ let agg_req = json!({
"my_texts": { "terms": { "field": "text_few_terms" } }, "my_texts": { "terms": { "field": "text_few_terms" } },
@@ -167,12 +135,6 @@ fn terms_many(index: &Index) {
}); });
execute_agg(index, agg_req); execute_agg(index, agg_req);
} }
fn terms_many_top_1000(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_many_terms", "size": 1000 } },
});
execute_agg(index, agg_req);
}
fn terms_many_order_by_term(index: &Index) { fn terms_many_order_by_term(index: &Index) {
let agg_req = json!({ let agg_req = json!({
"my_texts": { "terms": { "field": "text_many_terms", "order": { "_key": "desc" } } }, "my_texts": { "terms": { "field": "text_many_terms", "order": { "_key": "desc" } } },
@@ -209,7 +171,7 @@ fn terms_many_with_avg_sub_agg(index: &Index) {
}); });
execute_agg(index, agg_req); execute_agg(index, agg_req);
} }
fn terms_many_json_mixed_type_with_avg_sub_agg(index: &Index) { fn terms_many_json_mixed_type_with_sub_agg_card(index: &Index) {
let agg_req = json!({ let agg_req = json!({
"my_texts": { "my_texts": {
"terms": { "field": "json.mixed_type" }, "terms": { "field": "json.mixed_type" },
@@ -306,7 +268,6 @@ fn range_agg_with_term_agg_many(index: &Index) {
}); });
execute_agg(index, agg_req); execute_agg(index, agg_req);
} }
fn histogram(index: &Index) { fn histogram(index: &Index) {
let agg_req = json!({ let agg_req = json!({
"rangef64": { "rangef64": {

View File

@@ -34,7 +34,6 @@ fn compute_stats(vals: impl Iterator<Item = u64>) -> ColumnStats {
fn value_iter() -> impl Iterator<Item = u64> { fn value_iter() -> impl Iterator<Item = u64> {
0..20_000 0..20_000
} }
fn get_reader_for_bench<Codec: ColumnCodec>(data: &[u64]) -> Codec::ColumnValues { fn get_reader_for_bench<Codec: ColumnCodec>(data: &[u64]) -> Codec::ColumnValues {
let mut bytes = Vec::new(); let mut bytes = Vec::new();
let stats = compute_stats(data.iter().cloned()); let stats = compute_stats(data.iter().cloned());
@@ -42,13 +41,10 @@ fn get_reader_for_bench<Codec: ColumnCodec>(data: &[u64]) -> Codec::ColumnValues
for val in data { for val in data {
codec_serializer.collect(*val); codec_serializer.collect(*val);
} }
codec_serializer codec_serializer.serialize(&stats, Box::new(data.iter().copied()).as_mut(), &mut bytes);
.serialize(&stats, Box::new(data.iter().copied()).as_mut(), &mut bytes)
.unwrap();
Codec::load(OwnedBytes::new(bytes)).unwrap() Codec::load(OwnedBytes::new(bytes)).unwrap()
} }
fn bench_get<Codec: ColumnCodec>(b: &mut Bencher, data: &[u64]) { fn bench_get<Codec: ColumnCodec>(b: &mut Bencher, data: &[u64]) {
let col = get_reader_for_bench::<Codec>(data); let col = get_reader_for_bench::<Codec>(data);
b.iter(|| { b.iter(|| {

View File

@@ -8,7 +8,6 @@ use std::net::Ipv6Addr;
use column_operation::ColumnOperation; use column_operation::ColumnOperation;
pub(crate) use column_writers::CompatibleNumericalTypes; pub(crate) use column_writers::CompatibleNumericalTypes;
use common::json_path_writer::JSON_END_OF_PATH;
use common::CountingWriter; use common::CountingWriter;
pub(crate) use serializer::ColumnarSerializer; pub(crate) use serializer::ColumnarSerializer;
use stacker::{Addr, ArenaHashMap, MemoryArena}; use stacker::{Addr, ArenaHashMap, MemoryArena};
@@ -284,17 +283,12 @@ impl ColumnarWriter {
.iter() .iter()
.map(|(column_name, addr)| (column_name, ColumnType::DateTime, addr)), .map(|(column_name, addr)| (column_name, ColumnType::DateTime, addr)),
); );
// TODO: replace JSON_END_OF_PATH with b'0' in columns
columns.sort_unstable_by_key(|(column_name, col_type, _)| (*column_name, *col_type)); columns.sort_unstable_by_key(|(column_name, col_type, _)| (*column_name, *col_type));
let (arena, buffers, dictionaries) = (&self.arena, &mut self.buffers, &self.dictionaries); let (arena, buffers, dictionaries) = (&self.arena, &mut self.buffers, &self.dictionaries);
let mut symbol_byte_buffer: Vec<u8> = Vec::new(); let mut symbol_byte_buffer: Vec<u8> = Vec::new();
for (column_name, column_type, addr) in columns { for (column_name, column_type, addr) in columns {
if column_name.contains(&JSON_END_OF_PATH) {
// Tantivy uses b'0' as a separator for nested fields in JSON.
// Column names with a b'0' are not simply ignored by the columnar (and the inverted
// index).
continue;
}
match column_type { match column_type {
ColumnType::Bool => { ColumnType::Bool => {
let column_writer: ColumnWriter = self.bool_field_hash_map.read(addr); let column_writer: ColumnWriter = self.bool_field_hash_map.read(addr);

View File

@@ -93,3 +93,18 @@ impl<'a, W: io::Write> io::Write for ColumnSerializer<'a, W> {
self.columnar_serializer.wrt.write_all(buf) self.columnar_serializer.wrt.write_all(buf)
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_prepare_key_bytes() {
let mut buffer: Vec<u8> = b"somegarbage".to_vec();
prepare_key(b"root\0child", ColumnType::Str, &mut buffer);
assert_eq!(buffer.len(), 12);
assert_eq!(&buffer[..10], b"root0child");
assert_eq!(buffer[10], 0u8);
assert_eq!(buffer[11], ColumnType::Str.to_code());
}
}

View File

@@ -9,6 +9,7 @@ documentation = "https://docs.rs/tantivy_common/"
homepage = "https://github.com/quickwit-oss/tantivy" homepage = "https://github.com/quickwit-oss/tantivy"
repository = "https://github.com/quickwit-oss/tantivy" repository = "https://github.com/quickwit-oss/tantivy"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
@@ -19,7 +20,8 @@ time = { version = "0.3.10", features = ["serde-well-known"] }
serde = { version = "1.0.136", features = ["derive"] } serde = { version = "1.0.136", features = ["derive"] }
[dev-dependencies] [dev-dependencies]
binggan = "0.8.1"
proptest = "1.0.0" proptest = "1.0.0"
rand = "0.8.4" rand = "0.8.4"
[features]
unstable = [] # useful for benches.

View File

@@ -1,64 +1,39 @@
use binggan::{black_box, BenchRunner}; #![feature(test)]
use rand::seq::IteratorRandom;
use rand::thread_rng;
use tantivy_common::{serialize_vint_u32, BitSet, TinySet};
fn bench_vint() { extern crate test;
let mut runner = BenchRunner::new();
let vals: Vec<u32> = (0..20_000).collect(); #[cfg(test)]
runner.bench_function("bench_vint", move |_| { mod tests {
let mut out = 0u64; use rand::seq::IteratorRandom;
for val in vals.iter().cloned() { use rand::thread_rng;
let mut buf = [0u8; 8]; use tantivy_common::serialize_vint_u32;
serialize_vint_u32(val, &mut buf); use test::Bencher;
out += u64::from(buf[0]);
}
black_box(out);
});
let vals: Vec<u32> = (0..20_000).choose_multiple(&mut thread_rng(), 100_000); #[bench]
runner.bench_function("bench_vint_rand", move |_| { fn bench_vint(b: &mut Bencher) {
let mut out = 0u64; let vals: Vec<u32> = (0..20_000).collect();
for val in vals.iter().cloned() { b.iter(|| {
let mut buf = [0u8; 8]; let mut out = 0u64;
serialize_vint_u32(val, &mut buf); for val in vals.iter().cloned() {
out += u64::from(buf[0]); let mut buf = [0u8; 8];
} serialize_vint_u32(val, &mut buf);
black_box(out); out += u64::from(buf[0]);
}); }
} out
});
fn bench_bitset() { }
let mut runner = BenchRunner::new();
#[bench]
runner.bench_function("bench_tinyset_pop", move |_| { fn bench_vint_rand(b: &mut Bencher) {
let mut tinyset = TinySet::singleton(black_box(31u32)); let vals: Vec<u32> = (0..20_000).choose_multiple(&mut thread_rng(), 100_000);
tinyset.pop_lowest(); b.iter(|| {
tinyset.pop_lowest(); let mut out = 0u64;
tinyset.pop_lowest(); for val in vals.iter().cloned() {
tinyset.pop_lowest(); let mut buf = [0u8; 8];
tinyset.pop_lowest(); serialize_vint_u32(val, &mut buf);
tinyset.pop_lowest(); out += u64::from(buf[0]);
black_box(tinyset); }
}); out
});
let tiny_set = TinySet::empty().insert(10u32).insert(14u32).insert(21u32); }
runner.bench_function("bench_tinyset_sum", move |_| {
assert_eq!(black_box(tiny_set).into_iter().sum::<u32>(), 45u32);
});
let v = [10u32, 14u32, 21u32];
runner.bench_function("bench_tinyarr_sum", move |_| {
black_box(v.iter().cloned().sum::<u32>());
});
runner.bench_function("bench_bitset_initialize", move |_| {
black_box(BitSet::with_max_value(1_000_000));
});
}
fn main() {
bench_vint();
bench_bitset();
} }

View File

@@ -696,3 +696,43 @@ mod tests {
} }
} }
} }
#[cfg(all(test, feature = "unstable"))]
mod bench {
use test;
use super::{BitSet, TinySet};
#[bench]
fn bench_tinyset_pop(b: &mut test::Bencher) {
b.iter(|| {
let mut tinyset = TinySet::singleton(test::black_box(31u32));
tinyset.pop_lowest();
tinyset.pop_lowest();
tinyset.pop_lowest();
tinyset.pop_lowest();
tinyset.pop_lowest();
tinyset.pop_lowest();
});
}
#[bench]
fn bench_tinyset_sum(b: &mut test::Bencher) {
let tiny_set = TinySet::empty().insert(10u32).insert(14u32).insert(21u32);
b.iter(|| {
assert_eq!(test::black_box(tiny_set).into_iter().sum::<u32>(), 45u32);
});
}
#[bench]
fn bench_tinyarr_sum(b: &mut test::Bencher) {
let v = [10u32, 14u32, 21u32];
b.iter(|| test::black_box(v).iter().cloned().sum::<u32>());
}
#[bench]
fn bench_bitset_initialize(b: &mut test::Bencher) {
b.iter(|| BitSet::with_max_value(1_000_000));
}
}

View File

@@ -1,5 +1,3 @@
use std::ops::Bound;
// # Searching a range on an indexed int field. // # Searching a range on an indexed int field.
// //
// Below is an example of creating an indexed integer field in your schema // Below is an example of creating an indexed integer field in your schema
@@ -7,7 +5,7 @@ use std::ops::Bound;
use tantivy::collector::Count; use tantivy::collector::Count;
use tantivy::query::RangeQuery; use tantivy::query::RangeQuery;
use tantivy::schema::{Schema, INDEXED}; use tantivy::schema::{Schema, INDEXED};
use tantivy::{doc, Index, IndexWriter, Result, Term}; use tantivy::{doc, Index, IndexWriter, Result};
fn main() -> Result<()> { fn main() -> Result<()> {
// For the sake of simplicity, this schema will only have 1 field // For the sake of simplicity, this schema will only have 1 field
@@ -29,10 +27,7 @@ fn main() -> Result<()> {
reader.reload()?; reader.reload()?;
let searcher = reader.searcher(); let searcher = reader.searcher();
// The end is excluded i.e. here we are searching up to 1969 // The end is excluded i.e. here we are searching up to 1969
let docs_in_the_sixties = RangeQuery::new( let docs_in_the_sixties = RangeQuery::new_u64("year".to_string(), 1960..1970);
Bound::Included(Term::from_field_u64(year_field, 1960)),
Bound::Excluded(Term::from_field_u64(year_field, 1970)),
);
// Uses a Count collector to sum the total number of docs in the range // Uses a Count collector to sum the total number of docs in the range
let num_60s_books = searcher.search(&docs_in_the_sixties, &Count)?; let num_60s_books = searcher.search(&docs_in_the_sixties, &Count)?;
assert_eq!(num_60s_books, 10); assert_eq!(num_60s_books, 10);

View File

@@ -34,9 +34,8 @@ use super::bucket::{
DateHistogramAggregationReq, HistogramAggregation, RangeAggregation, TermsAggregation, DateHistogramAggregationReq, HistogramAggregation, RangeAggregation, TermsAggregation,
}; };
use super::metric::{ use super::metric::{
AverageAggregation, CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation, AverageAggregation, CountAggregation, ExtendedStatsAggregation, MaxAggregation, MinAggregation,
MaxAggregation, MinAggregation, PercentilesAggregationReq, StatsAggregation, SumAggregation, PercentilesAggregationReq, StatsAggregation, SumAggregation, TopHitsAggregation,
TopHitsAggregationReq,
}; };
/// The top-level aggregation request structure, which contains [`Aggregation`] and their user /// The top-level aggregation request structure, which contains [`Aggregation`] and their user
@@ -160,10 +159,7 @@ pub enum AggregationVariants {
Percentiles(PercentilesAggregationReq), Percentiles(PercentilesAggregationReq),
/// Finds the top k values matching some order /// Finds the top k values matching some order
#[serde(rename = "top_hits")] #[serde(rename = "top_hits")]
TopHits(TopHitsAggregationReq), TopHits(TopHitsAggregation),
/// Computes an estimate of the number of unique values
#[serde(rename = "cardinality")]
Cardinality(CardinalityAggregationReq),
} }
impl AggregationVariants { impl AggregationVariants {
@@ -183,7 +179,6 @@ impl AggregationVariants {
AggregationVariants::Sum(sum) => vec![sum.field_name()], AggregationVariants::Sum(sum) => vec![sum.field_name()],
AggregationVariants::Percentiles(per) => vec![per.field_name()], AggregationVariants::Percentiles(per) => vec![per.field_name()],
AggregationVariants::TopHits(top_hits) => top_hits.field_names(), AggregationVariants::TopHits(top_hits) => top_hits.field_names(),
AggregationVariants::Cardinality(per) => vec![per.field_name()],
} }
} }
@@ -208,7 +203,7 @@ impl AggregationVariants {
_ => None, _ => None,
} }
} }
pub(crate) fn as_top_hits(&self) -> Option<&TopHitsAggregationReq> { pub(crate) fn as_top_hits(&self) -> Option<&TopHitsAggregation> {
match &self { match &self {
AggregationVariants::TopHits(top_hits) => Some(top_hits), AggregationVariants::TopHits(top_hits) => Some(top_hits),
_ => None, _ => None,

View File

@@ -11,8 +11,8 @@ use super::bucket::{
DateHistogramAggregationReq, HistogramAggregation, RangeAggregation, TermsAggregation, DateHistogramAggregationReq, HistogramAggregation, RangeAggregation, TermsAggregation,
}; };
use super::metric::{ use super::metric::{
AverageAggregation, CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation, AverageAggregation, CountAggregation, ExtendedStatsAggregation, MaxAggregation, MinAggregation,
MaxAggregation, MinAggregation, StatsAggregation, SumAggregation, StatsAggregation, SumAggregation,
}; };
use super::segment_agg_result::AggregationLimits; use super::segment_agg_result::AggregationLimits;
use super::VecWithNames; use super::VecWithNames;
@@ -162,11 +162,6 @@ impl AggregationWithAccessor {
field: ref field_name, field: ref field_name,
ref missing, ref missing,
.. ..
})
| Cardinality(CardinalityAggregationReq {
field: ref field_name,
ref missing,
..
}) => { }) => {
let str_dict_column = reader.fast_fields().str(field_name)?; let str_dict_column = reader.fast_fields().str(field_name)?;
let allowed_column_types = [ let allowed_column_types = [

View File

@@ -98,8 +98,6 @@ pub enum MetricResult {
Percentiles(PercentilesMetricResult), Percentiles(PercentilesMetricResult),
/// Top hits metric result /// Top hits metric result
TopHits(TopHitsMetricResult), TopHits(TopHitsMetricResult),
/// Cardinality metric result
Cardinality(SingleMetricResult),
} }
impl MetricResult { impl MetricResult {
@@ -118,7 +116,6 @@ impl MetricResult {
MetricResult::TopHits(_) => Err(TantivyError::AggregationError( MetricResult::TopHits(_) => Err(TantivyError::AggregationError(
AggregationError::InvalidRequest("top_hits can't be used to order".to_string()), AggregationError::InvalidRequest("top_hits can't be used to order".to_string()),
)), )),
MetricResult::Cardinality(card) => Ok(card.value),
} }
} }
} }

View File

@@ -110,16 +110,6 @@ fn test_aggregation_flushing(
} }
} }
} }
},
"cardinality_string_id":{
"cardinality": {
"field": "string_id"
}
},
"cardinality_score":{
"cardinality": {
"field": "score"
}
} }
}); });
@@ -222,9 +212,6 @@ fn test_aggregation_flushing(
) )
); );
assert_eq!(res["cardinality_string_id"]["value"], 2.0);
assert_eq!(res["cardinality_score"]["value"], 80.0);
Ok(()) Ok(())
} }
@@ -939,10 +926,10 @@ fn test_aggregation_on_json_object_mixed_types() {
}, },
"termagg": { "termagg": {
"buckets": [ "buckets": [
{ "doc_count": 1, "key": 10.0, "key_as_string": "10", "min_price": { "value": 10.0 } }, { "doc_count": 1, "key": 10.0, "min_price": { "value": 10.0 } },
{ "doc_count": 3, "key": "blue", "min_price": { "value": 5.0 } }, { "doc_count": 3, "key": "blue", "min_price": { "value": 5.0 } },
{ "doc_count": 2, "key": "red", "min_price": { "value": 1.0 } }, { "doc_count": 2, "key": "red", "min_price": { "value": 1.0 } },
{ "doc_count": 1, "key": -20.5, "key_as_string": "-20.5", "min_price": { "value": -20.5 } }, { "doc_count": 1, "key": -20.5, "min_price": { "value": -20.5 } },
{ "doc_count": 2, "key": 1.0, "key_as_string": "true", "min_price": { "value": null } }, { "doc_count": 2, "key": 1.0, "key_as_string": "true", "min_price": { "value": null } },
], ],
"sum_other_doc_count": 0 "sum_other_doc_count": 0

View File

@@ -1,9 +1,10 @@
use std::fmt::Debug; use std::fmt::Debug;
use std::io;
use std::net::Ipv6Addr; use std::net::Ipv6Addr;
use columnar::column_values::CompactSpaceU64Accessor; use columnar::column_values::CompactSpaceU64Accessor;
use columnar::{ColumnType, Dictionary, MonotonicallyMappableToU128, MonotonicallyMappableToU64}; use columnar::{
BytesColumn, ColumnType, MonotonicallyMappableToU128, MonotonicallyMappableToU64, StrColumn,
};
use rustc_hash::FxHashMap; use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@@ -465,66 +466,49 @@ impl SegmentTermCollector {
}; };
if self.column_type == ColumnType::Str { if self.column_type == ColumnType::Str {
let fallback_dict = Dictionary::empty();
let term_dict = agg_with_accessor let term_dict = agg_with_accessor
.str_dict_column .str_dict_column
.as_ref() .as_ref()
.map(|el| el.dictionary()) .cloned()
.unwrap_or_else(|| &fallback_dict); .unwrap_or_else(|| {
let mut buffer = Vec::new(); StrColumn::wrap(BytesColumn::empty(agg_with_accessor.accessor.num_docs()))
});
// special case for missing key let mut buffer = String::new();
if let Some(index) = entries.iter().position(|value| value.0 == u64::MAX) { for (term_id, doc_count) in entries {
let entry = entries[index]; let intermediate_entry = into_intermediate_bucket_entry(term_id, doc_count)?;
let intermediate_entry = into_intermediate_bucket_entry(entry.0, entry.1)?; // Special case for missing key
let missing_key = self if term_id == u64::MAX {
.req let missing_key = self
.missing .req
.as_ref() .missing
.expect("Found placeholder term_id but `missing` is None"); .as_ref()
match missing_key { .expect("Found placeholder term_id but `missing` is None");
Key::Str(missing) => { match missing_key {
buffer.clear(); Key::Str(missing) => {
buffer.extend_from_slice(missing.as_bytes()); buffer.clear();
dict.insert( buffer.push_str(missing);
IntermediateKey::Str( dict.insert(
String::from_utf8(buffer.to_vec()) IntermediateKey::Str(buffer.to_string()),
.expect("could not convert to String"), intermediate_entry,
), );
intermediate_entry, }
); Key::F64(val) => {
buffer.push_str(&val.to_string());
dict.insert(IntermediateKey::F64(*val), intermediate_entry);
}
} }
Key::F64(val) => { } else {
dict.insert(IntermediateKey::F64(*val), intermediate_entry); if !term_dict.ord_to_str(term_id, &mut buffer)? {
return Err(TantivyError::InternalError(format!(
"Couldn't find term_id {term_id} in dict"
)));
} }
dict.insert(IntermediateKey::Str(buffer.to_string()), intermediate_entry);
} }
entries.swap_remove(index);
} }
// Sort by term ord
entries.sort_unstable_by_key(|bucket| bucket.0);
let mut idx = 0;
term_dict.sorted_ords_to_term_cb(
entries.iter().map(|(term_id, _)| *term_id),
|term| {
let entry = entries[idx];
let intermediate_entry = into_intermediate_bucket_entry(entry.0, entry.1)
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
dict.insert(
IntermediateKey::Str(
String::from_utf8(term.to_vec()).expect("could not convert to String"),
),
intermediate_entry,
);
idx += 1;
Ok(())
},
)?;
if self.req.min_doc_count == 0 { if self.req.min_doc_count == 0 {
// TODO: Handle rev streaming for descending sorting by keys // TODO: Handle rev streaming for descending sorting by keys
let mut stream = term_dict.stream()?; let mut stream = term_dict.dictionary().stream()?;
let empty_sub_aggregation = IntermediateAggregationResults::empty_from_req( let empty_sub_aggregation = IntermediateAggregationResults::empty_from_req(
agg_with_accessor.agg.sub_aggregation(), agg_with_accessor.agg.sub_aggregation(),
); );

View File

@@ -26,7 +26,6 @@ use super::segment_agg_result::AggregationLimits;
use super::{format_date, AggregationError, Key, SerializedKey}; use super::{format_date, AggregationError, Key, SerializedKey};
use crate::aggregation::agg_result::{AggregationResults, BucketEntries, BucketEntry}; use crate::aggregation::agg_result::{AggregationResults, BucketEntries, BucketEntry};
use crate::aggregation::bucket::TermsAggregationInternal; use crate::aggregation::bucket::TermsAggregationInternal;
use crate::aggregation::metric::CardinalityCollector;
use crate::TantivyError; use crate::TantivyError;
/// Contains the intermediate aggregation result, which is optimized to be merged with other /// Contains the intermediate aggregation result, which is optimized to be merged with other
@@ -228,9 +227,6 @@ pub(crate) fn empty_from_req(req: &Aggregation) -> IntermediateAggregationResult
TopHits(ref req) => IntermediateAggregationResult::Metric( TopHits(ref req) => IntermediateAggregationResult::Metric(
IntermediateMetricResult::TopHits(TopHitsTopNComputer::new(req)), IntermediateMetricResult::TopHits(TopHitsTopNComputer::new(req)),
), ),
Cardinality(_) => IntermediateAggregationResult::Metric(
IntermediateMetricResult::Cardinality(CardinalityCollector::default()),
),
} }
} }
@@ -295,8 +291,6 @@ pub enum IntermediateMetricResult {
Sum(IntermediateSum), Sum(IntermediateSum),
/// Intermediate top_hits result /// Intermediate top_hits result
TopHits(TopHitsTopNComputer), TopHits(TopHitsTopNComputer),
/// Intermediate cardinality result
Cardinality(CardinalityCollector),
} }
impl IntermediateMetricResult { impl IntermediateMetricResult {
@@ -330,9 +324,6 @@ impl IntermediateMetricResult {
IntermediateMetricResult::TopHits(top_hits) => { IntermediateMetricResult::TopHits(top_hits) => {
MetricResult::TopHits(top_hits.into_final_result()) MetricResult::TopHits(top_hits.into_final_result())
} }
IntermediateMetricResult::Cardinality(cardinality) => {
MetricResult::Cardinality(cardinality.finalize().into())
}
} }
} }
@@ -381,12 +372,6 @@ impl IntermediateMetricResult {
(IntermediateMetricResult::TopHits(left), IntermediateMetricResult::TopHits(right)) => { (IntermediateMetricResult::TopHits(left), IntermediateMetricResult::TopHits(right)) => {
left.merge_fruits(right)?; left.merge_fruits(right)?;
} }
(
IntermediateMetricResult::Cardinality(left),
IntermediateMetricResult::Cardinality(right),
) => {
left.merge_fruits(right)?;
}
_ => { _ => {
panic!("incompatible fruit types in tree or missing merge_fruits handler"); panic!("incompatible fruit types in tree or missing merge_fruits handler");
} }
@@ -599,7 +584,6 @@ impl IntermediateTermBucketResult {
let val = if key { "true" } else { "false" }; let val = if key { "true" } else { "false" };
Some(val.to_string()) Some(val.to_string())
} }
IntermediateKey::F64(val) => Some(val.to_string()),
_ => None, _ => None,
}; };
Ok(BucketEntry { Ok(BucketEntry {

View File

@@ -1,466 +0,0 @@
use std::collections::hash_map::DefaultHasher;
use std::hash::{BuildHasher, Hasher};
use columnar::column_values::CompactSpaceU64Accessor;
use columnar::Dictionary;
use common::f64_to_u64;
use hyperloglogplus::{HyperLogLog, HyperLogLogPlus};
use rustc_hash::FxHashSet;
use serde::{Deserialize, Serialize};
use crate::aggregation::agg_req_with_accessor::{
AggregationWithAccessor, AggregationsWithAccessor,
};
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::*;
use crate::TantivyError;
#[derive(Clone, Debug, Serialize, Deserialize)]
struct BuildSaltedHasher {
salt: u8,
}
impl BuildHasher for BuildSaltedHasher {
type Hasher = DefaultHasher;
fn build_hasher(&self) -> Self::Hasher {
let mut hasher = DefaultHasher::new();
hasher.write_u8(self.salt);
hasher
}
}
/// # Cardinality
///
/// The cardinality aggregation allows for computing an estimate
/// of the number of different values in a data set based on the
/// HyperLogLog++ algorithm. This is particularly useful for understanding the
/// uniqueness of values in a large dataset where counting each unique value
/// individually would be computationally expensive.
///
/// For example, you might use a cardinality aggregation to estimate the number
/// of unique visitors to a website by aggregating on a field that contains
/// user IDs or session IDs.
///
/// To use the cardinality aggregation, you'll need to provide a field to
/// aggregate on. The following example demonstrates a request for the cardinality
/// of the "user_id" field:
///
/// ```JSON
/// {
/// "cardinality": {
/// "field": "user_id"
/// }
/// }
/// ```
///
/// This request will return an estimate of the number of unique values in the
/// "user_id" field.
///
/// ## Missing Values
///
/// The `missing` parameter defines how documents that are missing a value should be treated.
/// By default, documents without a value for the specified field are ignored. However, you can
/// specify a default value for these documents using the `missing` parameter. This can be useful
/// when you want to include documents with missing values in the aggregation.
///
/// For example, the following request treats documents with missing values in the "user_id"
/// field as if they had a value of "unknown":
///
/// ```JSON
/// {
/// "cardinality": {
/// "field": "user_id",
/// "missing": "unknown"
/// }
/// }
/// ```
///
/// # Estimation Accuracy
///
/// The cardinality aggregation provides an approximate count, which is usually
/// accurate within a small error range. This trade-off allows for efficient
/// computation even on very large datasets.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct CardinalityAggregationReq {
/// The field name to compute the percentiles on.
pub field: String,
/// The missing parameter defines how documents that are missing a value should be treated.
/// By default they will be ignored but it is also possible to treat them as if they had a
/// value. Examples in JSON format:
/// { "field": "my_numbers", "missing": "10.0" }
#[serde(skip_serializing_if = "Option::is_none", default)]
pub missing: Option<Key>,
}
impl CardinalityAggregationReq {
/// Creates a new [`CardinalityAggregationReq`] instance from a field name.
pub fn from_field_name(field_name: String) -> Self {
Self {
field: field_name,
missing: None,
}
}
/// Returns the field name the aggregation is computed on.
pub fn field_name(&self) -> &str {
&self.field
}
}
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct SegmentCardinalityCollector {
cardinality: CardinalityCollector,
entries: FxHashSet<u64>,
column_type: ColumnType,
accessor_idx: usize,
missing: Option<Key>,
}
impl SegmentCardinalityCollector {
pub fn from_req(column_type: ColumnType, accessor_idx: usize, missing: &Option<Key>) -> Self {
Self {
cardinality: CardinalityCollector::new(column_type as u8),
entries: Default::default(),
column_type,
accessor_idx,
missing: missing.clone(),
}
}
fn fetch_block_with_field(
&mut self,
docs: &[crate::DocId],
agg_accessor: &mut AggregationWithAccessor,
) {
if let Some(missing) = agg_accessor.missing_value_for_accessor {
agg_accessor.column_block_accessor.fetch_block_with_missing(
docs,
&agg_accessor.accessor,
missing,
);
} else {
agg_accessor
.column_block_accessor
.fetch_block(docs, &agg_accessor.accessor);
}
}
fn into_intermediate_metric_result(
mut self,
agg_with_accessor: &AggregationWithAccessor,
) -> crate::Result<IntermediateMetricResult> {
if self.column_type == ColumnType::Str {
let fallback_dict = Dictionary::empty();
let dict = agg_with_accessor
.str_dict_column
.as_ref()
.map(|el| el.dictionary())
.unwrap_or_else(|| &fallback_dict);
let mut has_missing = false;
// TODO: replace FxHashSet with something that allows iterating in order
// (e.g. sparse bitvec)
let mut term_ids = Vec::new();
for term_ord in self.entries.into_iter() {
if term_ord == u64::MAX {
has_missing = true;
} else {
// we can reasonably exclude values above u32::MAX
term_ids.push(term_ord as u32);
}
}
term_ids.sort_unstable();
dict.sorted_ords_to_term_cb(term_ids.iter().map(|term| *term as u64), |term| {
self.cardinality.sketch.insert_any(&term);
Ok(())
})?;
if has_missing {
let missing_key = self
.missing
.as_ref()
.expect("Found placeholder term_ord but `missing` is None");
match missing_key {
Key::Str(missing) => {
self.cardinality.sketch.insert_any(&missing);
}
Key::F64(val) => {
let val = f64_to_u64(*val);
self.cardinality.sketch.insert_any(&val);
}
}
}
}
Ok(IntermediateMetricResult::Cardinality(self.cardinality))
}
}
impl SegmentAggregationCollector for SegmentCardinalityCollector {
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
results: &mut IntermediateAggregationResults,
) -> crate::Result<()> {
let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string();
let agg_with_accessor = &agg_with_accessor.aggs.values[self.accessor_idx];
let intermediate_result = self.into_intermediate_metric_result(agg_with_accessor)?;
results.push(
name,
IntermediateAggregationResult::Metric(intermediate_result),
)?;
Ok(())
}
fn collect(
&mut self,
doc: crate::DocId,
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_with_accessor)
}
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
let bucket_agg_accessor = &mut agg_with_accessor.aggs.values[self.accessor_idx];
self.fetch_block_with_field(docs, bucket_agg_accessor);
let col_block_accessor = &bucket_agg_accessor.column_block_accessor;
if self.column_type == ColumnType::Str {
for term_ord in col_block_accessor.iter_vals() {
self.entries.insert(term_ord);
}
} else if self.column_type == ColumnType::IpAddr {
let compact_space_accessor = bucket_agg_accessor
.accessor
.values
.clone()
.downcast_arc::<CompactSpaceU64Accessor>()
.map_err(|_| {
TantivyError::AggregationError(
crate::aggregation::AggregationError::InternalError(
"Type mismatch: Could not downcast to CompactSpaceU64Accessor"
.to_string(),
),
)
})?;
for val in col_block_accessor.iter_vals() {
let val: u128 = compact_space_accessor.compact_to_u128(val as u32);
self.cardinality.sketch.insert_any(&val);
}
} else {
for val in col_block_accessor.iter_vals() {
self.cardinality.sketch.insert_any(&val);
}
}
Ok(())
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
/// The percentiles collector used during segment collection and for merging results.
pub struct CardinalityCollector {
sketch: HyperLogLogPlus<u64, BuildSaltedHasher>,
}
impl Default for CardinalityCollector {
fn default() -> Self {
Self::new(0)
}
}
impl PartialEq for CardinalityCollector {
fn eq(&self, _other: &Self) -> bool {
false
}
}
impl CardinalityCollector {
/// Compute the final cardinality estimate.
pub fn finalize(self) -> Option<f64> {
Some(self.sketch.clone().count().trunc())
}
fn new(salt: u8) -> Self {
Self {
sketch: HyperLogLogPlus::new(16, BuildSaltedHasher { salt }).unwrap(),
}
}
pub(crate) fn merge_fruits(&mut self, right: CardinalityCollector) -> crate::Result<()> {
self.sketch.merge(&right.sketch).map_err(|err| {
TantivyError::AggregationError(AggregationError::InternalError(format!(
"Error while merging cardinality {err:?}"
)))
})?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::net::IpAddr;
use std::str::FromStr;
use columnar::MonotonicallyMappableToU64;
use crate::aggregation::agg_req::Aggregations;
use crate::aggregation::tests::{exec_request, get_test_index_from_terms};
use crate::schema::{IntoIpv6Addr, Schema, FAST};
use crate::Index;
#[test]
fn cardinality_aggregation_test_empty_index() -> crate::Result<()> {
let values = vec![];
let index = get_test_index_from_terms(false, &values)?;
let agg_req: Aggregations = serde_json::from_value(json!({
"cardinality": {
"cardinality": {
"field": "string_id",
}
},
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["cardinality"]["value"], 0.0);
Ok(())
}
#[test]
fn cardinality_aggregation_test_single_segment() -> crate::Result<()> {
cardinality_aggregation_test_merge_segment(true)
}
#[test]
fn cardinality_aggregation_test() -> crate::Result<()> {
cardinality_aggregation_test_merge_segment(false)
}
fn cardinality_aggregation_test_merge_segment(merge_segments: bool) -> crate::Result<()> {
let segment_and_terms = vec![
vec!["terma"],
vec!["termb"],
vec!["termc"],
vec!["terma"],
vec!["terma"],
vec!["terma"],
vec!["termb"],
vec!["terma"],
];
let index = get_test_index_from_terms(merge_segments, &segment_and_terms)?;
let agg_req: Aggregations = serde_json::from_value(json!({
"cardinality": {
"cardinality": {
"field": "string_id",
}
},
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["cardinality"]["value"], 3.0);
Ok(())
}
#[test]
fn cardinality_aggregation_u64() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let id_field = schema_builder.add_u64_field("id", FAST);
let index = Index::create_in_ram(schema_builder.build());
{
let mut writer = index.writer_for_tests()?;
writer.add_document(doc!(id_field => 1u64))?;
writer.add_document(doc!(id_field => 2u64))?;
writer.add_document(doc!(id_field => 3u64))?;
writer.add_document(doc!())?;
writer.commit()?;
}
let agg_req: Aggregations = serde_json::from_value(json!({
"cardinality": {
"cardinality": {
"field": "id",
"missing": 0u64
},
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["cardinality"]["value"], 4.0);
Ok(())
}
#[test]
fn cardinality_aggregation_ip_addr() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let field = schema_builder.add_ip_addr_field("ip_field", FAST);
let index = Index::create_in_ram(schema_builder.build());
{
let mut writer = index.writer_for_tests()?;
// IpV6 loopback
writer.add_document(doc!(field=>IpAddr::from_str("::1").unwrap().into_ipv6_addr()))?;
writer.add_document(doc!(field=>IpAddr::from_str("::1").unwrap().into_ipv6_addr()))?;
// IpV4
writer.add_document(
doc!(field=>IpAddr::from_str("127.0.0.1").unwrap().into_ipv6_addr()),
)?;
writer.commit()?;
}
let agg_req: Aggregations = serde_json::from_value(json!({
"cardinality": {
"cardinality": {
"field": "ip_field"
},
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["cardinality"]["value"], 2.0);
Ok(())
}
#[test]
fn cardinality_aggregation_json() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let field = schema_builder.add_json_field("json", FAST);
let index = Index::create_in_ram(schema_builder.build());
{
let mut writer = index.writer_for_tests()?;
writer.add_document(doc!(field => json!({"value": false})))?;
writer.add_document(doc!(field => json!({"value": true})))?;
writer.add_document(doc!(field => json!({"value": i64::from_u64(0u64)})))?;
writer.add_document(doc!(field => json!({"value": i64::from_u64(1u64)})))?;
writer.commit()?;
}
let agg_req: Aggregations = serde_json::from_value(json!({
"cardinality": {
"cardinality": {
"field": "json.value"
},
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["cardinality"]["value"], 4.0);
Ok(())
}
}

View File

@@ -17,7 +17,6 @@
//! - [Percentiles](PercentilesAggregationReq) //! - [Percentiles](PercentilesAggregationReq)
mod average; mod average;
mod cardinality;
mod count; mod count;
mod extended_stats; mod extended_stats;
mod max; mod max;
@@ -30,7 +29,6 @@ mod top_hits;
use std::collections::HashMap; use std::collections::HashMap;
pub use average::*; pub use average::*;
pub use cardinality::*;
pub use count::*; pub use count::*;
pub use extended_stats::*; pub use extended_stats::*;
pub use max::*; pub use max::*;

View File

@@ -89,7 +89,7 @@ use crate::{DocAddress, DocId, SegmentOrdinal};
/// } /// }
/// ``` /// ```
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
pub struct TopHitsAggregationReq { pub struct TopHitsAggregation {
sort: Vec<KeyOrder>, sort: Vec<KeyOrder>,
size: usize, size: usize,
from: Option<usize>, from: Option<usize>,
@@ -164,7 +164,7 @@ fn unsupported_err(parameter: &str) -> crate::Result<()> {
)) ))
} }
impl TopHitsAggregationReq { impl TopHitsAggregation {
/// Validate and resolve field retrieval parameters /// Validate and resolve field retrieval parameters
pub fn validate_and_resolve_field_names( pub fn validate_and_resolve_field_names(
&mut self, &mut self,
@@ -431,7 +431,7 @@ impl Eq for DocSortValuesAndFields {}
/// The TopHitsCollector used for collecting over segments and merging results. /// The TopHitsCollector used for collecting over segments and merging results.
#[derive(Clone, Serialize, Deserialize, Debug)] #[derive(Clone, Serialize, Deserialize, Debug)]
pub struct TopHitsTopNComputer { pub struct TopHitsTopNComputer {
req: TopHitsAggregationReq, req: TopHitsAggregation,
top_n: TopNComputer<DocSortValuesAndFields, DocAddress, false>, top_n: TopNComputer<DocSortValuesAndFields, DocAddress, false>,
} }
@@ -443,7 +443,7 @@ impl std::cmp::PartialEq for TopHitsTopNComputer {
impl TopHitsTopNComputer { impl TopHitsTopNComputer {
/// Create a new TopHitsCollector /// Create a new TopHitsCollector
pub fn new(req: &TopHitsAggregationReq) -> Self { pub fn new(req: &TopHitsAggregation) -> Self {
Self { Self {
top_n: TopNComputer::new(req.size + req.from.unwrap_or(0)), top_n: TopNComputer::new(req.size + req.from.unwrap_or(0)),
req: req.clone(), req: req.clone(),
@@ -496,7 +496,7 @@ pub(crate) struct TopHitsSegmentCollector {
impl TopHitsSegmentCollector { impl TopHitsSegmentCollector {
pub fn from_req( pub fn from_req(
req: &TopHitsAggregationReq, req: &TopHitsAggregation,
accessor_idx: usize, accessor_idx: usize,
segment_ordinal: SegmentOrdinal, segment_ordinal: SegmentOrdinal,
) -> Self { ) -> Self {
@@ -509,7 +509,7 @@ impl TopHitsSegmentCollector {
fn into_top_hits_collector( fn into_top_hits_collector(
self, self,
value_accessors: &HashMap<String, Vec<DynamicColumn>>, value_accessors: &HashMap<String, Vec<DynamicColumn>>,
req: &TopHitsAggregationReq, req: &TopHitsAggregation,
) -> TopHitsTopNComputer { ) -> TopHitsTopNComputer {
let mut top_hits_computer = TopHitsTopNComputer::new(req); let mut top_hits_computer = TopHitsTopNComputer::new(req);
let top_results = self.top_n.into_vec(); let top_results = self.top_n.into_vec();
@@ -532,7 +532,7 @@ impl TopHitsSegmentCollector {
fn collect_with( fn collect_with(
&mut self, &mut self,
doc_id: crate::DocId, doc_id: crate::DocId,
req: &TopHitsAggregationReq, req: &TopHitsAggregation,
accessors: &[(Column<u64>, ColumnType)], accessors: &[(Column<u64>, ColumnType)],
) -> crate::Result<()> { ) -> crate::Result<()> {
let sorts: Vec<DocValueAndOrder> = req let sorts: Vec<DocValueAndOrder> = req

View File

@@ -44,14 +44,11 @@
//! - [Metric](metric) //! - [Metric](metric)
//! - [Average](metric::AverageAggregation) //! - [Average](metric::AverageAggregation)
//! - [Stats](metric::StatsAggregation) //! - [Stats](metric::StatsAggregation)
//! - [ExtendedStats](metric::ExtendedStatsAggregation)
//! - [Min](metric::MinAggregation) //! - [Min](metric::MinAggregation)
//! - [Max](metric::MaxAggregation) //! - [Max](metric::MaxAggregation)
//! - [Sum](metric::SumAggregation) //! - [Sum](metric::SumAggregation)
//! - [Count](metric::CountAggregation) //! - [Count](metric::CountAggregation)
//! - [Percentiles](metric::PercentilesAggregationReq) //! - [Percentiles](metric::PercentilesAggregationReq)
//! - [Cardinality](metric::CardinalityAggregationReq)
//! - [TopHits](metric::TopHitsAggregationReq)
//! //!
//! # Example //! # Example
//! Compute the average metric, by building [`agg_req::Aggregations`], which is built from an //! Compute the average metric, by building [`agg_req::Aggregations`], which is built from an

View File

@@ -16,10 +16,7 @@ use super::metric::{
SumAggregation, SumAggregation,
}; };
use crate::aggregation::bucket::TermMissingAgg; use crate::aggregation::bucket::TermMissingAgg;
use crate::aggregation::metric::{ use crate::aggregation::metric::{SegmentExtendedStatsCollector, TopHitsSegmentCollector};
CardinalityAggregationReq, SegmentCardinalityCollector, SegmentExtendedStatsCollector,
TopHitsSegmentCollector,
};
pub(crate) trait SegmentAggregationCollector: CollectorClone + Debug { pub(crate) trait SegmentAggregationCollector: CollectorClone + Debug {
fn add_intermediate_aggregation_result( fn add_intermediate_aggregation_result(
@@ -172,9 +169,6 @@ pub(crate) fn build_single_agg_segment_collector(
accessor_idx, accessor_idx,
req.segment_ordinal, req.segment_ordinal,
))), ))),
Cardinality(CardinalityAggregationReq { missing, .. }) => Ok(Box::new(
SegmentCardinalityCollector::from_req(req.field_type, accessor_idx, missing),
)),
} }
} }

View File

@@ -1,4 +1,4 @@
use common::json_path_writer::{JSON_END_OF_PATH, JSON_PATH_SEGMENT_SEP}; use common::json_path_writer::JSON_PATH_SEGMENT_SEP;
use common::{replace_in_place, JsonPathWriter}; use common::{replace_in_place, JsonPathWriter};
use rustc_hash::FxHashMap; use rustc_hash::FxHashMap;
@@ -83,9 +83,6 @@ fn index_json_object<'a, V: Value<'a>>(
positions_per_path: &mut IndexingPositionsPerPath, positions_per_path: &mut IndexingPositionsPerPath,
) { ) {
for (json_path_segment, json_value_visitor) in json_visitor { for (json_path_segment, json_value_visitor) in json_visitor {
if json_path_segment.as_bytes().contains(&JSON_END_OF_PATH) {
continue;
}
json_path_writer.push(json_path_segment); json_path_writer.push(json_path_segment);
index_json_value( index_json_value(
doc, doc,

View File

@@ -815,9 +815,8 @@ mod tests {
use crate::indexer::NoMergePolicy; use crate::indexer::NoMergePolicy;
use crate::query::{QueryParser, TermQuery}; use crate::query::{QueryParser, TermQuery};
use crate::schema::{ use crate::schema::{
self, Facet, FacetOptions, IndexRecordOption, IpAddrOptions, JsonObjectOptions, self, Facet, FacetOptions, IndexRecordOption, IpAddrOptions, NumericOptions,
NumericOptions, Schema, TextFieldIndexing, TextOptions, Value, FAST, INDEXED, STORED, TextFieldIndexing, TextOptions, Value, FAST, INDEXED, STORED, STRING, TEXT,
STRING, TEXT,
}; };
use crate::store::DOCSTORE_CACHE_CAPACITY; use crate::store::DOCSTORE_CACHE_CAPACITY;
use crate::{ use crate::{
@@ -1574,11 +1573,11 @@ mod tests {
deleted_ids.remove(id); deleted_ids.remove(id);
} }
IndexingOp::DeleteDoc { id } => { IndexingOp::DeleteDoc { id } => {
existing_ids.remove(id); existing_ids.remove(&id);
deleted_ids.insert(*id); deleted_ids.insert(*id);
} }
IndexingOp::DeleteDocQuery { id } => { IndexingOp::DeleteDocQuery { id } => {
existing_ids.remove(id); existing_ids.remove(&id);
deleted_ids.insert(*id); deleted_ids.insert(*id);
} }
_ => {} _ => {}
@@ -2379,11 +2378,11 @@ mod tests {
#[test] #[test]
fn test_bug_1617_2() { fn test_bug_1617_2() {
test_operation_strategy( assert!(test_operation_strategy(
&[ &[
IndexingOp::AddDoc { IndexingOp::AddDoc {
id: 13, id: 13,
value: Default::default(), value: Default::default()
}, },
IndexingOp::DeleteDoc { id: 13 }, IndexingOp::DeleteDoc { id: 13 },
IndexingOp::Commit, IndexingOp::Commit,
@@ -2391,9 +2390,9 @@ mod tests {
IndexingOp::Commit, IndexingOp::Commit,
IndexingOp::Merge, IndexingOp::Merge,
], ],
true, true
) )
.unwrap(); .is_ok());
} }
#[test] #[test]
@@ -2493,9 +2492,9 @@ mod tests {
} }
#[test] #[test]
fn test_bug_2442_reserved_character_fast_field() -> crate::Result<()> { fn test_bug_2442() -> crate::Result<()> {
let mut schema_builder = schema::Schema::builder(); let mut schema_builder = schema::Schema::builder();
let json_field = schema_builder.add_json_field("json", FAST | TEXT); let json_field = schema_builder.add_json_field("json", TEXT | FAST);
let schema = schema_builder.build(); let schema = schema_builder.build();
let index = Index::builder().schema(schema).create_in_ram()?; let index = Index::builder().schema(schema).create_in_ram()?;
@@ -2516,21 +2515,4 @@ mod tests {
Ok(()) Ok(())
} }
#[test]
fn test_bug_2442_reserved_character_columnar() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let options = JsonObjectOptions::from(FAST).set_expand_dots_enabled();
let field = schema_builder.add_json_field("json", options);
let index = Index::create_in_ram(schema_builder.build());
let mut index_writer = index.writer_for_tests().unwrap();
index_writer
.add_document(doc!(field=>json!({"\u{0000}": "A"})))
.unwrap();
index_writer
.add_document(doc!(field=>json!({format!("\u{0000}\u{0000}"): "A"})))
.unwrap();
index_writer.commit().unwrap();
Ok(())
}
} }

View File

@@ -145,27 +145,15 @@ mod tests_mmap {
} }
} }
#[test] #[test]
fn test_json_field_null_byte_is_ignored() { fn test_json_field_null_byte() {
let mut schema_builder = Schema::builder(); // Test when field name contains a zero byte, which has special meaning in tantivy.
let options = JsonObjectOptions::from(TEXT | FAST).set_expand_dots_enabled(); // As a workaround, we convert the zero byte to the ASCII character '0'.
let field = schema_builder.add_json_field("json", options); // https://github.com/quickwit-oss/tantivy/issues/2340
let index = Index::create_in_ram(schema_builder.build()); // https://github.com/quickwit-oss/tantivy/issues/2193
let mut index_writer = index.writer_for_tests().unwrap(); let field_name_in = "\u{0000}";
index_writer let field_name_out = "0";
.add_document(doc!(field=>json!({"key": "test1", "invalidkey\u{0000}": "test2"}))) test_json_field_name(field_name_in, field_name_out);
.unwrap();
index_writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let segment_reader = searcher.segment_reader(0);
let inv_indexer = segment_reader.inverted_index(field).unwrap();
let term_dict = inv_indexer.terms();
assert_eq!(term_dict.num_terms(), 1);
let mut term_bytes = Vec::new();
term_dict.ord_to_term(0, &mut term_bytes).unwrap();
assert_eq!(term_bytes, b"key\0stest1");
} }
#[test] #[test]
fn test_json_field_1byte() { fn test_json_field_1byte() {
// Test when field name contains a '1' byte, which has special meaning in tantivy. // Test when field name contains a '1' byte, which has special meaning in tantivy.
@@ -303,7 +291,7 @@ mod tests_mmap {
Type::Str, Type::Str,
), ),
(format!("{field_name_out_internal}a"), Type::Str), (format!("{field_name_out_internal}a"), Type::Str),
(field_name_out_internal.to_string(), Type::Str), (format!("{field_name_out_internal}"), Type::Str),
(format!("num{field_name_out_internal}"), Type::I64), (format!("num{field_name_out_internal}"), Type::I64),
]; ];
expected_fields.sort(); expected_fields.sort();

View File

@@ -1,3 +1,5 @@
use common::json_path_writer::JSON_END_OF_PATH;
use common::replace_in_place;
use fnv::FnvHashMap; use fnv::FnvHashMap;
/// `Field` is represented by an unsigned 32-bit integer type. /// `Field` is represented by an unsigned 32-bit integer type.
@@ -38,7 +40,13 @@ impl PathToUnorderedId {
#[cold] #[cold]
fn insert_new_path(&mut self, path: &str) -> u32 { fn insert_new_path(&mut self, path: &str) -> u32 {
let next_id = self.map.len() as u32; let next_id = self.map.len() as u32;
let new_path = path.to_string(); let mut new_path = path.to_string();
// The unsafe below is safe as long as b'.' and JSON_PATH_SEGMENT_SEP are
// valid single byte ut8 strings.
// By utf-8 design, they cannot be part of another codepoint.
unsafe { replace_in_place(JSON_END_OF_PATH, b'0', new_path.as_bytes_mut()) };
self.map.insert(new_path, next_id); self.map.insert(new_path, next_id);
next_id next_id
} }

View File

@@ -3,7 +3,7 @@
//! In "The beauty and the beast", the term "the" appears in position 0 and position 3. //! In "The beauty and the beast", the term "the" appears in position 0 and position 3.
//! This information is useful to run phrase queries. //! This information is useful to run phrase queries.
//! //!
//! The [position](crate::index::SegmentComponent::Positions) file contains all of the //! The [position](crate::SegmentComponent::Positions) file contains all of the
//! bitpacked positions delta, for all terms of a given field, one term after the other. //! bitpacked positions delta, for all terms of a given field, one term after the other.
//! //!
//! Each term is encoded independently. //! Each term is encoded independently.

View File

@@ -59,7 +59,7 @@ impl<Rec: Recorder> PostingsWriter for JsonPostingsWriter<Rec> {
/// The actual serialization format is handled by the `PostingsSerializer`. /// The actual serialization format is handled by the `PostingsSerializer`.
fn serialize( fn serialize(
&self, &self,
ordered_term_addrs: &[(Field, OrderedPathId, &[u8], Addr)], term_addrs: &[(Field, OrderedPathId, &[u8], Addr)],
ordered_id_to_path: &[&str], ordered_id_to_path: &[&str],
ctx: &IndexingContext, ctx: &IndexingContext,
serializer: &mut FieldSerializer, serializer: &mut FieldSerializer,
@@ -69,7 +69,7 @@ impl<Rec: Recorder> PostingsWriter for JsonPostingsWriter<Rec> {
term_buffer.clear_with_field_and_type(Type::Json, Field::from_field_id(0)); term_buffer.clear_with_field_and_type(Type::Json, Field::from_field_id(0));
let mut prev_term_id = u32::MAX; let mut prev_term_id = u32::MAX;
let mut term_path_len = 0; // this will be set in the first iteration let mut term_path_len = 0; // this will be set in the first iteration
for (_field, path_id, term, addr) in ordered_term_addrs { for (_field, path_id, term, addr) in term_addrs {
if prev_term_id != path_id.path_id() { if prev_term_id != path_id.path_id() {
term_buffer.truncate_value_bytes(0); term_buffer.truncate_value_bytes(0);
term_buffer.append_path(ordered_id_to_path[path_id.path_id() as usize].as_bytes()); term_buffer.append_path(ordered_id_to_path[path_id.path_id() as usize].as_bytes());

View File

@@ -15,7 +15,6 @@ pub trait Postings: DocSet + 'static {
fn term_freq(&self) -> u32; fn term_freq(&self) -> u32;
/// Returns the positions offsetted with a given value. /// Returns the positions offsetted with a given value.
/// It is not necessary to clear the `output` before calling this method.
/// The output vector will be resized to the `term_freq`. /// The output vector will be resized to the `term_freq`.
fn positions_with_offset(&mut self, offset: u32, output: &mut Vec<u32>); fn positions_with_offset(&mut self, offset: u32, output: &mut Vec<u32>);

View File

@@ -22,7 +22,10 @@ pub struct AllWeight;
impl Weight for AllWeight { impl Weight for AllWeight {
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> { fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
let all_scorer = AllScorer::new(reader.max_doc()); let all_scorer = AllScorer {
doc: 0u32,
max_doc: reader.max_doc(),
};
Ok(Box::new(BoostScorer::new(all_scorer, boost))) Ok(Box::new(BoostScorer::new(all_scorer, boost)))
} }
@@ -40,13 +43,6 @@ pub struct AllScorer {
max_doc: DocId, max_doc: DocId,
} }
impl AllScorer {
/// Creates a new AllScorer with `max_doc` docs.
pub fn new(max_doc: DocId) -> AllScorer {
AllScorer { doc: 0u32, max_doc }
}
}
impl DocSet for AllScorer { impl DocSet for AllScorer {
#[inline(always)] #[inline(always)]
fn advance(&mut self) -> DocId { fn advance(&mut self) -> DocId {

View File

@@ -66,10 +66,6 @@ use crate::schema::{IndexRecordOption, Term};
/// Term::from_field_text(title, "diary"), /// Term::from_field_text(title, "diary"),
/// IndexRecordOption::Basic, /// IndexRecordOption::Basic,
/// )); /// ));
/// let cow_term_query: Box<dyn Query> = Box::new(TermQuery::new(
/// Term::from_field_text(title, "cow"),
/// IndexRecordOption::Basic
/// ));
/// // A TermQuery with "found" in the body /// // A TermQuery with "found" in the body
/// let body_term_query: Box<dyn Query> = Box::new(TermQuery::new( /// let body_term_query: Box<dyn Query> = Box::new(TermQuery::new(
/// Term::from_field_text(body, "found"), /// Term::from_field_text(body, "found"),
@@ -78,7 +74,7 @@ use crate::schema::{IndexRecordOption, Term};
/// // TermQuery "diary" must and "girl" must not be present /// // TermQuery "diary" must and "girl" must not be present
/// let queries_with_occurs1 = vec![ /// let queries_with_occurs1 = vec![
/// (Occur::Must, diary_term_query.box_clone()), /// (Occur::Must, diary_term_query.box_clone()),
/// (Occur::MustNot, girl_term_query.box_clone()), /// (Occur::MustNot, girl_term_query),
/// ]; /// ];
/// // Make a BooleanQuery equivalent to /// // Make a BooleanQuery equivalent to
/// // title:+diary title:-girl /// // title:+diary title:-girl
@@ -86,10 +82,15 @@ use crate::schema::{IndexRecordOption, Term};
/// let count1 = searcher.search(&diary_must_and_girl_mustnot, &Count)?; /// let count1 = searcher.search(&diary_must_and_girl_mustnot, &Count)?;
/// assert_eq!(count1, 1); /// assert_eq!(count1, 1);
/// ///
/// // TermQuery for "cow" in the title
/// let cow_term_query: Box<dyn Query> = Box::new(TermQuery::new(
/// Term::from_field_text(title, "cow"),
/// IndexRecordOption::Basic,
/// ));
/// // "title:diary OR title:cow" /// // "title:diary OR title:cow"
/// let title_diary_or_cow = BooleanQuery::new(vec![ /// let title_diary_or_cow = BooleanQuery::new(vec![
/// (Occur::Should, diary_term_query.box_clone()), /// (Occur::Should, diary_term_query.box_clone()),
/// (Occur::Should, cow_term_query.box_clone()), /// (Occur::Should, cow_term_query),
/// ]); /// ]);
/// let count2 = searcher.search(&title_diary_or_cow, &Count)?; /// let count2 = searcher.search(&title_diary_or_cow, &Count)?;
/// assert_eq!(count2, 4); /// assert_eq!(count2, 4);
@@ -117,38 +118,21 @@ use crate::schema::{IndexRecordOption, Term};
/// ]); /// ]);
/// let count4 = searcher.search(&nested_query, &Count)?; /// let count4 = searcher.search(&nested_query, &Count)?;
/// assert_eq!(count4, 1); /// assert_eq!(count4, 1);
///
/// // You may call `with_minimum_required_clauses` to
/// // specify the number of should clauses the returned documents must match.
/// let minimum_required_query = BooleanQuery::with_minimum_required_clauses(vec![
/// (Occur::Should, cow_term_query.box_clone()),
/// (Occur::Should, girl_term_query.box_clone()),
/// (Occur::Should, diary_term_query.box_clone()),
/// ], 2);
/// // Return documents contains "Diary Cow", "Diary Girl" or "Cow Girl"
/// // Notice: "Diary" isn't "Dairy". ;-)
/// let count5 = searcher.search(&minimum_required_query, &Count)?;
/// assert_eq!(count5, 1);
/// Ok(()) /// Ok(())
/// } /// }
/// ``` /// ```
#[derive(Debug)] #[derive(Debug)]
pub struct BooleanQuery { pub struct BooleanQuery {
subqueries: Vec<(Occur, Box<dyn Query>)>, subqueries: Vec<(Occur, Box<dyn Query>)>,
minimum_number_should_match: usize,
} }
impl Clone for BooleanQuery { impl Clone for BooleanQuery {
fn clone(&self) -> Self { fn clone(&self) -> Self {
let subqueries = self self.subqueries
.subqueries
.iter() .iter()
.map(|(occur, subquery)| (*occur, subquery.box_clone())) .map(|(occur, subquery)| (*occur, subquery.box_clone()))
.collect::<Vec<_>>(); .collect::<Vec<_>>()
Self { .into()
subqueries,
minimum_number_should_match: self.minimum_number_should_match,
}
} }
} }
@@ -165,9 +149,8 @@ impl Query for BooleanQuery {
.iter() .iter()
.map(|(occur, subquery)| Ok((*occur, subquery.weight(enable_scoring)?))) .map(|(occur, subquery)| Ok((*occur, subquery.weight(enable_scoring)?)))
.collect::<crate::Result<_>>()?; .collect::<crate::Result<_>>()?;
Ok(Box::new(BooleanWeight::with_minimum_number_should_match( Ok(Box::new(BooleanWeight::new(
sub_weights, sub_weights,
self.minimum_number_should_match,
enable_scoring.is_scoring_enabled(), enable_scoring.is_scoring_enabled(),
Box::new(SumWithCoordsCombiner::default), Box::new(SumWithCoordsCombiner::default),
))) )))
@@ -183,41 +166,7 @@ impl Query for BooleanQuery {
impl BooleanQuery { impl BooleanQuery {
/// Creates a new boolean query. /// Creates a new boolean query.
pub fn new(subqueries: Vec<(Occur, Box<dyn Query>)>) -> BooleanQuery { pub fn new(subqueries: Vec<(Occur, Box<dyn Query>)>) -> BooleanQuery {
// If the bool query includes at least one should clause BooleanQuery { subqueries }
// and no Must or MustNot clauses, the default value is 1. Otherwise, the default value is
// 0. Keep pace with Elasticsearch.
let mut minimum_required = 0;
for (occur, _) in &subqueries {
match occur {
Occur::Should => minimum_required = 1,
Occur::Must | Occur::MustNot => {
minimum_required = 0;
break;
}
}
}
Self::with_minimum_required_clauses(subqueries, minimum_required)
}
/// Create a new boolean query with minimum number of required should clauses specified.
pub fn with_minimum_required_clauses(
subqueries: Vec<(Occur, Box<dyn Query>)>,
minimum_number_should_match: usize,
) -> BooleanQuery {
BooleanQuery {
subqueries,
minimum_number_should_match,
}
}
/// Getter for `minimum_number_should_match`
pub fn get_minimum_number_should_match(&self) -> usize {
self.minimum_number_should_match
}
/// Setter for `minimum_number_should_match`
pub fn set_minimum_number_should_match(&mut self, minimum_number_should_match: usize) {
self.minimum_number_should_match = minimum_number_should_match;
} }
/// Returns the intersection of the queries. /// Returns the intersection of the queries.
@@ -232,18 +181,6 @@ impl BooleanQuery {
BooleanQuery::new(subqueries) BooleanQuery::new(subqueries)
} }
/// Returns the union of the queries with minimum required clause.
pub fn union_with_minimum_required_clauses(
queries: Vec<Box<dyn Query>>,
minimum_required_clauses: usize,
) -> BooleanQuery {
let subqueries = queries
.into_iter()
.map(|sub_query| (Occur::Should, sub_query))
.collect();
BooleanQuery::with_minimum_required_clauses(subqueries, minimum_required_clauses)
}
/// Helper method to create a boolean query matching a given list of terms. /// Helper method to create a boolean query matching a given list of terms.
/// The resulting query is a disjunction of the terms. /// The resulting query is a disjunction of the terms.
pub fn new_multiterms_query(terms: Vec<Term>) -> BooleanQuery { pub fn new_multiterms_query(terms: Vec<Term>) -> BooleanQuery {
@@ -266,13 +203,11 @@ impl BooleanQuery {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::collections::HashSet;
use super::BooleanQuery; use super::BooleanQuery;
use crate::collector::{Count, DocSetCollector}; use crate::collector::{Count, DocSetCollector};
use crate::query::{Query, QueryClone, QueryParser, TermQuery}; use crate::query::{QueryClone, QueryParser, TermQuery};
use crate::schema::{Field, IndexRecordOption, Schema, TEXT}; use crate::schema::{IndexRecordOption, Schema, TEXT};
use crate::{DocAddress, DocId, Index, Term}; use crate::{DocAddress, Index, Term};
fn create_test_index() -> crate::Result<Index> { fn create_test_index() -> crate::Result<Index> {
let mut schema_builder = Schema::builder(); let mut schema_builder = Schema::builder();
@@ -288,73 +223,6 @@ mod tests {
Ok(index) Ok(index)
} }
#[test]
fn test_minimum_required() -> crate::Result<()> {
fn create_test_index_with<T: IntoIterator<Item = &'static str>>(
docs: T,
) -> crate::Result<Index> {
let mut schema_builder = Schema::builder();
let text = schema_builder.add_text_field("text", TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut writer = index.writer_for_tests()?;
for doc in docs {
writer.add_document(doc!(text => doc))?;
}
writer.commit()?;
Ok(index)
}
fn create_boolean_query_with_mr<T: IntoIterator<Item = &'static str>>(
queries: T,
field: Field,
mr: usize,
) -> BooleanQuery {
let terms = queries
.into_iter()
.map(|t| Term::from_field_text(field, t))
.map(|t| TermQuery::new(t, IndexRecordOption::Basic))
.map(|q| -> Box<dyn Query> { Box::new(q) })
.collect();
BooleanQuery::union_with_minimum_required_clauses(terms, mr)
}
fn check_doc_id<T: IntoIterator<Item = DocId>>(
expected: T,
actually: HashSet<DocAddress>,
seg: u32,
) {
assert_eq!(
actually,
expected
.into_iter()
.map(|id| DocAddress::new(seg, id))
.collect()
);
}
let index = create_test_index_with(["a b c", "a c e", "d f g", "z z z", "c i b"])?;
let searcher = index.reader()?.searcher();
let text = index.schema().get_field("text").unwrap();
// Documents contains 'a c' 'a z' 'a i' 'c z' 'c i' or 'z i' shall be return.
let q1 = create_boolean_query_with_mr(["a", "c", "z", "i"], text, 2);
let docs = searcher.search(&q1, &DocSetCollector)?;
check_doc_id([0, 1, 4], docs, 0);
// Documents contains 'a b c', 'a b e', 'a c e' or 'b c e' shall be return.
let q2 = create_boolean_query_with_mr(["a", "b", "c", "e"], text, 3);
let docs = searcher.search(&q2, &DocSetCollector)?;
check_doc_id([0, 1], docs, 0);
// Nothing queried since minimum_required is too large.
let q3 = create_boolean_query_with_mr(["a", "b"], text, 3);
let docs = searcher.search(&q3, &DocSetCollector)?;
assert!(docs.is_empty());
// When mr is set to zero or one, there are no difference with `Boolean::Union`.
let q4 = create_boolean_query_with_mr(["a", "z"], text, 1);
let docs = searcher.search(&q4, &DocSetCollector)?;
check_doc_id([0, 1, 3], docs, 0);
let q5 = create_boolean_query_with_mr(["a", "b"], text, 0);
let docs = searcher.search(&q5, &DocSetCollector)?;
check_doc_id([0, 1, 4], docs, 0);
Ok(())
}
#[test] #[test]
fn test_union() -> crate::Result<()> { fn test_union() -> crate::Result<()> {
let index = create_test_index()?; let index = create_test_index()?;

View File

@@ -3,7 +3,6 @@ use std::collections::HashMap;
use crate::docset::COLLECT_BLOCK_BUFFER_LEN; use crate::docset::COLLECT_BLOCK_BUFFER_LEN;
use crate::index::SegmentReader; use crate::index::SegmentReader;
use crate::postings::FreqReadingOption; use crate::postings::FreqReadingOption;
use crate::query::disjunction::Disjunction;
use crate::query::explanation::does_not_match; use crate::query::explanation::does_not_match;
use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner}; use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner};
use crate::query::term_query::TermScorer; use crate::query::term_query::TermScorer;
@@ -19,26 +18,6 @@ enum SpecializedScorer {
Other(Box<dyn Scorer>), Other(Box<dyn Scorer>),
} }
fn scorer_disjunction<TScoreCombiner>(
scorers: Vec<Box<dyn Scorer>>,
score_combiner: TScoreCombiner,
minimum_match_required: usize,
) -> Box<dyn Scorer>
where
TScoreCombiner: ScoreCombiner,
{
debug_assert!(!scorers.is_empty());
debug_assert!(minimum_match_required > 1);
if scorers.len() == 1 {
return scorers.into_iter().next().unwrap(); // Safe unwrap.
}
Box::new(Disjunction::new(
scorers,
score_combiner,
minimum_match_required,
))
}
fn scorer_union<TScoreCombiner>( fn scorer_union<TScoreCombiner>(
scorers: Vec<Box<dyn Scorer>>, scorers: Vec<Box<dyn Scorer>>,
score_combiner_fn: impl Fn() -> TScoreCombiner, score_combiner_fn: impl Fn() -> TScoreCombiner,
@@ -91,7 +70,6 @@ fn into_box_scorer<TScoreCombiner: ScoreCombiner>(
/// Weight associated to the `BoolQuery`. /// Weight associated to the `BoolQuery`.
pub struct BooleanWeight<TScoreCombiner: ScoreCombiner> { pub struct BooleanWeight<TScoreCombiner: ScoreCombiner> {
weights: Vec<(Occur, Box<dyn Weight>)>, weights: Vec<(Occur, Box<dyn Weight>)>,
minimum_number_should_match: usize,
scoring_enabled: bool, scoring_enabled: bool,
score_combiner_fn: Box<dyn Fn() -> TScoreCombiner + Sync + Send>, score_combiner_fn: Box<dyn Fn() -> TScoreCombiner + Sync + Send>,
} }
@@ -107,22 +85,6 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
weights, weights,
scoring_enabled, scoring_enabled,
score_combiner_fn, score_combiner_fn,
minimum_number_should_match: 1,
}
}
/// Create a new boolean weight with minimum number of required should clauses specified.
pub fn with_minimum_number_should_match(
weights: Vec<(Occur, Box<dyn Weight>)>,
minimum_number_should_match: usize,
scoring_enabled: bool,
score_combiner_fn: Box<dyn Fn() -> TScoreCombiner + Sync + Send + 'static>,
) -> BooleanWeight<TScoreCombiner> {
BooleanWeight {
weights,
minimum_number_should_match,
scoring_enabled,
score_combiner_fn,
} }
} }
@@ -149,89 +111,43 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
score_combiner_fn: impl Fn() -> TComplexScoreCombiner, score_combiner_fn: impl Fn() -> TComplexScoreCombiner,
) -> crate::Result<SpecializedScorer> { ) -> crate::Result<SpecializedScorer> {
let mut per_occur_scorers = self.per_occur_scorers(reader, boost)?; let mut per_occur_scorers = self.per_occur_scorers(reader, boost)?;
// Indicate how should clauses are combined with other clauses.
enum CombinationMethod { let should_scorer_opt: Option<SpecializedScorer> = per_occur_scorers
Ignored, .remove(&Occur::Should)
// Only contributes to final score. .map(|scorers| scorer_union(scorers, &score_combiner_fn));
Optional(SpecializedScorer),
// Must be fitted.
Required(Box<dyn Scorer>),
}
let mut must_scorers = per_occur_scorers.remove(&Occur::Must);
let should_opt = if let Some(mut should_scorers) = per_occur_scorers.remove(&Occur::Should)
{
let num_of_should_scorers = should_scorers.len();
if self.minimum_number_should_match > num_of_should_scorers {
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)));
}
match self.minimum_number_should_match {
0 => CombinationMethod::Optional(scorer_union(should_scorers, &score_combiner_fn)),
1 => CombinationMethod::Required(into_box_scorer(
scorer_union(should_scorers, &score_combiner_fn),
&score_combiner_fn,
)),
n if num_of_should_scorers == n => {
// When num_of_should_scorers equals the number of should clauses,
// they are no different from must clauses.
must_scorers = match must_scorers.take() {
Some(mut must_scorers) => {
must_scorers.append(&mut should_scorers);
Some(must_scorers)
}
None => Some(should_scorers),
};
CombinationMethod::Ignored
}
_ => CombinationMethod::Required(scorer_disjunction(
should_scorers,
score_combiner_fn(),
self.minimum_number_should_match,
)),
}
} else {
// None of should clauses are provided.
if self.minimum_number_should_match > 0 {
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)));
} else {
CombinationMethod::Ignored
}
};
let exclude_scorer_opt: Option<Box<dyn Scorer>> = per_occur_scorers let exclude_scorer_opt: Option<Box<dyn Scorer>> = per_occur_scorers
.remove(&Occur::MustNot) .remove(&Occur::MustNot)
.map(|scorers| scorer_union(scorers, DoNothingCombiner::default)) .map(|scorers| scorer_union(scorers, DoNothingCombiner::default))
.map(|specialized_scorer: SpecializedScorer| { .map(|specialized_scorer| {
into_box_scorer(specialized_scorer, DoNothingCombiner::default) into_box_scorer(specialized_scorer, DoNothingCombiner::default)
}); });
let positive_scorer = match (should_opt, must_scorers) {
(CombinationMethod::Ignored, Some(must_scorers)) => { let must_scorer_opt: Option<Box<dyn Scorer>> = per_occur_scorers
SpecializedScorer::Other(intersect_scorers(must_scorers)) .remove(&Occur::Must)
} .map(intersect_scorers);
(CombinationMethod::Optional(should_scorer), Some(must_scorers)) => {
let must_scorer = intersect_scorers(must_scorers); let positive_scorer: SpecializedScorer = match (should_scorer_opt, must_scorer_opt) {
(Some(should_scorer), Some(must_scorer)) => {
if self.scoring_enabled { if self.scoring_enabled {
SpecializedScorer::Other(Box::new( SpecializedScorer::Other(Box::new(RequiredOptionalScorer::<
RequiredOptionalScorer::<_, _, TScoreCombiner>::new( Box<dyn Scorer>,
must_scorer, Box<dyn Scorer>,
into_box_scorer(should_scorer, &score_combiner_fn), TComplexScoreCombiner,
), >::new(
)) must_scorer,
into_box_scorer(should_scorer, &score_combiner_fn),
)))
} else { } else {
SpecializedScorer::Other(must_scorer) SpecializedScorer::Other(must_scorer)
} }
} }
(CombinationMethod::Required(should_scorer), Some(mut must_scorers)) => { (None, Some(must_scorer)) => SpecializedScorer::Other(must_scorer),
must_scorers.push(should_scorer); (Some(should_scorer), None) => should_scorer,
SpecializedScorer::Other(intersect_scorers(must_scorers)) (None, None) => {
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)));
} }
(CombinationMethod::Ignored, None) => {
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)))
}
(CombinationMethod::Required(should_scorer), None) => {
SpecializedScorer::Other(should_scorer)
}
// Optional options are promoted to required if no must scorers exists.
(CombinationMethod::Optional(should_scorer), None) => should_scorer,
}; };
if let Some(exclude_scorer) = exclude_scorer_opt { if let Some(exclude_scorer) = exclude_scorer_opt {
let positive_scorer_boxed = into_box_scorer(positive_scorer, &score_combiner_fn); let positive_scorer_boxed = into_box_scorer(positive_scorer, &score_combiner_fn);
Ok(SpecializedScorer::Other(Box::new(Exclude::new( Ok(SpecializedScorer::Other(Box::new(Exclude::new(

View File

@@ -1,327 +0,0 @@
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use crate::query::score_combiner::DoNothingCombiner;
use crate::query::{ScoreCombiner, Scorer};
use crate::{DocId, DocSet, Score, TERMINATED};
/// `Disjunction` is responsible for merging `DocSet` from multiple
/// source. Specifically, It takes the union of two or more `DocSet`s
/// then filtering out elements that appear fewer times than a
/// specified threshold.
pub struct Disjunction<TScorer, TScoreCombiner = DoNothingCombiner> {
chains: BinaryHeap<ScorerWrapper<TScorer>>,
minimum_matches_required: usize,
score_combiner: TScoreCombiner,
current_doc: DocId,
current_score: Score,
}
/// A wrapper around a `Scorer` that caches the current `doc_id` and implements the `DocSet` trait.
/// Also, the `Ord` trait and it's family are implemented reversely. So that we can combine
/// `std::BinaryHeap<ScorerWrapper<T>>` to gain a min-heap with current doc id as key.
struct ScorerWrapper<T> {
scorer: T,
current_doc: DocId,
}
impl<T: Scorer> ScorerWrapper<T> {
fn new(scorer: T) -> Self {
let current_doc = scorer.doc();
Self {
scorer,
current_doc,
}
}
}
impl<T: Scorer> PartialEq for ScorerWrapper<T> {
fn eq(&self, other: &Self) -> bool {
self.doc() == other.doc()
}
}
impl<T: Scorer> Eq for ScorerWrapper<T> {}
impl<T: Scorer> PartialOrd for ScorerWrapper<T> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<T: Scorer> Ord for ScorerWrapper<T> {
fn cmp(&self, other: &Self) -> Ordering {
self.doc().cmp(&other.doc()).reverse()
}
}
impl<T: Scorer> DocSet for ScorerWrapper<T> {
fn advance(&mut self) -> DocId {
let doc_id = self.scorer.advance();
self.current_doc = doc_id;
doc_id
}
fn doc(&self) -> DocId {
self.current_doc
}
fn size_hint(&self) -> u32 {
self.scorer.size_hint()
}
}
impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> Disjunction<TScorer, TScoreCombiner> {
pub fn new<T: IntoIterator<Item = TScorer>>(
docsets: T,
score_combiner: TScoreCombiner,
minimum_matches_required: usize,
) -> Self {
debug_assert!(
minimum_matches_required > 1,
"union scorer works better if just one matches required"
);
let chains = docsets
.into_iter()
.map(|doc| ScorerWrapper::new(doc))
.collect();
let mut disjunction = Self {
chains,
score_combiner,
current_doc: TERMINATED,
minimum_matches_required,
current_score: 0.0,
};
if minimum_matches_required > disjunction.chains.len() {
return disjunction;
}
disjunction.advance();
disjunction
}
}
impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> DocSet
for Disjunction<TScorer, TScoreCombiner>
{
fn advance(&mut self) -> DocId {
let mut current_num_matches = 0;
while let Some(mut candidate) = self.chains.pop() {
let next = candidate.doc();
if next != TERMINATED {
// Peek next doc.
if self.current_doc != next {
if current_num_matches >= self.minimum_matches_required {
self.chains.push(candidate);
self.current_score = self.score_combiner.score();
return self.current_doc;
}
// Reset current_num_matches and scores.
current_num_matches = 0;
self.current_doc = next;
self.score_combiner.clear();
}
current_num_matches += 1;
self.score_combiner.update(&mut candidate.scorer);
candidate.advance();
self.chains.push(candidate);
}
}
if current_num_matches < self.minimum_matches_required {
self.current_doc = TERMINATED;
}
self.current_score = self.score_combiner.score();
self.current_doc
}
#[inline]
fn doc(&self) -> DocId {
self.current_doc
}
fn size_hint(&self) -> u32 {
self.chains
.iter()
.map(|docset| docset.size_hint())
.max()
.unwrap_or(0u32)
}
}
impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> Scorer
for Disjunction<TScorer, TScoreCombiner>
{
fn score(&mut self) -> Score {
self.current_score
}
}
#[cfg(test)]
mod tests {
use std::collections::BTreeMap;
use super::Disjunction;
use crate::query::score_combiner::DoNothingCombiner;
use crate::query::{ConstScorer, Scorer, SumCombiner, VecDocSet};
use crate::{DocId, DocSet, Score, TERMINATED};
fn conjunct<T: Ord + Copy>(arrays: &[Vec<T>], pass_line: usize) -> Vec<T> {
let mut counts = BTreeMap::new();
for array in arrays {
for &element in array {
*counts.entry(element).or_insert(0) += 1;
}
}
counts
.iter()
.filter_map(|(&element, &count)| {
if count >= pass_line {
Some(element)
} else {
None
}
})
.collect()
}
fn aux_test_conjunction(vals: Vec<Vec<u32>>, min_match: usize) {
let mut union_expected = VecDocSet::from(conjunct(&vals, min_match));
let make_scorer = || {
Disjunction::new(
vals.iter()
.cloned()
.map(VecDocSet::from)
.map(|d| ConstScorer::new(d, 1.0)),
DoNothingCombiner,
min_match,
)
};
let mut scorer: Disjunction<_, DoNothingCombiner> = make_scorer();
let mut count = 0;
while scorer.doc() != TERMINATED {
assert_eq!(union_expected.doc(), scorer.doc());
assert_eq!(union_expected.advance(), scorer.advance());
count += 1;
}
assert_eq!(union_expected.advance(), TERMINATED);
assert_eq!(count, make_scorer().count_including_deleted());
}
#[should_panic]
#[test]
fn test_arg_check1() {
aux_test_conjunction(vec![], 0);
}
#[should_panic]
#[test]
fn test_arg_check2() {
aux_test_conjunction(vec![], 1);
}
#[test]
fn test_corner_case() {
aux_test_conjunction(vec![], 2);
aux_test_conjunction(vec![vec![]; 1000], 2);
aux_test_conjunction(vec![vec![]; 100], usize::MAX);
aux_test_conjunction(vec![vec![0xC0FFEE]; 10000], usize::MAX);
aux_test_conjunction((1..10000u32).map(|i| vec![i]).collect::<Vec<_>>(), 2);
}
#[test]
fn test_conjunction() {
aux_test_conjunction(
vec![
vec![1, 3333, 100000000u32],
vec![1, 2, 100000000u32],
vec![1, 2, 100000000u32],
],
2,
);
aux_test_conjunction(
vec![vec![8], vec![3, 4, 0xC0FFEEu32], vec![1, 2, 100000000u32]],
2,
);
aux_test_conjunction(
vec![
vec![1, 3333, 100000000u32],
vec![1, 2, 100000000u32],
vec![1, 2, 100000000u32],
],
3,
)
}
// This dummy scorer does nothing but yield doc id increasingly.
// with constant score 1.0
#[derive(Clone)]
struct DummyScorer {
cursor: usize,
foo: Vec<(DocId, f32)>,
}
impl DummyScorer {
fn new(doc_score: Vec<(DocId, f32)>) -> Self {
Self {
cursor: 0,
foo: doc_score,
}
}
}
impl DocSet for DummyScorer {
fn advance(&mut self) -> DocId {
self.cursor += 1;
self.doc()
}
fn doc(&self) -> DocId {
self.foo.get(self.cursor).map(|x| x.0).unwrap_or(TERMINATED)
}
fn size_hint(&self) -> u32 {
self.foo.len() as u32
}
}
impl Scorer for DummyScorer {
fn score(&mut self) -> Score {
self.foo.get(self.cursor).map(|x| x.1).unwrap_or(0.0)
}
}
#[test]
fn test_score_calculate() {
let mut scorer = Disjunction::new(
vec![
DummyScorer::new(vec![(1, 1f32), (2, 1f32)]),
DummyScorer::new(vec![(1, 1f32), (3, 1f32)]),
DummyScorer::new(vec![(1, 1f32), (4, 1f32)]),
DummyScorer::new(vec![(1, 1f32), (2, 1f32)]),
DummyScorer::new(vec![(1, 1f32), (2, 1f32)]),
],
SumCombiner::default(),
3,
);
assert_eq!(scorer.score(), 5.0);
assert_eq!(scorer.advance(), 2);
assert_eq!(scorer.score(), 3.0);
}
#[test]
fn test_score_calculate_corner_case() {
let mut scorer = Disjunction::new(
vec![
DummyScorer::new(vec![(1, 1f32), (2, 1f32)]),
DummyScorer::new(vec![(1, 1f32), (3, 1f32)]),
DummyScorer::new(vec![(1, 1f32), (3, 1f32)]),
],
SumCombiner::default(),
2,
);
assert_eq!(scorer.doc(), 1);
assert_eq!(scorer.score(), 3.0);
assert_eq!(scorer.advance(), 3);
assert_eq!(scorer.score(), 2.0);
}
}

View File

@@ -149,7 +149,7 @@ mod tests {
use crate::query::exist_query::ExistsQuery; use crate::query::exist_query::ExistsQuery;
use crate::query::{BooleanQuery, RangeQuery}; use crate::query::{BooleanQuery, RangeQuery};
use crate::schema::{Facet, FacetOptions, Schema, FAST, INDEXED, STRING, TEXT}; use crate::schema::{Facet, FacetOptions, Schema, FAST, INDEXED, STRING, TEXT};
use crate::{Index, Searcher, Term}; use crate::{Index, Searcher};
#[test] #[test]
fn test_exists_query_simple() -> crate::Result<()> { fn test_exists_query_simple() -> crate::Result<()> {
@@ -188,8 +188,9 @@ mod tests {
// exercise seek // exercise seek
let query = BooleanQuery::intersection(vec![ let query = BooleanQuery::intersection(vec![
Box::new(RangeQuery::new( Box::new(RangeQuery::new_u64_bounds(
Bound::Included(Term::from_field_u64(all_field, 50)), "all".to_string(),
Bound::Included(50),
Bound::Unbounded, Bound::Unbounded,
)), )),
Box::new(ExistsQuery::new_exists_query("even".to_string())), Box::new(ExistsQuery::new_exists_query("even".to_string())),
@@ -197,9 +198,10 @@ mod tests {
assert_eq!(searcher.search(&query, &Count)?, 25); assert_eq!(searcher.search(&query, &Count)?, 25);
let query = BooleanQuery::intersection(vec![ let query = BooleanQuery::intersection(vec![
Box::new(RangeQuery::new( Box::new(RangeQuery::new_u64_bounds(
Bound::Included(Term::from_field_u64(all_field, 0)), "all".to_string(),
Bound::Included(Term::from_field_u64(all_field, 50)), Bound::Included(0),
Bound::Excluded(50),
)), )),
Box::new(ExistsQuery::new_exists_query("odd".to_string())), Box::new(ExistsQuery::new_exists_query("odd".to_string())),
]); ]);

View File

@@ -5,7 +5,6 @@ mod bm25;
mod boolean_query; mod boolean_query;
mod boost_query; mod boost_query;
mod const_score_query; mod const_score_query;
mod disjunction;
mod disjunction_max_query; mod disjunction_max_query;
mod empty_query; mod empty_query;
mod exclude; mod exclude;
@@ -54,7 +53,7 @@ pub use self::phrase_prefix_query::PhrasePrefixQuery;
pub use self::phrase_query::PhraseQuery; pub use self::phrase_query::PhraseQuery;
pub use self::query::{EnableScoring, Query, QueryClone}; pub use self::query::{EnableScoring, Query, QueryClone};
pub use self::query_parser::{QueryParser, QueryParserError}; pub use self::query_parser::{QueryParser, QueryParserError};
pub use self::range_query::{FastFieldRangeWeight, RangeQuery}; pub use self::range_query::{FastFieldRangeWeight, IPFastFieldRangeWeight, RangeQuery};
pub use self::regex_query::RegexQuery; pub use self::regex_query::RegexQuery;
pub use self::reqopt_scorer::RequiredOptionalScorer; pub use self::reqopt_scorer::RequiredOptionalScorer;
pub use self::score_combiner::{ pub use self::score_combiner::{

View File

@@ -145,7 +145,15 @@ impl Query for PhrasePrefixQuery {
Bound::Unbounded Bound::Unbounded
}; };
let mut range_query = RangeQuery::new(Bound::Included(self.prefix.1.clone()), end_term); let mut range_query = RangeQuery::new_term_bounds(
enable_scoring
.schema()
.get_field_name(self.field)
.to_owned(),
self.prefix.1.typ(),
&Bound::Included(self.prefix.1.clone()),
&end_term,
);
range_query.limit(self.max_expansions as u64); range_query.limit(self.max_expansions as u64);
range_query.weight(enable_scoring) range_query.weight(enable_scoring)
} }

View File

@@ -97,7 +97,6 @@ pub struct PhrasePrefixScorer<TPostings: Postings> {
suffixes: Vec<TPostings>, suffixes: Vec<TPostings>,
suffix_offset: u32, suffix_offset: u32,
phrase_count: u32, phrase_count: u32,
suffix_position_buffer: Vec<u32>,
} }
impl<TPostings: Postings> PhrasePrefixScorer<TPostings> { impl<TPostings: Postings> PhrasePrefixScorer<TPostings> {
@@ -141,7 +140,6 @@ impl<TPostings: Postings> PhrasePrefixScorer<TPostings> {
suffixes, suffixes,
suffix_offset: (max_offset - suffix_pos) as u32, suffix_offset: (max_offset - suffix_pos) as u32,
phrase_count: 0, phrase_count: 0,
suffix_position_buffer: Vec::with_capacity(100),
}; };
if phrase_prefix_scorer.doc() != TERMINATED && !phrase_prefix_scorer.matches_prefix() { if phrase_prefix_scorer.doc() != TERMINATED && !phrase_prefix_scorer.matches_prefix() {
phrase_prefix_scorer.advance(); phrase_prefix_scorer.advance();
@@ -155,6 +153,7 @@ impl<TPostings: Postings> PhrasePrefixScorer<TPostings> {
fn matches_prefix(&mut self) -> bool { fn matches_prefix(&mut self) -> bool {
let mut count = 0; let mut count = 0;
let mut positions = Vec::new();
let current_doc = self.doc(); let current_doc = self.doc();
let pos_matching = self.phrase_scorer.get_intersection(); let pos_matching = self.phrase_scorer.get_intersection();
for suffix in &mut self.suffixes { for suffix in &mut self.suffixes {
@@ -163,8 +162,8 @@ impl<TPostings: Postings> PhrasePrefixScorer<TPostings> {
} }
let doc = suffix.seek(current_doc); let doc = suffix.seek(current_doc);
if doc == current_doc { if doc == current_doc {
suffix.positions_with_offset(self.suffix_offset, &mut self.suffix_position_buffer); suffix.positions_with_offset(self.suffix_offset, &mut positions);
count += intersection_count(pos_matching, &self.suffix_position_buffer); count += intersection_count(pos_matching, &positions);
} }
} }
self.phrase_count = count as u32; self.phrase_count = count as u32;

View File

@@ -2,7 +2,7 @@ use std::fmt;
use std::ops::Bound; use std::ops::Bound;
use crate::query::Occur; use crate::query::Occur;
use crate::schema::Term; use crate::schema::{Term, Type};
use crate::Score; use crate::Score;
#[derive(Clone)] #[derive(Clone)]
@@ -14,6 +14,8 @@ pub enum LogicalLiteral {
prefix: bool, prefix: bool,
}, },
Range { Range {
field: String,
value_type: Type,
lower: Bound<Term>, lower: Bound<Term>,
upper: Bound<Term>, upper: Bound<Term>,
}, },

View File

@@ -790,6 +790,8 @@ impl QueryParser {
let (field, json_path) = try_tuple!(self let (field, json_path) = try_tuple!(self
.split_full_path(&full_path) .split_full_path(&full_path)
.ok_or_else(|| QueryParserError::FieldDoesNotExist(full_path.clone()))); .ok_or_else(|| QueryParserError::FieldDoesNotExist(full_path.clone())));
let field_entry = self.schema.get_field_entry(field);
let value_type = field_entry.field_type().value_type();
let mut errors = Vec::new(); let mut errors = Vec::new();
let lower = match self.resolve_bound(field, json_path, &lower) { let lower = match self.resolve_bound(field, json_path, &lower) {
Ok(bound) => bound, Ok(bound) => bound,
@@ -810,8 +812,12 @@ impl QueryParser {
// we failed to parse something. Either way, there is no point emiting it // we failed to parse something. Either way, there is no point emiting it
return (None, errors); return (None, errors);
} }
let logical_ast = let logical_ast = LogicalAst::Leaf(Box::new(LogicalLiteral::Range {
LogicalAst::Leaf(Box::new(LogicalLiteral::Range { lower, upper })); field: self.schema.get_field_name(field).to_string(),
value_type,
lower,
upper,
}));
(Some(logical_ast), errors) (Some(logical_ast), errors)
} }
UserInputLeaf::Set { UserInputLeaf::Set {
@@ -878,7 +884,14 @@ fn convert_literal_to_query(
Box::new(PhraseQuery::new_with_offset_and_slop(terms, slop)) Box::new(PhraseQuery::new_with_offset_and_slop(terms, slop))
} }
} }
LogicalLiteral::Range { lower, upper } => Box::new(RangeQuery::new(lower, upper)), LogicalLiteral::Range {
field,
value_type,
lower,
upper,
} => Box::new(RangeQuery::new_term_bounds(
field, value_type, &lower, &upper,
)),
LogicalLiteral::Set { elements, .. } => Box::new(TermSetQuery::new(elements)), LogicalLiteral::Set { elements, .. } => Box::new(TermSetQuery::new(elements)),
LogicalLiteral::All => Box::new(AllQuery), LogicalLiteral::All => Box::new(AllQuery),
} }
@@ -1123,8 +1136,8 @@ mod test {
let query = make_query_parser().parse_query("title:[A TO B]").unwrap(); let query = make_query_parser().parse_query("title:[A TO B]").unwrap();
assert_eq!( assert_eq!(
format!("{query:?}"), format!("{query:?}"),
"RangeQuery { lower_bound: Included(Term(field=0, type=Str, \"a\")), upper_bound: \ "RangeQuery { field: \"title\", value_type: Str, lower_bound: Included([97]), \
Included(Term(field=0, type=Str, \"b\")), limit: None }" upper_bound: Included([98]), limit: None }"
); );
} }
@@ -1802,8 +1815,7 @@ mod test {
\"bad\"))], prefix: (2, Term(field=0, type=Str, \"wo\")), max_expansions: 50 }), \ \"bad\"))], prefix: (2, Term(field=0, type=Str, \"wo\")), max_expansions: 50 }), \
(Should, PhrasePrefixQuery { field: Field(1), phrase_terms: [(0, Term(field=1, \ (Should, PhrasePrefixQuery { field: Field(1), phrase_terms: [(0, Term(field=1, \
type=Str, \"big\")), (1, Term(field=1, type=Str, \"bad\"))], prefix: (2, \ type=Str, \"big\")), (1, Term(field=1, type=Str, \"bad\"))], prefix: (2, \
Term(field=1, type=Str, \"wo\")), max_expansions: 50 })], \ Term(field=1, type=Str, \"wo\")), max_expansions: 50 })] }"
minimum_number_should_match: 1 }"
); );
} }
@@ -1868,8 +1880,7 @@ mod test {
format!("{query:?}"), format!("{query:?}"),
"BooleanQuery { subqueries: [(Should, FuzzyTermQuery { term: Term(field=0, \ "BooleanQuery { subqueries: [(Should, FuzzyTermQuery { term: Term(field=0, \
type=Str, \"abc\"), distance: 1, transposition_cost_one: true, prefix: false }), \ type=Str, \"abc\"), distance: 1, transposition_cost_one: true, prefix: false }), \
(Should, TermQuery(Term(field=1, type=Str, \"abc\")))], \ (Should, TermQuery(Term(field=1, type=Str, \"abc\")))] }"
minimum_number_should_match: 1 }"
); );
} }
@@ -1886,8 +1897,7 @@ mod test {
format!("{query:?}"), format!("{query:?}"),
"BooleanQuery { subqueries: [(Should, TermQuery(Term(field=0, type=Str, \ "BooleanQuery { subqueries: [(Should, TermQuery(Term(field=0, type=Str, \
\"abc\"))), (Should, FuzzyTermQuery { term: Term(field=1, type=Str, \"abc\"), \ \"abc\"))), (Should, FuzzyTermQuery { term: Term(field=1, type=Str, \"abc\"), \
distance: 2, transposition_cost_one: false, prefix: true })], \ distance: 2, transposition_cost_one: false, prefix: true })] }"
minimum_number_should_match: 1 }"
); );
} }
} }

View File

@@ -180,12 +180,10 @@ impl<T: Send + Sync + PartialOrd + Copy + Debug + 'static> DocSet for RangeDocSe
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::ops::Bound;
use crate::collector::Count; use crate::collector::Count;
use crate::directory::RamDirectory; use crate::directory::RamDirectory;
use crate::query::RangeQuery; use crate::query::RangeQuery;
use crate::{schema, IndexBuilder, TantivyDocument, Term}; use crate::{schema, IndexBuilder, TantivyDocument};
#[test] #[test]
fn range_query_fast_optional_field_minimum() { fn range_query_fast_optional_field_minimum() {
@@ -220,9 +218,10 @@ mod tests {
let reader = index.reader().unwrap(); let reader = index.reader().unwrap();
let searcher = reader.searcher(); let searcher = reader.searcher();
let query = RangeQuery::new( let query = RangeQuery::new_u64_bounds(
Bound::Included(Term::from_field_u64(score_field, 70)), "score".to_string(),
Bound::Unbounded, std::ops::Bound::Included(70),
std::ops::Bound::Unbounded,
); );
let count = searcher.search(&query, &Count).unwrap(); let count = searcher.search(&query, &Count).unwrap();

View File

@@ -2,19 +2,21 @@ use std::ops::Bound;
use crate::schema::Type; use crate::schema::Type;
mod fast_field_range_doc_set; mod fast_field_range_query;
mod range_query; mod range_query;
mod range_query_ip_fastfield;
mod range_query_u64_fastfield; mod range_query_u64_fastfield;
pub use self::range_query::RangeQuery; pub use self::range_query::RangeQuery;
pub use self::range_query_ip_fastfield::IPFastFieldRangeWeight;
pub use self::range_query_u64_fastfield::FastFieldRangeWeight; pub use self::range_query_u64_fastfield::FastFieldRangeWeight;
// TODO is this correct? // TODO is this correct?
pub(crate) fn is_type_valid_for_fastfield_range_query(typ: Type) -> bool { pub(crate) fn is_type_valid_for_fastfield_range_query(typ: Type) -> bool {
match typ { match typ {
Type::Str | Type::U64 | Type::I64 | Type::F64 | Type::Bool | Type::Date => true, Type::U64 | Type::I64 | Type::F64 | Type::Bool | Type::Date => true,
Type::IpAddr => true, Type::IpAddr => true,
Type::Facet | Type::Bytes | Type::Json => false, Type::Str | Type::Facet | Type::Bytes | Type::Json => false,
} }
} }

View File

@@ -1,17 +1,21 @@
use std::io; use std::io;
use std::ops::Bound; use std::net::Ipv6Addr;
use std::ops::{Bound, Range};
use common::BitSet; use columnar::MonotonicallyMappableToU128;
use common::{BinarySerializable, BitSet};
use super::map_bound; use super::map_bound;
use super::range_query_u64_fastfield::FastFieldRangeWeight; use super::range_query_u64_fastfield::FastFieldRangeWeight;
use crate::error::TantivyError;
use crate::index::SegmentReader; use crate::index::SegmentReader;
use crate::query::explanation::does_not_match; use crate::query::explanation::does_not_match;
use crate::query::range_query::is_type_valid_for_fastfield_range_query; use crate::query::range_query::range_query_ip_fastfield::IPFastFieldRangeWeight;
use crate::query::range_query::{is_type_valid_for_fastfield_range_query, map_bound_res};
use crate::query::{BitSetDocSet, ConstScorer, EnableScoring, Explanation, Query, Scorer, Weight}; use crate::query::{BitSetDocSet, ConstScorer, EnableScoring, Explanation, Query, Scorer, Weight};
use crate::schema::{Field, IndexRecordOption, Term, Type}; use crate::schema::{Field, IndexRecordOption, Term, Type};
use crate::termdict::{TermDictionary, TermStreamer}; use crate::termdict::{TermDictionary, TermStreamer};
use crate::{DocId, Score}; use crate::{DateTime, DocId, Score};
/// `RangeQuery` matches all documents that have at least one term within a defined range. /// `RangeQuery` matches all documents that have at least one term within a defined range.
/// ///
@@ -36,10 +40,8 @@ use crate::{DocId, Score};
/// ```rust /// ```rust
/// use tantivy::collector::Count; /// use tantivy::collector::Count;
/// use tantivy::query::RangeQuery; /// use tantivy::query::RangeQuery;
/// use tantivy::Term;
/// use tantivy::schema::{Schema, INDEXED}; /// use tantivy::schema::{Schema, INDEXED};
/// use tantivy::{doc, Index, IndexWriter}; /// use tantivy::{doc, Index, IndexWriter};
/// use std::ops::Bound;
/// # fn test() -> tantivy::Result<()> { /// # fn test() -> tantivy::Result<()> {
/// let mut schema_builder = Schema::builder(); /// let mut schema_builder = Schema::builder();
/// let year_field = schema_builder.add_u64_field("year", INDEXED); /// let year_field = schema_builder.add_u64_field("year", INDEXED);
@@ -57,10 +59,7 @@ use crate::{DocId, Score};
/// ///
/// let reader = index.reader()?; /// let reader = index.reader()?;
/// let searcher = reader.searcher(); /// let searcher = reader.searcher();
/// let docs_in_the_sixties = RangeQuery::new( /// let docs_in_the_sixties = RangeQuery::new_u64("year".to_string(), 1960..1970);
/// Bound::Included(Term::from_field_u64(year_field, 1960)),
/// Bound::Excluded(Term::from_field_u64(year_field, 1970)),
/// );
/// let num_60s_books = searcher.search(&docs_in_the_sixties, &Count)?; /// let num_60s_books = searcher.search(&docs_in_the_sixties, &Count)?;
/// assert_eq!(num_60s_books, 2285); /// assert_eq!(num_60s_books, 2285);
/// Ok(()) /// Ok(())
@@ -69,46 +68,246 @@ use crate::{DocId, Score};
/// ``` /// ```
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct RangeQuery { pub struct RangeQuery {
lower_bound: Bound<Term>, field: String,
upper_bound: Bound<Term>, value_type: Type,
lower_bound: Bound<Vec<u8>>,
upper_bound: Bound<Vec<u8>>,
limit: Option<u64>, limit: Option<u64>,
} }
/// Returns the inner value of a `Bound`
pub(crate) fn inner_bound(val: &Bound<Term>) -> Option<&Term> {
match val {
Bound::Included(term) | Bound::Excluded(term) => Some(term),
Bound::Unbounded => None,
}
}
impl RangeQuery { impl RangeQuery {
/// Creates a new `RangeQuery` from bounded start and end terms. /// Creates a new `RangeQuery` from bounded start and end terms.
/// ///
/// If the value type is not correct, something may go terribly wrong when /// If the value type is not correct, something may go terribly wrong when
/// the `Weight` object is created. /// the `Weight` object is created.
pub fn new(lower_bound: Bound<Term>, upper_bound: Bound<Term>) -> RangeQuery { pub fn new_term_bounds(
field: String,
value_type: Type,
lower_bound: &Bound<Term>,
upper_bound: &Bound<Term>,
) -> RangeQuery {
let verify_and_unwrap_term = |val: &Term| val.serialized_value_bytes().to_owned();
RangeQuery { RangeQuery {
lower_bound, field,
upper_bound, value_type,
lower_bound: map_bound(lower_bound, verify_and_unwrap_term),
upper_bound: map_bound(upper_bound, verify_and_unwrap_term),
limit: None, limit: None,
} }
} }
/// Creates a new `RangeQuery` over a `i64` field.
///
/// If the field is not of the type `i64`, tantivy
/// will panic when the `Weight` object is created.
pub fn new_i64(field: String, range: Range<i64>) -> RangeQuery {
RangeQuery::new_i64_bounds(
field,
Bound::Included(range.start),
Bound::Excluded(range.end),
)
}
/// Create a new `RangeQuery` over a `i64` field.
///
/// The two `Bound` arguments make it possible to create more complex
/// ranges than semi-inclusive range.
///
/// If the field is not of the type `i64`, tantivy
/// will panic when the `Weight` object is created.
pub fn new_i64_bounds(
field: String,
lower_bound: Bound<i64>,
upper_bound: Bound<i64>,
) -> RangeQuery {
let make_term_val = |val: &i64| {
Term::from_field_i64(Field::from_field_id(0), *val)
.serialized_value_bytes()
.to_owned()
};
RangeQuery {
field,
value_type: Type::I64,
lower_bound: map_bound(&lower_bound, make_term_val),
upper_bound: map_bound(&upper_bound, make_term_val),
limit: None,
}
}
/// Creates a new `RangeQuery` over a `f64` field.
///
/// If the field is not of the type `f64`, tantivy
/// will panic when the `Weight` object is created.
pub fn new_f64(field: String, range: Range<f64>) -> RangeQuery {
RangeQuery::new_f64_bounds(
field,
Bound::Included(range.start),
Bound::Excluded(range.end),
)
}
/// Create a new `RangeQuery` over a `f64` field.
///
/// The two `Bound` arguments make it possible to create more complex
/// ranges than semi-inclusive range.
///
/// If the field is not of the type `f64`, tantivy
/// will panic when the `Weight` object is created.
pub fn new_f64_bounds(
field: String,
lower_bound: Bound<f64>,
upper_bound: Bound<f64>,
) -> RangeQuery {
let make_term_val = |val: &f64| {
Term::from_field_f64(Field::from_field_id(0), *val)
.serialized_value_bytes()
.to_owned()
};
RangeQuery {
field,
value_type: Type::F64,
lower_bound: map_bound(&lower_bound, make_term_val),
upper_bound: map_bound(&upper_bound, make_term_val),
limit: None,
}
}
/// Create a new `RangeQuery` over a `u64` field.
///
/// The two `Bound` arguments make it possible to create more complex
/// ranges than semi-inclusive range.
///
/// If the field is not of the type `u64`, tantivy
/// will panic when the `Weight` object is created.
pub fn new_u64_bounds(
field: String,
lower_bound: Bound<u64>,
upper_bound: Bound<u64>,
) -> RangeQuery {
let make_term_val = |val: &u64| {
Term::from_field_u64(Field::from_field_id(0), *val)
.serialized_value_bytes()
.to_owned()
};
RangeQuery {
field,
value_type: Type::U64,
lower_bound: map_bound(&lower_bound, make_term_val),
upper_bound: map_bound(&upper_bound, make_term_val),
limit: None,
}
}
/// Create a new `RangeQuery` over a `ip` field.
///
/// If the field is not of the type `ip`, tantivy
/// will panic when the `Weight` object is created.
pub fn new_ip_bounds(
field: String,
lower_bound: Bound<Ipv6Addr>,
upper_bound: Bound<Ipv6Addr>,
) -> RangeQuery {
let make_term_val = |val: &Ipv6Addr| {
Term::from_field_ip_addr(Field::from_field_id(0), *val)
.serialized_value_bytes()
.to_owned()
};
RangeQuery {
field,
value_type: Type::IpAddr,
lower_bound: map_bound(&lower_bound, make_term_val),
upper_bound: map_bound(&upper_bound, make_term_val),
limit: None,
}
}
/// Create a new `RangeQuery` over a `u64` field.
///
/// If the field is not of the type `u64`, tantivy
/// will panic when the `Weight` object is created.
pub fn new_u64(field: String, range: Range<u64>) -> RangeQuery {
RangeQuery::new_u64_bounds(
field,
Bound::Included(range.start),
Bound::Excluded(range.end),
)
}
/// Create a new `RangeQuery` over a `date` field.
///
/// The two `Bound` arguments make it possible to create more complex
/// ranges than semi-inclusive range.
///
/// If the field is not of the type `date`, tantivy
/// will panic when the `Weight` object is created.
pub fn new_date_bounds(
field: String,
lower_bound: Bound<DateTime>,
upper_bound: Bound<DateTime>,
) -> RangeQuery {
let make_term_val = |val: &DateTime| {
Term::from_field_date(Field::from_field_id(0), *val)
.serialized_value_bytes()
.to_owned()
};
RangeQuery {
field,
value_type: Type::Date,
lower_bound: map_bound(&lower_bound, make_term_val),
upper_bound: map_bound(&upper_bound, make_term_val),
limit: None,
}
}
/// Create a new `RangeQuery` over a `date` field.
///
/// If the field is not of the type `date`, tantivy
/// will panic when the `Weight` object is created.
pub fn new_date(field: String, range: Range<DateTime>) -> RangeQuery {
RangeQuery::new_date_bounds(
field,
Bound::Included(range.start),
Bound::Excluded(range.end),
)
}
/// Create a new `RangeQuery` over a `Str` field.
///
/// The two `Bound` arguments make it possible to create more complex
/// ranges than semi-inclusive range.
///
/// If the field is not of the type `Str`, tantivy
/// will panic when the `Weight` object is created.
pub fn new_str_bounds(
field: String,
lower_bound: Bound<&str>,
upper_bound: Bound<&str>,
) -> RangeQuery {
let make_term_val = |val: &&str| val.as_bytes().to_vec();
RangeQuery {
field,
value_type: Type::Str,
lower_bound: map_bound(&lower_bound, make_term_val),
upper_bound: map_bound(&upper_bound, make_term_val),
limit: None,
}
}
/// Create a new `RangeQuery` over a `Str` field.
///
/// If the field is not of the type `Str`, tantivy
/// will panic when the `Weight` object is created.
pub fn new_str(field: String, range: Range<&str>) -> RangeQuery {
RangeQuery::new_str_bounds(
field,
Bound::Included(range.start),
Bound::Excluded(range.end),
)
}
/// Field to search over /// Field to search over
pub fn field(&self) -> Field { pub fn field(&self) -> &str {
self.get_term().field() &self.field
}
/// The value type of the field
pub fn value_type(&self) -> Type {
self.get_term().typ()
}
pub(crate) fn get_term(&self) -> &Term {
inner_bound(&self.lower_bound)
.or(inner_bound(&self.upper_bound))
.expect("At least one bound must be set")
} }
/// Limit the number of term the `RangeQuery` will go through. /// Limit the number of term the `RangeQuery` will go through.
@@ -120,23 +319,70 @@ impl RangeQuery {
} }
} }
/// Returns true if the type maps to a u64 fast field
pub(crate) fn maps_to_u64_fastfield(typ: Type) -> bool {
match typ {
Type::U64 | Type::I64 | Type::F64 | Type::Bool | Type::Date => true,
Type::IpAddr => false,
Type::Str | Type::Facet | Type::Bytes | Type::Json => false,
}
}
impl Query for RangeQuery { impl Query for RangeQuery {
fn weight(&self, enable_scoring: EnableScoring<'_>) -> crate::Result<Box<dyn Weight>> { fn weight(&self, enable_scoring: EnableScoring<'_>) -> crate::Result<Box<dyn Weight>> {
let schema = enable_scoring.schema(); let schema = enable_scoring.schema();
let field_type = schema.get_field_entry(self.field()).field_type(); let field_type = schema
.get_field_entry(schema.get_field(&self.field)?)
.field_type();
let value_type = field_type.value_type();
if value_type != self.value_type {
let err_msg = format!(
"Create a range query of the type {:?}, when the field given was of type \
{value_type:?}",
self.value_type
);
return Err(TantivyError::SchemaError(err_msg));
}
if field_type.is_fast() && is_type_valid_for_fastfield_range_query(self.value_type()) { if field_type.is_fast() && is_type_valid_for_fastfield_range_query(self.value_type) {
Ok(Box::new(FastFieldRangeWeight::new( if field_type.is_ip_addr() {
self.field(), let parse_ip_from_bytes = |data: &Vec<u8>| {
self.lower_bound.clone(), let ip_u128_bytes: [u8; 16] = data.as_slice().try_into().map_err(|_| {
self.upper_bound.clone(), crate::TantivyError::InvalidArgument(
))) "Expected 8 bytes for ip address".to_string(),
)
})?;
let ip_u128 = u128::from_be_bytes(ip_u128_bytes);
crate::Result::<Ipv6Addr>::Ok(Ipv6Addr::from_u128(ip_u128))
};
let lower_bound = map_bound_res(&self.lower_bound, parse_ip_from_bytes)?;
let upper_bound = map_bound_res(&self.upper_bound, parse_ip_from_bytes)?;
Ok(Box::new(IPFastFieldRangeWeight::new(
self.field.to_string(),
lower_bound,
upper_bound,
)))
} else {
// We run the range query on u64 value space for performance reasons and simpicity
// assert the type maps to u64
assert!(maps_to_u64_fastfield(self.value_type));
let parse_from_bytes = |data: &Vec<u8>| {
u64::from_be(BinarySerializable::deserialize(&mut &data[..]).unwrap())
};
let lower_bound = map_bound(&self.lower_bound, parse_from_bytes);
let upper_bound = map_bound(&self.upper_bound, parse_from_bytes);
Ok(Box::new(FastFieldRangeWeight::new_u64_lenient(
self.field.to_string(),
lower_bound,
upper_bound,
)))
}
} else { } else {
let verify_and_unwrap_term = |val: &Term| val.serialized_value_bytes().to_owned();
Ok(Box::new(RangeWeight { Ok(Box::new(RangeWeight {
field: self.field(), field: self.field.to_string(),
lower_bound: map_bound(&self.lower_bound, verify_and_unwrap_term), lower_bound: self.lower_bound.clone(),
upper_bound: map_bound(&self.upper_bound, verify_and_unwrap_term), upper_bound: self.upper_bound.clone(),
limit: self.limit, limit: self.limit,
})) }))
} }
@@ -144,7 +390,7 @@ impl Query for RangeQuery {
} }
pub struct RangeWeight { pub struct RangeWeight {
field: Field, field: String,
lower_bound: Bound<Vec<u8>>, lower_bound: Bound<Vec<u8>>,
upper_bound: Bound<Vec<u8>>, upper_bound: Bound<Vec<u8>>,
limit: Option<u64>, limit: Option<u64>,
@@ -177,7 +423,7 @@ impl Weight for RangeWeight {
let max_doc = reader.max_doc(); let max_doc = reader.max_doc();
let mut doc_bitset = BitSet::with_max_value(max_doc); let mut doc_bitset = BitSet::with_max_value(max_doc);
let inverted_index = reader.inverted_index(self.field)?; let inverted_index = reader.inverted_index(reader.schema().get_field(&self.field)?)?;
let term_dict = inverted_index.terms(); let term_dict = inverted_index.terms();
let mut term_range = self.term_range(term_dict)?; let mut term_range = self.term_range(term_dict)?;
let mut processed_count = 0; let mut processed_count = 0;
@@ -231,7 +477,7 @@ mod tests {
use crate::schema::{ use crate::schema::{
Field, IntoIpv6Addr, Schema, TantivyDocument, FAST, INDEXED, STORED, TEXT, Field, IntoIpv6Addr, Schema, TantivyDocument, FAST, INDEXED, STORED, TEXT,
}; };
use crate::{Index, IndexWriter, Term}; use crate::{Index, IndexWriter};
#[test] #[test]
fn test_range_query_simple() -> crate::Result<()> { fn test_range_query_simple() -> crate::Result<()> {
@@ -253,10 +499,7 @@ mod tests {
let reader = index.reader()?; let reader = index.reader()?;
let searcher = reader.searcher(); let searcher = reader.searcher();
let docs_in_the_sixties = RangeQuery::new( let docs_in_the_sixties = RangeQuery::new_u64("year".to_string(), 1960u64..1970u64);
Bound::Included(Term::from_field_u64(year_field, 1960)),
Bound::Excluded(Term::from_field_u64(year_field, 1970)),
);
// ... or `1960..=1969` if inclusive range is enabled. // ... or `1960..=1969` if inclusive range is enabled.
let count = searcher.search(&docs_in_the_sixties, &Count)?; let count = searcher.search(&docs_in_the_sixties, &Count)?;
@@ -287,10 +530,7 @@ mod tests {
let reader = index.reader()?; let reader = index.reader()?;
let searcher = reader.searcher(); let searcher = reader.searcher();
let mut docs_in_the_sixties = RangeQuery::new( let mut docs_in_the_sixties = RangeQuery::new_u64("year".to_string(), 1960u64..1970u64);
Bound::Included(Term::from_field_u64(year_field, 1960)),
Bound::Excluded(Term::from_field_u64(year_field, 1970)),
);
docs_in_the_sixties.limit(5); docs_in_the_sixties.limit(5);
// due to the limit and no docs in 1963, it's really only 1960..=1965 // due to the limit and no docs in 1963, it's really only 1960..=1965
@@ -335,29 +575,29 @@ mod tests {
|range_query: RangeQuery| searcher.search(&range_query, &Count).unwrap(); |range_query: RangeQuery| searcher.search(&range_query, &Count).unwrap();
assert_eq!( assert_eq!(
count_multiples(RangeQuery::new( count_multiples(RangeQuery::new_i64("intfield".to_string(), 10..11)),
Bound::Included(Term::from_field_i64(int_field, 10)),
Bound::Excluded(Term::from_field_i64(int_field, 11)),
)),
9 9
); );
assert_eq!( assert_eq!(
count_multiples(RangeQuery::new( count_multiples(RangeQuery::new_i64_bounds(
Bound::Included(Term::from_field_i64(int_field, 10)), "intfield".to_string(),
Bound::Included(Term::from_field_i64(int_field, 11)), Bound::Included(10),
Bound::Included(11)
)), )),
18 18
); );
assert_eq!( assert_eq!(
count_multiples(RangeQuery::new( count_multiples(RangeQuery::new_i64_bounds(
Bound::Excluded(Term::from_field_i64(int_field, 9)), "intfield".to_string(),
Bound::Included(Term::from_field_i64(int_field, 10)), Bound::Excluded(9),
Bound::Included(10)
)), )),
9 9
); );
assert_eq!( assert_eq!(
count_multiples(RangeQuery::new( count_multiples(RangeQuery::new_i64_bounds(
Bound::Included(Term::from_field_i64(int_field, 9)), "intfield".to_string(),
Bound::Included(9),
Bound::Unbounded Bound::Unbounded
)), )),
91 91
@@ -406,29 +646,29 @@ mod tests {
|range_query: RangeQuery| searcher.search(&range_query, &Count).unwrap(); |range_query: RangeQuery| searcher.search(&range_query, &Count).unwrap();
assert_eq!( assert_eq!(
count_multiples(RangeQuery::new( count_multiples(RangeQuery::new_f64("floatfield".to_string(), 10.0..11.0)),
Bound::Included(Term::from_field_f64(float_field, 10.0)),
Bound::Excluded(Term::from_field_f64(float_field, 11.0)),
)),
9 9
); );
assert_eq!( assert_eq!(
count_multiples(RangeQuery::new( count_multiples(RangeQuery::new_f64_bounds(
Bound::Included(Term::from_field_f64(float_field, 10.0)), "floatfield".to_string(),
Bound::Included(Term::from_field_f64(float_field, 11.0)), Bound::Included(10.0),
Bound::Included(11.0)
)), )),
18 18
); );
assert_eq!( assert_eq!(
count_multiples(RangeQuery::new( count_multiples(RangeQuery::new_f64_bounds(
Bound::Excluded(Term::from_field_f64(float_field, 9.0)), "floatfield".to_string(),
Bound::Included(Term::from_field_f64(float_field, 10.0)), Bound::Excluded(9.0),
Bound::Included(10.0)
)), )),
9 9
); );
assert_eq!( assert_eq!(
count_multiples(RangeQuery::new( count_multiples(RangeQuery::new_f64_bounds(
Bound::Included(Term::from_field_f64(float_field, 9.0)), "floatfield".to_string(),
Bound::Included(9.0),
Bound::Unbounded Bound::Unbounded
)), )),
91 91

View File

@@ -0,0 +1,512 @@
//! IP Fastfields support efficient scanning for range queries.
//! We use this variant only if the fastfield exists, otherwise the default in `range_query` is
//! used, which uses the term dictionary + postings.
use std::net::Ipv6Addr;
use std::ops::{Bound, RangeInclusive};
use columnar::{Column, MonotonicallyMappableToU128};
use crate::query::range_query::fast_field_range_query::RangeDocSet;
use crate::query::{ConstScorer, EmptyScorer, Explanation, Scorer, Weight};
use crate::{DocId, DocSet, Score, SegmentReader, TantivyError};
/// `IPFastFieldRangeWeight` uses the ip address fast field to execute range queries.
pub struct IPFastFieldRangeWeight {
field: String,
lower_bound: Bound<Ipv6Addr>,
upper_bound: Bound<Ipv6Addr>,
}
impl IPFastFieldRangeWeight {
/// Creates a new IPFastFieldRangeWeight.
pub fn new(field: String, lower_bound: Bound<Ipv6Addr>, upper_bound: Bound<Ipv6Addr>) -> Self {
Self {
field,
lower_bound,
upper_bound,
}
}
}
impl Weight for IPFastFieldRangeWeight {
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
let Some(ip_addr_column): Option<Column<Ipv6Addr>> =
reader.fast_fields().column_opt(&self.field)?
else {
return Ok(Box::new(EmptyScorer));
};
let value_range = bound_to_value_range(
&self.lower_bound,
&self.upper_bound,
ip_addr_column.min_value(),
ip_addr_column.max_value(),
);
let docset = RangeDocSet::new(value_range, ip_addr_column);
Ok(Box::new(ConstScorer::new(docset, boost)))
}
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {
let mut scorer = self.scorer(reader, 1.0)?;
if scorer.seek(doc) != doc {
return Err(TantivyError::InvalidArgument(format!(
"Document #({doc}) does not match"
)));
}
let explanation = Explanation::new("Const", scorer.score());
Ok(explanation)
}
}
fn bound_to_value_range(
lower_bound: &Bound<Ipv6Addr>,
upper_bound: &Bound<Ipv6Addr>,
min_value: Ipv6Addr,
max_value: Ipv6Addr,
) -> RangeInclusive<Ipv6Addr> {
let start_value = match lower_bound {
Bound::Included(ip_addr) => *ip_addr,
Bound::Excluded(ip_addr) => Ipv6Addr::from(ip_addr.to_u128() + 1),
Bound::Unbounded => min_value,
};
let end_value = match upper_bound {
Bound::Included(ip_addr) => *ip_addr,
Bound::Excluded(ip_addr) => Ipv6Addr::from(ip_addr.to_u128() - 1),
Bound::Unbounded => max_value,
};
start_value..=end_value
}
#[cfg(test)]
pub mod tests {
use proptest::prelude::ProptestConfig;
use proptest::strategy::Strategy;
use proptest::{prop_oneof, proptest};
use super::*;
use crate::collector::Count;
use crate::query::QueryParser;
use crate::schema::{Schema, FAST, INDEXED, STORED, STRING};
use crate::{Index, IndexWriter};
#[derive(Clone, Debug)]
pub struct Doc {
pub id: String,
pub ip: Ipv6Addr,
}
fn operation_strategy() -> impl Strategy<Value = Doc> {
prop_oneof![
(0u64..10_000u64).prop_map(doc_from_id_1),
(1u64..10_000u64).prop_map(doc_from_id_2),
]
}
pub fn doc_from_id_1(id: u64) -> Doc {
let id = id * 1000;
Doc {
// ip != id
id: id.to_string(),
ip: Ipv6Addr::from_u128(id as u128),
}
}
fn doc_from_id_2(id: u64) -> Doc {
let id = id * 1000;
Doc {
// ip != id
id: (id - 1).to_string(),
ip: Ipv6Addr::from_u128(id as u128),
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(10))]
#[test]
fn test_ip_range_for_docs_prop(ops in proptest::collection::vec(operation_strategy(), 1..1000)) {
assert!(test_ip_range_for_docs(&ops).is_ok());
}
}
#[test]
fn test_ip_range_regression1() {
let ops = &[doc_from_id_1(0)];
assert!(test_ip_range_for_docs(ops).is_ok());
}
#[test]
fn test_ip_range_regression2() {
let ops = &[
doc_from_id_1(52),
doc_from_id_1(63),
doc_from_id_1(12),
doc_from_id_2(91),
doc_from_id_2(33),
];
assert!(test_ip_range_for_docs(ops).is_ok());
}
#[test]
fn test_ip_range_regression3() {
let ops = &[doc_from_id_1(1), doc_from_id_1(2), doc_from_id_1(3)];
assert!(test_ip_range_for_docs(ops).is_ok());
}
#[test]
fn test_ip_range_regression3_simple() {
let mut schema_builder = Schema::builder();
let ips_field = schema_builder.add_ip_addr_field("ips", FAST | INDEXED);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut writer: IndexWriter = index.writer_for_tests().unwrap();
let ip_addrs: Vec<Ipv6Addr> = [1000, 2000, 3000]
.into_iter()
.map(Ipv6Addr::from_u128)
.collect();
for &ip_addr in &ip_addrs {
writer
.add_document(doc!(ips_field=>ip_addr, ips_field=>ip_addr))
.unwrap();
}
writer.commit().unwrap();
let searcher = index.reader().unwrap().searcher();
let range_weight = IPFastFieldRangeWeight {
field: "ips".to_string(),
lower_bound: Bound::Included(ip_addrs[1]),
upper_bound: Bound::Included(ip_addrs[2]),
};
let count = range_weight.count(searcher.segment_reader(0)).unwrap();
assert_eq!(count, 2);
}
pub fn create_index_from_docs(docs: &[Doc]) -> Index {
let mut schema_builder = Schema::builder();
let ip_field = schema_builder.add_ip_addr_field("ip", STORED | FAST);
let ips_field = schema_builder.add_ip_addr_field("ips", FAST | INDEXED);
let text_field = schema_builder.add_text_field("id", STRING | STORED);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
{
let mut index_writer = index.writer_with_num_threads(2, 60_000_000).unwrap();
for doc in docs.iter() {
index_writer
.add_document(doc!(
ips_field => doc.ip,
ips_field => doc.ip,
ip_field => doc.ip,
text_field => doc.id.to_string(),
))
.unwrap();
}
index_writer.commit().unwrap();
}
index
}
fn test_ip_range_for_docs(docs: &[Doc]) -> crate::Result<()> {
let index = create_index_from_docs(docs);
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let get_num_hits = |query| searcher.search(&query, &Count).unwrap();
let query_from_text = |text: &str| {
QueryParser::for_index(&index, vec![])
.parse_query(text)
.unwrap()
};
let gen_query_inclusive = |field: &str, ip_range: &RangeInclusive<Ipv6Addr>| {
format!("{field}:[{} TO {}]", ip_range.start(), ip_range.end())
};
let test_sample = |sample_docs: &[Doc]| {
let mut ips: Vec<Ipv6Addr> = sample_docs.iter().map(|doc| doc.ip).collect();
ips.sort();
let ip_range = ips[0]..=ips[1];
let expected_num_hits = docs
.iter()
.filter(|doc| (ips[0]..=ips[1]).contains(&doc.ip))
.count();
let query = gen_query_inclusive("ip", &ip_range);
assert_eq!(get_num_hits(query_from_text(&query)), expected_num_hits);
let query = gen_query_inclusive("ips", &ip_range);
assert_eq!(get_num_hits(query_from_text(&query)), expected_num_hits);
// Intersection search
let id_filter = sample_docs[0].id.to_string();
let expected_num_hits = docs
.iter()
.filter(|doc| ip_range.contains(&doc.ip) && doc.id == id_filter)
.count();
let query = format!(
"{} AND id:{}",
gen_query_inclusive("ip", &ip_range),
&id_filter
);
assert_eq!(get_num_hits(query_from_text(&query)), expected_num_hits);
// Intersection search on multivalue ip field
let id_filter = sample_docs[0].id.to_string();
let query = format!(
"{} AND id:{}",
gen_query_inclusive("ips", &ip_range),
&id_filter
);
assert_eq!(get_num_hits(query_from_text(&query)), expected_num_hits);
};
test_sample(&[docs[0].clone(), docs[0].clone()]);
if docs.len() > 1 {
test_sample(&[docs[0].clone(), docs[1].clone()]);
test_sample(&[docs[1].clone(), docs[1].clone()]);
}
if docs.len() > 2 {
test_sample(&[docs[1].clone(), docs[2].clone()]);
}
Ok(())
}
}
#[cfg(all(test, feature = "unstable"))]
mod bench {
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use test::Bencher;
use super::tests::*;
use super::*;
use crate::collector::Count;
use crate::query::QueryParser;
use crate::Index;
fn get_index_0_to_100() -> Index {
let mut rng = StdRng::from_seed([1u8; 32]);
let num_vals = 100_000;
let docs: Vec<_> = (0..num_vals)
.map(|_i| {
let id = if rng.gen_bool(0.01) {
"veryfew".to_string() // 1%
} else if rng.gen_bool(0.1) {
"few".to_string() // 9%
} else {
"many".to_string() // 90%
};
Doc {
id,
// Multiply by 1000, so that we create many buckets in the compact space
// The benches depend on this range to select n-percent of elements with the
// methods below.
ip: Ipv6Addr::from_u128(rng.gen_range(0..100) * 1000),
}
})
.collect();
create_index_from_docs(&docs)
}
fn get_90_percent() -> RangeInclusive<Ipv6Addr> {
let start = Ipv6Addr::from_u128(0);
let end = Ipv6Addr::from_u128(90 * 1000);
start..=end
}
fn get_10_percent() -> RangeInclusive<Ipv6Addr> {
let start = Ipv6Addr::from_u128(0);
let end = Ipv6Addr::from_u128(10 * 1000);
start..=end
}
fn get_1_percent() -> RangeInclusive<Ipv6Addr> {
let start = Ipv6Addr::from_u128(10 * 1000);
let end = Ipv6Addr::from_u128(10 * 1000);
start..=end
}
fn excute_query(
field: &str,
ip_range: RangeInclusive<Ipv6Addr>,
suffix: &str,
index: &Index,
) -> usize {
let gen_query_inclusive = |from: &Ipv6Addr, to: &Ipv6Addr| {
format!(
"{}:[{} TO {}] {}",
field,
&from.to_string(),
&to.to_string(),
suffix
)
};
let query = gen_query_inclusive(ip_range.start(), ip_range.end());
let query_from_text = |text: &str| {
QueryParser::for_index(index, vec![])
.parse_query(text)
.unwrap()
};
let query = query_from_text(&query);
let reader = index.reader().unwrap();
let searcher = reader.searcher();
searcher.search(&query, &(Count)).unwrap()
}
#[bench]
fn bench_ip_range_hit_90_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_90_percent(), "", &index));
}
#[bench]
fn bench_ip_range_hit_10_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_10_percent(), "", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_1_percent(), "", &index));
}
#[bench]
fn bench_ip_range_hit_10_percent_intersect_with_10_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_10_percent(), "AND id:few", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent_intersect_with_10_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_1_percent(), "AND id:few", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent_intersect_with_90_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_1_percent(), "AND id:many", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent_intersect_with_1_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_1_percent(), "AND id:veryfew", &index));
}
#[bench]
fn bench_ip_range_hit_10_percent_intersect_with_90_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_10_percent(), "AND id:many", &index));
}
#[bench]
fn bench_ip_range_hit_90_percent_intersect_with_90_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_90_percent(), "AND id:many", &index));
}
#[bench]
fn bench_ip_range_hit_90_percent_intersect_with_10_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_90_percent(), "AND id:few", &index));
}
#[bench]
fn bench_ip_range_hit_90_percent_intersect_with_1_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_90_percent(), "AND id:veryfew", &index));
}
#[bench]
fn bench_ip_range_hit_90_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_90_percent(), "", &index));
}
#[bench]
fn bench_ip_range_hit_10_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_10_percent(), "", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_1_percent(), "", &index));
}
#[bench]
fn bench_ip_range_hit_10_percent_intersect_with_10_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_10_percent(), "AND id:few", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent_intersect_with_10_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_1_percent(), "AND id:few", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent_intersect_with_90_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_1_percent(), "AND id:many", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent_intersect_with_1_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_1_percent(), "AND id:veryfew", &index));
}
#[bench]
fn bench_ip_range_hit_10_percent_intersect_with_90_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_10_percent(), "AND id:many", &index));
}
#[bench]
fn bench_ip_range_hit_90_percent_intersect_with_90_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_90_percent(), "AND id:many", &index));
}
#[bench]
fn bench_ip_range_hit_90_percent_intersect_with_10_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_90_percent(), "AND id:few", &index));
}
#[bench]
fn bench_ip_range_hit_90_percent_intersect_with_1_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_90_percent(), "AND id:veryfew", &index));
}
}

View File

@@ -2,34 +2,54 @@
//! We use this variant only if the fastfield exists, otherwise the default in `range_query` is //! We use this variant only if the fastfield exists, otherwise the default in `range_query` is
//! used, which uses the term dictionary + postings. //! used, which uses the term dictionary + postings.
use std::net::Ipv6Addr;
use std::ops::{Bound, RangeInclusive}; use std::ops::{Bound, RangeInclusive};
use columnar::{Column, MonotonicallyMappableToU128, MonotonicallyMappableToU64, StrColumn}; use columnar::{ColumnType, HasAssociatedColumnType, MonotonicallyMappableToU64};
use common::BinarySerializable;
use super::fast_field_range_doc_set::RangeDocSet; use super::fast_field_range_query::RangeDocSet;
use super::{map_bound, map_bound_res}; use super::map_bound;
use crate::query::range_query::range_query::inner_bound; use crate::query::{ConstScorer, EmptyScorer, Explanation, Query, Scorer, Weight};
use crate::query::{AllScorer, ConstScorer, EmptyScorer, Explanation, Query, Scorer, Weight}; use crate::{DocId, DocSet, Score, SegmentReader, TantivyError};
use crate::schema::{Field, Type};
use crate::{DocId, DocSet, Score, SegmentReader, TantivyError, Term};
/// `FastFieldRangeWeight` uses the fast field to execute range queries. /// `FastFieldRangeWeight` uses the fast field to execute range queries.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct FastFieldRangeWeight { pub struct FastFieldRangeWeight {
lower_bound: Bound<Term>, field: String,
upper_bound: Bound<Term>, lower_bound: Bound<u64>,
field: Field, upper_bound: Bound<u64>,
column_type_opt: Option<ColumnType>,
} }
impl FastFieldRangeWeight { impl FastFieldRangeWeight {
/// Create a new FastFieldRangeWeight /// Create a new FastFieldRangeWeight, using the u64 representation of any fast field.
pub(crate) fn new(field: Field, lower_bound: Bound<Term>, upper_bound: Bound<Term>) -> Self { pub(crate) fn new_u64_lenient(
field: String,
lower_bound: Bound<u64>,
upper_bound: Bound<u64>,
) -> Self {
let lower_bound = map_bound(&lower_bound, |val| *val);
let upper_bound = map_bound(&upper_bound, |val| *val);
Self { Self {
field,
lower_bound, lower_bound,
upper_bound, upper_bound,
column_type_opt: None,
}
}
/// Create a new `FastFieldRangeWeight` for a range of a u64-mappable type .
pub fn new<T: HasAssociatedColumnType + MonotonicallyMappableToU64>(
field: String,
lower_bound: Bound<T>,
upper_bound: Bound<T>,
) -> Self {
let lower_bound = map_bound(&lower_bound, |val| val.to_u64());
let upper_bound = map_bound(&upper_bound, |val| val.to_u64());
Self {
field, field,
lower_bound,
upper_bound,
column_type_opt: Some(T::column_type()),
} }
} }
} }
@@ -45,101 +65,30 @@ impl Query for FastFieldRangeWeight {
impl Weight for FastFieldRangeWeight { impl Weight for FastFieldRangeWeight {
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> { fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
// Check if both bounds are Bound::Unbounded let fast_field_reader = reader.fast_fields();
if self.lower_bound == Bound::Unbounded && self.upper_bound == Bound::Unbounded { let column_type_opt: Option<[ColumnType; 1]> =
return Ok(Box::new(AllScorer::new(reader.max_doc()))); self.column_type_opt.map(|column_type| [column_type]);
} let column_type_opt_ref: Option<&[ColumnType]> = column_type_opt
let field_name = reader.schema().get_field_name(self.field); .as_ref()
let field_type = reader.schema().get_field_entry(self.field).field_type(); .map(|column_types| column_types.as_slice());
let Some((column, _)) =
let term = inner_bound(&self.lower_bound) fast_field_reader.u64_lenient_for_type(column_type_opt_ref, &self.field)?
.or(inner_bound(&self.upper_bound)) else {
.expect("At least one bound must be set"); return Ok(Box::new(EmptyScorer));
assert_eq!( };
term.typ(), #[allow(clippy::reversed_empty_ranges)]
field_type.value_type(), let value_range = bound_to_value_range(
"Field is of type {:?}, but got term of type {:?}", &self.lower_bound,
field_type, &self.upper_bound,
term.typ() column.min_value(),
); column.max_value(),
if field_type.is_ip_addr() { )
let parse_ip_from_bytes = |term: &Term| { .unwrap_or(1..=0); // empty range
term.value().as_ip_addr().ok_or_else(|| { if value_range.is_empty() {
crate::TantivyError::InvalidArgument("Expected ip address".to_string()) return Ok(Box::new(EmptyScorer));
})
};
let lower_bound = map_bound_res(&self.lower_bound, parse_ip_from_bytes)?;
let upper_bound = map_bound_res(&self.upper_bound, parse_ip_from_bytes)?;
let Some(ip_addr_column): Option<Column<Ipv6Addr>> =
reader.fast_fields().column_opt(field_name)?
else {
return Ok(Box::new(EmptyScorer));
};
let value_range = bound_to_value_range_ip(
&lower_bound,
&upper_bound,
ip_addr_column.min_value(),
ip_addr_column.max_value(),
);
let docset = RangeDocSet::new(value_range, ip_addr_column);
Ok(Box::new(ConstScorer::new(docset, boost)))
} else {
let (lower_bound, upper_bound) = if field_type.is_str() {
let Some(str_dict_column): Option<StrColumn> =
reader.fast_fields().str(field_name)?
else {
return Ok(Box::new(EmptyScorer));
};
let dict = str_dict_column.dictionary();
let lower_bound = map_bound(&self.lower_bound, |term| {
term.serialized_value_bytes().to_vec()
});
let upper_bound = map_bound(&self.upper_bound, |term| {
term.serialized_value_bytes().to_vec()
});
// Get term ids for terms
let (lower_bound, upper_bound) =
dict.term_bounds_to_ord(lower_bound, upper_bound)?;
(lower_bound, upper_bound)
} else {
assert!(
maps_to_u64_fastfield(field_type.value_type()),
"{:?}",
field_type
);
let parse_from_bytes = |term: &Term| {
u64::from_be(
BinarySerializable::deserialize(&mut &term.serialized_value_bytes()[..])
.unwrap(),
)
};
let lower_bound = map_bound(&self.lower_bound, parse_from_bytes);
let upper_bound = map_bound(&self.upper_bound, parse_from_bytes);
(lower_bound, upper_bound)
};
let fast_field_reader = reader.fast_fields();
let Some((column, _)) = fast_field_reader.u64_lenient_for_type(None, field_name)?
else {
return Ok(Box::new(EmptyScorer));
};
#[allow(clippy::reversed_empty_ranges)]
let value_range = bound_to_value_range(
&lower_bound,
&upper_bound,
column.min_value(),
column.max_value(),
)
.unwrap_or(1..=0); // empty range
if value_range.is_empty() {
return Ok(Box::new(EmptyScorer));
}
let docset = RangeDocSet::new(value_range, column);
Ok(Box::new(ConstScorer::new(docset, boost)))
} }
let docset = RangeDocSet::new(value_range, column);
Ok(Box::new(ConstScorer::new(docset, boost)))
} }
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> { fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {
@@ -155,35 +104,6 @@ impl Weight for FastFieldRangeWeight {
} }
} }
/// Returns true if the type maps to a u64 fast field
pub(crate) fn maps_to_u64_fastfield(typ: Type) -> bool {
match typ {
Type::U64 | Type::I64 | Type::F64 | Type::Bool | Type::Date => true,
Type::IpAddr => false,
Type::Str | Type::Facet | Type::Bytes | Type::Json => false,
}
}
fn bound_to_value_range_ip(
lower_bound: &Bound<Ipv6Addr>,
upper_bound: &Bound<Ipv6Addr>,
min_value: Ipv6Addr,
max_value: Ipv6Addr,
) -> RangeInclusive<Ipv6Addr> {
let start_value = match lower_bound {
Bound::Included(ip_addr) => *ip_addr,
Bound::Excluded(ip_addr) => Ipv6Addr::from(ip_addr.to_u128() + 1),
Bound::Unbounded => min_value,
};
let end_value = match upper_bound {
Bound::Included(ip_addr) => *ip_addr,
Bound::Excluded(ip_addr) => Ipv6Addr::from(ip_addr.to_u128() - 1),
Bound::Unbounded => max_value,
};
start_value..=end_value
}
// Returns None, if the range cannot be converted to a inclusive range (which equals to a empty // Returns None, if the range cannot be converted to a inclusive range (which equals to a empty
// range). // range).
fn bound_to_value_range<T: MonotonicallyMappableToU64>( fn bound_to_value_range<T: MonotonicallyMappableToU64>(
@@ -217,72 +137,11 @@ pub mod tests {
use rand::seq::SliceRandom; use rand::seq::SliceRandom;
use rand::SeedableRng; use rand::SeedableRng;
use crate::collector::{Count, TopDocs}; use crate::collector::Count;
use crate::query::range_query::range_query_u64_fastfield::FastFieldRangeWeight; use crate::query::range_query::range_query_u64_fastfield::FastFieldRangeWeight;
use crate::query::{QueryParser, Weight}; use crate::query::{QueryParser, Weight};
use crate::schema::{ use crate::schema::{NumericOptions, Schema, SchemaBuilder, FAST, INDEXED, STORED, STRING};
NumericOptions, Schema, SchemaBuilder, FAST, INDEXED, STORED, STRING, TEXT, use crate::{Index, IndexWriter, TERMINATED};
};
use crate::{Index, IndexWriter, Term, TERMINATED};
#[test]
fn test_text_field_ff_range_query() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
schema_builder.add_text_field("title", TEXT | FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema.clone());
let mut index_writer = index.writer_for_tests()?;
let title = schema.get_field("title").unwrap();
index_writer.add_document(doc!(
title => "bbb"
))?;
index_writer.add_document(doc!(
title => "ddd"
))?;
index_writer.commit()?;
let reader = index.reader()?;
let searcher = reader.searcher();
let query_parser = QueryParser::for_index(&index, vec![title]);
let test_query = |query, num_hits| {
let query = query_parser.parse_query(query).unwrap();
let top_docs = searcher.search(&query, &TopDocs::with_limit(10)).unwrap();
assert_eq!(top_docs.len(), num_hits);
};
test_query("title:[aaa TO ccc]", 1);
test_query("title:[aaa TO bbb]", 1);
test_query("title:[bbb TO bbb]", 1);
test_query("title:[bbb TO ddd]", 2);
test_query("title:[bbb TO eee]", 2);
test_query("title:[bb TO eee]", 2);
test_query("title:[ccc TO ccc]", 0);
test_query("title:[ccc TO ddd]", 1);
test_query("title:[ccc TO eee]", 1);
test_query("title:[aaa TO *}", 2);
test_query("title:[bbb TO *]", 2);
test_query("title:[bb TO *]", 2);
test_query("title:[ccc TO *]", 1);
test_query("title:[ddd TO *]", 1);
test_query("title:[dddd TO *]", 0);
test_query("title:{aaa TO *}", 2);
test_query("title:{bbb TO *]", 1);
test_query("title:{bb TO *]", 2);
test_query("title:{ccc TO *]", 1);
test_query("title:{ddd TO *]", 0);
test_query("title:{dddd TO *]", 0);
test_query("title:[* TO bb]", 0);
test_query("title:[* TO bbb]", 1);
test_query("title:[* TO ccc]", 1);
test_query("title:[* TO ddd]", 2);
test_query("title:[* TO ddd}", 1);
test_query("title:[* TO eee]", 2);
Ok(())
}
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct Doc { pub struct Doc {
@@ -300,14 +159,14 @@ pub mod tests {
fn doc_from_id_1(id: u64) -> Doc { fn doc_from_id_1(id: u64) -> Doc {
let id = id * 1000; let id = id * 1000;
Doc { Doc {
id_name: format!("id_name{:010}", id), id_name: id.to_string(),
id, id,
} }
} }
fn doc_from_id_2(id: u64) -> Doc { fn doc_from_id_2(id: u64) -> Doc {
let id = id * 1000; let id = id * 1000;
Doc { Doc {
id_name: format!("id_name{:010}", id - 1), id_name: (id - 1).to_string(),
id, id,
} }
} }
@@ -354,10 +213,10 @@ pub mod tests {
writer.add_document(doc!(field=>52_000u64)).unwrap(); writer.add_document(doc!(field=>52_000u64)).unwrap();
writer.commit().unwrap(); writer.commit().unwrap();
let searcher = index.reader().unwrap().searcher(); let searcher = index.reader().unwrap().searcher();
let range_query = FastFieldRangeWeight::new( let range_query = FastFieldRangeWeight::new_u64_lenient(
field, "test_field".to_string(),
Bound::Included(Term::from_field_u64(field, 50_000)), Bound::Included(50_000),
Bound::Included(Term::from_field_u64(field, 50_002)), Bound::Included(50_002),
); );
let scorer = range_query let scorer = range_query
.scorer(searcher.segment_reader(0), 1.0f32) .scorer(searcher.segment_reader(0), 1.0f32)
@@ -395,8 +254,7 @@ pub mod tests {
NumericOptions::default().set_fast().set_indexed(), NumericOptions::default().set_fast().set_indexed(),
); );
let text_field = schema_builder.add_text_field("id_name", STRING | STORED | FAST); let text_field = schema_builder.add_text_field("id_name", STRING | STORED);
let text_field2 = schema_builder.add_text_field("id_name_fast", STRING | STORED | FAST);
let schema = schema_builder.build(); let schema = schema_builder.build();
let index = Index::create_in_ram(schema); let index = Index::create_in_ram(schema);
@@ -415,7 +273,6 @@ pub mod tests {
id_f64_field => doc.id as f64, id_f64_field => doc.id as f64,
id_i64_field => doc.id as i64, id_i64_field => doc.id as i64,
text_field => doc.id_name.to_string(), text_field => doc.id_name.to_string(),
text_field2 => doc.id_name.to_string(),
)) ))
.unwrap(); .unwrap();
} }
@@ -460,24 +317,6 @@ pub mod tests {
let query = gen_query_inclusive("ids", ids[0]..=ids[1]); let query = gen_query_inclusive("ids", ids[0]..=ids[1]);
assert_eq!(get_num_hits(query_from_text(&query)), expected_num_hits); assert_eq!(get_num_hits(query_from_text(&query)), expected_num_hits);
// Text query
{
let test_text_query = |field_name: &str| {
let mut id_names: Vec<&str> =
sample_docs.iter().map(|doc| doc.id_name.as_str()).collect();
id_names.sort();
let expected_num_hits = docs
.iter()
.filter(|doc| (id_names[0]..=id_names[1]).contains(&doc.id_name.as_str()))
.count();
let query = format!("{}:[{} TO {}]", field_name, id_names[0], id_names[1]);
assert_eq!(get_num_hits(query_from_text(&query)), expected_num_hits);
};
test_text_query("id_name");
test_text_query("id_name_fast");
}
// Exclusive range // Exclusive range
let expected_num_hits = docs let expected_num_hits = docs
.iter() .iter()
@@ -555,202 +394,6 @@ pub mod tests {
} }
} }
#[cfg(test)]
pub mod ip_range_tests {
use proptest::prelude::ProptestConfig;
use proptest::strategy::Strategy;
use proptest::{prop_oneof, proptest};
use super::*;
use crate::collector::Count;
use crate::query::QueryParser;
use crate::schema::{Schema, FAST, INDEXED, STORED, STRING};
use crate::{Index, IndexWriter};
#[derive(Clone, Debug)]
pub struct Doc {
pub id: String,
pub ip: Ipv6Addr,
}
fn operation_strategy() -> impl Strategy<Value = Doc> {
prop_oneof![
(0u64..10_000u64).prop_map(doc_from_id_1),
(1u64..10_000u64).prop_map(doc_from_id_2),
]
}
pub fn doc_from_id_1(id: u64) -> Doc {
let id = id * 1000;
Doc {
// ip != id
id: id.to_string(),
ip: Ipv6Addr::from_u128(id as u128),
}
}
fn doc_from_id_2(id: u64) -> Doc {
let id = id * 1000;
Doc {
// ip != id
id: (id - 1).to_string(),
ip: Ipv6Addr::from_u128(id as u128),
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(10))]
#[test]
fn test_ip_range_for_docs_prop(ops in proptest::collection::vec(operation_strategy(), 1..1000)) {
assert!(test_ip_range_for_docs(&ops).is_ok());
}
}
#[test]
fn test_ip_range_regression1() {
let ops = &[doc_from_id_1(0)];
assert!(test_ip_range_for_docs(ops).is_ok());
}
#[test]
fn test_ip_range_regression2() {
let ops = &[
doc_from_id_1(52),
doc_from_id_1(63),
doc_from_id_1(12),
doc_from_id_2(91),
doc_from_id_2(33),
];
assert!(test_ip_range_for_docs(ops).is_ok());
}
#[test]
fn test_ip_range_regression3() {
let ops = &[doc_from_id_1(1), doc_from_id_1(2), doc_from_id_1(3)];
assert!(test_ip_range_for_docs(ops).is_ok());
}
#[test]
fn test_ip_range_regression3_simple() {
let mut schema_builder = Schema::builder();
let ips_field = schema_builder.add_ip_addr_field("ips", FAST | INDEXED);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut writer: IndexWriter = index.writer_for_tests().unwrap();
let ip_addrs: Vec<Ipv6Addr> = [1000, 2000, 3000]
.into_iter()
.map(Ipv6Addr::from_u128)
.collect();
for &ip_addr in &ip_addrs {
writer
.add_document(doc!(ips_field=>ip_addr, ips_field=>ip_addr))
.unwrap();
}
writer.commit().unwrap();
let searcher = index.reader().unwrap().searcher();
let range_weight = FastFieldRangeWeight::new(
ips_field,
Bound::Included(Term::from_field_ip_addr(ips_field, ip_addrs[1])),
Bound::Included(Term::from_field_ip_addr(ips_field, ip_addrs[2])),
);
let count =
crate::query::weight::Weight::count(&range_weight, searcher.segment_reader(0)).unwrap();
assert_eq!(count, 2);
}
pub fn create_index_from_ip_docs(docs: &[Doc]) -> Index {
let mut schema_builder = Schema::builder();
let ip_field = schema_builder.add_ip_addr_field("ip", STORED | FAST);
let ips_field = schema_builder.add_ip_addr_field("ips", FAST | INDEXED);
let text_field = schema_builder.add_text_field("id", STRING | STORED);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
{
let mut index_writer = index.writer_with_num_threads(2, 60_000_000).unwrap();
for doc in docs.iter() {
index_writer
.add_document(doc!(
ips_field => doc.ip,
ips_field => doc.ip,
ip_field => doc.ip,
text_field => doc.id.to_string(),
))
.unwrap();
}
index_writer.commit().unwrap();
}
index
}
fn test_ip_range_for_docs(docs: &[Doc]) -> crate::Result<()> {
let index = create_index_from_ip_docs(docs);
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let get_num_hits = |query| searcher.search(&query, &Count).unwrap();
let query_from_text = |text: &str| {
QueryParser::for_index(&index, vec![])
.parse_query(text)
.unwrap()
};
let gen_query_inclusive = |field: &str, ip_range: &RangeInclusive<Ipv6Addr>| {
format!("{field}:[{} TO {}]", ip_range.start(), ip_range.end())
};
let test_sample = |sample_docs: &[Doc]| {
let mut ips: Vec<Ipv6Addr> = sample_docs.iter().map(|doc| doc.ip).collect();
ips.sort();
let ip_range = ips[0]..=ips[1];
let expected_num_hits = docs
.iter()
.filter(|doc| (ips[0]..=ips[1]).contains(&doc.ip))
.count();
let query = gen_query_inclusive("ip", &ip_range);
assert_eq!(get_num_hits(query_from_text(&query)), expected_num_hits);
let query = gen_query_inclusive("ips", &ip_range);
assert_eq!(get_num_hits(query_from_text(&query)), expected_num_hits);
// Intersection search
let id_filter = sample_docs[0].id.to_string();
let expected_num_hits = docs
.iter()
.filter(|doc| ip_range.contains(&doc.ip) && doc.id == id_filter)
.count();
let query = format!(
"{} AND id:{}",
gen_query_inclusive("ip", &ip_range),
&id_filter
);
assert_eq!(get_num_hits(query_from_text(&query)), expected_num_hits);
// Intersection search on multivalue ip field
let id_filter = sample_docs[0].id.to_string();
let query = format!(
"{} AND id:{}",
gen_query_inclusive("ips", &ip_range),
&id_filter
);
assert_eq!(get_num_hits(query_from_text(&query)), expected_num_hits);
};
test_sample(&[docs[0].clone(), docs[0].clone()]);
if docs.len() > 1 {
test_sample(&[docs[0].clone(), docs[1].clone()]);
test_sample(&[docs[1].clone(), docs[1].clone()]);
}
if docs.len() > 2 {
test_sample(&[docs[1].clone(), docs[2].clone()]);
}
Ok(())
}
}
#[cfg(all(test, feature = "unstable"))] #[cfg(all(test, feature = "unstable"))]
mod bench { mod bench {
@@ -958,242 +601,3 @@ mod bench {
bench.iter(|| execute_query("ids", get_90_percent(), "AND id_name:veryfew", &index)); bench.iter(|| execute_query("ids", get_90_percent(), "AND id_name:veryfew", &index));
} }
} }
#[cfg(all(test, feature = "unstable"))]
mod bench_ip {
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use test::Bencher;
use super::ip_range_tests::*;
use super::*;
use crate::collector::Count;
use crate::query::QueryParser;
use crate::Index;
fn get_index_0_to_100() -> Index {
let mut rng = StdRng::from_seed([1u8; 32]);
let num_vals = 100_000;
let docs: Vec<_> = (0..num_vals)
.map(|_i| {
let id = if rng.gen_bool(0.01) {
"veryfew".to_string() // 1%
} else if rng.gen_bool(0.1) {
"few".to_string() // 9%
} else {
"many".to_string() // 90%
};
Doc {
id,
// Multiply by 1000, so that we create many buckets in the compact space
// The benches depend on this range to select n-percent of elements with the
// methods below.
ip: Ipv6Addr::from_u128(rng.gen_range(0..100) * 1000),
}
})
.collect();
create_index_from_ip_docs(&docs)
}
fn get_90_percent() -> RangeInclusive<Ipv6Addr> {
let start = Ipv6Addr::from_u128(0);
let end = Ipv6Addr::from_u128(90 * 1000);
start..=end
}
fn get_10_percent() -> RangeInclusive<Ipv6Addr> {
let start = Ipv6Addr::from_u128(0);
let end = Ipv6Addr::from_u128(10 * 1000);
start..=end
}
fn get_1_percent() -> RangeInclusive<Ipv6Addr> {
let start = Ipv6Addr::from_u128(10 * 1000);
let end = Ipv6Addr::from_u128(10 * 1000);
start..=end
}
fn excute_query(
field: &str,
ip_range: RangeInclusive<Ipv6Addr>,
suffix: &str,
index: &Index,
) -> usize {
let gen_query_inclusive = |from: &Ipv6Addr, to: &Ipv6Addr| {
format!(
"{}:[{} TO {}] {}",
field,
&from.to_string(),
&to.to_string(),
suffix
)
};
let query = gen_query_inclusive(ip_range.start(), ip_range.end());
let query_from_text = |text: &str| {
QueryParser::for_index(index, vec![])
.parse_query(text)
.unwrap()
};
let query = query_from_text(&query);
let reader = index.reader().unwrap();
let searcher = reader.searcher();
searcher.search(&query, &(Count)).unwrap()
}
#[bench]
fn bench_ip_range_hit_90_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_90_percent(), "", &index));
}
#[bench]
fn bench_ip_range_hit_10_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_10_percent(), "", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_1_percent(), "", &index));
}
#[bench]
fn bench_ip_range_hit_10_percent_intersect_with_10_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_10_percent(), "AND id:few", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent_intersect_with_10_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_1_percent(), "AND id:few", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent_intersect_with_90_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_1_percent(), "AND id:many", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent_intersect_with_1_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_1_percent(), "AND id:veryfew", &index));
}
#[bench]
fn bench_ip_range_hit_10_percent_intersect_with_90_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_10_percent(), "AND id:many", &index));
}
#[bench]
fn bench_ip_range_hit_90_percent_intersect_with_90_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_90_percent(), "AND id:many", &index));
}
#[bench]
fn bench_ip_range_hit_90_percent_intersect_with_10_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_90_percent(), "AND id:few", &index));
}
#[bench]
fn bench_ip_range_hit_90_percent_intersect_with_1_percent(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ip", get_90_percent(), "AND id:veryfew", &index));
}
#[bench]
fn bench_ip_range_hit_90_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_90_percent(), "", &index));
}
#[bench]
fn bench_ip_range_hit_10_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_10_percent(), "", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_1_percent(), "", &index));
}
#[bench]
fn bench_ip_range_hit_10_percent_intersect_with_10_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_10_percent(), "AND id:few", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent_intersect_with_10_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_1_percent(), "AND id:few", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent_intersect_with_90_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_1_percent(), "AND id:many", &index));
}
#[bench]
fn bench_ip_range_hit_1_percent_intersect_with_1_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_1_percent(), "AND id:veryfew", &index));
}
#[bench]
fn bench_ip_range_hit_10_percent_intersect_with_90_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_10_percent(), "AND id:many", &index));
}
#[bench]
fn bench_ip_range_hit_90_percent_intersect_with_90_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_90_percent(), "AND id:many", &index));
}
#[bench]
fn bench_ip_range_hit_90_percent_intersect_with_10_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_90_percent(), "AND id:few", &index));
}
#[bench]
fn bench_ip_range_hit_90_percent_intersect_with_1_percent_multi(bench: &mut Bencher) {
let index = get_index_0_to_100();
bench.iter(|| excute_query("ips", get_90_percent(), "AND id:veryfew", &index));
}
}

View File

@@ -508,7 +508,7 @@ impl std::fmt::Debug for ValueAddr {
/// A enum representing a value for tantivy to index. /// A enum representing a value for tantivy to index.
/// ///
/// ** Any changes need to be reflected in `BinarySerializable` for `ValueType` ** /// Any changes need to be reflected in `BinarySerializable` for `ValueType`
/// ///
/// We can't use [schema::Type] or [columnar::ColumnType] here, because they are missing /// We can't use [schema::Type] or [columnar::ColumnType] here, because they are missing
/// some items like Array and PreTokStr. /// some items like Array and PreTokStr.
@@ -553,7 +553,7 @@ impl BinarySerializable for ValueType {
fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> { fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
let num = u8::deserialize(reader)?; let num = u8::deserialize(reader)?;
let type_id = if (0..=12).contains(&num) { let type_id = if (0..=12).contains(&num) {
unsafe { std::mem::transmute::<u8, ValueType>(num) } unsafe { std::mem::transmute(num) }
} else { } else {
return Err(io::Error::new( return Err(io::Error::new(
io::ErrorKind::InvalidData, io::ErrorKind::InvalidData,

View File

@@ -201,11 +201,6 @@ impl FieldType {
matches!(self, FieldType::IpAddr(_)) matches!(self, FieldType::IpAddr(_))
} }
/// returns true if this is an str field
pub fn is_str(&self) -> bool {
matches!(self, FieldType::Str(_))
}
/// returns true if this is an date field /// returns true if this is an date field
pub fn is_date(&self) -> bool { pub fn is_date(&self) -> bool {
matches!(self, FieldType::Date(_)) matches!(self, FieldType::Date(_))

View File

@@ -249,8 +249,15 @@ impl Term {
#[inline] #[inline]
pub fn append_path(&mut self, bytes: &[u8]) -> &mut [u8] { pub fn append_path(&mut self, bytes: &[u8]) -> &mut [u8] {
let len_before = self.0.len(); let len_before = self.0.len();
assert!(!bytes.contains(&JSON_END_OF_PATH)); if bytes.contains(&JSON_END_OF_PATH) {
self.0.extend_from_slice(bytes); self.0.extend(
bytes
.iter()
.map(|&b| if b == JSON_END_OF_PATH { b'0' } else { b }),
);
} else {
self.0.extend_from_slice(bytes);
}
&mut self.0[len_before..] &mut self.0[len_before..]
} }
} }

View File

@@ -16,7 +16,9 @@ fn make_test_sstable(suffix: &str) -> FileSlice {
let table = builder.finish().unwrap(); let table = builder.finish().unwrap();
let table = Arc::new(OwnedBytes::new(table)); let table = Arc::new(OwnedBytes::new(table));
common::file_slice::FileSlice::new(table.clone()) let slice = common::file_slice::FileSlice::new(table.clone());
slice
} }
pub fn criterion_benchmark(c: &mut Criterion) { pub fn criterion_benchmark(c: &mut Criterion) {

View File

@@ -7,7 +7,7 @@ use rand::rngs::StdRng;
use rand::{Rng, SeedableRng}; use rand::{Rng, SeedableRng};
use tantivy_sstable::{Dictionary, MonotonicU64SSTable}; use tantivy_sstable::{Dictionary, MonotonicU64SSTable};
const CHARSET: &[u8] = b"abcdefghij"; const CHARSET: &'static [u8] = b"abcdefghij";
fn generate_key(rng: &mut impl Rng) -> String { fn generate_key(rng: &mut impl Rng) -> String {
let len = rng.gen_range(3..12); let len = rng.gen_range(3..12);

View File

@@ -56,53 +56,6 @@ impl Dictionary<VoidSSTable> {
} }
} }
fn map_bound<TFrom, TTo>(bound: &Bound<TFrom>, transform: impl Fn(&TFrom) -> TTo) -> Bound<TTo> {
use self::Bound::*;
match bound {
Excluded(ref from_val) => Bound::Excluded(transform(from_val)),
Included(ref from_val) => Bound::Included(transform(from_val)),
Unbounded => Unbounded,
}
}
/// Takes a bound and transforms the inner value into a new bound via a closure.
/// The bound variant may change by the value returned value from the closure.
fn transform_bound_inner<TFrom, TTo>(
bound: &Bound<TFrom>,
transform: impl Fn(&TFrom) -> io::Result<Bound<TTo>>,
) -> io::Result<Bound<TTo>> {
use self::Bound::*;
Ok(match bound {
Excluded(ref from_val) => transform(from_val)?,
Included(ref from_val) => transform(from_val)?,
Unbounded => Unbounded,
})
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TermOrdHit {
/// Exact term ord hit
Exact(TermOrdinal),
/// Next best term ordinal
Next(TermOrdinal),
}
impl TermOrdHit {
fn into_exact(self) -> Option<TermOrdinal> {
match self {
TermOrdHit::Exact(ord) => Some(ord),
TermOrdHit::Next(_) => None,
}
}
fn map<F: FnOnce(TermOrdinal) -> TermOrdinal>(self, f: F) -> Self {
match self {
TermOrdHit::Exact(ord) => TermOrdHit::Exact(f(ord)),
TermOrdHit::Next(ord) => TermOrdHit::Next(f(ord)),
}
}
}
impl<TSSTable: SSTable> Dictionary<TSSTable> { impl<TSSTable: SSTable> Dictionary<TSSTable> {
pub fn builder<W: io::Write>(wrt: W) -> io::Result<crate::Writer<W, TSSTable::ValueWriter>> { pub fn builder<W: io::Write>(wrt: W) -> io::Result<crate::Writer<W, TSSTable::ValueWriter>> {
Ok(TSSTable::writer(wrt)) Ok(TSSTable::writer(wrt))
@@ -304,17 +257,6 @@ impl<TSSTable: SSTable> Dictionary<TSSTable> {
key: K, key: K,
sstable_delta_reader: &mut DeltaReader<TSSTable::ValueReader>, sstable_delta_reader: &mut DeltaReader<TSSTable::ValueReader>,
) -> io::Result<Option<TermOrdinal>> { ) -> io::Result<Option<TermOrdinal>> {
self.decode_up_to_or_next(key, sstable_delta_reader)
.map(|hit| hit.into_exact())
}
/// Decode a DeltaReader up to key, returning the number of terms traversed
///
/// If the key was not found, it returns the next term id.
fn decode_up_to_or_next<K: AsRef<[u8]>>(
&self,
key: K,
sstable_delta_reader: &mut DeltaReader<TSSTable::ValueReader>,
) -> io::Result<TermOrdHit> {
let mut term_ord = 0; let mut term_ord = 0;
let key_bytes = key.as_ref(); let key_bytes = key.as_ref();
let mut ok_bytes = 0; let mut ok_bytes = 0;
@@ -323,7 +265,7 @@ impl<TSSTable: SSTable> Dictionary<TSSTable> {
let suffix = sstable_delta_reader.suffix(); let suffix = sstable_delta_reader.suffix();
match prefix_len.cmp(&ok_bytes) { match prefix_len.cmp(&ok_bytes) {
Ordering::Less => return Ok(TermOrdHit::Next(term_ord)), /* popped bytes already matched => too far */ Ordering::Less => return Ok(None), // popped bytes already matched => too far
Ordering::Equal => (), Ordering::Equal => (),
Ordering::Greater => { Ordering::Greater => {
// the ok prefix is less than current entry prefix => continue to next elem // the ok prefix is less than current entry prefix => continue to next elem
@@ -335,26 +277,25 @@ impl<TSSTable: SSTable> Dictionary<TSSTable> {
// we have ok_bytes byte of common prefix, check if this key adds more // we have ok_bytes byte of common prefix, check if this key adds more
for (key_byte, suffix_byte) in key_bytes[ok_bytes..].iter().zip(suffix) { for (key_byte, suffix_byte) in key_bytes[ok_bytes..].iter().zip(suffix) {
match suffix_byte.cmp(key_byte) { match suffix_byte.cmp(key_byte) {
Ordering::Less => break, // byte too small Ordering::Less => break, // byte too small
Ordering::Equal => ok_bytes += 1, // new matching Ordering::Equal => ok_bytes += 1, // new matching byte
// byte Ordering::Greater => return Ok(None), // too far
Ordering::Greater => return Ok(TermOrdHit::Next(term_ord)), // too far
} }
} }
if ok_bytes == key_bytes.len() { if ok_bytes == key_bytes.len() {
if prefix_len + suffix.len() == ok_bytes { if prefix_len + suffix.len() == ok_bytes {
return Ok(TermOrdHit::Exact(term_ord)); return Ok(Some(term_ord));
} else { } else {
// current key is a prefix of current element, not a match // current key is a prefix of current element, not a match
return Ok(TermOrdHit::Next(term_ord)); return Ok(None);
} }
} }
term_ord += 1; term_ord += 1;
} }
Ok(TermOrdHit::Next(term_ord)) Ok(None)
} }
/// Returns the ordinal associated with a given term. /// Returns the ordinal associated with a given term.
@@ -371,61 +312,6 @@ impl<TSSTable: SSTable> Dictionary<TSSTable> {
.map(|opt| opt.map(|ord| ord + first_ordinal)) .map(|opt| opt.map(|ord| ord + first_ordinal))
} }
/// Returns the ordinal associated with a given term or its closest next term_id
/// The closest next term_id may not exist.
pub fn term_ord_or_next<K: AsRef<[u8]>>(&self, key: K) -> io::Result<TermOrdHit> {
let key_bytes = key.as_ref();
let Some(block_addr) = self.sstable_index.get_block_with_key(key_bytes) else {
// TODO: Would be more consistent to return last_term id + 1
return Ok(TermOrdHit::Next(u64::MAX));
};
let first_ordinal = block_addr.first_ordinal;
let mut sstable_delta_reader = self.sstable_delta_reader_block(block_addr)?;
self.decode_up_to_or_next(key_bytes, &mut sstable_delta_reader)
.map(|opt| opt.map(|ord| ord + first_ordinal))
}
/// Converts strings into a Bound range.
/// This does handle several special cases if the term is not exactly in the dictionary.
/// e.g. [bbb, ddd]
/// lower_bound: Bound::Included(aaa) => Included(0) // "Next" term id
/// lower_bound: Bound::Excluded(aaa) => Included(0) // "Next" term id + Change the Bounds
/// lower_bound: Bound::Included(ccc) => Included(1) // "Next" term id
/// lower_bound: Bound::Excluded(ccc) => Included(1) // "Next" term id + Change the Bounds
/// lower_bound: Bound::Included(zzz) => Included(2) // "Next" term id
/// lower_bound: Bound::Excluded(zzz) => Included(2) // "Next" term id + Change the Bounds
/// For zzz we should have some post processing to return an empty query`
///
/// upper_bound: Bound::Included(aaa) => Excluded(0) // "Next" term id + Change the bounds
/// upper_bound: Bound::Excluded(aaa) => Excluded(0) // "Next" term id
/// upper_bound: Bound::Included(ccc) => Excluded(1) // Next term id + Change the bounds
/// upper_bound: Bound::Excluded(ccc) => Excluded(1) // Next term id
/// upper_bound: Bound::Included(zzz) => Excluded(2) // Next term id + Change the bounds
/// upper_bound: Bound::Excluded(zzz) => Excluded(2) // Next term id
pub fn term_bounds_to_ord<K: AsRef<[u8]>>(
&self,
lower_bound: Bound<K>,
upper_bound: Bound<K>,
) -> io::Result<(Bound<TermOrdinal>, Bound<TermOrdinal>)> {
let lower_bound = transform_bound_inner(&lower_bound, |start_bound_bytes| {
let ord = self.term_ord_or_next(start_bound_bytes)?;
match ord {
TermOrdHit::Exact(ord) => Ok(map_bound(&lower_bound, |_| ord)),
TermOrdHit::Next(ord) => Ok(Bound::Included(ord)), // Change bounds to included
}
})?;
let upper_bound = transform_bound_inner(&upper_bound, |end_bound_bytes| {
let ord = self.term_ord_or_next(end_bound_bytes)?;
match ord {
TermOrdHit::Exact(ord) => Ok(map_bound(&upper_bound, |_| ord)),
TermOrdHit::Next(ord) => Ok(Bound::Excluded(ord)), // Change bounds to excluded
}
})?;
Ok((lower_bound, upper_bound))
}
/// Returns the term associated with a given term ordinal. /// Returns the term associated with a given term ordinal.
/// ///
/// Term ordinals are defined as the position of the term in /// Term ordinals are defined as the position of the term in
@@ -452,45 +338,6 @@ impl<TSSTable: SSTable> Dictionary<TSSTable> {
Ok(true) Ok(true)
} }
/// Returns the terms for a _sorted_ list of term ordinals.
///
/// Returns true if and only if all terms have been found.
pub fn sorted_ords_to_term_cb<F: FnMut(&[u8]) -> io::Result<()>>(
&self,
ord: impl Iterator<Item = TermOrdinal>,
mut cb: F,
) -> io::Result<bool> {
let mut bytes = Vec::new();
let mut current_block_addr = self.sstable_index.get_block_with_ord(0);
let mut current_sstable_delta_reader =
self.sstable_delta_reader_block(current_block_addr.clone())?;
let mut current_ordinal = 0;
for ord in ord {
assert!(ord >= current_ordinal);
// check if block changed for new term_ord
let new_block_addr = self.sstable_index.get_block_with_ord(ord);
if new_block_addr != current_block_addr {
current_block_addr = new_block_addr;
current_ordinal = current_block_addr.first_ordinal;
current_sstable_delta_reader =
self.sstable_delta_reader_block(current_block_addr.clone())?;
bytes.clear();
}
// move to ord inside that block
for _ in current_ordinal..=ord {
if !current_sstable_delta_reader.advance()? {
return Ok(false);
}
bytes.truncate(current_sstable_delta_reader.common_prefix_len());
bytes.extend_from_slice(current_sstable_delta_reader.suffix());
}
current_ordinal = ord + 1;
cb(&bytes)?;
}
Ok(true)
}
/// Returns the number of terms in the dictionary. /// Returns the number of terms in the dictionary.
pub fn term_info_from_ord(&self, term_ord: TermOrdinal) -> io::Result<Option<TSSTable::Value>> { pub fn term_info_from_ord(&self, term_ord: TermOrdinal) -> io::Result<Option<TSSTable::Value>> {
// find block in which the term would be // find block in which the term would be
@@ -569,13 +416,12 @@ impl<TSSTable: SSTable> Dictionary<TSSTable> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::ops::{Bound, Range}; use std::ops::Range;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use common::OwnedBytes; use common::OwnedBytes;
use super::Dictionary; use super::Dictionary;
use crate::dictionary::TermOrdHit;
use crate::MonotonicU64SSTable; use crate::MonotonicU64SSTable;
#[derive(Debug)] #[derive(Debug)]
@@ -639,140 +485,6 @@ mod tests {
(dictionary, table) (dictionary, table)
} }
#[test]
fn test_term_to_ord_or_next() {
let dict = {
let mut builder = Dictionary::<MonotonicU64SSTable>::builder(Vec::new()).unwrap();
builder.insert(b"bbb", &1).unwrap();
builder.insert(b"ddd", &2).unwrap();
let table = builder.finish().unwrap();
let table = Arc::new(PermissionedHandle::new(table));
let slice = common::file_slice::FileSlice::new(table.clone());
Dictionary::<MonotonicU64SSTable>::open(slice).unwrap()
};
assert_eq!(dict.term_ord_or_next(b"aaa").unwrap(), TermOrdHit::Next(0));
assert_eq!(dict.term_ord_or_next(b"bbb").unwrap(), TermOrdHit::Exact(0));
assert_eq!(dict.term_ord_or_next(b"bb").unwrap(), TermOrdHit::Next(0));
assert_eq!(dict.term_ord_or_next(b"bbbb").unwrap(), TermOrdHit::Next(1));
assert_eq!(dict.term_ord_or_next(b"dd").unwrap(), TermOrdHit::Next(1));
assert_eq!(dict.term_ord_or_next(b"ddd").unwrap(), TermOrdHit::Exact(1));
assert_eq!(dict.term_ord_or_next(b"dddd").unwrap(), TermOrdHit::Next(2));
// This is not u64::MAX because for very small sstables (only one block),
// we don't store an index, and the pseudo-index always reply that the
// answer lies in block number 0
assert_eq!(
dict.term_ord_or_next(b"zzzzzzz").unwrap(),
TermOrdHit::Next(2)
);
}
#[test]
fn test_term_to_ord_or_next_2() {
let dict = {
let mut builder = Dictionary::<MonotonicU64SSTable>::builder(Vec::new()).unwrap();
let mut term_ord = 0;
builder.insert(b"bbb", &term_ord).unwrap();
// Fill blocks in between
for elem in 0..50_000 {
term_ord += 1;
let key = format!("ccccc{elem:05X}").into_bytes();
builder.insert(&key, &term_ord).unwrap();
}
term_ord += 1;
builder.insert(b"eee", &term_ord).unwrap();
let table = builder.finish().unwrap();
let table = Arc::new(PermissionedHandle::new(table));
let slice = common::file_slice::FileSlice::new(table.clone());
Dictionary::<MonotonicU64SSTable>::open(slice).unwrap()
};
assert_eq!(dict.term_ord(b"bbb").unwrap(), Some(0));
assert_eq!(dict.term_ord_or_next(b"bbb").unwrap(), TermOrdHit::Exact(0));
assert_eq!(dict.term_ord_or_next(b"aaa").unwrap(), TermOrdHit::Next(0));
assert_eq!(dict.term_ord_or_next(b"bb").unwrap(), TermOrdHit::Next(0));
assert_eq!(dict.term_ord_or_next(b"bbbb").unwrap(), TermOrdHit::Next(1));
assert_eq!(
dict.term_ord_or_next(b"ee").unwrap(),
TermOrdHit::Next(50001)
);
assert_eq!(
dict.term_ord_or_next(b"eee").unwrap(),
TermOrdHit::Exact(50001)
);
assert_eq!(
dict.term_ord_or_next(b"eeee").unwrap(),
TermOrdHit::Next(u64::MAX)
);
assert_eq!(
dict.term_ord_or_next(b"zzzzzzz").unwrap(),
TermOrdHit::Next(u64::MAX)
);
}
#[test]
fn test_term_bounds_to_ord() {
let dict = {
let mut builder = Dictionary::<MonotonicU64SSTable>::builder(Vec::new()).unwrap();
builder.insert(b"bbb", &1).unwrap();
builder.insert(b"ddd", &2).unwrap();
let table = builder.finish().unwrap();
let table = Arc::new(PermissionedHandle::new(table));
let slice = common::file_slice::FileSlice::new(table.clone());
Dictionary::<MonotonicU64SSTable>::open(slice).unwrap()
};
// Test cases for lower_bound
let test_lower_bound = |bound, expected| {
assert_eq!(
dict.term_bounds_to_ord::<&[u8]>(bound, Bound::Included(b"ignored"))
.unwrap()
.0,
expected
);
};
test_lower_bound(Bound::Included(b"aaa".as_slice()), Bound::Included(0));
test_lower_bound(Bound::Excluded(b"aaa".as_slice()), Bound::Included(0));
test_lower_bound(Bound::Included(b"bbb".as_slice()), Bound::Included(0));
test_lower_bound(Bound::Excluded(b"bbb".as_slice()), Bound::Excluded(0));
test_lower_bound(Bound::Included(b"ccc".as_slice()), Bound::Included(1));
test_lower_bound(Bound::Excluded(b"ccc".as_slice()), Bound::Included(1));
test_lower_bound(Bound::Included(b"zzz".as_slice()), Bound::Included(2));
test_lower_bound(Bound::Excluded(b"zzz".as_slice()), Bound::Included(2));
// Test cases for upper_bound
let test_upper_bound = |bound, expected| {
assert_eq!(
dict.term_bounds_to_ord::<&[u8]>(Bound::Included(b"ignored"), bound,)
.unwrap()
.1,
expected
);
};
test_upper_bound(Bound::Included(b"ccc".as_slice()), Bound::Excluded(1));
test_upper_bound(Bound::Excluded(b"ccc".as_slice()), Bound::Excluded(1));
test_upper_bound(Bound::Included(b"zzz".as_slice()), Bound::Excluded(2));
test_upper_bound(Bound::Excluded(b"zzz".as_slice()), Bound::Excluded(2));
test_upper_bound(Bound::Included(b"ddd".as_slice()), Bound::Included(1));
test_upper_bound(Bound::Excluded(b"ddd".as_slice()), Bound::Excluded(1));
}
#[test] #[test]
fn test_ord_term_conversion() { fn test_ord_term_conversion() {
let (dic, slice) = make_test_sstable(); let (dic, slice) = make_test_sstable();
@@ -839,61 +551,6 @@ mod tests {
assert!(dic.term_ord(b"1000").unwrap().is_none()); assert!(dic.term_ord(b"1000").unwrap().is_none());
} }
#[test]
fn test_ords_term() {
let (dic, _slice) = make_test_sstable();
// Single term
let mut terms = Vec::new();
assert!(dic
.sorted_ords_to_term_cb(100_000..100_001, |term| {
terms.push(term.to_vec());
Ok(())
})
.unwrap());
assert_eq!(terms, vec![format!("{:05X}", 100_000).into_bytes(),]);
// Single term
let mut terms = Vec::new();
assert!(dic
.sorted_ords_to_term_cb(100_001..100_002, |term| {
terms.push(term.to_vec());
Ok(())
})
.unwrap());
assert_eq!(terms, vec![format!("{:05X}", 100_001).into_bytes(),]);
// both terms
let mut terms = Vec::new();
assert!(dic
.sorted_ords_to_term_cb(100_000..100_002, |term| {
terms.push(term.to_vec());
Ok(())
})
.unwrap());
assert_eq!(
terms,
vec![
format!("{:05X}", 100_000).into_bytes(),
format!("{:05X}", 100_001).into_bytes(),
]
);
// Test cross block
let mut terms = Vec::new();
assert!(dic
.sorted_ords_to_term_cb(98653..=98655, |term| {
terms.push(term.to_vec());
Ok(())
})
.unwrap());
assert_eq!(
terms,
vec![
format!("{:05X}", 98653).into_bytes(),
format!("{:05X}", 98654).into_bytes(),
format!("{:05X}", 98655).into_bytes(),
]
);
}
#[test] #[test]
fn test_range() { fn test_range() {
let (dic, slice) = make_test_sstable(); let (dic, slice) = make_test_sstable();

View File

@@ -78,7 +78,6 @@ impl ValueWriter for RangeValueWriter {
} }
#[cfg(test)] #[cfg(test)]
#[allow(clippy::single_range_in_vec_init)]
mod tests { mod tests {
use super::*; use super::*;

View File

@@ -39,6 +39,8 @@ pub fn deserialize_read(buf: &[u8]) -> (usize, u64) {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::u64;
use super::{deserialize_read, serialize}; use super::{deserialize_read, serialize};
fn aux_test_int(val: u64, expect_len: usize) { fn aux_test_int(val: u64, expect_len: usize) {

View File

@@ -54,7 +54,7 @@ fn bench_hashmap_throughput(c: &mut Criterion) {
); );
// numbers // numbers
let input_bytes = 1_000_000 * 8; let input_bytes = 1_000_000 * 8 as u64;
group.throughput(Throughput::Bytes(input_bytes)); group.throughput(Throughput::Bytes(input_bytes));
let numbers: Vec<[u8; 8]> = (0..1_000_000u64).map(|el| el.to_le_bytes()).collect(); let numbers: Vec<[u8; 8]> = (0..1_000_000u64).map(|el| el.to_le_bytes()).collect();
@@ -82,7 +82,7 @@ fn bench_hashmap_throughput(c: &mut Criterion) {
let mut rng = StdRng::from_seed([3u8; 32]); let mut rng = StdRng::from_seed([3u8; 32]);
let zipf = zipf::ZipfDistribution::new(10_000, 1.03).unwrap(); let zipf = zipf::ZipfDistribution::new(10_000, 1.03).unwrap();
let input_bytes = 1_000_000 * 8; let input_bytes = 1_000_000 * 8 as u64;
group.throughput(Throughput::Bytes(input_bytes)); group.throughput(Throughput::Bytes(input_bytes));
let zipf_numbers: Vec<[u8; 8]> = (0..1_000_000u64) let zipf_numbers: Vec<[u8; 8]> = (0..1_000_000u64)
.map(|_| zipf.sample(&mut rng).to_le_bytes()) .map(|_| zipf.sample(&mut rng).to_le_bytes())
@@ -110,7 +110,7 @@ impl DocIdRecorder {
} }
} }
fn create_hash_map<T: AsRef<[u8]>>(terms: impl Iterator<Item = T>) -> ArenaHashMap { fn create_hash_map<'a, T: AsRef<[u8]>>(terms: impl Iterator<Item = T>) -> ArenaHashMap {
let mut map = ArenaHashMap::with_capacity(HASHMAP_SIZE); let mut map = ArenaHashMap::with_capacity(HASHMAP_SIZE);
for term in terms { for term in terms {
map.mutate_or_create(term.as_ref(), |val| { map.mutate_or_create(term.as_ref(), |val| {
@@ -126,7 +126,7 @@ fn create_hash_map<T: AsRef<[u8]>>(terms: impl Iterator<Item = T>) -> ArenaHashM
map map
} }
fn create_hash_map_with_expull<T: AsRef<[u8]>>( fn create_hash_map_with_expull<'a, T: AsRef<[u8]>>(
terms: impl Iterator<Item = (u32, T)>, terms: impl Iterator<Item = (u32, T)>,
) -> ArenaHashMap { ) -> ArenaHashMap {
let mut memory_arena = MemoryArena::default(); let mut memory_arena = MemoryArena::default();
@@ -145,7 +145,7 @@ fn create_hash_map_with_expull<T: AsRef<[u8]>>(
map map
} }
fn create_fx_hash_ref_map_with_expull( fn create_fx_hash_ref_map_with_expull<'a>(
terms: impl Iterator<Item = &'static [u8]>, terms: impl Iterator<Item = &'static [u8]>,
) -> FxHashMap<&'static [u8], Vec<u32>> { ) -> FxHashMap<&'static [u8], Vec<u32>> {
let terms = terms.enumerate(); let terms = terms.enumerate();
@@ -158,7 +158,7 @@ fn create_fx_hash_ref_map_with_expull(
map map
} }
fn create_fx_hash_owned_map_with_expull( fn create_fx_hash_owned_map_with_expull<'a>(
terms: impl Iterator<Item = &'static [u8]>, terms: impl Iterator<Item = &'static [u8]>,
) -> FxHashMap<Vec<u8>, Vec<u32>> { ) -> FxHashMap<Vec<u8>, Vec<u32>> {
let terms = terms.enumerate(); let terms = terms.enumerate();

View File

@@ -10,7 +10,7 @@ fn main() {
} }
} }
fn create_hash_map<T: AsRef<str>>(terms: impl Iterator<Item = T>) -> ArenaHashMap { fn create_hash_map<'a, T: AsRef<str>>(terms: impl Iterator<Item = T>) -> ArenaHashMap {
let mut map = ArenaHashMap::with_capacity(4); let mut map = ArenaHashMap::with_capacity(4);
for term in terms { for term in terms {
map.mutate_or_create(term.as_ref().as_bytes(), |val| { map.mutate_or_create(term.as_ref().as_bytes(), |val| {