mirror of
https://github.com/quickwit-oss/tantivy.git
synced 2025-12-29 21:42:55 +00:00
Compare commits
11 Commits
test_order
...
agg_format
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b345c11786 | ||
|
|
7ebcc15b17 | ||
|
|
1b4076691f | ||
|
|
eab660873a | ||
|
|
232f37126e | ||
|
|
13e9885dfd | ||
|
|
56d79cb203 | ||
|
|
0f4c2e27cf | ||
|
|
f9ae295507 | ||
|
|
d9db5302d9 | ||
|
|
e453848134 |
4
.github/workflows/coverage.yml
vendored
4
.github/workflows/coverage.yml
vendored
@@ -15,11 +15,11 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Install Rust
|
||||
run: rustup toolchain install nightly-2024-04-10 --profile minimal --component llvm-tools-preview
|
||||
run: rustup toolchain install nightly-2024-07-01 --profile minimal --component llvm-tools-preview
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- uses: taiki-e/install-action@cargo-llvm-cov
|
||||
- name: Generate code coverage
|
||||
run: cargo +nightly-2024-04-10 llvm-cov --all-features --workspace --doctests --lcov --output-path lcov.info
|
||||
run: cargo +nightly-2024-07-01 llvm-cov --all-features --workspace --doctests --lcov --output-path lcov.info
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v3
|
||||
continue-on-error: true
|
||||
|
||||
@@ -11,7 +11,7 @@ repository = "https://github.com/quickwit-oss/tantivy"
|
||||
readme = "README.md"
|
||||
keywords = ["search", "information", "retrieval"]
|
||||
edition = "2021"
|
||||
rust-version = "1.63"
|
||||
rust-version = "1.66"
|
||||
exclude = ["benches/*.json", "benches/*.txt"]
|
||||
|
||||
[dependencies]
|
||||
@@ -38,7 +38,7 @@ levenshtein_automata = "0.2.1"
|
||||
uuid = { version = "1.0.0", features = ["v4", "serde"] }
|
||||
crossbeam-channel = "0.5.4"
|
||||
rust-stemmers = "1.2.0"
|
||||
downcast-rs = "1.2.0"
|
||||
downcast-rs = "1.2.1"
|
||||
bitpacking = { version = "0.9.2", default-features = false, features = [
|
||||
"bitpacker4x",
|
||||
] }
|
||||
@@ -64,6 +64,7 @@ tantivy-bitpacker = { version = "0.6", path = "./bitpacker" }
|
||||
common = { version = "0.7", path = "./common/", package = "tantivy-common" }
|
||||
tokenizer-api = { version = "0.3", path = "./tokenizer-api", package = "tantivy-tokenizer-api" }
|
||||
sketches-ddsketch = { version = "0.3.0", features = ["use_serde"] }
|
||||
hyperloglogplus = { version = "0.4.1", features = ["const-loop"] }
|
||||
futures-util = { version = "0.3.28", optional = true }
|
||||
fnv = "1.0.7"
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ Tantivy is, in fact, strongly inspired by Lucene's design.
|
||||
|
||||
## Benchmark
|
||||
|
||||
The following [benchmark](https://tantivy-search.github.io/bench/) breakdowns
|
||||
The following [benchmark](https://tantivy-search.github.io/bench/) breaks down the
|
||||
performance for different types of queries/collections.
|
||||
|
||||
Your mileage WILL vary depending on the nature of queries and their load.
|
||||
|
||||
@@ -51,10 +51,15 @@ fn bench_agg(mut group: InputGroup<Index>) {
|
||||
register!(group, percentiles_f64);
|
||||
register!(group, terms_few);
|
||||
register!(group, terms_many);
|
||||
register!(group, terms_many_top_1000);
|
||||
register!(group, terms_many_order_by_term);
|
||||
register!(group, terms_many_with_top_hits);
|
||||
register!(group, terms_many_with_avg_sub_agg);
|
||||
register!(group, terms_many_json_mixed_type_with_sub_agg_card);
|
||||
register!(group, terms_many_json_mixed_type_with_avg_sub_agg);
|
||||
|
||||
register!(group, cardinality_agg);
|
||||
register!(group, terms_few_with_cardinality_agg);
|
||||
|
||||
register!(group, range_agg);
|
||||
register!(group, range_agg_with_avg_sub_agg);
|
||||
register!(group, range_agg_with_term_agg_few);
|
||||
@@ -123,6 +128,33 @@ fn percentiles_f64(index: &Index) {
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn cardinality_agg(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"cardinality": {
|
||||
"cardinality": {
|
||||
"field": "text_many_terms"
|
||||
},
|
||||
}
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
fn terms_few_with_cardinality_agg(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": {
|
||||
"terms": { "field": "text_few_terms" },
|
||||
"aggs": {
|
||||
"cardinality": {
|
||||
"cardinality": {
|
||||
"field": "text_many_terms"
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn terms_few(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": { "terms": { "field": "text_few_terms" } },
|
||||
@@ -135,6 +167,12 @@ fn terms_many(index: &Index) {
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
fn terms_many_top_1000(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": { "terms": { "field": "text_many_terms", "size": 1000 } },
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
fn terms_many_order_by_term(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": { "terms": { "field": "text_many_terms", "order": { "_key": "desc" } } },
|
||||
@@ -171,7 +209,7 @@ fn terms_many_with_avg_sub_agg(index: &Index) {
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
fn terms_many_json_mixed_type_with_sub_agg_card(index: &Index) {
|
||||
fn terms_many_json_mixed_type_with_avg_sub_agg(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"my_texts": {
|
||||
"terms": { "field": "json.mixed_type" },
|
||||
@@ -268,6 +306,7 @@ fn range_agg_with_term_agg_many(index: &Index) {
|
||||
});
|
||||
execute_agg(index, agg_req);
|
||||
}
|
||||
|
||||
fn histogram(index: &Index) {
|
||||
let agg_req = json!({
|
||||
"rangef64": {
|
||||
|
||||
@@ -34,6 +34,7 @@ fn compute_stats(vals: impl Iterator<Item = u64>) -> ColumnStats {
|
||||
fn value_iter() -> impl Iterator<Item = u64> {
|
||||
0..20_000
|
||||
}
|
||||
|
||||
fn get_reader_for_bench<Codec: ColumnCodec>(data: &[u64]) -> Codec::ColumnValues {
|
||||
let mut bytes = Vec::new();
|
||||
let stats = compute_stats(data.iter().cloned());
|
||||
@@ -41,10 +42,13 @@ fn get_reader_for_bench<Codec: ColumnCodec>(data: &[u64]) -> Codec::ColumnValues
|
||||
for val in data {
|
||||
codec_serializer.collect(*val);
|
||||
}
|
||||
codec_serializer.serialize(&stats, Box::new(data.iter().copied()).as_mut(), &mut bytes);
|
||||
codec_serializer
|
||||
.serialize(&stats, Box::new(data.iter().copied()).as_mut(), &mut bytes)
|
||||
.unwrap();
|
||||
|
||||
Codec::load(OwnedBytes::new(bytes)).unwrap()
|
||||
}
|
||||
|
||||
fn bench_get<Codec: ColumnCodec>(b: &mut Bencher, data: &[u64]) {
|
||||
let col = get_reader_for_bench::<Codec>(data);
|
||||
b.iter(|| {
|
||||
|
||||
@@ -8,6 +8,7 @@ use std::net::Ipv6Addr;
|
||||
|
||||
use column_operation::ColumnOperation;
|
||||
pub(crate) use column_writers::CompatibleNumericalTypes;
|
||||
use common::json_path_writer::JSON_END_OF_PATH;
|
||||
use common::CountingWriter;
|
||||
pub(crate) use serializer::ColumnarSerializer;
|
||||
use stacker::{Addr, ArenaHashMap, MemoryArena};
|
||||
@@ -283,12 +284,17 @@ impl ColumnarWriter {
|
||||
.iter()
|
||||
.map(|(column_name, addr)| (column_name, ColumnType::DateTime, addr)),
|
||||
);
|
||||
// TODO: replace JSON_END_OF_PATH with b'0' in columns
|
||||
columns.sort_unstable_by_key(|(column_name, col_type, _)| (*column_name, *col_type));
|
||||
|
||||
let (arena, buffers, dictionaries) = (&self.arena, &mut self.buffers, &self.dictionaries);
|
||||
let mut symbol_byte_buffer: Vec<u8> = Vec::new();
|
||||
for (column_name, column_type, addr) in columns {
|
||||
if column_name.contains(&JSON_END_OF_PATH) {
|
||||
// Tantivy uses b'0' as a separator for nested fields in JSON.
|
||||
// Column names with a b'0' are not simply ignored by the columnar (and the inverted
|
||||
// index).
|
||||
continue;
|
||||
}
|
||||
match column_type {
|
||||
ColumnType::Bool => {
|
||||
let column_writer: ColumnWriter = self.bool_field_hash_map.read(addr);
|
||||
|
||||
@@ -93,18 +93,3 @@ impl<'a, W: io::Write> io::Write for ColumnSerializer<'a, W> {
|
||||
self.columnar_serializer.wrt.write_all(buf)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_prepare_key_bytes() {
|
||||
let mut buffer: Vec<u8> = b"somegarbage".to_vec();
|
||||
prepare_key(b"root\0child", ColumnType::Str, &mut buffer);
|
||||
assert_eq!(buffer.len(), 12);
|
||||
assert_eq!(&buffer[..10], b"root0child");
|
||||
assert_eq!(buffer[10], 0u8);
|
||||
assert_eq!(buffer[11], ColumnType::Str.to_code());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ documentation = "https://docs.rs/tantivy_common/"
|
||||
homepage = "https://github.com/quickwit-oss/tantivy"
|
||||
repository = "https://github.com/quickwit-oss/tantivy"
|
||||
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
@@ -20,8 +19,7 @@ time = { version = "0.3.10", features = ["serde-well-known"] }
|
||||
serde = { version = "1.0.136", features = ["derive"] }
|
||||
|
||||
[dev-dependencies]
|
||||
binggan = "0.8.1"
|
||||
proptest = "1.0.0"
|
||||
rand = "0.8.4"
|
||||
|
||||
[features]
|
||||
unstable = [] # useful for benches.
|
||||
|
||||
@@ -1,39 +1,64 @@
|
||||
#![feature(test)]
|
||||
use binggan::{black_box, BenchRunner};
|
||||
use rand::seq::IteratorRandom;
|
||||
use rand::thread_rng;
|
||||
use tantivy_common::{serialize_vint_u32, BitSet, TinySet};
|
||||
|
||||
extern crate test;
|
||||
fn bench_vint() {
|
||||
let mut runner = BenchRunner::new();
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rand::seq::IteratorRandom;
|
||||
use rand::thread_rng;
|
||||
use tantivy_common::serialize_vint_u32;
|
||||
use test::Bencher;
|
||||
let vals: Vec<u32> = (0..20_000).collect();
|
||||
runner.bench_function("bench_vint", move |_| {
|
||||
let mut out = 0u64;
|
||||
for val in vals.iter().cloned() {
|
||||
let mut buf = [0u8; 8];
|
||||
serialize_vint_u32(val, &mut buf);
|
||||
out += u64::from(buf[0]);
|
||||
}
|
||||
black_box(out);
|
||||
});
|
||||
|
||||
#[bench]
|
||||
fn bench_vint(b: &mut Bencher) {
|
||||
let vals: Vec<u32> = (0..20_000).collect();
|
||||
b.iter(|| {
|
||||
let mut out = 0u64;
|
||||
for val in vals.iter().cloned() {
|
||||
let mut buf = [0u8; 8];
|
||||
serialize_vint_u32(val, &mut buf);
|
||||
out += u64::from(buf[0]);
|
||||
}
|
||||
out
|
||||
});
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_vint_rand(b: &mut Bencher) {
|
||||
let vals: Vec<u32> = (0..20_000).choose_multiple(&mut thread_rng(), 100_000);
|
||||
b.iter(|| {
|
||||
let mut out = 0u64;
|
||||
for val in vals.iter().cloned() {
|
||||
let mut buf = [0u8; 8];
|
||||
serialize_vint_u32(val, &mut buf);
|
||||
out += u64::from(buf[0]);
|
||||
}
|
||||
out
|
||||
});
|
||||
}
|
||||
let vals: Vec<u32> = (0..20_000).choose_multiple(&mut thread_rng(), 100_000);
|
||||
runner.bench_function("bench_vint_rand", move |_| {
|
||||
let mut out = 0u64;
|
||||
for val in vals.iter().cloned() {
|
||||
let mut buf = [0u8; 8];
|
||||
serialize_vint_u32(val, &mut buf);
|
||||
out += u64::from(buf[0]);
|
||||
}
|
||||
black_box(out);
|
||||
});
|
||||
}
|
||||
|
||||
fn bench_bitset() {
|
||||
let mut runner = BenchRunner::new();
|
||||
|
||||
runner.bench_function("bench_tinyset_pop", move |_| {
|
||||
let mut tinyset = TinySet::singleton(black_box(31u32));
|
||||
tinyset.pop_lowest();
|
||||
tinyset.pop_lowest();
|
||||
tinyset.pop_lowest();
|
||||
tinyset.pop_lowest();
|
||||
tinyset.pop_lowest();
|
||||
tinyset.pop_lowest();
|
||||
black_box(tinyset);
|
||||
});
|
||||
|
||||
let tiny_set = TinySet::empty().insert(10u32).insert(14u32).insert(21u32);
|
||||
runner.bench_function("bench_tinyset_sum", move |_| {
|
||||
assert_eq!(black_box(tiny_set).into_iter().sum::<u32>(), 45u32);
|
||||
});
|
||||
|
||||
let v = [10u32, 14u32, 21u32];
|
||||
runner.bench_function("bench_tinyarr_sum", move |_| {
|
||||
black_box(v.iter().cloned().sum::<u32>());
|
||||
});
|
||||
|
||||
runner.bench_function("bench_bitset_initialize", move |_| {
|
||||
black_box(BitSet::with_max_value(1_000_000));
|
||||
});
|
||||
}
|
||||
|
||||
fn main() {
|
||||
bench_vint();
|
||||
bench_bitset();
|
||||
}
|
||||
|
||||
@@ -696,43 +696,3 @@ mod tests {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "unstable"))]
|
||||
mod bench {
|
||||
|
||||
use test;
|
||||
|
||||
use super::{BitSet, TinySet};
|
||||
|
||||
#[bench]
|
||||
fn bench_tinyset_pop(b: &mut test::Bencher) {
|
||||
b.iter(|| {
|
||||
let mut tinyset = TinySet::singleton(test::black_box(31u32));
|
||||
tinyset.pop_lowest();
|
||||
tinyset.pop_lowest();
|
||||
tinyset.pop_lowest();
|
||||
tinyset.pop_lowest();
|
||||
tinyset.pop_lowest();
|
||||
tinyset.pop_lowest();
|
||||
});
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_tinyset_sum(b: &mut test::Bencher) {
|
||||
let tiny_set = TinySet::empty().insert(10u32).insert(14u32).insert(21u32);
|
||||
b.iter(|| {
|
||||
assert_eq!(test::black_box(tiny_set).into_iter().sum::<u32>(), 45u32);
|
||||
});
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_tinyarr_sum(b: &mut test::Bencher) {
|
||||
let v = [10u32, 14u32, 21u32];
|
||||
b.iter(|| test::black_box(v).iter().cloned().sum::<u32>());
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_bitset_initialize(b: &mut test::Bencher) {
|
||||
b.iter(|| BitSet::with_max_value(1_000_000));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use std::ops::Bound;
|
||||
|
||||
// # Searching a range on an indexed int field.
|
||||
//
|
||||
// Below is an example of creating an indexed integer field in your schema
|
||||
@@ -5,7 +7,7 @@
|
||||
use tantivy::collector::Count;
|
||||
use tantivy::query::RangeQuery;
|
||||
use tantivy::schema::{Schema, INDEXED};
|
||||
use tantivy::{doc, Index, IndexWriter, Result};
|
||||
use tantivy::{doc, Index, IndexWriter, Result, Term};
|
||||
|
||||
fn main() -> Result<()> {
|
||||
// For the sake of simplicity, this schema will only have 1 field
|
||||
@@ -27,7 +29,10 @@ fn main() -> Result<()> {
|
||||
reader.reload()?;
|
||||
let searcher = reader.searcher();
|
||||
// The end is excluded i.e. here we are searching up to 1969
|
||||
let docs_in_the_sixties = RangeQuery::new_u64("year".to_string(), 1960..1970);
|
||||
let docs_in_the_sixties = RangeQuery::new(
|
||||
Bound::Included(Term::from_field_u64(year_field, 1960)),
|
||||
Bound::Excluded(Term::from_field_u64(year_field, 1970)),
|
||||
);
|
||||
// Uses a Count collector to sum the total number of docs in the range
|
||||
let num_60s_books = searcher.search(&docs_in_the_sixties, &Count)?;
|
||||
assert_eq!(num_60s_books, 10);
|
||||
|
||||
@@ -34,8 +34,9 @@ use super::bucket::{
|
||||
DateHistogramAggregationReq, HistogramAggregation, RangeAggregation, TermsAggregation,
|
||||
};
|
||||
use super::metric::{
|
||||
AverageAggregation, CountAggregation, ExtendedStatsAggregation, MaxAggregation, MinAggregation,
|
||||
PercentilesAggregationReq, StatsAggregation, SumAggregation, TopHitsAggregation,
|
||||
AverageAggregation, CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation,
|
||||
MaxAggregation, MinAggregation, PercentilesAggregationReq, StatsAggregation, SumAggregation,
|
||||
TopHitsAggregationReq,
|
||||
};
|
||||
|
||||
/// The top-level aggregation request structure, which contains [`Aggregation`] and their user
|
||||
@@ -159,7 +160,10 @@ pub enum AggregationVariants {
|
||||
Percentiles(PercentilesAggregationReq),
|
||||
/// Finds the top k values matching some order
|
||||
#[serde(rename = "top_hits")]
|
||||
TopHits(TopHitsAggregation),
|
||||
TopHits(TopHitsAggregationReq),
|
||||
/// Computes an estimate of the number of unique values
|
||||
#[serde(rename = "cardinality")]
|
||||
Cardinality(CardinalityAggregationReq),
|
||||
}
|
||||
|
||||
impl AggregationVariants {
|
||||
@@ -179,6 +183,7 @@ impl AggregationVariants {
|
||||
AggregationVariants::Sum(sum) => vec![sum.field_name()],
|
||||
AggregationVariants::Percentiles(per) => vec![per.field_name()],
|
||||
AggregationVariants::TopHits(top_hits) => top_hits.field_names(),
|
||||
AggregationVariants::Cardinality(per) => vec![per.field_name()],
|
||||
}
|
||||
}
|
||||
|
||||
@@ -203,7 +208,7 @@ impl AggregationVariants {
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
pub(crate) fn as_top_hits(&self) -> Option<&TopHitsAggregation> {
|
||||
pub(crate) fn as_top_hits(&self) -> Option<&TopHitsAggregationReq> {
|
||||
match &self {
|
||||
AggregationVariants::TopHits(top_hits) => Some(top_hits),
|
||||
_ => None,
|
||||
|
||||
@@ -11,8 +11,8 @@ use super::bucket::{
|
||||
DateHistogramAggregationReq, HistogramAggregation, RangeAggregation, TermsAggregation,
|
||||
};
|
||||
use super::metric::{
|
||||
AverageAggregation, CountAggregation, ExtendedStatsAggregation, MaxAggregation, MinAggregation,
|
||||
StatsAggregation, SumAggregation,
|
||||
AverageAggregation, CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation,
|
||||
MaxAggregation, MinAggregation, StatsAggregation, SumAggregation,
|
||||
};
|
||||
use super::segment_agg_result::AggregationLimits;
|
||||
use super::VecWithNames;
|
||||
@@ -162,6 +162,11 @@ impl AggregationWithAccessor {
|
||||
field: ref field_name,
|
||||
ref missing,
|
||||
..
|
||||
})
|
||||
| Cardinality(CardinalityAggregationReq {
|
||||
field: ref field_name,
|
||||
ref missing,
|
||||
..
|
||||
}) => {
|
||||
let str_dict_column = reader.fast_fields().str(field_name)?;
|
||||
let allowed_column_types = [
|
||||
|
||||
@@ -98,6 +98,8 @@ pub enum MetricResult {
|
||||
Percentiles(PercentilesMetricResult),
|
||||
/// Top hits metric result
|
||||
TopHits(TopHitsMetricResult),
|
||||
/// Cardinality metric result
|
||||
Cardinality(SingleMetricResult),
|
||||
}
|
||||
|
||||
impl MetricResult {
|
||||
@@ -116,6 +118,7 @@ impl MetricResult {
|
||||
MetricResult::TopHits(_) => Err(TantivyError::AggregationError(
|
||||
AggregationError::InvalidRequest("top_hits can't be used to order".to_string()),
|
||||
)),
|
||||
MetricResult::Cardinality(card) => Ok(card.value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -110,6 +110,16 @@ fn test_aggregation_flushing(
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"cardinality_string_id":{
|
||||
"cardinality": {
|
||||
"field": "string_id"
|
||||
}
|
||||
},
|
||||
"cardinality_score":{
|
||||
"cardinality": {
|
||||
"field": "score"
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
@@ -212,6 +222,9 @@ fn test_aggregation_flushing(
|
||||
)
|
||||
);
|
||||
|
||||
assert_eq!(res["cardinality_string_id"]["value"], 2.0);
|
||||
assert_eq!(res["cardinality_score"]["value"], 80.0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -926,10 +939,10 @@ fn test_aggregation_on_json_object_mixed_types() {
|
||||
},
|
||||
"termagg": {
|
||||
"buckets": [
|
||||
{ "doc_count": 1, "key": 10.0, "min_price": { "value": 10.0 } },
|
||||
{ "doc_count": 1, "key": 10.0, "key_as_string": "10", "min_price": { "value": 10.0 } },
|
||||
{ "doc_count": 3, "key": "blue", "min_price": { "value": 5.0 } },
|
||||
{ "doc_count": 2, "key": "red", "min_price": { "value": 1.0 } },
|
||||
{ "doc_count": 1, "key": -20.5, "min_price": { "value": -20.5 } },
|
||||
{ "doc_count": 1, "key": -20.5, "key_as_string": "-20.5", "min_price": { "value": -20.5 } },
|
||||
{ "doc_count": 2, "key": 1.0, "key_as_string": "true", "min_price": { "value": null } },
|
||||
],
|
||||
"sum_other_doc_count": 0
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
use std::fmt::Debug;
|
||||
use std::io;
|
||||
use std::net::Ipv6Addr;
|
||||
|
||||
use columnar::column_values::CompactSpaceU64Accessor;
|
||||
use columnar::{
|
||||
BytesColumn, ColumnType, MonotonicallyMappableToU128, MonotonicallyMappableToU64, StrColumn,
|
||||
};
|
||||
use columnar::{ColumnType, Dictionary, MonotonicallyMappableToU128, MonotonicallyMappableToU64};
|
||||
use rustc_hash::FxHashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@@ -466,49 +465,66 @@ impl SegmentTermCollector {
|
||||
};
|
||||
|
||||
if self.column_type == ColumnType::Str {
|
||||
let fallback_dict = Dictionary::empty();
|
||||
let term_dict = agg_with_accessor
|
||||
.str_dict_column
|
||||
.as_ref()
|
||||
.cloned()
|
||||
.unwrap_or_else(|| {
|
||||
StrColumn::wrap(BytesColumn::empty(agg_with_accessor.accessor.num_docs()))
|
||||
});
|
||||
let mut buffer = String::new();
|
||||
for (term_id, doc_count) in entries {
|
||||
let intermediate_entry = into_intermediate_bucket_entry(term_id, doc_count)?;
|
||||
// Special case for missing key
|
||||
if term_id == u64::MAX {
|
||||
let missing_key = self
|
||||
.req
|
||||
.missing
|
||||
.as_ref()
|
||||
.expect("Found placeholder term_id but `missing` is None");
|
||||
match missing_key {
|
||||
Key::Str(missing) => {
|
||||
buffer.clear();
|
||||
buffer.push_str(missing);
|
||||
dict.insert(
|
||||
IntermediateKey::Str(buffer.to_string()),
|
||||
intermediate_entry,
|
||||
);
|
||||
}
|
||||
Key::F64(val) => {
|
||||
buffer.push_str(&val.to_string());
|
||||
dict.insert(IntermediateKey::F64(*val), intermediate_entry);
|
||||
}
|
||||
.map(|el| el.dictionary())
|
||||
.unwrap_or_else(|| &fallback_dict);
|
||||
let mut buffer = Vec::new();
|
||||
|
||||
// special case for missing key
|
||||
if let Some(index) = entries.iter().position(|value| value.0 == u64::MAX) {
|
||||
let entry = entries[index];
|
||||
let intermediate_entry = into_intermediate_bucket_entry(entry.0, entry.1)?;
|
||||
let missing_key = self
|
||||
.req
|
||||
.missing
|
||||
.as_ref()
|
||||
.expect("Found placeholder term_id but `missing` is None");
|
||||
match missing_key {
|
||||
Key::Str(missing) => {
|
||||
buffer.clear();
|
||||
buffer.extend_from_slice(missing.as_bytes());
|
||||
dict.insert(
|
||||
IntermediateKey::Str(
|
||||
String::from_utf8(buffer.to_vec())
|
||||
.expect("could not convert to String"),
|
||||
),
|
||||
intermediate_entry,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
if !term_dict.ord_to_str(term_id, &mut buffer)? {
|
||||
return Err(TantivyError::InternalError(format!(
|
||||
"Couldn't find term_id {term_id} in dict"
|
||||
)));
|
||||
Key::F64(val) => {
|
||||
dict.insert(IntermediateKey::F64(*val), intermediate_entry);
|
||||
}
|
||||
dict.insert(IntermediateKey::Str(buffer.to_string()), intermediate_entry);
|
||||
}
|
||||
|
||||
entries.swap_remove(index);
|
||||
}
|
||||
|
||||
// Sort by term ord
|
||||
entries.sort_unstable_by_key(|bucket| bucket.0);
|
||||
let mut idx = 0;
|
||||
term_dict.sorted_ords_to_term_cb(
|
||||
entries.iter().map(|(term_id, _)| *term_id),
|
||||
|term| {
|
||||
let entry = entries[idx];
|
||||
let intermediate_entry = into_intermediate_bucket_entry(entry.0, entry.1)
|
||||
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
|
||||
dict.insert(
|
||||
IntermediateKey::Str(
|
||||
String::from_utf8(term.to_vec()).expect("could not convert to String"),
|
||||
),
|
||||
intermediate_entry,
|
||||
);
|
||||
idx += 1;
|
||||
Ok(())
|
||||
},
|
||||
)?;
|
||||
|
||||
if self.req.min_doc_count == 0 {
|
||||
// TODO: Handle rev streaming for descending sorting by keys
|
||||
let mut stream = term_dict.dictionary().stream()?;
|
||||
let mut stream = term_dict.stream()?;
|
||||
let empty_sub_aggregation = IntermediateAggregationResults::empty_from_req(
|
||||
agg_with_accessor.agg.sub_aggregation(),
|
||||
);
|
||||
|
||||
@@ -26,6 +26,7 @@ use super::segment_agg_result::AggregationLimits;
|
||||
use super::{format_date, AggregationError, Key, SerializedKey};
|
||||
use crate::aggregation::agg_result::{AggregationResults, BucketEntries, BucketEntry};
|
||||
use crate::aggregation::bucket::TermsAggregationInternal;
|
||||
use crate::aggregation::metric::CardinalityCollector;
|
||||
use crate::TantivyError;
|
||||
|
||||
/// Contains the intermediate aggregation result, which is optimized to be merged with other
|
||||
@@ -227,6 +228,9 @@ pub(crate) fn empty_from_req(req: &Aggregation) -> IntermediateAggregationResult
|
||||
TopHits(ref req) => IntermediateAggregationResult::Metric(
|
||||
IntermediateMetricResult::TopHits(TopHitsTopNComputer::new(req)),
|
||||
),
|
||||
Cardinality(_) => IntermediateAggregationResult::Metric(
|
||||
IntermediateMetricResult::Cardinality(CardinalityCollector::default()),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -291,6 +295,8 @@ pub enum IntermediateMetricResult {
|
||||
Sum(IntermediateSum),
|
||||
/// Intermediate top_hits result
|
||||
TopHits(TopHitsTopNComputer),
|
||||
/// Intermediate cardinality result
|
||||
Cardinality(CardinalityCollector),
|
||||
}
|
||||
|
||||
impl IntermediateMetricResult {
|
||||
@@ -324,6 +330,9 @@ impl IntermediateMetricResult {
|
||||
IntermediateMetricResult::TopHits(top_hits) => {
|
||||
MetricResult::TopHits(top_hits.into_final_result())
|
||||
}
|
||||
IntermediateMetricResult::Cardinality(cardinality) => {
|
||||
MetricResult::Cardinality(cardinality.finalize().into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -372,6 +381,12 @@ impl IntermediateMetricResult {
|
||||
(IntermediateMetricResult::TopHits(left), IntermediateMetricResult::TopHits(right)) => {
|
||||
left.merge_fruits(right)?;
|
||||
}
|
||||
(
|
||||
IntermediateMetricResult::Cardinality(left),
|
||||
IntermediateMetricResult::Cardinality(right),
|
||||
) => {
|
||||
left.merge_fruits(right)?;
|
||||
}
|
||||
_ => {
|
||||
panic!("incompatible fruit types in tree or missing merge_fruits handler");
|
||||
}
|
||||
@@ -584,6 +599,7 @@ impl IntermediateTermBucketResult {
|
||||
let val = if key { "true" } else { "false" };
|
||||
Some(val.to_string())
|
||||
}
|
||||
IntermediateKey::F64(val) => Some(val.to_string()),
|
||||
_ => None,
|
||||
};
|
||||
Ok(BucketEntry {
|
||||
|
||||
466
src/aggregation/metric/cardinality.rs
Normal file
466
src/aggregation/metric/cardinality.rs
Normal file
@@ -0,0 +1,466 @@
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{BuildHasher, Hasher};
|
||||
|
||||
use columnar::column_values::CompactSpaceU64Accessor;
|
||||
use columnar::Dictionary;
|
||||
use common::f64_to_u64;
|
||||
use hyperloglogplus::{HyperLogLog, HyperLogLogPlus};
|
||||
use rustc_hash::FxHashSet;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::aggregation::agg_req_with_accessor::{
|
||||
AggregationWithAccessor, AggregationsWithAccessor,
|
||||
};
|
||||
use crate::aggregation::intermediate_agg_result::{
|
||||
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
|
||||
};
|
||||
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::*;
|
||||
use crate::TantivyError;
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
struct BuildSaltedHasher {
|
||||
salt: u8,
|
||||
}
|
||||
|
||||
impl BuildHasher for BuildSaltedHasher {
|
||||
type Hasher = DefaultHasher;
|
||||
|
||||
fn build_hasher(&self) -> Self::Hasher {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
hasher.write_u8(self.salt);
|
||||
|
||||
hasher
|
||||
}
|
||||
}
|
||||
|
||||
/// # Cardinality
|
||||
///
|
||||
/// The cardinality aggregation allows for computing an estimate
|
||||
/// of the number of different values in a data set based on the
|
||||
/// HyperLogLog++ algorithm. This is particularly useful for understanding the
|
||||
/// uniqueness of values in a large dataset where counting each unique value
|
||||
/// individually would be computationally expensive.
|
||||
///
|
||||
/// For example, you might use a cardinality aggregation to estimate the number
|
||||
/// of unique visitors to a website by aggregating on a field that contains
|
||||
/// user IDs or session IDs.
|
||||
///
|
||||
/// To use the cardinality aggregation, you'll need to provide a field to
|
||||
/// aggregate on. The following example demonstrates a request for the cardinality
|
||||
/// of the "user_id" field:
|
||||
///
|
||||
/// ```JSON
|
||||
/// {
|
||||
/// "cardinality": {
|
||||
/// "field": "user_id"
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// This request will return an estimate of the number of unique values in the
|
||||
/// "user_id" field.
|
||||
///
|
||||
/// ## Missing Values
|
||||
///
|
||||
/// The `missing` parameter defines how documents that are missing a value should be treated.
|
||||
/// By default, documents without a value for the specified field are ignored. However, you can
|
||||
/// specify a default value for these documents using the `missing` parameter. This can be useful
|
||||
/// when you want to include documents with missing values in the aggregation.
|
||||
///
|
||||
/// For example, the following request treats documents with missing values in the "user_id"
|
||||
/// field as if they had a value of "unknown":
|
||||
///
|
||||
/// ```JSON
|
||||
/// {
|
||||
/// "cardinality": {
|
||||
/// "field": "user_id",
|
||||
/// "missing": "unknown"
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// # Estimation Accuracy
|
||||
///
|
||||
/// The cardinality aggregation provides an approximate count, which is usually
|
||||
/// accurate within a small error range. This trade-off allows for efficient
|
||||
/// computation even on very large datasets.
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct CardinalityAggregationReq {
|
||||
/// The field name to compute the percentiles on.
|
||||
pub field: String,
|
||||
/// The missing parameter defines how documents that are missing a value should be treated.
|
||||
/// By default they will be ignored but it is also possible to treat them as if they had a
|
||||
/// value. Examples in JSON format:
|
||||
/// { "field": "my_numbers", "missing": "10.0" }
|
||||
#[serde(skip_serializing_if = "Option::is_none", default)]
|
||||
pub missing: Option<Key>,
|
||||
}
|
||||
|
||||
impl CardinalityAggregationReq {
|
||||
/// Creates a new [`CardinalityAggregationReq`] instance from a field name.
|
||||
pub fn from_field_name(field_name: String) -> Self {
|
||||
Self {
|
||||
field: field_name,
|
||||
missing: None,
|
||||
}
|
||||
}
|
||||
/// Returns the field name the aggregation is computed on.
|
||||
pub fn field_name(&self) -> &str {
|
||||
&self.field
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub(crate) struct SegmentCardinalityCollector {
|
||||
cardinality: CardinalityCollector,
|
||||
entries: FxHashSet<u64>,
|
||||
column_type: ColumnType,
|
||||
accessor_idx: usize,
|
||||
missing: Option<Key>,
|
||||
}
|
||||
|
||||
impl SegmentCardinalityCollector {
|
||||
pub fn from_req(column_type: ColumnType, accessor_idx: usize, missing: &Option<Key>) -> Self {
|
||||
Self {
|
||||
cardinality: CardinalityCollector::new(column_type as u8),
|
||||
entries: Default::default(),
|
||||
column_type,
|
||||
accessor_idx,
|
||||
missing: missing.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn fetch_block_with_field(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_accessor: &mut AggregationWithAccessor,
|
||||
) {
|
||||
if let Some(missing) = agg_accessor.missing_value_for_accessor {
|
||||
agg_accessor.column_block_accessor.fetch_block_with_missing(
|
||||
docs,
|
||||
&agg_accessor.accessor,
|
||||
missing,
|
||||
);
|
||||
} else {
|
||||
agg_accessor
|
||||
.column_block_accessor
|
||||
.fetch_block(docs, &agg_accessor.accessor);
|
||||
}
|
||||
}
|
||||
|
||||
fn into_intermediate_metric_result(
|
||||
mut self,
|
||||
agg_with_accessor: &AggregationWithAccessor,
|
||||
) -> crate::Result<IntermediateMetricResult> {
|
||||
if self.column_type == ColumnType::Str {
|
||||
let fallback_dict = Dictionary::empty();
|
||||
let dict = agg_with_accessor
|
||||
.str_dict_column
|
||||
.as_ref()
|
||||
.map(|el| el.dictionary())
|
||||
.unwrap_or_else(|| &fallback_dict);
|
||||
let mut has_missing = false;
|
||||
|
||||
// TODO: replace FxHashSet with something that allows iterating in order
|
||||
// (e.g. sparse bitvec)
|
||||
let mut term_ids = Vec::new();
|
||||
for term_ord in self.entries.into_iter() {
|
||||
if term_ord == u64::MAX {
|
||||
has_missing = true;
|
||||
} else {
|
||||
// we can reasonably exclude values above u32::MAX
|
||||
term_ids.push(term_ord as u32);
|
||||
}
|
||||
}
|
||||
term_ids.sort_unstable();
|
||||
dict.sorted_ords_to_term_cb(term_ids.iter().map(|term| *term as u64), |term| {
|
||||
self.cardinality.sketch.insert_any(&term);
|
||||
Ok(())
|
||||
})?;
|
||||
if has_missing {
|
||||
let missing_key = self
|
||||
.missing
|
||||
.as_ref()
|
||||
.expect("Found placeholder term_ord but `missing` is None");
|
||||
match missing_key {
|
||||
Key::Str(missing) => {
|
||||
self.cardinality.sketch.insert_any(&missing);
|
||||
}
|
||||
Key::F64(val) => {
|
||||
let val = f64_to_u64(*val);
|
||||
self.cardinality.sketch.insert_any(&val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(IntermediateMetricResult::Cardinality(self.cardinality))
|
||||
}
|
||||
}
|
||||
|
||||
impl SegmentAggregationCollector for SegmentCardinalityCollector {
|
||||
fn add_intermediate_aggregation_result(
|
||||
self: Box<Self>,
|
||||
agg_with_accessor: &AggregationsWithAccessor,
|
||||
results: &mut IntermediateAggregationResults,
|
||||
) -> crate::Result<()> {
|
||||
let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string();
|
||||
let agg_with_accessor = &agg_with_accessor.aggs.values[self.accessor_idx];
|
||||
|
||||
let intermediate_result = self.into_intermediate_metric_result(agg_with_accessor)?;
|
||||
results.push(
|
||||
name,
|
||||
IntermediateAggregationResult::Metric(intermediate_result),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn collect(
|
||||
&mut self,
|
||||
doc: crate::DocId,
|
||||
agg_with_accessor: &mut AggregationsWithAccessor,
|
||||
) -> crate::Result<()> {
|
||||
self.collect_block(&[doc], agg_with_accessor)
|
||||
}
|
||||
|
||||
fn collect_block(
|
||||
&mut self,
|
||||
docs: &[crate::DocId],
|
||||
agg_with_accessor: &mut AggregationsWithAccessor,
|
||||
) -> crate::Result<()> {
|
||||
let bucket_agg_accessor = &mut agg_with_accessor.aggs.values[self.accessor_idx];
|
||||
self.fetch_block_with_field(docs, bucket_agg_accessor);
|
||||
|
||||
let col_block_accessor = &bucket_agg_accessor.column_block_accessor;
|
||||
if self.column_type == ColumnType::Str {
|
||||
for term_ord in col_block_accessor.iter_vals() {
|
||||
self.entries.insert(term_ord);
|
||||
}
|
||||
} else if self.column_type == ColumnType::IpAddr {
|
||||
let compact_space_accessor = bucket_agg_accessor
|
||||
.accessor
|
||||
.values
|
||||
.clone()
|
||||
.downcast_arc::<CompactSpaceU64Accessor>()
|
||||
.map_err(|_| {
|
||||
TantivyError::AggregationError(
|
||||
crate::aggregation::AggregationError::InternalError(
|
||||
"Type mismatch: Could not downcast to CompactSpaceU64Accessor"
|
||||
.to_string(),
|
||||
),
|
||||
)
|
||||
})?;
|
||||
for val in col_block_accessor.iter_vals() {
|
||||
let val: u128 = compact_space_accessor.compact_to_u128(val as u32);
|
||||
self.cardinality.sketch.insert_any(&val);
|
||||
}
|
||||
} else {
|
||||
for val in col_block_accessor.iter_vals() {
|
||||
self.cardinality.sketch.insert_any(&val);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
/// The percentiles collector used during segment collection and for merging results.
|
||||
pub struct CardinalityCollector {
|
||||
sketch: HyperLogLogPlus<u64, BuildSaltedHasher>,
|
||||
}
|
||||
impl Default for CardinalityCollector {
|
||||
fn default() -> Self {
|
||||
Self::new(0)
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for CardinalityCollector {
|
||||
fn eq(&self, _other: &Self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
impl CardinalityCollector {
|
||||
/// Compute the final cardinality estimate.
|
||||
pub fn finalize(self) -> Option<f64> {
|
||||
Some(self.sketch.clone().count().trunc())
|
||||
}
|
||||
|
||||
fn new(salt: u8) -> Self {
|
||||
Self {
|
||||
sketch: HyperLogLogPlus::new(16, BuildSaltedHasher { salt }).unwrap(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn merge_fruits(&mut self, right: CardinalityCollector) -> crate::Result<()> {
|
||||
self.sketch.merge(&right.sketch).map_err(|err| {
|
||||
TantivyError::AggregationError(AggregationError::InternalError(format!(
|
||||
"Error while merging cardinality {err:?}"
|
||||
)))
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use std::net::IpAddr;
|
||||
use std::str::FromStr;
|
||||
|
||||
use columnar::MonotonicallyMappableToU64;
|
||||
|
||||
use crate::aggregation::agg_req::Aggregations;
|
||||
use crate::aggregation::tests::{exec_request, get_test_index_from_terms};
|
||||
use crate::schema::{IntoIpv6Addr, Schema, FAST};
|
||||
use crate::Index;
|
||||
|
||||
#[test]
|
||||
fn cardinality_aggregation_test_empty_index() -> crate::Result<()> {
|
||||
let values = vec![];
|
||||
let index = get_test_index_from_terms(false, &values)?;
|
||||
let agg_req: Aggregations = serde_json::from_value(json!({
|
||||
"cardinality": {
|
||||
"cardinality": {
|
||||
"field": "string_id",
|
||||
}
|
||||
},
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
let res = exec_request(agg_req, &index)?;
|
||||
assert_eq!(res["cardinality"]["value"], 0.0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cardinality_aggregation_test_single_segment() -> crate::Result<()> {
|
||||
cardinality_aggregation_test_merge_segment(true)
|
||||
}
|
||||
#[test]
|
||||
fn cardinality_aggregation_test() -> crate::Result<()> {
|
||||
cardinality_aggregation_test_merge_segment(false)
|
||||
}
|
||||
fn cardinality_aggregation_test_merge_segment(merge_segments: bool) -> crate::Result<()> {
|
||||
let segment_and_terms = vec![
|
||||
vec!["terma"],
|
||||
vec!["termb"],
|
||||
vec!["termc"],
|
||||
vec!["terma"],
|
||||
vec!["terma"],
|
||||
vec!["terma"],
|
||||
vec!["termb"],
|
||||
vec!["terma"],
|
||||
];
|
||||
let index = get_test_index_from_terms(merge_segments, &segment_and_terms)?;
|
||||
let agg_req: Aggregations = serde_json::from_value(json!({
|
||||
"cardinality": {
|
||||
"cardinality": {
|
||||
"field": "string_id",
|
||||
}
|
||||
},
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
let res = exec_request(agg_req, &index)?;
|
||||
assert_eq!(res["cardinality"]["value"], 3.0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cardinality_aggregation_u64() -> crate::Result<()> {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let id_field = schema_builder.add_u64_field("id", FAST);
|
||||
let index = Index::create_in_ram(schema_builder.build());
|
||||
{
|
||||
let mut writer = index.writer_for_tests()?;
|
||||
writer.add_document(doc!(id_field => 1u64))?;
|
||||
writer.add_document(doc!(id_field => 2u64))?;
|
||||
writer.add_document(doc!(id_field => 3u64))?;
|
||||
writer.add_document(doc!())?;
|
||||
writer.commit()?;
|
||||
}
|
||||
|
||||
let agg_req: Aggregations = serde_json::from_value(json!({
|
||||
"cardinality": {
|
||||
"cardinality": {
|
||||
"field": "id",
|
||||
"missing": 0u64
|
||||
},
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
let res = exec_request(agg_req, &index)?;
|
||||
assert_eq!(res["cardinality"]["value"], 4.0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cardinality_aggregation_ip_addr() -> crate::Result<()> {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let field = schema_builder.add_ip_addr_field("ip_field", FAST);
|
||||
let index = Index::create_in_ram(schema_builder.build());
|
||||
{
|
||||
let mut writer = index.writer_for_tests()?;
|
||||
// IpV6 loopback
|
||||
writer.add_document(doc!(field=>IpAddr::from_str("::1").unwrap().into_ipv6_addr()))?;
|
||||
writer.add_document(doc!(field=>IpAddr::from_str("::1").unwrap().into_ipv6_addr()))?;
|
||||
// IpV4
|
||||
writer.add_document(
|
||||
doc!(field=>IpAddr::from_str("127.0.0.1").unwrap().into_ipv6_addr()),
|
||||
)?;
|
||||
writer.commit()?;
|
||||
}
|
||||
|
||||
let agg_req: Aggregations = serde_json::from_value(json!({
|
||||
"cardinality": {
|
||||
"cardinality": {
|
||||
"field": "ip_field"
|
||||
},
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
let res = exec_request(agg_req, &index)?;
|
||||
assert_eq!(res["cardinality"]["value"], 2.0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cardinality_aggregation_json() -> crate::Result<()> {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let field = schema_builder.add_json_field("json", FAST);
|
||||
let index = Index::create_in_ram(schema_builder.build());
|
||||
{
|
||||
let mut writer = index.writer_for_tests()?;
|
||||
writer.add_document(doc!(field => json!({"value": false})))?;
|
||||
writer.add_document(doc!(field => json!({"value": true})))?;
|
||||
writer.add_document(doc!(field => json!({"value": i64::from_u64(0u64)})))?;
|
||||
writer.add_document(doc!(field => json!({"value": i64::from_u64(1u64)})))?;
|
||||
writer.commit()?;
|
||||
}
|
||||
|
||||
let agg_req: Aggregations = serde_json::from_value(json!({
|
||||
"cardinality": {
|
||||
"cardinality": {
|
||||
"field": "json.value"
|
||||
},
|
||||
}
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
let res = exec_request(agg_req, &index)?;
|
||||
assert_eq!(res["cardinality"]["value"], 4.0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -17,6 +17,7 @@
|
||||
//! - [Percentiles](PercentilesAggregationReq)
|
||||
|
||||
mod average;
|
||||
mod cardinality;
|
||||
mod count;
|
||||
mod extended_stats;
|
||||
mod max;
|
||||
@@ -29,6 +30,7 @@ mod top_hits;
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub use average::*;
|
||||
pub use cardinality::*;
|
||||
pub use count::*;
|
||||
pub use extended_stats::*;
|
||||
pub use max::*;
|
||||
|
||||
@@ -89,7 +89,7 @@ use crate::{DocAddress, DocId, SegmentOrdinal};
|
||||
/// }
|
||||
/// ```
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
|
||||
pub struct TopHitsAggregation {
|
||||
pub struct TopHitsAggregationReq {
|
||||
sort: Vec<KeyOrder>,
|
||||
size: usize,
|
||||
from: Option<usize>,
|
||||
@@ -164,7 +164,7 @@ fn unsupported_err(parameter: &str) -> crate::Result<()> {
|
||||
))
|
||||
}
|
||||
|
||||
impl TopHitsAggregation {
|
||||
impl TopHitsAggregationReq {
|
||||
/// Validate and resolve field retrieval parameters
|
||||
pub fn validate_and_resolve_field_names(
|
||||
&mut self,
|
||||
@@ -431,7 +431,7 @@ impl Eq for DocSortValuesAndFields {}
|
||||
/// The TopHitsCollector used for collecting over segments and merging results.
|
||||
#[derive(Clone, Serialize, Deserialize, Debug)]
|
||||
pub struct TopHitsTopNComputer {
|
||||
req: TopHitsAggregation,
|
||||
req: TopHitsAggregationReq,
|
||||
top_n: TopNComputer<DocSortValuesAndFields, DocAddress, false>,
|
||||
}
|
||||
|
||||
@@ -443,7 +443,7 @@ impl std::cmp::PartialEq for TopHitsTopNComputer {
|
||||
|
||||
impl TopHitsTopNComputer {
|
||||
/// Create a new TopHitsCollector
|
||||
pub fn new(req: &TopHitsAggregation) -> Self {
|
||||
pub fn new(req: &TopHitsAggregationReq) -> Self {
|
||||
Self {
|
||||
top_n: TopNComputer::new(req.size + req.from.unwrap_or(0)),
|
||||
req: req.clone(),
|
||||
@@ -496,7 +496,7 @@ pub(crate) struct TopHitsSegmentCollector {
|
||||
|
||||
impl TopHitsSegmentCollector {
|
||||
pub fn from_req(
|
||||
req: &TopHitsAggregation,
|
||||
req: &TopHitsAggregationReq,
|
||||
accessor_idx: usize,
|
||||
segment_ordinal: SegmentOrdinal,
|
||||
) -> Self {
|
||||
@@ -509,7 +509,7 @@ impl TopHitsSegmentCollector {
|
||||
fn into_top_hits_collector(
|
||||
self,
|
||||
value_accessors: &HashMap<String, Vec<DynamicColumn>>,
|
||||
req: &TopHitsAggregation,
|
||||
req: &TopHitsAggregationReq,
|
||||
) -> TopHitsTopNComputer {
|
||||
let mut top_hits_computer = TopHitsTopNComputer::new(req);
|
||||
let top_results = self.top_n.into_vec();
|
||||
@@ -532,7 +532,7 @@ impl TopHitsSegmentCollector {
|
||||
fn collect_with(
|
||||
&mut self,
|
||||
doc_id: crate::DocId,
|
||||
req: &TopHitsAggregation,
|
||||
req: &TopHitsAggregationReq,
|
||||
accessors: &[(Column<u64>, ColumnType)],
|
||||
) -> crate::Result<()> {
|
||||
let sorts: Vec<DocValueAndOrder> = req
|
||||
|
||||
@@ -44,11 +44,14 @@
|
||||
//! - [Metric](metric)
|
||||
//! - [Average](metric::AverageAggregation)
|
||||
//! - [Stats](metric::StatsAggregation)
|
||||
//! - [ExtendedStats](metric::ExtendedStatsAggregation)
|
||||
//! - [Min](metric::MinAggregation)
|
||||
//! - [Max](metric::MaxAggregation)
|
||||
//! - [Sum](metric::SumAggregation)
|
||||
//! - [Count](metric::CountAggregation)
|
||||
//! - [Percentiles](metric::PercentilesAggregationReq)
|
||||
//! - [Cardinality](metric::CardinalityAggregationReq)
|
||||
//! - [TopHits](metric::TopHitsAggregationReq)
|
||||
//!
|
||||
//! # Example
|
||||
//! Compute the average metric, by building [`agg_req::Aggregations`], which is built from an
|
||||
|
||||
@@ -16,7 +16,10 @@ use super::metric::{
|
||||
SumAggregation,
|
||||
};
|
||||
use crate::aggregation::bucket::TermMissingAgg;
|
||||
use crate::aggregation::metric::{SegmentExtendedStatsCollector, TopHitsSegmentCollector};
|
||||
use crate::aggregation::metric::{
|
||||
CardinalityAggregationReq, SegmentCardinalityCollector, SegmentExtendedStatsCollector,
|
||||
TopHitsSegmentCollector,
|
||||
};
|
||||
|
||||
pub(crate) trait SegmentAggregationCollector: CollectorClone + Debug {
|
||||
fn add_intermediate_aggregation_result(
|
||||
@@ -169,6 +172,9 @@ pub(crate) fn build_single_agg_segment_collector(
|
||||
accessor_idx,
|
||||
req.segment_ordinal,
|
||||
))),
|
||||
Cardinality(CardinalityAggregationReq { missing, .. }) => Ok(Box::new(
|
||||
SegmentCardinalityCollector::from_req(req.field_type, accessor_idx, missing),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use common::json_path_writer::JSON_PATH_SEGMENT_SEP;
|
||||
use common::json_path_writer::{JSON_END_OF_PATH, JSON_PATH_SEGMENT_SEP};
|
||||
use common::{replace_in_place, JsonPathWriter};
|
||||
use rustc_hash::FxHashMap;
|
||||
|
||||
@@ -83,6 +83,9 @@ fn index_json_object<'a, V: Value<'a>>(
|
||||
positions_per_path: &mut IndexingPositionsPerPath,
|
||||
) {
|
||||
for (json_path_segment, json_value_visitor) in json_visitor {
|
||||
if json_path_segment.as_bytes().contains(&JSON_END_OF_PATH) {
|
||||
continue;
|
||||
}
|
||||
json_path_writer.push(json_path_segment);
|
||||
index_json_value(
|
||||
doc,
|
||||
|
||||
@@ -815,8 +815,9 @@ mod tests {
|
||||
use crate::indexer::NoMergePolicy;
|
||||
use crate::query::{QueryParser, TermQuery};
|
||||
use crate::schema::{
|
||||
self, Facet, FacetOptions, IndexRecordOption, IpAddrOptions, NumericOptions,
|
||||
TextFieldIndexing, TextOptions, Value, FAST, INDEXED, STORED, STRING, TEXT,
|
||||
self, Facet, FacetOptions, IndexRecordOption, IpAddrOptions, JsonObjectOptions,
|
||||
NumericOptions, Schema, TextFieldIndexing, TextOptions, Value, FAST, INDEXED, STORED,
|
||||
STRING, TEXT,
|
||||
};
|
||||
use crate::store::DOCSTORE_CACHE_CAPACITY;
|
||||
use crate::{
|
||||
@@ -1573,11 +1574,11 @@ mod tests {
|
||||
deleted_ids.remove(id);
|
||||
}
|
||||
IndexingOp::DeleteDoc { id } => {
|
||||
existing_ids.remove(&id);
|
||||
existing_ids.remove(id);
|
||||
deleted_ids.insert(*id);
|
||||
}
|
||||
IndexingOp::DeleteDocQuery { id } => {
|
||||
existing_ids.remove(&id);
|
||||
existing_ids.remove(id);
|
||||
deleted_ids.insert(*id);
|
||||
}
|
||||
_ => {}
|
||||
@@ -2378,11 +2379,11 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_bug_1617_2() {
|
||||
assert!(test_operation_strategy(
|
||||
test_operation_strategy(
|
||||
&[
|
||||
IndexingOp::AddDoc {
|
||||
id: 13,
|
||||
value: Default::default()
|
||||
value: Default::default(),
|
||||
},
|
||||
IndexingOp::DeleteDoc { id: 13 },
|
||||
IndexingOp::Commit,
|
||||
@@ -2390,9 +2391,9 @@ mod tests {
|
||||
IndexingOp::Commit,
|
||||
IndexingOp::Merge,
|
||||
],
|
||||
true
|
||||
true,
|
||||
)
|
||||
.is_ok());
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -2492,9 +2493,9 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bug_2442() -> crate::Result<()> {
|
||||
fn test_bug_2442_reserved_character_fast_field() -> crate::Result<()> {
|
||||
let mut schema_builder = schema::Schema::builder();
|
||||
let json_field = schema_builder.add_json_field("json", TEXT | FAST);
|
||||
let json_field = schema_builder.add_json_field("json", FAST | TEXT);
|
||||
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::builder().schema(schema).create_in_ram()?;
|
||||
@@ -2515,4 +2516,21 @@ mod tests {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bug_2442_reserved_character_columnar() -> crate::Result<()> {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let options = JsonObjectOptions::from(FAST).set_expand_dots_enabled();
|
||||
let field = schema_builder.add_json_field("json", options);
|
||||
let index = Index::create_in_ram(schema_builder.build());
|
||||
let mut index_writer = index.writer_for_tests().unwrap();
|
||||
index_writer
|
||||
.add_document(doc!(field=>json!({"\u{0000}": "A"})))
|
||||
.unwrap();
|
||||
index_writer
|
||||
.add_document(doc!(field=>json!({format!("\u{0000}\u{0000}"): "A"})))
|
||||
.unwrap();
|
||||
index_writer.commit().unwrap();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -145,15 +145,27 @@ mod tests_mmap {
|
||||
}
|
||||
}
|
||||
#[test]
|
||||
fn test_json_field_null_byte() {
|
||||
// Test when field name contains a zero byte, which has special meaning in tantivy.
|
||||
// As a workaround, we convert the zero byte to the ASCII character '0'.
|
||||
// https://github.com/quickwit-oss/tantivy/issues/2340
|
||||
// https://github.com/quickwit-oss/tantivy/issues/2193
|
||||
let field_name_in = "\u{0000}";
|
||||
let field_name_out = "0";
|
||||
test_json_field_name(field_name_in, field_name_out);
|
||||
fn test_json_field_null_byte_is_ignored() {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let options = JsonObjectOptions::from(TEXT | FAST).set_expand_dots_enabled();
|
||||
let field = schema_builder.add_json_field("json", options);
|
||||
let index = Index::create_in_ram(schema_builder.build());
|
||||
let mut index_writer = index.writer_for_tests().unwrap();
|
||||
index_writer
|
||||
.add_document(doc!(field=>json!({"key": "test1", "invalidkey\u{0000}": "test2"})))
|
||||
.unwrap();
|
||||
index_writer.commit().unwrap();
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
let segment_reader = searcher.segment_reader(0);
|
||||
let inv_indexer = segment_reader.inverted_index(field).unwrap();
|
||||
let term_dict = inv_indexer.terms();
|
||||
assert_eq!(term_dict.num_terms(), 1);
|
||||
let mut term_bytes = Vec::new();
|
||||
term_dict.ord_to_term(0, &mut term_bytes).unwrap();
|
||||
assert_eq!(term_bytes, b"key\0stest1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_json_field_1byte() {
|
||||
// Test when field name contains a '1' byte, which has special meaning in tantivy.
|
||||
@@ -291,7 +303,7 @@ mod tests_mmap {
|
||||
Type::Str,
|
||||
),
|
||||
(format!("{field_name_out_internal}a"), Type::Str),
|
||||
(format!("{field_name_out_internal}"), Type::Str),
|
||||
(field_name_out_internal.to_string(), Type::Str),
|
||||
(format!("num{field_name_out_internal}"), Type::I64),
|
||||
];
|
||||
expected_fields.sort();
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
use common::json_path_writer::JSON_END_OF_PATH;
|
||||
use common::replace_in_place;
|
||||
use fnv::FnvHashMap;
|
||||
|
||||
/// `Field` is represented by an unsigned 32-bit integer type.
|
||||
@@ -40,13 +38,7 @@ impl PathToUnorderedId {
|
||||
#[cold]
|
||||
fn insert_new_path(&mut self, path: &str) -> u32 {
|
||||
let next_id = self.map.len() as u32;
|
||||
let mut new_path = path.to_string();
|
||||
|
||||
// The unsafe below is safe as long as b'.' and JSON_PATH_SEGMENT_SEP are
|
||||
// valid single byte ut8 strings.
|
||||
// By utf-8 design, they cannot be part of another codepoint.
|
||||
unsafe { replace_in_place(JSON_END_OF_PATH, b'0', new_path.as_bytes_mut()) };
|
||||
|
||||
let new_path = path.to_string();
|
||||
self.map.insert(new_path, next_id);
|
||||
next_id
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
//! In "The beauty and the beast", the term "the" appears in position 0 and position 3.
|
||||
//! This information is useful to run phrase queries.
|
||||
//!
|
||||
//! The [position](crate::SegmentComponent::Positions) file contains all of the
|
||||
//! The [position](crate::index::SegmentComponent::Positions) file contains all of the
|
||||
//! bitpacked positions delta, for all terms of a given field, one term after the other.
|
||||
//!
|
||||
//! Each term is encoded independently.
|
||||
|
||||
@@ -59,7 +59,7 @@ impl<Rec: Recorder> PostingsWriter for JsonPostingsWriter<Rec> {
|
||||
/// The actual serialization format is handled by the `PostingsSerializer`.
|
||||
fn serialize(
|
||||
&self,
|
||||
term_addrs: &[(Field, OrderedPathId, &[u8], Addr)],
|
||||
ordered_term_addrs: &[(Field, OrderedPathId, &[u8], Addr)],
|
||||
ordered_id_to_path: &[&str],
|
||||
ctx: &IndexingContext,
|
||||
serializer: &mut FieldSerializer,
|
||||
@@ -69,7 +69,7 @@ impl<Rec: Recorder> PostingsWriter for JsonPostingsWriter<Rec> {
|
||||
term_buffer.clear_with_field_and_type(Type::Json, Field::from_field_id(0));
|
||||
let mut prev_term_id = u32::MAX;
|
||||
let mut term_path_len = 0; // this will be set in the first iteration
|
||||
for (_field, path_id, term, addr) in term_addrs {
|
||||
for (_field, path_id, term, addr) in ordered_term_addrs {
|
||||
if prev_term_id != path_id.path_id() {
|
||||
term_buffer.truncate_value_bytes(0);
|
||||
term_buffer.append_path(ordered_id_to_path[path_id.path_id() as usize].as_bytes());
|
||||
|
||||
@@ -15,6 +15,7 @@ pub trait Postings: DocSet + 'static {
|
||||
fn term_freq(&self) -> u32;
|
||||
|
||||
/// Returns the positions offsetted with a given value.
|
||||
/// It is not necessary to clear the `output` before calling this method.
|
||||
/// The output vector will be resized to the `term_freq`.
|
||||
fn positions_with_offset(&mut self, offset: u32, output: &mut Vec<u32>);
|
||||
|
||||
|
||||
@@ -22,10 +22,7 @@ pub struct AllWeight;
|
||||
|
||||
impl Weight for AllWeight {
|
||||
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
|
||||
let all_scorer = AllScorer {
|
||||
doc: 0u32,
|
||||
max_doc: reader.max_doc(),
|
||||
};
|
||||
let all_scorer = AllScorer::new(reader.max_doc());
|
||||
Ok(Box::new(BoostScorer::new(all_scorer, boost)))
|
||||
}
|
||||
|
||||
@@ -43,6 +40,13 @@ pub struct AllScorer {
|
||||
max_doc: DocId,
|
||||
}
|
||||
|
||||
impl AllScorer {
|
||||
/// Creates a new AllScorer with `max_doc` docs.
|
||||
pub fn new(max_doc: DocId) -> AllScorer {
|
||||
AllScorer { doc: 0u32, max_doc }
|
||||
}
|
||||
}
|
||||
|
||||
impl DocSet for AllScorer {
|
||||
#[inline(always)]
|
||||
fn advance(&mut self) -> DocId {
|
||||
|
||||
@@ -66,6 +66,10 @@ use crate::schema::{IndexRecordOption, Term};
|
||||
/// Term::from_field_text(title, "diary"),
|
||||
/// IndexRecordOption::Basic,
|
||||
/// ));
|
||||
/// let cow_term_query: Box<dyn Query> = Box::new(TermQuery::new(
|
||||
/// Term::from_field_text(title, "cow"),
|
||||
/// IndexRecordOption::Basic
|
||||
/// ));
|
||||
/// // A TermQuery with "found" in the body
|
||||
/// let body_term_query: Box<dyn Query> = Box::new(TermQuery::new(
|
||||
/// Term::from_field_text(body, "found"),
|
||||
@@ -74,7 +78,7 @@ use crate::schema::{IndexRecordOption, Term};
|
||||
/// // TermQuery "diary" must and "girl" must not be present
|
||||
/// let queries_with_occurs1 = vec![
|
||||
/// (Occur::Must, diary_term_query.box_clone()),
|
||||
/// (Occur::MustNot, girl_term_query),
|
||||
/// (Occur::MustNot, girl_term_query.box_clone()),
|
||||
/// ];
|
||||
/// // Make a BooleanQuery equivalent to
|
||||
/// // title:+diary title:-girl
|
||||
@@ -82,15 +86,10 @@ use crate::schema::{IndexRecordOption, Term};
|
||||
/// let count1 = searcher.search(&diary_must_and_girl_mustnot, &Count)?;
|
||||
/// assert_eq!(count1, 1);
|
||||
///
|
||||
/// // TermQuery for "cow" in the title
|
||||
/// let cow_term_query: Box<dyn Query> = Box::new(TermQuery::new(
|
||||
/// Term::from_field_text(title, "cow"),
|
||||
/// IndexRecordOption::Basic,
|
||||
/// ));
|
||||
/// // "title:diary OR title:cow"
|
||||
/// let title_diary_or_cow = BooleanQuery::new(vec![
|
||||
/// (Occur::Should, diary_term_query.box_clone()),
|
||||
/// (Occur::Should, cow_term_query),
|
||||
/// (Occur::Should, cow_term_query.box_clone()),
|
||||
/// ]);
|
||||
/// let count2 = searcher.search(&title_diary_or_cow, &Count)?;
|
||||
/// assert_eq!(count2, 4);
|
||||
@@ -118,21 +117,38 @@ use crate::schema::{IndexRecordOption, Term};
|
||||
/// ]);
|
||||
/// let count4 = searcher.search(&nested_query, &Count)?;
|
||||
/// assert_eq!(count4, 1);
|
||||
///
|
||||
/// // You may call `with_minimum_required_clauses` to
|
||||
/// // specify the number of should clauses the returned documents must match.
|
||||
/// let minimum_required_query = BooleanQuery::with_minimum_required_clauses(vec![
|
||||
/// (Occur::Should, cow_term_query.box_clone()),
|
||||
/// (Occur::Should, girl_term_query.box_clone()),
|
||||
/// (Occur::Should, diary_term_query.box_clone()),
|
||||
/// ], 2);
|
||||
/// // Return documents contains "Diary Cow", "Diary Girl" or "Cow Girl"
|
||||
/// // Notice: "Diary" isn't "Dairy". ;-)
|
||||
/// let count5 = searcher.search(&minimum_required_query, &Count)?;
|
||||
/// assert_eq!(count5, 1);
|
||||
/// Ok(())
|
||||
/// }
|
||||
/// ```
|
||||
#[derive(Debug)]
|
||||
pub struct BooleanQuery {
|
||||
subqueries: Vec<(Occur, Box<dyn Query>)>,
|
||||
minimum_number_should_match: usize,
|
||||
}
|
||||
|
||||
impl Clone for BooleanQuery {
|
||||
fn clone(&self) -> Self {
|
||||
self.subqueries
|
||||
let subqueries = self
|
||||
.subqueries
|
||||
.iter()
|
||||
.map(|(occur, subquery)| (*occur, subquery.box_clone()))
|
||||
.collect::<Vec<_>>()
|
||||
.into()
|
||||
.collect::<Vec<_>>();
|
||||
Self {
|
||||
subqueries,
|
||||
minimum_number_should_match: self.minimum_number_should_match,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -149,8 +165,9 @@ impl Query for BooleanQuery {
|
||||
.iter()
|
||||
.map(|(occur, subquery)| Ok((*occur, subquery.weight(enable_scoring)?)))
|
||||
.collect::<crate::Result<_>>()?;
|
||||
Ok(Box::new(BooleanWeight::new(
|
||||
Ok(Box::new(BooleanWeight::with_minimum_number_should_match(
|
||||
sub_weights,
|
||||
self.minimum_number_should_match,
|
||||
enable_scoring.is_scoring_enabled(),
|
||||
Box::new(SumWithCoordsCombiner::default),
|
||||
)))
|
||||
@@ -166,7 +183,41 @@ impl Query for BooleanQuery {
|
||||
impl BooleanQuery {
|
||||
/// Creates a new boolean query.
|
||||
pub fn new(subqueries: Vec<(Occur, Box<dyn Query>)>) -> BooleanQuery {
|
||||
BooleanQuery { subqueries }
|
||||
// If the bool query includes at least one should clause
|
||||
// and no Must or MustNot clauses, the default value is 1. Otherwise, the default value is
|
||||
// 0. Keep pace with Elasticsearch.
|
||||
let mut minimum_required = 0;
|
||||
for (occur, _) in &subqueries {
|
||||
match occur {
|
||||
Occur::Should => minimum_required = 1,
|
||||
Occur::Must | Occur::MustNot => {
|
||||
minimum_required = 0;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Self::with_minimum_required_clauses(subqueries, minimum_required)
|
||||
}
|
||||
|
||||
/// Create a new boolean query with minimum number of required should clauses specified.
|
||||
pub fn with_minimum_required_clauses(
|
||||
subqueries: Vec<(Occur, Box<dyn Query>)>,
|
||||
minimum_number_should_match: usize,
|
||||
) -> BooleanQuery {
|
||||
BooleanQuery {
|
||||
subqueries,
|
||||
minimum_number_should_match,
|
||||
}
|
||||
}
|
||||
|
||||
/// Getter for `minimum_number_should_match`
|
||||
pub fn get_minimum_number_should_match(&self) -> usize {
|
||||
self.minimum_number_should_match
|
||||
}
|
||||
|
||||
/// Setter for `minimum_number_should_match`
|
||||
pub fn set_minimum_number_should_match(&mut self, minimum_number_should_match: usize) {
|
||||
self.minimum_number_should_match = minimum_number_should_match;
|
||||
}
|
||||
|
||||
/// Returns the intersection of the queries.
|
||||
@@ -181,6 +232,18 @@ impl BooleanQuery {
|
||||
BooleanQuery::new(subqueries)
|
||||
}
|
||||
|
||||
/// Returns the union of the queries with minimum required clause.
|
||||
pub fn union_with_minimum_required_clauses(
|
||||
queries: Vec<Box<dyn Query>>,
|
||||
minimum_required_clauses: usize,
|
||||
) -> BooleanQuery {
|
||||
let subqueries = queries
|
||||
.into_iter()
|
||||
.map(|sub_query| (Occur::Should, sub_query))
|
||||
.collect();
|
||||
BooleanQuery::with_minimum_required_clauses(subqueries, minimum_required_clauses)
|
||||
}
|
||||
|
||||
/// Helper method to create a boolean query matching a given list of terms.
|
||||
/// The resulting query is a disjunction of the terms.
|
||||
pub fn new_multiterms_query(terms: Vec<Term>) -> BooleanQuery {
|
||||
@@ -203,11 +266,13 @@ impl BooleanQuery {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::HashSet;
|
||||
|
||||
use super::BooleanQuery;
|
||||
use crate::collector::{Count, DocSetCollector};
|
||||
use crate::query::{QueryClone, QueryParser, TermQuery};
|
||||
use crate::schema::{IndexRecordOption, Schema, TEXT};
|
||||
use crate::{DocAddress, Index, Term};
|
||||
use crate::query::{Query, QueryClone, QueryParser, TermQuery};
|
||||
use crate::schema::{Field, IndexRecordOption, Schema, TEXT};
|
||||
use crate::{DocAddress, DocId, Index, Term};
|
||||
|
||||
fn create_test_index() -> crate::Result<Index> {
|
||||
let mut schema_builder = Schema::builder();
|
||||
@@ -223,6 +288,73 @@ mod tests {
|
||||
Ok(index)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_minimum_required() -> crate::Result<()> {
|
||||
fn create_test_index_with<T: IntoIterator<Item = &'static str>>(
|
||||
docs: T,
|
||||
) -> crate::Result<Index> {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let text = schema_builder.add_text_field("text", TEXT);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut writer = index.writer_for_tests()?;
|
||||
for doc in docs {
|
||||
writer.add_document(doc!(text => doc))?;
|
||||
}
|
||||
writer.commit()?;
|
||||
Ok(index)
|
||||
}
|
||||
fn create_boolean_query_with_mr<T: IntoIterator<Item = &'static str>>(
|
||||
queries: T,
|
||||
field: Field,
|
||||
mr: usize,
|
||||
) -> BooleanQuery {
|
||||
let terms = queries
|
||||
.into_iter()
|
||||
.map(|t| Term::from_field_text(field, t))
|
||||
.map(|t| TermQuery::new(t, IndexRecordOption::Basic))
|
||||
.map(|q| -> Box<dyn Query> { Box::new(q) })
|
||||
.collect();
|
||||
BooleanQuery::union_with_minimum_required_clauses(terms, mr)
|
||||
}
|
||||
fn check_doc_id<T: IntoIterator<Item = DocId>>(
|
||||
expected: T,
|
||||
actually: HashSet<DocAddress>,
|
||||
seg: u32,
|
||||
) {
|
||||
assert_eq!(
|
||||
actually,
|
||||
expected
|
||||
.into_iter()
|
||||
.map(|id| DocAddress::new(seg, id))
|
||||
.collect()
|
||||
);
|
||||
}
|
||||
let index = create_test_index_with(["a b c", "a c e", "d f g", "z z z", "c i b"])?;
|
||||
let searcher = index.reader()?.searcher();
|
||||
let text = index.schema().get_field("text").unwrap();
|
||||
// Documents contains 'a c' 'a z' 'a i' 'c z' 'c i' or 'z i' shall be return.
|
||||
let q1 = create_boolean_query_with_mr(["a", "c", "z", "i"], text, 2);
|
||||
let docs = searcher.search(&q1, &DocSetCollector)?;
|
||||
check_doc_id([0, 1, 4], docs, 0);
|
||||
// Documents contains 'a b c', 'a b e', 'a c e' or 'b c e' shall be return.
|
||||
let q2 = create_boolean_query_with_mr(["a", "b", "c", "e"], text, 3);
|
||||
let docs = searcher.search(&q2, &DocSetCollector)?;
|
||||
check_doc_id([0, 1], docs, 0);
|
||||
// Nothing queried since minimum_required is too large.
|
||||
let q3 = create_boolean_query_with_mr(["a", "b"], text, 3);
|
||||
let docs = searcher.search(&q3, &DocSetCollector)?;
|
||||
assert!(docs.is_empty());
|
||||
// When mr is set to zero or one, there are no difference with `Boolean::Union`.
|
||||
let q4 = create_boolean_query_with_mr(["a", "z"], text, 1);
|
||||
let docs = searcher.search(&q4, &DocSetCollector)?;
|
||||
check_doc_id([0, 1, 3], docs, 0);
|
||||
let q5 = create_boolean_query_with_mr(["a", "b"], text, 0);
|
||||
let docs = searcher.search(&q5, &DocSetCollector)?;
|
||||
check_doc_id([0, 1, 4], docs, 0);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_union() -> crate::Result<()> {
|
||||
let index = create_test_index()?;
|
||||
|
||||
@@ -3,6 +3,7 @@ use std::collections::HashMap;
|
||||
use crate::docset::COLLECT_BLOCK_BUFFER_LEN;
|
||||
use crate::index::SegmentReader;
|
||||
use crate::postings::FreqReadingOption;
|
||||
use crate::query::disjunction::Disjunction;
|
||||
use crate::query::explanation::does_not_match;
|
||||
use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner};
|
||||
use crate::query::term_query::TermScorer;
|
||||
@@ -18,6 +19,26 @@ enum SpecializedScorer {
|
||||
Other(Box<dyn Scorer>),
|
||||
}
|
||||
|
||||
fn scorer_disjunction<TScoreCombiner>(
|
||||
scorers: Vec<Box<dyn Scorer>>,
|
||||
score_combiner: TScoreCombiner,
|
||||
minimum_match_required: usize,
|
||||
) -> Box<dyn Scorer>
|
||||
where
|
||||
TScoreCombiner: ScoreCombiner,
|
||||
{
|
||||
debug_assert!(!scorers.is_empty());
|
||||
debug_assert!(minimum_match_required > 1);
|
||||
if scorers.len() == 1 {
|
||||
return scorers.into_iter().next().unwrap(); // Safe unwrap.
|
||||
}
|
||||
Box::new(Disjunction::new(
|
||||
scorers,
|
||||
score_combiner,
|
||||
minimum_match_required,
|
||||
))
|
||||
}
|
||||
|
||||
fn scorer_union<TScoreCombiner>(
|
||||
scorers: Vec<Box<dyn Scorer>>,
|
||||
score_combiner_fn: impl Fn() -> TScoreCombiner,
|
||||
@@ -70,6 +91,7 @@ fn into_box_scorer<TScoreCombiner: ScoreCombiner>(
|
||||
/// Weight associated to the `BoolQuery`.
|
||||
pub struct BooleanWeight<TScoreCombiner: ScoreCombiner> {
|
||||
weights: Vec<(Occur, Box<dyn Weight>)>,
|
||||
minimum_number_should_match: usize,
|
||||
scoring_enabled: bool,
|
||||
score_combiner_fn: Box<dyn Fn() -> TScoreCombiner + Sync + Send>,
|
||||
}
|
||||
@@ -85,6 +107,22 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
|
||||
weights,
|
||||
scoring_enabled,
|
||||
score_combiner_fn,
|
||||
minimum_number_should_match: 1,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new boolean weight with minimum number of required should clauses specified.
|
||||
pub fn with_minimum_number_should_match(
|
||||
weights: Vec<(Occur, Box<dyn Weight>)>,
|
||||
minimum_number_should_match: usize,
|
||||
scoring_enabled: bool,
|
||||
score_combiner_fn: Box<dyn Fn() -> TScoreCombiner + Sync + Send + 'static>,
|
||||
) -> BooleanWeight<TScoreCombiner> {
|
||||
BooleanWeight {
|
||||
weights,
|
||||
minimum_number_should_match,
|
||||
scoring_enabled,
|
||||
score_combiner_fn,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,43 +149,89 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
|
||||
score_combiner_fn: impl Fn() -> TComplexScoreCombiner,
|
||||
) -> crate::Result<SpecializedScorer> {
|
||||
let mut per_occur_scorers = self.per_occur_scorers(reader, boost)?;
|
||||
|
||||
let should_scorer_opt: Option<SpecializedScorer> = per_occur_scorers
|
||||
.remove(&Occur::Should)
|
||||
.map(|scorers| scorer_union(scorers, &score_combiner_fn));
|
||||
// Indicate how should clauses are combined with other clauses.
|
||||
enum CombinationMethod {
|
||||
Ignored,
|
||||
// Only contributes to final score.
|
||||
Optional(SpecializedScorer),
|
||||
// Must be fitted.
|
||||
Required(Box<dyn Scorer>),
|
||||
}
|
||||
let mut must_scorers = per_occur_scorers.remove(&Occur::Must);
|
||||
let should_opt = if let Some(mut should_scorers) = per_occur_scorers.remove(&Occur::Should)
|
||||
{
|
||||
let num_of_should_scorers = should_scorers.len();
|
||||
if self.minimum_number_should_match > num_of_should_scorers {
|
||||
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)));
|
||||
}
|
||||
match self.minimum_number_should_match {
|
||||
0 => CombinationMethod::Optional(scorer_union(should_scorers, &score_combiner_fn)),
|
||||
1 => CombinationMethod::Required(into_box_scorer(
|
||||
scorer_union(should_scorers, &score_combiner_fn),
|
||||
&score_combiner_fn,
|
||||
)),
|
||||
n if num_of_should_scorers == n => {
|
||||
// When num_of_should_scorers equals the number of should clauses,
|
||||
// they are no different from must clauses.
|
||||
must_scorers = match must_scorers.take() {
|
||||
Some(mut must_scorers) => {
|
||||
must_scorers.append(&mut should_scorers);
|
||||
Some(must_scorers)
|
||||
}
|
||||
None => Some(should_scorers),
|
||||
};
|
||||
CombinationMethod::Ignored
|
||||
}
|
||||
_ => CombinationMethod::Required(scorer_disjunction(
|
||||
should_scorers,
|
||||
score_combiner_fn(),
|
||||
self.minimum_number_should_match,
|
||||
)),
|
||||
}
|
||||
} else {
|
||||
// None of should clauses are provided.
|
||||
if self.minimum_number_should_match > 0 {
|
||||
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)));
|
||||
} else {
|
||||
CombinationMethod::Ignored
|
||||
}
|
||||
};
|
||||
let exclude_scorer_opt: Option<Box<dyn Scorer>> = per_occur_scorers
|
||||
.remove(&Occur::MustNot)
|
||||
.map(|scorers| scorer_union(scorers, DoNothingCombiner::default))
|
||||
.map(|specialized_scorer| {
|
||||
.map(|specialized_scorer: SpecializedScorer| {
|
||||
into_box_scorer(specialized_scorer, DoNothingCombiner::default)
|
||||
});
|
||||
|
||||
let must_scorer_opt: Option<Box<dyn Scorer>> = per_occur_scorers
|
||||
.remove(&Occur::Must)
|
||||
.map(intersect_scorers);
|
||||
|
||||
let positive_scorer: SpecializedScorer = match (should_scorer_opt, must_scorer_opt) {
|
||||
(Some(should_scorer), Some(must_scorer)) => {
|
||||
let positive_scorer = match (should_opt, must_scorers) {
|
||||
(CombinationMethod::Ignored, Some(must_scorers)) => {
|
||||
SpecializedScorer::Other(intersect_scorers(must_scorers))
|
||||
}
|
||||
(CombinationMethod::Optional(should_scorer), Some(must_scorers)) => {
|
||||
let must_scorer = intersect_scorers(must_scorers);
|
||||
if self.scoring_enabled {
|
||||
SpecializedScorer::Other(Box::new(RequiredOptionalScorer::<
|
||||
Box<dyn Scorer>,
|
||||
Box<dyn Scorer>,
|
||||
TComplexScoreCombiner,
|
||||
>::new(
|
||||
must_scorer,
|
||||
into_box_scorer(should_scorer, &score_combiner_fn),
|
||||
)))
|
||||
SpecializedScorer::Other(Box::new(
|
||||
RequiredOptionalScorer::<_, _, TScoreCombiner>::new(
|
||||
must_scorer,
|
||||
into_box_scorer(should_scorer, &score_combiner_fn),
|
||||
),
|
||||
))
|
||||
} else {
|
||||
SpecializedScorer::Other(must_scorer)
|
||||
}
|
||||
}
|
||||
(None, Some(must_scorer)) => SpecializedScorer::Other(must_scorer),
|
||||
(Some(should_scorer), None) => should_scorer,
|
||||
(None, None) => {
|
||||
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)));
|
||||
(CombinationMethod::Required(should_scorer), Some(mut must_scorers)) => {
|
||||
must_scorers.push(should_scorer);
|
||||
SpecializedScorer::Other(intersect_scorers(must_scorers))
|
||||
}
|
||||
(CombinationMethod::Ignored, None) => {
|
||||
return Ok(SpecializedScorer::Other(Box::new(EmptyScorer)))
|
||||
}
|
||||
(CombinationMethod::Required(should_scorer), None) => {
|
||||
SpecializedScorer::Other(should_scorer)
|
||||
}
|
||||
// Optional options are promoted to required if no must scorers exists.
|
||||
(CombinationMethod::Optional(should_scorer), None) => should_scorer,
|
||||
};
|
||||
|
||||
if let Some(exclude_scorer) = exclude_scorer_opt {
|
||||
let positive_scorer_boxed = into_box_scorer(positive_scorer, &score_combiner_fn);
|
||||
Ok(SpecializedScorer::Other(Box::new(Exclude::new(
|
||||
|
||||
327
src/query/disjunction.rs
Normal file
327
src/query/disjunction.rs
Normal file
@@ -0,0 +1,327 @@
|
||||
use std::cmp::Ordering;
|
||||
use std::collections::BinaryHeap;
|
||||
|
||||
use crate::query::score_combiner::DoNothingCombiner;
|
||||
use crate::query::{ScoreCombiner, Scorer};
|
||||
use crate::{DocId, DocSet, Score, TERMINATED};
|
||||
|
||||
/// `Disjunction` is responsible for merging `DocSet` from multiple
|
||||
/// source. Specifically, It takes the union of two or more `DocSet`s
|
||||
/// then filtering out elements that appear fewer times than a
|
||||
/// specified threshold.
|
||||
pub struct Disjunction<TScorer, TScoreCombiner = DoNothingCombiner> {
|
||||
chains: BinaryHeap<ScorerWrapper<TScorer>>,
|
||||
minimum_matches_required: usize,
|
||||
score_combiner: TScoreCombiner,
|
||||
|
||||
current_doc: DocId,
|
||||
current_score: Score,
|
||||
}
|
||||
|
||||
/// A wrapper around a `Scorer` that caches the current `doc_id` and implements the `DocSet` trait.
|
||||
/// Also, the `Ord` trait and it's family are implemented reversely. So that we can combine
|
||||
/// `std::BinaryHeap<ScorerWrapper<T>>` to gain a min-heap with current doc id as key.
|
||||
struct ScorerWrapper<T> {
|
||||
scorer: T,
|
||||
current_doc: DocId,
|
||||
}
|
||||
|
||||
impl<T: Scorer> ScorerWrapper<T> {
|
||||
fn new(scorer: T) -> Self {
|
||||
let current_doc = scorer.doc();
|
||||
Self {
|
||||
scorer,
|
||||
current_doc,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Scorer> PartialEq for ScorerWrapper<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.doc() == other.doc()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Scorer> Eq for ScorerWrapper<T> {}
|
||||
|
||||
impl<T: Scorer> PartialOrd for ScorerWrapper<T> {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Scorer> Ord for ScorerWrapper<T> {
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
self.doc().cmp(&other.doc()).reverse()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Scorer> DocSet for ScorerWrapper<T> {
|
||||
fn advance(&mut self) -> DocId {
|
||||
let doc_id = self.scorer.advance();
|
||||
self.current_doc = doc_id;
|
||||
doc_id
|
||||
}
|
||||
|
||||
fn doc(&self) -> DocId {
|
||||
self.current_doc
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> u32 {
|
||||
self.scorer.size_hint()
|
||||
}
|
||||
}
|
||||
|
||||
impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> Disjunction<TScorer, TScoreCombiner> {
|
||||
pub fn new<T: IntoIterator<Item = TScorer>>(
|
||||
docsets: T,
|
||||
score_combiner: TScoreCombiner,
|
||||
minimum_matches_required: usize,
|
||||
) -> Self {
|
||||
debug_assert!(
|
||||
minimum_matches_required > 1,
|
||||
"union scorer works better if just one matches required"
|
||||
);
|
||||
let chains = docsets
|
||||
.into_iter()
|
||||
.map(|doc| ScorerWrapper::new(doc))
|
||||
.collect();
|
||||
let mut disjunction = Self {
|
||||
chains,
|
||||
score_combiner,
|
||||
current_doc: TERMINATED,
|
||||
minimum_matches_required,
|
||||
current_score: 0.0,
|
||||
};
|
||||
if minimum_matches_required > disjunction.chains.len() {
|
||||
return disjunction;
|
||||
}
|
||||
disjunction.advance();
|
||||
disjunction
|
||||
}
|
||||
}
|
||||
|
||||
impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> DocSet
|
||||
for Disjunction<TScorer, TScoreCombiner>
|
||||
{
|
||||
fn advance(&mut self) -> DocId {
|
||||
let mut current_num_matches = 0;
|
||||
while let Some(mut candidate) = self.chains.pop() {
|
||||
let next = candidate.doc();
|
||||
if next != TERMINATED {
|
||||
// Peek next doc.
|
||||
if self.current_doc != next {
|
||||
if current_num_matches >= self.minimum_matches_required {
|
||||
self.chains.push(candidate);
|
||||
self.current_score = self.score_combiner.score();
|
||||
return self.current_doc;
|
||||
}
|
||||
// Reset current_num_matches and scores.
|
||||
current_num_matches = 0;
|
||||
self.current_doc = next;
|
||||
self.score_combiner.clear();
|
||||
}
|
||||
current_num_matches += 1;
|
||||
self.score_combiner.update(&mut candidate.scorer);
|
||||
candidate.advance();
|
||||
self.chains.push(candidate);
|
||||
}
|
||||
}
|
||||
if current_num_matches < self.minimum_matches_required {
|
||||
self.current_doc = TERMINATED;
|
||||
}
|
||||
self.current_score = self.score_combiner.score();
|
||||
self.current_doc
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn doc(&self) -> DocId {
|
||||
self.current_doc
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> u32 {
|
||||
self.chains
|
||||
.iter()
|
||||
.map(|docset| docset.size_hint())
|
||||
.max()
|
||||
.unwrap_or(0u32)
|
||||
}
|
||||
}
|
||||
|
||||
impl<TScorer: Scorer, TScoreCombiner: ScoreCombiner> Scorer
|
||||
for Disjunction<TScorer, TScoreCombiner>
|
||||
{
|
||||
fn score(&mut self) -> Score {
|
||||
self.current_score
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use super::Disjunction;
|
||||
use crate::query::score_combiner::DoNothingCombiner;
|
||||
use crate::query::{ConstScorer, Scorer, SumCombiner, VecDocSet};
|
||||
use crate::{DocId, DocSet, Score, TERMINATED};
|
||||
|
||||
fn conjunct<T: Ord + Copy>(arrays: &[Vec<T>], pass_line: usize) -> Vec<T> {
|
||||
let mut counts = BTreeMap::new();
|
||||
for array in arrays {
|
||||
for &element in array {
|
||||
*counts.entry(element).or_insert(0) += 1;
|
||||
}
|
||||
}
|
||||
counts
|
||||
.iter()
|
||||
.filter_map(|(&element, &count)| {
|
||||
if count >= pass_line {
|
||||
Some(element)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn aux_test_conjunction(vals: Vec<Vec<u32>>, min_match: usize) {
|
||||
let mut union_expected = VecDocSet::from(conjunct(&vals, min_match));
|
||||
let make_scorer = || {
|
||||
Disjunction::new(
|
||||
vals.iter()
|
||||
.cloned()
|
||||
.map(VecDocSet::from)
|
||||
.map(|d| ConstScorer::new(d, 1.0)),
|
||||
DoNothingCombiner,
|
||||
min_match,
|
||||
)
|
||||
};
|
||||
let mut scorer: Disjunction<_, DoNothingCombiner> = make_scorer();
|
||||
let mut count = 0;
|
||||
while scorer.doc() != TERMINATED {
|
||||
assert_eq!(union_expected.doc(), scorer.doc());
|
||||
assert_eq!(union_expected.advance(), scorer.advance());
|
||||
count += 1;
|
||||
}
|
||||
assert_eq!(union_expected.advance(), TERMINATED);
|
||||
assert_eq!(count, make_scorer().count_including_deleted());
|
||||
}
|
||||
|
||||
#[should_panic]
|
||||
#[test]
|
||||
fn test_arg_check1() {
|
||||
aux_test_conjunction(vec![], 0);
|
||||
}
|
||||
|
||||
#[should_panic]
|
||||
#[test]
|
||||
fn test_arg_check2() {
|
||||
aux_test_conjunction(vec![], 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_corner_case() {
|
||||
aux_test_conjunction(vec![], 2);
|
||||
aux_test_conjunction(vec![vec![]; 1000], 2);
|
||||
aux_test_conjunction(vec![vec![]; 100], usize::MAX);
|
||||
aux_test_conjunction(vec![vec![0xC0FFEE]; 10000], usize::MAX);
|
||||
aux_test_conjunction((1..10000u32).map(|i| vec![i]).collect::<Vec<_>>(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conjunction() {
|
||||
aux_test_conjunction(
|
||||
vec![
|
||||
vec![1, 3333, 100000000u32],
|
||||
vec![1, 2, 100000000u32],
|
||||
vec![1, 2, 100000000u32],
|
||||
],
|
||||
2,
|
||||
);
|
||||
aux_test_conjunction(
|
||||
vec![vec![8], vec![3, 4, 0xC0FFEEu32], vec![1, 2, 100000000u32]],
|
||||
2,
|
||||
);
|
||||
aux_test_conjunction(
|
||||
vec![
|
||||
vec![1, 3333, 100000000u32],
|
||||
vec![1, 2, 100000000u32],
|
||||
vec![1, 2, 100000000u32],
|
||||
],
|
||||
3,
|
||||
)
|
||||
}
|
||||
|
||||
// This dummy scorer does nothing but yield doc id increasingly.
|
||||
// with constant score 1.0
|
||||
#[derive(Clone)]
|
||||
struct DummyScorer {
|
||||
cursor: usize,
|
||||
foo: Vec<(DocId, f32)>,
|
||||
}
|
||||
|
||||
impl DummyScorer {
|
||||
fn new(doc_score: Vec<(DocId, f32)>) -> Self {
|
||||
Self {
|
||||
cursor: 0,
|
||||
foo: doc_score,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DocSet for DummyScorer {
|
||||
fn advance(&mut self) -> DocId {
|
||||
self.cursor += 1;
|
||||
self.doc()
|
||||
}
|
||||
|
||||
fn doc(&self) -> DocId {
|
||||
self.foo.get(self.cursor).map(|x| x.0).unwrap_or(TERMINATED)
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> u32 {
|
||||
self.foo.len() as u32
|
||||
}
|
||||
}
|
||||
|
||||
impl Scorer for DummyScorer {
|
||||
fn score(&mut self) -> Score {
|
||||
self.foo.get(self.cursor).map(|x| x.1).unwrap_or(0.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_score_calculate() {
|
||||
let mut scorer = Disjunction::new(
|
||||
vec![
|
||||
DummyScorer::new(vec![(1, 1f32), (2, 1f32)]),
|
||||
DummyScorer::new(vec![(1, 1f32), (3, 1f32)]),
|
||||
DummyScorer::new(vec![(1, 1f32), (4, 1f32)]),
|
||||
DummyScorer::new(vec![(1, 1f32), (2, 1f32)]),
|
||||
DummyScorer::new(vec![(1, 1f32), (2, 1f32)]),
|
||||
],
|
||||
SumCombiner::default(),
|
||||
3,
|
||||
);
|
||||
assert_eq!(scorer.score(), 5.0);
|
||||
assert_eq!(scorer.advance(), 2);
|
||||
assert_eq!(scorer.score(), 3.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_score_calculate_corner_case() {
|
||||
let mut scorer = Disjunction::new(
|
||||
vec![
|
||||
DummyScorer::new(vec![(1, 1f32), (2, 1f32)]),
|
||||
DummyScorer::new(vec![(1, 1f32), (3, 1f32)]),
|
||||
DummyScorer::new(vec![(1, 1f32), (3, 1f32)]),
|
||||
],
|
||||
SumCombiner::default(),
|
||||
2,
|
||||
);
|
||||
assert_eq!(scorer.doc(), 1);
|
||||
assert_eq!(scorer.score(), 3.0);
|
||||
assert_eq!(scorer.advance(), 3);
|
||||
assert_eq!(scorer.score(), 2.0);
|
||||
}
|
||||
}
|
||||
@@ -149,7 +149,7 @@ mod tests {
|
||||
use crate::query::exist_query::ExistsQuery;
|
||||
use crate::query::{BooleanQuery, RangeQuery};
|
||||
use crate::schema::{Facet, FacetOptions, Schema, FAST, INDEXED, STRING, TEXT};
|
||||
use crate::{Index, Searcher};
|
||||
use crate::{Index, Searcher, Term};
|
||||
|
||||
#[test]
|
||||
fn test_exists_query_simple() -> crate::Result<()> {
|
||||
@@ -188,9 +188,8 @@ mod tests {
|
||||
|
||||
// exercise seek
|
||||
let query = BooleanQuery::intersection(vec![
|
||||
Box::new(RangeQuery::new_u64_bounds(
|
||||
"all".to_string(),
|
||||
Bound::Included(50),
|
||||
Box::new(RangeQuery::new(
|
||||
Bound::Included(Term::from_field_u64(all_field, 50)),
|
||||
Bound::Unbounded,
|
||||
)),
|
||||
Box::new(ExistsQuery::new_exists_query("even".to_string())),
|
||||
@@ -198,10 +197,9 @@ mod tests {
|
||||
assert_eq!(searcher.search(&query, &Count)?, 25);
|
||||
|
||||
let query = BooleanQuery::intersection(vec![
|
||||
Box::new(RangeQuery::new_u64_bounds(
|
||||
"all".to_string(),
|
||||
Bound::Included(0),
|
||||
Bound::Excluded(50),
|
||||
Box::new(RangeQuery::new(
|
||||
Bound::Included(Term::from_field_u64(all_field, 0)),
|
||||
Bound::Included(Term::from_field_u64(all_field, 50)),
|
||||
)),
|
||||
Box::new(ExistsQuery::new_exists_query("odd".to_string())),
|
||||
]);
|
||||
|
||||
@@ -5,6 +5,7 @@ mod bm25;
|
||||
mod boolean_query;
|
||||
mod boost_query;
|
||||
mod const_score_query;
|
||||
mod disjunction;
|
||||
mod disjunction_max_query;
|
||||
mod empty_query;
|
||||
mod exclude;
|
||||
@@ -53,7 +54,7 @@ pub use self::phrase_prefix_query::PhrasePrefixQuery;
|
||||
pub use self::phrase_query::PhraseQuery;
|
||||
pub use self::query::{EnableScoring, Query, QueryClone};
|
||||
pub use self::query_parser::{QueryParser, QueryParserError};
|
||||
pub use self::range_query::{FastFieldRangeWeight, IPFastFieldRangeWeight, RangeQuery};
|
||||
pub use self::range_query::{FastFieldRangeWeight, RangeQuery};
|
||||
pub use self::regex_query::RegexQuery;
|
||||
pub use self::reqopt_scorer::RequiredOptionalScorer;
|
||||
pub use self::score_combiner::{
|
||||
|
||||
@@ -145,15 +145,7 @@ impl Query for PhrasePrefixQuery {
|
||||
Bound::Unbounded
|
||||
};
|
||||
|
||||
let mut range_query = RangeQuery::new_term_bounds(
|
||||
enable_scoring
|
||||
.schema()
|
||||
.get_field_name(self.field)
|
||||
.to_owned(),
|
||||
self.prefix.1.typ(),
|
||||
&Bound::Included(self.prefix.1.clone()),
|
||||
&end_term,
|
||||
);
|
||||
let mut range_query = RangeQuery::new(Bound::Included(self.prefix.1.clone()), end_term);
|
||||
range_query.limit(self.max_expansions as u64);
|
||||
range_query.weight(enable_scoring)
|
||||
}
|
||||
|
||||
@@ -97,6 +97,7 @@ pub struct PhrasePrefixScorer<TPostings: Postings> {
|
||||
suffixes: Vec<TPostings>,
|
||||
suffix_offset: u32,
|
||||
phrase_count: u32,
|
||||
suffix_position_buffer: Vec<u32>,
|
||||
}
|
||||
|
||||
impl<TPostings: Postings> PhrasePrefixScorer<TPostings> {
|
||||
@@ -140,6 +141,7 @@ impl<TPostings: Postings> PhrasePrefixScorer<TPostings> {
|
||||
suffixes,
|
||||
suffix_offset: (max_offset - suffix_pos) as u32,
|
||||
phrase_count: 0,
|
||||
suffix_position_buffer: Vec::with_capacity(100),
|
||||
};
|
||||
if phrase_prefix_scorer.doc() != TERMINATED && !phrase_prefix_scorer.matches_prefix() {
|
||||
phrase_prefix_scorer.advance();
|
||||
@@ -153,7 +155,6 @@ impl<TPostings: Postings> PhrasePrefixScorer<TPostings> {
|
||||
|
||||
fn matches_prefix(&mut self) -> bool {
|
||||
let mut count = 0;
|
||||
let mut positions = Vec::new();
|
||||
let current_doc = self.doc();
|
||||
let pos_matching = self.phrase_scorer.get_intersection();
|
||||
for suffix in &mut self.suffixes {
|
||||
@@ -162,8 +163,8 @@ impl<TPostings: Postings> PhrasePrefixScorer<TPostings> {
|
||||
}
|
||||
let doc = suffix.seek(current_doc);
|
||||
if doc == current_doc {
|
||||
suffix.positions_with_offset(self.suffix_offset, &mut positions);
|
||||
count += intersection_count(pos_matching, &positions);
|
||||
suffix.positions_with_offset(self.suffix_offset, &mut self.suffix_position_buffer);
|
||||
count += intersection_count(pos_matching, &self.suffix_position_buffer);
|
||||
}
|
||||
}
|
||||
self.phrase_count = count as u32;
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::fmt;
|
||||
use std::ops::Bound;
|
||||
|
||||
use crate::query::Occur;
|
||||
use crate::schema::{Term, Type};
|
||||
use crate::schema::Term;
|
||||
use crate::Score;
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -14,8 +14,6 @@ pub enum LogicalLiteral {
|
||||
prefix: bool,
|
||||
},
|
||||
Range {
|
||||
field: String,
|
||||
value_type: Type,
|
||||
lower: Bound<Term>,
|
||||
upper: Bound<Term>,
|
||||
},
|
||||
|
||||
@@ -790,8 +790,6 @@ impl QueryParser {
|
||||
let (field, json_path) = try_tuple!(self
|
||||
.split_full_path(&full_path)
|
||||
.ok_or_else(|| QueryParserError::FieldDoesNotExist(full_path.clone())));
|
||||
let field_entry = self.schema.get_field_entry(field);
|
||||
let value_type = field_entry.field_type().value_type();
|
||||
let mut errors = Vec::new();
|
||||
let lower = match self.resolve_bound(field, json_path, &lower) {
|
||||
Ok(bound) => bound,
|
||||
@@ -812,12 +810,8 @@ impl QueryParser {
|
||||
// we failed to parse something. Either way, there is no point emiting it
|
||||
return (None, errors);
|
||||
}
|
||||
let logical_ast = LogicalAst::Leaf(Box::new(LogicalLiteral::Range {
|
||||
field: self.schema.get_field_name(field).to_string(),
|
||||
value_type,
|
||||
lower,
|
||||
upper,
|
||||
}));
|
||||
let logical_ast =
|
||||
LogicalAst::Leaf(Box::new(LogicalLiteral::Range { lower, upper }));
|
||||
(Some(logical_ast), errors)
|
||||
}
|
||||
UserInputLeaf::Set {
|
||||
@@ -884,14 +878,7 @@ fn convert_literal_to_query(
|
||||
Box::new(PhraseQuery::new_with_offset_and_slop(terms, slop))
|
||||
}
|
||||
}
|
||||
LogicalLiteral::Range {
|
||||
field,
|
||||
value_type,
|
||||
lower,
|
||||
upper,
|
||||
} => Box::new(RangeQuery::new_term_bounds(
|
||||
field, value_type, &lower, &upper,
|
||||
)),
|
||||
LogicalLiteral::Range { lower, upper } => Box::new(RangeQuery::new(lower, upper)),
|
||||
LogicalLiteral::Set { elements, .. } => Box::new(TermSetQuery::new(elements)),
|
||||
LogicalLiteral::All => Box::new(AllQuery),
|
||||
}
|
||||
@@ -1136,8 +1123,8 @@ mod test {
|
||||
let query = make_query_parser().parse_query("title:[A TO B]").unwrap();
|
||||
assert_eq!(
|
||||
format!("{query:?}"),
|
||||
"RangeQuery { field: \"title\", value_type: Str, lower_bound: Included([97]), \
|
||||
upper_bound: Included([98]), limit: None }"
|
||||
"RangeQuery { lower_bound: Included(Term(field=0, type=Str, \"a\")), upper_bound: \
|
||||
Included(Term(field=0, type=Str, \"b\")), limit: None }"
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1815,7 +1802,8 @@ mod test {
|
||||
\"bad\"))], prefix: (2, Term(field=0, type=Str, \"wo\")), max_expansions: 50 }), \
|
||||
(Should, PhrasePrefixQuery { field: Field(1), phrase_terms: [(0, Term(field=1, \
|
||||
type=Str, \"big\")), (1, Term(field=1, type=Str, \"bad\"))], prefix: (2, \
|
||||
Term(field=1, type=Str, \"wo\")), max_expansions: 50 })] }"
|
||||
Term(field=1, type=Str, \"wo\")), max_expansions: 50 })], \
|
||||
minimum_number_should_match: 1 }"
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1880,7 +1868,8 @@ mod test {
|
||||
format!("{query:?}"),
|
||||
"BooleanQuery { subqueries: [(Should, FuzzyTermQuery { term: Term(field=0, \
|
||||
type=Str, \"abc\"), distance: 1, transposition_cost_one: true, prefix: false }), \
|
||||
(Should, TermQuery(Term(field=1, type=Str, \"abc\")))] }"
|
||||
(Should, TermQuery(Term(field=1, type=Str, \"abc\")))], \
|
||||
minimum_number_should_match: 1 }"
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1897,7 +1886,8 @@ mod test {
|
||||
format!("{query:?}"),
|
||||
"BooleanQuery { subqueries: [(Should, TermQuery(Term(field=0, type=Str, \
|
||||
\"abc\"))), (Should, FuzzyTermQuery { term: Term(field=1, type=Str, \"abc\"), \
|
||||
distance: 2, transposition_cost_one: false, prefix: true })] }"
|
||||
distance: 2, transposition_cost_one: false, prefix: true })], \
|
||||
minimum_number_should_match: 1 }"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -180,10 +180,12 @@ impl<T: Send + Sync + PartialOrd + Copy + Debug + 'static> DocSet for RangeDocSe
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::ops::Bound;
|
||||
|
||||
use crate::collector::Count;
|
||||
use crate::directory::RamDirectory;
|
||||
use crate::query::RangeQuery;
|
||||
use crate::{schema, IndexBuilder, TantivyDocument};
|
||||
use crate::{schema, IndexBuilder, TantivyDocument, Term};
|
||||
|
||||
#[test]
|
||||
fn range_query_fast_optional_field_minimum() {
|
||||
@@ -218,10 +220,9 @@ mod tests {
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
|
||||
let query = RangeQuery::new_u64_bounds(
|
||||
"score".to_string(),
|
||||
std::ops::Bound::Included(70),
|
||||
std::ops::Bound::Unbounded,
|
||||
let query = RangeQuery::new(
|
||||
Bound::Included(Term::from_field_u64(score_field, 70)),
|
||||
Bound::Unbounded,
|
||||
);
|
||||
|
||||
let count = searcher.search(&query, &Count).unwrap();
|
||||
@@ -2,21 +2,19 @@ use std::ops::Bound;
|
||||
|
||||
use crate::schema::Type;
|
||||
|
||||
mod fast_field_range_query;
|
||||
mod fast_field_range_doc_set;
|
||||
mod range_query;
|
||||
mod range_query_ip_fastfield;
|
||||
mod range_query_u64_fastfield;
|
||||
|
||||
pub use self::range_query::RangeQuery;
|
||||
pub use self::range_query_ip_fastfield::IPFastFieldRangeWeight;
|
||||
pub use self::range_query_u64_fastfield::FastFieldRangeWeight;
|
||||
|
||||
// TODO is this correct?
|
||||
pub(crate) fn is_type_valid_for_fastfield_range_query(typ: Type) -> bool {
|
||||
match typ {
|
||||
Type::U64 | Type::I64 | Type::F64 | Type::Bool | Type::Date => true,
|
||||
Type::Str | Type::U64 | Type::I64 | Type::F64 | Type::Bool | Type::Date => true,
|
||||
Type::IpAddr => true,
|
||||
Type::Str | Type::Facet | Type::Bytes | Type::Json => false,
|
||||
Type::Facet | Type::Bytes | Type::Json => false,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,21 +1,17 @@
|
||||
use std::io;
|
||||
use std::net::Ipv6Addr;
|
||||
use std::ops::{Bound, Range};
|
||||
use std::ops::Bound;
|
||||
|
||||
use columnar::MonotonicallyMappableToU128;
|
||||
use common::{BinarySerializable, BitSet};
|
||||
use common::BitSet;
|
||||
|
||||
use super::map_bound;
|
||||
use super::range_query_u64_fastfield::FastFieldRangeWeight;
|
||||
use crate::error::TantivyError;
|
||||
use crate::index::SegmentReader;
|
||||
use crate::query::explanation::does_not_match;
|
||||
use crate::query::range_query::range_query_ip_fastfield::IPFastFieldRangeWeight;
|
||||
use crate::query::range_query::{is_type_valid_for_fastfield_range_query, map_bound_res};
|
||||
use crate::query::range_query::is_type_valid_for_fastfield_range_query;
|
||||
use crate::query::{BitSetDocSet, ConstScorer, EnableScoring, Explanation, Query, Scorer, Weight};
|
||||
use crate::schema::{Field, IndexRecordOption, Term, Type};
|
||||
use crate::termdict::{TermDictionary, TermStreamer};
|
||||
use crate::{DateTime, DocId, Score};
|
||||
use crate::{DocId, Score};
|
||||
|
||||
/// `RangeQuery` matches all documents that have at least one term within a defined range.
|
||||
///
|
||||
@@ -40,8 +36,10 @@ use crate::{DateTime, DocId, Score};
|
||||
/// ```rust
|
||||
/// use tantivy::collector::Count;
|
||||
/// use tantivy::query::RangeQuery;
|
||||
/// use tantivy::Term;
|
||||
/// use tantivy::schema::{Schema, INDEXED};
|
||||
/// use tantivy::{doc, Index, IndexWriter};
|
||||
/// use std::ops::Bound;
|
||||
/// # fn test() -> tantivy::Result<()> {
|
||||
/// let mut schema_builder = Schema::builder();
|
||||
/// let year_field = schema_builder.add_u64_field("year", INDEXED);
|
||||
@@ -59,7 +57,10 @@ use crate::{DateTime, DocId, Score};
|
||||
///
|
||||
/// let reader = index.reader()?;
|
||||
/// let searcher = reader.searcher();
|
||||
/// let docs_in_the_sixties = RangeQuery::new_u64("year".to_string(), 1960..1970);
|
||||
/// let docs_in_the_sixties = RangeQuery::new(
|
||||
/// Bound::Included(Term::from_field_u64(year_field, 1960)),
|
||||
/// Bound::Excluded(Term::from_field_u64(year_field, 1970)),
|
||||
/// );
|
||||
/// let num_60s_books = searcher.search(&docs_in_the_sixties, &Count)?;
|
||||
/// assert_eq!(num_60s_books, 2285);
|
||||
/// Ok(())
|
||||
@@ -68,246 +69,46 @@ use crate::{DateTime, DocId, Score};
|
||||
/// ```
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct RangeQuery {
|
||||
field: String,
|
||||
value_type: Type,
|
||||
lower_bound: Bound<Vec<u8>>,
|
||||
upper_bound: Bound<Vec<u8>>,
|
||||
lower_bound: Bound<Term>,
|
||||
upper_bound: Bound<Term>,
|
||||
limit: Option<u64>,
|
||||
}
|
||||
|
||||
/// Returns the inner value of a `Bound`
|
||||
pub(crate) fn inner_bound(val: &Bound<Term>) -> Option<&Term> {
|
||||
match val {
|
||||
Bound::Included(term) | Bound::Excluded(term) => Some(term),
|
||||
Bound::Unbounded => None,
|
||||
}
|
||||
}
|
||||
|
||||
impl RangeQuery {
|
||||
/// Creates a new `RangeQuery` from bounded start and end terms.
|
||||
///
|
||||
/// If the value type is not correct, something may go terribly wrong when
|
||||
/// the `Weight` object is created.
|
||||
pub fn new_term_bounds(
|
||||
field: String,
|
||||
value_type: Type,
|
||||
lower_bound: &Bound<Term>,
|
||||
upper_bound: &Bound<Term>,
|
||||
) -> RangeQuery {
|
||||
let verify_and_unwrap_term = |val: &Term| val.serialized_value_bytes().to_owned();
|
||||
pub fn new(lower_bound: Bound<Term>, upper_bound: Bound<Term>) -> RangeQuery {
|
||||
RangeQuery {
|
||||
field,
|
||||
value_type,
|
||||
lower_bound: map_bound(lower_bound, verify_and_unwrap_term),
|
||||
upper_bound: map_bound(upper_bound, verify_and_unwrap_term),
|
||||
lower_bound,
|
||||
upper_bound,
|
||||
limit: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new `RangeQuery` over a `i64` field.
|
||||
///
|
||||
/// If the field is not of the type `i64`, tantivy
|
||||
/// will panic when the `Weight` object is created.
|
||||
pub fn new_i64(field: String, range: Range<i64>) -> RangeQuery {
|
||||
RangeQuery::new_i64_bounds(
|
||||
field,
|
||||
Bound::Included(range.start),
|
||||
Bound::Excluded(range.end),
|
||||
)
|
||||
}
|
||||
|
||||
/// Create a new `RangeQuery` over a `i64` field.
|
||||
///
|
||||
/// The two `Bound` arguments make it possible to create more complex
|
||||
/// ranges than semi-inclusive range.
|
||||
///
|
||||
/// If the field is not of the type `i64`, tantivy
|
||||
/// will panic when the `Weight` object is created.
|
||||
pub fn new_i64_bounds(
|
||||
field: String,
|
||||
lower_bound: Bound<i64>,
|
||||
upper_bound: Bound<i64>,
|
||||
) -> RangeQuery {
|
||||
let make_term_val = |val: &i64| {
|
||||
Term::from_field_i64(Field::from_field_id(0), *val)
|
||||
.serialized_value_bytes()
|
||||
.to_owned()
|
||||
};
|
||||
RangeQuery {
|
||||
field,
|
||||
value_type: Type::I64,
|
||||
lower_bound: map_bound(&lower_bound, make_term_val),
|
||||
upper_bound: map_bound(&upper_bound, make_term_val),
|
||||
limit: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new `RangeQuery` over a `f64` field.
|
||||
///
|
||||
/// If the field is not of the type `f64`, tantivy
|
||||
/// will panic when the `Weight` object is created.
|
||||
pub fn new_f64(field: String, range: Range<f64>) -> RangeQuery {
|
||||
RangeQuery::new_f64_bounds(
|
||||
field,
|
||||
Bound::Included(range.start),
|
||||
Bound::Excluded(range.end),
|
||||
)
|
||||
}
|
||||
|
||||
/// Create a new `RangeQuery` over a `f64` field.
|
||||
///
|
||||
/// The two `Bound` arguments make it possible to create more complex
|
||||
/// ranges than semi-inclusive range.
|
||||
///
|
||||
/// If the field is not of the type `f64`, tantivy
|
||||
/// will panic when the `Weight` object is created.
|
||||
pub fn new_f64_bounds(
|
||||
field: String,
|
||||
lower_bound: Bound<f64>,
|
||||
upper_bound: Bound<f64>,
|
||||
) -> RangeQuery {
|
||||
let make_term_val = |val: &f64| {
|
||||
Term::from_field_f64(Field::from_field_id(0), *val)
|
||||
.serialized_value_bytes()
|
||||
.to_owned()
|
||||
};
|
||||
RangeQuery {
|
||||
field,
|
||||
value_type: Type::F64,
|
||||
lower_bound: map_bound(&lower_bound, make_term_val),
|
||||
upper_bound: map_bound(&upper_bound, make_term_val),
|
||||
limit: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new `RangeQuery` over a `u64` field.
|
||||
///
|
||||
/// The two `Bound` arguments make it possible to create more complex
|
||||
/// ranges than semi-inclusive range.
|
||||
///
|
||||
/// If the field is not of the type `u64`, tantivy
|
||||
/// will panic when the `Weight` object is created.
|
||||
pub fn new_u64_bounds(
|
||||
field: String,
|
||||
lower_bound: Bound<u64>,
|
||||
upper_bound: Bound<u64>,
|
||||
) -> RangeQuery {
|
||||
let make_term_val = |val: &u64| {
|
||||
Term::from_field_u64(Field::from_field_id(0), *val)
|
||||
.serialized_value_bytes()
|
||||
.to_owned()
|
||||
};
|
||||
RangeQuery {
|
||||
field,
|
||||
value_type: Type::U64,
|
||||
lower_bound: map_bound(&lower_bound, make_term_val),
|
||||
upper_bound: map_bound(&upper_bound, make_term_val),
|
||||
limit: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new `RangeQuery` over a `ip` field.
|
||||
///
|
||||
/// If the field is not of the type `ip`, tantivy
|
||||
/// will panic when the `Weight` object is created.
|
||||
pub fn new_ip_bounds(
|
||||
field: String,
|
||||
lower_bound: Bound<Ipv6Addr>,
|
||||
upper_bound: Bound<Ipv6Addr>,
|
||||
) -> RangeQuery {
|
||||
let make_term_val = |val: &Ipv6Addr| {
|
||||
Term::from_field_ip_addr(Field::from_field_id(0), *val)
|
||||
.serialized_value_bytes()
|
||||
.to_owned()
|
||||
};
|
||||
RangeQuery {
|
||||
field,
|
||||
value_type: Type::IpAddr,
|
||||
lower_bound: map_bound(&lower_bound, make_term_val),
|
||||
upper_bound: map_bound(&upper_bound, make_term_val),
|
||||
limit: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new `RangeQuery` over a `u64` field.
|
||||
///
|
||||
/// If the field is not of the type `u64`, tantivy
|
||||
/// will panic when the `Weight` object is created.
|
||||
pub fn new_u64(field: String, range: Range<u64>) -> RangeQuery {
|
||||
RangeQuery::new_u64_bounds(
|
||||
field,
|
||||
Bound::Included(range.start),
|
||||
Bound::Excluded(range.end),
|
||||
)
|
||||
}
|
||||
|
||||
/// Create a new `RangeQuery` over a `date` field.
|
||||
///
|
||||
/// The two `Bound` arguments make it possible to create more complex
|
||||
/// ranges than semi-inclusive range.
|
||||
///
|
||||
/// If the field is not of the type `date`, tantivy
|
||||
/// will panic when the `Weight` object is created.
|
||||
pub fn new_date_bounds(
|
||||
field: String,
|
||||
lower_bound: Bound<DateTime>,
|
||||
upper_bound: Bound<DateTime>,
|
||||
) -> RangeQuery {
|
||||
let make_term_val = |val: &DateTime| {
|
||||
Term::from_field_date(Field::from_field_id(0), *val)
|
||||
.serialized_value_bytes()
|
||||
.to_owned()
|
||||
};
|
||||
RangeQuery {
|
||||
field,
|
||||
value_type: Type::Date,
|
||||
lower_bound: map_bound(&lower_bound, make_term_val),
|
||||
upper_bound: map_bound(&upper_bound, make_term_val),
|
||||
limit: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new `RangeQuery` over a `date` field.
|
||||
///
|
||||
/// If the field is not of the type `date`, tantivy
|
||||
/// will panic when the `Weight` object is created.
|
||||
pub fn new_date(field: String, range: Range<DateTime>) -> RangeQuery {
|
||||
RangeQuery::new_date_bounds(
|
||||
field,
|
||||
Bound::Included(range.start),
|
||||
Bound::Excluded(range.end),
|
||||
)
|
||||
}
|
||||
|
||||
/// Create a new `RangeQuery` over a `Str` field.
|
||||
///
|
||||
/// The two `Bound` arguments make it possible to create more complex
|
||||
/// ranges than semi-inclusive range.
|
||||
///
|
||||
/// If the field is not of the type `Str`, tantivy
|
||||
/// will panic when the `Weight` object is created.
|
||||
pub fn new_str_bounds(
|
||||
field: String,
|
||||
lower_bound: Bound<&str>,
|
||||
upper_bound: Bound<&str>,
|
||||
) -> RangeQuery {
|
||||
let make_term_val = |val: &&str| val.as_bytes().to_vec();
|
||||
RangeQuery {
|
||||
field,
|
||||
value_type: Type::Str,
|
||||
lower_bound: map_bound(&lower_bound, make_term_val),
|
||||
upper_bound: map_bound(&upper_bound, make_term_val),
|
||||
limit: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new `RangeQuery` over a `Str` field.
|
||||
///
|
||||
/// If the field is not of the type `Str`, tantivy
|
||||
/// will panic when the `Weight` object is created.
|
||||
pub fn new_str(field: String, range: Range<&str>) -> RangeQuery {
|
||||
RangeQuery::new_str_bounds(
|
||||
field,
|
||||
Bound::Included(range.start),
|
||||
Bound::Excluded(range.end),
|
||||
)
|
||||
}
|
||||
|
||||
/// Field to search over
|
||||
pub fn field(&self) -> &str {
|
||||
&self.field
|
||||
pub fn field(&self) -> Field {
|
||||
self.get_term().field()
|
||||
}
|
||||
|
||||
/// The value type of the field
|
||||
pub fn value_type(&self) -> Type {
|
||||
self.get_term().typ()
|
||||
}
|
||||
|
||||
pub(crate) fn get_term(&self) -> &Term {
|
||||
inner_bound(&self.lower_bound)
|
||||
.or(inner_bound(&self.upper_bound))
|
||||
.expect("At least one bound must be set")
|
||||
}
|
||||
|
||||
/// Limit the number of term the `RangeQuery` will go through.
|
||||
@@ -319,70 +120,23 @@ impl RangeQuery {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if the type maps to a u64 fast field
|
||||
pub(crate) fn maps_to_u64_fastfield(typ: Type) -> bool {
|
||||
match typ {
|
||||
Type::U64 | Type::I64 | Type::F64 | Type::Bool | Type::Date => true,
|
||||
Type::IpAddr => false,
|
||||
Type::Str | Type::Facet | Type::Bytes | Type::Json => false,
|
||||
}
|
||||
}
|
||||
|
||||
impl Query for RangeQuery {
|
||||
fn weight(&self, enable_scoring: EnableScoring<'_>) -> crate::Result<Box<dyn Weight>> {
|
||||
let schema = enable_scoring.schema();
|
||||
let field_type = schema
|
||||
.get_field_entry(schema.get_field(&self.field)?)
|
||||
.field_type();
|
||||
let value_type = field_type.value_type();
|
||||
if value_type != self.value_type {
|
||||
let err_msg = format!(
|
||||
"Create a range query of the type {:?}, when the field given was of type \
|
||||
{value_type:?}",
|
||||
self.value_type
|
||||
);
|
||||
return Err(TantivyError::SchemaError(err_msg));
|
||||
}
|
||||
let field_type = schema.get_field_entry(self.field()).field_type();
|
||||
|
||||
if field_type.is_fast() && is_type_valid_for_fastfield_range_query(self.value_type) {
|
||||
if field_type.is_ip_addr() {
|
||||
let parse_ip_from_bytes = |data: &Vec<u8>| {
|
||||
let ip_u128_bytes: [u8; 16] = data.as_slice().try_into().map_err(|_| {
|
||||
crate::TantivyError::InvalidArgument(
|
||||
"Expected 8 bytes for ip address".to_string(),
|
||||
)
|
||||
})?;
|
||||
let ip_u128 = u128::from_be_bytes(ip_u128_bytes);
|
||||
crate::Result::<Ipv6Addr>::Ok(Ipv6Addr::from_u128(ip_u128))
|
||||
};
|
||||
let lower_bound = map_bound_res(&self.lower_bound, parse_ip_from_bytes)?;
|
||||
let upper_bound = map_bound_res(&self.upper_bound, parse_ip_from_bytes)?;
|
||||
Ok(Box::new(IPFastFieldRangeWeight::new(
|
||||
self.field.to_string(),
|
||||
lower_bound,
|
||||
upper_bound,
|
||||
)))
|
||||
} else {
|
||||
// We run the range query on u64 value space for performance reasons and simpicity
|
||||
// assert the type maps to u64
|
||||
assert!(maps_to_u64_fastfield(self.value_type));
|
||||
let parse_from_bytes = |data: &Vec<u8>| {
|
||||
u64::from_be(BinarySerializable::deserialize(&mut &data[..]).unwrap())
|
||||
};
|
||||
|
||||
let lower_bound = map_bound(&self.lower_bound, parse_from_bytes);
|
||||
let upper_bound = map_bound(&self.upper_bound, parse_from_bytes);
|
||||
Ok(Box::new(FastFieldRangeWeight::new_u64_lenient(
|
||||
self.field.to_string(),
|
||||
lower_bound,
|
||||
upper_bound,
|
||||
)))
|
||||
}
|
||||
if field_type.is_fast() && is_type_valid_for_fastfield_range_query(self.value_type()) {
|
||||
Ok(Box::new(FastFieldRangeWeight::new(
|
||||
self.field(),
|
||||
self.lower_bound.clone(),
|
||||
self.upper_bound.clone(),
|
||||
)))
|
||||
} else {
|
||||
let verify_and_unwrap_term = |val: &Term| val.serialized_value_bytes().to_owned();
|
||||
Ok(Box::new(RangeWeight {
|
||||
field: self.field.to_string(),
|
||||
lower_bound: self.lower_bound.clone(),
|
||||
upper_bound: self.upper_bound.clone(),
|
||||
field: self.field(),
|
||||
lower_bound: map_bound(&self.lower_bound, verify_and_unwrap_term),
|
||||
upper_bound: map_bound(&self.upper_bound, verify_and_unwrap_term),
|
||||
limit: self.limit,
|
||||
}))
|
||||
}
|
||||
@@ -390,7 +144,7 @@ impl Query for RangeQuery {
|
||||
}
|
||||
|
||||
pub struct RangeWeight {
|
||||
field: String,
|
||||
field: Field,
|
||||
lower_bound: Bound<Vec<u8>>,
|
||||
upper_bound: Bound<Vec<u8>>,
|
||||
limit: Option<u64>,
|
||||
@@ -423,7 +177,7 @@ impl Weight for RangeWeight {
|
||||
let max_doc = reader.max_doc();
|
||||
let mut doc_bitset = BitSet::with_max_value(max_doc);
|
||||
|
||||
let inverted_index = reader.inverted_index(reader.schema().get_field(&self.field)?)?;
|
||||
let inverted_index = reader.inverted_index(self.field)?;
|
||||
let term_dict = inverted_index.terms();
|
||||
let mut term_range = self.term_range(term_dict)?;
|
||||
let mut processed_count = 0;
|
||||
@@ -477,7 +231,7 @@ mod tests {
|
||||
use crate::schema::{
|
||||
Field, IntoIpv6Addr, Schema, TantivyDocument, FAST, INDEXED, STORED, TEXT,
|
||||
};
|
||||
use crate::{Index, IndexWriter};
|
||||
use crate::{Index, IndexWriter, Term};
|
||||
|
||||
#[test]
|
||||
fn test_range_query_simple() -> crate::Result<()> {
|
||||
@@ -499,7 +253,10 @@ mod tests {
|
||||
let reader = index.reader()?;
|
||||
let searcher = reader.searcher();
|
||||
|
||||
let docs_in_the_sixties = RangeQuery::new_u64("year".to_string(), 1960u64..1970u64);
|
||||
let docs_in_the_sixties = RangeQuery::new(
|
||||
Bound::Included(Term::from_field_u64(year_field, 1960)),
|
||||
Bound::Excluded(Term::from_field_u64(year_field, 1970)),
|
||||
);
|
||||
|
||||
// ... or `1960..=1969` if inclusive range is enabled.
|
||||
let count = searcher.search(&docs_in_the_sixties, &Count)?;
|
||||
@@ -530,7 +287,10 @@ mod tests {
|
||||
let reader = index.reader()?;
|
||||
let searcher = reader.searcher();
|
||||
|
||||
let mut docs_in_the_sixties = RangeQuery::new_u64("year".to_string(), 1960u64..1970u64);
|
||||
let mut docs_in_the_sixties = RangeQuery::new(
|
||||
Bound::Included(Term::from_field_u64(year_field, 1960)),
|
||||
Bound::Excluded(Term::from_field_u64(year_field, 1970)),
|
||||
);
|
||||
docs_in_the_sixties.limit(5);
|
||||
|
||||
// due to the limit and no docs in 1963, it's really only 1960..=1965
|
||||
@@ -575,29 +335,29 @@ mod tests {
|
||||
|range_query: RangeQuery| searcher.search(&range_query, &Count).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
count_multiples(RangeQuery::new_i64("intfield".to_string(), 10..11)),
|
||||
count_multiples(RangeQuery::new(
|
||||
Bound::Included(Term::from_field_i64(int_field, 10)),
|
||||
Bound::Excluded(Term::from_field_i64(int_field, 11)),
|
||||
)),
|
||||
9
|
||||
);
|
||||
assert_eq!(
|
||||
count_multiples(RangeQuery::new_i64_bounds(
|
||||
"intfield".to_string(),
|
||||
Bound::Included(10),
|
||||
Bound::Included(11)
|
||||
count_multiples(RangeQuery::new(
|
||||
Bound::Included(Term::from_field_i64(int_field, 10)),
|
||||
Bound::Included(Term::from_field_i64(int_field, 11)),
|
||||
)),
|
||||
18
|
||||
);
|
||||
assert_eq!(
|
||||
count_multiples(RangeQuery::new_i64_bounds(
|
||||
"intfield".to_string(),
|
||||
Bound::Excluded(9),
|
||||
Bound::Included(10)
|
||||
count_multiples(RangeQuery::new(
|
||||
Bound::Excluded(Term::from_field_i64(int_field, 9)),
|
||||
Bound::Included(Term::from_field_i64(int_field, 10)),
|
||||
)),
|
||||
9
|
||||
);
|
||||
assert_eq!(
|
||||
count_multiples(RangeQuery::new_i64_bounds(
|
||||
"intfield".to_string(),
|
||||
Bound::Included(9),
|
||||
count_multiples(RangeQuery::new(
|
||||
Bound::Included(Term::from_field_i64(int_field, 9)),
|
||||
Bound::Unbounded
|
||||
)),
|
||||
91
|
||||
@@ -646,29 +406,29 @@ mod tests {
|
||||
|range_query: RangeQuery| searcher.search(&range_query, &Count).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
count_multiples(RangeQuery::new_f64("floatfield".to_string(), 10.0..11.0)),
|
||||
count_multiples(RangeQuery::new(
|
||||
Bound::Included(Term::from_field_f64(float_field, 10.0)),
|
||||
Bound::Excluded(Term::from_field_f64(float_field, 11.0)),
|
||||
)),
|
||||
9
|
||||
);
|
||||
assert_eq!(
|
||||
count_multiples(RangeQuery::new_f64_bounds(
|
||||
"floatfield".to_string(),
|
||||
Bound::Included(10.0),
|
||||
Bound::Included(11.0)
|
||||
count_multiples(RangeQuery::new(
|
||||
Bound::Included(Term::from_field_f64(float_field, 10.0)),
|
||||
Bound::Included(Term::from_field_f64(float_field, 11.0)),
|
||||
)),
|
||||
18
|
||||
);
|
||||
assert_eq!(
|
||||
count_multiples(RangeQuery::new_f64_bounds(
|
||||
"floatfield".to_string(),
|
||||
Bound::Excluded(9.0),
|
||||
Bound::Included(10.0)
|
||||
count_multiples(RangeQuery::new(
|
||||
Bound::Excluded(Term::from_field_f64(float_field, 9.0)),
|
||||
Bound::Included(Term::from_field_f64(float_field, 10.0)),
|
||||
)),
|
||||
9
|
||||
);
|
||||
assert_eq!(
|
||||
count_multiples(RangeQuery::new_f64_bounds(
|
||||
"floatfield".to_string(),
|
||||
Bound::Included(9.0),
|
||||
count_multiples(RangeQuery::new(
|
||||
Bound::Included(Term::from_field_f64(float_field, 9.0)),
|
||||
Bound::Unbounded
|
||||
)),
|
||||
91
|
||||
|
||||
@@ -1,512 +0,0 @@
|
||||
//! IP Fastfields support efficient scanning for range queries.
|
||||
//! We use this variant only if the fastfield exists, otherwise the default in `range_query` is
|
||||
//! used, which uses the term dictionary + postings.
|
||||
|
||||
use std::net::Ipv6Addr;
|
||||
use std::ops::{Bound, RangeInclusive};
|
||||
|
||||
use columnar::{Column, MonotonicallyMappableToU128};
|
||||
|
||||
use crate::query::range_query::fast_field_range_query::RangeDocSet;
|
||||
use crate::query::{ConstScorer, EmptyScorer, Explanation, Scorer, Weight};
|
||||
use crate::{DocId, DocSet, Score, SegmentReader, TantivyError};
|
||||
|
||||
/// `IPFastFieldRangeWeight` uses the ip address fast field to execute range queries.
|
||||
pub struct IPFastFieldRangeWeight {
|
||||
field: String,
|
||||
lower_bound: Bound<Ipv6Addr>,
|
||||
upper_bound: Bound<Ipv6Addr>,
|
||||
}
|
||||
|
||||
impl IPFastFieldRangeWeight {
|
||||
/// Creates a new IPFastFieldRangeWeight.
|
||||
pub fn new(field: String, lower_bound: Bound<Ipv6Addr>, upper_bound: Bound<Ipv6Addr>) -> Self {
|
||||
Self {
|
||||
field,
|
||||
lower_bound,
|
||||
upper_bound,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Weight for IPFastFieldRangeWeight {
|
||||
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
|
||||
let Some(ip_addr_column): Option<Column<Ipv6Addr>> =
|
||||
reader.fast_fields().column_opt(&self.field)?
|
||||
else {
|
||||
return Ok(Box::new(EmptyScorer));
|
||||
};
|
||||
let value_range = bound_to_value_range(
|
||||
&self.lower_bound,
|
||||
&self.upper_bound,
|
||||
ip_addr_column.min_value(),
|
||||
ip_addr_column.max_value(),
|
||||
);
|
||||
let docset = RangeDocSet::new(value_range, ip_addr_column);
|
||||
Ok(Box::new(ConstScorer::new(docset, boost)))
|
||||
}
|
||||
|
||||
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {
|
||||
let mut scorer = self.scorer(reader, 1.0)?;
|
||||
if scorer.seek(doc) != doc {
|
||||
return Err(TantivyError::InvalidArgument(format!(
|
||||
"Document #({doc}) does not match"
|
||||
)));
|
||||
}
|
||||
let explanation = Explanation::new("Const", scorer.score());
|
||||
Ok(explanation)
|
||||
}
|
||||
}
|
||||
|
||||
fn bound_to_value_range(
|
||||
lower_bound: &Bound<Ipv6Addr>,
|
||||
upper_bound: &Bound<Ipv6Addr>,
|
||||
min_value: Ipv6Addr,
|
||||
max_value: Ipv6Addr,
|
||||
) -> RangeInclusive<Ipv6Addr> {
|
||||
let start_value = match lower_bound {
|
||||
Bound::Included(ip_addr) => *ip_addr,
|
||||
Bound::Excluded(ip_addr) => Ipv6Addr::from(ip_addr.to_u128() + 1),
|
||||
Bound::Unbounded => min_value,
|
||||
};
|
||||
|
||||
let end_value = match upper_bound {
|
||||
Bound::Included(ip_addr) => *ip_addr,
|
||||
Bound::Excluded(ip_addr) => Ipv6Addr::from(ip_addr.to_u128() - 1),
|
||||
Bound::Unbounded => max_value,
|
||||
};
|
||||
start_value..=end_value
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod tests {
|
||||
use proptest::prelude::ProptestConfig;
|
||||
use proptest::strategy::Strategy;
|
||||
use proptest::{prop_oneof, proptest};
|
||||
|
||||
use super::*;
|
||||
use crate::collector::Count;
|
||||
use crate::query::QueryParser;
|
||||
use crate::schema::{Schema, FAST, INDEXED, STORED, STRING};
|
||||
use crate::{Index, IndexWriter};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Doc {
|
||||
pub id: String,
|
||||
pub ip: Ipv6Addr,
|
||||
}
|
||||
|
||||
fn operation_strategy() -> impl Strategy<Value = Doc> {
|
||||
prop_oneof![
|
||||
(0u64..10_000u64).prop_map(doc_from_id_1),
|
||||
(1u64..10_000u64).prop_map(doc_from_id_2),
|
||||
]
|
||||
}
|
||||
|
||||
pub fn doc_from_id_1(id: u64) -> Doc {
|
||||
let id = id * 1000;
|
||||
Doc {
|
||||
// ip != id
|
||||
id: id.to_string(),
|
||||
ip: Ipv6Addr::from_u128(id as u128),
|
||||
}
|
||||
}
|
||||
fn doc_from_id_2(id: u64) -> Doc {
|
||||
let id = id * 1000;
|
||||
Doc {
|
||||
// ip != id
|
||||
id: (id - 1).to_string(),
|
||||
ip: Ipv6Addr::from_u128(id as u128),
|
||||
}
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(10))]
|
||||
#[test]
|
||||
fn test_ip_range_for_docs_prop(ops in proptest::collection::vec(operation_strategy(), 1..1000)) {
|
||||
assert!(test_ip_range_for_docs(&ops).is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ip_range_regression1() {
|
||||
let ops = &[doc_from_id_1(0)];
|
||||
assert!(test_ip_range_for_docs(ops).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ip_range_regression2() {
|
||||
let ops = &[
|
||||
doc_from_id_1(52),
|
||||
doc_from_id_1(63),
|
||||
doc_from_id_1(12),
|
||||
doc_from_id_2(91),
|
||||
doc_from_id_2(33),
|
||||
];
|
||||
assert!(test_ip_range_for_docs(ops).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ip_range_regression3() {
|
||||
let ops = &[doc_from_id_1(1), doc_from_id_1(2), doc_from_id_1(3)];
|
||||
assert!(test_ip_range_for_docs(ops).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ip_range_regression3_simple() {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let ips_field = schema_builder.add_ip_addr_field("ips", FAST | INDEXED);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut writer: IndexWriter = index.writer_for_tests().unwrap();
|
||||
let ip_addrs: Vec<Ipv6Addr> = [1000, 2000, 3000]
|
||||
.into_iter()
|
||||
.map(Ipv6Addr::from_u128)
|
||||
.collect();
|
||||
for &ip_addr in &ip_addrs {
|
||||
writer
|
||||
.add_document(doc!(ips_field=>ip_addr, ips_field=>ip_addr))
|
||||
.unwrap();
|
||||
}
|
||||
writer.commit().unwrap();
|
||||
let searcher = index.reader().unwrap().searcher();
|
||||
let range_weight = IPFastFieldRangeWeight {
|
||||
field: "ips".to_string(),
|
||||
lower_bound: Bound::Included(ip_addrs[1]),
|
||||
upper_bound: Bound::Included(ip_addrs[2]),
|
||||
};
|
||||
let count = range_weight.count(searcher.segment_reader(0)).unwrap();
|
||||
assert_eq!(count, 2);
|
||||
}
|
||||
|
||||
pub fn create_index_from_docs(docs: &[Doc]) -> Index {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let ip_field = schema_builder.add_ip_addr_field("ip", STORED | FAST);
|
||||
let ips_field = schema_builder.add_ip_addr_field("ips", FAST | INDEXED);
|
||||
let text_field = schema_builder.add_text_field("id", STRING | STORED);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
|
||||
{
|
||||
let mut index_writer = index.writer_with_num_threads(2, 60_000_000).unwrap();
|
||||
for doc in docs.iter() {
|
||||
index_writer
|
||||
.add_document(doc!(
|
||||
ips_field => doc.ip,
|
||||
ips_field => doc.ip,
|
||||
ip_field => doc.ip,
|
||||
text_field => doc.id.to_string(),
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
index_writer.commit().unwrap();
|
||||
}
|
||||
index
|
||||
}
|
||||
|
||||
fn test_ip_range_for_docs(docs: &[Doc]) -> crate::Result<()> {
|
||||
let index = create_index_from_docs(docs);
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
|
||||
let get_num_hits = |query| searcher.search(&query, &Count).unwrap();
|
||||
let query_from_text = |text: &str| {
|
||||
QueryParser::for_index(&index, vec![])
|
||||
.parse_query(text)
|
||||
.unwrap()
|
||||
};
|
||||
|
||||
let gen_query_inclusive = |field: &str, ip_range: &RangeInclusive<Ipv6Addr>| {
|
||||
format!("{field}:[{} TO {}]", ip_range.start(), ip_range.end())
|
||||
};
|
||||
|
||||
let test_sample = |sample_docs: &[Doc]| {
|
||||
let mut ips: Vec<Ipv6Addr> = sample_docs.iter().map(|doc| doc.ip).collect();
|
||||
ips.sort();
|
||||
let ip_range = ips[0]..=ips[1];
|
||||
let expected_num_hits = docs
|
||||
.iter()
|
||||
.filter(|doc| (ips[0]..=ips[1]).contains(&doc.ip))
|
||||
.count();
|
||||
|
||||
let query = gen_query_inclusive("ip", &ip_range);
|
||||
assert_eq!(get_num_hits(query_from_text(&query)), expected_num_hits);
|
||||
|
||||
let query = gen_query_inclusive("ips", &ip_range);
|
||||
assert_eq!(get_num_hits(query_from_text(&query)), expected_num_hits);
|
||||
|
||||
// Intersection search
|
||||
let id_filter = sample_docs[0].id.to_string();
|
||||
let expected_num_hits = docs
|
||||
.iter()
|
||||
.filter(|doc| ip_range.contains(&doc.ip) && doc.id == id_filter)
|
||||
.count();
|
||||
let query = format!(
|
||||
"{} AND id:{}",
|
||||
gen_query_inclusive("ip", &ip_range),
|
||||
&id_filter
|
||||
);
|
||||
assert_eq!(get_num_hits(query_from_text(&query)), expected_num_hits);
|
||||
|
||||
// Intersection search on multivalue ip field
|
||||
let id_filter = sample_docs[0].id.to_string();
|
||||
let query = format!(
|
||||
"{} AND id:{}",
|
||||
gen_query_inclusive("ips", &ip_range),
|
||||
&id_filter
|
||||
);
|
||||
assert_eq!(get_num_hits(query_from_text(&query)), expected_num_hits);
|
||||
};
|
||||
|
||||
test_sample(&[docs[0].clone(), docs[0].clone()]);
|
||||
if docs.len() > 1 {
|
||||
test_sample(&[docs[0].clone(), docs[1].clone()]);
|
||||
test_sample(&[docs[1].clone(), docs[1].clone()]);
|
||||
}
|
||||
if docs.len() > 2 {
|
||||
test_sample(&[docs[1].clone(), docs[2].clone()]);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "unstable"))]
|
||||
mod bench {
|
||||
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use test::Bencher;
|
||||
|
||||
use super::tests::*;
|
||||
use super::*;
|
||||
use crate::collector::Count;
|
||||
use crate::query::QueryParser;
|
||||
use crate::Index;
|
||||
|
||||
fn get_index_0_to_100() -> Index {
|
||||
let mut rng = StdRng::from_seed([1u8; 32]);
|
||||
let num_vals = 100_000;
|
||||
let docs: Vec<_> = (0..num_vals)
|
||||
.map(|_i| {
|
||||
let id = if rng.gen_bool(0.01) {
|
||||
"veryfew".to_string() // 1%
|
||||
} else if rng.gen_bool(0.1) {
|
||||
"few".to_string() // 9%
|
||||
} else {
|
||||
"many".to_string() // 90%
|
||||
};
|
||||
Doc {
|
||||
id,
|
||||
// Multiply by 1000, so that we create many buckets in the compact space
|
||||
// The benches depend on this range to select n-percent of elements with the
|
||||
// methods below.
|
||||
ip: Ipv6Addr::from_u128(rng.gen_range(0..100) * 1000),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
create_index_from_docs(&docs)
|
||||
}
|
||||
|
||||
fn get_90_percent() -> RangeInclusive<Ipv6Addr> {
|
||||
let start = Ipv6Addr::from_u128(0);
|
||||
let end = Ipv6Addr::from_u128(90 * 1000);
|
||||
start..=end
|
||||
}
|
||||
|
||||
fn get_10_percent() -> RangeInclusive<Ipv6Addr> {
|
||||
let start = Ipv6Addr::from_u128(0);
|
||||
let end = Ipv6Addr::from_u128(10 * 1000);
|
||||
start..=end
|
||||
}
|
||||
|
||||
fn get_1_percent() -> RangeInclusive<Ipv6Addr> {
|
||||
let start = Ipv6Addr::from_u128(10 * 1000);
|
||||
let end = Ipv6Addr::from_u128(10 * 1000);
|
||||
start..=end
|
||||
}
|
||||
|
||||
fn excute_query(
|
||||
field: &str,
|
||||
ip_range: RangeInclusive<Ipv6Addr>,
|
||||
suffix: &str,
|
||||
index: &Index,
|
||||
) -> usize {
|
||||
let gen_query_inclusive = |from: &Ipv6Addr, to: &Ipv6Addr| {
|
||||
format!(
|
||||
"{}:[{} TO {}] {}",
|
||||
field,
|
||||
&from.to_string(),
|
||||
&to.to_string(),
|
||||
suffix
|
||||
)
|
||||
};
|
||||
|
||||
let query = gen_query_inclusive(ip_range.start(), ip_range.end());
|
||||
let query_from_text = |text: &str| {
|
||||
QueryParser::for_index(index, vec![])
|
||||
.parse_query(text)
|
||||
.unwrap()
|
||||
};
|
||||
let query = query_from_text(&query);
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
searcher.search(&query, &(Count)).unwrap()
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_90_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ip", get_90_percent(), "", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_10_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ip", get_10_percent(), "", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_1_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ip", get_1_percent(), "", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_10_percent_intersect_with_10_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ip", get_10_percent(), "AND id:few", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_1_percent_intersect_with_10_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ip", get_1_percent(), "AND id:few", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_1_percent_intersect_with_90_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ip", get_1_percent(), "AND id:many", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_1_percent_intersect_with_1_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ip", get_1_percent(), "AND id:veryfew", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_10_percent_intersect_with_90_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ip", get_10_percent(), "AND id:many", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_90_percent_intersect_with_90_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ip", get_90_percent(), "AND id:many", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_90_percent_intersect_with_10_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ip", get_90_percent(), "AND id:few", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_90_percent_intersect_with_1_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ip", get_90_percent(), "AND id:veryfew", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_90_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ips", get_90_percent(), "", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_10_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ips", get_10_percent(), "", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_1_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ips", get_1_percent(), "", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_10_percent_intersect_with_10_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ips", get_10_percent(), "AND id:few", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_1_percent_intersect_with_10_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ips", get_1_percent(), "AND id:few", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_1_percent_intersect_with_90_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ips", get_1_percent(), "AND id:many", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_1_percent_intersect_with_1_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ips", get_1_percent(), "AND id:veryfew", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_10_percent_intersect_with_90_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ips", get_10_percent(), "AND id:many", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_90_percent_intersect_with_90_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ips", get_90_percent(), "AND id:many", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_90_percent_intersect_with_10_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ips", get_90_percent(), "AND id:few", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_90_percent_intersect_with_1_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ips", get_90_percent(), "AND id:veryfew", &index));
|
||||
}
|
||||
}
|
||||
@@ -2,54 +2,34 @@
|
||||
//! We use this variant only if the fastfield exists, otherwise the default in `range_query` is
|
||||
//! used, which uses the term dictionary + postings.
|
||||
|
||||
use std::net::Ipv6Addr;
|
||||
use std::ops::{Bound, RangeInclusive};
|
||||
|
||||
use columnar::{ColumnType, HasAssociatedColumnType, MonotonicallyMappableToU64};
|
||||
use columnar::{Column, MonotonicallyMappableToU128, MonotonicallyMappableToU64, StrColumn};
|
||||
use common::BinarySerializable;
|
||||
|
||||
use super::fast_field_range_query::RangeDocSet;
|
||||
use super::map_bound;
|
||||
use crate::query::{ConstScorer, EmptyScorer, Explanation, Query, Scorer, Weight};
|
||||
use crate::{DocId, DocSet, Score, SegmentReader, TantivyError};
|
||||
use super::fast_field_range_doc_set::RangeDocSet;
|
||||
use super::{map_bound, map_bound_res};
|
||||
use crate::query::range_query::range_query::inner_bound;
|
||||
use crate::query::{AllScorer, ConstScorer, EmptyScorer, Explanation, Query, Scorer, Weight};
|
||||
use crate::schema::{Field, Type};
|
||||
use crate::{DocId, DocSet, Score, SegmentReader, TantivyError, Term};
|
||||
|
||||
/// `FastFieldRangeWeight` uses the fast field to execute range queries.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct FastFieldRangeWeight {
|
||||
field: String,
|
||||
lower_bound: Bound<u64>,
|
||||
upper_bound: Bound<u64>,
|
||||
column_type_opt: Option<ColumnType>,
|
||||
lower_bound: Bound<Term>,
|
||||
upper_bound: Bound<Term>,
|
||||
field: Field,
|
||||
}
|
||||
|
||||
impl FastFieldRangeWeight {
|
||||
/// Create a new FastFieldRangeWeight, using the u64 representation of any fast field.
|
||||
pub(crate) fn new_u64_lenient(
|
||||
field: String,
|
||||
lower_bound: Bound<u64>,
|
||||
upper_bound: Bound<u64>,
|
||||
) -> Self {
|
||||
let lower_bound = map_bound(&lower_bound, |val| *val);
|
||||
let upper_bound = map_bound(&upper_bound, |val| *val);
|
||||
/// Create a new FastFieldRangeWeight
|
||||
pub(crate) fn new(field: Field, lower_bound: Bound<Term>, upper_bound: Bound<Term>) -> Self {
|
||||
Self {
|
||||
field,
|
||||
lower_bound,
|
||||
upper_bound,
|
||||
column_type_opt: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new `FastFieldRangeWeight` for a range of a u64-mappable type .
|
||||
pub fn new<T: HasAssociatedColumnType + MonotonicallyMappableToU64>(
|
||||
field: String,
|
||||
lower_bound: Bound<T>,
|
||||
upper_bound: Bound<T>,
|
||||
) -> Self {
|
||||
let lower_bound = map_bound(&lower_bound, |val| val.to_u64());
|
||||
let upper_bound = map_bound(&upper_bound, |val| val.to_u64());
|
||||
Self {
|
||||
field,
|
||||
lower_bound,
|
||||
upper_bound,
|
||||
column_type_opt: Some(T::column_type()),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -65,30 +45,101 @@ impl Query for FastFieldRangeWeight {
|
||||
|
||||
impl Weight for FastFieldRangeWeight {
|
||||
fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> {
|
||||
let fast_field_reader = reader.fast_fields();
|
||||
let column_type_opt: Option<[ColumnType; 1]> =
|
||||
self.column_type_opt.map(|column_type| [column_type]);
|
||||
let column_type_opt_ref: Option<&[ColumnType]> = column_type_opt
|
||||
.as_ref()
|
||||
.map(|column_types| column_types.as_slice());
|
||||
let Some((column, _)) =
|
||||
fast_field_reader.u64_lenient_for_type(column_type_opt_ref, &self.field)?
|
||||
else {
|
||||
return Ok(Box::new(EmptyScorer));
|
||||
};
|
||||
#[allow(clippy::reversed_empty_ranges)]
|
||||
let value_range = bound_to_value_range(
|
||||
&self.lower_bound,
|
||||
&self.upper_bound,
|
||||
column.min_value(),
|
||||
column.max_value(),
|
||||
)
|
||||
.unwrap_or(1..=0); // empty range
|
||||
if value_range.is_empty() {
|
||||
return Ok(Box::new(EmptyScorer));
|
||||
// Check if both bounds are Bound::Unbounded
|
||||
if self.lower_bound == Bound::Unbounded && self.upper_bound == Bound::Unbounded {
|
||||
return Ok(Box::new(AllScorer::new(reader.max_doc())));
|
||||
}
|
||||
let field_name = reader.schema().get_field_name(self.field);
|
||||
let field_type = reader.schema().get_field_entry(self.field).field_type();
|
||||
|
||||
let term = inner_bound(&self.lower_bound)
|
||||
.or(inner_bound(&self.upper_bound))
|
||||
.expect("At least one bound must be set");
|
||||
assert_eq!(
|
||||
term.typ(),
|
||||
field_type.value_type(),
|
||||
"Field is of type {:?}, but got term of type {:?}",
|
||||
field_type,
|
||||
term.typ()
|
||||
);
|
||||
if field_type.is_ip_addr() {
|
||||
let parse_ip_from_bytes = |term: &Term| {
|
||||
term.value().as_ip_addr().ok_or_else(|| {
|
||||
crate::TantivyError::InvalidArgument("Expected ip address".to_string())
|
||||
})
|
||||
};
|
||||
let lower_bound = map_bound_res(&self.lower_bound, parse_ip_from_bytes)?;
|
||||
let upper_bound = map_bound_res(&self.upper_bound, parse_ip_from_bytes)?;
|
||||
|
||||
let Some(ip_addr_column): Option<Column<Ipv6Addr>> =
|
||||
reader.fast_fields().column_opt(field_name)?
|
||||
else {
|
||||
return Ok(Box::new(EmptyScorer));
|
||||
};
|
||||
let value_range = bound_to_value_range_ip(
|
||||
&lower_bound,
|
||||
&upper_bound,
|
||||
ip_addr_column.min_value(),
|
||||
ip_addr_column.max_value(),
|
||||
);
|
||||
let docset = RangeDocSet::new(value_range, ip_addr_column);
|
||||
Ok(Box::new(ConstScorer::new(docset, boost)))
|
||||
} else {
|
||||
let (lower_bound, upper_bound) = if field_type.is_str() {
|
||||
let Some(str_dict_column): Option<StrColumn> =
|
||||
reader.fast_fields().str(field_name)?
|
||||
else {
|
||||
return Ok(Box::new(EmptyScorer));
|
||||
};
|
||||
let dict = str_dict_column.dictionary();
|
||||
|
||||
let lower_bound = map_bound(&self.lower_bound, |term| {
|
||||
term.serialized_value_bytes().to_vec()
|
||||
});
|
||||
let upper_bound = map_bound(&self.upper_bound, |term| {
|
||||
term.serialized_value_bytes().to_vec()
|
||||
});
|
||||
// Get term ids for terms
|
||||
let (lower_bound, upper_bound) =
|
||||
dict.term_bounds_to_ord(lower_bound, upper_bound)?;
|
||||
(lower_bound, upper_bound)
|
||||
} else {
|
||||
assert!(
|
||||
maps_to_u64_fastfield(field_type.value_type()),
|
||||
"{:?}",
|
||||
field_type
|
||||
);
|
||||
let parse_from_bytes = |term: &Term| {
|
||||
u64::from_be(
|
||||
BinarySerializable::deserialize(&mut &term.serialized_value_bytes()[..])
|
||||
.unwrap(),
|
||||
)
|
||||
};
|
||||
|
||||
let lower_bound = map_bound(&self.lower_bound, parse_from_bytes);
|
||||
let upper_bound = map_bound(&self.upper_bound, parse_from_bytes);
|
||||
(lower_bound, upper_bound)
|
||||
};
|
||||
|
||||
let fast_field_reader = reader.fast_fields();
|
||||
let Some((column, _)) = fast_field_reader.u64_lenient_for_type(None, field_name)?
|
||||
else {
|
||||
return Ok(Box::new(EmptyScorer));
|
||||
};
|
||||
#[allow(clippy::reversed_empty_ranges)]
|
||||
let value_range = bound_to_value_range(
|
||||
&lower_bound,
|
||||
&upper_bound,
|
||||
column.min_value(),
|
||||
column.max_value(),
|
||||
)
|
||||
.unwrap_or(1..=0); // empty range
|
||||
if value_range.is_empty() {
|
||||
return Ok(Box::new(EmptyScorer));
|
||||
}
|
||||
let docset = RangeDocSet::new(value_range, column);
|
||||
Ok(Box::new(ConstScorer::new(docset, boost)))
|
||||
}
|
||||
let docset = RangeDocSet::new(value_range, column);
|
||||
Ok(Box::new(ConstScorer::new(docset, boost)))
|
||||
}
|
||||
|
||||
fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> {
|
||||
@@ -104,6 +155,35 @@ impl Weight for FastFieldRangeWeight {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if the type maps to a u64 fast field
|
||||
pub(crate) fn maps_to_u64_fastfield(typ: Type) -> bool {
|
||||
match typ {
|
||||
Type::U64 | Type::I64 | Type::F64 | Type::Bool | Type::Date => true,
|
||||
Type::IpAddr => false,
|
||||
Type::Str | Type::Facet | Type::Bytes | Type::Json => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn bound_to_value_range_ip(
|
||||
lower_bound: &Bound<Ipv6Addr>,
|
||||
upper_bound: &Bound<Ipv6Addr>,
|
||||
min_value: Ipv6Addr,
|
||||
max_value: Ipv6Addr,
|
||||
) -> RangeInclusive<Ipv6Addr> {
|
||||
let start_value = match lower_bound {
|
||||
Bound::Included(ip_addr) => *ip_addr,
|
||||
Bound::Excluded(ip_addr) => Ipv6Addr::from(ip_addr.to_u128() + 1),
|
||||
Bound::Unbounded => min_value,
|
||||
};
|
||||
|
||||
let end_value = match upper_bound {
|
||||
Bound::Included(ip_addr) => *ip_addr,
|
||||
Bound::Excluded(ip_addr) => Ipv6Addr::from(ip_addr.to_u128() - 1),
|
||||
Bound::Unbounded => max_value,
|
||||
};
|
||||
start_value..=end_value
|
||||
}
|
||||
|
||||
// Returns None, if the range cannot be converted to a inclusive range (which equals to a empty
|
||||
// range).
|
||||
fn bound_to_value_range<T: MonotonicallyMappableToU64>(
|
||||
@@ -137,11 +217,72 @@ pub mod tests {
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::SeedableRng;
|
||||
|
||||
use crate::collector::Count;
|
||||
use crate::collector::{Count, TopDocs};
|
||||
use crate::query::range_query::range_query_u64_fastfield::FastFieldRangeWeight;
|
||||
use crate::query::{QueryParser, Weight};
|
||||
use crate::schema::{NumericOptions, Schema, SchemaBuilder, FAST, INDEXED, STORED, STRING};
|
||||
use crate::{Index, IndexWriter, TERMINATED};
|
||||
use crate::schema::{
|
||||
NumericOptions, Schema, SchemaBuilder, FAST, INDEXED, STORED, STRING, TEXT,
|
||||
};
|
||||
use crate::{Index, IndexWriter, Term, TERMINATED};
|
||||
|
||||
#[test]
|
||||
fn test_text_field_ff_range_query() -> crate::Result<()> {
|
||||
let mut schema_builder = Schema::builder();
|
||||
schema_builder.add_text_field("title", TEXT | FAST);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema.clone());
|
||||
let mut index_writer = index.writer_for_tests()?;
|
||||
let title = schema.get_field("title").unwrap();
|
||||
index_writer.add_document(doc!(
|
||||
title => "bbb"
|
||||
))?;
|
||||
index_writer.add_document(doc!(
|
||||
title => "ddd"
|
||||
))?;
|
||||
index_writer.commit()?;
|
||||
let reader = index.reader()?;
|
||||
let searcher = reader.searcher();
|
||||
let query_parser = QueryParser::for_index(&index, vec![title]);
|
||||
|
||||
let test_query = |query, num_hits| {
|
||||
let query = query_parser.parse_query(query).unwrap();
|
||||
let top_docs = searcher.search(&query, &TopDocs::with_limit(10)).unwrap();
|
||||
assert_eq!(top_docs.len(), num_hits);
|
||||
};
|
||||
|
||||
test_query("title:[aaa TO ccc]", 1);
|
||||
test_query("title:[aaa TO bbb]", 1);
|
||||
test_query("title:[bbb TO bbb]", 1);
|
||||
test_query("title:[bbb TO ddd]", 2);
|
||||
test_query("title:[bbb TO eee]", 2);
|
||||
test_query("title:[bb TO eee]", 2);
|
||||
test_query("title:[ccc TO ccc]", 0);
|
||||
test_query("title:[ccc TO ddd]", 1);
|
||||
test_query("title:[ccc TO eee]", 1);
|
||||
|
||||
test_query("title:[aaa TO *}", 2);
|
||||
test_query("title:[bbb TO *]", 2);
|
||||
test_query("title:[bb TO *]", 2);
|
||||
test_query("title:[ccc TO *]", 1);
|
||||
test_query("title:[ddd TO *]", 1);
|
||||
test_query("title:[dddd TO *]", 0);
|
||||
|
||||
test_query("title:{aaa TO *}", 2);
|
||||
test_query("title:{bbb TO *]", 1);
|
||||
test_query("title:{bb TO *]", 2);
|
||||
test_query("title:{ccc TO *]", 1);
|
||||
test_query("title:{ddd TO *]", 0);
|
||||
test_query("title:{dddd TO *]", 0);
|
||||
|
||||
test_query("title:[* TO bb]", 0);
|
||||
test_query("title:[* TO bbb]", 1);
|
||||
test_query("title:[* TO ccc]", 1);
|
||||
test_query("title:[* TO ddd]", 2);
|
||||
test_query("title:[* TO ddd}", 1);
|
||||
test_query("title:[* TO eee]", 2);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Doc {
|
||||
@@ -159,14 +300,14 @@ pub mod tests {
|
||||
fn doc_from_id_1(id: u64) -> Doc {
|
||||
let id = id * 1000;
|
||||
Doc {
|
||||
id_name: id.to_string(),
|
||||
id_name: format!("id_name{:010}", id),
|
||||
id,
|
||||
}
|
||||
}
|
||||
fn doc_from_id_2(id: u64) -> Doc {
|
||||
let id = id * 1000;
|
||||
Doc {
|
||||
id_name: (id - 1).to_string(),
|
||||
id_name: format!("id_name{:010}", id - 1),
|
||||
id,
|
||||
}
|
||||
}
|
||||
@@ -213,10 +354,10 @@ pub mod tests {
|
||||
writer.add_document(doc!(field=>52_000u64)).unwrap();
|
||||
writer.commit().unwrap();
|
||||
let searcher = index.reader().unwrap().searcher();
|
||||
let range_query = FastFieldRangeWeight::new_u64_lenient(
|
||||
"test_field".to_string(),
|
||||
Bound::Included(50_000),
|
||||
Bound::Included(50_002),
|
||||
let range_query = FastFieldRangeWeight::new(
|
||||
field,
|
||||
Bound::Included(Term::from_field_u64(field, 50_000)),
|
||||
Bound::Included(Term::from_field_u64(field, 50_002)),
|
||||
);
|
||||
let scorer = range_query
|
||||
.scorer(searcher.segment_reader(0), 1.0f32)
|
||||
@@ -254,7 +395,8 @@ pub mod tests {
|
||||
NumericOptions::default().set_fast().set_indexed(),
|
||||
);
|
||||
|
||||
let text_field = schema_builder.add_text_field("id_name", STRING | STORED);
|
||||
let text_field = schema_builder.add_text_field("id_name", STRING | STORED | FAST);
|
||||
let text_field2 = schema_builder.add_text_field("id_name_fast", STRING | STORED | FAST);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
|
||||
@@ -273,6 +415,7 @@ pub mod tests {
|
||||
id_f64_field => doc.id as f64,
|
||||
id_i64_field => doc.id as i64,
|
||||
text_field => doc.id_name.to_string(),
|
||||
text_field2 => doc.id_name.to_string(),
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
@@ -317,6 +460,24 @@ pub mod tests {
|
||||
let query = gen_query_inclusive("ids", ids[0]..=ids[1]);
|
||||
assert_eq!(get_num_hits(query_from_text(&query)), expected_num_hits);
|
||||
|
||||
// Text query
|
||||
{
|
||||
let test_text_query = |field_name: &str| {
|
||||
let mut id_names: Vec<&str> =
|
||||
sample_docs.iter().map(|doc| doc.id_name.as_str()).collect();
|
||||
id_names.sort();
|
||||
let expected_num_hits = docs
|
||||
.iter()
|
||||
.filter(|doc| (id_names[0]..=id_names[1]).contains(&doc.id_name.as_str()))
|
||||
.count();
|
||||
let query = format!("{}:[{} TO {}]", field_name, id_names[0], id_names[1]);
|
||||
assert_eq!(get_num_hits(query_from_text(&query)), expected_num_hits);
|
||||
};
|
||||
|
||||
test_text_query("id_name");
|
||||
test_text_query("id_name_fast");
|
||||
}
|
||||
|
||||
// Exclusive range
|
||||
let expected_num_hits = docs
|
||||
.iter()
|
||||
@@ -394,6 +555,202 @@ pub mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod ip_range_tests {
|
||||
use proptest::prelude::ProptestConfig;
|
||||
use proptest::strategy::Strategy;
|
||||
use proptest::{prop_oneof, proptest};
|
||||
|
||||
use super::*;
|
||||
use crate::collector::Count;
|
||||
use crate::query::QueryParser;
|
||||
use crate::schema::{Schema, FAST, INDEXED, STORED, STRING};
|
||||
use crate::{Index, IndexWriter};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Doc {
|
||||
pub id: String,
|
||||
pub ip: Ipv6Addr,
|
||||
}
|
||||
|
||||
fn operation_strategy() -> impl Strategy<Value = Doc> {
|
||||
prop_oneof![
|
||||
(0u64..10_000u64).prop_map(doc_from_id_1),
|
||||
(1u64..10_000u64).prop_map(doc_from_id_2),
|
||||
]
|
||||
}
|
||||
|
||||
pub fn doc_from_id_1(id: u64) -> Doc {
|
||||
let id = id * 1000;
|
||||
Doc {
|
||||
// ip != id
|
||||
id: id.to_string(),
|
||||
ip: Ipv6Addr::from_u128(id as u128),
|
||||
}
|
||||
}
|
||||
fn doc_from_id_2(id: u64) -> Doc {
|
||||
let id = id * 1000;
|
||||
Doc {
|
||||
// ip != id
|
||||
id: (id - 1).to_string(),
|
||||
ip: Ipv6Addr::from_u128(id as u128),
|
||||
}
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(10))]
|
||||
#[test]
|
||||
fn test_ip_range_for_docs_prop(ops in proptest::collection::vec(operation_strategy(), 1..1000)) {
|
||||
assert!(test_ip_range_for_docs(&ops).is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ip_range_regression1() {
|
||||
let ops = &[doc_from_id_1(0)];
|
||||
assert!(test_ip_range_for_docs(ops).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ip_range_regression2() {
|
||||
let ops = &[
|
||||
doc_from_id_1(52),
|
||||
doc_from_id_1(63),
|
||||
doc_from_id_1(12),
|
||||
doc_from_id_2(91),
|
||||
doc_from_id_2(33),
|
||||
];
|
||||
assert!(test_ip_range_for_docs(ops).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ip_range_regression3() {
|
||||
let ops = &[doc_from_id_1(1), doc_from_id_1(2), doc_from_id_1(3)];
|
||||
assert!(test_ip_range_for_docs(ops).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ip_range_regression3_simple() {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let ips_field = schema_builder.add_ip_addr_field("ips", FAST | INDEXED);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut writer: IndexWriter = index.writer_for_tests().unwrap();
|
||||
let ip_addrs: Vec<Ipv6Addr> = [1000, 2000, 3000]
|
||||
.into_iter()
|
||||
.map(Ipv6Addr::from_u128)
|
||||
.collect();
|
||||
for &ip_addr in &ip_addrs {
|
||||
writer
|
||||
.add_document(doc!(ips_field=>ip_addr, ips_field=>ip_addr))
|
||||
.unwrap();
|
||||
}
|
||||
writer.commit().unwrap();
|
||||
let searcher = index.reader().unwrap().searcher();
|
||||
let range_weight = FastFieldRangeWeight::new(
|
||||
ips_field,
|
||||
Bound::Included(Term::from_field_ip_addr(ips_field, ip_addrs[1])),
|
||||
Bound::Included(Term::from_field_ip_addr(ips_field, ip_addrs[2])),
|
||||
);
|
||||
|
||||
let count =
|
||||
crate::query::weight::Weight::count(&range_weight, searcher.segment_reader(0)).unwrap();
|
||||
assert_eq!(count, 2);
|
||||
}
|
||||
|
||||
pub fn create_index_from_ip_docs(docs: &[Doc]) -> Index {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let ip_field = schema_builder.add_ip_addr_field("ip", STORED | FAST);
|
||||
let ips_field = schema_builder.add_ip_addr_field("ips", FAST | INDEXED);
|
||||
let text_field = schema_builder.add_text_field("id", STRING | STORED);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
|
||||
{
|
||||
let mut index_writer = index.writer_with_num_threads(2, 60_000_000).unwrap();
|
||||
for doc in docs.iter() {
|
||||
index_writer
|
||||
.add_document(doc!(
|
||||
ips_field => doc.ip,
|
||||
ips_field => doc.ip,
|
||||
ip_field => doc.ip,
|
||||
text_field => doc.id.to_string(),
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
index_writer.commit().unwrap();
|
||||
}
|
||||
index
|
||||
}
|
||||
|
||||
fn test_ip_range_for_docs(docs: &[Doc]) -> crate::Result<()> {
|
||||
let index = create_index_from_ip_docs(docs);
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
|
||||
let get_num_hits = |query| searcher.search(&query, &Count).unwrap();
|
||||
let query_from_text = |text: &str| {
|
||||
QueryParser::for_index(&index, vec![])
|
||||
.parse_query(text)
|
||||
.unwrap()
|
||||
};
|
||||
|
||||
let gen_query_inclusive = |field: &str, ip_range: &RangeInclusive<Ipv6Addr>| {
|
||||
format!("{field}:[{} TO {}]", ip_range.start(), ip_range.end())
|
||||
};
|
||||
|
||||
let test_sample = |sample_docs: &[Doc]| {
|
||||
let mut ips: Vec<Ipv6Addr> = sample_docs.iter().map(|doc| doc.ip).collect();
|
||||
ips.sort();
|
||||
let ip_range = ips[0]..=ips[1];
|
||||
let expected_num_hits = docs
|
||||
.iter()
|
||||
.filter(|doc| (ips[0]..=ips[1]).contains(&doc.ip))
|
||||
.count();
|
||||
|
||||
let query = gen_query_inclusive("ip", &ip_range);
|
||||
assert_eq!(get_num_hits(query_from_text(&query)), expected_num_hits);
|
||||
|
||||
let query = gen_query_inclusive("ips", &ip_range);
|
||||
assert_eq!(get_num_hits(query_from_text(&query)), expected_num_hits);
|
||||
|
||||
// Intersection search
|
||||
let id_filter = sample_docs[0].id.to_string();
|
||||
let expected_num_hits = docs
|
||||
.iter()
|
||||
.filter(|doc| ip_range.contains(&doc.ip) && doc.id == id_filter)
|
||||
.count();
|
||||
let query = format!(
|
||||
"{} AND id:{}",
|
||||
gen_query_inclusive("ip", &ip_range),
|
||||
&id_filter
|
||||
);
|
||||
assert_eq!(get_num_hits(query_from_text(&query)), expected_num_hits);
|
||||
|
||||
// Intersection search on multivalue ip field
|
||||
let id_filter = sample_docs[0].id.to_string();
|
||||
let query = format!(
|
||||
"{} AND id:{}",
|
||||
gen_query_inclusive("ips", &ip_range),
|
||||
&id_filter
|
||||
);
|
||||
assert_eq!(get_num_hits(query_from_text(&query)), expected_num_hits);
|
||||
};
|
||||
|
||||
test_sample(&[docs[0].clone(), docs[0].clone()]);
|
||||
if docs.len() > 1 {
|
||||
test_sample(&[docs[0].clone(), docs[1].clone()]);
|
||||
test_sample(&[docs[1].clone(), docs[1].clone()]);
|
||||
}
|
||||
if docs.len() > 2 {
|
||||
test_sample(&[docs[1].clone(), docs[2].clone()]);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "unstable"))]
|
||||
mod bench {
|
||||
|
||||
@@ -601,3 +958,242 @@ mod bench {
|
||||
bench.iter(|| execute_query("ids", get_90_percent(), "AND id_name:veryfew", &index));
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "unstable"))]
|
||||
mod bench_ip {
|
||||
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use test::Bencher;
|
||||
|
||||
use super::ip_range_tests::*;
|
||||
use super::*;
|
||||
use crate::collector::Count;
|
||||
use crate::query::QueryParser;
|
||||
use crate::Index;
|
||||
|
||||
fn get_index_0_to_100() -> Index {
|
||||
let mut rng = StdRng::from_seed([1u8; 32]);
|
||||
let num_vals = 100_000;
|
||||
let docs: Vec<_> = (0..num_vals)
|
||||
.map(|_i| {
|
||||
let id = if rng.gen_bool(0.01) {
|
||||
"veryfew".to_string() // 1%
|
||||
} else if rng.gen_bool(0.1) {
|
||||
"few".to_string() // 9%
|
||||
} else {
|
||||
"many".to_string() // 90%
|
||||
};
|
||||
Doc {
|
||||
id,
|
||||
// Multiply by 1000, so that we create many buckets in the compact space
|
||||
// The benches depend on this range to select n-percent of elements with the
|
||||
// methods below.
|
||||
ip: Ipv6Addr::from_u128(rng.gen_range(0..100) * 1000),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
create_index_from_ip_docs(&docs)
|
||||
}
|
||||
|
||||
fn get_90_percent() -> RangeInclusive<Ipv6Addr> {
|
||||
let start = Ipv6Addr::from_u128(0);
|
||||
let end = Ipv6Addr::from_u128(90 * 1000);
|
||||
start..=end
|
||||
}
|
||||
|
||||
fn get_10_percent() -> RangeInclusive<Ipv6Addr> {
|
||||
let start = Ipv6Addr::from_u128(0);
|
||||
let end = Ipv6Addr::from_u128(10 * 1000);
|
||||
start..=end
|
||||
}
|
||||
|
||||
fn get_1_percent() -> RangeInclusive<Ipv6Addr> {
|
||||
let start = Ipv6Addr::from_u128(10 * 1000);
|
||||
let end = Ipv6Addr::from_u128(10 * 1000);
|
||||
start..=end
|
||||
}
|
||||
|
||||
fn excute_query(
|
||||
field: &str,
|
||||
ip_range: RangeInclusive<Ipv6Addr>,
|
||||
suffix: &str,
|
||||
index: &Index,
|
||||
) -> usize {
|
||||
let gen_query_inclusive = |from: &Ipv6Addr, to: &Ipv6Addr| {
|
||||
format!(
|
||||
"{}:[{} TO {}] {}",
|
||||
field,
|
||||
&from.to_string(),
|
||||
&to.to_string(),
|
||||
suffix
|
||||
)
|
||||
};
|
||||
|
||||
let query = gen_query_inclusive(ip_range.start(), ip_range.end());
|
||||
let query_from_text = |text: &str| {
|
||||
QueryParser::for_index(index, vec![])
|
||||
.parse_query(text)
|
||||
.unwrap()
|
||||
};
|
||||
let query = query_from_text(&query);
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
searcher.search(&query, &(Count)).unwrap()
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_90_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ip", get_90_percent(), "", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_10_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ip", get_10_percent(), "", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_1_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ip", get_1_percent(), "", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_10_percent_intersect_with_10_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ip", get_10_percent(), "AND id:few", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_1_percent_intersect_with_10_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ip", get_1_percent(), "AND id:few", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_1_percent_intersect_with_90_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ip", get_1_percent(), "AND id:many", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_1_percent_intersect_with_1_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ip", get_1_percent(), "AND id:veryfew", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_10_percent_intersect_with_90_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ip", get_10_percent(), "AND id:many", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_90_percent_intersect_with_90_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ip", get_90_percent(), "AND id:many", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_90_percent_intersect_with_10_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ip", get_90_percent(), "AND id:few", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_90_percent_intersect_with_1_percent(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ip", get_90_percent(), "AND id:veryfew", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_90_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ips", get_90_percent(), "", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_10_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ips", get_10_percent(), "", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_1_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ips", get_1_percent(), "", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_10_percent_intersect_with_10_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ips", get_10_percent(), "AND id:few", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_1_percent_intersect_with_10_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ips", get_1_percent(), "AND id:few", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_1_percent_intersect_with_90_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ips", get_1_percent(), "AND id:many", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_1_percent_intersect_with_1_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ips", get_1_percent(), "AND id:veryfew", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_10_percent_intersect_with_90_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ips", get_10_percent(), "AND id:many", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_90_percent_intersect_with_90_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ips", get_90_percent(), "AND id:many", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_90_percent_intersect_with_10_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ips", get_90_percent(), "AND id:few", &index));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_ip_range_hit_90_percent_intersect_with_1_percent_multi(bench: &mut Bencher) {
|
||||
let index = get_index_0_to_100();
|
||||
|
||||
bench.iter(|| excute_query("ips", get_90_percent(), "AND id:veryfew", &index));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -508,7 +508,7 @@ impl std::fmt::Debug for ValueAddr {
|
||||
|
||||
/// A enum representing a value for tantivy to index.
|
||||
///
|
||||
/// Any changes need to be reflected in `BinarySerializable` for `ValueType`
|
||||
/// ** Any changes need to be reflected in `BinarySerializable` for `ValueType` **
|
||||
///
|
||||
/// We can't use [schema::Type] or [columnar::ColumnType] here, because they are missing
|
||||
/// some items like Array and PreTokStr.
|
||||
@@ -553,7 +553,7 @@ impl BinarySerializable for ValueType {
|
||||
fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
|
||||
let num = u8::deserialize(reader)?;
|
||||
let type_id = if (0..=12).contains(&num) {
|
||||
unsafe { std::mem::transmute(num) }
|
||||
unsafe { std::mem::transmute::<u8, ValueType>(num) }
|
||||
} else {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
|
||||
@@ -201,6 +201,11 @@ impl FieldType {
|
||||
matches!(self, FieldType::IpAddr(_))
|
||||
}
|
||||
|
||||
/// returns true if this is an str field
|
||||
pub fn is_str(&self) -> bool {
|
||||
matches!(self, FieldType::Str(_))
|
||||
}
|
||||
|
||||
/// returns true if this is an date field
|
||||
pub fn is_date(&self) -> bool {
|
||||
matches!(self, FieldType::Date(_))
|
||||
|
||||
@@ -249,15 +249,8 @@ impl Term {
|
||||
#[inline]
|
||||
pub fn append_path(&mut self, bytes: &[u8]) -> &mut [u8] {
|
||||
let len_before = self.0.len();
|
||||
if bytes.contains(&JSON_END_OF_PATH) {
|
||||
self.0.extend(
|
||||
bytes
|
||||
.iter()
|
||||
.map(|&b| if b == JSON_END_OF_PATH { b'0' } else { b }),
|
||||
);
|
||||
} else {
|
||||
self.0.extend_from_slice(bytes);
|
||||
}
|
||||
assert!(!bytes.contains(&JSON_END_OF_PATH));
|
||||
self.0.extend_from_slice(bytes);
|
||||
&mut self.0[len_before..]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,9 +16,7 @@ fn make_test_sstable(suffix: &str) -> FileSlice {
|
||||
|
||||
let table = builder.finish().unwrap();
|
||||
let table = Arc::new(OwnedBytes::new(table));
|
||||
let slice = common::file_slice::FileSlice::new(table.clone());
|
||||
|
||||
slice
|
||||
common::file_slice::FileSlice::new(table.clone())
|
||||
}
|
||||
|
||||
pub fn criterion_benchmark(c: &mut Criterion) {
|
||||
|
||||
@@ -7,7 +7,7 @@ use rand::rngs::StdRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use tantivy_sstable::{Dictionary, MonotonicU64SSTable};
|
||||
|
||||
const CHARSET: &'static [u8] = b"abcdefghij";
|
||||
const CHARSET: &[u8] = b"abcdefghij";
|
||||
|
||||
fn generate_key(rng: &mut impl Rng) -> String {
|
||||
let len = rng.gen_range(3..12);
|
||||
|
||||
@@ -56,6 +56,53 @@ impl Dictionary<VoidSSTable> {
|
||||
}
|
||||
}
|
||||
|
||||
fn map_bound<TFrom, TTo>(bound: &Bound<TFrom>, transform: impl Fn(&TFrom) -> TTo) -> Bound<TTo> {
|
||||
use self::Bound::*;
|
||||
match bound {
|
||||
Excluded(ref from_val) => Bound::Excluded(transform(from_val)),
|
||||
Included(ref from_val) => Bound::Included(transform(from_val)),
|
||||
Unbounded => Unbounded,
|
||||
}
|
||||
}
|
||||
|
||||
/// Takes a bound and transforms the inner value into a new bound via a closure.
|
||||
/// The bound variant may change by the value returned value from the closure.
|
||||
fn transform_bound_inner<TFrom, TTo>(
|
||||
bound: &Bound<TFrom>,
|
||||
transform: impl Fn(&TFrom) -> io::Result<Bound<TTo>>,
|
||||
) -> io::Result<Bound<TTo>> {
|
||||
use self::Bound::*;
|
||||
Ok(match bound {
|
||||
Excluded(ref from_val) => transform(from_val)?,
|
||||
Included(ref from_val) => transform(from_val)?,
|
||||
Unbounded => Unbounded,
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum TermOrdHit {
|
||||
/// Exact term ord hit
|
||||
Exact(TermOrdinal),
|
||||
/// Next best term ordinal
|
||||
Next(TermOrdinal),
|
||||
}
|
||||
|
||||
impl TermOrdHit {
|
||||
fn into_exact(self) -> Option<TermOrdinal> {
|
||||
match self {
|
||||
TermOrdHit::Exact(ord) => Some(ord),
|
||||
TermOrdHit::Next(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn map<F: FnOnce(TermOrdinal) -> TermOrdinal>(self, f: F) -> Self {
|
||||
match self {
|
||||
TermOrdHit::Exact(ord) => TermOrdHit::Exact(f(ord)),
|
||||
TermOrdHit::Next(ord) => TermOrdHit::Next(f(ord)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<TSSTable: SSTable> Dictionary<TSSTable> {
|
||||
pub fn builder<W: io::Write>(wrt: W) -> io::Result<crate::Writer<W, TSSTable::ValueWriter>> {
|
||||
Ok(TSSTable::writer(wrt))
|
||||
@@ -257,6 +304,17 @@ impl<TSSTable: SSTable> Dictionary<TSSTable> {
|
||||
key: K,
|
||||
sstable_delta_reader: &mut DeltaReader<TSSTable::ValueReader>,
|
||||
) -> io::Result<Option<TermOrdinal>> {
|
||||
self.decode_up_to_or_next(key, sstable_delta_reader)
|
||||
.map(|hit| hit.into_exact())
|
||||
}
|
||||
/// Decode a DeltaReader up to key, returning the number of terms traversed
|
||||
///
|
||||
/// If the key was not found, it returns the next term id.
|
||||
fn decode_up_to_or_next<K: AsRef<[u8]>>(
|
||||
&self,
|
||||
key: K,
|
||||
sstable_delta_reader: &mut DeltaReader<TSSTable::ValueReader>,
|
||||
) -> io::Result<TermOrdHit> {
|
||||
let mut term_ord = 0;
|
||||
let key_bytes = key.as_ref();
|
||||
let mut ok_bytes = 0;
|
||||
@@ -265,7 +323,7 @@ impl<TSSTable: SSTable> Dictionary<TSSTable> {
|
||||
let suffix = sstable_delta_reader.suffix();
|
||||
|
||||
match prefix_len.cmp(&ok_bytes) {
|
||||
Ordering::Less => return Ok(None), // popped bytes already matched => too far
|
||||
Ordering::Less => return Ok(TermOrdHit::Next(term_ord)), /* popped bytes already matched => too far */
|
||||
Ordering::Equal => (),
|
||||
Ordering::Greater => {
|
||||
// the ok prefix is less than current entry prefix => continue to next elem
|
||||
@@ -277,25 +335,26 @@ impl<TSSTable: SSTable> Dictionary<TSSTable> {
|
||||
// we have ok_bytes byte of common prefix, check if this key adds more
|
||||
for (key_byte, suffix_byte) in key_bytes[ok_bytes..].iter().zip(suffix) {
|
||||
match suffix_byte.cmp(key_byte) {
|
||||
Ordering::Less => break, // byte too small
|
||||
Ordering::Equal => ok_bytes += 1, // new matching byte
|
||||
Ordering::Greater => return Ok(None), // too far
|
||||
Ordering::Less => break, // byte too small
|
||||
Ordering::Equal => ok_bytes += 1, // new matching
|
||||
// byte
|
||||
Ordering::Greater => return Ok(TermOrdHit::Next(term_ord)), // too far
|
||||
}
|
||||
}
|
||||
|
||||
if ok_bytes == key_bytes.len() {
|
||||
if prefix_len + suffix.len() == ok_bytes {
|
||||
return Ok(Some(term_ord));
|
||||
return Ok(TermOrdHit::Exact(term_ord));
|
||||
} else {
|
||||
// current key is a prefix of current element, not a match
|
||||
return Ok(None);
|
||||
return Ok(TermOrdHit::Next(term_ord));
|
||||
}
|
||||
}
|
||||
|
||||
term_ord += 1;
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
Ok(TermOrdHit::Next(term_ord))
|
||||
}
|
||||
|
||||
/// Returns the ordinal associated with a given term.
|
||||
@@ -312,6 +371,61 @@ impl<TSSTable: SSTable> Dictionary<TSSTable> {
|
||||
.map(|opt| opt.map(|ord| ord + first_ordinal))
|
||||
}
|
||||
|
||||
/// Returns the ordinal associated with a given term or its closest next term_id
|
||||
/// The closest next term_id may not exist.
|
||||
pub fn term_ord_or_next<K: AsRef<[u8]>>(&self, key: K) -> io::Result<TermOrdHit> {
|
||||
let key_bytes = key.as_ref();
|
||||
|
||||
let Some(block_addr) = self.sstable_index.get_block_with_key(key_bytes) else {
|
||||
// TODO: Would be more consistent to return last_term id + 1
|
||||
return Ok(TermOrdHit::Next(u64::MAX));
|
||||
};
|
||||
|
||||
let first_ordinal = block_addr.first_ordinal;
|
||||
let mut sstable_delta_reader = self.sstable_delta_reader_block(block_addr)?;
|
||||
self.decode_up_to_or_next(key_bytes, &mut sstable_delta_reader)
|
||||
.map(|opt| opt.map(|ord| ord + first_ordinal))
|
||||
}
|
||||
|
||||
/// Converts strings into a Bound range.
|
||||
/// This does handle several special cases if the term is not exactly in the dictionary.
|
||||
/// e.g. [bbb, ddd]
|
||||
/// lower_bound: Bound::Included(aaa) => Included(0) // "Next" term id
|
||||
/// lower_bound: Bound::Excluded(aaa) => Included(0) // "Next" term id + Change the Bounds
|
||||
/// lower_bound: Bound::Included(ccc) => Included(1) // "Next" term id
|
||||
/// lower_bound: Bound::Excluded(ccc) => Included(1) // "Next" term id + Change the Bounds
|
||||
/// lower_bound: Bound::Included(zzz) => Included(2) // "Next" term id
|
||||
/// lower_bound: Bound::Excluded(zzz) => Included(2) // "Next" term id + Change the Bounds
|
||||
/// For zzz we should have some post processing to return an empty query`
|
||||
///
|
||||
/// upper_bound: Bound::Included(aaa) => Excluded(0) // "Next" term id + Change the bounds
|
||||
/// upper_bound: Bound::Excluded(aaa) => Excluded(0) // "Next" term id
|
||||
/// upper_bound: Bound::Included(ccc) => Excluded(1) // Next term id + Change the bounds
|
||||
/// upper_bound: Bound::Excluded(ccc) => Excluded(1) // Next term id
|
||||
/// upper_bound: Bound::Included(zzz) => Excluded(2) // Next term id + Change the bounds
|
||||
/// upper_bound: Bound::Excluded(zzz) => Excluded(2) // Next term id
|
||||
pub fn term_bounds_to_ord<K: AsRef<[u8]>>(
|
||||
&self,
|
||||
lower_bound: Bound<K>,
|
||||
upper_bound: Bound<K>,
|
||||
) -> io::Result<(Bound<TermOrdinal>, Bound<TermOrdinal>)> {
|
||||
let lower_bound = transform_bound_inner(&lower_bound, |start_bound_bytes| {
|
||||
let ord = self.term_ord_or_next(start_bound_bytes)?;
|
||||
match ord {
|
||||
TermOrdHit::Exact(ord) => Ok(map_bound(&lower_bound, |_| ord)),
|
||||
TermOrdHit::Next(ord) => Ok(Bound::Included(ord)), // Change bounds to included
|
||||
}
|
||||
})?;
|
||||
let upper_bound = transform_bound_inner(&upper_bound, |end_bound_bytes| {
|
||||
let ord = self.term_ord_or_next(end_bound_bytes)?;
|
||||
match ord {
|
||||
TermOrdHit::Exact(ord) => Ok(map_bound(&upper_bound, |_| ord)),
|
||||
TermOrdHit::Next(ord) => Ok(Bound::Excluded(ord)), // Change bounds to excluded
|
||||
}
|
||||
})?;
|
||||
Ok((lower_bound, upper_bound))
|
||||
}
|
||||
|
||||
/// Returns the term associated with a given term ordinal.
|
||||
///
|
||||
/// Term ordinals are defined as the position of the term in
|
||||
@@ -338,6 +452,45 @@ impl<TSSTable: SSTable> Dictionary<TSSTable> {
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Returns the terms for a _sorted_ list of term ordinals.
|
||||
///
|
||||
/// Returns true if and only if all terms have been found.
|
||||
pub fn sorted_ords_to_term_cb<F: FnMut(&[u8]) -> io::Result<()>>(
|
||||
&self,
|
||||
ord: impl Iterator<Item = TermOrdinal>,
|
||||
mut cb: F,
|
||||
) -> io::Result<bool> {
|
||||
let mut bytes = Vec::new();
|
||||
let mut current_block_addr = self.sstable_index.get_block_with_ord(0);
|
||||
let mut current_sstable_delta_reader =
|
||||
self.sstable_delta_reader_block(current_block_addr.clone())?;
|
||||
let mut current_ordinal = 0;
|
||||
for ord in ord {
|
||||
assert!(ord >= current_ordinal);
|
||||
// check if block changed for new term_ord
|
||||
let new_block_addr = self.sstable_index.get_block_with_ord(ord);
|
||||
if new_block_addr != current_block_addr {
|
||||
current_block_addr = new_block_addr;
|
||||
current_ordinal = current_block_addr.first_ordinal;
|
||||
current_sstable_delta_reader =
|
||||
self.sstable_delta_reader_block(current_block_addr.clone())?;
|
||||
bytes.clear();
|
||||
}
|
||||
|
||||
// move to ord inside that block
|
||||
for _ in current_ordinal..=ord {
|
||||
if !current_sstable_delta_reader.advance()? {
|
||||
return Ok(false);
|
||||
}
|
||||
bytes.truncate(current_sstable_delta_reader.common_prefix_len());
|
||||
bytes.extend_from_slice(current_sstable_delta_reader.suffix());
|
||||
}
|
||||
current_ordinal = ord + 1;
|
||||
cb(&bytes)?;
|
||||
}
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Returns the number of terms in the dictionary.
|
||||
pub fn term_info_from_ord(&self, term_ord: TermOrdinal) -> io::Result<Option<TSSTable::Value>> {
|
||||
// find block in which the term would be
|
||||
@@ -416,12 +569,13 @@ impl<TSSTable: SSTable> Dictionary<TSSTable> {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::ops::Range;
|
||||
use std::ops::{Bound, Range};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use common::OwnedBytes;
|
||||
|
||||
use super::Dictionary;
|
||||
use crate::dictionary::TermOrdHit;
|
||||
use crate::MonotonicU64SSTable;
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -485,6 +639,140 @@ mod tests {
|
||||
(dictionary, table)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_term_to_ord_or_next() {
|
||||
let dict = {
|
||||
let mut builder = Dictionary::<MonotonicU64SSTable>::builder(Vec::new()).unwrap();
|
||||
|
||||
builder.insert(b"bbb", &1).unwrap();
|
||||
builder.insert(b"ddd", &2).unwrap();
|
||||
|
||||
let table = builder.finish().unwrap();
|
||||
let table = Arc::new(PermissionedHandle::new(table));
|
||||
let slice = common::file_slice::FileSlice::new(table.clone());
|
||||
|
||||
Dictionary::<MonotonicU64SSTable>::open(slice).unwrap()
|
||||
};
|
||||
|
||||
assert_eq!(dict.term_ord_or_next(b"aaa").unwrap(), TermOrdHit::Next(0));
|
||||
assert_eq!(dict.term_ord_or_next(b"bbb").unwrap(), TermOrdHit::Exact(0));
|
||||
assert_eq!(dict.term_ord_or_next(b"bb").unwrap(), TermOrdHit::Next(0));
|
||||
assert_eq!(dict.term_ord_or_next(b"bbbb").unwrap(), TermOrdHit::Next(1));
|
||||
assert_eq!(dict.term_ord_or_next(b"dd").unwrap(), TermOrdHit::Next(1));
|
||||
assert_eq!(dict.term_ord_or_next(b"ddd").unwrap(), TermOrdHit::Exact(1));
|
||||
assert_eq!(dict.term_ord_or_next(b"dddd").unwrap(), TermOrdHit::Next(2));
|
||||
|
||||
// This is not u64::MAX because for very small sstables (only one block),
|
||||
// we don't store an index, and the pseudo-index always reply that the
|
||||
// answer lies in block number 0
|
||||
assert_eq!(
|
||||
dict.term_ord_or_next(b"zzzzzzz").unwrap(),
|
||||
TermOrdHit::Next(2)
|
||||
);
|
||||
}
|
||||
#[test]
|
||||
fn test_term_to_ord_or_next_2() {
|
||||
let dict = {
|
||||
let mut builder = Dictionary::<MonotonicU64SSTable>::builder(Vec::new()).unwrap();
|
||||
|
||||
let mut term_ord = 0;
|
||||
builder.insert(b"bbb", &term_ord).unwrap();
|
||||
|
||||
// Fill blocks in between
|
||||
for elem in 0..50_000 {
|
||||
term_ord += 1;
|
||||
let key = format!("ccccc{elem:05X}").into_bytes();
|
||||
builder.insert(&key, &term_ord).unwrap();
|
||||
}
|
||||
|
||||
term_ord += 1;
|
||||
builder.insert(b"eee", &term_ord).unwrap();
|
||||
|
||||
let table = builder.finish().unwrap();
|
||||
let table = Arc::new(PermissionedHandle::new(table));
|
||||
let slice = common::file_slice::FileSlice::new(table.clone());
|
||||
|
||||
Dictionary::<MonotonicU64SSTable>::open(slice).unwrap()
|
||||
};
|
||||
|
||||
assert_eq!(dict.term_ord(b"bbb").unwrap(), Some(0));
|
||||
assert_eq!(dict.term_ord_or_next(b"bbb").unwrap(), TermOrdHit::Exact(0));
|
||||
assert_eq!(dict.term_ord_or_next(b"aaa").unwrap(), TermOrdHit::Next(0));
|
||||
assert_eq!(dict.term_ord_or_next(b"bb").unwrap(), TermOrdHit::Next(0));
|
||||
assert_eq!(dict.term_ord_or_next(b"bbbb").unwrap(), TermOrdHit::Next(1));
|
||||
assert_eq!(
|
||||
dict.term_ord_or_next(b"ee").unwrap(),
|
||||
TermOrdHit::Next(50001)
|
||||
);
|
||||
assert_eq!(
|
||||
dict.term_ord_or_next(b"eee").unwrap(),
|
||||
TermOrdHit::Exact(50001)
|
||||
);
|
||||
assert_eq!(
|
||||
dict.term_ord_or_next(b"eeee").unwrap(),
|
||||
TermOrdHit::Next(u64::MAX)
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
dict.term_ord_or_next(b"zzzzzzz").unwrap(),
|
||||
TermOrdHit::Next(u64::MAX)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_term_bounds_to_ord() {
|
||||
let dict = {
|
||||
let mut builder = Dictionary::<MonotonicU64SSTable>::builder(Vec::new()).unwrap();
|
||||
|
||||
builder.insert(b"bbb", &1).unwrap();
|
||||
builder.insert(b"ddd", &2).unwrap();
|
||||
|
||||
let table = builder.finish().unwrap();
|
||||
let table = Arc::new(PermissionedHandle::new(table));
|
||||
let slice = common::file_slice::FileSlice::new(table.clone());
|
||||
|
||||
Dictionary::<MonotonicU64SSTable>::open(slice).unwrap()
|
||||
};
|
||||
|
||||
// Test cases for lower_bound
|
||||
let test_lower_bound = |bound, expected| {
|
||||
assert_eq!(
|
||||
dict.term_bounds_to_ord::<&[u8]>(bound, Bound::Included(b"ignored"))
|
||||
.unwrap()
|
||||
.0,
|
||||
expected
|
||||
);
|
||||
};
|
||||
|
||||
test_lower_bound(Bound::Included(b"aaa".as_slice()), Bound::Included(0));
|
||||
test_lower_bound(Bound::Excluded(b"aaa".as_slice()), Bound::Included(0));
|
||||
|
||||
test_lower_bound(Bound::Included(b"bbb".as_slice()), Bound::Included(0));
|
||||
test_lower_bound(Bound::Excluded(b"bbb".as_slice()), Bound::Excluded(0));
|
||||
|
||||
test_lower_bound(Bound::Included(b"ccc".as_slice()), Bound::Included(1));
|
||||
test_lower_bound(Bound::Excluded(b"ccc".as_slice()), Bound::Included(1));
|
||||
|
||||
test_lower_bound(Bound::Included(b"zzz".as_slice()), Bound::Included(2));
|
||||
test_lower_bound(Bound::Excluded(b"zzz".as_slice()), Bound::Included(2));
|
||||
|
||||
// Test cases for upper_bound
|
||||
let test_upper_bound = |bound, expected| {
|
||||
assert_eq!(
|
||||
dict.term_bounds_to_ord::<&[u8]>(Bound::Included(b"ignored"), bound,)
|
||||
.unwrap()
|
||||
.1,
|
||||
expected
|
||||
);
|
||||
};
|
||||
test_upper_bound(Bound::Included(b"ccc".as_slice()), Bound::Excluded(1));
|
||||
test_upper_bound(Bound::Excluded(b"ccc".as_slice()), Bound::Excluded(1));
|
||||
test_upper_bound(Bound::Included(b"zzz".as_slice()), Bound::Excluded(2));
|
||||
test_upper_bound(Bound::Excluded(b"zzz".as_slice()), Bound::Excluded(2));
|
||||
test_upper_bound(Bound::Included(b"ddd".as_slice()), Bound::Included(1));
|
||||
test_upper_bound(Bound::Excluded(b"ddd".as_slice()), Bound::Excluded(1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ord_term_conversion() {
|
||||
let (dic, slice) = make_test_sstable();
|
||||
@@ -551,6 +839,61 @@ mod tests {
|
||||
assert!(dic.term_ord(b"1000").unwrap().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ords_term() {
|
||||
let (dic, _slice) = make_test_sstable();
|
||||
|
||||
// Single term
|
||||
let mut terms = Vec::new();
|
||||
assert!(dic
|
||||
.sorted_ords_to_term_cb(100_000..100_001, |term| {
|
||||
terms.push(term.to_vec());
|
||||
Ok(())
|
||||
})
|
||||
.unwrap());
|
||||
assert_eq!(terms, vec![format!("{:05X}", 100_000).into_bytes(),]);
|
||||
// Single term
|
||||
let mut terms = Vec::new();
|
||||
assert!(dic
|
||||
.sorted_ords_to_term_cb(100_001..100_002, |term| {
|
||||
terms.push(term.to_vec());
|
||||
Ok(())
|
||||
})
|
||||
.unwrap());
|
||||
assert_eq!(terms, vec![format!("{:05X}", 100_001).into_bytes(),]);
|
||||
// both terms
|
||||
let mut terms = Vec::new();
|
||||
assert!(dic
|
||||
.sorted_ords_to_term_cb(100_000..100_002, |term| {
|
||||
terms.push(term.to_vec());
|
||||
Ok(())
|
||||
})
|
||||
.unwrap());
|
||||
assert_eq!(
|
||||
terms,
|
||||
vec![
|
||||
format!("{:05X}", 100_000).into_bytes(),
|
||||
format!("{:05X}", 100_001).into_bytes(),
|
||||
]
|
||||
);
|
||||
// Test cross block
|
||||
let mut terms = Vec::new();
|
||||
assert!(dic
|
||||
.sorted_ords_to_term_cb(98653..=98655, |term| {
|
||||
terms.push(term.to_vec());
|
||||
Ok(())
|
||||
})
|
||||
.unwrap());
|
||||
assert_eq!(
|
||||
terms,
|
||||
vec![
|
||||
format!("{:05X}", 98653).into_bytes(),
|
||||
format!("{:05X}", 98654).into_bytes(),
|
||||
format!("{:05X}", 98655).into_bytes(),
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_range() {
|
||||
let (dic, slice) = make_test_sstable();
|
||||
|
||||
@@ -78,6 +78,7 @@ impl ValueWriter for RangeValueWriter {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(clippy::single_range_in_vec_init)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
|
||||
@@ -39,8 +39,6 @@ pub fn deserialize_read(buf: &[u8]) -> (usize, u64) {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::u64;
|
||||
|
||||
use super::{deserialize_read, serialize};
|
||||
|
||||
fn aux_test_int(val: u64, expect_len: usize) {
|
||||
|
||||
@@ -54,7 +54,7 @@ fn bench_hashmap_throughput(c: &mut Criterion) {
|
||||
);
|
||||
|
||||
// numbers
|
||||
let input_bytes = 1_000_000 * 8 as u64;
|
||||
let input_bytes = 1_000_000 * 8;
|
||||
group.throughput(Throughput::Bytes(input_bytes));
|
||||
let numbers: Vec<[u8; 8]> = (0..1_000_000u64).map(|el| el.to_le_bytes()).collect();
|
||||
|
||||
@@ -82,7 +82,7 @@ fn bench_hashmap_throughput(c: &mut Criterion) {
|
||||
let mut rng = StdRng::from_seed([3u8; 32]);
|
||||
let zipf = zipf::ZipfDistribution::new(10_000, 1.03).unwrap();
|
||||
|
||||
let input_bytes = 1_000_000 * 8 as u64;
|
||||
let input_bytes = 1_000_000 * 8;
|
||||
group.throughput(Throughput::Bytes(input_bytes));
|
||||
let zipf_numbers: Vec<[u8; 8]> = (0..1_000_000u64)
|
||||
.map(|_| zipf.sample(&mut rng).to_le_bytes())
|
||||
@@ -110,7 +110,7 @@ impl DocIdRecorder {
|
||||
}
|
||||
}
|
||||
|
||||
fn create_hash_map<'a, T: AsRef<[u8]>>(terms: impl Iterator<Item = T>) -> ArenaHashMap {
|
||||
fn create_hash_map<T: AsRef<[u8]>>(terms: impl Iterator<Item = T>) -> ArenaHashMap {
|
||||
let mut map = ArenaHashMap::with_capacity(HASHMAP_SIZE);
|
||||
for term in terms {
|
||||
map.mutate_or_create(term.as_ref(), |val| {
|
||||
@@ -126,7 +126,7 @@ fn create_hash_map<'a, T: AsRef<[u8]>>(terms: impl Iterator<Item = T>) -> ArenaH
|
||||
map
|
||||
}
|
||||
|
||||
fn create_hash_map_with_expull<'a, T: AsRef<[u8]>>(
|
||||
fn create_hash_map_with_expull<T: AsRef<[u8]>>(
|
||||
terms: impl Iterator<Item = (u32, T)>,
|
||||
) -> ArenaHashMap {
|
||||
let mut memory_arena = MemoryArena::default();
|
||||
@@ -145,7 +145,7 @@ fn create_hash_map_with_expull<'a, T: AsRef<[u8]>>(
|
||||
map
|
||||
}
|
||||
|
||||
fn create_fx_hash_ref_map_with_expull<'a>(
|
||||
fn create_fx_hash_ref_map_with_expull(
|
||||
terms: impl Iterator<Item = &'static [u8]>,
|
||||
) -> FxHashMap<&'static [u8], Vec<u32>> {
|
||||
let terms = terms.enumerate();
|
||||
@@ -158,7 +158,7 @@ fn create_fx_hash_ref_map_with_expull<'a>(
|
||||
map
|
||||
}
|
||||
|
||||
fn create_fx_hash_owned_map_with_expull<'a>(
|
||||
fn create_fx_hash_owned_map_with_expull(
|
||||
terms: impl Iterator<Item = &'static [u8]>,
|
||||
) -> FxHashMap<Vec<u8>, Vec<u32>> {
|
||||
let terms = terms.enumerate();
|
||||
|
||||
@@ -10,7 +10,7 @@ fn main() {
|
||||
}
|
||||
}
|
||||
|
||||
fn create_hash_map<'a, T: AsRef<str>>(terms: impl Iterator<Item = T>) -> ArenaHashMap {
|
||||
fn create_hash_map<T: AsRef<str>>(terms: impl Iterator<Item = T>) -> ArenaHashMap {
|
||||
let mut map = ArenaHashMap::with_capacity(4);
|
||||
for term in terms {
|
||||
map.mutate_or_create(term.as_ref().as_bytes(), |val| {
|
||||
|
||||
Reference in New Issue
Block a user