mirror of
https://github.com/quickwit-oss/tantivy.git
synced 2026-02-12 11:00:36 +00:00
Compare commits
12 Commits
congxie/su
...
congxie/re
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
698f073f88 | ||
|
|
cdd24b7ee5 | ||
|
|
5562ce6037 | ||
|
|
09b6ececa7 | ||
|
|
8018016e46 | ||
|
|
6bf185dc3f | ||
|
|
bb141abe22 | ||
|
|
f1c29ba972 | ||
|
|
ae0554a6a5 | ||
|
|
0d7abe5d23 | ||
|
|
a55e4069e4 | ||
|
|
1fd30c62be |
@@ -65,7 +65,7 @@ tantivy-bitpacker = { version = "0.9", path = "./bitpacker" }
|
||||
common = { version = "0.10", path = "./common/", package = "tantivy-common" }
|
||||
tokenizer-api = { version = "0.6", path = "./tokenizer-api", package = "tantivy-tokenizer-api" }
|
||||
sketches-ddsketch = { version = "0.3.0", features = ["use_serde"] }
|
||||
hyperloglogplus = { version = "0.4.1", features = ["const-loop"] }
|
||||
datasketches = "0.2.0"
|
||||
futures-util = { version = "0.3.28", optional = true }
|
||||
futures-channel = { version = "0.3.28", optional = true }
|
||||
fnv = "1.0.7"
|
||||
|
||||
@@ -704,7 +704,11 @@ fn regex(inp: &str) -> IResult<&str, UserInputLeaf> {
|
||||
many1(alt((preceded(char('\\'), char('/')), none_of("/")))),
|
||||
char('/'),
|
||||
),
|
||||
peek(alt((multispace1, eof))),
|
||||
peek(alt((
|
||||
value((), multispace1),
|
||||
value((), char(')')),
|
||||
value((), eof),
|
||||
))),
|
||||
),
|
||||
|elements| UserInputLeaf::Regex {
|
||||
field: None,
|
||||
@@ -721,8 +725,12 @@ fn regex_infallible(inp: &str) -> JResult<&str, UserInputLeaf> {
|
||||
opt_i_err(char('/'), "missing delimiter /"),
|
||||
),
|
||||
opt_i_err(
|
||||
peek(alt((multispace1, eof))),
|
||||
"expected whitespace or end of input",
|
||||
peek(alt((
|
||||
value((), multispace1),
|
||||
value((), char(')')),
|
||||
value((), eof),
|
||||
))),
|
||||
"expected whitespace, closing parenthesis, or end of input",
|
||||
),
|
||||
)(inp)
|
||||
{
|
||||
@@ -1707,6 +1715,10 @@ mod test {
|
||||
test_parse_query_to_ast_helper("foo:(A OR B)", "(?\"foo\":A ?\"foo\":B)");
|
||||
test_parse_query_to_ast_helper("foo:(A* OR B*)", "(?\"foo\":A* ?\"foo\":B*)");
|
||||
test_parse_query_to_ast_helper("foo:(*A OR *B)", "(?\"foo\":*A ?\"foo\":*B)");
|
||||
|
||||
// Regexes between parentheses
|
||||
test_parse_query_to_ast_helper("foo:(/A.*/)", "\"foo\":/A.*/");
|
||||
test_parse_query_to_ast_helper("foo:(/A.*/ OR /B.*/)", "(?\"foo\":/A.*/ ?\"foo\":/B.*/)");
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -66,6 +66,7 @@ impl UserInputLeaf {
|
||||
}
|
||||
UserInputLeaf::Range { field, .. } if field.is_none() => *field = Some(default_field),
|
||||
UserInputLeaf::Set { field, .. } if field.is_none() => *field = Some(default_field),
|
||||
UserInputLeaf::Regex { field, .. } if field.is_none() => *field = Some(default_field),
|
||||
_ => (), // field was already set, do nothing
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,8 +10,7 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::bucket::GetDocCount;
|
||||
use super::metric::{
|
||||
AverageMetricResult, CardinalityMetricResult, ExtendedStats, PercentilesMetricResult,
|
||||
SingleMetricResult, Stats, TopHitsMetricResult,
|
||||
ExtendedStats, PercentilesMetricResult, SingleMetricResult, Stats, TopHitsMetricResult,
|
||||
};
|
||||
use super::{AggregationError, Key};
|
||||
use crate::TantivyError;
|
||||
@@ -82,8 +81,8 @@ impl AggregationResult {
|
||||
#[serde(untagged)]
|
||||
/// MetricResult
|
||||
pub enum MetricResult {
|
||||
/// Average metric result with sum and count for multi-step merging.
|
||||
Average(AverageMetricResult),
|
||||
/// Average metric result.
|
||||
Average(SingleMetricResult),
|
||||
/// Count metric result.
|
||||
Count(SingleMetricResult),
|
||||
/// Max metric result.
|
||||
@@ -100,8 +99,8 @@ pub enum MetricResult {
|
||||
Percentiles(PercentilesMetricResult),
|
||||
/// Top hits metric result
|
||||
TopHits(TopHitsMetricResult),
|
||||
/// Cardinality metric result with HLL sketch for multi-step merging.
|
||||
Cardinality(CardinalityMetricResult),
|
||||
/// Cardinality metric result
|
||||
Cardinality(SingleMetricResult),
|
||||
}
|
||||
|
||||
impl MetricResult {
|
||||
@@ -120,7 +119,7 @@ impl MetricResult {
|
||||
MetricResult::TopHits(_) => Err(TantivyError::AggregationError(
|
||||
AggregationError::InvalidRequest("top_hits can't be used to order".to_string()),
|
||||
)),
|
||||
MetricResult::Cardinality(card) => Ok(card.value), // CardinalityMetricResult.value
|
||||
MetricResult::Cardinality(card) => Ok(card.value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1359,10 +1359,10 @@ fn test_aggregation_on_json_object_mixed_types() {
|
||||
&serde_json::json!({
|
||||
"rangeagg": {
|
||||
"buckets": [
|
||||
{ "average_in_range": { "value": -20.5, "sum": -20.5, "count": 1 }, "doc_count": 1, "key": "*-3", "to": 3.0 },
|
||||
{ "average_in_range": { "value": 10.0, "sum": 10.0, "count": 1 }, "doc_count": 1, "from": 3.0, "key": "3-19", "to": 19.0 },
|
||||
{ "average_in_range": { "value": null, "sum": 0.0, "count": 0 }, "doc_count": 0, "from": 19.0, "key": "19-20", "to": 20.0 },
|
||||
{ "average_in_range": { "value": null, "sum": 0.0, "count": 0 }, "doc_count": 0, "from": 20.0, "key": "20-*" }
|
||||
{ "average_in_range": { "value": -20.5 }, "doc_count": 1, "key": "*-3", "to": 3.0 },
|
||||
{ "average_in_range": { "value": 10.0 }, "doc_count": 1, "from": 3.0, "key": "3-19", "to": 19.0 },
|
||||
{ "average_in_range": { "value": null }, "doc_count": 0, "from": 19.0, "key": "19-20", "to": 20.0 },
|
||||
{ "average_in_range": { "value": null }, "doc_count": 0, "from": 20.0, "key": "20-*" }
|
||||
]
|
||||
},
|
||||
"termagg": {
|
||||
|
||||
@@ -838,7 +838,7 @@ mod tests {
|
||||
let expected = json!({
|
||||
"electronics": {
|
||||
"doc_count": 2,
|
||||
"avg_price": { "value": 899.0, "sum": 1798.0, "count": 2 } // (999 + 799) / 2
|
||||
"avg_price": { "value": 899.0 } // (999 + 799) / 2
|
||||
}
|
||||
});
|
||||
|
||||
@@ -868,7 +868,7 @@ mod tests {
|
||||
let expected = json!({
|
||||
"furniture": {
|
||||
"doc_count": 0,
|
||||
"avg_price": { "value": null, "sum": 0.0, "count": 0 }
|
||||
"avg_price": { "value": null }
|
||||
}
|
||||
});
|
||||
|
||||
@@ -904,7 +904,7 @@ mod tests {
|
||||
let expected = json!({
|
||||
"electronics": {
|
||||
"doc_count": 2,
|
||||
"avg_price": { "value": 899.0, "sum": 1798.0, "count": 2 }
|
||||
"avg_price": { "value": 899.0 }
|
||||
},
|
||||
"in_stock": {
|
||||
"doc_count": 3, // apple, samsung, penguin
|
||||
@@ -1000,7 +1000,7 @@ mod tests {
|
||||
let expected = json!({
|
||||
"premium_electronics": {
|
||||
"doc_count": 1, // Only apple (999) is >= 800 in tantivy's range semantics
|
||||
"avg_rating": { "value": 4.5, "sum": 4.5, "count": 1 }
|
||||
"avg_rating": { "value": 4.5 }
|
||||
}
|
||||
});
|
||||
|
||||
@@ -1032,7 +1032,7 @@ mod tests {
|
||||
let expected = json!({
|
||||
"in_stock": {
|
||||
"doc_count": 3, // apple, samsung, penguin
|
||||
"avg_price": { "value": 607.67, "sum": 1823.0, "count": 3 } // (999 + 799 + 25) / 3 ≈ 607.67
|
||||
"avg_price": { "value": 607.67 } // (999 + 799 + 25) / 3 ≈ 607.67
|
||||
},
|
||||
"out_of_stock": {
|
||||
"doc_count": 1, // nike
|
||||
@@ -1183,7 +1183,7 @@ mod tests {
|
||||
"doc_count": 4,
|
||||
"electronics_branch": {
|
||||
"doc_count": 2,
|
||||
"avg_price": { "value": 899.0, "sum": 1798.0, "count": 2 }
|
||||
"avg_price": { "value": 899.0 }
|
||||
},
|
||||
"in_stock_branch": {
|
||||
"doc_count": 3,
|
||||
@@ -1259,7 +1259,7 @@ mod tests {
|
||||
"doc_count": 2, // apple (999), samsung (799)
|
||||
"electronics": {
|
||||
"doc_count": 2, // both are electronics
|
||||
"avg_rating": { "value": 4.35, "sum": 8.7, "count": 2 } // (4.5 + 4.2) / 2
|
||||
"avg_rating": { "value": 4.35 } // (4.5 + 4.2) / 2
|
||||
},
|
||||
"in_stock": {
|
||||
"doc_count": 2, // both are in stock
|
||||
@@ -1321,12 +1321,12 @@ mod tests {
|
||||
{
|
||||
"key": "samsung",
|
||||
"doc_count": 1,
|
||||
"avg_price": { "value": 799.0, "sum": 799.0, "count": 1 }
|
||||
"avg_price": { "value": 799.0 }
|
||||
},
|
||||
{
|
||||
"key": "apple",
|
||||
"doc_count": 1,
|
||||
"avg_price": { "value": 999.0, "sum": 999.0, "count": 1 }
|
||||
"avg_price": { "value": 999.0 }
|
||||
}
|
||||
],
|
||||
"sum_other_doc_count": 0,
|
||||
@@ -1370,7 +1370,7 @@ mod tests {
|
||||
"sum": 1798.0,
|
||||
"avg": 899.0
|
||||
},
|
||||
"rating_avg": { "value": 4.35, "sum": 8.7, "count": 2 },
|
||||
"rating_avg": { "value": 4.35 },
|
||||
"count": { "value": 2.0 }
|
||||
}
|
||||
});
|
||||
@@ -1411,7 +1411,7 @@ mod tests {
|
||||
let expected = json!({
|
||||
"electronics": {
|
||||
"doc_count": 0,
|
||||
"avg_price": { "value": null, "sum": 0.0, "count": 0 }
|
||||
"avg_price": { "value": null }
|
||||
}
|
||||
});
|
||||
|
||||
@@ -1698,15 +1698,13 @@ mod tests {
|
||||
let filter_expected = json!({
|
||||
"electronics": {
|
||||
"doc_count": 2,
|
||||
"avg_price": { "value": 899.0, "sum": 1798.0, "count": 2 }
|
||||
"avg_price": { "value": 899.0 }
|
||||
}
|
||||
});
|
||||
|
||||
let separate_expected = json!({
|
||||
"result": {
|
||||
"value": 899.0,
|
||||
"sum": 1798.0,
|
||||
"count": 2
|
||||
"value": 899.0
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
@@ -1222,9 +1222,7 @@ mod tests {
|
||||
res["histogram"]["buckets"][0],
|
||||
json!({
|
||||
"avg": {
|
||||
"value": Value::Null,
|
||||
"sum": 0.0,
|
||||
"count": 0
|
||||
"value": Value::Null
|
||||
},
|
||||
"doc_count": 0,
|
||||
"key": 2.0,
|
||||
|
||||
@@ -19,9 +19,8 @@ use super::bucket::{
|
||||
GetDocCount, Order, OrderTarget, RangeAggregation, TermsAggregation,
|
||||
};
|
||||
use super::metric::{
|
||||
AverageMetricResult, CardinalityMetricResult, IntermediateAverage, IntermediateCount,
|
||||
IntermediateExtendedStats, IntermediateMax, IntermediateMin, IntermediateStats,
|
||||
IntermediateSum, PercentilesCollector, TopHitsTopNComputer,
|
||||
IntermediateAverage, IntermediateCount, IntermediateExtendedStats, IntermediateMax,
|
||||
IntermediateMin, IntermediateStats, IntermediateSum, PercentilesCollector, TopHitsTopNComputer,
|
||||
};
|
||||
use super::segment_agg_result::AggregationLimitsGuard;
|
||||
use super::{format_date, AggregationError, Key, SerializedKey};
|
||||
@@ -91,6 +90,19 @@ impl From<IntermediateKey> for Key {
|
||||
|
||||
impl Eq for IntermediateKey {}
|
||||
|
||||
impl std::fmt::Display for IntermediateKey {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
IntermediateKey::Str(val) => f.write_str(val),
|
||||
IntermediateKey::F64(val) => f.write_str(&val.to_string()),
|
||||
IntermediateKey::U64(val) => f.write_str(&val.to_string()),
|
||||
IntermediateKey::I64(val) => f.write_str(&val.to_string()),
|
||||
IntermediateKey::Bool(val) => f.write_str(&val.to_string()),
|
||||
IntermediateKey::IpAddr(val) => f.write_str(&val.to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::hash::Hash for IntermediateKey {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
core::mem::discriminant(self).hash(state);
|
||||
@@ -106,6 +118,21 @@ impl std::hash::Hash for IntermediateKey {
|
||||
}
|
||||
|
||||
impl IntermediateAggregationResults {
|
||||
/// Returns a reference to the intermediate aggregation result for the given key.
|
||||
pub fn get(&self, key: &str) -> Option<&IntermediateAggregationResult> {
|
||||
self.aggs_res.get(key)
|
||||
}
|
||||
|
||||
/// Removes and returns the intermediate aggregation result for the given key.
|
||||
pub fn remove(&mut self, key: &str) -> Option<IntermediateAggregationResult> {
|
||||
self.aggs_res.remove(key)
|
||||
}
|
||||
|
||||
/// Returns an iterator over the keys in the intermediate aggregation results.
|
||||
pub fn keys(&self) -> impl Iterator<Item = &String> {
|
||||
self.aggs_res.keys()
|
||||
}
|
||||
|
||||
/// Add a result
|
||||
pub fn push(&mut self, key: String, value: IntermediateAggregationResult) -> crate::Result<()> {
|
||||
let entry = self.aggs_res.entry(key);
|
||||
@@ -326,11 +353,7 @@ impl IntermediateMetricResult {
|
||||
fn into_final_metric_result(self, req: &Aggregation) -> MetricResult {
|
||||
match self {
|
||||
IntermediateMetricResult::Average(intermediate_avg) => {
|
||||
MetricResult::Average(AverageMetricResult {
|
||||
value: intermediate_avg.finalize(),
|
||||
sum: intermediate_avg.sum(),
|
||||
count: intermediate_avg.count(),
|
||||
})
|
||||
MetricResult::Average(intermediate_avg.finalize().into())
|
||||
}
|
||||
IntermediateMetricResult::Count(intermediate_count) => {
|
||||
MetricResult::Count(intermediate_count.finalize().into())
|
||||
@@ -358,11 +381,7 @@ impl IntermediateMetricResult {
|
||||
MetricResult::TopHits(top_hits.into_final_result())
|
||||
}
|
||||
IntermediateMetricResult::Cardinality(cardinality) => {
|
||||
let value = cardinality.finalize();
|
||||
MetricResult::Cardinality(CardinalityMetricResult {
|
||||
value,
|
||||
sketch: Some(cardinality),
|
||||
})
|
||||
MetricResult::Cardinality(cardinality.finalize().into())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -648,6 +667,21 @@ pub struct IntermediateTermBucketResult {
|
||||
}
|
||||
|
||||
impl IntermediateTermBucketResult {
|
||||
/// Returns a reference to the map of bucket entries keyed by [`IntermediateKey`].
|
||||
pub fn entries(&self) -> &FxHashMap<IntermediateKey, IntermediateTermBucketEntry> {
|
||||
&self.entries
|
||||
}
|
||||
|
||||
/// Returns the count of documents not included in the returned buckets.
|
||||
pub fn sum_other_doc_count(&self) -> u64 {
|
||||
self.sum_other_doc_count
|
||||
}
|
||||
|
||||
/// Returns the upper bound of the error on document counts in the returned buckets.
|
||||
pub fn doc_count_error_upper_bound(&self) -> u64 {
|
||||
self.doc_count_error_upper_bound
|
||||
}
|
||||
|
||||
pub(crate) fn into_final_result(
|
||||
self,
|
||||
req: &TermsAggregation,
|
||||
|
||||
@@ -55,6 +55,12 @@ impl IntermediateAverage {
|
||||
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
|
||||
Self { stats }
|
||||
}
|
||||
|
||||
/// Returns a reference to the underlying [`IntermediateStats`].
|
||||
pub fn stats(&self) -> &IntermediateStats {
|
||||
&self.stats
|
||||
}
|
||||
|
||||
/// Merges the other intermediate result into self.
|
||||
pub fn merge_fruits(&mut self, other: IntermediateAverage) {
|
||||
self.stats.merge_fruits(other.stats);
|
||||
@@ -63,16 +69,6 @@ impl IntermediateAverage {
|
||||
pub fn finalize(&self) -> Option<f64> {
|
||||
self.stats.finalize().avg
|
||||
}
|
||||
|
||||
/// Returns the sum of all collected values.
|
||||
pub fn sum(&self) -> f64 {
|
||||
self.stats.sum
|
||||
}
|
||||
|
||||
/// Returns the count of all collected values.
|
||||
pub fn count(&self) -> u64 {
|
||||
self.stats.count
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{BuildHasher, Hasher};
|
||||
use std::hash::Hash;
|
||||
|
||||
use columnar::column_values::CompactSpaceU64Accessor;
|
||||
use columnar::{Column, ColumnType, Dictionary, StrColumn};
|
||||
use common::f64_to_u64;
|
||||
use hyperloglogplus::{HyperLogLog, HyperLogLogPlus};
|
||||
use datasketches::hll::{HllSketch, HllType, HllUnion};
|
||||
use rustc_hash::FxHashSet;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
|
||||
use crate::aggregation::agg_data::AggregationsSegmentCtx;
|
||||
use crate::aggregation::intermediate_agg_result::{
|
||||
@@ -16,29 +15,17 @@ use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
|
||||
use crate::aggregation::*;
|
||||
use crate::TantivyError;
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
struct BuildSaltedHasher {
|
||||
salt: u8,
|
||||
}
|
||||
|
||||
impl BuildHasher for BuildSaltedHasher {
|
||||
type Hasher = DefaultHasher;
|
||||
|
||||
fn build_hasher(&self) -> Self::Hasher {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
hasher.write_u8(self.salt);
|
||||
|
||||
hasher
|
||||
}
|
||||
}
|
||||
/// Log2 of the number of registers. Must match the Java `Union(LOG2M)` where LOG2M=11.
|
||||
/// 2^11 = 2048 registers.
|
||||
const LG_K: u8 = 11;
|
||||
|
||||
/// # Cardinality
|
||||
///
|
||||
/// The cardinality aggregation allows for computing an estimate
|
||||
/// of the number of different values in a data set based on the
|
||||
/// HyperLogLog++ algorithm. This is particularly useful for understanding the
|
||||
/// uniqueness of values in a large dataset where counting each unique value
|
||||
/// individually would be computationally expensive.
|
||||
/// Apache DataSketches HyperLogLog algorithm. This is particularly useful for
|
||||
/// understanding the uniqueness of values in a large dataset where counting
|
||||
/// each unique value individually would be computationally expensive.
|
||||
///
|
||||
/// For example, you might use a cardinality aggregation to estimate the number
|
||||
/// of unique visitors to a website by aggregating on a field that contains
|
||||
@@ -184,7 +171,7 @@ impl SegmentCardinalityCollectorBucket {
|
||||
|
||||
term_ids.sort_unstable();
|
||||
dict.sorted_ords_to_term_cb(term_ids.iter().map(|term| *term as u64), |term| {
|
||||
self.cardinality.sketch.insert_any(&term);
|
||||
self.cardinality.insert(term);
|
||||
Ok(())
|
||||
})?;
|
||||
if has_missing {
|
||||
@@ -195,17 +182,17 @@ impl SegmentCardinalityCollectorBucket {
|
||||
);
|
||||
match missing_key {
|
||||
Key::Str(missing) => {
|
||||
self.cardinality.sketch.insert_any(&missing);
|
||||
self.cardinality.insert(missing.as_str());
|
||||
}
|
||||
Key::F64(val) => {
|
||||
let val = f64_to_u64(*val);
|
||||
self.cardinality.sketch.insert_any(&val);
|
||||
self.cardinality.insert(val);
|
||||
}
|
||||
Key::U64(val) => {
|
||||
self.cardinality.sketch.insert_any(&val);
|
||||
self.cardinality.insert(*val);
|
||||
}
|
||||
Key::I64(val) => {
|
||||
self.cardinality.sketch.insert_any(&val);
|
||||
self.cardinality.insert(*val);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -296,11 +283,11 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector {
|
||||
})?;
|
||||
for val in col_block_accessor.iter_vals() {
|
||||
let val: u128 = compact_space_accessor.compact_to_u128(val as u32);
|
||||
bucket.cardinality.sketch.insert_any(&val);
|
||||
bucket.cardinality.insert(val);
|
||||
}
|
||||
} else {
|
||||
for val in col_block_accessor.iter_vals() {
|
||||
bucket.cardinality.sketch.insert_any(&val);
|
||||
bucket.cardinality.insert(val);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -321,11 +308,17 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
/// The percentiles collector used during segment collection and for merging results.
|
||||
#[derive(Clone, Debug)]
|
||||
/// The cardinality collector used during segment collection and for merging results.
|
||||
/// Uses Apache DataSketches HLL (lg_k=11) for compatibility with Datadog's event query.
|
||||
pub struct CardinalityCollector {
|
||||
sketch: HyperLogLogPlus<u64, BuildSaltedHasher>,
|
||||
sketch: HllSketch,
|
||||
/// Salt derived from `ColumnType`, used to differentiate values of different column types
|
||||
/// that map to the same u64 (e.g. bool `false` = 0 vs i64 `0`).
|
||||
/// Not serialized — only needed during insertion, not after sketch registers are populated.
|
||||
salt: u8,
|
||||
}
|
||||
|
||||
impl Default for CardinalityCollector {
|
||||
fn default() -> Self {
|
||||
Self::new(0)
|
||||
@@ -338,25 +331,52 @@ impl PartialEq for CardinalityCollector {
|
||||
}
|
||||
}
|
||||
|
||||
impl CardinalityCollector {
|
||||
/// Compute the final cardinality estimate.
|
||||
pub fn finalize(&self) -> Option<f64> {
|
||||
Some(self.sketch.clone().count().trunc())
|
||||
impl Serialize for CardinalityCollector {
|
||||
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
|
||||
let bytes = self.sketch.serialize();
|
||||
serializer.serialize_bytes(&bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for CardinalityCollector {
|
||||
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
|
||||
let bytes: Vec<u8> = Deserialize::deserialize(deserializer)?;
|
||||
let sketch = HllSketch::deserialize(&bytes).map_err(serde::de::Error::custom)?;
|
||||
Ok(Self { sketch, salt: 0 })
|
||||
}
|
||||
}
|
||||
|
||||
impl CardinalityCollector {
|
||||
fn new(salt: u8) -> Self {
|
||||
Self {
|
||||
sketch: HyperLogLogPlus::new(16, BuildSaltedHasher { salt }).unwrap(),
|
||||
sketch: HllSketch::new(LG_K, HllType::Hll4),
|
||||
salt,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn merge_fruits(&mut self, right: CardinalityCollector) -> crate::Result<()> {
|
||||
self.sketch.merge(&right.sketch).map_err(|err| {
|
||||
TantivyError::AggregationError(AggregationError::InternalError(format!(
|
||||
"Error while merging cardinality {err:?}"
|
||||
)))
|
||||
})?;
|
||||
/// Insert a value into the HLL sketch, salted by the column type.
|
||||
/// The salt ensures that identical u64 values from different column types
|
||||
/// (e.g. bool `false` vs i64 `0`) are counted as distinct.
|
||||
pub(crate) fn insert<T: Hash>(&mut self, value: T) {
|
||||
self.sketch.update((self.salt, value));
|
||||
}
|
||||
|
||||
/// Compute the final cardinality estimate.
|
||||
pub fn finalize(self) -> Option<f64> {
|
||||
Some(self.sketch.estimate().trunc())
|
||||
}
|
||||
|
||||
/// Serialize the HLL sketch to its compact binary representation.
|
||||
/// This format is compatible with Apache DataSketches Java (`HllSketch.heapify()`).
|
||||
pub fn to_sketch_bytes(&self) -> Vec<u8> {
|
||||
self.sketch.serialize()
|
||||
}
|
||||
|
||||
pub(crate) fn merge_fruits(&mut self, right: CardinalityCollector) -> crate::Result<()> {
|
||||
let mut union = HllUnion::new(LG_K);
|
||||
union.update(&self.sketch);
|
||||
union.update(&right.sketch);
|
||||
self.sketch = union.get_result(HllType::Hll4);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -518,4 +538,75 @@ mod tests {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cardinality_collector_serde_roundtrip() {
|
||||
use super::CardinalityCollector;
|
||||
|
||||
let mut collector = CardinalityCollector::default();
|
||||
collector.insert("hello");
|
||||
collector.insert("world");
|
||||
collector.insert("hello"); // duplicate
|
||||
|
||||
let serialized = serde_json::to_vec(&collector).unwrap();
|
||||
let deserialized: CardinalityCollector = serde_json::from_slice(&serialized).unwrap();
|
||||
|
||||
let original_estimate = collector.finalize().unwrap();
|
||||
let roundtrip_estimate = deserialized.finalize().unwrap();
|
||||
assert_eq!(original_estimate, roundtrip_estimate);
|
||||
assert_eq!(original_estimate, 2.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cardinality_collector_merge() {
|
||||
use super::CardinalityCollector;
|
||||
|
||||
let mut left = CardinalityCollector::default();
|
||||
left.insert("a");
|
||||
left.insert("b");
|
||||
|
||||
let mut right = CardinalityCollector::default();
|
||||
right.insert("b");
|
||||
right.insert("c");
|
||||
|
||||
left.merge_fruits(right).unwrap();
|
||||
let estimate = left.finalize().unwrap();
|
||||
assert_eq!(estimate, 3.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cardinality_collector_serialize_deserialize_binary() {
|
||||
use datasketches::hll::HllSketch;
|
||||
|
||||
use super::CardinalityCollector;
|
||||
|
||||
let mut collector = CardinalityCollector::default();
|
||||
collector.insert("apple");
|
||||
collector.insert("banana");
|
||||
collector.insert("cherry");
|
||||
|
||||
let bytes = collector.to_sketch_bytes();
|
||||
let deserialized = HllSketch::deserialize(&bytes).unwrap();
|
||||
assert!((deserialized.estimate() - 3.0).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cardinality_collector_salt_differentiates_types() {
|
||||
use super::CardinalityCollector;
|
||||
|
||||
// Without salt, same u64 value from different column types would collide
|
||||
let mut collector_bool = CardinalityCollector::new(5); // e.g. ColumnType::Bool
|
||||
collector_bool.insert(0u64); // false
|
||||
collector_bool.insert(1u64); // true
|
||||
|
||||
let mut collector_i64 = CardinalityCollector::new(2); // e.g. ColumnType::I64
|
||||
collector_i64.insert(0u64);
|
||||
collector_i64.insert(1u64);
|
||||
|
||||
// Merge them
|
||||
collector_bool.merge_fruits(collector_i64).unwrap();
|
||||
let estimate = collector_bool.finalize().unwrap();
|
||||
// Should be 4 because salt makes (5, 0) != (2, 0) and (5, 1) != (2, 1)
|
||||
assert_eq!(estimate, 4.0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -93,41 +93,6 @@ impl From<Option<f64>> for SingleMetricResult {
|
||||
}
|
||||
}
|
||||
|
||||
/// Average metric result with intermediate data for merging.
|
||||
///
|
||||
/// Unlike [`SingleMetricResult`], this struct includes the raw `sum` and `count`
|
||||
/// values that can be used for multi-step query merging.
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct AverageMetricResult {
|
||||
/// The computed average value. None if no documents matched.
|
||||
pub value: Option<f64>,
|
||||
/// The sum of all values (for multi-step merging).
|
||||
pub sum: f64,
|
||||
/// The count of all values (for multi-step merging).
|
||||
pub count: u64,
|
||||
}
|
||||
|
||||
/// Cardinality metric result with computed value and raw HLL sketch for multi-step merging.
|
||||
///
|
||||
/// The `value` field contains the computed cardinality estimate.
|
||||
/// The `sketch` field contains the serialized HyperLogLog++ sketch that can be used
|
||||
/// for merging results across multiple query steps.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct CardinalityMetricResult {
|
||||
/// The computed cardinality estimate.
|
||||
pub value: Option<f64>,
|
||||
/// The serialized HyperLogLog++ sketch for multi-step merging.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub sketch: Option<CardinalityCollector>,
|
||||
}
|
||||
|
||||
impl PartialEq for CardinalityMetricResult {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
// Only compare values, not sketch (sketch comparison is complex)
|
||||
self.value == other.value
|
||||
}
|
||||
}
|
||||
|
||||
/// This is the wrapper of percentile entries, which can be vector or hashmap
|
||||
/// depending on if it's keyed or not.
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
@@ -142,30 +107,20 @@ pub enum PercentileValues {
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
/// The entry when requesting percentiles with keyed: false
|
||||
pub struct PercentileValuesVecEntry {
|
||||
key: f64,
|
||||
value: f64,
|
||||
/// Percentile
|
||||
pub key: f64,
|
||||
|
||||
/// Value at the percentile
|
||||
pub value: f64,
|
||||
}
|
||||
|
||||
/// Percentiles metric result with computed values and raw sketch for multi-step merging.
|
||||
/// Single-metric aggregations use this common result structure.
|
||||
///
|
||||
/// The `values` field contains the computed percentile values.
|
||||
/// The `sketch` field contains the serialized DDSketch that can be used for merging
|
||||
/// results across multiple query steps.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
/// Main reason to wrap it in value is to match elasticsearch output structure.
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct PercentilesMetricResult {
|
||||
/// The computed percentile values.
|
||||
/// The result of the percentile metric.
|
||||
pub values: PercentileValues,
|
||||
/// The serialized DDSketch for multi-step merging.
|
||||
/// This is the raw sketch data that can be deserialized and merged with other sketches.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub sketch: Option<PercentilesCollector>,
|
||||
}
|
||||
|
||||
impl PartialEq for PercentilesMetricResult {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
// Only compare values, not sketch (sketch comparison is complex)
|
||||
self.values == other.values
|
||||
}
|
||||
}
|
||||
|
||||
/// The top_hits metric results entry
|
||||
@@ -246,105 +201,4 @@ mod tests {
|
||||
assert_eq!(aggregations_res_json["price_min"]["value"], 0.0);
|
||||
assert_eq!(aggregations_res_json["price_sum"]["value"], 15.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_average_returns_sum_and_count() {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let field_options = NumericOptions::default().set_fast();
|
||||
let field = schema_builder.add_f64_field("price", field_options);
|
||||
let index = Index::create_in_ram(schema_builder.build());
|
||||
let mut index_writer: IndexWriter = index.writer_for_tests().unwrap();
|
||||
|
||||
// Add documents with values 0, 1, 2, 3, 4, 5
|
||||
// sum = 15, count = 6, avg = 2.5
|
||||
for i in 0..6 {
|
||||
index_writer
|
||||
.add_document(doc!(
|
||||
field => i as f64,
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
index_writer.commit().unwrap();
|
||||
|
||||
let aggregations_json = r#"{ "price_avg": { "avg": { "field": "price" } } }"#;
|
||||
let aggregations: Aggregations = serde_json::from_str(aggregations_json).unwrap();
|
||||
let collector = AggregationCollector::from_aggs(aggregations, Default::default());
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
let aggregations_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap();
|
||||
let aggregations_res_json = serde_json::to_value(aggregations_res).unwrap();
|
||||
|
||||
// Verify all three fields are present and correct
|
||||
assert_eq!(aggregations_res_json["price_avg"]["value"], 2.5);
|
||||
assert_eq!(aggregations_res_json["price_avg"]["sum"], 15.0);
|
||||
assert_eq!(aggregations_res_json["price_avg"]["count"], 6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_percentiles_returns_sketch() {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let field_options = NumericOptions::default().set_fast();
|
||||
let field = schema_builder.add_f64_field("latency", field_options);
|
||||
let index = Index::create_in_ram(schema_builder.build());
|
||||
let mut index_writer: IndexWriter = index.writer_for_tests().unwrap();
|
||||
|
||||
// Add documents with latency values
|
||||
for i in 0..100 {
|
||||
index_writer
|
||||
.add_document(doc!(
|
||||
field => i as f64,
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
index_writer.commit().unwrap();
|
||||
|
||||
let aggregations_json =
|
||||
r#"{ "latency_percentiles": { "percentiles": { "field": "latency" } } }"#;
|
||||
let aggregations: Aggregations = serde_json::from_str(aggregations_json).unwrap();
|
||||
let collector = AggregationCollector::from_aggs(aggregations, Default::default());
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
let aggregations_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap();
|
||||
let aggregations_res_json = serde_json::to_value(aggregations_res).unwrap();
|
||||
|
||||
// Verify percentile values are present
|
||||
assert!(aggregations_res_json["latency_percentiles"]["values"].is_object());
|
||||
// Verify sketch is present (serialized DDSketch)
|
||||
assert!(aggregations_res_json["latency_percentiles"]["sketch"].is_object());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cardinality_returns_sketch() {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let field_options = NumericOptions::default().set_fast();
|
||||
let field = schema_builder.add_u64_field("user_id", field_options);
|
||||
let index = Index::create_in_ram(schema_builder.build());
|
||||
let mut index_writer: IndexWriter = index.writer_for_tests().unwrap();
|
||||
|
||||
// Add documents with some duplicate user_ids
|
||||
for i in 0..50 {
|
||||
index_writer
|
||||
.add_document(doc!(
|
||||
field => (i % 10) as u64, // 10 unique values
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
index_writer.commit().unwrap();
|
||||
|
||||
let aggregations_json = r#"{ "unique_users": { "cardinality": { "field": "user_id" } } }"#;
|
||||
let aggregations: Aggregations = serde_json::from_str(aggregations_json).unwrap();
|
||||
let collector = AggregationCollector::from_aggs(aggregations, Default::default());
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
let aggregations_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap();
|
||||
let aggregations_res_json = serde_json::to_value(aggregations_res).unwrap();
|
||||
|
||||
// Verify cardinality value is present and approximately correct
|
||||
let cardinality = aggregations_res_json["unique_users"]["value"]
|
||||
.as_f64()
|
||||
.unwrap();
|
||||
assert!(cardinality >= 9.0 && cardinality <= 11.0); // HLL is approximate
|
||||
// Verify sketch is present (serialized HyperLogLog++)
|
||||
assert!(aggregations_res_json["unique_users"]["sketch"].is_object());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -178,9 +178,6 @@ fn format_percentile(percentile: f64) -> String {
|
||||
impl PercentilesCollector {
|
||||
/// Convert result into final result. This will query the quantils from the underlying quantil
|
||||
/// collector.
|
||||
///
|
||||
/// The result includes both the computed percentile values and the raw DDSketch
|
||||
/// for multi-step query merging.
|
||||
pub fn into_final_result(self, req: &PercentilesAggregationReq) -> PercentilesMetricResult {
|
||||
let percentiles: &[f64] = req
|
||||
.percents
|
||||
@@ -213,15 +210,7 @@ impl PercentilesCollector {
|
||||
.collect(),
|
||||
)
|
||||
};
|
||||
PercentilesMetricResult {
|
||||
values,
|
||||
sketch: Some(self),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a reference to the underlying DDSketch.
|
||||
pub fn sketch(&self) -> &sketches_ddsketch::DDSketch {
|
||||
&self.sketch
|
||||
PercentilesMetricResult { values }
|
||||
}
|
||||
|
||||
fn new() -> Self {
|
||||
|
||||
@@ -110,6 +110,16 @@ impl Default for IntermediateStats {
|
||||
}
|
||||
|
||||
impl IntermediateStats {
|
||||
/// Returns the number of values collected.
|
||||
pub fn count(&self) -> u64 {
|
||||
self.count
|
||||
}
|
||||
|
||||
/// Returns the sum of all values collected.
|
||||
pub fn sum(&self) -> f64 {
|
||||
self.sum
|
||||
}
|
||||
|
||||
/// Merges the other stats intermediate result into self.
|
||||
pub fn merge_fruits(&mut self, other: IntermediateStats) {
|
||||
self.count += other.count;
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
mod order;
|
||||
mod sort_by_bytes;
|
||||
mod sort_by_erased_type;
|
||||
mod sort_by_score;
|
||||
mod sort_by_static_fast_value;
|
||||
@@ -6,6 +7,7 @@ mod sort_by_string;
|
||||
mod sort_key_computer;
|
||||
|
||||
pub use order::*;
|
||||
pub use sort_by_bytes::SortByBytes;
|
||||
pub use sort_by_erased_type::SortByErasedType;
|
||||
pub use sort_by_score::SortBySimilarityScore;
|
||||
pub use sort_by_static_fast_value::SortByStaticFastValue;
|
||||
|
||||
168
src/collector/sort_key/sort_by_bytes.rs
Normal file
168
src/collector/sort_key/sort_by_bytes.rs
Normal file
@@ -0,0 +1,168 @@
|
||||
use columnar::BytesColumn;
|
||||
|
||||
use crate::collector::sort_key::NaturalComparator;
|
||||
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
|
||||
use crate::termdict::TermOrdinal;
|
||||
use crate::{DocId, Score};
|
||||
|
||||
/// Sort by the first value of a bytes column.
|
||||
///
|
||||
/// If the field is multivalued, only the first value is considered.
|
||||
///
|
||||
/// Documents that do not have this value are still considered.
|
||||
/// Their sort key will simply be `None`.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SortByBytes {
|
||||
column_name: String,
|
||||
}
|
||||
|
||||
impl SortByBytes {
|
||||
/// Creates a new sort by bytes sort key computer.
|
||||
pub fn for_field(column_name: impl ToString) -> Self {
|
||||
SortByBytes {
|
||||
column_name: column_name.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SortKeyComputer for SortByBytes {
|
||||
type SortKey = Option<Vec<u8>>;
|
||||
type Child = ByBytesColumnSegmentSortKeyComputer;
|
||||
type Comparator = NaturalComparator;
|
||||
|
||||
fn segment_sort_key_computer(
|
||||
&self,
|
||||
segment_reader: &crate::SegmentReader,
|
||||
) -> crate::Result<Self::Child> {
|
||||
let bytes_column_opt = segment_reader.fast_fields().bytes(&self.column_name)?;
|
||||
Ok(ByBytesColumnSegmentSortKeyComputer { bytes_column_opt })
|
||||
}
|
||||
}
|
||||
|
||||
/// Segment-level sort key computer for bytes columns.
|
||||
pub struct ByBytesColumnSegmentSortKeyComputer {
|
||||
bytes_column_opt: Option<BytesColumn>,
|
||||
}
|
||||
|
||||
impl SegmentSortKeyComputer for ByBytesColumnSegmentSortKeyComputer {
|
||||
type SortKey = Option<Vec<u8>>;
|
||||
type SegmentSortKey = Option<TermOrdinal>;
|
||||
type SegmentComparator = NaturalComparator;
|
||||
|
||||
#[inline(always)]
|
||||
fn segment_sort_key(&mut self, doc: DocId, _score: Score) -> Option<TermOrdinal> {
|
||||
let bytes_column = self.bytes_column_opt.as_ref()?;
|
||||
bytes_column.ords().first(doc)
|
||||
}
|
||||
|
||||
fn convert_segment_sort_key(&self, term_ord_opt: Option<TermOrdinal>) -> Option<Vec<u8>> {
|
||||
// TODO: Individual lookups to the dictionary like this are very likely to repeatedly
|
||||
// decompress the same blocks. See https://github.com/quickwit-oss/tantivy/issues/2776
|
||||
let term_ord = term_ord_opt?;
|
||||
let bytes_column = self.bytes_column_opt.as_ref()?;
|
||||
let mut bytes = Vec::new();
|
||||
bytes_column
|
||||
.dictionary()
|
||||
.ord_to_term(term_ord, &mut bytes)
|
||||
.ok()?;
|
||||
Some(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::SortByBytes;
|
||||
use crate::collector::TopDocs;
|
||||
use crate::query::AllQuery;
|
||||
use crate::schema::{BytesOptions, Schema, FAST, INDEXED};
|
||||
use crate::{Index, IndexWriter, Order, TantivyDocument};
|
||||
|
||||
#[test]
|
||||
fn test_sort_by_bytes_asc() -> crate::Result<()> {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let bytes_field = schema_builder
|
||||
.add_bytes_field("data", BytesOptions::default().set_fast().set_indexed());
|
||||
let id_field = schema_builder.add_u64_field("id", FAST | INDEXED);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut index_writer: IndexWriter = index.writer_for_tests()?;
|
||||
|
||||
// Insert documents with byte values in non-sorted order
|
||||
let test_data: Vec<(u64, Vec<u8>)> = vec![
|
||||
(1, vec![0x02, 0x00]),
|
||||
(2, vec![0x00, 0x10]),
|
||||
(3, vec![0x01, 0x00]),
|
||||
(4, vec![0x00, 0x20]),
|
||||
];
|
||||
|
||||
for (id, bytes) in &test_data {
|
||||
let mut doc = TantivyDocument::new();
|
||||
doc.add_u64(id_field, *id);
|
||||
doc.add_bytes(bytes_field, bytes);
|
||||
index_writer.add_document(doc)?;
|
||||
}
|
||||
index_writer.commit()?;
|
||||
|
||||
let reader = index.reader()?;
|
||||
let searcher = reader.searcher();
|
||||
|
||||
// Sort ascending by bytes
|
||||
let top_docs =
|
||||
TopDocs::with_limit(10).order_by((SortByBytes::for_field("data"), Order::Asc));
|
||||
let results: Vec<(Option<Vec<u8>>, _)> = searcher.search(&AllQuery, &top_docs)?;
|
||||
|
||||
// Expected order: [0x00,0x10], [0x00,0x20], [0x01,0x00], [0x02,0x00]
|
||||
let sorted_bytes: Vec<Option<Vec<u8>>> = results.into_iter().map(|(b, _)| b).collect();
|
||||
assert_eq!(
|
||||
sorted_bytes,
|
||||
vec![
|
||||
Some(vec![0x00, 0x10]),
|
||||
Some(vec![0x00, 0x20]),
|
||||
Some(vec![0x01, 0x00]),
|
||||
Some(vec![0x02, 0x00]),
|
||||
]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sort_by_bytes_desc() -> crate::Result<()> {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let bytes_field = schema_builder
|
||||
.add_bytes_field("data", BytesOptions::default().set_fast().set_indexed());
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut index_writer: IndexWriter = index.writer_for_tests()?;
|
||||
|
||||
let test_data: Vec<Vec<u8>> = vec![vec![0x00, 0x10], vec![0x02, 0x00], vec![0x01, 0x00]];
|
||||
|
||||
for bytes in &test_data {
|
||||
let mut doc = TantivyDocument::new();
|
||||
doc.add_bytes(bytes_field, bytes);
|
||||
index_writer.add_document(doc)?;
|
||||
}
|
||||
index_writer.commit()?;
|
||||
|
||||
let reader = index.reader()?;
|
||||
let searcher = reader.searcher();
|
||||
|
||||
// Sort descending by bytes
|
||||
let top_docs =
|
||||
TopDocs::with_limit(10).order_by((SortByBytes::for_field("data"), Order::Desc));
|
||||
let results: Vec<(Option<Vec<u8>>, _)> = searcher.search(&AllQuery, &top_docs)?;
|
||||
|
||||
// Expected order (descending): [0x02,0x00], [0x01,0x00], [0x00,0x10]
|
||||
let sorted_bytes: Vec<Option<Vec<u8>>> = results.into_iter().map(|(b, _)| b).collect();
|
||||
assert_eq!(
|
||||
sorted_bytes,
|
||||
vec![
|
||||
Some(vec![0x02, 0x00]),
|
||||
Some(vec![0x01, 0x00]),
|
||||
Some(vec![0x00, 0x10]),
|
||||
]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
use columnar::{ColumnType, MonotonicallyMappableToU64};
|
||||
|
||||
use crate::collector::sort_key::{
|
||||
NaturalComparator, SortBySimilarityScore, SortByStaticFastValue, SortByString,
|
||||
NaturalComparator, SortByBytes, SortBySimilarityScore, SortByStaticFastValue, SortByString,
|
||||
};
|
||||
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
|
||||
use crate::fastfield::FastFieldNotAvailableError;
|
||||
@@ -114,6 +114,16 @@ impl SortKeyComputer for SortByErasedType {
|
||||
},
|
||||
})
|
||||
}
|
||||
ColumnType::Bytes => {
|
||||
let computer = SortByBytes::for_field(column_name);
|
||||
let inner = computer.segment_sort_key_computer(segment_reader)?;
|
||||
Box::new(ErasedSegmentSortKeyComputerWrapper {
|
||||
inner,
|
||||
converter: |val: Option<Vec<u8>>| {
|
||||
val.map(OwnedValue::Bytes).unwrap_or(OwnedValue::Null)
|
||||
},
|
||||
})
|
||||
}
|
||||
ColumnType::U64 => {
|
||||
let computer = SortByStaticFastValue::<u64>::for_field(column_name);
|
||||
let inner = computer.segment_sort_key_computer(segment_reader)?;
|
||||
@@ -281,6 +291,65 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sort_by_owned_bytes() {
|
||||
let mut schema_builder = Schema::builder();
|
||||
let data_field = schema_builder.add_bytes_field("data", FAST);
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema);
|
||||
let mut writer = index.writer_for_tests().unwrap();
|
||||
writer
|
||||
.add_document(doc!(data_field => vec![0x03u8, 0x00]))
|
||||
.unwrap();
|
||||
writer
|
||||
.add_document(doc!(data_field => vec![0x01u8, 0x00]))
|
||||
.unwrap();
|
||||
writer
|
||||
.add_document(doc!(data_field => vec![0x02u8, 0x00]))
|
||||
.unwrap();
|
||||
writer.add_document(doc!()).unwrap();
|
||||
writer.commit().unwrap();
|
||||
|
||||
let reader = index.reader().unwrap();
|
||||
let searcher = reader.searcher();
|
||||
|
||||
// Sort descending (Natural - highest first)
|
||||
let collector = TopDocs::with_limit(10)
|
||||
.order_by((SortByErasedType::for_field("data"), ComparatorEnum::Natural));
|
||||
let top_docs = searcher.search(&AllQuery, &collector).unwrap();
|
||||
|
||||
let values: Vec<OwnedValue> = top_docs.into_iter().map(|(key, _)| key).collect();
|
||||
|
||||
assert_eq!(
|
||||
values,
|
||||
vec![
|
||||
OwnedValue::Bytes(vec![0x03, 0x00]),
|
||||
OwnedValue::Bytes(vec![0x02, 0x00]),
|
||||
OwnedValue::Bytes(vec![0x01, 0x00]),
|
||||
OwnedValue::Null
|
||||
]
|
||||
);
|
||||
|
||||
// Sort ascending (ReverseNoneLower - lowest first, nulls last)
|
||||
let collector = TopDocs::with_limit(10).order_by((
|
||||
SortByErasedType::for_field("data"),
|
||||
ComparatorEnum::ReverseNoneLower,
|
||||
));
|
||||
let top_docs = searcher.search(&AllQuery, &collector).unwrap();
|
||||
|
||||
let values: Vec<OwnedValue> = top_docs.into_iter().map(|(key, _)| key).collect();
|
||||
|
||||
assert_eq!(
|
||||
values,
|
||||
vec![
|
||||
OwnedValue::Bytes(vec![0x01, 0x00]),
|
||||
OwnedValue::Bytes(vec![0x02, 0x00]),
|
||||
OwnedValue::Bytes(vec![0x03, 0x00]),
|
||||
OwnedValue::Null
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sort_by_owned_reverse() {
|
||||
let mut schema_builder = Schema::builder();
|
||||
|
||||
@@ -2068,6 +2068,16 @@ mod test {
|
||||
format!("Regex(Field(0), {:#?})", expected_regex).as_str(),
|
||||
false,
|
||||
);
|
||||
let expected_regex2 = tantivy_fst::Regex::new(r".*a").unwrap();
|
||||
test_parse_query_to_logical_ast_helper(
|
||||
"title:(/.*b/ OR /.*a/)",
|
||||
format!(
|
||||
"(Regex(Field(0), {:#?}) Regex(Field(0), {:#?}))",
|
||||
expected_regex, expected_regex2
|
||||
)
|
||||
.as_str(),
|
||||
false,
|
||||
);
|
||||
|
||||
// Invalid field
|
||||
let err = parse_query_to_logical_ast("float:/.*b/", false).unwrap_err();
|
||||
|
||||
@@ -19,7 +19,8 @@ pub(crate) fn is_type_valid_for_fastfield_range_query(typ: Type) -> bool {
|
||||
| Type::Bool
|
||||
| Type::Date
|
||||
| Type::Json
|
||||
| Type::IpAddr => true,
|
||||
Type::Facet | Type::Bytes => false,
|
||||
| Type::IpAddr
|
||||
| Type::Bytes => true,
|
||||
Type::Facet => false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,8 +6,8 @@ use std::net::Ipv6Addr;
|
||||
use std::ops::{Bound, RangeInclusive};
|
||||
|
||||
use columnar::{
|
||||
Cardinality, Column, ColumnType, MonotonicallyMappableToU128, MonotonicallyMappableToU64,
|
||||
NumericalType, StrColumn,
|
||||
BytesColumn, Cardinality, Column, ColumnType, MonotonicallyMappableToU128,
|
||||
MonotonicallyMappableToU64, NumericalType, StrColumn,
|
||||
};
|
||||
use common::bounds::{BoundsRange, TransformBound};
|
||||
|
||||
@@ -163,6 +163,25 @@ impl Weight for FastFieldRangeWeight {
|
||||
};
|
||||
let dict = str_dict_column.dictionary();
|
||||
|
||||
let bounds = self.bounds.map_bound(get_value_bytes);
|
||||
// Get term ids for terms
|
||||
let (lower_bound, upper_bound) =
|
||||
dict.term_bounds_to_ord(bounds.lower_bound, bounds.upper_bound)?;
|
||||
let fast_field_reader = reader.fast_fields();
|
||||
let Some((column, _col_type)) =
|
||||
fast_field_reader.u64_lenient_for_type(None, &field_name)?
|
||||
else {
|
||||
return Ok(Box::new(EmptyScorer));
|
||||
};
|
||||
search_on_u64_ff(column, boost, BoundsRange::new(lower_bound, upper_bound))
|
||||
} else if field_type.is_bytes() {
|
||||
let Some(bytes_column): Option<BytesColumn> =
|
||||
reader.fast_fields().bytes(&field_name)?
|
||||
else {
|
||||
return Ok(Box::new(EmptyScorer));
|
||||
};
|
||||
let dict = bytes_column.dictionary();
|
||||
|
||||
let bounds = self.bounds.map_bound(get_value_bytes);
|
||||
// Get term ids for terms
|
||||
let (lower_bound, upper_bound) =
|
||||
@@ -1402,6 +1421,66 @@ mod tests {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bytes_field_ff_range_query() -> crate::Result<()> {
|
||||
use crate::schema::BytesOptions;
|
||||
|
||||
let mut schema_builder = Schema::builder();
|
||||
let bytes_field = schema_builder
|
||||
.add_bytes_field("data", BytesOptions::default().set_fast().set_indexed());
|
||||
let schema = schema_builder.build();
|
||||
let index = Index::create_in_ram(schema.clone());
|
||||
let mut index_writer: IndexWriter = index.writer_for_tests()?;
|
||||
|
||||
// Insert documents with lexicographically sortable byte values
|
||||
// Using simple byte sequences that have clear ordering
|
||||
let values: Vec<Vec<u8>> = vec![
|
||||
vec![0x00, 0x10],
|
||||
vec![0x00, 0x20],
|
||||
vec![0x00, 0x30],
|
||||
vec![0x01, 0x00],
|
||||
vec![0x01, 0x10],
|
||||
vec![0x02, 0x00],
|
||||
];
|
||||
|
||||
for value in &values {
|
||||
let mut doc = TantivyDocument::new();
|
||||
doc.add_bytes(bytes_field, value);
|
||||
index_writer.add_document(doc)?;
|
||||
}
|
||||
index_writer.commit()?;
|
||||
|
||||
let reader = index.reader()?;
|
||||
let searcher = reader.searcher();
|
||||
|
||||
// Test: Range query [0x00, 0x20] to [0x01, 0x00] (inclusive)
|
||||
// Should match: [0x00, 0x20], [0x00, 0x30], [0x01, 0x00]
|
||||
let lower = Term::from_field_bytes(bytes_field, &[0x00, 0x20]);
|
||||
let upper = Term::from_field_bytes(bytes_field, &[0x01, 0x00]);
|
||||
let range_query = RangeQuery::new(Bound::Included(lower), Bound::Included(upper));
|
||||
let count = searcher.search(&range_query, &Count)?;
|
||||
assert_eq!(
|
||||
count, 3,
|
||||
"Expected 3 documents in range [0x00,0x20] to [0x01,0x00]"
|
||||
);
|
||||
|
||||
// Test: Range query > [0x01, 0x00] (exclusive lower bound)
|
||||
// Should match: [0x01, 0x10], [0x02, 0x00]
|
||||
let lower = Term::from_field_bytes(bytes_field, &[0x01, 0x00]);
|
||||
let range_query = RangeQuery::new(Bound::Excluded(lower), Bound::Unbounded);
|
||||
let count = searcher.search(&range_query, &Count)?;
|
||||
assert_eq!(count, 2, "Expected 2 documents > [0x01,0x00]");
|
||||
|
||||
// Test: Range query < [0x00, 0x30] (exclusive upper bound)
|
||||
// Should match: [0x00, 0x10], [0x00, 0x20]
|
||||
let upper = Term::from_field_bytes(bytes_field, &[0x00, 0x30]);
|
||||
let range_query = RangeQuery::new(Bound::Unbounded, Bound::Excluded(upper));
|
||||
let count = searcher.search(&range_query, &Count)?;
|
||||
assert_eq!(count, 2, "Expected 2 documents < [0x00,0x30]");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -223,6 +223,11 @@ impl FieldType {
|
||||
matches!(self, FieldType::Str(_))
|
||||
}
|
||||
|
||||
/// returns true if this is a bytes field
|
||||
pub fn is_bytes(&self) -> bool {
|
||||
matches!(self, FieldType::Bytes(_))
|
||||
}
|
||||
|
||||
/// returns true if this is an date field
|
||||
pub fn is_date(&self) -> bool {
|
||||
matches!(self, FieldType::Date(_))
|
||||
|
||||
Reference in New Issue
Block a user