Compare commits

..

14 Commits

Author SHA1 Message Date
Pascal Seitz
3ddca31292 simplify fetch block in column_block_accessor 2025-12-29 10:56:50 +01:00
Pascal Seitz
87fe3a311f share column block accessor 2025-12-12 16:54:44 +08:00
Pascal Seitz
71dc08424c add comment 2025-12-12 16:19:06 +08:00
Pascal Seitz
15913446b8 cleanup
remove clone
move data in term req, single doc opt for stats
2025-12-10 14:34:43 +08:00
Pascal Seitz
78bd3826dc remove stacktrace bloat, use &mut helper
increase cache to 2048
2025-12-08 10:35:23 +08:00
Pascal Seitz
1b56487307 specialize columntype in stats 2025-12-08 10:20:09 +08:00
Pascal Seitz
030554d544 use radix map, fix prepare_max_bucket
use paged term map in term agg
use special no sub agg term map impl
2025-12-08 10:20:09 +08:00
Pascal Seitz
c852bac532 reduce dynamic dispatch, faster term agg 2025-12-08 10:20:09 +08:00
Pascal Seitz
2ce4da8b66 one collector per agg request instead per bucket
In this refactoring a collector knows in which bucket of the parent
their data is in. This allows to convert the previous approach of one
collector per bucket to one collector per request.

low card bucket optimization
2025-12-08 10:20:09 +08:00
Pascal Seitz
0dd6a958f8 add more tests for new collection type 2025-12-08 10:20:09 +08:00
Pascal Seitz
254314a4a3 improve bench 2025-12-07 10:05:30 +08:00
PSeitz
b2f99c6217 add term->histogram benchmark (#2758)
* add term->histogram benchmark

* add more term aggs

---------

Co-authored-by: Pascal Seitz <pascal.seitz@datadoghq.com>
2025-12-04 02:29:37 +01:00
PSeitz
76de5bab6f fix unsafe warnings (#2757) 2025-12-03 20:15:21 +08:00
rustmailer
b7eb31162b docs: add usage example to README (#2743) 2025-12-02 21:56:57 +01:00
73 changed files with 2236 additions and 4603 deletions

View File

@@ -56,7 +56,6 @@ itertools = "0.14.0"
measure_time = "0.9.0"
arc-swap = "1.5.0"
bon = "3.3.1"
i_triangle = "0.38.0"
columnar = { version = "0.6", path = "./columnar", package = "tantivy-columnar" }
sstable = { version = "0.6", path = "./sstable", package = "tantivy-sstable", optional = true }
@@ -71,7 +70,6 @@ futures-util = { version = "0.3.28", optional = true }
futures-channel = { version = "0.3.28", optional = true }
fnv = "1.0.7"
typetag = "0.2.21"
geo-types = "0.7.17"
[target.'cfg(windows)'.dependencies]
winapi = "0.3.9"

View File

@@ -123,6 +123,7 @@ You can also find other bindings on [GitHub](https://github.com/search?q=tantivy
- [seshat](https://github.com/matrix-org/seshat/): A matrix message database/indexer
- [tantiny](https://github.com/baygeldin/tantiny): Tiny full-text search for Ruby
- [lnx](https://github.com/lnx-search/lnx): adaptable, typo tolerant search engine with a REST API
- [Bichon](https://github.com/rustmailer/bichon): A lightweight, high-performance Rust email archiver with WebUI
- and [more](https://github.com/search?q=tantivy)!
### On average, how much faster is Tantivy compared to Lucene?

View File

@@ -1,5 +1,6 @@
use binggan::plugins::PeakMemAllocPlugin;
use binggan::{black_box, InputGroup, PeakMemAlloc, INSTRUMENTED_SYSTEM};
use rand::distributions::WeightedIndex;
use rand::prelude::SliceRandom;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
@@ -53,27 +54,33 @@ fn bench_agg(mut group: InputGroup<Index>) {
register!(group, stats_f64);
register!(group, extendedstats_f64);
register!(group, percentiles_f64);
register!(group, terms_few);
register!(group, terms_many);
register!(group, terms_7);
register!(group, terms_all_unique);
register!(group, terms_150_000);
register!(group, terms_many_top_1000);
register!(group, terms_many_order_by_term);
register!(group, terms_many_with_top_hits);
register!(group, terms_all_unique_with_avg_sub_agg);
register!(group, terms_many_with_avg_sub_agg);
register!(group, terms_few_with_avg_sub_agg);
register!(group, terms_status_with_avg_sub_agg);
register!(group, terms_status_with_histogram);
register!(group, terms_zipf_1000);
register!(group, terms_zipf_1000_with_histogram);
register!(group, terms_zipf_1000_with_avg_sub_agg);
register!(group, terms_many_json_mixed_type_with_avg_sub_agg);
register!(group, cardinality_agg);
register!(group, terms_few_with_cardinality_agg);
register!(group, terms_status_with_cardinality_agg);
register!(group, range_agg);
register!(group, range_agg_with_avg_sub_agg);
register!(group, range_agg_with_term_agg_few);
register!(group, range_agg_with_term_agg_status);
register!(group, range_agg_with_term_agg_many);
register!(group, histogram);
register!(group, histogram_hard_bounds);
register!(group, histogram_with_avg_sub_agg);
register!(group, histogram_with_term_agg_few);
register!(group, histogram_with_term_agg_status);
register!(group, avg_and_range_with_avg_sub_agg);
// Filter aggregation benchmarks
@@ -132,12 +139,12 @@ fn extendedstats_f64(index: &Index) {
}
fn percentiles_f64(index: &Index) {
let agg_req = json!({
"mypercentiles": {
"percentiles": {
"field": "score_f64",
"percents": [ 95, 99, 99.9 ]
"mypercentiles": {
"percentiles": {
"field": "score_f64",
"percents": [ 95, 99, 99.9 ]
}
}
}
});
execute_agg(index, agg_req);
}
@@ -152,10 +159,10 @@ fn cardinality_agg(index: &Index) {
});
execute_agg(index, agg_req);
}
fn terms_few_with_cardinality_agg(index: &Index) {
fn terms_status_with_cardinality_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms" },
"terms": { "field": "text_few_terms_status" },
"aggs": {
"cardinality": {
"cardinality": {
@@ -168,13 +175,20 @@ fn terms_few_with_cardinality_agg(index: &Index) {
execute_agg(index, agg_req);
}
fn terms_few(index: &Index) {
fn terms_7(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_few_terms" } },
"my_texts": { "terms": { "field": "text_few_terms_status" } },
});
execute_agg(index, agg_req);
}
fn terms_many(index: &Index) {
fn terms_all_unique(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_all_unique_terms" } },
});
execute_agg(index, agg_req);
}
fn terms_150_000(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_many_terms" } },
});
@@ -222,11 +236,10 @@ fn terms_many_with_avg_sub_agg(index: &Index) {
});
execute_agg(index, agg_req);
}
fn terms_few_with_avg_sub_agg(index: &Index) {
fn terms_all_unique_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms" },
"terms": { "field": "text_all_unique_terms" },
"aggs": {
"average_f64": { "avg": { "field": "score_f64" } }
}
@@ -234,6 +247,60 @@ fn terms_few_with_avg_sub_agg(index: &Index) {
});
execute_agg(index, agg_req);
}
fn terms_status_with_histogram(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms_status" },
"aggs": {
"histo": {"histogram": { "field": "score_f64", "interval": 10 }}
}
}
});
execute_agg(index, agg_req);
}
fn terms_zipf_1000_with_histogram(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_1000_terms_zipf" },
"aggs": {
"histo": {"histogram": { "field": "score_f64", "interval": 10 }}
}
}
});
execute_agg(index, agg_req);
}
fn terms_status_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms_status" },
"aggs": {
"average_f64": { "avg": { "field": "score_f64" } }
}
},
});
execute_agg(index, agg_req);
}
fn terms_zipf_1000_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_1000_terms_zipf" },
"aggs": {
"average_f64": { "avg": { "field": "score_f64" } }
}
},
});
execute_agg(index, agg_req);
}
fn terms_zipf_1000(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_1000_terms_zipf" } },
});
execute_agg(index, agg_req);
}
fn terms_many_json_mixed_type_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
@@ -290,7 +357,7 @@ fn range_agg_with_avg_sub_agg(index: &Index) {
execute_agg(index, agg_req);
}
fn range_agg_with_term_agg_few(index: &Index) {
fn range_agg_with_term_agg_status(index: &Index) {
let agg_req = json!({
"rangef64": {
"range": {
@@ -305,7 +372,7 @@ fn range_agg_with_term_agg_few(index: &Index) {
]
},
"aggs": {
"my_texts": { "terms": { "field": "text_few_terms" } },
"my_texts": { "terms": { "field": "text_few_terms_status" } },
}
},
});
@@ -361,12 +428,12 @@ fn histogram_with_avg_sub_agg(index: &Index) {
});
execute_agg(index, agg_req);
}
fn histogram_with_term_agg_few(index: &Index) {
fn histogram_with_term_agg_status(index: &Index) {
let agg_req = json!({
"rangef64": {
"histogram": { "field": "score_f64", "interval": 10 },
"aggs": {
"my_texts": { "terms": { "field": "text_few_terms" } }
"my_texts": { "terms": { "field": "text_few_terms_status" } }
}
}
});
@@ -411,6 +478,13 @@ fn get_collector(agg_req: Aggregations) -> AggregationCollector {
}
fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
// Flag to use existing index
let reuse_index = std::env::var("REUSE_AGG_BENCH_INDEX").is_ok();
if reuse_index && std::path::Path::new("agg_bench").exists() {
return Index::open_in_dir("agg_bench");
}
// crreate dir
std::fs::create_dir_all("agg_bench")?;
let mut schema_builder = Schema::builder();
let text_fieldtype = tantivy::schema::TextOptions::default()
.set_indexing_options(
@@ -419,20 +493,47 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
.set_stored();
let text_field = schema_builder.add_text_field("text", text_fieldtype);
let json_field = schema_builder.add_json_field("json", FAST);
let text_field_all_unique_terms =
schema_builder.add_text_field("text_all_unique_terms", STRING | FAST);
let text_field_many_terms = schema_builder.add_text_field("text_many_terms", STRING | FAST);
let text_field_few_terms = schema_builder.add_text_field("text_few_terms", STRING | FAST);
let text_field_few_terms_status =
schema_builder.add_text_field("text_few_terms_status", STRING | FAST);
let text_field_1000_terms_zipf =
schema_builder.add_text_field("text_1000_terms_zipf", STRING | FAST);
let score_fieldtype = tantivy::schema::NumericOptions::default().set_fast();
let score_field = schema_builder.add_u64_field("score", score_fieldtype.clone());
let score_field_f64 = schema_builder.add_f64_field("score_f64", score_fieldtype.clone());
let score_field_i64 = schema_builder.add_i64_field("score_i64", score_fieldtype);
let index = Index::create_from_tempdir(schema_builder.build())?;
let few_terms_data = ["INFO", "ERROR", "WARN", "DEBUG"];
// use tmp dir
let index = if reuse_index {
Index::create_in_dir("agg_bench", schema_builder.build())?
} else {
Index::create_from_tempdir(schema_builder.build())?
};
// Approximate log proportions
let status_field_data = [
("INFO", 8000),
("ERROR", 300),
("WARN", 1200),
("DEBUG", 500),
("OK", 500),
("CRITICAL", 20),
("EMERGENCY", 1),
];
let log_level_distribution =
WeightedIndex::new(status_field_data.iter().map(|item| item.1)).unwrap();
let lg_norm = rand_distr::LogNormal::new(2.996f64, 0.979f64).unwrap();
let many_terms_data = (0..150_000)
.map(|num| format!("author{num}"))
.collect::<Vec<_>>();
// Prepare 1000 unique terms sampled using a Zipf distribution.
// Exponent ~1.1 approximates top-20 terms covering around ~20%.
let terms_1000: Vec<String> = (1..=1000).map(|i| format!("term_{i}")).collect();
let zipf_1000 = rand_distr::Zipf::new(1000, 1.1f64).unwrap();
{
let mut rng = StdRng::from_seed([1u8; 32]);
let mut index_writer = index.writer_with_num_threads(1, 200_000_000)?;
@@ -442,15 +543,25 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
index_writer.add_document(doc!())?;
}
if cardinality == Cardinality::Multivalued {
let log_level_sample_a = status_field_data[log_level_distribution.sample(&mut rng)].0;
let log_level_sample_b = status_field_data[log_level_distribution.sample(&mut rng)].0;
let idx_a = zipf_1000.sample(&mut rng) as usize - 1;
let idx_b = zipf_1000.sample(&mut rng) as usize - 1;
let term_1000_a = &terms_1000[idx_a];
let term_1000_b = &terms_1000[idx_b];
index_writer.add_document(doc!(
json_field => json!({"mixed_type": 10.0}),
json_field => json!({"mixed_type": 10.0}),
text_field => "cool",
text_field => "cool",
text_field_all_unique_terms => "cool",
text_field_all_unique_terms => "coolo",
text_field_many_terms => "cool",
text_field_many_terms => "cool",
text_field_few_terms => "cool",
text_field_few_terms => "cool",
text_field_few_terms_status => log_level_sample_a,
text_field_few_terms_status => log_level_sample_b,
text_field_1000_terms_zipf => term_1000_a.as_str(),
text_field_1000_terms_zipf => term_1000_b.as_str(),
score_field => 1u64,
score_field => 1u64,
score_field_f64 => lg_norm.sample(&mut rng),
@@ -475,8 +586,10 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
index_writer.add_document(doc!(
text_field => "cool",
json_field => json,
text_field_all_unique_terms => format!("unique_term_{}", rng.gen::<u64>()),
text_field_many_terms => many_terms_data.choose(&mut rng).unwrap().to_string(),
text_field_few_terms => few_terms_data.choose(&mut rng).unwrap().to_string(),
text_field_few_terms_status => status_field_data[log_level_distribution.sample(&mut rng)].0,
text_field_1000_terms_zipf => terms_1000[zipf_1000.sample(&mut rng) as usize - 1].as_str(),
score_field => val as u64,
score_field_f64 => lg_norm.sample(&mut rng),
score_field_i64 => val as i64,
@@ -528,7 +641,7 @@ fn filter_agg_all_query_with_sub_aggs(index: &Index) {
"avg_score": { "avg": { "field": "score" } },
"stats_score": { "stats": { "field": "score_f64" } },
"terms_text": {
"terms": { "field": "text_few_terms" }
"terms": { "field": "text_few_terms_status" }
}
}
}
@@ -544,7 +657,7 @@ fn filter_agg_term_query_with_sub_aggs(index: &Index) {
"avg_score": { "avg": { "field": "score" } },
"stats_score": { "stats": { "field": "score_f64" } },
"terms_text": {
"terms": { "field": "text_few_terms" }
"terms": { "field": "text_few_terms_status" }
}
}
}

View File

@@ -19,7 +19,7 @@ fn u32_to_i32(val: u32) -> i32 {
#[inline]
unsafe fn u32_to_i32_avx2(vals_u32x8s: DataType) -> DataType {
const HIGHEST_BIT_MASK: DataType = from_u32x8([HIGHEST_BIT; NUM_LANES]);
op_xor(vals_u32x8s, HIGHEST_BIT_MASK)
unsafe { op_xor(vals_u32x8s, HIGHEST_BIT_MASK) }
}
pub fn filter_vec_in_place(range: RangeInclusive<u32>, offset: u32, output: &mut Vec<u32>) {
@@ -66,17 +66,19 @@ unsafe fn filter_vec_avx2_aux(
]);
const SHIFT: __m256i = from_u32x8([NUM_LANES as u32; NUM_LANES]);
for _ in 0..num_words {
let word = load_unaligned(input);
let word = u32_to_i32_avx2(word);
let keeper_bitset = compute_filter_bitset(word, range_simd.clone());
let added_len = keeper_bitset.count_ones();
let filtered_doc_ids = compact(ids, keeper_bitset);
store_unaligned(output_tail as *mut __m256i, filtered_doc_ids);
output_tail = output_tail.offset(added_len as isize);
ids = op_add(ids, SHIFT);
input = input.offset(1);
unsafe {
let word = load_unaligned(input);
let word = u32_to_i32_avx2(word);
let keeper_bitset = compute_filter_bitset(word, range_simd.clone());
let added_len = keeper_bitset.count_ones();
let filtered_doc_ids = compact(ids, keeper_bitset);
store_unaligned(output_tail as *mut __m256i, filtered_doc_ids);
output_tail = output_tail.offset(added_len as isize);
ids = op_add(ids, SHIFT);
input = input.offset(1);
}
}
output_tail.offset_from(output) as usize
unsafe { output_tail.offset_from(output) as usize }
}
#[inline]
@@ -92,8 +94,7 @@ unsafe fn compute_filter_bitset(val: __m256i, range: std::ops::RangeInclusive<__
let too_low = op_greater(*range.start(), val);
let too_high = op_greater(val, *range.end());
let inside = op_or(too_low, too_high);
255 - std::arch::x86_64::_mm256_movemask_ps(std::mem::transmute::<DataType, __m256>(inside))
as u8
255 - std::arch::x86_64::_mm256_movemask_ps(_mm256_castsi256_ps(inside)) as u8
}
union U8x32 {

View File

@@ -29,12 +29,20 @@ impl<T: PartialOrd + Copy + std::fmt::Debug + Send + Sync + 'static + Default>
}
}
#[inline]
pub fn fetch_block_with_missing(&mut self, docs: &[u32], accessor: &Column<T>, missing: T) {
pub fn fetch_block_with_missing(
&mut self,
docs: &[u32],
accessor: &Column<T>,
missing: Option<T>,
) {
self.fetch_block(docs, accessor);
// no missing values
if accessor.index.get_cardinality().is_full() {
return;
}
let Some(missing) = missing else {
return;
};
// We can compare docid_cache length with docs to find missing docs
// For multi value columns we can't rely on the length and always need to scan

View File

@@ -181,6 +181,14 @@ pub struct BitSet {
len: u64,
max_value: u32,
}
impl std::fmt::Debug for BitSet {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BitSet")
.field("len", &self.len)
.field("max_value", &self.max_value)
.finish()
}
}
fn num_buckets(max_val: u32) -> u32 {
max_val.div_ceil(64u32)

View File

@@ -1,66 +0,0 @@
use geo_types::Point;
use tantivy::collector::TopDocs;
use tantivy::query::SpatialQuery;
use tantivy::schema::{Schema, Value, SPATIAL, STORED, TEXT};
use tantivy::spatial::point::GeoPoint;
use tantivy::{Index, IndexWriter, TantivyDocument};
fn main() -> tantivy::Result<()> {
let mut schema_builder = Schema::builder();
schema_builder.add_json_field("properties", STORED | TEXT);
schema_builder.add_spatial_field("geometry", STORED | SPATIAL);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema.clone());
let mut index_writer: IndexWriter = index.writer(50_000_000)?;
let doc = TantivyDocument::parse_json(
&schema,
r#"{
"type":"Feature",
"geometry":{
"type":"Polygon",
"coordinates":[[[-99.483911,45.577697],[-99.483869,45.571457],[-99.481739,45.571461],[-99.474881,45.571584],[-99.473167,45.571615],[-99.463394,45.57168],[-99.463391,45.57883],[-99.463368,45.586076],[-99.48177,45.585926],[-99.48384,45.585953],[-99.483885,45.57873],[-99.483911,45.577697]]]
},
"properties":{
"admin_level":"8",
"border_type":"city",
"boundary":"administrative",
"gnis:feature_id":"1267426",
"name":"Hosmer",
"place":"city",
"source":"TIGER/Line® 2008 Place Shapefiles (http://www.census.gov/geo/www/tiger/)",
"wikidata":"Q2442118",
"wikipedia":"en:Hosmer, South Dakota"
}
}"#,
)?;
index_writer.add_document(doc)?;
index_writer.commit()?;
let reader = index.reader()?;
let searcher = reader.searcher();
let field = schema.get_field("geometry").unwrap();
let query = SpatialQuery::new(
field,
[
GeoPoint {
lon: -99.49,
lat: 45.56,
},
GeoPoint {
lon: -99.45,
lat: 45.59,
},
],
tantivy::query::SpatialQueryType::Intersects,
);
let hits = searcher.search(&query, &TopDocs::with_limit(10).order_by_score())?;
for (_score, doc_address) in &hits {
let retrieved_doc: TantivyDocument = searcher.doc(*doc_address)?;
if let Some(field_value) = retrieved_doc.get_first(field) {
if let Some(geometry_box) = field_value.as_value().into_geometry() {
println!("Retrieved geometry: {:?}", geometry_box);
}
}
}
assert_eq!(hits.len(), 1);
Ok(())
}

View File

@@ -1,4 +1,4 @@
use columnar::{Column, ColumnType, StrColumn};
use columnar::{Column, ColumnBlockAccessor, ColumnType, StrColumn};
use common::BitSet;
use rustc_hash::FxHashSet;
use serde::Serialize;
@@ -10,16 +10,16 @@ use crate::aggregation::accessor_helpers::{
};
use crate::aggregation::agg_req::{Aggregation, AggregationVariants, Aggregations};
use crate::aggregation::bucket::{
FilterAggReqData, HistogramAggReqData, HistogramBounds, IncludeExcludeParam,
MissingTermAggReqData, RangeAggReqData, SegmentFilterCollector, SegmentHistogramCollector,
SegmentRangeCollector, TermMissingAgg, TermsAggReqData, TermsAggregation,
build_segment_range_collector, FilterAggReqData, HistogramAggReqData, HistogramBounds,
IncludeExcludeParam, MissingTermAggReqData, RangeAggReqData, SegmentFilterCollector,
SegmentHistogramCollector, TermMissingAgg, TermsAggReqData, TermsAggregation,
TermsAggregationInternal,
};
use crate::aggregation::metric::{
AverageAggregation, CardinalityAggReqData, CardinalityAggregationReq, CountAggregation,
ExtendedStatsAggregation, MaxAggregation, MetricAggReqData, MinAggregation,
SegmentCardinalityCollector, SegmentExtendedStatsCollector, SegmentPercentilesCollector,
SegmentStatsCollector, StatsAggregation, StatsType, SumAggregation, TopHitsAggReqData,
build_segment_stats_collector, AverageAggregation, CardinalityAggReqData,
CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation, MaxAggregation,
MetricAggReqData, MinAggregation, SegmentCardinalityCollector, SegmentExtendedStatsCollector,
SegmentPercentilesCollector, StatsAggregation, StatsType, SumAggregation, TopHitsAggReqData,
TopHitsSegmentCollector,
};
use crate::aggregation::segment_agg_result::{
@@ -35,6 +35,7 @@ pub struct AggregationsSegmentCtx {
/// Request data for each aggregation type.
pub per_request: PerRequestAggSegCtx,
pub context: AggContextParams,
pub column_block_accessor: ColumnBlockAccessor<u64>,
}
impl AggregationsSegmentCtx {
@@ -107,21 +108,14 @@ impl AggregationsSegmentCtx {
.as_deref()
.expect("range_req_data slot is empty (taken)")
}
#[inline]
pub(crate) fn get_filter_req_data(&self, idx: usize) -> &FilterAggReqData {
self.per_request.filter_req_data[idx]
.as_deref()
.expect("filter_req_data slot is empty (taken)")
}
// ---------- mutable getters ----------
#[inline]
pub(crate) fn get_term_req_data_mut(&mut self, idx: usize) -> &mut TermsAggReqData {
self.per_request.term_req_data[idx]
.as_deref_mut()
.expect("term_req_data slot is empty (taken)")
pub(crate) fn get_metric_req_data_mut(&mut self, idx: usize) -> &mut MetricAggReqData {
&mut self.per_request.stats_metric_req_data[idx]
}
#[inline]
pub(crate) fn get_cardinality_req_data_mut(
&mut self,
@@ -129,10 +123,7 @@ impl AggregationsSegmentCtx {
) -> &mut CardinalityAggReqData {
&mut self.per_request.cardinality_req_data[idx]
}
#[inline]
pub(crate) fn get_metric_req_data_mut(&mut self, idx: usize) -> &mut MetricAggReqData {
&mut self.per_request.stats_metric_req_data[idx]
}
#[inline]
pub(crate) fn get_histogram_req_data_mut(&mut self, idx: usize) -> &mut HistogramAggReqData {
self.per_request.histogram_req_data[idx]
@@ -142,21 +133,6 @@ impl AggregationsSegmentCtx {
// ---------- take / put (terms, histogram, range) ----------
/// Move out the boxed Terms request at `idx`, leaving `None`.
#[inline]
pub(crate) fn take_term_req_data(&mut self, idx: usize) -> Box<TermsAggReqData> {
self.per_request.term_req_data[idx]
.take()
.expect("term_req_data slot is empty (taken)")
}
/// Put back a Terms request into an empty slot at `idx`.
#[inline]
pub(crate) fn put_back_term_req_data(&mut self, idx: usize, value: Box<TermsAggReqData>) {
debug_assert!(self.per_request.term_req_data[idx].is_none());
self.per_request.term_req_data[idx] = Some(value);
}
/// Move out the boxed Histogram request at `idx`, leaving `None`.
#[inline]
pub(crate) fn take_histogram_req_data(&mut self, idx: usize) -> Box<HistogramAggReqData> {
@@ -320,6 +296,7 @@ impl PerRequestAggSegCtx {
/// Convert the aggregation tree into a serializable struct representation.
/// Each node contains: { name, kind, children }.
#[allow(dead_code)]
pub fn get_view_tree(&self) -> Vec<AggTreeViewNode> {
fn node_to_view(node: &AggRefNode, pr: &PerRequestAggSegCtx) -> AggTreeViewNode {
let mut children: Vec<AggTreeViewNode> =
@@ -345,12 +322,19 @@ impl PerRequestAggSegCtx {
pub(crate) fn build_segment_agg_collectors_root(
req: &mut AggregationsSegmentCtx,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
build_segment_agg_collectors(req, &req.per_request.agg_tree.clone())
build_segment_agg_collectors_generic(req, &req.per_request.agg_tree.clone())
}
pub(crate) fn build_segment_agg_collectors(
req: &mut AggregationsSegmentCtx,
nodes: &[AggRefNode],
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
build_segment_agg_collectors_generic(req, nodes)
}
fn build_segment_agg_collectors_generic(
req: &mut AggregationsSegmentCtx,
nodes: &[AggRefNode],
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
let mut collectors = Vec::new();
for node in nodes.iter() {
@@ -388,6 +372,8 @@ pub(crate) fn build_segment_agg_collector(
Ok(Box::new(SegmentCardinalityCollector::from_req(
req_data.column_type,
node.idx_in_req_data,
req_data.accessor.clone(),
req_data.missing_value_for_accessor,
)))
}
AggKind::StatsKind(stats_type) => {
@@ -398,20 +384,21 @@ pub(crate) fn build_segment_agg_collector(
| StatsType::Count
| StatsType::Max
| StatsType::Min
| StatsType::Stats => Ok(Box::new(SegmentStatsCollector::from_req(
node.idx_in_req_data,
))),
StatsType::ExtendedStats(sigma) => {
Ok(Box::new(SegmentExtendedStatsCollector::from_req(
req_data.field_type,
sigma,
node.idx_in_req_data,
req_data.missing,
)))
}
StatsType::Percentiles => Ok(Box::new(
SegmentPercentilesCollector::from_req_and_validate(node.idx_in_req_data)?,
| StatsType::Stats => build_segment_stats_collector(req_data),
StatsType::ExtendedStats(sigma) => Ok(Box::new(
SegmentExtendedStatsCollector::from_req(req_data, sigma),
)),
StatsType::Percentiles => {
let req_data = req.get_metric_req_data_mut(node.idx_in_req_data);
Ok(Box::new(
SegmentPercentilesCollector::from_req_and_validate(
req_data.field_type,
req_data.missing_u64,
req_data.accessor.clone(),
node.idx_in_req_data,
),
))
}
}
}
AggKind::TopHits => {
@@ -428,9 +415,7 @@ pub(crate) fn build_segment_agg_collector(
AggKind::DateHistogram => Ok(Box::new(SegmentHistogramCollector::from_req_and_validate(
req, node,
)?)),
AggKind::Range => Ok(Box::new(SegmentRangeCollector::from_req_and_validate(
req, node,
)?)),
AggKind::Range => Ok(build_segment_range_collector(req, node)?),
AggKind::Filter => Ok(Box::new(SegmentFilterCollector::from_req_and_validate(
req, node,
)?)),
@@ -493,6 +478,7 @@ pub(crate) fn build_aggregations_data_from_req(
let mut data = AggregationsSegmentCtx {
per_request: Default::default(),
context,
column_block_accessor: ColumnBlockAccessor::default(),
};
for (name, agg) in aggs.iter() {
@@ -521,9 +507,9 @@ fn build_nodes(
let idx_in_req_data = data.push_range_req_data(RangeAggReqData {
accessor,
field_type,
column_block_accessor: Default::default(),
name: agg_name.to_string(),
req: range_req.clone(),
is_top_level,
});
let children = build_children(&req.sub_aggregation, reader, segment_ordinal, data)?;
Ok(vec![AggRefNode {
@@ -541,9 +527,7 @@ fn build_nodes(
let idx_in_req_data = data.push_histogram_req_data(HistogramAggReqData {
accessor,
field_type,
column_block_accessor: Default::default(),
name: agg_name.to_string(),
sub_aggregation_blueprint: None,
req: histo_req.clone(),
is_date_histogram: false,
bounds: HistogramBounds {
@@ -568,9 +552,7 @@ fn build_nodes(
let idx_in_req_data = data.push_histogram_req_data(HistogramAggReqData {
accessor,
field_type,
column_block_accessor: Default::default(),
name: agg_name.to_string(),
sub_aggregation_blueprint: None,
req: histo_req,
is_date_histogram: true,
bounds: HistogramBounds {
@@ -650,7 +632,6 @@ fn build_nodes(
let idx_in_req_data = data.push_metric_req_data(MetricAggReqData {
accessor,
field_type,
column_block_accessor: Default::default(),
name: agg_name.to_string(),
collecting_for,
missing: *missing,
@@ -678,7 +659,6 @@ fn build_nodes(
let idx_in_req_data = data.push_metric_req_data(MetricAggReqData {
accessor,
field_type,
column_block_accessor: Default::default(),
name: agg_name.to_string(),
collecting_for: StatsType::Percentiles,
missing: percentiles_req.missing,
@@ -895,7 +875,7 @@ fn build_terms_or_cardinality_nodes(
});
}
// Add one node per accessor to mirror previous behavior and allow per-type missing handling.
// Add one node per accessor
for (accessor, column_type) in column_and_types {
let missing_value_for_accessor = if use_special_missing_agg {
None
@@ -926,11 +906,8 @@ fn build_terms_or_cardinality_nodes(
column_type,
str_dict_column: str_dict_column.clone(),
missing_value_for_accessor,
column_block_accessor: Default::default(),
name: agg_name.to_string(),
req: TermsAggregationInternal::from_req(req),
// Will be filled later when building collectors
sub_aggregation_blueprint: None,
sug_aggregations: sub_aggs.clone(),
allowed_term_ids,
is_top_level,
@@ -943,7 +920,6 @@ fn build_terms_or_cardinality_nodes(
column_type,
str_dict_column: str_dict_column.clone(),
missing_value_for_accessor,
column_block_accessor: Default::default(),
name: agg_name.to_string(),
req: req.clone(),
});

View File

@@ -2,15 +2,441 @@ use serde_json::Value;
use crate::aggregation::agg_req::{Aggregation, Aggregations};
use crate::aggregation::agg_result::AggregationResults;
use crate::aggregation::buf_collector::DOC_BLOCK_SIZE;
use crate::aggregation::collector::AggregationCollector;
use crate::aggregation::intermediate_agg_result::IntermediateAggregationResults;
use crate::aggregation::tests::{get_test_index_2_segments, get_test_index_from_values_and_terms};
use crate::aggregation::DistributedAggregationCollector;
use crate::docset::COLLECT_BLOCK_BUFFER_LEN;
use crate::query::{AllQuery, TermQuery};
use crate::schema::{IndexRecordOption, Schema, FAST};
use crate::{Index, IndexWriter, Term};
// The following tests ensure that each bucket aggregation type correctly functions as a
// sub-aggregation of another bucket aggregation in two scenarios:
// 1) The parent has more buckets than the child sub-aggregation
// 2) The child sub-aggregation has more buckets than the parent
//
// These scenarios exercise the bucket id mapping and sub-aggregation routing logic.
#[test]
fn test_terms_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
let index = get_test_index_2_segments(false)?;
// Case A: parent has more buckets than child
// Parent: range with 4 buckets
// Child: terms on text -> 2 buckets
let agg_parent_more: Aggregations = serde_json::from_value(json!({
"parent_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 7.0},
{"from": 7.0, "to": 20.0},
{"from": 20.0}
]
},
"aggs": {
"child_terms": {"terms": {"field": "text", "order": {"_key": "asc"}}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
// Exact expected structure and counts
assert_eq!(
res["parent_range"]["buckets"],
json!([
{
"key": "*-3",
"doc_count": 1,
"to": 3.0,
"child_terms": {
"buckets": [
{"doc_count": 1, "key": "cool"}
],
"sum_other_doc_count": 0
}
},
{
"key": "3-7",
"doc_count": 3,
"from": 3.0,
"to": 7.0,
"child_terms": {
"buckets": [
{"doc_count": 2, "key": "cool"},
{"doc_count": 1, "key": "nohit"}
],
"sum_other_doc_count": 0
}
},
{
"key": "7-20",
"doc_count": 3,
"from": 7.0,
"to": 20.0,
"child_terms": {
"buckets": [
{"doc_count": 3, "key": "cool"}
],
"sum_other_doc_count": 0
}
},
{
"key": "20-*",
"doc_count": 2,
"from": 20.0,
"child_terms": {
"buckets": [
{"doc_count": 1, "key": "cool"},
{"doc_count": 1, "key": "nohit"}
],
"sum_other_doc_count": 0
}
}
])
);
// Case B: child has more buckets than parent
// Parent: histogram on score with large interval -> 1 bucket
// Child: terms on text -> 2 buckets (cool/nohit)
let agg_child_more: Aggregations = serde_json::from_value(json!({
"parent_hist": {
"histogram": {"field": "score", "interval": 100.0},
"aggs": {
"child_terms": {"terms": {"field": "text", "order": {"_key": "asc"}}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
assert_eq!(
res["parent_hist"],
json!({
"buckets": [
{
"key": 0.0,
"doc_count": 9,
"child_terms": {
"buckets": [
{"doc_count": 7, "key": "cool"},
{"doc_count": 2, "key": "nohit"}
],
"sum_other_doc_count": 0
}
}
]
})
);
Ok(())
}
#[test]
fn test_range_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
let index = get_test_index_2_segments(false)?;
// Case A: parent has more buckets than child
// Parent: range with 5 buckets
// Child: coarse range with 3 buckets
let agg_parent_more: Aggregations = serde_json::from_value(json!({
"parent_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 7.0},
{"from": 7.0, "to": 11.0},
{"from": 11.0, "to": 20.0},
{"from": 20.0}
]
},
"aggs": {
"child_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 20.0}
]
}
}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
assert_eq!(
res["parent_range"]["buckets"],
json!([
{"key": "*-3", "doc_count": 1, "to": 3.0,
"child_range": {"buckets": [
{"key": "*-3", "doc_count": 1, "to": 3.0},
{"key": "3-20", "doc_count": 0, "from": 3.0, "to": 20.0},
{"key": "20-*", "doc_count": 0, "from": 20.0}
]}
},
{"key": "3-7", "doc_count": 3, "from": 3.0, "to": 7.0,
"child_range": {"buckets": [
{"key": "*-3", "doc_count": 0, "to": 3.0},
{"key": "3-20", "doc_count": 3, "from": 3.0, "to": 20.0},
{"key": "20-*", "doc_count": 0, "from": 20.0}
]}
},
{"key": "7-11", "doc_count": 1, "from": 7.0, "to": 11.0,
"child_range": {"buckets": [
{"key": "*-3", "doc_count": 0, "to": 3.0},
{"key": "3-20", "doc_count": 1, "from": 3.0, "to": 20.0},
{"key": "20-*", "doc_count": 0, "from": 20.0}
]}
},
{"key": "11-20", "doc_count": 2, "from": 11.0, "to": 20.0,
"child_range": {"buckets": [
{"key": "*-3", "doc_count": 0, "to": 3.0},
{"key": "3-20", "doc_count": 2, "from": 3.0, "to": 20.0},
{"key": "20-*", "doc_count": 0, "from": 20.0}
]}
},
{"key": "20-*", "doc_count": 2, "from": 20.0,
"child_range": {"buckets": [
{"key": "*-3", "doc_count": 0, "to": 3.0},
{"key": "3-20", "doc_count": 0, "from": 3.0, "to": 20.0},
{"key": "20-*", "doc_count": 2, "from": 20.0}
]}
}
])
);
// Case B: child has more buckets than parent
// Parent: terms on text (2 buckets)
// Child: range with 4 buckets
let agg_child_more: Aggregations = serde_json::from_value(json!({
"parent_terms": {
"terms": {"field": "text"},
"aggs": {
"child_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 7.0},
{"from": 7.0, "to": 20.0}
]
}
}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
assert_eq!(
res["parent_terms"],
json!({
"buckets": [
{
"key": "cool",
"doc_count": 7,
"child_range": {
"buckets": [
{"key": "*-3", "doc_count": 1, "to": 3.0},
{"key": "3-7", "doc_count": 2, "from": 3.0, "to": 7.0},
{"key": "7-20", "doc_count": 3, "from": 7.0, "to": 20.0},
{"key": "20-*", "doc_count": 1, "from": 20.0}
]
}
},
{
"key": "nohit",
"doc_count": 2,
"child_range": {
"buckets": [
{"key": "*-3", "doc_count": 0, "to": 3.0},
{"key": "3-7", "doc_count": 1, "from": 3.0, "to": 7.0},
{"key": "7-20", "doc_count": 0, "from": 7.0, "to": 20.0},
{"key": "20-*", "doc_count": 1, "from": 20.0}
]
}
}
],
"doc_count_error_upper_bound": 0,
"sum_other_doc_count": 0
})
);
Ok(())
}
#[test]
fn test_histogram_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
let index = get_test_index_2_segments(false)?;
// Case A: parent has more buckets than child
// Parent: range with several ranges
// Child: histogram with large interval (single bucket per parent)
let agg_parent_more: Aggregations = serde_json::from_value(json!({
"parent_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 7.0},
{"from": 7.0, "to": 11.0},
{"from": 11.0, "to": 20.0},
{"from": 20.0}
]
},
"aggs": {
"child_hist": {"histogram": {"field": "score", "interval": 100.0}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
assert_eq!(
res["parent_range"]["buckets"],
json!([
{"key": "*-3", "doc_count": 1, "to": 3.0,
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 1} ]}
},
{"key": "3-7", "doc_count": 3, "from": 3.0, "to": 7.0,
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 3} ]}
},
{"key": "7-11", "doc_count": 1, "from": 7.0, "to": 11.0,
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 1} ]}
},
{"key": "11-20", "doc_count": 2, "from": 11.0, "to": 20.0,
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 2} ]}
},
{"key": "20-*", "doc_count": 2, "from": 20.0,
"child_hist": {"buckets": [ {"key": 0.0, "doc_count": 2} ]}
}
])
);
// Case B: child has more buckets than parent
// Parent: terms on text -> 2 buckets
// Child: histogram with small interval -> multiple buckets including empties
let agg_child_more: Aggregations = serde_json::from_value(json!({
"parent_terms": {
"terms": {"field": "text"},
"aggs": {
"child_hist": {"histogram": {"field": "score", "interval": 10.0}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
assert_eq!(
res["parent_terms"],
json!({
"buckets": [
{
"key": "cool",
"doc_count": 7,
"child_hist": {
"buckets": [
{"key": 0.0, "doc_count": 4},
{"key": 10.0, "doc_count": 2},
{"key": 20.0, "doc_count": 0},
{"key": 30.0, "doc_count": 0},
{"key": 40.0, "doc_count": 1}
]
}
},
{
"key": "nohit",
"doc_count": 2,
"child_hist": {
"buckets": [
{"key": 0.0, "doc_count": 1},
{"key": 10.0, "doc_count": 0},
{"key": 20.0, "doc_count": 0},
{"key": 30.0, "doc_count": 0},
{"key": 40.0, "doc_count": 1}
]
}
}
],
"doc_count_error_upper_bound": 0,
"sum_other_doc_count": 0
})
);
Ok(())
}
#[test]
fn test_date_histogram_as_subagg_parent_more_vs_child_more() -> crate::Result<()> {
let index = get_test_index_2_segments(false)?;
// Case A: parent has more buckets than child
// Parent: range with several buckets
// Child: date_histogram with 30d -> single bucket per parent
let agg_parent_more: Aggregations = serde_json::from_value(json!({
"parent_range": {
"range": {
"field": "score",
"ranges": [
{"to": 3.0},
{"from": 3.0, "to": 7.0},
{"from": 7.0, "to": 11.0},
{"from": 11.0, "to": 20.0},
{"from": 20.0}
]
},
"aggs": {
"child_date_hist": {"date_histogram": {"field": "date", "fixed_interval": "30d"}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_parent_more, &index)?;
let buckets = res["parent_range"]["buckets"].as_array().unwrap();
// Verify each parent bucket has exactly one child date bucket with matching doc_count
for bucket in buckets {
let parent_count = bucket["doc_count"].as_u64().unwrap();
let child_buckets = bucket["child_date_hist"]["buckets"].as_array().unwrap();
assert_eq!(child_buckets.len(), 1);
assert_eq!(child_buckets[0]["doc_count"], parent_count);
}
// Case B: child has more buckets than parent
// Parent: terms on text (2 buckets)
// Child: date_histogram with 1d -> multiple buckets
let agg_child_more: Aggregations = serde_json::from_value(json!({
"parent_terms": {
"terms": {"field": "text"},
"aggs": {
"child_date_hist": {"date_histogram": {"field": "date", "fixed_interval": "1d"}}
}
}
}))
.unwrap();
let res = crate::aggregation::tests::exec_request(agg_child_more, &index)?;
let buckets = res["parent_terms"]["buckets"].as_array().unwrap();
// cool bucket
assert_eq!(buckets[0]["key"], "cool");
let cool_buckets = buckets[0]["child_date_hist"]["buckets"].as_array().unwrap();
assert_eq!(cool_buckets.len(), 3);
assert_eq!(cool_buckets[0]["doc_count"], 1); // day 0
assert_eq!(cool_buckets[1]["doc_count"], 4); // day 1
assert_eq!(cool_buckets[2]["doc_count"], 2); // day 2
// nohit bucket
assert_eq!(buckets[1]["key"], "nohit");
let nohit_buckets = buckets[1]["child_date_hist"]["buckets"].as_array().unwrap();
assert_eq!(nohit_buckets.len(), 2);
assert_eq!(nohit_buckets[0]["doc_count"], 1); // day 1
assert_eq!(nohit_buckets[1]["doc_count"], 1); // day 2
Ok(())
}
fn get_avg_req(field_name: &str) -> Aggregation {
serde_json::from_value(json!({
"avg": {
@@ -25,6 +451,10 @@ fn get_collector(agg_req: Aggregations) -> AggregationCollector {
}
// *** EVERY BUCKET-TYPE SHOULD BE TESTED HERE ***
// Note: The flushng part of these tests are outdated, since the buffering change after converting
// the collection into one collector per request instead of per bucket.
//
// However they are useful as they test a complex aggregation requests.
fn test_aggregation_flushing(
merge_segments: bool,
use_distributed_collector: bool,
@@ -37,8 +467,9 @@ fn test_aggregation_flushing(
let reader = index.reader()?;
assert_eq!(DOC_BLOCK_SIZE, 64);
// In the tree we cache Documents of DOC_BLOCK_SIZE, before passing them down as one block.
assert_eq!(COLLECT_BLOCK_BUFFER_LEN, 64);
// In the tree we cache documents of COLLECT_BLOCK_BUFFER_LEN before passing them down as one
// block.
//
// Build a request so that on the first level we have one full cache, which is then flushed.
// The same cache should have some residue docs at the end, which are flushed (Range 0-70)

View File

@@ -6,10 +6,12 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer};
use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::cached_sub_aggs::CachedSubAggs;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
};
use crate::aggregation::segment_agg_result::{CollectorClone, SegmentAggregationCollector};
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
use crate::aggregation::BucketId;
use crate::docset::DocSet;
use crate::query::{AllQuery, EnableScoring, Query, QueryParser};
use crate::schema::Schema;
@@ -410,9 +412,9 @@ impl FilterAggReqData {
pub(crate) fn get_memory_consumption(&self) -> usize {
// Estimate: name + segment reader reference + bitset + buffer capacity
self.name.len()
+ std::mem::size_of::<SegmentReader>()
+ self.evaluator.bitset.len() / 8 // BitSet memory (bits to bytes)
+ self.matching_docs_buffer.capacity() * std::mem::size_of::<DocId>()
+ std::mem::size_of::<SegmentReader>()
+ self.evaluator.bitset.len() / 8 // BitSet memory (bits to bytes)
+ self.matching_docs_buffer.capacity() * std::mem::size_of::<DocId>()
}
}
@@ -489,12 +491,19 @@ impl Debug for DocumentQueryEvaluator {
}
}
#[derive(Debug, Clone, PartialEq, Copy)]
struct DocCount {
doc_count: u64,
bucket_id: BucketId,
}
/// Segment collector for filter aggregation
pub struct SegmentFilterCollector {
/// Document count in this bucket
doc_count: u64,
/// Document counts per parent bucket
parent_buckets: Vec<DocCount>,
/// Sub-aggregation collectors
sub_aggregations: Option<Box<dyn SegmentAggregationCollector>>,
sub_aggregations: Option<CachedSubAggs<true>>,
bucket_id_provider: BucketIdProvider,
/// Accessor index for this filter aggregation (to access FilterAggReqData)
accessor_idx: usize,
}
@@ -511,11 +520,13 @@ impl SegmentFilterCollector {
} else {
None
};
let sub_agg_collector = sub_agg_collector.map(CachedSubAggs::new);
Ok(SegmentFilterCollector {
doc_count: 0,
parent_buckets: Vec::new(),
sub_aggregations: sub_agg_collector,
accessor_idx: node.idx_in_req_data,
bucket_id_provider: BucketIdProvider::default(),
})
}
}
@@ -523,35 +534,41 @@ impl SegmentFilterCollector {
impl Debug for SegmentFilterCollector {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SegmentFilterCollector")
.field("doc_count", &self.doc_count)
.field("buckets", &self.parent_buckets)
.field("has_sub_aggs", &self.sub_aggregations.is_some())
.field("accessor_idx", &self.accessor_idx)
.finish()
}
}
impl CollectorClone for SegmentFilterCollector {
fn clone_box(&self) -> Box<dyn SegmentAggregationCollector> {
// For now, panic - this needs proper implementation with weight recreation
panic!("SegmentFilterCollector cloning not yet implemented - requires weight recreation")
}
}
impl SegmentAggregationCollector for SegmentFilterCollector {
fn add_intermediate_aggregation_result(
self: Box<Self>,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let mut sub_results = IntermediateAggregationResults::default();
let bucket_opt = self.parent_buckets.get(parent_bucket_id as usize);
if let Some(sub_aggs) = self.sub_aggregations {
sub_aggs.add_intermediate_aggregation_result(agg_data, &mut sub_results)?;
if let Some(sub_aggs) = &mut self.sub_aggregations {
sub_aggs
.get_sub_agg_collector()
.add_intermediate_aggregation_result(
agg_data,
&mut sub_results,
// Here we create a new bucket ID for sub-aggregations if the bucket doesn't
// exist, so that sub-aggregations can still produce results (e.g., zero doc
// count)
bucket_opt
.map(|bucket| bucket.bucket_id)
.unwrap_or(self.bucket_id_provider.next_bucket_id()),
)?;
}
// Create the filter bucket result
let filter_bucket_result = IntermediateBucketResult::Filter {
doc_count: self.doc_count,
doc_count: bucket_opt.map(|b| b.doc_count).unwrap_or(0),
sub_aggregations: sub_results,
};
@@ -570,32 +587,17 @@ impl SegmentAggregationCollector for SegmentFilterCollector {
Ok(())
}
fn collect(&mut self, doc: DocId, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
// Access the evaluator from FilterAggReqData
let req_data = agg_data.get_filter_req_data(self.accessor_idx);
// O(1) BitSet lookup to check if document matches filter
if req_data.evaluator.matches_document(doc) {
self.doc_count += 1;
// If we have sub-aggregations, collect on them for this filtered document
if let Some(sub_aggs) = &mut self.sub_aggregations {
sub_aggs.collect(doc, agg_data)?;
}
}
Ok(())
}
#[inline]
fn collect_block(
fn collect(
&mut self,
docs: &[DocId],
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
if docs.is_empty() {
return Ok(());
}
let mut bucket = self.parent_buckets[parent_bucket_id as usize];
// Take the request data to avoid borrow checker issues with sub-aggregations
let mut req = agg_data.take_filter_req_data(self.accessor_idx);
@@ -604,18 +606,24 @@ impl SegmentAggregationCollector for SegmentFilterCollector {
req.evaluator
.filter_batch(docs, &mut req.matching_docs_buffer);
self.doc_count += req.matching_docs_buffer.len() as u64;
bucket.doc_count += req.matching_docs_buffer.len() as u64;
// Batch process sub-aggregations if we have matches
if !req.matching_docs_buffer.is_empty() {
if let Some(sub_aggs) = &mut self.sub_aggregations {
// Use collect_block for better sub-aggregation performance
sub_aggs.collect_block(&req.matching_docs_buffer, agg_data)?;
for &doc_id in &req.matching_docs_buffer {
sub_aggs.push(bucket.bucket_id, doc_id);
}
}
}
// Put the request data back
agg_data.put_back_filter_req_data(self.accessor_idx, req);
if let Some(sub_aggs) = &mut self.sub_aggregations {
sub_aggs.check_flush_local(agg_data)?;
}
// put back bucket
self.parent_buckets[parent_bucket_id as usize] = bucket;
Ok(())
}
@@ -626,6 +634,21 @@ impl SegmentAggregationCollector for SegmentFilterCollector {
}
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
while self.parent_buckets.len() <= max_bucket as usize {
let bucket_id = self.bucket_id_provider.next_bucket_id();
self.parent_buckets.push(DocCount {
doc_count: 0,
bucket_id,
});
}
Ok(())
}
}
/// Intermediate result for filter aggregation
@@ -1519,9 +1542,9 @@ mod tests {
let searcher = reader.searcher();
let agg = json!({
"test": {
"filter": deserialized,
"aggs": { "count": { "value_count": { "field": "brand" } } }
"test": {
"filter": deserialized,
"aggs": { "count": { "value_count": { "field": "brand" } } }
}
});

View File

@@ -1,6 +1,6 @@
use std::cmp::Ordering;
use columnar::{Column, ColumnBlockAccessor, ColumnType};
use columnar::{Column, ColumnType};
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use tantivy_bitpacker::minmax;
@@ -8,14 +8,14 @@ use tantivy_bitpacker::minmax;
use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::agg_limits::MemoryConsumption;
use crate::aggregation::agg_req::Aggregations;
use crate::aggregation::agg_result::BucketEntry;
use crate::aggregation::cached_sub_aggs::CachedSubAggs;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
IntermediateHistogramBucketEntry,
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
use crate::aggregation::*;
use crate::TantivyError;
@@ -26,13 +26,8 @@ pub struct HistogramAggReqData {
pub accessor: Column<u64>,
/// The field type of the fast field.
pub field_type: ColumnType,
/// The column block accessor to access the fast field values.
pub column_block_accessor: ColumnBlockAccessor<u64>,
/// The name of the aggregation.
pub name: String,
/// The sub aggregation blueprint, used to create sub aggregations for each bucket.
/// Will be filled during initialization of the collector.
pub sub_aggregation_blueprint: Option<Box<dyn SegmentAggregationCollector>>,
/// The histogram aggregation request.
pub req: HistogramAggregation,
/// True if this is a date_histogram aggregation.
@@ -257,18 +252,24 @@ impl HistogramBounds {
pub(crate) struct SegmentHistogramBucketEntry {
pub key: f64,
pub doc_count: u64,
pub bucket_id: BucketId,
}
impl SegmentHistogramBucketEntry {
pub(crate) fn into_intermediate_bucket_entry(
self,
sub_aggregation: Option<Box<dyn SegmentAggregationCollector>>,
sub_aggregation: &mut Option<CachedSubAggs>,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<IntermediateHistogramBucketEntry> {
let mut sub_aggregation_res = IntermediateAggregationResults::default();
if let Some(sub_aggregation) = sub_aggregation {
sub_aggregation
.add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)?;
.get_sub_agg_collector()
.add_intermediate_aggregation_result(
agg_data,
&mut sub_aggregation_res,
self.bucket_id,
)?;
}
Ok(IntermediateHistogramBucketEntry {
key: self.key,
@@ -278,27 +279,38 @@ impl SegmentHistogramBucketEntry {
}
}
#[derive(Clone, Debug, Default)]
struct HistogramBuckets {
pub buckets: FxHashMap<i64, SegmentHistogramBucketEntry>,
}
/// The collector puts values from the fast field into the correct buckets and does a conversion to
/// the correct datatype.
#[derive(Clone, Debug)]
#[derive(Debug)]
pub struct SegmentHistogramCollector {
/// The buckets containing the aggregation data.
buckets: FxHashMap<i64, SegmentHistogramBucketEntry>,
sub_aggregations: FxHashMap<i64, Box<dyn SegmentAggregationCollector>>,
/// One Histogram bucket per parent bucket id.
parent_buckets: Vec<HistogramBuckets>,
sub_agg: Option<CachedSubAggs>,
accessor_idx: usize,
bucket_id_provider: BucketIdProvider,
}
impl SegmentAggregationCollector for SegmentHistogramCollector {
fn add_intermediate_aggregation_result(
self: Box<Self>,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let name = agg_data
.get_histogram_req_data(self.accessor_idx)
.name
.clone();
let bucket = self.into_intermediate_bucket_result(agg_data)?;
// TODO: avoid prepare_max_bucket here and handle empty buckets.
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let histogram = std::mem::take(&mut self.parent_buckets[parent_bucket_id as usize]);
let bucket = self.add_intermediate_bucket_result(agg_data, histogram)?;
results.push(name, IntermediateAggregationResult::Bucket(bucket))?;
Ok(())
@@ -307,44 +319,40 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_data)
}
#[inline]
fn collect_block(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let mut req = agg_data.take_histogram_req_data(self.accessor_idx);
let req = agg_data.take_histogram_req_data(self.accessor_idx);
let mem_pre = self.get_memory_consumption();
let buckets = &mut self.parent_buckets[parent_bucket_id as usize].buckets;
let bounds = req.bounds;
let interval = req.req.interval;
let offset = req.offset;
let get_bucket_pos = |val| get_bucket_pos_f64(val, interval, offset) as i64;
req.column_block_accessor.fetch_block(docs, &req.accessor);
for (doc, val) in req
agg_data
.column_block_accessor
.fetch_block(docs, &req.accessor);
for (doc, val) in agg_data
.column_block_accessor
.iter_docid_vals(docs, &req.accessor)
{
let val = f64_from_fastfield_u64(val, &req.field_type);
let val = f64_from_fastfield_u64(val, req.field_type);
let bucket_pos = get_bucket_pos(val);
if bounds.contains(val) {
let bucket = self.buckets.entry(bucket_pos).or_insert_with(|| {
let bucket = buckets.entry(bucket_pos).or_insert_with(|| {
let key = get_bucket_key_from_pos(bucket_pos as f64, interval, offset);
SegmentHistogramBucketEntry { key, doc_count: 0 }
SegmentHistogramBucketEntry {
key,
doc_count: 0,
bucket_id: self.bucket_id_provider.next_bucket_id(),
}
});
bucket.doc_count += 1;
if let Some(sub_aggregation_blueprint) = req.sub_aggregation_blueprint.as_ref() {
self.sub_aggregations
.entry(bucket_pos)
.or_insert_with(|| sub_aggregation_blueprint.clone())
.collect(doc, agg_data)?;
if let Some(sub_agg) = &mut self.sub_agg {
sub_agg.push(bucket.bucket_id, doc);
}
}
}
@@ -358,14 +366,30 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
.add_memory_consumed(mem_delta as u64)?;
}
if let Some(sub_agg) = &mut self.sub_agg {
sub_agg.check_flush_local(agg_data)?;
}
Ok(())
}
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
for sub_aggregation in self.sub_aggregations.values_mut() {
if let Some(sub_aggregation) = &mut self.sub_agg {
sub_aggregation.flush(agg_data)?;
}
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
while self.parent_buckets.len() <= max_bucket as usize {
self.parent_buckets.push(HistogramBuckets {
buckets: FxHashMap::default(),
});
}
Ok(())
}
}
@@ -373,22 +397,19 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
impl SegmentHistogramCollector {
fn get_memory_consumption(&self) -> usize {
let self_mem = std::mem::size_of::<Self>();
let sub_aggs_mem = self.sub_aggregations.memory_consumption();
let buckets_mem = self.buckets.memory_consumption();
self_mem + sub_aggs_mem + buckets_mem
let buckets_mem = self.parent_buckets.len() * std::mem::size_of::<HistogramBuckets>();
self_mem + buckets_mem
}
/// Converts the collector result into a intermediate bucket result.
pub fn into_intermediate_bucket_result(
self,
fn add_intermediate_bucket_result(
&mut self,
agg_data: &AggregationsSegmentCtx,
histogram: HistogramBuckets,
) -> crate::Result<IntermediateBucketResult> {
let mut buckets = Vec::with_capacity(self.buckets.len());
let mut buckets = Vec::with_capacity(histogram.buckets.len());
for (bucket_pos, bucket) in self.buckets {
let bucket_res = bucket.into_intermediate_bucket_entry(
self.sub_aggregations.get(&bucket_pos).cloned(),
agg_data,
);
for bucket in histogram.buckets.into_values() {
let bucket_res = bucket.into_intermediate_bucket_entry(&mut self.sub_agg, agg_data);
buckets.push(bucket_res?);
}
@@ -408,7 +429,7 @@ impl SegmentHistogramCollector {
agg_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Self> {
let blueprint = if !node.children.is_empty() {
let sub_agg = if !node.children.is_empty() {
Some(build_segment_agg_collectors(agg_data, &node.children)?)
} else {
None
@@ -423,13 +444,13 @@ impl SegmentHistogramCollector {
max: f64::MAX,
});
req_data.offset = req_data.req.offset.unwrap_or(0.0);
req_data.sub_aggregation_blueprint = blueprint;
let sub_agg = sub_agg.map(CachedSubAggs::new);
Ok(Self {
buckets: Default::default(),
sub_aggregations: Default::default(),
parent_buckets: Default::default(),
sub_agg,
accessor_idx: node.idx_in_req_data,
bucket_id_provider: BucketIdProvider::default(),
})
}
}

View File

@@ -1,18 +1,20 @@
use std::fmt::Debug;
use std::ops::Range;
use columnar::{Column, ColumnBlockAccessor, ColumnType};
use columnar::{Column, ColumnType};
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::agg_limits::AggregationLimitsGuard;
use crate::aggregation::cached_sub_aggs::CachedSubAggs;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
IntermediateRangeBucketEntry, IntermediateRangeBucketResult,
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
use crate::aggregation::*;
use crate::TantivyError;
@@ -23,12 +25,12 @@ pub struct RangeAggReqData {
pub accessor: Column<u64>,
/// The type of the fast field.
pub field_type: ColumnType,
/// The column block accessor to access the fast field values.
pub column_block_accessor: ColumnBlockAccessor<u64>,
/// The range aggregation request.
pub req: RangeAggregation,
/// The name of the aggregation.
pub name: String,
/// Whether this is a top-level aggregation.
pub is_top_level: bool,
}
impl RangeAggReqData {
@@ -151,19 +153,47 @@ pub(crate) struct SegmentRangeAndBucketEntry {
/// The collector puts values from the fast field into the correct buckets and does a conversion to
/// the correct datatype.
#[derive(Clone, Debug)]
pub struct SegmentRangeCollector {
pub struct SegmentRangeCollector<const LOWCARD: bool = false> {
/// The buckets containing the aggregation data.
buckets: Vec<SegmentRangeAndBucketEntry>,
/// One for each ParentBucketId
parent_buckets: Vec<Vec<SegmentRangeAndBucketEntry>>,
column_type: ColumnType,
pub(crate) accessor_idx: usize,
sub_agg: Option<CachedSubAggs<LOWCARD>>,
/// Here things get a bit weird. We need to assign unique bucket ids across all
/// parent buckets. So we keep track of the next available bucket id here.
/// This allows a kind of flattening of the bucket ids across all parent buckets.
/// E.g. in nested aggregations:
/// Term Agg -> Range aggregation -> Stats aggregation
/// E.g. the Term Agg creates 3 buckets ["INFO", "ERROR", "WARN"], each of these has a Range
/// aggregation with 4 buckets. The Range aggregation will create buckets with ids:
/// - INFO: 0,1,2,3
/// - ERROR: 4,5,6,7
/// - WARN: 8,9,10,11
///
/// This allows the Stats aggregation to have unique bucket ids to refer to.
bucket_id_provider: BucketIdProvider,
limits: AggregationLimitsGuard,
}
impl<const LOWCARD: bool> Debug for SegmentRangeCollector<LOWCARD> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SegmentRangeCollector")
.field("parent_buckets_len", &self.parent_buckets.len())
.field("column_type", &self.column_type)
.field("accessor_idx", &self.accessor_idx)
.field("has_sub_agg", &self.sub_agg.is_some())
.finish()
}
}
/// TODO: Bad naming, there's also SegmentRangeAndBucketEntry
#[derive(Clone)]
pub(crate) struct SegmentRangeBucketEntry {
pub key: Key,
pub doc_count: u64,
pub sub_aggregation: Option<Box<dyn SegmentAggregationCollector>>,
// pub sub_aggregation: Option<Box<dyn SegmentAggregationCollector>>,
pub bucket_id: BucketId,
/// The from range of the bucket. Equals `f64::MIN` when `None`.
pub from: Option<f64>,
/// The to range of the bucket. Equals `f64::MAX` when `None`. Open interval, `to` is not
@@ -184,48 +214,50 @@ impl Debug for SegmentRangeBucketEntry {
impl SegmentRangeBucketEntry {
pub(crate) fn into_intermediate_bucket_entry(
self,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<IntermediateRangeBucketEntry> {
let mut sub_aggregation_res = IntermediateAggregationResults::default();
if let Some(sub_aggregation) = self.sub_aggregation {
sub_aggregation
.add_intermediate_aggregation_result(agg_data, &mut sub_aggregation_res)?
} else {
Default::default()
};
let sub_aggregation = IntermediateAggregationResults::default();
Ok(IntermediateRangeBucketEntry {
key: self.key.into(),
doc_count: self.doc_count,
sub_aggregation: sub_aggregation_res,
sub_aggregation_res: sub_aggregation,
from: self.from,
to: self.to,
})
}
}
impl SegmentAggregationCollector for SegmentRangeCollector {
impl<const LOWCARD: bool> SegmentAggregationCollector for SegmentRangeCollector<LOWCARD> {
fn add_intermediate_aggregation_result(
self: Box<Self>,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let field_type = self.column_type;
let name = agg_data
.get_range_req_data(self.accessor_idx)
.name
.to_string();
let buckets: FxHashMap<SerializedKey, IntermediateRangeBucketEntry> = self
.buckets
let buckets = std::mem::take(&mut self.parent_buckets[parent_bucket_id as usize]);
let buckets: FxHashMap<SerializedKey, IntermediateRangeBucketEntry> = buckets
.into_iter()
.map(move |range_bucket| {
Ok((
range_to_string(&range_bucket.range, &field_type)?,
range_bucket
.bucket
.into_intermediate_bucket_entry(agg_data)?,
))
.map(|range_bucket| {
let bucket_id = range_bucket.bucket.bucket_id;
let mut agg = range_bucket.bucket.into_intermediate_bucket_entry()?;
if let Some(sub_aggregation) = &mut self.sub_agg {
sub_aggregation
.get_sub_agg_collector()
.add_intermediate_aggregation_result(
agg_data,
&mut agg.sub_aggregation_res,
bucket_id,
)?;
}
Ok((range_to_string(&range_bucket.range, &field_type)?, agg))
})
.collect::<crate::Result<_>>()?;
@@ -242,73 +274,114 @@ impl SegmentAggregationCollector for SegmentRangeCollector {
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_data)
}
#[inline]
fn collect_block(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
// Take request data to avoid borrow conflicts during sub-aggregation
let mut req = agg_data.take_range_req_data(self.accessor_idx);
let req = agg_data.take_range_req_data(self.accessor_idx);
req.column_block_accessor.fetch_block(docs, &req.accessor);
agg_data
.column_block_accessor
.fetch_block(docs, &req.accessor);
for (doc, val) in req
let buckets = &mut self.parent_buckets[parent_bucket_id as usize];
for (doc, val) in agg_data
.column_block_accessor
.iter_docid_vals(docs, &req.accessor)
{
let bucket_pos = self.get_bucket_pos(val);
let bucket = &mut self.buckets[bucket_pos];
let bucket_pos = get_bucket_pos(val, buckets);
let bucket = &mut buckets[bucket_pos];
bucket.bucket.doc_count += 1;
if let Some(sub_agg) = bucket.bucket.sub_aggregation.as_mut() {
sub_agg.collect(doc, agg_data)?;
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.push(bucket.bucket.bucket_id, doc);
}
}
agg_data.put_back_range_req_data(self.accessor_idx, req);
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.check_flush_local(agg_data)?;
}
Ok(())
}
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
for bucket in self.buckets.iter_mut() {
if let Some(sub_agg) = bucket.bucket.sub_aggregation.as_mut() {
sub_agg.flush(agg_data)?;
}
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.flush(agg_data)?;
}
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
while self.parent_buckets.len() <= max_bucket as usize {
let new_buckets = self.create_new_buckets(agg_data)?;
self.parent_buckets.push(new_buckets);
}
Ok(())
}
}
/// Build a concrete `SegmentRangeCollector` with either a Vec- or HashMap-backed
/// bucket storage, depending on the column type and aggregation level.
pub(crate) fn build_segment_range_collector(
agg_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
let accessor_idx = node.idx_in_req_data;
let req_data = agg_data.get_range_req_data(node.idx_in_req_data);
let field_type = req_data.field_type;
// TODO: A better metric instead of is_top_level would be the number of buckets expected.
// E.g. If range agg is not top level, but the parent is a bucket agg with less than 10 buckets,
// we can are still in low cardinality territory.
let is_low_card = req_data.is_top_level && req_data.req.ranges.len() <= 64;
let sub_agg = if !node.children.is_empty() {
Some(build_segment_agg_collectors(agg_data, &node.children)?)
} else {
None
};
if is_low_card {
Ok(Box::new(SegmentRangeCollector {
sub_agg: sub_agg.map(CachedSubAggs::<true>::new),
column_type: field_type,
accessor_idx,
parent_buckets: Vec::new(),
bucket_id_provider: BucketIdProvider::default(),
limits: agg_data.context.limits.clone(),
}))
} else {
Ok(Box::new(SegmentRangeCollector {
sub_agg: sub_agg.map(CachedSubAggs::<false>::new),
column_type: field_type,
accessor_idx,
parent_buckets: Vec::new(),
bucket_id_provider: BucketIdProvider::default(),
limits: agg_data.context.limits.clone(),
}))
}
}
impl SegmentRangeCollector {
pub(crate) fn from_req_and_validate(
req_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Self> {
let accessor_idx = node.idx_in_req_data;
let (field_type, ranges) = {
let req_view = req_data.get_range_req_data(node.idx_in_req_data);
(req_view.field_type, req_view.req.ranges.clone())
};
impl<const LOWCARD: bool> SegmentRangeCollector<LOWCARD> {
pub(crate) fn create_new_buckets(
&mut self,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<Vec<SegmentRangeAndBucketEntry>> {
let field_type = self.column_type;
let req_data = agg_data.get_range_req_data(self.accessor_idx);
// The range input on the request is f64.
// We need to convert to u64 ranges, because we read the values as u64.
// The mapping from the conversion is monotonic so ordering is preserved.
let sub_agg_prototype = if !node.children.is_empty() {
Some(build_segment_agg_collectors(req_data, &node.children)?)
} else {
None
};
let buckets: Vec<_> = extend_validate_ranges(&ranges, &field_type)?
let buckets: Vec<_> = extend_validate_ranges(&req_data.req.ranges, &field_type)?
.iter()
.map(|range| {
let bucket_id = self.bucket_id_provider.next_bucket_id();
let key = range
.key
.clone()
@@ -317,20 +390,20 @@ impl SegmentRangeCollector {
let to = if range.range.end == u64::MAX {
None
} else {
Some(f64_from_fastfield_u64(range.range.end, &field_type))
Some(f64_from_fastfield_u64(range.range.end, field_type))
};
let from = if range.range.start == u64::MIN {
None
} else {
Some(f64_from_fastfield_u64(range.range.start, &field_type))
Some(f64_from_fastfield_u64(range.range.start, field_type))
};
let sub_aggregation = sub_agg_prototype.clone();
// let sub_aggregation = sub_agg_prototype.clone();
Ok(SegmentRangeAndBucketEntry {
range: range.range.clone(),
bucket: SegmentRangeBucketEntry {
doc_count: 0,
sub_aggregation,
bucket_id,
key,
from,
to,
@@ -339,27 +412,20 @@ impl SegmentRangeCollector {
})
.collect::<crate::Result<_>>()?;
req_data.context.limits.add_memory_consumed(
self.limits.add_memory_consumed(
buckets.len() as u64 * std::mem::size_of::<SegmentRangeAndBucketEntry>() as u64,
)?;
Ok(SegmentRangeCollector {
buckets,
column_type: field_type,
accessor_idx,
})
}
#[inline]
fn get_bucket_pos(&self, val: u64) -> usize {
let pos = self
.buckets
.binary_search_by_key(&val, |probe| probe.range.start)
.unwrap_or_else(|pos| pos - 1);
debug_assert!(self.buckets[pos].range.contains(&val));
pos
Ok(buckets)
}
}
#[inline]
fn get_bucket_pos(val: u64, buckets: &[SegmentRangeAndBucketEntry]) -> usize {
let pos = buckets
.binary_search_by_key(&val, |probe| probe.range.start)
.unwrap_or_else(|pos| pos - 1);
debug_assert!(buckets[pos].range.contains(&val));
pos
}
/// Converts the user provided f64 range value to fast field value space.
///
@@ -456,7 +522,7 @@ pub(crate) fn range_to_string(
let val = i64::from_u64(val);
format_date(val)
} else {
Ok(f64_from_fastfield_u64(val, field_type).to_string())
Ok(f64_from_fastfield_u64(val, *field_type).to_string())
}
};
@@ -506,30 +572,33 @@ mod tests {
let to = if range.range.end == u64::MAX {
None
} else {
Some(f64_from_fastfield_u64(range.range.end, &field_type))
Some(f64_from_fastfield_u64(range.range.end, field_type))
};
let from = if range.range.start == u64::MIN {
None
} else {
Some(f64_from_fastfield_u64(range.range.start, &field_type))
Some(f64_from_fastfield_u64(range.range.start, field_type))
};
SegmentRangeAndBucketEntry {
range: range.range.clone(),
bucket: SegmentRangeBucketEntry {
doc_count: 0,
sub_aggregation: None,
key,
from,
to,
bucket_id: 0,
},
}
})
.collect();
SegmentRangeCollector {
buckets,
parent_buckets: vec![buckets],
column_type: field_type,
accessor_idx: 0,
sub_agg: None,
bucket_id_provider: Default::default(),
limits: AggregationLimitsGuard::default(),
}
}
@@ -776,7 +845,7 @@ mod tests {
let buckets = vec![(10f64..20f64).into(), (30f64..40f64).into()];
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = collector.buckets;
let buckets = collector.parent_buckets[0].clone();
assert_eq!(buckets[0].range.start, u64::MIN);
assert_eq!(buckets[0].range.end, 10f64.to_u64());
assert_eq!(buckets[1].range.start, 10f64.to_u64());
@@ -799,7 +868,7 @@ mod tests {
];
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = collector.buckets;
let buckets = collector.parent_buckets[0].clone();
assert_eq!(buckets[0].range.start, u64::MIN);
assert_eq!(buckets[0].range.end, 10f64.to_u64());
assert_eq!(buckets[1].range.start, 10f64.to_u64());
@@ -814,7 +883,7 @@ mod tests {
let buckets = vec![(-10f64..-1f64).into()];
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = collector.buckets;
let buckets = collector.parent_buckets[0].clone();
assert_eq!(&buckets[0].bucket.key.to_string(), "*--10");
assert_eq!(&buckets[buckets.len() - 1].bucket.key.to_string(), "-1-*");
}
@@ -823,7 +892,7 @@ mod tests {
let buckets = vec![(0f64..10f64).into()];
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = collector.buckets;
let buckets = collector.parent_buckets[0].clone();
assert_eq!(&buckets[0].bucket.key.to_string(), "*-0");
assert_eq!(&buckets[buckets.len() - 1].bucket.key.to_string(), "10-*");
}
@@ -832,7 +901,7 @@ mod tests {
fn range_binary_search_test_u64() {
let check_ranges = |ranges: Vec<RangeAggregationRange>| {
let collector = get_collector_from_ranges(ranges, ColumnType::U64);
let search = |val: u64| collector.get_bucket_pos(val);
let search = |val: u64| get_bucket_pos(val, &collector.parent_buckets[0]);
assert_eq!(search(u64::MIN), 0);
assert_eq!(search(9), 0);
@@ -878,7 +947,7 @@ mod tests {
let ranges = vec![(10.0..100.0).into()];
let collector = get_collector_from_ranges(ranges, ColumnType::F64);
let search = |val: u64| collector.get_bucket_pos(val);
let search = |val: u64| get_bucket_pos(val, &collector.parent_buckets[0]);
assert_eq!(search(u64::MIN), 0);
assert_eq!(search(9f64.to_u64()), 0);
@@ -890,63 +959,3 @@ mod tests {
// the max value
}
}
#[cfg(all(test, feature = "unstable"))]
mod bench {
use itertools::Itertools;
use rand::seq::SliceRandom;
use rand::thread_rng;
use super::*;
use crate::aggregation::bucket::range::tests::get_collector_from_ranges;
const TOTAL_DOCS: u64 = 1_000_000u64;
const NUM_DOCS: u64 = 50_000u64;
fn get_collector_with_buckets(num_buckets: u64, num_docs: u64) -> SegmentRangeCollector {
let bucket_size = num_docs / num_buckets;
let mut buckets: Vec<RangeAggregationRange> = vec![];
for i in 0..num_buckets {
let bucket_start = (i * bucket_size) as f64;
buckets.push((bucket_start..bucket_start + bucket_size as f64).into())
}
get_collector_from_ranges(buckets, ColumnType::U64)
}
fn get_rand_docs(total_docs: u64, num_docs_returned: u64) -> Vec<u64> {
let mut rng = thread_rng();
let all_docs = (0..total_docs - 1).collect_vec();
let mut vals = all_docs
.as_slice()
.choose_multiple(&mut rng, num_docs_returned as usize)
.cloned()
.collect_vec();
vals.sort();
vals
}
fn bench_range_binary_search(b: &mut test::Bencher, num_buckets: u64) {
let collector = get_collector_with_buckets(num_buckets, TOTAL_DOCS);
let vals = get_rand_docs(TOTAL_DOCS, NUM_DOCS);
b.iter(|| {
let mut bucket_pos = 0;
for val in &vals {
bucket_pos = collector.get_bucket_pos(*val);
}
bucket_pos
})
}
#[bench]
fn bench_range_100_buckets(b: &mut test::Bencher) {
bench_range_binary_search(b, 100)
}
#[bench]
fn bench_range_10_buckets(b: &mut test::Bencher) {
bench_range_binary_search(b, 10)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -5,11 +5,13 @@ use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::bucket::term_agg::TermsAggregation;
use crate::aggregation::cached_sub_aggs::CachedSubAggs;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
IntermediateKey, IntermediateTermBucketEntry, IntermediateTermBucketResult,
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
use crate::aggregation::BucketId;
/// Special aggregation to handle missing values for term aggregations.
/// This missing aggregation will check multiple columns for existence.
@@ -35,41 +37,55 @@ impl MissingTermAggReqData {
}
}
/// The specialized missing term aggregation.
#[derive(Default, Debug, Clone)]
pub struct TermMissingAgg {
struct MissingCount {
missing_count: u32,
bucket_id: BucketId,
}
/// The specialized missing term aggregation.
#[derive(Default, Debug)]
pub struct TermMissingAgg {
accessor_idx: usize,
sub_agg: Option<Box<dyn SegmentAggregationCollector>>,
sub_agg: Option<CachedSubAggs>,
/// Idx = parent bucket id, Value = missing count for that bucket
missing_count_per_bucket: Vec<MissingCount>,
bucket_id_provider: BucketIdProvider,
}
impl TermMissingAgg {
pub(crate) fn new(
req_data: &mut AggregationsSegmentCtx,
agg_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Self> {
let has_sub_aggregations = !node.children.is_empty();
let accessor_idx = node.idx_in_req_data;
let sub_agg = if has_sub_aggregations {
let sub_aggregation = build_segment_agg_collectors(req_data, &node.children)?;
let sub_aggregation = build_segment_agg_collectors(agg_data, &node.children)?;
Some(sub_aggregation)
} else {
None
};
let sub_agg = sub_agg.map(CachedSubAggs::new);
let bucket_id_provider = BucketIdProvider::default();
Ok(Self {
accessor_idx,
sub_agg,
..Default::default()
missing_count_per_bucket: Vec::new(),
bucket_id_provider,
})
}
}
impl SegmentAggregationCollector for TermMissingAgg {
fn add_intermediate_aggregation_result(
self: Box<Self>,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let req_data = agg_data.get_missing_term_req_data(self.accessor_idx);
let term_agg = &req_data.req;
let missing = term_agg
@@ -80,13 +96,16 @@ impl SegmentAggregationCollector for TermMissingAgg {
let mut entries: FxHashMap<IntermediateKey, IntermediateTermBucketEntry> =
Default::default();
let missing_count = &self.missing_count_per_bucket[parent_bucket_id as usize];
let mut missing_entry = IntermediateTermBucketEntry {
doc_count: self.missing_count,
doc_count: missing_count.missing_count,
sub_aggregation: Default::default(),
};
if let Some(sub_agg) = self.sub_agg {
if let Some(sub_agg) = &mut self.sub_agg {
let mut res = IntermediateAggregationResults::default();
sub_agg.add_intermediate_aggregation_result(agg_data, &mut res)?;
sub_agg
.get_sub_agg_collector()
.add_intermediate_aggregation_result(agg_data, &mut res, missing_count.bucket_id)?;
missing_entry.sub_aggregation = res;
}
entries.insert(missing.into(), missing_entry);
@@ -109,30 +128,52 @@ impl SegmentAggregationCollector for TermMissingAgg {
fn collect(
&mut self,
doc: crate::DocId,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let bucket = &mut self.missing_count_per_bucket[parent_bucket_id as usize];
let req_data = agg_data.get_missing_term_req_data(self.accessor_idx);
let has_value = req_data
.accessors
.iter()
.any(|(acc, _)| acc.index.has_value(doc));
if !has_value {
self.missing_count += 1;
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.collect(doc, agg_data)?;
for doc in docs {
let doc = *doc;
let has_value = req_data
.accessors
.iter()
.any(|(acc, _)| acc.index.has_value(doc));
if !has_value {
bucket.missing_count += 1;
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.push(bucket.bucket_id, doc);
}
}
}
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.check_flush_local(agg_data)?;
}
Ok(())
}
fn collect_block(
fn prepare_max_bucket(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
for doc in docs {
self.collect(*doc, agg_data)?;
while self.missing_count_per_bucket.len() <= max_bucket as usize {
let bucket_id = self.bucket_id_provider.next_bucket_id();
self.missing_count_per_bucket.push(MissingCount {
missing_count: 0,
bucket_id,
});
}
Ok(())
}
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
if let Some(sub_agg) = self.sub_agg.as_mut() {
sub_agg.flush(agg_data)?;
}
Ok(())
}

View File

@@ -1,87 +0,0 @@
use super::intermediate_agg_result::IntermediateAggregationResults;
use super::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::DocId;
#[cfg(test)]
pub(crate) const DOC_BLOCK_SIZE: usize = 64;
#[cfg(not(test))]
pub(crate) const DOC_BLOCK_SIZE: usize = 256;
pub(crate) type DocBlock = [DocId; DOC_BLOCK_SIZE];
/// BufAggregationCollector buffers documents before calling collect_block().
#[derive(Clone)]
pub(crate) struct BufAggregationCollector {
pub(crate) collector: Box<dyn SegmentAggregationCollector>,
staged_docs: DocBlock,
num_staged_docs: usize,
}
impl std::fmt::Debug for BufAggregationCollector {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("SegmentAggregationResultsCollector")
.field("staged_docs", &&self.staged_docs[..self.num_staged_docs])
.field("num_staged_docs", &self.num_staged_docs)
.finish()
}
}
impl BufAggregationCollector {
pub fn new(collector: Box<dyn SegmentAggregationCollector>) -> Self {
Self {
collector,
num_staged_docs: 0,
staged_docs: [0; DOC_BLOCK_SIZE],
}
}
}
impl SegmentAggregationCollector for BufAggregationCollector {
#[inline]
fn add_intermediate_aggregation_result(
self: Box<Self>,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
) -> crate::Result<()> {
Box::new(self.collector).add_intermediate_aggregation_result(agg_data, results)
}
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.staged_docs[self.num_staged_docs] = doc;
self.num_staged_docs += 1;
if self.num_staged_docs == self.staged_docs.len() {
self.collector
.collect_block(&self.staged_docs[..self.num_staged_docs], agg_data)?;
self.num_staged_docs = 0;
}
Ok(())
}
#[inline]
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collector.collect_block(docs, agg_data)?;
Ok(())
}
#[inline]
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
self.collector
.collect_block(&self.staged_docs[..self.num_staged_docs], agg_data)?;
self.num_staged_docs = 0;
self.collector.flush(agg_data)?;
Ok(())
}
}

View File

@@ -0,0 +1,185 @@
use super::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::bucket::MAX_NUM_TERMS_FOR_VEC;
use crate::aggregation::BucketId;
use crate::DocId;
#[derive(Debug)]
/// A cache for sub-aggregations, storing doc ids per bucket id.
/// Depending on the cardinality of the parent aggregation, we use different
/// storage strategies.
///
/// ## Low Cardinality
/// Cardinality here refers to the number of unique flattened buckets that can be created
/// by the parent aggregation.
/// Flattened buckets are the result of combining all buckets per collector
/// into a single list of buckets, where each bucket is identified by its BucketId.
///
/// ## Usage
/// Since this is caching for sub-aggregations, it is only used by bucket
/// aggregations.
///
/// TODO: consider using a more advanced data structure for high cardinality
/// aggregations.
/// What this datastructure does in general is to group docs by bucket id.
pub(crate) struct CachedSubAggs<const LOWCARD: bool = false> {
/// Only used when LOWCARD is true.
/// Cache doc ids per bucket for sub-aggregations.
///
/// The outer Vec is indexed by BucketId.
per_bucket_docs: Vec<Vec<DocId>>,
/// Only used when LOWCARD is false.
///
/// This weird partitioning is used to do some cheap grouping on the bucket ids.
/// bucket ids are dense, e.g. when we don't detect the cardinality as low cardinality,
/// but there are just 16 bucket ids, each bucket id will go to its own partition.
///
/// We want to keep this cheap, because high cardinality aggregations can have a lot of
/// buckets, and they may be nothing to group.
partitions: [PartitionEntry; NUM_PARTITIONS],
pub(crate) sub_agg_collector: Box<dyn SegmentAggregationCollector>,
num_docs: usize,
}
const FLUSH_THRESHOLD: usize = 2048;
const NUM_PARTITIONS: usize = 16;
impl<const LOWCARD: bool> CachedSubAggs<LOWCARD> {
pub fn get_sub_agg_collector(&mut self) -> &mut Box<dyn SegmentAggregationCollector> {
&mut self.sub_agg_collector
}
pub fn new(sub_agg: Box<dyn SegmentAggregationCollector>) -> Self {
Self {
per_bucket_docs: Vec::new(),
num_docs: 0,
sub_agg_collector: sub_agg,
partitions: core::array::from_fn(|_| PartitionEntry::default()),
}
}
#[inline]
pub fn clear(&mut self) {
for v in &mut self.per_bucket_docs {
v.clear();
}
for partition in &mut self.partitions {
partition.clear();
}
self.num_docs = 0;
}
#[inline]
pub fn push(&mut self, bucket_id: BucketId, doc_id: DocId) {
if LOWCARD {
// TODO: We could flush single buckets here
let idx = bucket_id as usize;
if self.per_bucket_docs.len() <= idx {
self.per_bucket_docs.resize_with(idx + 1, Vec::new);
}
self.per_bucket_docs[idx].push(doc_id);
} else {
let idx = bucket_id % NUM_PARTITIONS as u32;
let slot = &mut self.partitions[idx as usize];
slot.bucket_ids.push(bucket_id);
slot.docs.push(doc_id);
}
self.num_docs += 1;
}
/// Check if we need to flush based on the number of documents cached.
/// If so, flushes the cache to the provided aggregation collector.
pub fn check_flush_local(
&mut self,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
if self.num_docs >= FLUSH_THRESHOLD {
self.flush_local(agg_data, false)?;
}
Ok(())
}
/// Note: this does _not_ flush the sub aggregations
fn flush_local(
&mut self,
agg_data: &mut AggregationsSegmentCtx,
force: bool,
) -> crate::Result<()> {
if LOWCARD {
// Pre-aggregated: call collect per bucket.
let max_bucket = (self.per_bucket_docs.len() as BucketId).saturating_sub(1);
self.sub_agg_collector
.prepare_max_bucket(max_bucket, agg_data)?;
// The threshold above which we flush buckets individually.
// Note: We need to make sure that we don't lock ourselves into a situation where we hit
// the FLUSH_THRESHOLD, but never flush any buckets. (except the final flush)
let mut bucket_treshold = FLUSH_THRESHOLD / (self.per_bucket_docs.len().max(1) * 2);
const _: () = {
// MAX_NUM_TERMS_FOR_VEC == LOWCARD threshold
let bucket_treshold = FLUSH_THRESHOLD / (MAX_NUM_TERMS_FOR_VEC as usize * 2);
assert!(
bucket_treshold > 0,
"Bucket threshold must be greater than 0"
);
};
if force {
bucket_treshold = 0;
}
for (bucket_id, docs) in self
.per_bucket_docs
.iter()
.enumerate()
.filter(|(_, docs)| docs.len() > bucket_treshold)
{
self.sub_agg_collector
.collect(bucket_id as BucketId, docs, agg_data)?;
}
} else {
let mut max_bucket = 0u32;
for partition in &self.partitions {
if let Some(&local_max) = partition.bucket_ids.iter().max() {
max_bucket = max_bucket.max(local_max);
}
}
self.sub_agg_collector
.prepare_max_bucket(max_bucket, agg_data)?;
for slot in &self.partitions {
if !slot.bucket_ids.is_empty() {
// Reduce dynamic dispatch overhead by collecting a full partition in one call.
self.sub_agg_collector.collect_multiple(
&slot.bucket_ids,
&slot.docs,
agg_data,
)?;
}
}
}
self.clear();
Ok(())
}
/// Note: this _does_ flush the sub aggregations
pub fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
if self.num_docs != 0 {
self.flush_local(agg_data, true)?;
}
self.sub_agg_collector.flush(agg_data)?;
Ok(())
}
}
#[derive(Debug, Clone, Default)]
struct PartitionEntry {
bucket_ids: Vec<BucketId>,
docs: Vec<DocId>,
}
impl PartitionEntry {
#[inline]
fn clear(&mut self) {
self.bucket_ids.clear();
self.docs.clear();
}
}

View File

@@ -1,9 +1,9 @@
use super::agg_req::Aggregations;
use super::agg_result::AggregationResults;
use super::buf_collector::BufAggregationCollector;
use super::cached_sub_aggs::CachedSubAggs;
use super::intermediate_agg_result::IntermediateAggregationResults;
use super::segment_agg_result::SegmentAggregationCollector;
use super::AggContextParams;
// group buffering strategy is chosen explicitly by callers; no need to hash-group on the fly.
use crate::aggregation::agg_data::{
build_aggregations_data_from_req, build_segment_agg_collectors_root, AggregationsSegmentCtx,
};
@@ -136,7 +136,7 @@ fn merge_fruits(
/// `AggregationSegmentCollector` does the aggregation collection on a segment.
pub struct AggregationSegmentCollector {
aggs_with_accessor: AggregationsSegmentCtx,
agg_collector: BufAggregationCollector,
agg_collector: CachedSubAggs<true>,
error: Option<TantivyError>,
}
@@ -151,8 +151,10 @@ impl AggregationSegmentCollector {
) -> crate::Result<Self> {
let mut agg_data =
build_aggregations_data_from_req(agg, reader, segment_ordinal, context.clone())?;
let result =
BufAggregationCollector::new(build_segment_agg_collectors_root(&mut agg_data)?);
let mut result = CachedSubAggs::new(build_segment_agg_collectors_root(&mut agg_data)?);
result
.get_sub_agg_collector()
.prepare_max_bucket(0, &agg_data)?; // prepare for bucket zero
Ok(AggregationSegmentCollector {
aggs_with_accessor: agg_data,
@@ -170,26 +172,31 @@ impl SegmentCollector for AggregationSegmentCollector {
if self.error.is_some() {
return;
}
if let Err(err) = self
self.agg_collector.push(0, doc);
match self
.agg_collector
.collect(doc, &mut self.aggs_with_accessor)
.check_flush_local(&mut self.aggs_with_accessor)
{
self.error = Some(err);
Ok(_) => {}
Err(e) => {
self.error = Some(e);
}
}
}
/// The query pushes the documents to the collector via this method.
///
/// Only valid for Collectors that ignore docs
fn collect_block(&mut self, docs: &[DocId]) {
if self.error.is_some() {
return;
}
if let Err(err) = self
.agg_collector
.collect_block(docs, &mut self.aggs_with_accessor)
{
self.error = Some(err);
match self.agg_collector.get_sub_agg_collector().collect(
0,
docs,
&mut self.aggs_with_accessor,
) {
Ok(_) => {}
Err(e) => {
self.error = Some(e);
}
}
}
@@ -200,10 +207,13 @@ impl SegmentCollector for AggregationSegmentCollector {
self.agg_collector.flush(&mut self.aggs_with_accessor)?;
let mut sub_aggregation_res = IntermediateAggregationResults::default();
Box::new(self.agg_collector).add_intermediate_aggregation_result(
&self.aggs_with_accessor,
&mut sub_aggregation_res,
)?;
self.agg_collector
.get_sub_agg_collector()
.add_intermediate_aggregation_result(
&self.aggs_with_accessor,
&mut sub_aggregation_res,
0,
)?;
Ok(sub_aggregation_res)
}

View File

@@ -792,7 +792,7 @@ pub struct IntermediateRangeBucketEntry {
/// The number of documents in the bucket.
pub doc_count: u64,
/// The sub_aggregation in this bucket.
pub sub_aggregation: IntermediateAggregationResults,
pub sub_aggregation_res: IntermediateAggregationResults,
/// The from range of the bucket. Equals `f64::MIN` when `None`.
pub from: Option<f64>,
/// The to range of the bucket. Equals `f64::MAX` when `None`.
@@ -811,7 +811,7 @@ impl IntermediateRangeBucketEntry {
key: self.key.into(),
doc_count: self.doc_count,
sub_aggregation: self
.sub_aggregation
.sub_aggregation_res
.into_final_result_internal(req, limits)?,
to: self.to,
from: self.from,
@@ -857,7 +857,8 @@ impl MergeFruits for IntermediateTermBucketEntry {
impl MergeFruits for IntermediateRangeBucketEntry {
fn merge_fruits(&mut self, other: IntermediateRangeBucketEntry) -> crate::Result<()> {
self.doc_count += other.doc_count;
self.sub_aggregation.merge_fruits(other.sub_aggregation)?;
self.sub_aggregation_res
.merge_fruits(other.sub_aggregation_res)?;
Ok(())
}
}
@@ -887,7 +888,7 @@ mod tests {
IntermediateRangeBucketEntry {
key: IntermediateKey::Str(key.to_string()),
doc_count: *doc_count,
sub_aggregation: Default::default(),
sub_aggregation_res: Default::default(),
from: None,
to: None,
},
@@ -920,7 +921,7 @@ mod tests {
doc_count: *doc_count,
from: None,
to: None,
sub_aggregation: get_sub_test_tree(&[(
sub_aggregation_res: get_sub_test_tree(&[(
sub_aggregation_key.to_string(),
*sub_aggregation_count,
)]),

View File

@@ -52,10 +52,8 @@ pub struct IntermediateAverage {
impl IntermediateAverage {
/// Creates a new [`IntermediateAverage`] instance from a [`SegmentStatsCollector`].
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
Self {
stats: collector.stats,
}
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
Self { stats }
}
/// Merges the other intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateAverage) {

View File

@@ -2,7 +2,7 @@ use std::collections::hash_map::DefaultHasher;
use std::hash::{BuildHasher, Hasher};
use columnar::column_values::CompactSpaceU64Accessor;
use columnar::{Column, ColumnBlockAccessor, ColumnType, Dictionary, StrColumn};
use columnar::{Column, ColumnType, Dictionary, StrColumn};
use common::f64_to_u64;
use hyperloglogplus::{HyperLogLog, HyperLogLogPlus};
use rustc_hash::FxHashSet;
@@ -106,8 +106,6 @@ pub struct CardinalityAggReqData {
pub str_dict_column: Option<StrColumn>,
/// The missing value normalized to the internal u64 representation of the field type.
pub missing_value_for_accessor: Option<u64>,
/// The column block accessor to access the fast field values.
pub(crate) column_block_accessor: ColumnBlockAccessor<u64>,
/// The name of the aggregation.
pub name: String,
/// The aggregation request.
@@ -135,45 +133,34 @@ impl CardinalityAggregationReq {
}
}
#[derive(Clone, Debug, PartialEq)]
#[derive(Clone, Debug)]
pub(crate) struct SegmentCardinalityCollector {
cardinality: CardinalityCollector,
entries: FxHashSet<u64>,
buckets: Vec<SegmentCardinalityCollectorBucket>,
accessor_idx: usize,
/// The column accessor to access the fast field values.
accessor: Column<u64>,
/// The column_type of the field.
column_type: ColumnType,
/// The missing value normalized to the internal u64 representation of the field type.
missing_value_for_accessor: Option<u64>,
}
impl SegmentCardinalityCollector {
pub fn from_req(column_type: ColumnType, accessor_idx: usize) -> Self {
#[derive(Clone, Debug, PartialEq, Default)]
pub(crate) struct SegmentCardinalityCollectorBucket {
cardinality: CardinalityCollector,
entries: FxHashSet<u64>,
}
impl SegmentCardinalityCollectorBucket {
pub fn new(column_type: ColumnType) -> Self {
Self {
cardinality: CardinalityCollector::new(column_type as u8),
entries: Default::default(),
accessor_idx,
entries: FxHashSet::default(),
}
}
fn fetch_block_with_field(
&mut self,
docs: &[crate::DocId],
agg_data: &mut CardinalityAggReqData,
) {
if let Some(missing) = agg_data.missing_value_for_accessor {
agg_data.column_block_accessor.fetch_block_with_missing(
docs,
&agg_data.accessor,
missing,
);
} else {
agg_data
.column_block_accessor
.fetch_block(docs, &agg_data.accessor);
}
}
fn into_intermediate_metric_result(
mut self,
agg_data: &AggregationsSegmentCtx,
req_data: &CardinalityAggReqData,
) -> crate::Result<IntermediateMetricResult> {
let req_data = &agg_data.get_cardinality_req_data(self.accessor_idx);
if req_data.column_type == ColumnType::Str {
let fallback_dict = Dictionary::empty();
let dict = req_data
@@ -194,6 +181,7 @@ impl SegmentCardinalityCollector {
term_ids.push(term_ord as u32);
}
}
term_ids.sort_unstable();
dict.sorted_ords_to_term_cb(term_ids.iter().map(|term| *term as u64), |term| {
self.cardinality.sketch.insert_any(&term);
@@ -227,16 +215,49 @@ impl SegmentCardinalityCollector {
}
}
impl SegmentCardinalityCollector {
pub fn from_req(
column_type: ColumnType,
accessor_idx: usize,
accessor: Column<u64>,
missing_value_for_accessor: Option<u64>,
) -> Self {
Self {
buckets: vec![SegmentCardinalityCollectorBucket::new(column_type); 1],
column_type,
accessor_idx,
accessor,
missing_value_for_accessor,
}
}
fn fetch_block_with_field(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) {
agg_data.column_block_accessor.fetch_block_with_missing(
docs,
&self.accessor,
self.missing_value_for_accessor,
);
}
}
impl SegmentAggregationCollector for SegmentCardinalityCollector {
fn add_intermediate_aggregation_result(
self: Box<Self>,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let req_data = &agg_data.get_cardinality_req_data(self.accessor_idx);
let name = req_data.name.to_string();
// take the bucket in buckets and replace it with a new empty one
let bucket = std::mem::take(&mut self.buckets[parent_bucket_id as usize]);
let intermediate_result = self.into_intermediate_metric_result(agg_data)?;
let intermediate_result = bucket.into_intermediate_metric_result(req_data)?;
results.push(
name,
IntermediateAggregationResult::Metric(intermediate_result),
@@ -247,27 +268,20 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector {
fn collect(
&mut self,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_data)
}
fn collect_block(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let req_data = agg_data.get_cardinality_req_data_mut(self.accessor_idx);
self.fetch_block_with_field(docs, req_data);
self.fetch_block_with_field(docs, agg_data);
let bucket = &mut self.buckets[parent_bucket_id as usize];
let col_block_accessor = &req_data.column_block_accessor;
if req_data.column_type == ColumnType::Str {
let col_block_accessor = &agg_data.column_block_accessor;
if self.column_type == ColumnType::Str {
for term_ord in col_block_accessor.iter_vals() {
self.entries.insert(term_ord);
bucket.entries.insert(term_ord);
}
} else if req_data.column_type == ColumnType::IpAddr {
let compact_space_accessor = req_data
} else if self.column_type == ColumnType::IpAddr {
let compact_space_accessor = self
.accessor
.values
.clone()
@@ -282,16 +296,29 @@ impl SegmentAggregationCollector for SegmentCardinalityCollector {
})?;
for val in col_block_accessor.iter_vals() {
let val: u128 = compact_space_accessor.compact_to_u128(val as u32);
self.cardinality.sketch.insert_any(&val);
bucket.cardinality.sketch.insert_any(&val);
}
} else {
for val in col_block_accessor.iter_vals() {
self.cardinality.sketch.insert_any(&val);
bucket.cardinality.sketch.insert_any(&val);
}
}
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
if max_bucket as usize >= self.buckets.len() {
self.buckets.resize_with(max_bucket as usize + 1, || {
SegmentCardinalityCollectorBucket::new(self.column_type)
});
}
Ok(())
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]

View File

@@ -52,10 +52,8 @@ pub struct IntermediateCount {
impl IntermediateCount {
/// Creates a new [`IntermediateCount`] instance from a [`SegmentStatsCollector`].
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
Self {
stats: collector.stats,
}
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
Self { stats }
}
/// Merges the other intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateCount) {

View File

@@ -8,10 +8,9 @@ use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
};
use crate::aggregation::metric::MetricAggReqData;
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::*;
use crate::{DocId, TantivyError};
use crate::TantivyError;
/// A multi-value metric aggregation that computes a collection of extended statistics
/// on numeric values that are extracted
@@ -318,51 +317,28 @@ impl IntermediateExtendedStats {
}
}
#[derive(Clone, Debug, PartialEq)]
#[derive(Clone, Debug)]
pub(crate) struct SegmentExtendedStatsCollector {
name: String,
missing: Option<u64>,
field_type: ColumnType,
pub(crate) extended_stats: IntermediateExtendedStats,
pub(crate) accessor_idx: usize,
val_cache: Vec<u64>,
accessor: columnar::Column<u64>,
buckets: Vec<IntermediateExtendedStats>,
sigma: Option<f64>,
}
impl SegmentExtendedStatsCollector {
pub fn from_req(
field_type: ColumnType,
sigma: Option<f64>,
accessor_idx: usize,
missing: Option<f64>,
) -> Self {
let missing = missing.and_then(|val| f64_to_fastfield_u64(val, &field_type));
pub fn from_req(req: &MetricAggReqData, sigma: Option<f64>) -> Self {
let missing = req
.missing
.and_then(|val| f64_to_fastfield_u64(val, &req.field_type));
Self {
field_type,
extended_stats: IntermediateExtendedStats::with_sigma(sigma),
accessor_idx,
name: req.name.clone(),
field_type: req.field_type,
accessor: req.accessor.clone(),
missing,
val_cache: Default::default(),
}
}
#[inline]
pub(crate) fn collect_block_with_field(
&mut self,
docs: &[DocId],
req_data: &mut MetricAggReqData,
) {
if let Some(missing) = self.missing.as_ref() {
req_data.column_block_accessor.fetch_block_with_missing(
docs,
&req_data.accessor,
*missing,
);
} else {
req_data
.column_block_accessor
.fetch_block(docs, &req_data.accessor);
}
for val in req_data.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.extended_stats.collect(val1);
buckets: vec![IntermediateExtendedStats::with_sigma(sigma); 16],
sigma,
}
}
}
@@ -370,15 +346,18 @@ impl SegmentExtendedStatsCollector {
impl SegmentAggregationCollector for SegmentExtendedStatsCollector {
#[inline]
fn add_intermediate_aggregation_result(
self: Box<Self>,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let name = agg_data.get_metric_req_data(self.accessor_idx).name.clone();
let name = self.name.clone();
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let extended_stats = std::mem::take(&mut self.buckets[parent_bucket_id as usize]);
results.push(
name,
IntermediateAggregationResult::Metric(IntermediateMetricResult::ExtendedStats(
self.extended_stats,
extended_stats,
)),
)?;
@@ -388,39 +367,36 @@ impl SegmentAggregationCollector for SegmentExtendedStatsCollector {
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let req_data = agg_data.get_metric_req_data(self.accessor_idx);
if let Some(missing) = self.missing {
let mut has_val = false;
for val in req_data.accessor.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.extended_stats.collect(val1);
has_val = true;
}
if !has_val {
self.extended_stats
.collect(f64_from_fastfield_u64(missing, &self.field_type));
}
} else {
for val in req_data.accessor.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.extended_stats.collect(val1);
}
let mut extended_stats = self.buckets[parent_bucket_id as usize].clone();
agg_data
.column_block_accessor
.fetch_block_with_missing(docs, &self.accessor, self.missing);
for val in agg_data.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, self.field_type);
extended_stats.collect(val1);
}
// store back
self.buckets[parent_bucket_id as usize] = extended_stats;
Ok(())
}
#[inline]
fn collect_block(
fn prepare_max_bucket(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx);
self.collect_block_with_field(docs, req_data);
if self.buckets.len() <= max_bucket as usize {
self.buckets.resize_with(max_bucket as usize + 1, || {
IntermediateExtendedStats::with_sigma(self.sigma)
});
}
Ok(())
}
}

View File

@@ -52,10 +52,8 @@ pub struct IntermediateMax {
impl IntermediateMax {
/// Creates a new [`IntermediateMax`] instance from a [`SegmentStatsCollector`].
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
Self {
stats: collector.stats,
}
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
Self { stats }
}
/// Merges the other intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateMax) {

View File

@@ -52,10 +52,8 @@ pub struct IntermediateMin {
impl IntermediateMin {
/// Creates a new [`IntermediateMin`] instance from a [`SegmentStatsCollector`].
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
Self {
stats: collector.stats,
}
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
Self { stats }
}
/// Merges the other intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateMin) {

View File

@@ -31,7 +31,7 @@ use std::collections::HashMap;
pub use average::*;
pub use cardinality::*;
use columnar::{Column, ColumnBlockAccessor, ColumnType};
use columnar::{Column, ColumnType};
pub use count::*;
pub use extended_stats::*;
pub use max::*;
@@ -55,8 +55,6 @@ pub struct MetricAggReqData {
pub field_type: ColumnType,
/// The missing value normalized to the internal u64 representation of the field type.
pub missing_u64: Option<u64>,
/// The column block accessor to access the fast field values.
pub column_block_accessor: ColumnBlockAccessor<u64>,
/// The column accessor to access the fast field values.
pub accessor: Column<u64>,
/// Used when converting to intermediate result

View File

@@ -7,10 +7,9 @@ use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
};
use crate::aggregation::metric::MetricAggReqData;
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::*;
use crate::{DocId, TantivyError};
use crate::TantivyError;
/// # Percentiles
///
@@ -131,10 +130,16 @@ impl PercentilesAggregationReq {
}
}
#[derive(Clone, Debug, PartialEq)]
#[derive(Clone, Debug)]
pub(crate) struct SegmentPercentilesCollector {
pub(crate) percentiles: PercentilesCollector,
pub(crate) buckets: Vec<PercentilesCollector>,
pub(crate) accessor_idx: usize,
/// The type of the field.
pub field_type: ColumnType,
/// The missing value normalized to the internal u64 representation of the field type.
pub missing_u64: Option<u64>,
/// The column accessor to access the fast field values.
pub accessor: Column<u64>,
}
#[derive(Clone, Serialize, Deserialize)]
@@ -229,33 +234,18 @@ impl PercentilesCollector {
}
impl SegmentPercentilesCollector {
pub fn from_req_and_validate(accessor_idx: usize) -> crate::Result<Self> {
Ok(Self {
percentiles: PercentilesCollector::new(),
pub fn from_req_and_validate(
field_type: ColumnType,
missing_u64: Option<u64>,
accessor: Column<u64>,
accessor_idx: usize,
) -> Self {
Self {
buckets: Vec::with_capacity(64),
field_type,
missing_u64,
accessor,
accessor_idx,
})
}
#[inline]
pub(crate) fn collect_block_with_field(
&mut self,
docs: &[DocId],
req_data: &mut MetricAggReqData,
) {
if let Some(missing) = req_data.missing_u64.as_ref() {
req_data.column_block_accessor.fetch_block_with_missing(
docs,
&req_data.accessor,
*missing,
);
} else {
req_data
.column_block_accessor
.fetch_block(docs, &req_data.accessor);
}
for val in req_data.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
self.percentiles.collect(val1);
}
}
}
@@ -263,12 +253,18 @@ impl SegmentPercentilesCollector {
impl SegmentAggregationCollector for SegmentPercentilesCollector {
#[inline]
fn add_intermediate_aggregation_result(
self: Box<Self>,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let name = agg_data.get_metric_req_data(self.accessor_idx).name.clone();
let intermediate_metric_result = IntermediateMetricResult::Percentiles(self.percentiles);
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
// Swap collector with an empty one to avoid cloning
let percentiles_collector = std::mem::take(&mut self.buckets[parent_bucket_id as usize]);
let intermediate_metric_result =
IntermediateMetricResult::Percentiles(percentiles_collector);
results.push(
name,
@@ -281,40 +277,33 @@ impl SegmentAggregationCollector for SegmentPercentilesCollector {
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let req_data = agg_data.get_metric_req_data(self.accessor_idx);
let percentiles = &mut self.buckets[parent_bucket_id as usize];
agg_data.column_block_accessor.fetch_block_with_missing(
docs,
&self.accessor,
self.missing_u64,
);
if let Some(missing) = req_data.missing_u64 {
let mut has_val = false;
for val in req_data.accessor.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
self.percentiles.collect(val1);
has_val = true;
}
if !has_val {
self.percentiles
.collect(f64_from_fastfield_u64(missing, &req_data.field_type));
}
} else {
for val in req_data.accessor.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
self.percentiles.collect(val1);
}
for val in agg_data.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, self.field_type);
percentiles.collect(val1);
}
Ok(())
}
#[inline]
fn collect_block(
fn prepare_max_bucket(
&mut self,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx);
self.collect_block_with_field(docs, req_data);
while self.buckets.len() <= max_bucket as usize {
self.buckets.push(PercentilesCollector::new());
}
Ok(())
}
}

View File

@@ -1,5 +1,6 @@
use std::fmt::Debug;
use columnar::{Column, ColumnType};
use serde::{Deserialize, Serialize};
use super::*;
@@ -7,10 +8,9 @@ use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult,
};
use crate::aggregation::metric::MetricAggReqData;
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::*;
use crate::{DocId, TantivyError};
use crate::TantivyError;
/// A multi-value metric aggregation that computes a collection of statistics on numeric values that
/// are extracted from the aggregated documents.
@@ -83,7 +83,7 @@ impl Stats {
/// Intermediate result of the stats aggregation that can be combined with other intermediate
/// results.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
pub struct IntermediateStats {
/// The number of extracted values.
pub(crate) count: u64,
@@ -187,75 +187,75 @@ pub enum StatsType {
Percentiles,
}
fn create_collector<const TYPE_ID: u8>(
req: &MetricAggReqData,
) -> Box<dyn SegmentAggregationCollector> {
Box::new(SegmentStatsCollector::<TYPE_ID> {
name: req.name.clone(),
collecting_for: req.collecting_for,
is_number_or_date_type: req.is_number_or_date_type,
missing_u64: req.missing_u64,
accessor: req.accessor.clone(),
buckets: vec![IntermediateStats::default()],
})
}
/// Build a concrete `SegmentStatsCollector` depending on the column type.
pub(crate) fn build_segment_stats_collector(
req: &MetricAggReqData,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
match req.field_type {
ColumnType::I64 => Ok(create_collector::<{ ColumnType::I64 as u8 }>(req)),
ColumnType::U64 => Ok(create_collector::<{ ColumnType::U64 as u8 }>(req)),
ColumnType::F64 => Ok(create_collector::<{ ColumnType::F64 as u8 }>(req)),
ColumnType::Bool => Ok(create_collector::<{ ColumnType::Bool as u8 }>(req)),
ColumnType::DateTime => Ok(create_collector::<{ ColumnType::DateTime as u8 }>(req)),
ColumnType::Bytes => Ok(create_collector::<{ ColumnType::Bytes as u8 }>(req)),
ColumnType::Str => Ok(create_collector::<{ ColumnType::Str as u8 }>(req)),
ColumnType::IpAddr => Ok(create_collector::<{ ColumnType::IpAddr as u8 }>(req)),
}
}
#[repr(C)]
#[derive(Clone, Debug)]
pub(crate) struct SegmentStatsCollector {
pub(crate) stats: IntermediateStats,
pub(crate) accessor_idx: usize,
pub(crate) struct SegmentStatsCollector<const COLUMN_TYPE_ID: u8> {
pub(crate) missing_u64: Option<u64>,
pub(crate) accessor: Column<u64>,
pub(crate) is_number_or_date_type: bool,
pub(crate) buckets: Vec<IntermediateStats>,
pub(crate) name: String,
pub(crate) collecting_for: StatsType,
}
impl SegmentStatsCollector {
pub fn from_req(accessor_idx: usize) -> Self {
Self {
stats: IntermediateStats::default(),
accessor_idx,
}
}
#[inline]
pub(crate) fn collect_block_with_field(
&mut self,
docs: &[DocId],
req_data: &mut MetricAggReqData,
) {
if let Some(missing) = req_data.missing_u64.as_ref() {
req_data.column_block_accessor.fetch_block_with_missing(
docs,
&req_data.accessor,
*missing,
);
} else {
req_data
.column_block_accessor
.fetch_block(docs, &req_data.accessor);
}
if req_data.is_number_or_date_type {
for val in req_data.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
self.stats.collect(val1);
}
} else {
for _val in req_data.column_block_accessor.iter_vals() {
// we ignore the value and simply record that we got something
self.stats.collect(0.0);
}
}
}
}
impl SegmentAggregationCollector for SegmentStatsCollector {
impl<const COLUMN_TYPE_ID: u8> SegmentAggregationCollector
for SegmentStatsCollector<COLUMN_TYPE_ID>
{
#[inline]
fn add_intermediate_aggregation_result(
self: Box<Self>,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let req = agg_data.get_metric_req_data(self.accessor_idx);
let name = req.name.clone();
let name = self.name.clone();
let intermediate_metric_result = match req.collecting_for {
self.prepare_max_bucket(parent_bucket_id, agg_data)?;
let stats = self.buckets[parent_bucket_id as usize];
let intermediate_metric_result = match self.collecting_for {
StatsType::Average => {
IntermediateMetricResult::Average(IntermediateAverage::from_collector(*self))
IntermediateMetricResult::Average(IntermediateAverage::from_stats(stats))
}
StatsType::Count => {
IntermediateMetricResult::Count(IntermediateCount::from_collector(*self))
IntermediateMetricResult::Count(IntermediateCount::from_stats(stats))
}
StatsType::Max => IntermediateMetricResult::Max(IntermediateMax::from_collector(*self)),
StatsType::Min => IntermediateMetricResult::Min(IntermediateMin::from_collector(*self)),
StatsType::Stats => IntermediateMetricResult::Stats(self.stats),
StatsType::Sum => IntermediateMetricResult::Sum(IntermediateSum::from_collector(*self)),
StatsType::Max => IntermediateMetricResult::Max(IntermediateMax::from_stats(stats)),
StatsType::Min => IntermediateMetricResult::Min(IntermediateMin::from_stats(stats)),
StatsType::Stats => IntermediateMetricResult::Stats(stats),
StatsType::Sum => IntermediateMetricResult::Sum(IntermediateSum::from_stats(stats)),
_ => {
return Err(TantivyError::InvalidArgument(format!(
"Unsupported stats type for stats aggregation: {:?}",
req.collecting_for
self.collecting_for
)))
}
};
@@ -271,41 +271,67 @@ impl SegmentAggregationCollector for SegmentStatsCollector {
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let req_data = agg_data.get_metric_req_data(self.accessor_idx);
if let Some(missing) = req_data.missing_u64 {
let mut has_val = false;
for val in req_data.accessor.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
self.stats.collect(val1);
has_val = true;
}
if !has_val {
self.stats
.collect(f64_from_fastfield_u64(missing, &req_data.field_type));
}
} else {
for val in req_data.accessor.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &req_data.field_type);
self.stats.collect(val1);
}
}
Ok(())
}
#[inline]
fn collect_block(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let req_data = agg_data.get_metric_req_data_mut(self.accessor_idx);
self.collect_block_with_field(docs, req_data);
// TODO: remove once we fetch all values for all bucket ids in one go
if docs.len() == 1 && self.missing_u64.is_none() {
collect_stats::<COLUMN_TYPE_ID>(
&mut self.buckets[parent_bucket_id as usize],
self.accessor.values_for_doc(docs[0]),
self.is_number_or_date_type,
)?;
return Ok(());
}
agg_data.column_block_accessor.fetch_block_with_missing(
docs,
&self.accessor,
self.missing_u64,
);
collect_stats::<COLUMN_TYPE_ID>(
&mut self.buckets[parent_bucket_id as usize],
agg_data.column_block_accessor.iter_vals(),
self.is_number_or_date_type,
)?;
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
let required_buckets = (max_bucket as usize) + 1;
if self.buckets.len() < required_buckets {
self.buckets
.resize_with(required_buckets, IntermediateStats::default);
}
Ok(())
}
}
#[inline]
fn collect_stats<const COLUMN_TYPE_ID: u8>(
stats: &mut IntermediateStats,
vals: impl Iterator<Item = u64>,
is_number_or_date_type: bool,
) -> crate::Result<()> {
if is_number_or_date_type {
for val in vals {
let val1 = convert_to_f64::<COLUMN_TYPE_ID>(val);
stats.collect(val1);
}
} else {
for _val in vals {
// we ignore the value and simply record that we got something
stats.collect(0.0);
}
}
Ok(())
}
#[cfg(test)]

View File

@@ -52,10 +52,8 @@ pub struct IntermediateSum {
impl IntermediateSum {
/// Creates a new [`IntermediateSum`] instance from a [`SegmentStatsCollector`].
pub(crate) fn from_collector(collector: SegmentStatsCollector) -> Self {
Self {
stats: collector.stats,
}
pub(crate) fn from_stats(stats: IntermediateStats) -> Self {
Self { stats }
}
/// Merges the other intermediate result into self.
pub fn merge_fruits(&mut self, other: IntermediateSum) {

View File

@@ -15,12 +15,11 @@ use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateMetricResult,
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::AggregationError;
use crate::aggregation::{AggregationError, BucketId};
use crate::collector::sort_key::ReverseComparator;
use crate::collector::TopNComputer;
use crate::schema::OwnedValue;
use crate::{DocAddress, DocId, SegmentOrdinal};
// duplicate import removed; already imported above
/// Contains all information required by the TopHitsSegmentCollector to perform the
/// top_hits aggregation on a segment.
@@ -472,7 +471,10 @@ impl TopHitsTopNComputer {
/// Create a new TopHitsCollector
pub fn new(req: &TopHitsAggregationReq) -> Self {
Self {
top_n: TopNComputer::new(req.size + req.from.unwrap_or(0)),
top_n: TopNComputer::new_with_comparator(
req.size + req.from.unwrap_or(0),
ReverseComparator,
),
req: req.clone(),
}
}
@@ -518,7 +520,8 @@ impl TopHitsTopNComputer {
pub(crate) struct TopHitsSegmentCollector {
segment_ordinal: SegmentOrdinal,
accessor_idx: usize,
top_n: TopNComputer<Vec<DocValueAndOrder>, DocAddress, ReverseComparator>,
buckets: Vec<TopNComputer<Vec<DocValueAndOrder>, DocAddress, ReverseComparator>>,
num_hits: usize,
}
impl TopHitsSegmentCollector {
@@ -527,19 +530,29 @@ impl TopHitsSegmentCollector {
accessor_idx: usize,
segment_ordinal: SegmentOrdinal,
) -> Self {
let num_hits = req.size + req.from.unwrap_or(0);
Self {
top_n: TopNComputer::new(req.size + req.from.unwrap_or(0)),
num_hits,
segment_ordinal,
accessor_idx,
buckets: vec![TopNComputer::new_with_comparator(num_hits, ReverseComparator); 1],
}
}
fn into_top_hits_collector(
self,
fn get_top_hits_computer(
&mut self,
parent_bucket_id: BucketId,
value_accessors: &HashMap<String, Vec<DynamicColumn>>,
req: &TopHitsAggregationReq,
) -> TopHitsTopNComputer {
if parent_bucket_id as usize >= self.buckets.len() {
return TopHitsTopNComputer::new(req);
}
let top_n = std::mem::replace(
&mut self.buckets[parent_bucket_id as usize],
TopNComputer::new(0),
);
let mut top_hits_computer = TopHitsTopNComputer::new(req);
let top_results = self.top_n.into_vec();
let top_results = top_n.into_vec();
for res in top_results {
let doc_value_fields = req.get_document_field_data(value_accessors, res.doc.doc_id);
@@ -554,54 +567,24 @@ impl TopHitsSegmentCollector {
top_hits_computer
}
/// TODO add a specialized variant for a single sort field
fn collect_with(
&mut self,
doc_id: crate::DocId,
req: &TopHitsAggregationReq,
accessors: &[(Column<u64>, ColumnType)],
) -> crate::Result<()> {
let sorts: Vec<DocValueAndOrder> = req
.sort
.iter()
.enumerate()
.map(|(idx, KeyOrder { order, .. })| {
let order = *order;
let value = accessors
.get(idx)
.expect("could not find field in accessors")
.0
.values_for_doc(doc_id)
.next();
DocValueAndOrder { value, order }
})
.collect();
self.top_n.push(
sorts,
DocAddress {
segment_ord: self.segment_ordinal,
doc_id,
},
);
Ok(())
}
}
impl SegmentAggregationCollector for TopHitsSegmentCollector {
fn add_intermediate_aggregation_result(
self: Box<Self>,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut crate::aggregation::intermediate_agg_result::IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let req_data = agg_data.get_top_hits_req_data(self.accessor_idx);
let value_accessors = &req_data.value_accessors;
let intermediate_result = IntermediateMetricResult::TopHits(
self.into_top_hits_collector(value_accessors, &req_data.req),
);
let intermediate_result = IntermediateMetricResult::TopHits(self.get_top_hits_computer(
parent_bucket_id,
value_accessors,
&req_data.req,
));
results.push(
req_data.name.to_string(),
IntermediateAggregationResult::Metric(intermediate_result),
@@ -611,26 +594,57 @@ impl SegmentAggregationCollector for TopHitsSegmentCollector {
/// TODO: Consider a caching layer to reduce the call overhead
fn collect(
&mut self,
doc_id: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let req_data = agg_data.get_top_hits_req_data(self.accessor_idx);
self.collect_with(doc_id, &req_data.req, &req_data.accessors)?;
Ok(())
}
fn collect_block(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let top_n = &mut self.buckets[parent_bucket_id as usize];
let req_data = agg_data.get_top_hits_req_data(self.accessor_idx);
// TODO: Consider getting fields with the column block accessor.
for doc in docs {
self.collect_with(*doc, &req_data.req, &req_data.accessors)?;
let req = &req_data.req;
let accessors = &req_data.accessors;
for doc_id in docs {
let doc_id = *doc_id;
// TODO: this is terrible, a new vec is allocated for every doc
// We can fetch blocks instead
// We don't need to store the order for every value
let sorts: Vec<DocValueAndOrder> = req
.sort
.iter()
.enumerate()
.map(|(idx, KeyOrder { order, .. })| {
let order = *order;
let value = accessors
.get(idx)
.expect("could not find field in accessors")
.0
.values_for_doc(doc_id)
.next();
DocValueAndOrder { value, order }
})
.collect();
top_n.push(
sorts,
DocAddress {
segment_ord: self.segment_ordinal,
doc_id,
},
);
}
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
self.buckets.resize(
(max_bucket as usize) + 1,
TopNComputer::new_with_comparator(self.num_hits, ReverseComparator),
);
Ok(())
}
}
#[cfg(test)]
@@ -746,7 +760,7 @@ mod tests {
],
"from": 0,
}
}
}
}))
.unwrap();
@@ -875,7 +889,7 @@ mod tests {
"mixed.*",
],
}
}
}
}))?;
let collector = AggregationCollector::from_aggs(d, Default::default());

View File

@@ -133,7 +133,7 @@ mod agg_limits;
pub mod agg_req;
pub mod agg_result;
pub mod bucket;
mod buf_collector;
pub(crate) mod cached_sub_aggs;
mod collector;
mod date;
mod error;
@@ -162,6 +162,19 @@ use serde::{Deserialize, Deserializer, Serialize};
use crate::tokenizer::TokenizerManager;
/// A bucket id is a dense identifier for a bucket within an aggregation.
/// It is used to index into a Vec that hold per-bucket data.
///
/// For example, in a terms aggregation, each unique term will be assigned a incremental BucketId.
/// This BucketId will be forwarded to sub-aggregations to identify the parent bucket.
///
/// This allows to have a single AggregationCollector instance per aggregation,
/// that can handle multiple buckets efficiently.
///
/// The API to call sub-aggregations is therefore a &[(BucketId, &[DocId])].
/// For that we'll need a buffer. One Vec per bucket aggregation is needed.
pub type BucketId = u32;
/// Context parameters for aggregation execution
///
/// This struct holds shared resources needed during aggregation execution:
@@ -335,19 +348,37 @@ impl Display for Key {
}
}
pub(crate) fn convert_to_f64<const COLUMN_TYPE_ID: u8>(val: u64) -> f64 {
if COLUMN_TYPE_ID == ColumnType::U64 as u8 {
val as f64
} else if COLUMN_TYPE_ID == ColumnType::I64 as u8
|| COLUMN_TYPE_ID == ColumnType::DateTime as u8
{
i64::from_u64(val) as f64
} else if COLUMN_TYPE_ID == ColumnType::F64 as u8 {
f64::from_u64(val)
} else if COLUMN_TYPE_ID == ColumnType::Bool as u8 {
val as f64
} else {
panic!(
"ColumnType ID {} cannot be converted to f64 metric",
COLUMN_TYPE_ID
)
}
}
/// Inverse of `to_fastfield_u64`. Used to convert to `f64` for metrics.
///
/// # Panics
/// Only `u64`, `f64`, `date`, and `i64` are supported.
pub(crate) fn f64_from_fastfield_u64(val: u64, field_type: &ColumnType) -> f64 {
pub(crate) fn f64_from_fastfield_u64(val: u64, field_type: ColumnType) -> f64 {
match field_type {
ColumnType::U64 => val as f64,
ColumnType::I64 | ColumnType::DateTime => i64::from_u64(val) as f64,
ColumnType::F64 => f64::from_u64(val),
ColumnType::Bool => val as f64,
_ => {
panic!("unexpected type {field_type:?}. This should not happen")
}
ColumnType::U64 => convert_to_f64::<{ ColumnType::U64 as u8 }>(val),
ColumnType::I64 => convert_to_f64::<{ ColumnType::I64 as u8 }>(val),
ColumnType::F64 => convert_to_f64::<{ ColumnType::F64 as u8 }>(val),
ColumnType::Bool => convert_to_f64::<{ ColumnType::Bool as u8 }>(val),
ColumnType::DateTime => convert_to_f64::<{ ColumnType::DateTime as u8 }>(val),
_ => panic!("unexpected type {field_type:?}. This should not happen"),
}
}

View File

@@ -8,25 +8,67 @@ use std::fmt::Debug;
pub(crate) use super::agg_limits::AggregationLimitsGuard;
use super::intermediate_agg_result::IntermediateAggregationResults;
use crate::aggregation::agg_data::AggregationsSegmentCtx;
use crate::aggregation::BucketId;
/// Monotonically increasing provider of BucketIds.
#[derive(Debug, Clone, Default)]
pub struct BucketIdProvider(u32);
impl BucketIdProvider {
/// Get the next BucketId.
pub fn next_bucket_id(&mut self) -> BucketId {
let bucket_id = self.0;
self.0 += 1;
bucket_id
}
}
/// A SegmentAggregationCollector is used to collect aggregation results.
pub trait SegmentAggregationCollector: CollectorClone + Debug {
pub trait SegmentAggregationCollector: Debug {
fn add_intermediate_aggregation_result(
self: Box<Self>,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()>;
/// Note: The caller needs to call `prepare_max_bucket` before calling `collect`.
fn collect(
&mut self,
doc: crate::DocId,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()>;
fn collect_block(
/// Collect docs for multiple buckets in one call.
/// Minimizes dynamic dispatch overhead when collecting many buckets.
///
/// Note: The caller needs to call `prepare_max_bucket` before calling `collect`.
fn collect_multiple(
&mut self,
bucket_ids: &[BucketId],
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
debug_assert_eq!(bucket_ids.len(), docs.len());
let mut start = 0;
while start < bucket_ids.len() {
let bucket_id = bucket_ids[start];
let mut end = start + 1;
while end < bucket_ids.len() && bucket_ids[end] == bucket_id {
end += 1;
}
self.collect(bucket_id, &docs[start..end], agg_data)?;
start = end;
}
Ok(())
}
/// Prepare the collector for collecting up to BucketId `max_bucket`.
/// This is useful so we can split allocation ahead of time of collecting.
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()>;
/// Finalize method. Some Aggregator collect blocks of docs before calling `collect_block`.
@@ -36,26 +78,7 @@ pub trait SegmentAggregationCollector: CollectorClone + Debug {
}
}
/// A helper trait to enable cloning of Box<dyn SegmentAggregationCollector>
pub trait CollectorClone {
fn clone_box(&self) -> Box<dyn SegmentAggregationCollector>;
}
impl<T> CollectorClone for T
where T: 'static + SegmentAggregationCollector + Clone
{
fn clone_box(&self) -> Box<dyn SegmentAggregationCollector> {
Box::new(self.clone())
}
}
impl Clone for Box<dyn SegmentAggregationCollector> {
fn clone(&self) -> Box<dyn SegmentAggregationCollector> {
self.clone_box()
}
}
#[derive(Clone, Default)]
#[derive(Default)]
/// The GenericSegmentAggregationResultsCollector is the generic version of the collector, which
/// can handle arbitrary complexity of sub-aggregations. Ideally we never have to pick this one
/// and can provide specialized versions instead, that remove some of its overhead.
@@ -73,12 +96,13 @@ impl Debug for GenericSegmentAggregationResultsCollector {
impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
fn add_intermediate_aggregation_result(
self: Box<Self>,
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
for agg in self.aggs {
agg.add_intermediate_aggregation_result(agg_data, results)?;
for agg in &mut self.aggs {
agg.add_intermediate_aggregation_result(agg_data, results, parent_bucket_id)?;
}
Ok(())
@@ -86,23 +110,13 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
fn collect(
&mut self,
doc: crate::DocId,
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_data)?;
Ok(())
}
fn collect_block(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
for collector in &mut self.aggs {
collector.collect_block(docs, agg_data)?;
collector.collect(parent_bucket_id, docs, agg_data)?;
}
Ok(())
}
@@ -112,4 +126,15 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
}
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
for collector in &mut self.aggs {
collector.prepare_max_bucket(max_bucket, agg_data)?;
}
Ok(())
}
}

View File

@@ -48,7 +48,15 @@ impl Executor {
F: Sized + Sync + Fn(A) -> crate::Result<R>,
{
match self {
Executor::SingleThread => args.map(f).collect::<crate::Result<_>>(),
Executor::SingleThread => {
// Avoid `collect`, since the stacktrace is blown up by it, which makes profiling
// harder.
let mut result = Vec::with_capacity(args.size_hint().0);
for arg in args {
result.push(f(arg)?);
}
Ok(result)
}
Executor::ThreadPool(pool) => {
let args: Vec<A> = args.collect();
let num_fruits = args.len();

View File

@@ -227,9 +227,6 @@ pub(crate) fn index_json_value<'a, V: Value<'a>>(
ReferenceValueLeaf::IpAddr(_) => {
unimplemented!("IP address support in dynamic fields is not yet implemented")
}
ReferenceValueLeaf::Geometry(_) => {
unimplemented!("Geometry support in dynamic fields is not implemented")
}
},
ReferenceValue::Array(elements) => {
for val in elements {

View File

@@ -683,7 +683,7 @@ mod tests {
}
#[test]
fn test_datefastfield() {
fn test_datefastfield() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let date_field = schema_builder.add_date_field(
"date",
@@ -697,28 +697,22 @@ mod tests {
);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests().unwrap();
let mut index_writer = index.writer_for_tests()?;
index_writer.set_merge_policy(Box::new(NoMergePolicy));
index_writer
.add_document(doc!(
date_field => DateTime::from_u64(1i64.to_u64()),
multi_date_field => DateTime::from_u64(2i64.to_u64()),
multi_date_field => DateTime::from_u64(3i64.to_u64())
))
.unwrap();
index_writer
.add_document(doc!(
date_field => DateTime::from_u64(4i64.to_u64())
))
.unwrap();
index_writer
.add_document(doc!(
multi_date_field => DateTime::from_u64(5i64.to_u64()),
multi_date_field => DateTime::from_u64(6i64.to_u64())
))
.unwrap();
index_writer.commit().unwrap();
let reader = index.reader().unwrap();
index_writer.add_document(doc!(
date_field => DateTime::from_u64(1i64.to_u64()),
multi_date_field => DateTime::from_u64(2i64.to_u64()),
multi_date_field => DateTime::from_u64(3i64.to_u64())
))?;
index_writer.add_document(doc!(
date_field => DateTime::from_u64(4i64.to_u64())
))?;
index_writer.add_document(doc!(
multi_date_field => DateTime::from_u64(5i64.to_u64()),
multi_date_field => DateTime::from_u64(6i64.to_u64())
))?;
index_writer.commit()?;
let reader = index.reader()?;
let searcher = reader.searcher();
assert_eq!(searcher.segment_readers().len(), 1);
let segment_reader = searcher.segment_reader(0);
@@ -752,6 +746,7 @@ mod tests {
assert_eq!(dates[0].into_timestamp_nanos(), 5i64);
assert_eq!(dates[1].into_timestamp_nanos(), 6i64);
}
Ok(())
}
#[test]

View File

@@ -189,9 +189,6 @@ impl FastFieldsWriter {
.record_str(doc_id, field_name, &token.text);
}
}
ReferenceValueLeaf::Geometry(_) => {
panic!("Geometry fields should not be routed to fast field writer")
}
},
ReferenceValue::Array(val) => {
// TODO: Check this is the correct behaviour we want.
@@ -323,9 +320,6 @@ fn record_json_value_to_columnar_writer<'a, V: Value<'a>>(
"Pre-tokenized string support in dynamic fields is not yet implemented"
)
}
ReferenceValueLeaf::Geometry(_) => {
unimplemented!("Geometry support in dynamic fields is not yet implemented")
}
},
ReferenceValue::Array(elements) => {
for el in elements {

View File

@@ -142,7 +142,6 @@ impl SegmentMeta {
SegmentComponent::FastFields => ".fast".to_string(),
SegmentComponent::FieldNorms => ".fieldnorm".to_string(),
SegmentComponent::Delete => format!(".{}.del", self.delete_opstamp().unwrap_or(0)),
SegmentComponent::Spatial => ".spatial".to_string(),
});
PathBuf::from(path)
}

View File

@@ -28,14 +28,12 @@ pub enum SegmentComponent {
/// Bitset describing which document of the segment is alive.
/// (It was representing deleted docs but changed to represent alive docs from v0.17)
Delete,
/// HUSH
Spatial,
}
impl SegmentComponent {
/// Iterates through the components.
pub fn iterator() -> slice::Iter<'static, SegmentComponent> {
static SEGMENT_COMPONENTS: [SegmentComponent; 9] = [
static SEGMENT_COMPONENTS: [SegmentComponent; 8] = [
SegmentComponent::Postings,
SegmentComponent::Positions,
SegmentComponent::FastFields,
@@ -44,7 +42,6 @@ impl SegmentComponent {
SegmentComponent::Store,
SegmentComponent::TempStore,
SegmentComponent::Delete,
SegmentComponent::Spatial,
];
SEGMENT_COMPONENTS.iter()
}

View File

@@ -14,7 +14,6 @@ use crate::index::{InvertedIndexReader, Segment, SegmentComponent, SegmentId};
use crate::json_utils::json_path_sep_to_dot;
use crate::schema::{Field, IndexRecordOption, Schema, Type};
use crate::space_usage::SegmentSpaceUsage;
use crate::spatial::reader::SpatialReaders;
use crate::store::StoreReader;
use crate::termdict::TermDictionary;
use crate::{DocId, Opstamp};
@@ -44,7 +43,6 @@ pub struct SegmentReader {
positions_composite: CompositeFile,
fast_fields_readers: FastFieldReaders,
fieldnorm_readers: FieldNormReaders,
spatial_readers: SpatialReaders,
store_file: FileSlice,
alive_bitset_opt: Option<AliveBitSet>,
@@ -94,11 +92,6 @@ impl SegmentReader {
&self.fast_fields_readers
}
/// HUSH
pub fn spatial_fields(&self) -> &SpatialReaders {
&self.spatial_readers
}
/// Accessor to the `FacetReader` associated with a given `Field`.
pub fn facet_reader(&self, field_name: &str) -> crate::Result<FacetReader> {
let schema = self.schema();
@@ -180,12 +173,6 @@ impl SegmentReader {
let fast_fields_readers = FastFieldReaders::open(fast_fields_data, schema.clone())?;
let fieldnorm_data = segment.open_read(SegmentComponent::FieldNorms)?;
let fieldnorm_readers = FieldNormReaders::open(fieldnorm_data)?;
let spatial_readers = if schema.contains_spatial_field() {
let spatial_data = segment.open_read(SegmentComponent::Spatial)?;
SpatialReaders::open(spatial_data)?
} else {
SpatialReaders::empty()
};
let original_bitset = if segment.meta().has_deletes() {
let alive_doc_file_slice = segment.open_read(SegmentComponent::Delete)?;
@@ -211,7 +198,6 @@ impl SegmentReader {
postings_composite,
fast_fields_readers,
fieldnorm_readers,
spatial_readers,
segment_id: segment.id(),
delete_opstamp: segment.meta().delete_opstamp(),
store_file,
@@ -474,7 +460,6 @@ impl SegmentReader {
self.positions_composite.space_usage(),
self.fast_fields_readers.space_usage(self.schema())?,
self.fieldnorm_readers.space_usage(),
self.spatial_readers.space_usage(),
self.get_store_reader(0)?.space_usage(),
self.alive_bitset_opt
.as_ref()

View File

@@ -1,5 +1,3 @@
use std::collections::HashMap;
use std::io::{BufWriter, Write};
use std::sync::Arc;
use columnar::{
@@ -8,7 +6,6 @@ use columnar::{
use common::ReadOnlyBitSet;
use itertools::Itertools;
use measure_time::debug_time;
use tempfile::NamedTempFile;
use crate::directory::WritePtr;
use crate::docset::{DocSet, TERMINATED};
@@ -20,8 +17,6 @@ use crate::indexer::doc_id_mapping::{MappingType, SegmentDocIdMapping};
use crate::indexer::SegmentSerializer;
use crate::postings::{InvertedIndexSerializer, Postings, SegmentPostings};
use crate::schema::{value_type_to_column_type, Field, FieldType, Schema};
use crate::spatial::bkd::LeafPageIterator;
use crate::spatial::triangle::Triangle;
use crate::store::StoreWriter;
use crate::termdict::{TermMerger, TermOrdinal};
use crate::{DocAddress, DocId, InvertedIndexReader};
@@ -175,7 +170,6 @@ impl IndexMerger {
let mut readers = vec![];
for (segment, new_alive_bitset_opt) in segments.iter().zip(alive_bitset_opt) {
if segment.meta().num_docs() > 0 {
dbg!("segment");
let reader =
SegmentReader::open_with_custom_alive_set(segment, new_alive_bitset_opt)?;
readers.push(reader);
@@ -526,89 +520,6 @@ impl IndexMerger {
Ok(())
}
fn write_spatial_fields(
&self,
serializer: &mut SegmentSerializer,
doc_id_mapping: &SegmentDocIdMapping,
) -> crate::Result<()> {
/// We need to rebuild a BKD-tree based off the list of triangles.
///
/// Because the data can be large, we do this by writing the sequence of triangles to
/// disk, and mmapping it as mutable slice, and calling the same code as what
/// is done for the segment serialization.
///
/// The OS is in charge of deciding how to handle its page cache.
/// This is the same as what would have happened with swapping,
/// except by explicitly mapping the file, the OS is more likely to
/// swap, the memory will not be accounted as anonymous memory,
/// swap space is reserved etc.
use crate::spatial::bkd::Segment;
let Some(mut spatial_serializer) = serializer.extract_spatial_serializer() else {
// The schema does not contain any spatial field.
return Ok(());
};
let mut segment_mappings: Vec<Vec<Option<DocId>>> = Vec::new();
for reader in &self.readers {
let max_doc = reader.max_doc();
segment_mappings.push(vec![None; max_doc as usize]);
}
for (new_doc_id, old_doc_addr) in doc_id_mapping.iter_old_doc_addrs().enumerate() {
segment_mappings[old_doc_addr.segment_ord as usize][old_doc_addr.doc_id as usize] =
Some(new_doc_id as DocId);
}
let mut temp_files: HashMap<Field, NamedTempFile> = HashMap::new();
for (field, field_entry) in self.schema.fields() {
if matches!(field_entry.field_type(), FieldType::Spatial(_)) {
temp_files.insert(field, NamedTempFile::new()?);
}
}
for (segment_ord, reader) in self.readers.iter().enumerate() {
for (field, temp_file) in &mut temp_files {
let mut buf_temp_file = BufWriter::new(temp_file);
let spatial_readers = reader.spatial_fields();
let Some(spatial_reader) = spatial_readers.get_field(*field)? else {
continue;
};
let segment = Segment::new(spatial_reader.get_bytes());
for triangle_result in LeafPageIterator::new(&segment) {
let triangles = triangle_result?;
for triangle in triangles {
if let Some(new_doc_id) =
segment_mappings[segment_ord][triangle.doc_id as usize]
{
// This is really just a temporary file, not meant to be portable, so we
// use native endianness here.
for &word in &triangle.words {
buf_temp_file.write_all(&word.to_ne_bytes())?;
}
buf_temp_file.write_all(&new_doc_id.to_ne_bytes())?;
}
}
}
buf_temp_file.flush()?;
// No need to fsync here. This file is not here for persistency.
}
}
for (field, temp_file) in temp_files {
// Memory map the triangle file.
use memmap2::MmapOptions;
let mmap = unsafe { MmapOptions::new().map_mut(temp_file.as_file())? };
// Cast to &[Triangle] slice
let triangle_count = mmap.len() / std::mem::size_of::<Triangle>();
let triangles = unsafe {
std::slice::from_raw_parts_mut(mmap.as_ptr() as *mut Triangle, triangle_count)
};
// Get spatial writer and rebuild block kd-tree.
spatial_serializer.serialize_field(field, triangles)?;
}
spatial_serializer.close()?;
Ok(())
}
/// Writes the merged segment by pushing information
/// to the `SegmentSerializer`.
///
@@ -633,10 +544,9 @@ impl IndexMerger {
debug!("write-storagefields");
self.write_storable_fields(serializer.get_store_writer())?;
debug!("write-spatialfields");
self.write_spatial_fields(&mut serializer, &doc_id_mapping)?;
debug!("write-fastfields");
self.write_fast_fields(serializer.get_fast_field_write(), doc_id_mapping)?;
debug!("close-serializer");
serializer.close()?;
Ok(self.max_doc)

View File

@@ -4,7 +4,6 @@ use crate::directory::WritePtr;
use crate::fieldnorm::FieldNormsSerializer;
use crate::index::{Segment, SegmentComponent};
use crate::postings::InvertedIndexSerializer;
use crate::spatial::serializer::SpatialSerializer;
use crate::store::StoreWriter;
/// Segment serializer is in charge of laying out on disk
@@ -13,7 +12,6 @@ pub struct SegmentSerializer {
segment: Segment,
pub(crate) store_writer: StoreWriter,
fast_field_write: WritePtr,
spatial_serializer: Option<SpatialSerializer>,
fieldnorms_serializer: Option<FieldNormsSerializer>,
postings_serializer: InvertedIndexSerializer,
}
@@ -37,20 +35,11 @@ impl SegmentSerializer {
let fieldnorms_write = segment.open_write(SegmentComponent::FieldNorms)?;
let fieldnorms_serializer = FieldNormsSerializer::from_write(fieldnorms_write)?;
let spatial_serializer: Option<SpatialSerializer> =
if segment.schema().contains_spatial_field() {
let spatial_write = segment.open_write(SegmentComponent::Spatial)?;
Some(SpatialSerializer::from_write(spatial_write)?)
} else {
None
};
let postings_serializer = InvertedIndexSerializer::open(&mut segment)?;
Ok(SegmentSerializer {
segment,
store_writer,
fast_field_write,
spatial_serializer,
fieldnorms_serializer: Some(fieldnorms_serializer),
postings_serializer,
})
@@ -75,11 +64,6 @@ impl SegmentSerializer {
&mut self.fast_field_write
}
/// Accessor to the `SpatialSerializer`
pub fn extract_spatial_serializer(&mut self) -> Option<SpatialSerializer> {
self.spatial_serializer.take()
}
/// Extract the field norm serializer.
///
/// Note the fieldnorms serializer can only be extracted once.
@@ -97,9 +81,6 @@ impl SegmentSerializer {
if let Some(fieldnorms_serializer) = self.extract_fieldnorms_serializer() {
fieldnorms_serializer.close()?;
}
if let Some(spatial_serializer) = self.extract_spatial_serializer() {
spatial_serializer.close()?;
}
self.fast_field_write.terminate()?;
self.postings_serializer.close()?;
self.store_writer.close()?;

View File

@@ -16,7 +16,6 @@ use crate::postings::{
};
use crate::schema::document::{Document, Value};
use crate::schema::{FieldEntry, FieldType, Schema, DATE_TIME_PRECISION_INDEXED};
use crate::spatial::writer::SpatialWriter;
use crate::tokenizer::{FacetTokenizer, PreTokenizedStream, TextAnalyzer, Tokenizer};
use crate::{DocId, Opstamp, TantivyError};
@@ -53,7 +52,6 @@ pub struct SegmentWriter {
pub(crate) segment_serializer: SegmentSerializer,
pub(crate) fast_field_writers: FastFieldsWriter,
pub(crate) fieldnorms_writer: FieldNormsWriter,
pub(crate) spatial_writer: SpatialWriter,
pub(crate) json_path_writer: JsonPathWriter,
pub(crate) json_positions_per_path: IndexingPositionsPerPath,
pub(crate) doc_opstamps: Vec<Opstamp>,
@@ -106,7 +104,6 @@ impl SegmentWriter {
ctx: IndexingContext::new(table_size),
per_field_postings_writers,
fieldnorms_writer: FieldNormsWriter::for_schema(&schema),
spatial_writer: SpatialWriter::default(),
json_path_writer: JsonPathWriter::default(),
json_positions_per_path: IndexingPositionsPerPath::default(),
segment_serializer,
@@ -133,7 +130,6 @@ impl SegmentWriter {
self.ctx,
self.fast_field_writers,
&self.fieldnorms_writer,
&mut self.spatial_writer,
self.segment_serializer,
)?;
Ok(self.doc_opstamps)
@@ -146,7 +142,6 @@ impl SegmentWriter {
+ self.fieldnorms_writer.mem_usage()
+ self.fast_field_writers.mem_usage()
+ self.segment_serializer.mem_usage()
+ self.spatial_writer.mem_usage()
}
fn index_document<D: Document>(&mut self, doc: &D) -> crate::Result<()> {
@@ -343,13 +338,6 @@ impl SegmentWriter {
self.fieldnorms_writer.record(doc_id, field, num_vals);
}
}
FieldType::Spatial(_) => {
for value in values {
if let Some(geometry) = value.as_geometry() {
self.spatial_writer.add_geometry(doc_id, field, *geometry);
}
}
}
}
}
Ok(())
@@ -404,16 +392,12 @@ fn remap_and_write(
ctx: IndexingContext,
fast_field_writers: FastFieldsWriter,
fieldnorms_writer: &FieldNormsWriter,
spatial_writer: &mut SpatialWriter,
mut serializer: SegmentSerializer,
) -> crate::Result<()> {
debug!("remap-and-write");
if let Some(fieldnorms_serializer) = serializer.extract_fieldnorms_serializer() {
fieldnorms_writer.serialize(fieldnorms_serializer)?;
}
if let Some(spatial_serializer) = serializer.extract_spatial_serializer() {
spatial_writer.serialize(spatial_serializer)?;
}
let fieldnorm_data = serializer
.segment()
.open_read(SegmentComponent::FieldNorms)?;

View File

@@ -191,7 +191,6 @@ pub mod fieldnorm;
pub mod index;
pub mod positions;
pub mod postings;
pub mod spatial;
/// Module containing the different query implementations.
pub mod query;

View File

@@ -51,7 +51,6 @@ fn posting_writer_from_field_entry(field_entry: &FieldEntry) -> Box<dyn Postings
| FieldType::Date(_)
| FieldType::Bytes(_)
| FieldType::IpAddr(_)
| FieldType::Spatial(_)
| FieldType::Facet(_) => Box::<SpecializedPostingsWriter<DocIdRecorder>>::default(),
FieldType::JsonObject(ref json_object_options) => {
if let Some(text_indexing_option) = json_object_options.get_text_indexing_options() {

View File

@@ -24,7 +24,6 @@ mod reqopt_scorer;
mod scorer;
mod set_query;
mod size_hint;
mod spatial_query;
mod term_query;
mod union;
mod weight;
@@ -63,7 +62,6 @@ pub use self::reqopt_scorer::RequiredOptionalScorer;
pub use self::score_combiner::{DisjunctionMaxCombiner, ScoreCombiner, SumCombiner};
pub use self::scorer::Scorer;
pub use self::set_query::TermSetQuery;
pub use self::spatial_query::{SpatialQuery, SpatialQueryType};
pub use self::term_query::TermQuery;
pub use self::union::BufferedUnionScorer;
#[cfg(test)]

View File

@@ -524,9 +524,6 @@ impl QueryParser {
let ip_v6 = IpAddr::from_str(phrase)?.into_ipv6_addr();
Ok(Term::from_field_ip_addr(field, ip_v6))
}
FieldType::Spatial(_) => Err(QueryParserError::UnsupportedQuery(
"Spatial queries are not yet supported in text query parser".to_string(),
)),
}
}
@@ -627,10 +624,6 @@ impl QueryParser {
let term = Term::from_field_ip_addr(field, ip_v6);
Ok(vec![LogicalLiteral::Term(term)])
}
FieldType::Spatial(_) => Err(QueryParserError::UnsupportedQuery(format!(
"Spatial queries are not yet supported for field '{}'",
field_name
))),
}
}

View File

@@ -20,6 +20,6 @@ pub(crate) fn is_type_valid_for_fastfield_range_query(typ: Type) -> bool {
| Type::Date
| Type::Json
| Type::IpAddr => true,
Type::Facet | Type::Bytes | Type::Spatial => false,
Type::Facet | Type::Bytes => false,
}
}

View File

@@ -128,15 +128,12 @@ impl Weight for FastFieldRangeWeight {
BoundsRange::new(bounds.lower_bound, bounds.upper_bound),
)
}
Type::Bool
| Type::Facet
| Type::Bytes
| Type::Json
| Type::IpAddr
| Type::Spatial => Err(crate::TantivyError::InvalidArgument(format!(
"unsupported value bytes type in json term value_bytes {:?}",
term_value.typ()
))),
Type::Bool | Type::Facet | Type::Bytes | Type::Json | Type::IpAddr => {
Err(crate::TantivyError::InvalidArgument(format!(
"unsupported value bytes type in json term value_bytes {:?}",
term_value.typ()
)))
}
}
} else if field_type.is_ip_addr() {
let parse_ip_from_bytes = |term: &Term| {
@@ -438,7 +435,7 @@ pub(crate) fn maps_to_u64_fastfield(typ: Type) -> bool {
match typ {
Type::U64 | Type::I64 | Type::F64 | Type::Bool | Type::Date => true,
Type::IpAddr => false,
Type::Str | Type::Facet | Type::Bytes | Type::Json | Type::Spatial => false,
Type::Str | Type::Facet | Type::Bytes | Type::Json => false,
}
}

View File

@@ -1,186 +0,0 @@
//! HUSH
use common::BitSet;
use crate::query::explanation::does_not_match;
use crate::query::{BitSetDocSet, Explanation, Query, Scorer, Weight};
use crate::schema::Field;
use crate::spatial::bkd::{search_intersects, Segment};
use crate::spatial::point::GeoPoint;
use crate::spatial::writer::as_point_i32;
use crate::{DocId, DocSet, Score, TantivyError, TERMINATED};
#[derive(Clone, Copy, Debug)]
/// HUSH
pub enum SpatialQueryType {
/// HUSH
Intersects,
// Within,
// Contains,
}
#[derive(Clone, Copy, Debug)]
/// HUSH
pub struct SpatialQuery {
field: Field,
bounds: [(i32, i32); 2],
query_type: SpatialQueryType,
}
impl SpatialQuery {
/// HUSH
pub fn new(field: Field, bounds: [GeoPoint; 2], query_type: SpatialQueryType) -> Self {
SpatialQuery {
field,
bounds: [as_point_i32(bounds[0]), as_point_i32(bounds[1])],
query_type,
}
}
}
impl Query for SpatialQuery {
fn weight(
&self,
_enable_scoring: super::EnableScoring<'_>,
) -> crate::Result<Box<dyn super::Weight>> {
Ok(Box::new(SpatialWeight::new(
self.field,
self.bounds,
self.query_type,
)))
}
}
pub struct SpatialWeight {
field: Field,
bounds: [(i32, i32); 2],
query_type: SpatialQueryType,
}
impl SpatialWeight {
fn new(field: Field, bounds: [(i32, i32); 2], query_type: SpatialQueryType) -> Self {
SpatialWeight {
field,
bounds,
query_type,
}
}
}
impl Weight for SpatialWeight {
fn scorer(
&self,
reader: &crate::SegmentReader,
boost: crate::Score,
) -> crate::Result<Box<dyn super::Scorer>> {
let spatial_reader = reader
.spatial_fields()
.get_field(self.field)?
.ok_or_else(|| TantivyError::SchemaError(format!("No spatial data for field")))?;
let block_kd_tree = Segment::new(spatial_reader.get_bytes());
match self.query_type {
SpatialQueryType::Intersects => {
let mut include = BitSet::with_max_value(reader.max_doc());
search_intersects(
&block_kd_tree,
block_kd_tree.root_offset,
&[
self.bounds[0].1,
self.bounds[0].0,
self.bounds[1].1,
self.bounds[1].0,
],
&mut include,
)?;
Ok(Box::new(SpatialScorer::new(boost, include, None)))
}
}
}
fn explain(
&self,
reader: &crate::SegmentReader,
doc: DocId,
) -> crate::Result<super::Explanation> {
let mut scorer = self.scorer(reader, 1.0)?;
if scorer.seek(doc) != doc {
return Err(does_not_match(doc));
}
let query_type_desc = match self.query_type {
SpatialQueryType::Intersects => "SpatialQuery::Intersects",
};
let score = scorer.score();
let mut explanation = Explanation::new(query_type_desc, score);
explanation.add_context(format!(
"bounds: [({}, {}), ({}, {})]",
self.bounds[0].0, self.bounds[0].1, self.bounds[1].0, self.bounds[1].1,
));
explanation.add_context(format!("field: {:?}", self.field));
Ok(explanation)
}
}
struct SpatialScorer {
include: BitSetDocSet,
exclude: Option<BitSet>,
doc_id: DocId,
score: Score,
}
impl SpatialScorer {
pub fn new(score: Score, include: BitSet, exclude: Option<BitSet>) -> Self {
let mut scorer = SpatialScorer {
include: BitSetDocSet::from(include),
exclude,
doc_id: 0,
score,
};
scorer.prime();
scorer
}
fn prime(&mut self) {
self.doc_id = self.include.doc();
while self.exclude() {
self.doc_id = self.include.advance();
}
}
fn exclude(&self) -> bool {
if self.doc_id == TERMINATED {
return false;
}
match &self.exclude {
Some(exclude) => exclude.contains(self.doc_id),
None => false,
}
}
}
impl Scorer for SpatialScorer {
fn score(&mut self) -> Score {
self.score
}
}
impl DocSet for SpatialScorer {
fn advance(&mut self) -> DocId {
if self.doc_id == TERMINATED {
return TERMINATED;
}
self.doc_id = self.include.advance();
while self.exclude() {
self.doc_id = self.include.advance();
}
self.doc_id
}
fn size_hint(&self) -> u32 {
match &self.exclude {
Some(exclude) => self.include.size_hint() - exclude.len() as u32,
None => self.include.size_hint(),
}
}
fn doc(&self) -> DocId {
self.doc_id
}
}

View File

@@ -22,7 +22,6 @@ use super::se::BinaryObjectSerializer;
use super::{OwnedValue, Value};
use crate::schema::document::type_codes;
use crate::schema::{Facet, Field};
use crate::spatial::geometry::Geometry;
use crate::store::DocStoreVersion;
use crate::tokenizer::PreTokenizedString;
@@ -130,9 +129,6 @@ pub trait ValueDeserializer<'de> {
/// Attempts to deserialize a pre-tokenized string value from the deserializer.
fn deserialize_pre_tokenized_string(self) -> Result<PreTokenizedString, DeserializeError>;
/// HUSH
fn deserialize_geometry(self) -> Result<Geometry, DeserializeError>;
/// Attempts to deserialize the value using a given visitor.
fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, DeserializeError>
where V: ValueVisitor;
@@ -170,8 +166,6 @@ pub enum ValueType {
/// A JSON object value. Deprecated.
#[deprecated(note = "We keep this for backwards compatibility, use Object instead")]
JSONObject,
/// HUSH
Geometry,
}
/// A value visitor for deserializing a document value.
@@ -252,12 +246,6 @@ pub trait ValueVisitor {
Err(DeserializeError::UnsupportedType(ValueType::PreTokStr))
}
#[inline]
/// Called when the deserializer visits a geometry value.
fn visit_geometry(&self, _val: Geometry) -> Result<Self::Value, DeserializeError> {
Err(DeserializeError::UnsupportedType(ValueType::Geometry))
}
#[inline]
/// Called when the deserializer visits an array.
fn visit_array<'de, A>(&self, _access: A) -> Result<Self::Value, DeserializeError>
@@ -392,7 +380,6 @@ where R: Read
match ext_type_code {
type_codes::TOK_STR_EXT_CODE => ValueType::PreTokStr,
type_codes::GEO_EXT_CODE => ValueType::Geometry,
_ => {
return Err(DeserializeError::from(io::Error::new(
io::ErrorKind::InvalidData,
@@ -508,11 +495,6 @@ where R: Read
.map_err(DeserializeError::from)
}
fn deserialize_geometry(self) -> Result<Geometry, DeserializeError> {
self.validate_type(ValueType::Geometry)?;
<Geometry as BinarySerializable>::deserialize(self.reader).map_err(DeserializeError::from)
}
fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, DeserializeError>
where V: ValueVisitor {
match self.value_type {
@@ -557,10 +539,6 @@ where R: Read
let val = self.deserialize_pre_tokenized_string()?;
visitor.visit_pre_tokenized_string(val)
}
ValueType::Geometry => {
let val = self.deserialize_geometry()?;
visitor.visit_geometry(val)
}
ValueType::Array => {
let access =
BinaryArrayDeserializer::from_reader(self.reader, self.doc_store_version)?;

View File

@@ -13,7 +13,6 @@ use crate::schema::document::{
};
use crate::schema::field_type::ValueParsingError;
use crate::schema::{Facet, Field, NamedFieldDocument, OwnedValue, Schema};
use crate::spatial::geometry::Geometry;
use crate::tokenizer::PreTokenizedString;
#[repr(C, packed)]
@@ -255,7 +254,6 @@ impl CompactDoc {
}
ReferenceValueLeaf::IpAddr(num) => write_into(&mut self.node_data, num.to_u128()),
ReferenceValueLeaf::PreTokStr(pre_tok) => write_into(&mut self.node_data, *pre_tok),
ReferenceValueLeaf::Geometry(geometry) => write_into(&mut self.node_data, *geometry),
};
ValueAddr { type_id, val_addr }
}
@@ -466,12 +464,6 @@ impl<'a> CompactDocValue<'a> {
.map(Into::into)
.map(ReferenceValueLeaf::PreTokStr)
.map(Into::into),
ValueType::Geometry => self
.container
.read_from::<Geometry>(addr)
.map(Into::into)
.map(ReferenceValueLeaf::Geometry)
.map(Into::into),
ValueType::Object => Ok(ReferenceValue::Object(CompactDocObjectIter::new(
self.container,
addr,
@@ -550,8 +542,6 @@ pub enum ValueType {
Object = 11,
/// Pre-tokenized str type,
Array = 12,
/// HUSH
Geometry = 13,
}
impl BinarySerializable for ValueType {
@@ -597,7 +587,6 @@ impl<'a> From<&ReferenceValueLeaf<'a>> for ValueType {
ReferenceValueLeaf::PreTokStr(_) => ValueType::PreTokStr,
ReferenceValueLeaf::Facet(_) => ValueType::Facet,
ReferenceValueLeaf::Bytes(_) => ValueType::Bytes,
ReferenceValueLeaf::Geometry(_) => ValueType::Geometry,
}
}
}

View File

@@ -273,5 +273,4 @@ pub(crate) mod type_codes {
// Extended type codes
pub const TOK_STR_EXT_CODE: u8 = 0;
pub const GEO_EXT_CODE: u8 = 1;
}

View File

@@ -15,7 +15,6 @@ use crate::schema::document::{
ValueDeserializer, ValueVisitor,
};
use crate::schema::Facet;
use crate::spatial::geometry::Geometry;
use crate::tokenizer::PreTokenizedString;
use crate::DateTime;
@@ -50,8 +49,6 @@ pub enum OwnedValue {
Object(Vec<(String, Self)>),
/// IpV6 Address. Internally there is no IpV4, it needs to be converted to `Ipv6Addr`.
IpAddr(Ipv6Addr),
/// A GeoRust multi-polygon.
Geometry(Geometry),
}
impl AsRef<OwnedValue> for OwnedValue {
@@ -80,9 +77,6 @@ impl<'a> Value<'a> for &'a OwnedValue {
OwnedValue::IpAddr(val) => ReferenceValueLeaf::IpAddr(*val).into(),
OwnedValue::Array(array) => ReferenceValue::Array(array.iter()),
OwnedValue::Object(object) => ReferenceValue::Object(ObjectMapIter(object.iter())),
OwnedValue::Geometry(geometry) => {
ReferenceValueLeaf::Geometry(Box::new(geometry.clone())).into()
}
}
}
}
@@ -142,10 +136,6 @@ impl ValueDeserialize for OwnedValue {
Ok(OwnedValue::PreTokStr(val))
}
fn visit_geometry(&self, val: Geometry) -> Result<Self::Value, DeserializeError> {
Ok(OwnedValue::Geometry(val))
}
fn visit_array<'de, A>(&self, mut access: A) -> Result<Self::Value, DeserializeError>
where A: ArrayAccess<'de> {
let mut elements = Vec::with_capacity(access.size_hint());
@@ -208,7 +198,6 @@ impl serde::Serialize for OwnedValue {
}
}
OwnedValue::Array(ref array) => array.serialize(serializer),
OwnedValue::Geometry(ref geometry) => geometry.to_geojson().serialize(serializer),
}
}
}
@@ -296,7 +285,6 @@ impl<'a, V: Value<'a>> From<ReferenceValue<'a, V>> for OwnedValue {
ReferenceValueLeaf::IpAddr(val) => OwnedValue::IpAddr(val),
ReferenceValueLeaf::Bool(val) => OwnedValue::Bool(val),
ReferenceValueLeaf::PreTokStr(val) => OwnedValue::PreTokStr(*val.clone()),
ReferenceValueLeaf::Geometry(val) => OwnedValue::Geometry(*val.clone()),
},
ReferenceValue::Array(val) => {
OwnedValue::Array(val.map(|v| v.as_value().into()).collect())

View File

@@ -133,10 +133,6 @@ where W: Write
self.write_type_code(type_codes::EXT_CODE)?;
self.serialize_with_type_code(type_codes::TOK_STR_EXT_CODE, &*val)
}
ReferenceValueLeaf::Geometry(val) => {
self.write_type_code(type_codes::EXT_CODE)?;
self.serialize_with_type_code(type_codes::GEO_EXT_CODE, &*val)
}
},
ReferenceValue::Array(elements) => {
self.write_type_code(type_codes::ARRAY_CODE)?;

View File

@@ -3,7 +3,6 @@ use std::net::Ipv6Addr;
use common::DateTime;
use crate::spatial::geometry::Geometry;
use crate::tokenizer::PreTokenizedString;
/// A single field value.
@@ -109,12 +108,6 @@ pub trait Value<'a>: Send + Sync + Debug {
None
}
}
#[inline]
/// HUSH
fn as_geometry(&self) -> Option<Box<Geometry>> {
self.as_leaf().and_then(|leaf| leaf.into_geometry())
}
}
/// A enum representing a leaf value for tantivy to index.
@@ -143,8 +136,6 @@ pub enum ReferenceValueLeaf<'a> {
Bool(bool),
/// Pre-tokenized str type,
PreTokStr(Box<PreTokenizedString>),
/// HUSH
Geometry(Box<Geometry>),
}
impl From<u64> for ReferenceValueLeaf<'_> {
@@ -229,9 +220,6 @@ impl<'a, T: Value<'a> + ?Sized> From<ReferenceValueLeaf<'a>> for ReferenceValue<
ReferenceValueLeaf::PreTokStr(val) => {
ReferenceValue::Leaf(ReferenceValueLeaf::PreTokStr(val))
}
ReferenceValueLeaf::Geometry(val) => {
ReferenceValue::Leaf(ReferenceValueLeaf::Geometry(val))
}
}
}
}
@@ -343,16 +331,6 @@ impl<'a> ReferenceValueLeaf<'a> {
None
}
}
#[inline]
/// HUSH
pub fn into_geometry(self) -> Option<Box<Geometry>> {
if let Self::Geometry(val) = self {
Some(val)
} else {
None
}
}
}
/// A enum representing a value for tantivy to index.
@@ -470,10 +448,4 @@ where V: Value<'a>
pub fn is_object(&self) -> bool {
matches!(self, Self::Object(_))
}
#[inline]
/// HUSH
pub fn into_geometry(self) -> Option<Box<Geometry>> {
self.into_leaf().and_then(|leaf| leaf.into_geometry())
}
}

View File

@@ -1,7 +1,6 @@
use serde::{Deserialize, Serialize};
use super::ip_options::IpAddrOptions;
use super::spatial_options::SpatialOptions;
use crate::schema::bytes_options::BytesOptions;
use crate::schema::{
is_valid_field_name, DateOptions, FacetOptions, FieldType, JsonObjectOptions, NumericOptions,
@@ -81,11 +80,6 @@ impl FieldEntry {
Self::new(field_name, FieldType::JsonObject(json_object_options))
}
/// Creates a field entry for a spatial field
pub fn new_spatial(field_name: String, spatial_options: SpatialOptions) -> FieldEntry {
Self::new(field_name, FieldType::Spatial(spatial_options))
}
/// Returns the name of the field
pub fn name(&self) -> &str {
&self.name
@@ -135,7 +129,6 @@ impl FieldEntry {
FieldType::Bytes(ref options) => options.is_stored(),
FieldType::JsonObject(ref options) => options.is_stored(),
FieldType::IpAddr(ref options) => options.is_stored(),
FieldType::Spatial(ref options) => options.is_stored(),
}
}
}

View File

@@ -9,7 +9,6 @@ use serde_json::Value as JsonValue;
use thiserror::Error;
use super::ip_options::IpAddrOptions;
use super::spatial_options::SpatialOptions;
use super::IntoIpv6Addr;
use crate::schema::bytes_options::BytesOptions;
use crate::schema::facet_options::FacetOptions;
@@ -17,7 +16,6 @@ use crate::schema::{
DateOptions, Facet, IndexRecordOption, JsonObjectOptions, NumericOptions, OwnedValue,
TextFieldIndexing, TextOptions,
};
use crate::spatial::geometry::Geometry;
use crate::time::format_description::well_known::Rfc3339;
use crate::time::OffsetDateTime;
use crate::tokenizer::PreTokenizedString;
@@ -73,8 +71,6 @@ pub enum Type {
Json = b'j',
/// IpAddr
IpAddr = b'p',
/// Spatial
Spatial = b't',
}
impl From<ColumnType> for Type {
@@ -143,7 +139,6 @@ impl Type {
Type::Bytes => "Bytes",
Type::Json => "Json",
Type::IpAddr => "IpAddr",
Type::Spatial => "Spatial",
}
}
@@ -194,8 +189,6 @@ pub enum FieldType {
JsonObject(JsonObjectOptions),
/// IpAddr field
IpAddr(IpAddrOptions),
/// Spatial field
Spatial(SpatialOptions),
}
impl FieldType {
@@ -212,7 +205,6 @@ impl FieldType {
FieldType::Bytes(_) => Type::Bytes,
FieldType::JsonObject(_) => Type::Json,
FieldType::IpAddr(_) => Type::IpAddr,
FieldType::Spatial(_) => Type::Spatial,
}
}
@@ -249,7 +241,6 @@ impl FieldType {
FieldType::Bytes(ref bytes_options) => bytes_options.is_indexed(),
FieldType::JsonObject(ref json_object_options) => json_object_options.is_indexed(),
FieldType::IpAddr(ref ip_addr_options) => ip_addr_options.is_indexed(),
FieldType::Spatial(ref _spatial_options) => true,
}
}
@@ -287,7 +278,6 @@ impl FieldType {
FieldType::IpAddr(ref ip_addr_options) => ip_addr_options.is_fast(),
FieldType::Facet(_) => true,
FieldType::JsonObject(ref json_object_options) => json_object_options.is_fast(),
FieldType::Spatial(_) => false,
}
}
@@ -307,7 +297,6 @@ impl FieldType {
FieldType::Bytes(ref bytes_options) => bytes_options.fieldnorms(),
FieldType::JsonObject(ref _json_object_options) => false,
FieldType::IpAddr(ref ip_addr_options) => ip_addr_options.fieldnorms(),
FieldType::Spatial(_) => false,
}
}
@@ -359,8 +348,6 @@ impl FieldType {
None
}
}
FieldType::Spatial(_) => None, /* Geometry types cannot be indexed in the inverted
* index. */
}
}
@@ -462,10 +449,6 @@ impl FieldType {
Ok(OwnedValue::IpAddr(ip_addr.into_ipv6_addr()))
}
FieldType::Spatial(_) => Err(ValueParsingError::TypeError {
expected: "spatial field parsing not implemented",
json: JsonValue::String(field_text),
}),
}
}
JsonValue::Number(field_val_num) => match self {
@@ -525,10 +508,6 @@ impl FieldType {
expected: "a string with an ip addr",
json: JsonValue::Number(field_val_num),
}),
FieldType::Spatial(_) => Err(ValueParsingError::TypeError {
expected: "spatial field parsing not implemented",
json: JsonValue::Number(field_val_num),
}),
},
JsonValue::Object(json_map) => match self {
FieldType::Str(_) => {
@@ -544,14 +523,6 @@ impl FieldType {
}
}
FieldType::JsonObject(_) => Ok(OwnedValue::from(json_map)),
FieldType::Spatial(_) => Ok(OwnedValue::Geometry(
Geometry::from_geojson(&json_map).map_err(|e| {
ValueParsingError::ParseError {
error: format!("{:?}", e),
json: JsonValue::Object(json_map),
}
})?,
)),
_ => Err(ValueParsingError::TypeError {
expected: self.value_type().name(),
json: JsonValue::Object(json_map),

View File

@@ -1,6 +1,6 @@
use std::ops::BitOr;
use crate::schema::{DateOptions, NumericOptions, SpatialOptions, TextOptions};
use crate::schema::{DateOptions, NumericOptions, TextOptions};
#[derive(Clone)]
pub struct StoredFlag;
@@ -95,14 +95,6 @@ impl<T: Clone + Into<TextOptions>> BitOr<TextOptions> for SchemaFlagList<T, ()>
}
}
impl<T: Clone + Into<SpatialOptions>> BitOr<SpatialOptions> for SchemaFlagList<T, ()> {
type Output = SpatialOptions;
fn bitor(self, rhs: SpatialOptions) -> Self::Output {
self.head.into() | rhs
}
}
#[derive(Clone)]
pub struct SchemaFlagList<Head: Clone, Tail: Clone> {
pub head: Head,

View File

@@ -124,7 +124,6 @@ mod ip_options;
mod json_object_options;
mod named_field_document;
mod numeric_options;
mod spatial_options;
mod text_options;
use columnar::ColumnType;
@@ -145,7 +144,6 @@ pub use self::json_object_options::JsonObjectOptions;
pub use self::named_field_document::NamedFieldDocument;
pub use self::numeric_options::NumericOptions;
pub use self::schema::{Schema, SchemaBuilder};
pub use self::spatial_options::{SpatialOptions, SPATIAL};
pub use self::term::{Term, ValueBytes};
pub use self::text_options::{TextFieldIndexing, TextOptions, STRING, TEXT};
@@ -170,7 +168,6 @@ pub(crate) fn value_type_to_column_type(typ: Type) -> Option<ColumnType> {
Type::Bytes => Some(ColumnType::Bytes),
Type::IpAddr => Some(ColumnType::IpAddr),
Type::Json => None,
Type::Spatial => None,
}
}

View File

@@ -194,16 +194,6 @@ impl SchemaBuilder {
self.add_field(field_entry)
}
/// Adds a spatial entry to the schema in build.
pub fn add_spatial_field<T: Into<SpatialOptions>>(
&mut self,
field_name: &str,
field_options: T,
) -> Field {
let field_entry = FieldEntry::new_spatial(field_name.to_string(), field_options.into());
self.add_field(field_entry)
}
/// Adds a field entry to the schema in build.
pub fn add_field(&mut self, field_entry: FieldEntry) -> Field {
let field = Field::from_field_id(self.fields.len() as u32);
@@ -218,14 +208,9 @@ impl SchemaBuilder {
/// Finalize the creation of a `Schema`
/// This will consume your `SchemaBuilder`
pub fn build(self) -> Schema {
let contains_spatial_field = self
.fields
.iter()
.any(|field_entry| field_entry.field_type().value_type() == Type::Spatial);
Schema(Arc::new(InnerSchema {
fields: self.fields,
fields_map: self.fields_map,
contains_spatial_field,
}))
}
}
@@ -233,7 +218,6 @@ impl SchemaBuilder {
struct InnerSchema {
fields: Vec<FieldEntry>,
fields_map: HashMap<String, Field>, // transient
contains_spatial_field: bool,
}
impl PartialEq for InnerSchema {
@@ -384,11 +368,6 @@ impl Schema {
}
Some((field, json_path))
}
/// Returns true if the schema contains a spatial field.
pub(crate) fn contains_spatial_field(&self) -> bool {
self.0.contains_spatial_field
}
}
impl Serialize for Schema {
@@ -416,16 +395,16 @@ impl<'de> Deserialize<'de> for Schema {
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where A: SeqAccess<'de> {
let mut schema_builder = SchemaBuilder {
let mut schema = SchemaBuilder {
fields: Vec::with_capacity(seq.size_hint().unwrap_or(0)),
fields_map: HashMap::with_capacity(seq.size_hint().unwrap_or(0)),
};
while let Some(value) = seq.next_element()? {
schema_builder.add_field(value);
schema.add_field(value);
}
Ok(schema_builder.build())
Ok(schema.build())
}
}
@@ -1041,33 +1020,4 @@ mod tests {
Some((default, "foobar"))
);
}
#[test]
fn test_contains_spatial_field() {
// No spatial field
{
let mut schema_builder = Schema::builder();
schema_builder.add_text_field("title", TEXT);
let schema = schema_builder.build();
assert!(!schema.contains_spatial_field());
// Serialization check
let schema_json = serde_json::to_string(&schema).unwrap();
let schema_deserialized: Schema = serde_json::from_str(&schema_json).unwrap();
assert!(!schema_deserialized.contains_spatial_field());
}
// With spatial field
{
let mut schema_builder = Schema::builder();
schema_builder.add_text_field("title", TEXT);
schema_builder.add_spatial_field("location", SPATIAL);
let schema = schema_builder.build();
assert!(schema.contains_spatial_field());
// Serialization check
let schema_json = serde_json::to_string(&schema).unwrap();
let schema_deserialized: Schema = serde_json::from_str(&schema_json).unwrap();
assert!(schema_deserialized.contains_spatial_field());
}
}
}

View File

@@ -1,53 +0,0 @@
use std::ops::BitOr;
use serde::{Deserialize, Serialize};
use crate::schema::flags::StoredFlag;
/// Define how a spatial field should be handled by tantivy.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Default)]
pub struct SpatialOptions {
#[serde(default)]
stored: bool,
}
/// The field will be untokenized and indexed.
pub const SPATIAL: SpatialOptions = SpatialOptions { stored: false };
impl SpatialOptions {
/// Returns true if the geometry is to be stored.
#[inline]
pub fn is_stored(&self) -> bool {
self.stored
}
}
impl<T: Into<SpatialOptions>> BitOr<T> for SpatialOptions {
type Output = SpatialOptions;
fn bitor(self, other: T) -> SpatialOptions {
let other = other.into();
SpatialOptions {
stored: self.stored | other.stored,
}
}
}
impl From<StoredFlag> for SpatialOptions {
fn from(_: StoredFlag) -> SpatialOptions {
SpatialOptions { stored: true }
}
}
// #[cfg(test)]
// mod tests {
// use crate::schema::*;
//
// #[test]
// fn test_field_options() {
// let field_options = STORED | SPATIAL;
// assert!(field_options.is_stored());
// let mut schema_builder = Schema::builder();
// schema_builder.add_spatial_index("where", SPATIAL | STORED);
// }
// }

View File

@@ -503,9 +503,6 @@ where B: AsRef<[u8]>
Type::IpAddr => {
write_opt(f, self.as_ip_addr())?;
}
Type::Spatial => {
write!(f, "<spatial term formatting not yet implemented>")?;
}
}
Ok(())
}

View File

@@ -69,7 +69,6 @@ pub struct SegmentSpaceUsage {
positions: PerFieldSpaceUsage,
fast_fields: PerFieldSpaceUsage,
fieldnorms: PerFieldSpaceUsage,
spatial: PerFieldSpaceUsage,
store: StoreSpaceUsage,
@@ -87,7 +86,6 @@ impl SegmentSpaceUsage {
positions: PerFieldSpaceUsage,
fast_fields: PerFieldSpaceUsage,
fieldnorms: PerFieldSpaceUsage,
spatial: PerFieldSpaceUsage,
store: StoreSpaceUsage,
deletes: ByteCount,
) -> SegmentSpaceUsage {
@@ -96,7 +94,6 @@ impl SegmentSpaceUsage {
+ positions.total()
+ fast_fields.total()
+ fieldnorms.total()
+ spatial.total()
+ store.total()
+ deletes;
SegmentSpaceUsage {
@@ -106,7 +103,6 @@ impl SegmentSpaceUsage {
positions,
fast_fields,
fieldnorms,
spatial,
store,
deletes,
total,
@@ -125,7 +121,6 @@ impl SegmentSpaceUsage {
Positions => PerField(self.positions().clone()),
FastFields => PerField(self.fast_fields().clone()),
FieldNorms => PerField(self.fieldnorms().clone()),
Spatial => PerField(self.spatial().clone()),
Terms => PerField(self.termdict().clone()),
SegmentComponent::Store => ComponentSpaceUsage::Store(self.store().clone()),
SegmentComponent::TempStore => ComponentSpaceUsage::Store(self.store().clone()),
@@ -163,11 +158,6 @@ impl SegmentSpaceUsage {
&self.fieldnorms
}
/// Space usage for field norms
pub fn spatial(&self) -> &PerFieldSpaceUsage {
&self.spatial
}
/// Space usage for stored documents
pub fn store(&self) -> &StoreSpaceUsage {
&self.store

View File

@@ -1,853 +0,0 @@
//! Block kd-tree spatial indexing for triangulated polygons.
//!
//! Implements an immutable bulk-loaded spatial index using recursive median partitioning on
//! bounding box dimensions. Each leaf stores up to 512 triangles with delta-compressed coordinates
//! and doc IDs. The tree provides three query types (intersects, within, contains) that use exact
//! integer arithmetic for geometric predicates and accumulate results in bit sets for efficient
//! deduplication across leaves.
//!
//! The serialized format stores compressed leaf pages followed by the tree structure (leaf and
//! branch nodes), enabling zero-copy access through memory-mapped segments without upfront
//! decompression.
use std::io;
use std::io::Write;
use common::{BitSet, CountingWriter};
use crate::directory::WritePtr;
use crate::spatial::delta::{compress, decompress, Compressible};
use crate::spatial::triangle::Triangle;
#[derive(Clone, Copy)]
struct SpreadSurvey {
min: i32,
max: i32,
}
impl SpreadSurvey {
fn survey(&mut self, value: i32) {
self.min = self.min.min(value);
self.max = self.max.max(value);
}
fn spread(&self) -> i32 {
self.max - self.min
}
}
impl Default for SpreadSurvey {
fn default() -> Self {
SpreadSurvey {
min: i32::MAX,
max: i32::MIN,
}
}
}
#[derive(Clone, Copy)]
struct BoundingBoxSurvey {
bbox: [i32; 4],
}
impl BoundingBoxSurvey {
fn survey(&mut self, triangle: &Triangle) {
self.bbox[0] = triangle.words[0].min(self.bbox[0]);
self.bbox[1] = triangle.words[1].min(self.bbox[1]);
self.bbox[2] = triangle.words[2].max(self.bbox[2]);
self.bbox[3] = triangle.words[3].max(self.bbox[3]);
}
fn bbox(&self) -> [i32; 4] {
self.bbox.clone()
}
}
impl Default for BoundingBoxSurvey {
fn default() -> Self {
BoundingBoxSurvey {
bbox: [i32::MAX, i32::MAX, i32::MIN, i32::MIN],
}
}
}
enum BuildNode {
Branch {
bbox: [i32; 4],
left: Box<BuildNode>,
right: Box<BuildNode>,
},
Leaf {
bbox: [i32; 4],
pos: u64,
len: u16,
},
}
struct CompressibleTriangleI32<'a> {
triangles: &'a [Triangle],
dimension: usize,
}
impl<'a> CompressibleTriangleI32<'a> {
fn new(triangles: &'a [Triangle], dimension: usize) -> Self {
CompressibleTriangleI32 {
triangles,
dimension,
}
}
}
impl<'a> Compressible for CompressibleTriangleI32<'a> {
type Value = i32;
fn len(&self) -> usize {
self.triangles.len()
}
fn get(&self, i: usize) -> i32 {
self.triangles[i].words[self.dimension]
}
}
struct CompressibleTriangleDocID<'a> {
triangles: &'a [Triangle],
}
impl<'a> CompressibleTriangleDocID<'a> {
fn new(triangles: &'a [Triangle]) -> Self {
CompressibleTriangleDocID { triangles }
}
}
impl<'a> Compressible for CompressibleTriangleDocID<'a> {
type Value = u32;
fn len(&self) -> usize {
self.triangles.len()
}
fn get(&self, i: usize) -> u32 {
self.triangles[i].doc_id
}
}
// Leaf pages are first the count of triangles, followed by delta encoded doc_ids, followed by
// the delta encoded words in order. We will then have the length of the page. We build a tree
// after the pages with leaf nodes and branch nodes. Leaf nodes will contain the bounding box
// of the leaf followed position and length of the page. The leaf node is a level of direction
// to store the position and length of the page in a format that is easy to read directly from
// the mapping.
// We do not compress the tree nodes. We read them directly from the mapping.
//
fn write_leaf_pages(
triangles: &mut [Triangle],
write: &mut CountingWriter<WritePtr>,
) -> io::Result<BuildNode> {
// If less than 512 triangles we are at a leaf, otherwise we still in the inner nodes.
if triangles.len() <= 512 {
let pos = write.written_bytes();
let mut spreads = [SpreadSurvey::default(); 4];
let mut bounding_box = BoundingBoxSurvey::default();
for triangle in triangles.iter() {
for i in 0..4 {
spreads[i].survey(triangle.words[i]);
}
bounding_box.survey(triangle);
}
let mut max_spread = spreads[0].spread();
let mut dimension = 0;
for i in 1..4 {
let current_spread = spreads[i].spread();
if current_spread > max_spread {
dimension = i;
max_spread = current_spread;
}
}
write.write_all(&(triangles.len() as u16).to_le_bytes())?;
triangles.sort_by_key(|t| t.words[dimension]);
compress(&CompressibleTriangleDocID::new(triangles), write)?;
let compressible = [
CompressibleTriangleI32::new(triangles, 0),
CompressibleTriangleI32::new(triangles, 1),
CompressibleTriangleI32::new(triangles, 2),
CompressibleTriangleI32::new(triangles, 3),
CompressibleTriangleI32::new(triangles, 4),
CompressibleTriangleI32::new(triangles, 5),
CompressibleTriangleI32::new(triangles, 6),
];
for i in 0..7 {
compress(&compressible[i], write)?;
}
let len = write.written_bytes() - pos;
Ok(BuildNode::Leaf {
bbox: bounding_box.bbox(),
pos,
len: len as u16,
})
} else {
let mut spreads = [SpreadSurvey::default(); 4];
let mut bounding_box = BoundingBoxSurvey::default();
for triangle in triangles.iter() {
for i in 0..4 {
spreads[i].survey(triangle.words[i]);
}
bounding_box.survey(triangle);
}
let mut max_spread = spreads[0].spread();
let mut dimension = 0;
for i in 0..4 {
let current_spread = spreads[i].spread();
if current_spread > max_spread {
dimension = i;
max_spread = current_spread;
}
}
// Partition the triangles.
let mid = triangles.len() / 2;
triangles.select_nth_unstable_by_key(mid, |t| t.words[dimension]);
let partition = triangles[mid].words[dimension];
let mut split_point = mid + 1;
while split_point < triangles.len() && triangles[split_point].words[dimension] == partition
{
split_point += 1;
}
// If we reached the end of triangles then all of the triangles share the partition value
// for the dimension. We handle this degeneracy by splitting at the midpoint so that we
// won't have a leaf with zero triangles.
if split_point == triangles.len() {
split_point = mid; // Force split at midpoint index
} else {
// Our partition does not sort the triangles, it only partitions. We have scan our right
// partition to find all the midpoint values and move them to the left partition.
let mut reverse = triangles.len() - 1;
loop {
// Scan backwards looking for the partition value.
while triangles[reverse].words[dimension] != partition {
reverse -= 1;
}
// If we have reached the split point then we are done.
if reverse <= split_point {
break;
}
// Swap the midpoint value with our current split point.
triangles.swap(split_point, reverse);
// Move the split point up one.
split_point += 1;
// We know that what was at the split point was not the midpoint value.
reverse -= 1;
}
}
// Split into left and write partitions and create child nodes.
let (left, right) = triangles.split_at_mut(split_point);
let left_node = write_leaf_pages(left, write)?;
let right_node = write_leaf_pages(right, write)?;
// Return an inner node.
Ok(BuildNode::Branch {
bbox: bounding_box.bbox(),
left: Box::new(left_node),
right: Box::new(right_node),
})
}
}
fn write_leaf_nodes(node: &BuildNode, write: &mut CountingWriter<WritePtr>) -> io::Result<()> {
match node {
BuildNode::Branch {
bbox: _,
left,
right,
} => {
write_leaf_nodes(right, write)?;
write_leaf_nodes(left, write)?;
}
BuildNode::Leaf { bbox, pos, len } => {
for &dimension in bbox.iter() {
write.write_all(&dimension.to_le_bytes())?;
}
write.write_all(&pos.to_le_bytes())?;
write.write_all(&len.to_le_bytes())?;
write.write_all(&[0u8; 6])?;
}
}
Ok(())
}
fn write_branch_nodes(
node: &BuildNode,
branch_offset: &mut i32,
leaf_offset: &mut i32,
write: &mut CountingWriter<WritePtr>,
) -> io::Result<i32> {
match node {
BuildNode::Leaf { .. } => {
let pos = *leaf_offset;
*leaf_offset -= 1;
Ok(pos * size_of::<LeafNode>() as i32)
}
BuildNode::Branch { bbox, left, right } => {
let left = write_branch_nodes(left, branch_offset, leaf_offset, write)?;
let right = write_branch_nodes(right, branch_offset, leaf_offset, write)?;
for &val in bbox {
write.write_all(&val.to_le_bytes())?;
}
write.write_all(&left.to_le_bytes())?;
write.write_all(&right.to_le_bytes())?;
write.write_all(&[0u8; 8])?;
let pos = *branch_offset;
*branch_offset += 1;
Ok(pos * size_of::<BranchNode>() as i32)
}
}
}
const VERSION: u16 = 1u16;
/// Builds and serializes a block kd-tree for spatial indexing of triangles.
///
/// Takes a collection of triangles and constructs a complete block kd-tree, writing both the
/// compressed leaf pages and tree structure to the output. The tree uses recursive median
/// partitioning on the dimension with maximum spread, storing up to 512 triangles per leaf.
///
/// The output format consists of:
/// - Version header (u16)
/// - Compressed leaf pages (delta-encoded doc_ids and triangle coordinates)
/// - 32-byte aligned tree structure (leaf nodes, then branch nodes)
/// - Footer with triangle count, root offset, and branch position
///
/// The `triangles` slice will be reordered during tree construction as partitioning sorts by the
/// selected dimension at each level.
pub fn write_block_kd_tree(
triangles: &mut [Triangle],
write: &mut CountingWriter<WritePtr>,
) -> io::Result<()> {
write.write_all(&VERSION.to_le_bytes())?;
let tree = write_leaf_pages(triangles, write)?;
let current = write.written_bytes();
let aligned = current.next_multiple_of(32);
let padding = aligned - current;
write.write_all(&vec![0u8; padding as usize])?;
write_leaf_nodes(&tree, write)?;
let branch_position = write.written_bytes();
let mut branch_offset: i32 = 0;
let mut leaf_offset: i32 = -1;
let root = write_branch_nodes(&tree, &mut branch_offset, &mut leaf_offset, write)?;
write.write_all(&[0u8; 12])?;
write.write_all(&triangles.len().to_le_bytes())?;
write.write_all(&root.to_le_bytes())?;
write.write_all(&branch_position.to_le_bytes())?;
Ok(())
}
fn decompress_leaf(mut data: &[u8]) -> io::Result<Vec<Triangle>> {
use common::BinarySerializable;
let triangle_count: usize = u16::deserialize(&mut data)? as usize;
let mut offset: usize = 0;
let mut triangles: Vec<Triangle> = Vec::with_capacity(triangle_count);
offset += decompress::<u32, _>(&data[offset..], triangle_count, |_, doc_id| {
triangles.push(Triangle::skeleton(doc_id))
})?;
for i in 0..7 {
offset += decompress::<i32, _>(&data[offset..], triangle_count, |j, word| {
triangles[j].words[i] = word
})?;
}
Ok(triangles)
}
#[repr(C)]
struct BranchNode {
bbox: [i32; 4],
left: i32,
right: i32,
pad: [u8; 8],
}
#[repr(C)]
struct LeafNode {
bbox: [i32; 4],
pos: u64,
len: u16,
pad: [u8; 6],
}
/// A read-only view into a serialized block kd-tree segment.
///
/// Provides access to the tree structure and compressed leaf data through memory-mapped or
/// buffered byte slices. The segment contains compressed leaf pages followed by the tree structure
/// (leaf nodes and branch nodes), with a footer containing metadata for locating the root and
/// interpreting offsets.
pub struct Segment<'a> {
data: &'a [u8],
branch_position: u64,
/// Offset to the root of the tree, used as the starting point for traversal.
pub root_offset: i32,
}
impl<'a> Segment<'a> {
/// Creates a new segment from serialized block kd-tree data.
///
/// Reads the footer metadata from the last 12 bytes to locate the tree structure and root
/// node.
pub fn new(data: &'a [u8]) -> Self {
Segment {
data,
branch_position: u64::from_le_bytes(data[data.len() - 8..].try_into().unwrap()),
root_offset: i32::from_le_bytes(
data[data.len() - 12..data.len() - 8].try_into().unwrap(),
),
}
}
#[inline(always)]
fn bounding_box(&self, offset: i32) -> [i32; 4] {
let byte_offset = (self.branch_position as i64 + offset as i64) as usize;
let bytes = &self.data[byte_offset..byte_offset + 16];
[
i32::from_le_bytes(bytes[0..4].try_into().unwrap()),
i32::from_le_bytes(bytes[4..8].try_into().unwrap()),
i32::from_le_bytes(bytes[8..12].try_into().unwrap()),
i32::from_le_bytes(bytes[12..16].try_into().unwrap()),
]
}
#[inline(always)]
fn branch_node(&self, offset: i32) -> BranchNode {
let byte_offset = (self.branch_position as i64 + offset as i64) as usize;
let bytes = &self.data[byte_offset..byte_offset + 32];
BranchNode {
bbox: [
i32::from_le_bytes(bytes[0..4].try_into().unwrap()),
i32::from_le_bytes(bytes[4..8].try_into().unwrap()),
i32::from_le_bytes(bytes[8..12].try_into().unwrap()),
i32::from_le_bytes(bytes[12..16].try_into().unwrap()),
],
left: i32::from_le_bytes(bytes[16..20].try_into().unwrap()),
right: i32::from_le_bytes(bytes[20..24].try_into().unwrap()),
pad: [0u8; 8],
}
}
#[inline(always)]
fn leaf_node(&self, offset: i32) -> LeafNode {
let byte_offset = (self.branch_position as i64 + offset as i64) as usize;
let bytes = &self.data[byte_offset..byte_offset + 32];
LeafNode {
bbox: [
i32::from_le_bytes(bytes[0..4].try_into().unwrap()),
i32::from_le_bytes(bytes[4..8].try_into().unwrap()),
i32::from_le_bytes(bytes[8..12].try_into().unwrap()),
i32::from_le_bytes(bytes[12..16].try_into().unwrap()),
],
pos: u64::from_le_bytes(bytes[16..24].try_into().unwrap()),
len: u16::from_le_bytes(bytes[24..26].try_into().unwrap()),
pad: [0u8; 6],
}
}
fn leaf_page(&self, leaf_node: &LeafNode) -> &[u8] {
&self.data[(leaf_node.pos as usize)..(leaf_node.pos as usize + leaf_node.len as usize)]
}
}
fn collect_all_docs(segment: &Segment, offset: i32, result: &mut BitSet) -> io::Result<()> {
if offset < 0 {
let leaf_node = segment.leaf_node(offset);
let data = segment.leaf_page(&leaf_node);
let count = u16::from_le_bytes([data[0], data[1]]) as usize;
decompress::<u32, _>(&data[2..], count, |_, doc_id| result.insert(doc_id))?;
} else {
let branch_node = segment.branch_node(offset);
collect_all_docs(segment, branch_node.left, result)?;
collect_all_docs(segment, branch_node.right, result)?;
}
Ok(())
}
fn bbox_within(bbox: &[i32; 4], query: &[i32; 4]) -> bool {
bbox[0] >= query[0] && // min_y >= query_min_y
bbox[1] >= query[1] && // min_x >= query_min_x
bbox[2] <= query[2] && // max_y <= query_max_y
bbox[3] <= query[3] // max_x <= query_max_x
}
fn bbox_intersects(bbox: &[i32; 4], query: &[i32; 4]) -> bool {
!(bbox[2] < query[0] || bbox[0] > query[2] || bbox[3] < query[1] || bbox[1] > query[3])
}
/// Finds documents with triangles that intersect the query bounding box.
///
/// Traverses the tree starting at `offset` (typically `segment.root_offset`), pruning subtrees
/// whose bounding boxes don't intersect the query. When a node's bbox is entirely within the
/// query, all its documents are bulk-collected. Otherwise, individual triangles are tested using
/// exact geometric predicates.
///
/// The query is `[min_y, min_x, max_y, max_x]` in integer coordinates. Documents are inserted into
/// the `result` BitSet, which automatically deduplicates when the same document appears in
/// multiple leaves.
pub fn search_intersects(
segment: &Segment,
offset: i32,
query: &[i32; 4],
result: &mut BitSet,
) -> io::Result<()> {
let bbox = segment.bounding_box(offset);
// bbox doesn't intersect query → skip entire subtree
if !bbox_intersects(&bbox, query) {
}
// bbox entirely within query → all triangles intersect
else if bbox_within(&bbox, query) {
collect_all_docs(segment, offset, result)?;
} else if offset < 0 {
// bbox crosses query → test each triangle
let leaf_node = segment.leaf_node(offset);
let triangles = decompress_leaf(segment.leaf_page(&leaf_node))?;
for triangle in &triangles {
if triangle_intersects(triangle, query) {
result.insert(triangle.doc_id); // BitSet deduplicates
}
}
} else {
let branch_node = segment.branch_node(offset);
// bbox crosses query → must check children
search_intersects(segment, branch_node.left, query, result)?;
search_intersects(segment, branch_node.right, query, result)?;
}
Ok(())
}
#[expect(clippy::too_many_arguments)]
fn line_intersects_line(
x1: i32,
y1: i32,
x2: i32,
y2: i32,
x3: i32,
y3: i32,
x4: i32,
y4: i32,
) -> bool {
// Cast to i128 to prevent overflow in coordinate arithmetic
let x1 = x1 as i128;
let y1 = y1 as i128;
let x2 = x2 as i128;
let y2 = y2 as i128;
let x3 = x3 as i128;
let y3 = y3 as i128;
let x4 = x4 as i128;
let y4 = y4 as i128;
// Proper segment-segment intersection test
let d = (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4);
if d == 0 {
// parallel
return false;
}
let t = (x1 - x3) * (y3 - y4) - (y1 - y3) * (x3 - x4);
let u = -((x1 - x2) * (y1 - y3) - (y1 - y2) * (x1 - x3));
if d > 0 {
t >= 0 && t <= d && u >= 0 && u <= d
} else {
t <= 0 && t >= d && u <= 0 && u >= d
}
}
fn edge_intersects_bbox(x1: i32, y1: i32, x2: i32, y2: i32, bbox: &[i32; 4]) -> bool {
// Test against all 4 rectangle edges, bottom, right, top, left.
line_intersects_line(x1, y1, x2, y2, bbox[1], bbox[0], bbox[3], bbox[0])
|| line_intersects_line(x1, y1, x2, y2, bbox[3], bbox[0], bbox[3], bbox[2])
|| line_intersects_line(x1, y1, x2, y2, bbox[3], bbox[2], bbox[1], bbox[2])
|| line_intersects_line(x1, y1, x2, y2, bbox[1], bbox[2], bbox[1], bbox[0])
}
fn edge_crosses_bbox(x1: i32, y1: i32, x2: i32, y2: i32, bbox: &[i32; 4]) -> bool {
// Edge has endpoint outside while other is inside (crosses boundary)
let p1_inside = y1 >= bbox[0] && x1 >= bbox[1] && y1 <= bbox[2] && x1 <= bbox[3];
let p2_inside = y2 >= bbox[0] && x2 >= bbox[1] && y2 <= bbox[2] && x2 <= bbox[3];
p1_inside != p2_inside
}
fn triangle_within(triangle: &Triangle, query: &[i32; 4]) -> bool {
let tri_bbox = &triangle.words[0..4];
// Triangle bbox entirely within query → WITHIN
if tri_bbox[0] >= query[0]
&& tri_bbox[1] >= query[1]
&& tri_bbox[2] <= query[2]
&& tri_bbox[3] <= query[3]
{
return true;
}
// Triangle bbox entirely outside → NOT WITHIN
if tri_bbox[2] < query[0]
|| tri_bbox[3] < query[1]
|| tri_bbox[0] > query[2]
|| tri_bbox[1] > query[3]
{
return false;
}
// Decode vertices.
let ([ay, ax, by, bx, cy, cx], [ab, bc, ca]) = triangle.decode();
// Check each edge - if boundary edge crosses query bbox, NOT WITHIN
if ab && edge_crosses_bbox(ax, ay, bx, by, query) {
return false;
}
if bc && edge_crosses_bbox(bx, by, cx, cy, query) {
return false;
}
if ca && edge_crosses_bbox(cx, cy, ax, ay, query) {
return false;
}
// No boundary edges cross out
true
}
#[expect(clippy::too_many_arguments)]
fn point_in_triangle(
px: i32,
py: i32,
ax: i32,
ay: i32,
bx: i32,
by: i32,
cx: i32,
cy: i32,
) -> bool {
let v0x = (cx - ax) as i128;
let v0y = (cy - ay) as i128;
let v1x = (bx - ax) as i128;
let v1y = (by - ay) as i128;
let v2x = (px - ax) as i128;
let v2y = (py - ay) as i128;
let dot00 = v0x * v0x + v0y * v0y;
let dot01 = v0x * v1x + v0y * v1y;
let dot02 = v0x * v2x + v0y * v2y;
let dot11 = v1x * v1x + v1y * v1y;
let dot12 = v1x * v2x + v1y * v2y;
let denom = dot00 * dot11 - dot01 * dot01;
if denom == 0 {
return false;
}
let u = dot11 * dot02 - dot01 * dot12;
let v = dot00 * dot12 - dot01 * dot02;
u >= 0 && v >= 0 && u + v <= denom
}
fn triangle_intersects(triangle: &Triangle, query: &[i32; 4]) -> bool {
let tri_bbox = &triangle.words[0..4];
// Quick reject: bboxes don't overlap
if tri_bbox[2] < query[0]
|| tri_bbox[3] < query[1]
|| tri_bbox[0] > query[2]
|| tri_bbox[1] > query[3]
{
return false;
}
let ([ay, ax, by, bx, cy, cx], _) = triangle.decode();
// Any triangle vertex inside rectangle?
if (ax >= query[1] && ax <= query[3] && ay >= query[0] && ay <= query[2])
|| (bx >= query[1] && bx <= query[3] && by >= query[0] && by <= query[2])
|| (cx >= query[1] && cx <= query[3] && cy >= query[0] && cy <= query[2])
{
return true;
}
// Any rectangle corner inside triangle?
let corners = [
(query[1], query[0]), // min_x, min_y
(query[3], query[0]), // max_x, min_y
(query[3], query[2]), // max_x, max_y
(query[1], query[2]), // min_x, max_y
];
for (x, y) in corners {
if point_in_triangle(x, y, ax, ay, bx, by, cx, cy) {
return true;
}
}
// Any triangle edge intersect rectangle edges?
edge_intersects_bbox(ax, ay, bx, by, query)
|| edge_intersects_bbox(bx, by, cx, cy, query)
|| edge_intersects_bbox(cx, cy, ax, ay, query)
}
/// Finds documents where all triangles are within the query bounding box.
///
/// Traverses the tree starting at `offset` (typically `segment.root_offset`), testing each
/// triangle to determine if it lies entirely within the query bounds. Uses two `BitSet` instances
/// to track state: `result` accumulates candidate documents, while `excluded` marks documents that
/// have at least one triangle extending outside the query.
///
/// The query is `[min_y, min_x, max_y, max_x]` in integer coordinates. The final result is
/// documents in `result` that are NOT in `excluded` - the caller must compute this difference.
pub fn search_within(
segment: &Segment,
offset: i32,
query: &[i32; 4], // [min_y, min_x, max_y, max_x]
result: &mut BitSet,
excluded: &mut BitSet,
) -> io::Result<()> {
let bbox = segment.bounding_box(offset);
if !bbox_intersects(&bbox, query) {
} else if offset < 0 {
let leaf_node = segment.leaf_node(offset);
// bbox crosses query → test each triangle
let triangles = decompress_leaf(segment.leaf_page(&leaf_node))?;
for triangle in &triangles {
if triangle_intersects(triangle, query) {
if excluded.contains(triangle.doc_id) {
continue; // Already excluded
}
if triangle_within(triangle, query) {
result.insert(triangle.doc_id);
} else {
excluded.insert(triangle.doc_id);
}
}
}
} else {
let branch_node = segment.branch_node(offset);
search_within(segment, branch_node.left, query, result, excluded)?;
search_within(segment, branch_node.right, query, result, excluded)?;
}
Ok(())
}
enum ContainsRelation {
CANDIDATE, // Query might be contained
NOTWITHIN, // Query definitely not contained
DISJOINT, // Triangle doesn't overlap query
}
fn triangle_contains_relation(triangle: &Triangle, query: &[i32; 4]) -> ContainsRelation {
let tri_bbox = &triangle.words[0..4];
if query[2] < tri_bbox[0]
|| query[3] < tri_bbox[1]
|| query[0] > tri_bbox[2]
|| query[1] > tri_bbox[3]
{
return ContainsRelation::DISJOINT;
}
let ([ay, ax, by, bx, cy, cx], [ab, bc, ca]) = triangle.decode();
let corners = [
(query[1], query[0]),
(query[3], query[0]),
(query[3], query[2]),
(query[1], query[2]),
];
let mut any_corner_inside = false;
for &(qx, qy) in &corners {
if point_in_triangle(qx, qy, ax, ay, bx, by, cx, cy) {
any_corner_inside = true;
break;
}
}
let ab_intersects = edge_intersects_bbox(ax, ay, bx, by, query);
let bc_intersects = edge_intersects_bbox(bx, by, cx, cy, query);
let ca_intersects = edge_intersects_bbox(cx, cy, ax, ay, query);
if (ab && edge_crosses_bbox(ax, ay, bx, by, query))
|| (bc && edge_crosses_bbox(bx, by, cx, cy, query))
|| (ca && edge_crosses_bbox(cx, cy, ax, ay, query))
{
return ContainsRelation::NOTWITHIN;
}
if any_corner_inside || ab_intersects || bc_intersects || ca_intersects {
return ContainsRelation::CANDIDATE;
}
ContainsRelation::DISJOINT
}
/// Finds documents whose polygons contain the query bounding box.
///
/// Traverses the tree starting at `offset` (typically `segment.root_offset`), testing each
/// triangle using three-state logic: `CANDIDATE` (query might be contained), `NOTWITHIN` (boundary
/// edge crosses query), or `DISJOINT` (no overlap). Only boundary edges are tested for crossing -
/// internal tessellation edges are ignored.
///
/// The query is `[min_y, min_x, max_y, max_x]` in integer coordinates. Uses two `BitSet`
/// instances: `result` accumulates candidates, `excluded` marks documents with disqualifying
/// boundary crossings. The final result is documents in `result` that are NOT in `excluded`.
pub fn search_contains(
segment: &Segment,
offset: i32,
query: &[i32; 4],
result: &mut BitSet,
excluded: &mut BitSet,
) -> io::Result<()> {
let bbox = segment.bounding_box(offset);
if !bbox_intersects(&bbox, query) {
} else if offset < 0 {
let leaf_node = segment.leaf_node(offset);
// bbox crosses query → test each triangle
let triangles = decompress_leaf(segment.leaf_page(&leaf_node))?;
for triangle in &triangles {
if triangle_intersects(triangle, query) {
let doc_id = triangle.doc_id;
if excluded.contains(doc_id) {
continue;
}
match triangle_contains_relation(triangle, query) {
ContainsRelation::CANDIDATE => result.insert(doc_id),
ContainsRelation::NOTWITHIN => excluded.insert(doc_id),
ContainsRelation::DISJOINT => {}
}
}
}
} else {
let branch_node = segment.branch_node(offset);
search_contains(segment, branch_node.left, query, result, excluded)?;
search_contains(segment, branch_node.right, query, result, excluded)?;
}
Ok(())
}
/// HUSH
pub struct LeafPageIterator<'a> {
segment: &'a Segment<'a>,
descent_stack: Vec<i32>,
}
impl<'a> LeafPageIterator<'a> {
/// HUSH
pub fn new(segment: &'a Segment<'a>) -> Self {
Self {
segment,
descent_stack: vec![segment.root_offset],
}
}
}
impl<'a> Iterator for LeafPageIterator<'a> {
type Item = io::Result<Vec<Triangle>>;
fn next(&mut self) -> Option<Self::Item> {
let offset = self.descent_stack.pop()?;
if offset < 0 {
let leaf_node = self.segment.leaf_node(offset);
let leaf_page = self.segment.leaf_page(&leaf_node);
match decompress_leaf(&leaf_page) {
Ok(triangles) => Some(Ok(triangles)),
Err(e) => Some(Err(e)),
}
} else {
let branch_node = self.segment.branch_node(offset);
self.descent_stack.push(branch_node.right);
self.descent_stack.push(branch_node.left);
self.next()
}
}
}

View File

@@ -1,300 +0,0 @@
//! Delta compression for block kd-tree leaves.
//!
//! Delta compression with dimension-major bit-packing for block kd-tree leaves. Each leaf contains
//! ≤512 triangles sorted by the split dimension (the dimension with maximum spread chosen during
//! tree construction). We store all 512 values for dimension 0, then all for dimension 1, etc.,
//! enabling tight bit-packing per dimension and better cache locality during decode.
//!
//! The split dimension is already optimal for compression. Since triangles in a leaf are spatially
//! clustered, sorting by the max-spread dimension naturally orders them by proximity in all
//! dimensions. Testing multiple sort orders would be wasted effort.
//!
//! Our encoding uses ~214 units/meter for latitude, ~107 units/meter for longitude (millimeter
//! precision). A quarter-acre lot (32m × 32m) spans ~6,850 units across 512 sorted triangles = avg
//! delta ~13 units = 4 bits. A baseball field (100m × 100m) is ~42 unit deltas = 6 bits. Even
//! Russia-sized polygons (1000 km) average ~418,000 unit deltas = 19 bits. Time will tell if these
//! numbers are anything to go by in practice.
//!
//! Our format for use with leaf-page triangles: First a count of triangles in the page, then the
//! delta encoded doc_ids followed by delta encoding of each series of the triangle dimensions,
//! followed by delta encoding of the flags. Creates eight parallel arrays from which triangles can
//! be reconstructed.
//!
//! Note: Tantivy also has delta encoding in `sstable/src/delta.rs`, but that's for string
//! dictionary compression (prefix sharing + vint deltas). This module uses bit-packing with zigzag
//! encoding, which is optimal for our signed i32 spatial coordinates with small deltas. It uses
//! the same basic algorithm to compress u32 doc_ids.
use std::io::{self, Write};
fn zigzag_encode(x: i32) -> u32 {
((x << 1) ^ (x >> 31)) as u32
}
fn zigzag_decode(x: u32) -> i32 {
((x >> 1) ^ (0u32.wrapping_sub(x & 1))) as i32
}
/// Trait for reading values by index during compression.
///
/// The `Compressible` trait allows `compress()` to work with two different data sources,
/// `Vec<Triangle>` when indexing and memory mapped `Triangle` when merging. The compress function
/// reads values on-demand via `get()`, computing deltas and bit-packing without intermediate
/// allocations.
pub trait Compressible {
/// The type of the values being compressed.
type Value: Copy;
/// Returns the number of values in this source.
fn len(&self) -> usize;
/// Returns the value at the given index.
fn get(&self, i: usize) -> Self::Value;
}
/// Operations for types that can be delta-encoded and bit-packed into four-byte words.
pub trait DeltaEncoder: Copy {
/// Computes a zigzag-encoded delta between two values.
fn compute_delta(current: Self, previous: Self) -> u32;
/// Converts a value to little-endian bytes for storage.
fn to_le_bytes(value: Self) -> [u8; 4];
}
impl DeltaEncoder for i32 {
fn compute_delta(current: Self, previous: Self) -> u32 {
zigzag_encode(current.wrapping_sub(previous))
}
fn to_le_bytes(value: Self) -> [u8; 4] {
value.to_le_bytes()
}
}
// Delta encoding for u32 values using wrapping arithmetic and zigzag encoding.
//
// This handles arbitrary u32 document IDs that may be non-sequential or widely spaced. The
// strategy uses wrapping subtraction followed by zigzag encoding:
//
// 1. wrapping_sub computes the difference modulo 2^32, producing a u32 result
// 2. Cast to i32 reinterprets the bit pattern as signed (two's complement)
// 3. zigzag_encode maps signed values to unsigned for efficient bit-packing:
// - Positive deltas (0, 1, 2...) encode to even numbers (0, 2, 4...)
// - Negative deltas (-1, -2, -3...) encode to odd numbers (1, 3, 5...)
//
// Example with large jump (doc_id 0 → 4,000,000,000):
// delta = 4_000_000_000u32.wrapping_sub(0) = 4_000_000_000u32
// as i32 = -294,967,296 (bit pattern preserved via two's complement)
// zigzag_encode(-294,967,296) = some u32 value
//
// During decompression, zigzag_decode returns the signed i32 delta, which is cast back to u32 and
// added with wrapping_add. The bit pattern round-trips correctly because wrapping_add and
// wrapping_sub are mathematical inverses modulo 2^32, making this encoding symmetric for the full
// u32 range.
impl DeltaEncoder for u32 {
fn compute_delta(current: Self, previous: Self) -> u32 {
zigzag_encode(current.wrapping_sub(previous) as i32)
}
fn to_le_bytes(value: Self) -> [u8; 4] {
value.to_le_bytes()
}
}
/// Compresses values from a `Compressible` source using delta encoding and bit-packing.
///
/// Computes signed deltas between consecutive values, zigzag encodes them, and determines the
/// minimum bit width needed to represent all deltas. Writes a header (1 byte for bit width +
/// 4 bytes for first value in little-endian), then bit-packs the remaining deltas.
pub fn compress<T, W>(compressible: &T, write: &mut W) -> io::Result<()>
where
T: Compressible,
T::Value: DeltaEncoder,
W: Write,
{
let mut max_delta = 0u32;
for i in 1..compressible.len() {
let delta = T::Value::compute_delta(compressible.get(i), compressible.get(i - 1));
max_delta = max_delta.max(delta);
}
let bits = if max_delta == 0 {
0u32
} else {
32 - max_delta.leading_zeros() as u32
};
let mask = if bits == 32 {
u32::MAX
} else {
(1u32 << bits) - 1
};
write.write_all(&[bits as u8])?;
write.write_all(&T::Value::to_le_bytes(compressible.get(0)))?;
let mut buffer = 0u64;
let mut buffer_bits = 0u32;
for i in 1..compressible.len() {
let delta = T::Value::compute_delta(compressible.get(i), compressible.get(i - 1));
let value = delta & mask;
buffer = (buffer << bits) | (value as u64);
buffer_bits += bits;
while buffer_bits >= 8 {
buffer_bits -= 8;
write.write_all(&[(buffer >> buffer_bits) as u8])?;
}
}
if buffer_bits > 0 {
write.write_all(&[(buffer << (8 - buffer_bits)) as u8])?;
}
Ok(())
}
/// Operations needed to decompress delta-encoded values back to their original form.
pub trait DeltaDecoder: Copy + Sized {
/// Converts from little-endian bytes to a value.
fn from_le_bytes(bytes: [u8; 4]) -> Self;
/// Applies a zigzag-decoded delta to reconstruct the next value.
fn apply_delta(value: Self, delta: u32) -> Self;
}
impl DeltaDecoder for i32 {
fn from_le_bytes(bytes: [u8; 4]) -> Self {
i32::from_le_bytes(bytes)
}
fn apply_delta(value: Self, delta: u32) -> Self {
value.wrapping_add(zigzag_decode(delta))
}
}
impl DeltaDecoder for u32 {
fn from_le_bytes(bytes: [u8; 4]) -> Self {
u32::from_le_bytes(bytes)
}
fn apply_delta(value: Self, delta: u32) -> Self {
value.wrapping_add(zigzag_decode(delta) as u32)
}
}
/// Decompresses bit-packed delta-encoded values from a byte slice.
///
/// Reads the header to get bit width and first value, then unpacks the bit-packed deltas, applies
/// zigzag decoding, and reconstructs the original values by accumulating deltas.
///
/// Returns the count of bytes read from `data`.
pub fn decompress<T: DeltaDecoder, F>(
data: &[u8],
count: usize,
mut process: F,
) -> io::Result<usize>
where
F: FnMut(usize, T),
{
if data.len() < 5 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"truncated header",
));
}
let bits = data[0] as u32;
let first = T::from_le_bytes([data[1], data[2], data[3], data[4]]);
process(0, first);
let mut offset = 5;
if bits == 0 {
// All deltas are zero - all values same as first
for i in 1..count {
process(i, first);
}
return Ok(offset);
}
let mut buffer = 0u64;
let mut buffer_bits = 0u32;
let mut prev = first;
for i in 1..count {
// Refill buffer with bytes
while buffer_bits < bits {
if offset >= data.len() {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
format!("expected {} values but only decoded {}", count, i - 1),
));
}
buffer = (buffer << 8) | (data[offset] as u64);
offset += 1;
buffer_bits += 8;
}
if buffer_bits >= bits {
// Extract packed value
buffer_bits -= bits;
let encoded = ((buffer >> buffer_bits) & ((1u64 << bits) - 1)) as u32;
let value = T::apply_delta(prev, encoded);
process(i, value);
prev = value;
} else {
break;
}
}
Ok(offset)
}
#[cfg(test)]
mod test {
use super::*;
pub struct CompressibleI32Vec {
vec: Vec<i32>,
}
impl CompressibleI32Vec {
fn new(vec: Vec<i32>) -> Self {
CompressibleI32Vec { vec }
}
}
impl Compressible for CompressibleI32Vec {
type Value = i32;
fn len(&self) -> usize {
return self.vec.len();
}
fn get(&self, i: usize) -> i32 {
return self.vec[i];
}
}
#[test]
fn test_spatial_delta_compress_decompress() {
let values = vec![
100000, 99975, 100050, 99980, 100100, 100025, 99950, 100150, 100075, 99925, 100200,
100100,
];
let compressible = CompressibleI32Vec::new(values.clone());
let mut buffer = Vec::new();
compress(&compressible, &mut buffer).unwrap();
let mut vec = Vec::new();
decompress::<i32, _>(&buffer, values.len(), |_, value| vec.push(value)).unwrap();
assert_eq!(vec, values);
}
#[test]
fn test_spatial_delta_bad_header() {
let mut vec = Vec::new();
let result = decompress::<i32, _>(&[1, 2], 1, |_, value| vec.push(value));
assert!(result.is_err());
}
#[test]
fn test_spatial_delta_insufficient_data() {
let mut vec = Vec::new();
let result = decompress::<i32, _>(&[5, 0, 0, 0, 1], 12, |_, value| vec.push(value));
assert!(result.is_err());
}
#[test]
fn test_spatial_delta_single_item() {
let mut vec = Vec::new();
decompress::<i32, _>(&[5, 1, 0, 0, 0], 1, |_, value| vec.push(value)).unwrap();
assert_eq!(vec[0], 1);
}
#[test]
fn test_spatial_delta_zero_length_delta() {
let values = vec![1, 1, 1];
let compressible = CompressibleI32Vec::new(values.clone());
let mut buffer = Vec::new();
compress(&compressible, &mut buffer).unwrap();
let mut vec = Vec::new();
decompress::<i32, _>(&buffer, values.len(), |_, value| vec.push(value)).unwrap();
assert_eq!(vec, values);
}
}

View File

@@ -1,490 +0,0 @@
//! HUSH
use std::io::{self, Read, Write};
use common::{BinarySerializable, VInt};
use serde_json::{json, Map, Value};
use crate::spatial::point::GeoPoint;
use crate::spatial::xor::{compress_f64, decompress_f64};
/// HUSH
#[derive(Debug)]
pub enum GeometryError {
/// HUSH
MissingType,
/// HUSH
MissingField(String), // "expected array", "wrong nesting depth", etc
/// HUSH
UnsupportedType(String),
/// HUSH
InvalidCoordinate(String), // Can report the actual bad value
/// HUSH
InvalidStructure(String), // "expected array", "wrong nesting depth", etc
}
/// HUSH
#[derive(Debug, Clone, PartialEq)]
pub enum Geometry {
/// HUSH
Point(GeoPoint),
/// HUSH
MultiPoint(Vec<GeoPoint>),
/// HUSH
LineString(Vec<GeoPoint>),
/// HUSH
MultiLineString(Vec<Vec<GeoPoint>>),
/// HUSH
Polygon(Vec<Vec<GeoPoint>>),
/// HUSH
MultiPolygon(Vec<Vec<Vec<GeoPoint>>>),
/// HUSH
GeometryCollection(Vec<Self>),
}
impl Geometry {
/// HUSH
pub fn from_geojson(object: &Map<String, Value>) -> Result<Self, GeometryError> {
let geometry_type = object
.get("type")
.and_then(|v| v.as_str())
.ok_or(GeometryError::MissingType)?;
match geometry_type {
"Point" => {
let coordinates = get_coordinates(object)?;
let point = to_point(coordinates)?;
Ok(Geometry::Point(point))
}
"MultiPoint" => {
let coordinates = get_coordinates(object)?;
let multi_point = to_line_string(coordinates)?;
Ok(Geometry::MultiPoint(multi_point))
}
"LineString" => {
let coordinates = get_coordinates(object)?;
let line_string = to_line_string(coordinates)?;
if line_string.len() < 2 {
return Err(GeometryError::InvalidStructure(
"a line string contains at least 2 points".to_string(),
));
}
Ok(Geometry::LineString(line_string))
}
"MultiLineString" => {
let coordinates = get_coordinates(object)?;
let multi_line_string = to_multi_line_string(coordinates)?;
for line_string in &multi_line_string {
if line_string.len() < 2 {
return Err(GeometryError::InvalidStructure(
"a line string contains at least 2 points".to_string(),
));
}
}
Ok(Geometry::MultiLineString(multi_line_string))
}
"Polygon" => {
let coordinates = get_coordinates(object)?;
let polygon = to_multi_line_string(coordinates)?;
for ring in &polygon {
if ring.len() < 3 {
return Err(GeometryError::InvalidStructure(
"a polygon ring contains at least 3 points".to_string(),
));
}
}
Ok(Geometry::Polygon(polygon))
}
"MultiPolygon" => {
let mut result = Vec::new();
let multi_polygons = get_coordinates(object)?;
let multi_polygons =
multi_polygons
.as_array()
.ok_or(GeometryError::InvalidStructure(
"expected an array of polygons".to_string(),
))?;
for polygon in multi_polygons {
let polygon = to_multi_line_string(polygon)?;
for ring in &polygon {
if ring.len() < 3 {
return Err(GeometryError::InvalidStructure(
"a polygon ring contains at least 3 points".to_string(),
));
}
}
result.push(polygon);
}
Ok(Geometry::MultiPolygon(result))
}
"GeometriesCollection" => {
let geometries = object
.get("geometries")
.ok_or(GeometryError::MissingField("geometries".to_string()))?;
let geometries = geometries
.as_array()
.ok_or(GeometryError::InvalidStructure(
"geometries is not an array".to_string(),
))?;
let mut result = Vec::new();
for geometry in geometries {
let object = geometry.as_object().ok_or(GeometryError::InvalidStructure(
"geometry is not an object".to_string(),
))?;
result.push(Geometry::from_geojson(object)?);
}
Ok(Geometry::GeometryCollection(result))
}
_ => Err(GeometryError::UnsupportedType(geometry_type.to_string())),
}
}
/// Serialize the geometry to GeoJSON format.
/// https://fr.wikipedia.org/wiki/GeoJSON
pub fn to_geojson(&self) -> Map<String, Value> {
let mut map = Map::new();
match self {
Geometry::Point(point) => {
map.insert("type".to_string(), Value::String("Point".to_string()));
let coords = json!([point.lon, point.lat]);
map.insert("coordinates".to_string(), coords);
}
Geometry::MultiPoint(points) => {
map.insert("type".to_string(), Value::String("MultiPoint".to_string()));
let coords: Vec<Value> = points.iter().map(|p| json!([p.lon, p.lat])).collect();
map.insert("coordinates".to_string(), Value::Array(coords));
}
Geometry::LineString(line) => {
map.insert("type".to_string(), Value::String("LineString".to_string()));
let coords: Vec<Value> = line.iter().map(|p| json!([p.lon, p.lat])).collect();
map.insert("coordinates".to_string(), Value::Array(coords));
}
Geometry::MultiLineString(lines) => {
map.insert(
"type".to_string(),
Value::String("MultiLineString".to_string()),
);
let coords: Vec<Value> = lines
.iter()
.map(|line| Value::Array(line.iter().map(|p| json!([p.lon, p.lat])).collect()))
.collect();
map.insert("coordinates".to_string(), Value::Array(coords));
}
Geometry::Polygon(rings) => {
map.insert("type".to_string(), Value::String("Polygon".to_string()));
let coords: Vec<Value> = rings
.iter()
.map(|ring| Value::Array(ring.iter().map(|p| json!([p.lon, p.lat])).collect()))
.collect();
map.insert("coordinates".to_string(), Value::Array(coords));
}
Geometry::MultiPolygon(polygons) => {
map.insert(
"type".to_string(),
Value::String("MultiPolygon".to_string()),
);
let coords: Vec<Value> = polygons
.iter()
.map(|polygon| {
Value::Array(
polygon
.iter()
.map(|ring| {
Value::Array(
ring.iter().map(|p| json!([p.lon, p.lat])).collect(),
)
})
.collect(),
)
})
.collect();
map.insert("coordinates".to_string(), Value::Array(coords));
}
Geometry::GeometryCollection(geometries) => {
map.insert(
"type".to_string(),
Value::String("GeometryCollection".to_string()),
);
let geoms: Vec<Value> = geometries
.iter()
.map(|g| Value::Object(g.to_geojson()))
.collect();
map.insert("geometries".to_string(), Value::Array(geoms));
}
}
map
}
}
fn get_coordinates(object: &Map<String, Value>) -> Result<&Value, GeometryError> {
let coordinates = object
.get("coordinates")
.ok_or(GeometryError::MissingField("coordinates".to_string()))?;
Ok(coordinates)
}
fn to_point(value: &Value) -> Result<GeoPoint, GeometryError> {
let lonlat = value.as_array().ok_or(GeometryError::InvalidStructure(
"expected 2 element array pair of lon/lat".to_string(),
))?;
if lonlat.len() != 2 {
return Err(GeometryError::InvalidStructure(
"expected 2 element array pair of lon/lat".to_string(),
));
}
let lon = lonlat[0].as_f64().ok_or(GeometryError::InvalidCoordinate(
"longitude must be f64".to_string(),
))?;
if !lon.is_finite() || !(-180.0..=180.0).contains(&lon) {
return Err(GeometryError::InvalidCoordinate(format!(
"invalid longitude: {}",
lon
)));
}
let lat = lonlat[1].as_f64().ok_or(GeometryError::InvalidCoordinate(
"latitude must be f64".to_string(),
))?;
if !lat.is_finite() || !(-90.0..=90.0).contains(&lat) {
return Err(GeometryError::InvalidCoordinate(format!(
"invalid latitude: {}",
lat
)));
}
Ok(GeoPoint { lon, lat })
}
fn to_line_string(value: &Value) -> Result<Vec<GeoPoint>, GeometryError> {
let mut result = Vec::new();
let coordinates = value.as_array().ok_or(GeometryError::InvalidStructure(
"expected an array of lon/lat arrays".to_string(),
))?;
for coordinate in coordinates {
result.push(to_point(coordinate)?);
}
Ok(result)
}
fn to_multi_line_string(value: &Value) -> Result<Vec<Vec<GeoPoint>>, GeometryError> {
let mut result = Vec::new();
let coordinates = value.as_array().ok_or(GeometryError::InvalidStructure(
"expected an array of an array of lon/lat arrays".to_string(),
))?;
for coordinate in coordinates {
result.push(to_line_string(coordinate)?);
}
Ok(result)
}
impl BinarySerializable for Geometry {
fn serialize<W: Write + ?Sized>(&self, writer: &mut W) -> io::Result<()> {
match self {
Geometry::Point(point) => {
0u8.serialize(writer)?;
point.lon.serialize(writer)?;
point.lat.serialize(writer)?;
Ok(())
}
Geometry::MultiPoint(points) => {
1u8.serialize(writer)?;
serialize_line_string(points, writer)
}
Geometry::LineString(line_string) => {
2u8.serialize(writer)?;
serialize_line_string(line_string, writer)
}
Geometry::MultiLineString(multi_line_string) => {
3u8.serialize(writer)?;
serialize_polygon(&multi_line_string[..], writer)
}
Geometry::Polygon(polygon) => {
4u8.serialize(writer)?;
serialize_polygon(polygon, writer)
}
Geometry::MultiPolygon(multi_polygon) => {
5u8.serialize(writer)?;
BinarySerializable::serialize(&VInt(multi_polygon.len() as u64), writer)?;
for polygon in multi_polygon {
BinarySerializable::serialize(&VInt(polygon.len() as u64), writer)?;
for ring in polygon {
BinarySerializable::serialize(&VInt(ring.len() as u64), writer)?;
}
}
let mut lon = Vec::new();
let mut lat = Vec::new();
for polygon in multi_polygon {
for ring in polygon {
for point in ring {
lon.push(point.lon);
lat.push(point.lat);
}
}
}
let lon = compress_f64(&lon);
let lat = compress_f64(&lat);
VInt(lon.len() as u64).serialize(writer)?;
writer.write_all(&lon)?;
VInt(lat.len() as u64).serialize(writer)?;
writer.write_all(&lat)?;
Ok(())
}
Geometry::GeometryCollection(geometries) => {
6u8.serialize(writer)?;
BinarySerializable::serialize(&VInt(geometries.len() as u64), writer)?;
for geometry in geometries {
geometry.serialize(writer)?;
}
Ok(())
}
}
}
fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
let discriminant: u8 = BinarySerializable::deserialize(reader)?;
match discriminant {
0 => {
let lon = BinarySerializable::deserialize(reader)?;
let lat = BinarySerializable::deserialize(reader)?;
Ok(Geometry::Point(GeoPoint { lon, lat }))
}
1 => Ok(Geometry::MultiPoint(deserialize_line_string(reader)?)),
2 => Ok(Geometry::LineString(deserialize_line_string(reader)?)),
3 => Ok(Geometry::MultiLineString(deserialize_polygon(reader)?)),
4 => Ok(Geometry::Polygon(deserialize_polygon(reader)?)),
5 => {
let polygon_count = VInt::deserialize(reader)?.0 as usize;
let mut polygons = Vec::new();
let mut count = 0;
for _ in 0..polygon_count {
let ring_count = VInt::deserialize(reader)?.0 as usize;
let mut rings = Vec::new();
for _ in 0..ring_count {
let point_count = VInt::deserialize(reader)?.0 as usize;
rings.push(point_count);
count += point_count;
}
polygons.push(rings);
}
let lon_bytes: Vec<u8> = BinarySerializable::deserialize(reader)?;
let lat_bytes: Vec<u8> = BinarySerializable::deserialize(reader)?;
let lon = decompress_f64(&lon_bytes, count);
let lat = decompress_f64(&lat_bytes, count);
let mut multi_polygon = Vec::new();
let mut offset = 0;
for rings in polygons {
let mut polygon = Vec::new();
for point_count in rings {
let mut ring = Vec::new();
for _ in 0..point_count {
ring.push(GeoPoint {
lon: lon[offset],
lat: lat[offset],
});
offset += 1;
}
polygon.push(ring);
}
multi_polygon.push(polygon);
}
Ok(Geometry::MultiPolygon(multi_polygon))
}
6 => {
let geometry_count = VInt::deserialize(reader)?.0 as usize;
let mut geometries = Vec::new();
for _ in 0..geometry_count {
geometries.push(Geometry::deserialize(reader)?);
}
Ok(Geometry::GeometryCollection(geometries))
}
_ => Err(io::Error::new(
io::ErrorKind::InvalidData,
"invalid geometry type",
)),
}
}
}
fn serialize_line_string<W: Write + ?Sized>(line: &[GeoPoint], writer: &mut W) -> io::Result<()> {
BinarySerializable::serialize(&VInt(line.len() as u64), writer)?;
let mut lon = Vec::new();
let mut lat = Vec::new();
for point in line {
lon.push(point.lon);
lat.push(point.lat);
}
let lon = compress_f64(&lon);
let lat = compress_f64(&lat);
VInt(lon.len() as u64).serialize(writer)?;
writer.write_all(&lon)?;
VInt(lat.len() as u64).serialize(writer)?;
writer.write_all(&lat)?;
Ok(())
}
fn serialize_polygon<W: Write + ?Sized>(
line_string: &[Vec<GeoPoint>],
writer: &mut W,
) -> io::Result<()> {
BinarySerializable::serialize(&VInt(line_string.len() as u64), writer)?;
for ring in line_string {
BinarySerializable::serialize(&VInt(ring.len() as u64), writer)?;
}
let mut lon: Vec<f64> = Vec::new();
let mut lat: Vec<f64> = Vec::new();
for ring in line_string {
for point in ring {
lon.push(point.lon);
lat.push(point.lat);
}
}
let lon: Vec<u8> = compress_f64(&lon);
let lat: Vec<u8> = compress_f64(&lat);
VInt(lon.len() as u64).serialize(writer)?;
writer.write_all(&lon)?;
VInt(lat.len() as u64).serialize(writer)?;
writer.write_all(&lat)?;
Ok(())
}
fn deserialize_line_string<R: Read>(reader: &mut R) -> io::Result<Vec<GeoPoint>> {
let point_count = VInt::deserialize(reader)?.0 as usize;
let lon_bytes: Vec<u8> = BinarySerializable::deserialize(reader)?;
let lat_bytes: Vec<u8> = BinarySerializable::deserialize(reader)?;
let lon: Vec<f64> = decompress_f64(&lon_bytes, point_count);
let lat: Vec<f64> = decompress_f64(&lat_bytes, point_count);
let mut line_string: Vec<GeoPoint> = Vec::new();
for offset in 0..point_count {
line_string.push(GeoPoint {
lon: lon[offset],
lat: lat[offset],
});
}
Ok(line_string)
}
fn deserialize_polygon<R: Read>(reader: &mut R) -> io::Result<Vec<Vec<GeoPoint>>> {
let ring_count = VInt::deserialize(reader)?.0 as usize;
let mut rings = Vec::new();
let mut count = 0;
for _ in 0..ring_count {
let point_count = VInt::deserialize(reader)?.0 as usize;
rings.push(point_count);
count += point_count;
}
let lon_bytes: Vec<u8> = BinarySerializable::deserialize(reader)?;
let lat_bytes: Vec<u8> = BinarySerializable::deserialize(reader)?;
let lon: Vec<f64> = decompress_f64(&lon_bytes, count);
let lat: Vec<f64> = decompress_f64(&lat_bytes, count);
let mut polygon: Vec<Vec<GeoPoint>> = Vec::new();
let mut offset = 0;
for point_count in rings {
let mut ring = Vec::new();
for _ in 0..point_count {
ring.push(GeoPoint {
lon: lon[offset],
lat: lat[offset],
});
offset += 1;
}
polygon.push(ring);
}
Ok(polygon)
}

View File

@@ -1,12 +0,0 @@
//! Spatial module (implements a block kd-tree)
pub mod bkd;
pub mod delta;
pub mod geometry;
pub mod point;
pub mod radix_select;
pub mod reader;
pub mod serializer;
pub mod triangle;
pub mod writer;
pub mod xor;

View File

@@ -1,8 +0,0 @@
/// A point in the geographical coordinate system.
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct GeoPoint {
/// Longitude
pub lon: f64,
/// Latitude
pub lat: f64,
}

View File

@@ -1,122 +0,0 @@
//! Radix selection for block kd-tree tree partitioning.
//!
//! Implements byte-wise histogram selection to find median values without comparisons, enabling
//! efficient partitioning of spatial data during block kd-tree construction. Processes values
//! through multiple passes, building histograms for each byte position after a common prefix,
//! avoiding the need to sort or compare elements directly.
/// Performs radix selection to find the median value without comparisons by building byte-wise
/// histograms.
pub struct RadixSelect {
histogram: [usize; 256],
prefix: Vec<u8>,
offset: usize,
nth: usize,
}
impl RadixSelect {
/// Creates a new radix selector for finding the nth element among values with a common prefix.
///
/// The offset specifies how many matching elements appeared in previous buckets (from earlier
/// passes). The nth parameter is 0-indexed, so pass 31 to find the 32nd element (median of
/// 64).
pub fn new(prefix: Vec<u8>, offset: usize, nth: usize) -> Self {
RadixSelect {
histogram: [0; 256],
prefix,
offset,
nth,
}
}
/// Updates the histogram with a value if it matches the current prefix.
///
/// Values that don't start with the prefix are ignored. For matching values, increments the
/// count for the byte at position `prefix.len()`.
pub fn update(&mut self, value: i32) {
let bytes = value.to_be_bytes();
if !bytes.starts_with(&self.prefix) {
return;
}
let byte = bytes[self.prefix.len()];
self.histogram[byte as usize] += 1;
}
/// Finds which bucket contains the nth element and returns the bucket value and offset.
///
/// Returns a tuple of `(bucket_byte, count_before)` where bucket_byte is the value of the byte
/// that contains the nth element, and count_before is the number of elements in earlier
/// buckets (becomes the offset for the next pass).
pub fn nth(&self) -> (u8, usize) {
let mut count = self.offset;
for (bucket, &frequency) in self.histogram.iter().enumerate() {
if count + frequency > self.nth as usize {
return (bucket as u8, count);
}
count += frequency;
}
panic!("nth element {} not found in histogram", self.nth);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn radix_selection() {
let dimensions = [
(
vec![
0x10101010, 0x10101011, 0x10101012, 0x10101013, 0x10101014, 0x10101015,
0x10101016, 0x10101017, 0x10101018, 0x10101019, 0x1010101A, 0x1010101B,
0x1010101C, 0x1010101D, 0x1010101E, 0x1010101F, 0x10101020, 0x10101021,
0x10101022, 0x10101023, 0x10101024, 0x10101025, 0x10101026, 0x10101027,
0x10101028, 0x10101029, 0x1010102A, 0x1010102B, 0x1010102C, 0x1010102D,
0x1010102E, 0x1010102F, 0x10101030, 0x10101031, 0x10101032, 0x10101033,
0x10101034, 0x10101035, 0x10101036, 0x10101037, 0x10101038, 0x10101039,
0x1010103A, 0x1010103B, 0x1010103C, 0x1010103D, 0x1010103E, 0x1010103F,
0x10101040, 0x10101041, 0x10101042, 0x10101043, 0x10101044, 0x10101045,
0x10101046, 0x10101047, 0x10101048, 0x10101049, 0x1010104A, 0x1010104B,
0x1010104C, 0x1010104D, 0x1010104E, 0x1010104F,
],
[(0x10, 0), (0x10, 0), (0x10, 0), (0x2F, 31)],
),
(
vec![
0x10101010, 0x10101011, 0x10101012, 0x10101013, 0x10101014, 0x10101015,
0x10101016, 0x10101017, 0x10101018, 0x10101019, 0x1010101A, 0x1010101B,
0x1010101C, 0x1010101D, 0x1010101E, 0x1010101F, 0x10101020, 0x10101021,
0x10101022, 0x10101023, 0x10101024, 0x20101025, 0x20201026, 0x20301027,
0x20401028, 0x20501029, 0x2060102A, 0x2070102B, 0x2080102C, 0x2090102D,
0x20A0102E, 0x20B0102F, 0x20C01030, 0x20D01031, 0x20E01032, 0x20F01033,
0x20F11034, 0x20F21035, 0x20F31036, 0x20F41037, 0x20F51038, 0x20F61039,
0x3010103A, 0x3010103B, 0x3010103C, 0x3010103D, 0x3010103E, 0x3010103F,
0x30101040, 0x30101041, 0x30101042, 0x30101043, 0x30101044, 0x30101045,
0x30101046, 0x30101047, 0x30101048, 0x30101049, 0x3010104A, 0x3010104B,
0x3010104C, 0x3010104D, 0x3010104E, 0x3010104F,
],
[(0x20, 21), (0xB0, 31), (0x10, 31), (0x2F, 31)],
),
];
for (numbers, expected) in dimensions {
let mut offset = 0;
let mut prefix = Vec::new();
for i in 0..4 {
let mut radix_select = RadixSelect::new(prefix.clone(), offset, 31);
for &number in &numbers {
radix_select.update(number);
}
let (byte, count) = radix_select.nth();
if i != 3 {
assert_eq!(expected[i].0, byte);
assert_eq!(expected[i].1, count);
}
prefix.push(byte);
offset = count;
}
let mut sorted = numbers.clone();
sorted.sort();
let radix_result = i32::from_be_bytes(prefix.as_slice().try_into().unwrap());
assert_eq!(radix_result, sorted[31]);
}
}
}

View File

@@ -1,70 +0,0 @@
//! HUSH
use std::io;
use std::sync::Arc;
use common::file_slice::FileSlice;
use common::OwnedBytes;
use crate::directory::CompositeFile;
use crate::schema::Field;
use crate::space_usage::PerFieldSpaceUsage;
#[derive(Clone)]
pub struct SpatialReaders {
data: Arc<CompositeFile>,
}
impl SpatialReaders {
pub fn empty() -> SpatialReaders {
SpatialReaders {
data: Arc::new(CompositeFile::empty()),
}
}
/// Creates a field norm reader.
pub fn open(file: FileSlice) -> crate::Result<SpatialReaders> {
let data = CompositeFile::open(&file)?;
Ok(SpatialReaders {
data: Arc::new(data),
})
}
/// Returns the FieldNormReader for a specific field.
pub fn get_field(&self, field: Field) -> crate::Result<Option<SpatialReader>> {
if let Some(file) = self.data.open_read(field) {
let spatial_reader = SpatialReader::open(file)?;
Ok(Some(spatial_reader))
} else {
Ok(None)
}
}
/// Return a break down of the space usage per field.
pub fn space_usage(&self) -> PerFieldSpaceUsage {
self.data.space_usage()
}
/// Returns a handle to inner file
pub fn get_inner_file(&self) -> Arc<CompositeFile> {
self.data.clone()
}
}
/// HUSH
#[derive(Clone)]
pub struct SpatialReader {
data: OwnedBytes,
}
impl SpatialReader {
/// Opens the spatial reader from a `FileSlice`. Returns `None` if the file is empty (no
/// spatial fields indexed.)
pub fn open(spatial_file: FileSlice) -> io::Result<SpatialReader> {
let data = spatial_file.read_bytes()?;
Ok(SpatialReader { data })
}
/// HUSH
pub fn get_bytes(&self) -> &[u8] {
self.data.as_ref()
}
}

View File

@@ -1,37 +0,0 @@
//! HUSH
use std::io;
use std::io::Write;
use crate::directory::{CompositeWrite, WritePtr};
use crate::schema::Field;
use crate::spatial::bkd::write_block_kd_tree;
use crate::spatial::triangle::Triangle;
/// The fieldnorms serializer is in charge of
/// the serialization of field norms for all fields.
pub struct SpatialSerializer {
composite_write: CompositeWrite,
}
impl SpatialSerializer {
/// Create a composite file from the write pointer.
pub fn from_write(write: WritePtr) -> io::Result<SpatialSerializer> {
// just making room for the pointer to header.
let composite_write = CompositeWrite::wrap(write);
Ok(SpatialSerializer { composite_write })
}
/// Serialize the given field
pub fn serialize_field(&mut self, field: Field, triangles: &mut [Triangle]) -> io::Result<()> {
let write = self.composite_write.for_field(field);
write_block_kd_tree(triangles, write)?;
write.flush()?;
Ok(())
}
/// Clean up, flush, and close.
pub fn close(self) -> io::Result<()> {
self.composite_write.close()?;
Ok(())
}
}

View File

@@ -1,515 +0,0 @@
//! A triangle encoding with bounding box in the first four words for efficient spatial pruning.
//!
//! Encodes triangles with the bounding box in the first four words, enabling efficient spatial
//! pruning during tree traversal without reconstructing the full triangle. The remaining words
//! contain an additional vertex and packed reconstruction metadata, allowing exact triangle
//! recovery when needed.
use i_triangle::advanced::delaunay::IntDelaunay;
use i_triangle::i_overlay::i_float::int::point::IntPoint;
use crate::DocId;
const MINY_MINX_MAXY_MAXX_Y_X: i32 = 0;
const MINY_MINX_Y_X_MAXY_MAXX: i32 = 1;
const MAXY_MINX_Y_X_MINY_MAXX: i32 = 2;
const MAXY_MINX_MINY_MAXX_Y_X: i32 = 3;
const Y_MINX_MINY_X_MAXY_MAXX: i32 = 4;
const Y_MINX_MINY_MAXX_MAXY_X: i32 = 5;
const MAXY_MINX_MINY_X_Y_MAXX: i32 = 6;
const MINY_MINX_Y_MAXX_MAXY_X: i32 = 7;
/// Converts geographic coordinates (WGS84 lat/lon) to integer spatial coordinates.
///
/// Maps the full globe to the i32 range using linear scaling:
/// - Latitude: -90° to +90° → -2³¹ to +2³¹-1
/// - Longitude: -180° to +180° → -2³¹ to +2³¹-1
///
/// Provides approximately 214 units/meter for latitude and 107 units/meter for longitude, giving
/// millimeter-level precision. Uses `floor()` to ensure consistent quantization.
///
/// Returns `(y, x)` where y=latitude coordinate, x=longitude coordinate.
pub fn latlon_to_point(lat: f64, lon: f64) -> (i32, i32) {
let y = (lat / (180.0 / (1i64 << 32) as f64)).floor() as i32;
let x = (lon / (360.0 / (1i64 << 32) as f64)).floor() as i32;
(y, x)
}
/// Creates a bounding box from two lat/lon corner coordinates.
///
/// Takes two arbitrary corner points and produces a normalized bounding box in the internal
/// coordinate system. Automatically computes min/max for each dimension.
///
/// Returns `[min_y, min_x, max_y, max_x]` matching the internal storage format used throughout the
/// block kd-tree and triangle encoding.
pub fn latlon_to_bbox(lat1: f64, lon1: f64, lat2: f64, lon2: f64) -> [i32; 4] {
let (y1, x1) = latlon_to_point(lat1, lon1);
let (y2, x2) = latlon_to_point(lat2, lon2);
[y1.min(y2), x1.min(x2), y1.max(y2), x1.max(x2)]
}
/// A triangle encoded with bounding box in the first four words for efficient spatial pruning.
///
/// Encodes the bounding box, one vertex, boundary edge flags, and a reconstruction code that
/// together allow exact triangle recovery while optimizing for spatial query performance. Finally,
/// it contains the document id.
#[repr(C)]
#[derive(Debug)]
pub struct Triangle {
/// The bounding box, one vertex, followed by a packed integer containing boundary edge flags
/// and a reconstruction code.
pub words: [i32; 7],
/// The id of the document associated with this triangle.
pub doc_id: DocId,
}
impl Triangle {
/// Encodes a triangle with the bounding box in the first four words for efficient spatial
/// pruning.
///
/// Takes three vertices as `[y0, x0, y1, x1, y2, x2]` and edge boundary flags `[ab, bc, ca]`
/// indicating which edges are polygon boundaries. Returns a triangle struct with the bounding
/// box in the first four words as `[min_y, min_x, max_y, max_x]`. When decoded, the vertex
/// order may differ from the original input to `new()` due to normalized rotation.
///
/// The edge boundary flags are here to express whether an edge is part of the boundaries
/// in the tesselation of the larger polygon it belongs to.
pub fn new(doc_id: u32, triangle: [i32; 6], boundaries: [bool; 3]) -> Self {
let mut ay = triangle[0];
let mut ax = triangle[1];
let mut by = triangle[2];
let mut bx = triangle[3];
let mut cy = triangle[4];
let mut cx = triangle[5];
let mut ab = boundaries[0];
let mut bc = boundaries[1];
let mut ca = boundaries[2];
// rotate edges and place minX at the beginning
if bx < ax || cx < ax {
let temp_x = ax;
let temp_y = ay;
let temp_boundary = ab;
if bx < cx {
ax = bx;
ay = by;
ab = bc;
bx = cx;
by = cy;
bc = ca;
cx = temp_x;
cy = temp_y;
ca = temp_boundary;
} else {
ax = cx;
ay = cy;
ab = ca;
cx = bx;
cy = by;
ca = bc;
bx = temp_x;
by = temp_y;
bc = temp_boundary;
}
} else if ax == bx && ax == cx {
// degenerated case, all points with same longitude
// we need to prevent that ax is in the middle (not part of the MBS)
if by < ay || cy < ay {
let temp_x = ax;
let temp_y = ay;
let temp_boundary = ab;
if by < cy {
ax = bx;
ay = by;
ab = bc;
bx = cx;
by = cy;
bc = ca;
cx = temp_x;
cy = temp_y;
ca = temp_boundary;
} else {
ax = cx;
ay = cy;
ab = ca;
cx = bx;
cy = by;
ca = bc;
bx = temp_x;
by = temp_y;
bc = temp_boundary;
}
}
}
// change orientation if clockwise (CW)
if !is_counter_clockwise(
IntPoint { y: ay, x: ax },
IntPoint { y: by, x: bx },
IntPoint { y: cy, x: cx },
) {
// To change the orientation, we simply swap B and C.
let temp_x = bx;
let temp_y = by;
let temp_boundary = ab;
// ax and ay do not change, ab becomes bc
ab = ca;
bx = cx;
by = cy;
// bc does not change, ca becomes ab
cx = temp_x;
cy = temp_y;
ca = temp_boundary;
}
let min_x = ax;
let min_y = ay.min(by).min(cy);
let max_x = ax.max(bx).max(cx);
let max_y = ay.max(by).max(cy);
let (y, x, code) = if min_y == ay {
if max_y == by && max_x == bx {
(cy, cx, MINY_MINX_MAXY_MAXX_Y_X)
} else if max_y == cy && max_x == cx {
(by, bx, MINY_MINX_Y_X_MAXY_MAXX)
} else {
(by, cx, MINY_MINX_Y_MAXX_MAXY_X)
}
} else if max_y == ay {
if min_y == by && max_x == bx {
(cy, cx, MAXY_MINX_MINY_MAXX_Y_X)
} else if min_y == cy && max_x == cx {
(by, bx, MAXY_MINX_Y_X_MINY_MAXX)
} else {
(cy, bx, MAXY_MINX_MINY_X_Y_MAXX)
}
} else if max_x == bx && min_y == by {
(ay, cx, Y_MINX_MINY_MAXX_MAXY_X)
} else if max_x == cx && max_y == cy {
(ay, bx, Y_MINX_MINY_X_MAXY_MAXX)
} else {
panic!("Could not encode the provided triangle");
};
let boundaries_bits = (ab as i32) | ((bc as i32) << 1) | ((ca as i32) << 2);
let packed = code | (boundaries_bits << 3);
Triangle {
words: [min_y, min_x, max_y, max_x, y, x, packed],
doc_id: doc_id,
}
}
/// Builds a degenerated triangle degenerating for a single point.
/// All vertices are that point, and all vertices are boundaries.
pub fn from_point(doc_id: DocId, point_x: i32, point_y: i32) -> Triangle {
Triangle::new(
doc_id,
[point_y, point_x, point_y, point_x, point_y, point_x],
[true, true, true],
)
}
/// Builds a degenerated triangle for a segment.
/// Line segment AB is represented as the triangle ABA.
pub fn from_line_segment(doc_id: DocId, a_x: i32, a_y: i32, b_x: i32, b_y: i32) -> Triangle {
Triangle::new(doc_id, [a_y, a_x, b_y, b_x, a_y, a_x], [true, true, true])
}
/// Create a triangle with only the doc_id and the words initialized to zero.
///
/// The doc_id and words in the field are delta-compressed as a series with the doc_id
/// serialized first. When we reconstruct the triangle we can first reconstruct skeleton
/// triangles with the doc_id series, then populate the words directly from the decompression
/// as we decompress each series.
///
/// An immutable constructor would require that we decompress first into parallel `Vec`
/// instances, then loop through the count of triangles building triangles using a constructor
/// that takes all eight field values at once. This saves a copy, the triangle is the
/// decompression destination.
pub fn skeleton(doc_id: u32) -> Self {
Triangle {
doc_id: doc_id,
words: [0i32; 7],
}
}
/// Decodes the triangle back to vertex coordinates and boundary flags.
///
/// Returns vertices as `[y0, x0, y1, x1, y2, x2]` in CCW order and boundary flags `[ab, bc,
/// ca]`. The vertex order may differ from the original input to `new()` due to normalized CCW
/// rotation.
pub fn decode(&self) -> ([i32; 6], [bool; 3]) {
let packed = self.words[6];
let code = packed & 7; // Lower 3 bits
let boundaries = [
(packed & (1 << 3)) != 0, // bit 3 = ab
(packed & (1 << 4)) != 0, // bit 4 = bc
(packed & (1 << 5)) != 0, // bit 5 = ca
];
let (ay, ax, by, bx, cy, cx) = match code {
MINY_MINX_MAXY_MAXX_Y_X => (
self.words[0],
self.words[1],
self.words[2],
self.words[3],
self.words[4],
self.words[5],
),
MINY_MINX_Y_X_MAXY_MAXX => (
self.words[0],
self.words[1],
self.words[4],
self.words[5],
self.words[2],
self.words[3],
),
MAXY_MINX_Y_X_MINY_MAXX => (
self.words[2],
self.words[1],
self.words[4],
self.words[5],
self.words[0],
self.words[3],
),
MAXY_MINX_MINY_MAXX_Y_X => (
self.words[2],
self.words[1],
self.words[0],
self.words[3],
self.words[4],
self.words[5],
),
Y_MINX_MINY_X_MAXY_MAXX => (
self.words[4],
self.words[1],
self.words[0],
self.words[5],
self.words[2],
self.words[3],
),
Y_MINX_MINY_MAXX_MAXY_X => (
self.words[4],
self.words[1],
self.words[0],
self.words[3],
self.words[2],
self.words[5],
),
MAXY_MINX_MINY_X_Y_MAXX => (
self.words[2],
self.words[1],
self.words[0],
self.words[5],
self.words[4],
self.words[3],
),
MINY_MINX_Y_MAXX_MAXY_X => (
self.words[0],
self.words[1],
self.words[4],
self.words[3],
self.words[2],
self.words[5],
),
_ => panic!("Could not decode the provided triangle"),
};
([ay, ax, by, bx, cy, cx], boundaries)
}
/// Returns the bounding box coordinates of the encoded triangle.
///
/// Provides access to the bounding box `[min_y, min_x, max_y, max_x]` stored in the first four
/// words of the structure. The bounding box is stored first for efficient spatial pruning,
/// determining whether it is necessary to decode the triangle for precise intersection or
/// containment tests.
pub fn bbox(&self) -> &[i32] {
&self.words[..4]
}
}
/// Encodes the triangles of a Delaunay triangulation into block kd-tree triangles.
///
/// Takes the output of a Delaunay triangulation from `i_triangle` and encodes each triangle into
/// the normalized triangle used by the block kd-tree. Each triangle includes its bounding box,
/// vertex coordinates, and boundary edge flags that distinguish original polygon edges from
/// internal tessellation edges.
///
/// The boundary edge information provided by the `i_triangle` Delaunay triangulation is essential
/// for CONTAINS and WITHIN queries to work correctly.
pub fn delaunay_to_triangles(doc_id: u32, delaunay: &IntDelaunay, triangles: &mut Vec<Triangle>) {
for triangle in delaunay.triangles.iter() {
let bounds = [
triangle.neighbors[0] == usize::MAX,
triangle.neighbors[1] == usize::MAX,
triangle.neighbors[2] == usize::MAX,
];
let v0 = &delaunay.points[triangle.vertices[0].index];
let v1 = &delaunay.points[triangle.vertices[1].index];
let v2 = &delaunay.points[triangle.vertices[2].index];
triangles.push(Triangle::new(
doc_id,
[v0.y, v0.x, v1.y, v1.x, v2.y, v2.x],
bounds,
))
}
}
/// Returns true if the path A -> B -> C is Counter-Clockwise (CCW) or collinear.
/// Returns false if it is Clockwise (CW).
#[inline(always)]
fn is_counter_clockwise(a: IntPoint, b: IntPoint, c: IntPoint) -> bool {
// We calculate the 2D cross product (determinant) of vectors AB and AC.
// Formula: (bx - ax)(cy - ay) - (by - ay)(cx - ax)
// We cast to i64 to prevent overflow, as multiplying two i32s can exceed i32::MAX.
let val = (b.x as i64 - a.x as i64) * (c.y as i64 - a.y as i64)
- (b.y as i64 - a.y as i64) * (c.x as i64 - a.x as i64);
// If the result is positive, the triangle is CCW.
// If negative, it is CW.
// If zero, the points are collinear (we return true in that case).
val >= 0
}
#[cfg(test)]
mod tests {
use i_triangle::i_overlay::i_float::int::point::IntPoint;
use i_triangle::int::triangulatable::IntTriangulatable;
use super::*;
#[test]
fn encode_triangle() {
let test_cases = [
([1, 1, 3, 2, 2, 4], [true, false, false]),
([1, 1, 2, 4, 3, 2], [false, false, true]),
([2, 4, 1, 1, 3, 2], [false, true, false]),
([2, 4, 3, 2, 1, 1], [false, true, false]),
([3, 2, 1, 1, 2, 4], [true, false, false]),
([3, 2, 2, 4, 1, 1], [false, false, true]),
];
let ccw_coords = [1, 1, 2, 4, 3, 2];
let ccw_bounds = [false, false, true];
for (coords, bounds) in test_cases {
let triangle = Triangle::new(1, coords, bounds);
let (decoded_coords, decoded_bounds) = triangle.decode();
assert_eq!(decoded_coords, ccw_coords);
assert_eq!(decoded_bounds, ccw_bounds);
}
}
#[test]
fn test_cw_triangle_boundary_and_coord_flip() {
// 1. Define distinct coordinates for a Clockwise triangle
// Visual layout:
// A(50,40): Top Center-ish
// B(10,60): Bottom Right
// C(20,10): Bottom Left (Has the Minimum X=10)
// Path A->B->C is Clockwise.
let input_coords = [
50, 40, // A (y, x)
10, 60, // B
20, 10, // C
];
// 2. Define Boundaries [ab, bc, ca]
// We set BC=true and CA=false.
// The bug (ab=bc) would erroneously put 'true' into the first slot.
// The fix (ab=ca) should put 'false' into the first slot.
let input_bounds = [false, true, false];
// 3. Encode
let triangle = Triangle::new(1, input_coords, input_bounds);
let (decoded_coords, decoded_bounds) = triangle.decode();
// 4. Expected Coordinates
// The internal logic detects CW, swaps B/C to make it CCW:
// A(50,40) -> C(20,10) -> B(10,60)
// Then it rotates to put Min-X first.
// Min X is 10 (Vertex C).
// Final Sequence: C -> B -> A
let expected_coords = [
20, 10, // C
10, 60, // B
50, 40, // A
];
// 5. Expected Boundaries
// After Flip (A->C->B):
// Edge AC (was CA) = false
// Edge CB (was BC) = true
// Edge BA (was AB) = false
// Unrotated: [false, true, false]
// After Rotation (shifting to start at C):
// Shift left by 1: [true, false, false]
let expected_bounds = [true, false, false];
assert_eq!(
decoded_coords, expected_coords,
"Coordinates did not decode as expected"
);
assert_eq!(
decoded_bounds, expected_bounds,
"Boundary flags were incorrect (likely swap bug)"
);
}
#[test]
fn degenerate_triangle() {
let test_cases = [
(
[1, 1, 2, 1, 3, 1],
[true, false, false],
[1, 1, 2, 1, 3, 1],
[true, false, false],
),
(
[2, 1, 1, 1, 3, 1],
[true, false, false],
[1, 1, 3, 1, 2, 1],
[false, false, true],
),
(
[2, 1, 3, 1, 1, 1],
[false, false, true],
[1, 1, 2, 1, 3, 1],
[true, false, false],
),
];
for (coords, bounds, ccw_coords, ccw_bounds) in test_cases {
let triangle = Triangle::new(1, coords, bounds);
let (decoded_coords, decoded_bounds) = triangle.decode();
assert_eq!(decoded_coords, ccw_coords);
assert_eq!(decoded_bounds, ccw_bounds);
}
}
#[test]
fn decode_triangle() {
// distinct values for each coordinate to catch transposition
let test_cases = [
[11, 10, 60, 80, 41, 40],
[1, 0, 11, 20, 31, 30],
[30, 0, 11, 10, 1, 20],
[30, 0, 1, 20, 21, 11],
[20, 0, 1, 30, 41, 40],
[20, 0, 1, 30, 31, 10],
[30, 0, 1, 10, 11, 20],
[1, 0, 10, 20, 21, 11],
];
for coords in test_cases {
let triangle = Triangle::new(1, coords, [true, true, true]);
let (decoded_coords, _) = triangle.decode();
assert_eq!(decoded_coords, coords);
}
}
#[test]
fn triangulate_box() {
let i_polygon = vec![vec![
IntPoint::new(0, 0),
IntPoint::new(10, 0),
IntPoint::new(10, 10),
IntPoint::new(0, 10),
]];
let mut triangles = Vec::new();
let delaunay = i_polygon.triangulate().into_delaunay();
delaunay_to_triangles(1, &delaunay, &mut triangles);
assert_eq!(triangles.len(), 2);
}
}

View File

@@ -1,125 +0,0 @@
//! HUSH
use std::collections::HashMap;
use std::io;
use i_triangle::i_overlay::i_float::int::point::IntPoint;
use i_triangle::int::triangulatable::IntTriangulatable;
use crate::schema::Field;
use crate::spatial::geometry::Geometry;
use crate::spatial::point::GeoPoint;
use crate::spatial::serializer::SpatialSerializer;
use crate::spatial::triangle::{delaunay_to_triangles, Triangle};
use crate::DocId;
/// HUSH
pub struct SpatialWriter {
/// Map from field to its triangles buffer
triangles_by_field: HashMap<Field, Vec<Triangle>>,
}
impl SpatialWriter {
/// HUST
pub fn add_geometry(&mut self, doc_id: DocId, field: Field, geometry: Geometry) {
let triangles = &mut self.triangles_by_field.entry(field).or_default();
match geometry {
Geometry::Point(point) => {
append_point(triangles, doc_id, point);
}
Geometry::MultiPoint(multi_point) => {
for point in multi_point {
append_point(triangles, doc_id, point);
}
}
Geometry::LineString(line_string) => {
append_line_string(triangles, doc_id, line_string);
}
Geometry::MultiLineString(multi_line_string) => {
for line_string in multi_line_string {
append_line_string(triangles, doc_id, line_string);
}
}
Geometry::Polygon(polygon) => {
append_polygon(triangles, doc_id, &polygon);
}
Geometry::MultiPolygon(multi_polygon) => {
for polygon in multi_polygon {
append_polygon(triangles, doc_id, &polygon);
}
}
Geometry::GeometryCollection(geometries) => {
for geometry in geometries {
self.add_geometry(doc_id, field, geometry);
}
}
}
}
/// Memory usage estimate
pub fn mem_usage(&self) -> usize {
self.triangles_by_field
.values()
.map(|triangles| triangles.len() * std::mem::size_of::<Triangle>())
.sum()
}
/// Serializing our field.
pub fn serialize(&mut self, mut serializer: SpatialSerializer) -> io::Result<()> {
for (field, triangles) in &mut self.triangles_by_field {
serializer.serialize_field(*field, triangles)?;
}
serializer.close()?;
Ok(())
}
}
impl Default for SpatialWriter {
/// HUSH
fn default() -> Self {
SpatialWriter {
triangles_by_field: HashMap::new(),
}
}
}
/// Convert a point of `(longitude, latitude)` to a integer point.
pub fn as_point_i32(point: GeoPoint) -> (i32, i32) {
(
(point.lon / (360.0 / (1i64 << 32) as f64)).floor() as i32,
(point.lat / (180.0 / (1i64 << 32) as f64)).floor() as i32,
)
}
fn append_point(triangles: &mut Vec<Triangle>, doc_id: DocId, point: GeoPoint) {
let point = as_point_i32(point);
triangles.push(Triangle::from_point(doc_id, point.0, point.1));
}
fn append_line_string(triangles: &mut Vec<Triangle>, doc_id: DocId, line_string: Vec<GeoPoint>) {
let mut previous = as_point_i32(line_string[0]);
for point in line_string.into_iter().skip(1) {
let point = as_point_i32(point);
triangles.push(Triangle::from_line_segment(
doc_id, previous.0, previous.1, point.0, point.1,
));
previous = point
}
}
fn append_ring(i_polygon: &mut Vec<Vec<IntPoint>>, ring: &[GeoPoint]) {
let mut i_ring = Vec::with_capacity(ring.len() + 1);
for &point in ring {
let point = as_point_i32(point);
i_ring.push(IntPoint::new(point.0, point.1));
}
i_polygon.push(i_ring);
}
fn append_polygon(triangles: &mut Vec<Triangle>, doc_id: DocId, polygon: &[Vec<GeoPoint>]) {
let mut i_polygon: Vec<Vec<IntPoint>> = Vec::new();
for ring in polygon {
append_ring(&mut i_polygon, ring);
}
let delaunay = i_polygon.triangulate().into_delaunay();
delaunay_to_triangles(doc_id, &delaunay, triangles);
}

View File

@@ -1,136 +0,0 @@
//! XOR delta compression for f64 polygon coordinates.
//!
//! Lossless compression for floating-point lat/lon coordinates using XOR delta encoding on IEEE
//! 754 bit patterns with variable-length integer encoding. Designed for per-polygon random access
//! in the document store, where each polygon compresses independently without requiring sequential
//! decompression.
//!
//! Spatially local coordinates share most high-order bits. A municipal boundary spanning 1km has
//! consecutive vertices typically within 100-500 meters, meaning their f64 bit patterns share
//! 30-40 bits. XOR reveals these common bits as zeros, which varint encoding then compresses
//! efficiently.
//!
//! The format stores the first coordinate as raw 8 bytes, then XOR deltas between consecutive
//! coordinates encoded as variable-length integers. When compression produces larger output than
//! the raw input (random data, compression-hostile patterns), the function automatically falls
//! back to storing coordinates as uncompressed 8-byte values.
//!
//! Unlike delta.rs which uses arithmetic deltas for i32 spatial coordinates in the block kd-tree,
//! this module operates on f64 bit patterns directly to preserve exact floating-point values for
//! returning to users.
use std::io::Read;
use common::VInt;
/// Compresses f64 coordinates using XOR delta encoding with automatic raw fallback.
///
/// Stores the first coordinate as raw bits, then computes XOR between consecutive coordinate bit
/// patterns and encodes as variable-length integers. If the compressed output would be larger than
/// raw storage (8 bytes per coordinate), automatically falls back to raw encoding.
///
/// Returns a byte vector that can be decompressed with `decompress_f64()` to recover exact
/// original values.
pub fn compress_f64(values: &[f64]) -> Vec<u8> {
if values.is_empty() {
return Vec::new();
}
let mut output: Vec<u8> = Vec::new();
let mut previous: u64 = f64_to_le(values[0]);
output.extend_from_slice(&previous.to_le_bytes());
for &value in &values[1..] {
let bits = value.to_bits();
let xor = bits ^ previous;
VInt(xor).serialize_into_vec(&mut output);
previous = bits
}
if output.len() >= values.len() * 8 {
let mut output = Vec::with_capacity(values.len() * 8);
for &value in values {
output.extend_from_slice(&f64_to_le(value).to_le_bytes());
}
return output;
}
output
}
fn f64_to_le(value: f64) -> u64 {
u64::from_le_bytes(value.to_le_bytes())
}
fn f64_from_le(value: u64) -> f64 {
f64::from_le_bytes(value.to_le_bytes())
}
/// Decompresses f64 coordinates from XOR delta or raw encoding.
///
/// Detects compression format by byte length - if `bytes.len() == count * 8`, data is raw and
/// copied directly. Otherwise, reads first coordinate from 8 bytes, then XOR deltas as varints,
/// reconstructing the original sequence.
///
/// Returns exact f64 values that were passed to `compress_f64()`.
pub fn decompress_f64(mut bytes: &[u8], count: usize) -> Vec<f64> {
let mut values = Vec::with_capacity(count);
if bytes.len() == count * 8 {
for i in 0..count {
let bits = u64::from_le_bytes(bytes[i * 8..(i + 1) * 8].try_into().unwrap());
values.push(f64_from_le(bits));
}
return values;
}
let mut cursor: &mut &[u8] = &mut bytes;
// Read first value (raw 8 bytes)
let mut first_bytes = [0u8; 8];
cursor.read_exact(&mut first_bytes).unwrap();
let mut previous = u64::from_le_bytes(first_bytes);
values.push(f64::from_bits(previous));
// Read remaining values as VInt XORs
while values.len() < count {
let xor = VInt::deserialize_u64(&mut cursor).unwrap();
let bits = previous ^ xor;
values.push(f64::from_bits(bits));
previous = bits;
}
values
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_compress_spatial_locality() {
// Small town polygon - longitude only.
let longitudes = vec![
40.7580, 40.7581, 40.7582, 40.7583, 40.7584, 40.7585, 40.7586, 40.7587,
];
let bytes = compress_f64(&longitudes);
// Should compress well - XOR deltas will be small
assert_eq!(bytes.len(), 46);
// Should decompress to exact original values
let decompressed = decompress_f64(&bytes, longitudes.len());
assert_eq!(longitudes, decompressed);
}
#[test]
fn test_fallback_to_raw() {
// Random, widely scattered values - poor compression
let values = vec![
12345.6789,
-98765.4321,
0.00001,
999999.999,
-0.0,
std::f64::consts::PI,
std::f64::consts::E,
42.0,
];
let bytes = compress_f64(&values);
// Should fall back to raw storage
assert_eq!(bytes.len(), values.len() * 8);
// Should still decompress correctly
let decompressed = decompress_f64(&bytes, values.len());
assert_eq!(values, decompressed);
}
}