Compare commits

...

66 Commits

Author SHA1 Message Date
Pascal Seitz
806a1e1b1e clarify tokenizer docs 2023-04-03 22:59:38 +08:00
PSeitz
5c4ea6a708 tokenizer option on text fastfield (#1945)
* tokenizer option on text fastfield

allow to set tokenizer option on text fastfield (fixes #1901)
handle PreTokenized strings in fast field

* change visibility

* remove custom de/serialization
2023-03-31 10:03:38 +02:00
PSeitz
4cf93dab7d fix build (#1973) 2023-03-31 13:54:03 +09:00
PSeitz
5c380b76e7 Better mixed types support in aggs and fix serialization issue (#1971)
* Better mixed types support in aggs and fix serialization issue

- Improve support for mixed types in JSON field aggregations (pick the right field, #1913)
- Resolve the issue with JSON serialization for numeric keys (fixes #1967)
- Add JSON round-trip test for term buckets
- Remove `u64_lenient`, as this is a footgun without the type
- move aggregation benchmarks

* remove shadowing
2023-03-31 05:52:11 +02:00
PSeitz
571735c5f7 Fix index sort by on optional/multicolumn (#1972)
Fix index sort by on optional/multicolumn
add optional columns to proptest
extend proptests for sort
add columnar sort tests
2023-03-31 04:24:11 +02:00
zhouhui
8e92f960d3 Fix comment: change max_merge_size to max_docs_before_merge. (#1970) 2023-03-28 22:49:00 +09:00
Paul Masurel
057211c3d8 Fixing build on arm (#1966) 2023-03-27 22:42:57 +09:00
Paul Masurel
059fc767ea Added ::MIN ::MAX DateTime. (#1965) 2023-03-27 15:32:53 +09:00
Paul Masurel
694a056255 Faster range (#1954)
* Faster range queries

This PR does several changes
- ip compact space now uses u32
- the bitunpacker now gets a get_batch function
- we push down range filtering, removing GCD / shift in the bitpacking
  codec.
- we rely on AVX2 routine to do the filtering.

* Apply suggestions from code review

* Apply suggestions from code review

* CR comments
2023-03-27 14:56:32 +09:00
Paul Masurel
2955e34452 Added proptests for building/merging columnar. (#1963) 2023-03-27 14:56:02 +09:00
Paul Masurel
821208480b Adding Debug/Display impl. Refining the ColumnIndex::get_cardinality 2023-03-26 14:40:37 +09:00
Paul Masurel
a2e3c2ed5b Renaming Column::idx -> Column::index (#1961)
There was some variable name ghosting happening.
2023-03-26 13:58:50 +09:00
PSeitz
835f228bfa fix cardinality when merging empty columns (#1960)
fixes #1958
2023-03-25 15:58:15 +09:00
Paul Masurel
2b6a4da640 Exposing empty column builder. (#1959) 2023-03-24 16:34:41 +09:00
PSeitz
d6a95381ee add memory check for term agg (#1957) 2023-03-24 06:47:45 +01:00
PSeitz
da2804644f fetch blocks of vals in aggregation for all cardinality (#1950)
* fetch blocks of vals in aggregation for all cardinality

* move caching in common accessor
2023-03-23 08:41:11 +01:00
PSeitz
5504cfd012 remove IterColumn (#1955)
fixes #1658
2023-03-23 06:43:17 +01:00
trinity-1686a
482b4155e8 fix bug with new sstable index format (#1953) 2023-03-22 10:22:36 +01:00
Till Wegmüller
1a35f6573d Switch fs2 to fs4 as it is now unmaintained and does not support illumos (#1944)
Signed-off-by: Till Wegmueller <toasterson@gmail.com>
2023-03-22 13:48:49 +09:00
trinity-1686a
e5e50603a8 new sstable format (#1943)
* document a new sstable format

* add support for changing target block size

* use new format for sstable index

* handle sstable version errror

* use very small blocks for proptests

* add a footer structure
2023-03-21 15:03:52 +01:00
PSeitz
8f7f1d6be4 add Display for ByteCount (#1949)
* add Display for ByteCount

* export missing AggregationLimits
2023-03-21 08:02:35 +01:00
PSeitz
6a7a1106d6 work in batches of docs (#1937)
* work in batches of docs

* add fill_buffer test
2023-03-21 06:57:44 +01:00
PSeitz
9e2faecf5b add memory limit for aggregations (#1942)
* add memory limit for aggregations

introduce AggregationLimits to set memory consumption limit and bucket limits
memory limit is checked during aggregation, bucket limit is checked before returning the aggregation request.

* Apply suggestions from code review

Co-authored-by: Paul Masurel <paul@quickwit.io>

* add ByteCount with human readable format

---------

Co-authored-by: Paul Masurel <paul@quickwit.io>
2023-03-16 06:21:07 +01:00
PSeitz
b6703f1b3c fix validation in date histogram (#1936)
fix validation in date histogram for parameters interval and date_interval
2023-03-15 06:10:43 +01:00
PSeitz
2fb3740cb0 handle missing column for aggs (#1920)
* handle missing column for aggs

add empty column fallback for missing column in aggs.
Fix sort for term agg on sub-agg with missing value (null is smallest)

* add error when field is not fast
2023-03-15 06:09:59 +01:00
PSeitz
8459efa32c split term collection count and sub_agg (#1921)
use unrolled ColumnValues::get_vals
2023-03-13 04:37:41 +01:00
PSeitz
61cfd8dc57 fix clippy (#1927) 2023-03-13 03:12:02 +01:00
trinity-1686a
064518156f refactor tokenization pipeline to use GATs (#1924)
* refactor tokenization pipeline to use GATs

* fix doctests

* fix clippy lints

* remove commented code
2023-03-09 09:39:37 +01:00
PSeitz
a42a96f470 fix panic in dict column merge (#1930)
* fix panic in dict column merge

* Bugfix and added unit test

---------

Co-authored-by: Paul Masurel <paul@quickwit.io>
2023-03-08 22:04:37 +09:00
trinity-1686a
fcf5a25d93 use DeltaReader directly to implement Dictionnary::ord_to_term (#1928) 2023-03-08 11:15:56 +09:00
dependabot[bot]
c0a5b28fd3 Update lru requirement from 0.9.0 to 0.10.0 (#1932)
Updates the requirements on [lru](https://github.com/jeromefroe/lru-rs) to permit the latest version.
- [Release notes](https://github.com/jeromefroe/lru-rs/releases)
- [Changelog](https://github.com/jeromefroe/lru-rs/blob/master/CHANGELOG.md)
- [Commits](https://github.com/jeromefroe/lru-rs/compare/0.9.0...0.10.0)

---
updated-dependencies:
- dependency-name: lru
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-03-07 15:09:02 +09:00
trinity-1686a
a4f7ca8309 use DeltaReader directly to implement Dictionnary::term_ord (#1925)
* use DeltaReader directly to implement Dictionnary::term_ord

* add some additional test case for Dictionary::term_ord
2023-03-06 09:45:22 +01:00
Paul Masurel
364e321415 Clippy fix (#1926) 2023-03-06 10:37:17 +09:00
Paul Masurel
ed5a3b3172 Bumped murmurhash version 2023-03-03 21:24:32 +09:00
PSeitz
ca20bfa776 add date_histogram (#1900)
* add date_histogram

* add return result
2023-03-02 05:17:35 +01:00
PSeitz
faa706d804 add coerce option for text and numbers types (#1904)
* add coerce option for text and numbers types

allow to coerce the field type when indexing if the type does not match

* Apply suggestions from code review

Co-authored-by: Paul Masurel <paul@quickwit.io>

* add tests,add COERCE flag, include bool in coercion

---------

Co-authored-by: Paul Masurel <paul@quickwit.io>
2023-03-01 11:36:59 +01:00
PSeitz
850a0d7ae2 add agg benchmark for optional and multi value (#1916)
closes #1870
2023-03-01 17:01:52 +09:00
Paul Masurel
7fae4d98d7 Adapting for quickwit2 (#1912)
* Adapting tantivy to make it possible to be plugged to quickwit.

* Apply suggestions from code review

Co-authored-by: PSeitz <PSeitz@users.noreply.github.com>

* Added unit test

---------

Co-authored-by: PSeitz <PSeitz@users.noreply.github.com>
2023-03-01 16:27:46 +09:00
PSeitz
bc36458334 move buffer in front of dynamic dispatch (#1915)
dynamic dispatch seems to be really expensive, move the buffer in front of the dynamic dispatch, to reduce the number of calls into the dynamic dispatched collector.
2023-02-28 13:07:50 +08:00
trinity-1686a
8a71e00da3 allow limiting the number of matched term in range query (#1899) 2023-02-27 10:44:08 +01:00
PSeitz
e510f699c8 feat: add support for u64,i64,f64 fields in term aggregation (#1883)
* feat: add support for u64,i64,f64 fields in term aggregation

* hash enum values

* fix build

* Apply suggestions from code review

Co-authored-by: Paul Masurel <paul@quickwit.io>

---------

Co-authored-by: Paul Masurel <paul@quickwit.io>
2023-02-27 15:04:41 +08:00
Paul Masurel
d25fc155b2 Making some of the column/termdict operations async-friendly (#1902) 2023-02-27 15:34:47 +09:00
Paul Masurel
8ea97e7d6b Minor refactoring preparing for getting columnar integrated in quickwit. (#1911) 2023-02-27 14:23:30 +09:00
Paul Masurel
0a726a0897 Added Empty ColumnIndex (#1910) 2023-02-27 13:59:22 +09:00
Paul Masurel
66ff53b0f4 Various minor code cleanup (#1909) 2023-02-27 13:48:34 +09:00
Paul Masurel
d002698008 Re-export of query grammar. (#1908) 2023-02-27 12:26:34 +09:00
Paul Masurel
c838aa808b Removedc the extra nesting in unit test file (#1907) 2023-02-27 12:17:52 +09:00
Paul Masurel
06850719dc Renaming .values(DocId) to .values_for_doc(DocId) (#1906) 2023-02-27 12:15:13 +09:00
PSeitz
5f23bb7e65 switch to sparse collection for histogram (#1898)
* switch to sparse collection for histogram

Replaces histogram vec collection with a hashmap. This approach works much better for sparse data and enables use cases like drill downs (filter + small interval).
It is slower for dense cases (1.3x-2x slower). This can be alleviated with a specialized hashmap in the future.
closes #1704
closes #1370

* refactor, clippy

* fix bucket_pos overflow issue
2023-02-23 07:02:58 +01:00
trinity-1686a
533ad99cd5 add PhrasePrefixQuery (#1842)
* add PhrasePrefixQuery
2023-02-22 11:18:33 +01:00
PSeitz
c7278b3258 remove schema in aggs (#1888)
* switch to ColumnType, move tests

* remove Schema dependency in agg
2023-02-22 04:50:28 +01:00
Paul Masurel
6b403e3281 Re-export of columnar 2023-02-22 11:23:54 +09:00
Paul Masurel
789cc8703e Adding unit test testing docfreq after merge (#1895) 2023-02-22 11:05:34 +09:00
Paul Masurel
e5098d9fe8 Moving test around reenabling tests that were disabled. (#1894) 2023-02-22 10:31:52 +09:00
Paul Masurel
f537334e4f Adding a write schema to columnar's merge operations. (#1884)
* Adding a write schema to columnar's merge operations.

* Added unit test checking min/max when columns are empty.

* CR comment

* Rename to value_type_to_column_type
2023-02-21 18:25:16 +09:00
Paul Masurel
e2aa5af075 Clippy warnings fixes (#1885) 2023-02-20 19:04:13 +09:00
Paul Masurel
02bebf4ff5 Cargo fmt 2023-02-20 09:40:04 +09:00
Paul Masurel
0274c982d5 Refactoring. (#1881)
`ColumnValues` wrongly located in column_values/column.rs due to
historical reason moves to column_values/mod.rs

u128 stuff gets its own directory like u64 stuff.
2023-02-17 21:57:14 +09:00
PSeitz
74bf60b4f7 implement SegmentAggregationCollector on bucket aggs (#1878) 2023-02-17 12:53:29 +01:00
PSeitz
bf1449b22d update examples for literate docs (#1880) 2023-02-17 11:48:22 +01:00
PSeitz
111f25a8f7 clippy (#1879)
* fix clippy

* fix clippy

* fmt
2023-02-17 11:34:21 +01:00
PSeitz
019db10e8e refactor aggregations (#1875)
* add specialized version for full cardinality

Pre Columnar
test aggregation::tests::bench::bench_aggregation_average_u64                                                            ... bench:   6,681,850 ns/iter (+/- 1,217,385)
test aggregation::tests::bench::bench_aggregation_average_u64_and_f64                                                    ... bench:  10,576,327 ns/iter (+/- 494,380)

Current
test aggregation::tests::bench::bench_aggregation_average_u64                                                            ... bench:  11,562,084 ns/iter (+/- 3,678,682)
test aggregation::tests::bench::bench_aggregation_average_u64_and_f64                                                    ... bench:  18,925,790 ns/iter (+/- 17,616,771)

Post Change
test aggregation::tests::bench::bench_aggregation_average_u64                                                            ... bench:   9,123,811 ns/iter (+/- 399,720)
test aggregation::tests::bench::bench_aggregation_average_u64_and_f64                                                    ... bench:  13,111,825 ns/iter (+/- 273,547)

* refactor aggregation collection

* add buffering collector
2023-02-16 13:15:16 +01:00
Paul Masurel
7423f99719 Issue/columnar for json (#1876)
Adding support for JSON fast field.
2023-02-16 20:38:32 +09:00
Alex Cole
f2f38c43ce Make BM25 scoring more flexible (#1855)
* Introduce Bm25StatisticsProvider to inject statistics

* fix formatting I accidentally changed
2023-02-16 19:14:12 +09:00
PSeitz
71f43ace1d fix dynamic dispatch regression for range queries (#1871) 2023-02-14 16:56:40 +01:00
PSeitz
347614c841 test error for avg agg on ip field (#1873)
closes #1835
2023-02-14 23:22:56 +08:00
186 changed files with 11470 additions and 5085 deletions

2
.gitignore vendored
View File

@@ -13,3 +13,5 @@ benchmark
.idea
trace.dat
cargo-timing*
control
variable

View File

@@ -32,7 +32,7 @@ log = "0.4.16"
serde = { version = "1.0.136", features = ["derive"] }
serde_json = "1.0.79"
num_cpus = "1.13.1"
fs2 = { version = "0.4.3", optional = true }
fs4 = { version = "0.6.3", optional = true }
levenshtein_automata = "0.2.1"
uuid = { version = "1.0.0", features = ["v4", "serde"] }
crossbeam-channel = "0.5.4"
@@ -44,11 +44,11 @@ rustc-hash = "1.1.0"
thiserror = "1.0.30"
htmlescape = "0.3.1"
fail = "0.5.0"
murmurhash32 = "0.2.0"
murmurhash32 = "0.3.0"
time = { version = "0.3.10", features = ["serde-well-known"] }
smallvec = "1.8.0"
rayon = "1.5.2"
lru = "0.9.0"
lru = "0.10.0"
fastdivide = "0.4.0"
itertools = "0.10.3"
measure_time = "0.8.2"
@@ -58,7 +58,7 @@ arc-swap = "1.5.0"
columnar = { version="0.1", path="./columnar", package ="tantivy-columnar" }
sstable = { version="0.1", path="./sstable", package ="tantivy-sstable", optional = true }
stacker = { version="0.1", path="./stacker", package ="tantivy-stacker" }
tantivy-query-grammar = { version= "0.19.0", path="./query-grammar" }
query-grammar = { version= "0.19.0", path="./query-grammar", package = "tantivy-query-grammar" }
tantivy-bitpacker = { version= "0.3", path="./bitpacker" }
common = { version= "0.5", path = "./common/", package = "tantivy-common" }
tokenizer-api = { version="0.1", path="./tokenizer-api", package="tantivy-tokenizer-api" }
@@ -77,6 +77,7 @@ test-log = "0.2.10"
env_logger = "0.10.0"
pprof = { version = "0.11.0", features = ["flamegraph", "criterion"] }
futures = "0.3.21"
paste = "1.0.11"
[dev-dependencies.fail]
version = "0.5.0"
@@ -93,7 +94,7 @@ overflow-checks = true
[features]
default = ["mmap", "stopwords", "lz4-compression"]
mmap = ["fs2", "tempfile", "memmap2"]
mmap = ["fs4", "tempfile", "memmap2"]
stopwords = []
brotli-compression = ["brotli"]

View File

@@ -15,6 +15,7 @@ homepage = "https://github.com/quickwit-oss/tantivy"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
bitpacking = {version="0.8", default-features=false, features = ["bitpacker1x"]}
[dev-dependencies]
rand = "0.8"

View File

@@ -1,10 +1,14 @@
use std::convert::TryInto;
use std::io;
use std::ops::{Range, RangeInclusive};
use bitpacking::{BitPacker as ExternalBitPackerTrait, BitPacker1x};
pub struct BitPacker {
mini_buffer: u64,
mini_buffer_written: usize,
}
impl Default for BitPacker {
fn default() -> Self {
BitPacker::new()
@@ -118,6 +122,125 @@ impl BitUnpacker {
let val_shifted = val_unshifted_unmasked >> bit_shift;
val_shifted & self.mask
}
// Decodes the range of bitpacked `u32` values with idx
// in [start_idx, start_idx + output.len()).
//
// #Panics
//
// This methods panics if `num_bits` is > 32.
fn get_batch_u32s(&self, start_idx: u32, data: &[u8], output: &mut [u32]) {
assert!(
self.bit_width() <= 32,
"Bitwidth must be <= 32 to use this method."
);
let end_idx = start_idx + output.len() as u32;
let end_bit_read = end_idx * self.num_bits;
let end_byte_read = (end_bit_read + 7) / 8;
assert!(
end_byte_read as usize <= data.len(),
"Requested index is out of bounds."
);
// Simple slow implementation of get_batch_u32s, to deal with our ramps.
let get_batch_ramp = |start_idx: u32, output: &mut [u32]| {
for (out, idx) in output.iter_mut().zip(start_idx..) {
*out = self.get(idx, data) as u32;
}
};
// We use an unrolled routine to decode 32 values at once.
// We therefore decompose our range of values to decode into three ranges:
// - Entrance ramp: [start_idx, fast_track_start) (up to 31 values)
// - Highway: [fast_track_start, fast_track_end) (a length multiple of 32s)
// - Exit ramp: [fast_track_end, start_idx + output.len()) (up to 31 values)
// We want the start of the fast track to start align with bytes.
// A sufficient condition is to start with an idx that is a multiple of 8,
// so highway start is the closest multiple of 8 that is >= start_idx.
let entrance_ramp_len = 8 - (start_idx % 8) % 8;
let highway_start: u32 = start_idx + entrance_ramp_len;
if highway_start + BitPacker1x::BLOCK_LEN as u32 > end_idx {
// We don't have enough values to have even a single block of highway.
// Let's just supply the values the simple way.
get_batch_ramp(start_idx, output);
return;
}
let num_blocks: u32 = (end_idx - highway_start) / BitPacker1x::BLOCK_LEN as u32;
// Entrance ramp
get_batch_ramp(start_idx, &mut output[..entrance_ramp_len as usize]);
// Highway
let mut offset = (highway_start * self.num_bits) as usize / 8;
let mut output_cursor = (highway_start - start_idx) as usize;
for _ in 0..num_blocks {
offset += BitPacker1x.decompress(
&data[offset..],
&mut output[output_cursor..],
self.num_bits as u8,
);
output_cursor += 32;
}
// Exit ramp
let highway_end = highway_start + num_blocks * BitPacker1x::BLOCK_LEN as u32;
get_batch_ramp(highway_end, &mut output[output_cursor..]);
}
pub fn get_ids_for_value_range(
&self,
range: RangeInclusive<u64>,
id_range: Range<u32>,
data: &[u8],
positions: &mut Vec<u32>,
) {
if self.bit_width() > 32 {
self.get_ids_for_value_range_slow(range, id_range, data, positions)
} else {
if *range.start() > u32::MAX as u64 {
positions.clear();
return;
}
let range_u32 = (*range.start() as u32)..=(*range.end()).min(u32::MAX as u64) as u32;
self.get_ids_for_value_range_fast(range_u32, id_range, data, positions)
}
}
fn get_ids_for_value_range_slow(
&self,
range: RangeInclusive<u64>,
id_range: Range<u32>,
data: &[u8],
positions: &mut Vec<u32>,
) {
positions.clear();
for i in id_range {
// If we cared we could make this branchless, but the slow implementation should rarely
// kick in.
let val = self.get(i, data);
if range.contains(&val) {
positions.push(i);
}
}
}
fn get_ids_for_value_range_fast(
&self,
value_range: RangeInclusive<u32>,
id_range: Range<u32>,
data: &[u8],
positions: &mut Vec<u32>,
) {
positions.resize(id_range.len(), 0u32);
self.get_batch_u32s(id_range.start, data, positions);
crate::filter_vec::filter_vec_in_place(value_range, id_range.start, positions)
}
}
#[cfg(test)]
@@ -200,4 +323,58 @@ mod test {
test_bitpacker_aux(num_bits, &vals);
}
}
#[test]
#[should_panic]
fn test_get_batch_panics_over_32_bits() {
let bitunpacker = BitUnpacker::new(33);
let mut output: [u32; 1] = [0u32];
bitunpacker.get_batch_u32s(0, &[0, 0, 0, 0, 0, 0, 0, 0], &mut output[..]);
}
#[test]
fn test_get_batch_limit() {
let bitunpacker = BitUnpacker::new(1);
let mut output: [u32; 3] = [0u32, 0u32, 0u32];
bitunpacker.get_batch_u32s(8 * 4 - 3, &[0u8, 0u8, 0u8, 0u8], &mut output[..]);
}
#[test]
#[should_panic]
fn test_get_batch_panics_when_off_scope() {
let bitunpacker = BitUnpacker::new(1);
let mut output: [u32; 3] = [0u32, 0u32, 0u32];
// We are missing exactly one bit.
bitunpacker.get_batch_u32s(8 * 4 - 2, &[0u8, 0u8, 0u8, 0u8], &mut output[..]);
}
proptest::proptest! {
#[test]
fn test_get_batch_u32s_proptest(num_bits in 0u8..=32u8) {
let mask =
if num_bits == 32u8 {
u32::MAX
} else {
(1u32 << num_bits) - 1
};
let mut buffer: Vec<u8> = Vec::new();
let mut bitpacker = BitPacker::new();
for val in 0..100 {
bitpacker.write(val & mask as u64, num_bits, &mut buffer).unwrap();
}
bitpacker.flush(&mut buffer).unwrap();
let bitunpacker = BitUnpacker::new(num_bits);
let mut output: Vec<u32> = Vec::new();
for len in [0, 1, 2, 32, 33, 34, 64] {
for start_idx in 0u32..32u32 {
output.resize(len as usize, 0);
bitunpacker.get_batch_u32s(start_idx, &buffer, &mut output);
for i in 0..len {
let expected = (start_idx + i as u32) & mask;
assert_eq!(output[i], expected);
}
}
}
}
}
}

View File

@@ -0,0 +1,365 @@
//! SIMD filtering of a vector as described in the following blog post.
//! https://quickwit.io/blog/filtering%20a%20vector%20with%20simd%20instructions%20avx-2%20and%20avx-512
use std::arch::x86_64::{
__m256i as DataType, _mm256_add_epi32 as op_add, _mm256_cmpgt_epi32 as op_greater,
_mm256_lddqu_si256 as load_unaligned, _mm256_or_si256 as op_or, _mm256_set1_epi32 as set1,
_mm256_storeu_si256 as store_unaligned, _mm256_xor_si256 as op_xor, *,
};
use std::ops::RangeInclusive;
const NUM_LANES: usize = 8;
const HIGHEST_BIT: u32 = 1 << 31;
#[inline]
fn u32_to_i32(val: u32) -> i32 {
(val ^ HIGHEST_BIT) as i32
}
#[inline]
unsafe fn u32_to_i32_avx2(vals_u32x8s: DataType) -> DataType {
const HIGHEST_BIT_MASK: DataType = from_u32x8([HIGHEST_BIT; NUM_LANES]);
op_xor(vals_u32x8s, HIGHEST_BIT_MASK)
}
pub fn filter_vec_in_place(range: RangeInclusive<u32>, offset: u32, output: &mut Vec<u32>) {
// We use a monotonic mapping from u32 to i32 to make the comparison possible in AVX2.
let range_i32: RangeInclusive<i32> = u32_to_i32(*range.start())..=u32_to_i32(*range.end());
let num_words = output.len() / NUM_LANES;
let mut output_len = unsafe {
filter_vec_avx2_aux(
output.as_ptr() as *const __m256i,
range_i32,
output.as_mut_ptr(),
offset,
num_words,
)
};
let reminder_start = num_words * NUM_LANES;
for i in reminder_start..output.len() {
let val = output[i];
output[output_len] = offset + i as u32;
output_len += if range.contains(&val) { 1 } else { 0 };
}
output.truncate(output_len);
}
#[target_feature(enable = "avx2")]
unsafe fn filter_vec_avx2_aux(
mut input: *const __m256i,
range: RangeInclusive<i32>,
output: *mut u32,
offset: u32,
num_words: usize,
) -> usize {
let mut output_tail = output;
let range_simd = set1(*range.start())..=set1(*range.end());
let mut ids = from_u32x8([
offset,
offset + 1,
offset + 2,
offset + 3,
offset + 4,
offset + 5,
offset + 6,
offset + 7,
]);
const SHIFT: __m256i = from_u32x8([NUM_LANES as u32; NUM_LANES]);
for _ in 0..num_words {
let word = load_unaligned(input);
let word = u32_to_i32_avx2(word);
let keeper_bitset = compute_filter_bitset(word, range_simd.clone());
let added_len = keeper_bitset.count_ones();
let filtered_doc_ids = compact(ids, keeper_bitset);
store_unaligned(output_tail as *mut __m256i, filtered_doc_ids);
output_tail = output_tail.offset(added_len as isize);
ids = op_add(ids, SHIFT);
input = input.offset(1);
}
output_tail.offset_from(output) as usize
}
#[inline]
#[target_feature(enable = "avx2")]
unsafe fn compact(data: DataType, mask: u8) -> DataType {
let vperm_mask = MASK_TO_PERMUTATION[mask as usize];
_mm256_permutevar8x32_epi32(data, vperm_mask)
}
#[inline]
#[target_feature(enable = "avx2")]
unsafe fn compute_filter_bitset(val: __m256i, range: std::ops::RangeInclusive<__m256i>) -> u8 {
let too_low = op_greater(*range.start(), val);
let too_high = op_greater(val, *range.end());
let inside = op_or(too_low, too_high);
255 - std::arch::x86_64::_mm256_movemask_ps(std::mem::transmute::<DataType, __m256>(inside))
as u8
}
union U8x32 {
vector: DataType,
vals: [u32; NUM_LANES],
}
const fn from_u32x8(vals: [u32; NUM_LANES]) -> DataType {
unsafe { U8x32 { vals }.vector }
}
const MASK_TO_PERMUTATION: [DataType; 256] = [
from_u32x8([0, 0, 0, 0, 0, 0, 0, 0]),
from_u32x8([0, 0, 0, 0, 0, 0, 0, 0]),
from_u32x8([1, 0, 0, 0, 0, 0, 0, 0]),
from_u32x8([0, 1, 0, 0, 0, 0, 0, 0]),
from_u32x8([2, 0, 0, 0, 0, 0, 0, 0]),
from_u32x8([0, 2, 0, 0, 0, 0, 0, 0]),
from_u32x8([1, 2, 0, 0, 0, 0, 0, 0]),
from_u32x8([0, 1, 2, 0, 0, 0, 0, 0]),
from_u32x8([3, 0, 0, 0, 0, 0, 0, 0]),
from_u32x8([0, 3, 0, 0, 0, 0, 0, 0]),
from_u32x8([1, 3, 0, 0, 0, 0, 0, 0]),
from_u32x8([0, 1, 3, 0, 0, 0, 0, 0]),
from_u32x8([2, 3, 0, 0, 0, 0, 0, 0]),
from_u32x8([0, 2, 3, 0, 0, 0, 0, 0]),
from_u32x8([1, 2, 3, 0, 0, 0, 0, 0]),
from_u32x8([0, 1, 2, 3, 0, 0, 0, 0]),
from_u32x8([4, 0, 0, 0, 0, 0, 0, 0]),
from_u32x8([0, 4, 0, 0, 0, 0, 0, 0]),
from_u32x8([1, 4, 0, 0, 0, 0, 0, 0]),
from_u32x8([0, 1, 4, 0, 0, 0, 0, 0]),
from_u32x8([2, 4, 0, 0, 0, 0, 0, 0]),
from_u32x8([0, 2, 4, 0, 0, 0, 0, 0]),
from_u32x8([1, 2, 4, 0, 0, 0, 0, 0]),
from_u32x8([0, 1, 2, 4, 0, 0, 0, 0]),
from_u32x8([3, 4, 0, 0, 0, 0, 0, 0]),
from_u32x8([0, 3, 4, 0, 0, 0, 0, 0]),
from_u32x8([1, 3, 4, 0, 0, 0, 0, 0]),
from_u32x8([0, 1, 3, 4, 0, 0, 0, 0]),
from_u32x8([2, 3, 4, 0, 0, 0, 0, 0]),
from_u32x8([0, 2, 3, 4, 0, 0, 0, 0]),
from_u32x8([1, 2, 3, 4, 0, 0, 0, 0]),
from_u32x8([0, 1, 2, 3, 4, 0, 0, 0]),
from_u32x8([5, 0, 0, 0, 0, 0, 0, 0]),
from_u32x8([0, 5, 0, 0, 0, 0, 0, 0]),
from_u32x8([1, 5, 0, 0, 0, 0, 0, 0]),
from_u32x8([0, 1, 5, 0, 0, 0, 0, 0]),
from_u32x8([2, 5, 0, 0, 0, 0, 0, 0]),
from_u32x8([0, 2, 5, 0, 0, 0, 0, 0]),
from_u32x8([1, 2, 5, 0, 0, 0, 0, 0]),
from_u32x8([0, 1, 2, 5, 0, 0, 0, 0]),
from_u32x8([3, 5, 0, 0, 0, 0, 0, 0]),
from_u32x8([0, 3, 5, 0, 0, 0, 0, 0]),
from_u32x8([1, 3, 5, 0, 0, 0, 0, 0]),
from_u32x8([0, 1, 3, 5, 0, 0, 0, 0]),
from_u32x8([2, 3, 5, 0, 0, 0, 0, 0]),
from_u32x8([0, 2, 3, 5, 0, 0, 0, 0]),
from_u32x8([1, 2, 3, 5, 0, 0, 0, 0]),
from_u32x8([0, 1, 2, 3, 5, 0, 0, 0]),
from_u32x8([4, 5, 0, 0, 0, 0, 0, 0]),
from_u32x8([0, 4, 5, 0, 0, 0, 0, 0]),
from_u32x8([1, 4, 5, 0, 0, 0, 0, 0]),
from_u32x8([0, 1, 4, 5, 0, 0, 0, 0]),
from_u32x8([2, 4, 5, 0, 0, 0, 0, 0]),
from_u32x8([0, 2, 4, 5, 0, 0, 0, 0]),
from_u32x8([1, 2, 4, 5, 0, 0, 0, 0]),
from_u32x8([0, 1, 2, 4, 5, 0, 0, 0]),
from_u32x8([3, 4, 5, 0, 0, 0, 0, 0]),
from_u32x8([0, 3, 4, 5, 0, 0, 0, 0]),
from_u32x8([1, 3, 4, 5, 0, 0, 0, 0]),
from_u32x8([0, 1, 3, 4, 5, 0, 0, 0]),
from_u32x8([2, 3, 4, 5, 0, 0, 0, 0]),
from_u32x8([0, 2, 3, 4, 5, 0, 0, 0]),
from_u32x8([1, 2, 3, 4, 5, 0, 0, 0]),
from_u32x8([0, 1, 2, 3, 4, 5, 0, 0]),
from_u32x8([6, 0, 0, 0, 0, 0, 0, 0]),
from_u32x8([0, 6, 0, 0, 0, 0, 0, 0]),
from_u32x8([1, 6, 0, 0, 0, 0, 0, 0]),
from_u32x8([0, 1, 6, 0, 0, 0, 0, 0]),
from_u32x8([2, 6, 0, 0, 0, 0, 0, 0]),
from_u32x8([0, 2, 6, 0, 0, 0, 0, 0]),
from_u32x8([1, 2, 6, 0, 0, 0, 0, 0]),
from_u32x8([0, 1, 2, 6, 0, 0, 0, 0]),
from_u32x8([3, 6, 0, 0, 0, 0, 0, 0]),
from_u32x8([0, 3, 6, 0, 0, 0, 0, 0]),
from_u32x8([1, 3, 6, 0, 0, 0, 0, 0]),
from_u32x8([0, 1, 3, 6, 0, 0, 0, 0]),
from_u32x8([2, 3, 6, 0, 0, 0, 0, 0]),
from_u32x8([0, 2, 3, 6, 0, 0, 0, 0]),
from_u32x8([1, 2, 3, 6, 0, 0, 0, 0]),
from_u32x8([0, 1, 2, 3, 6, 0, 0, 0]),
from_u32x8([4, 6, 0, 0, 0, 0, 0, 0]),
from_u32x8([0, 4, 6, 0, 0, 0, 0, 0]),
from_u32x8([1, 4, 6, 0, 0, 0, 0, 0]),
from_u32x8([0, 1, 4, 6, 0, 0, 0, 0]),
from_u32x8([2, 4, 6, 0, 0, 0, 0, 0]),
from_u32x8([0, 2, 4, 6, 0, 0, 0, 0]),
from_u32x8([1, 2, 4, 6, 0, 0, 0, 0]),
from_u32x8([0, 1, 2, 4, 6, 0, 0, 0]),
from_u32x8([3, 4, 6, 0, 0, 0, 0, 0]),
from_u32x8([0, 3, 4, 6, 0, 0, 0, 0]),
from_u32x8([1, 3, 4, 6, 0, 0, 0, 0]),
from_u32x8([0, 1, 3, 4, 6, 0, 0, 0]),
from_u32x8([2, 3, 4, 6, 0, 0, 0, 0]),
from_u32x8([0, 2, 3, 4, 6, 0, 0, 0]),
from_u32x8([1, 2, 3, 4, 6, 0, 0, 0]),
from_u32x8([0, 1, 2, 3, 4, 6, 0, 0]),
from_u32x8([5, 6, 0, 0, 0, 0, 0, 0]),
from_u32x8([0, 5, 6, 0, 0, 0, 0, 0]),
from_u32x8([1, 5, 6, 0, 0, 0, 0, 0]),
from_u32x8([0, 1, 5, 6, 0, 0, 0, 0]),
from_u32x8([2, 5, 6, 0, 0, 0, 0, 0]),
from_u32x8([0, 2, 5, 6, 0, 0, 0, 0]),
from_u32x8([1, 2, 5, 6, 0, 0, 0, 0]),
from_u32x8([0, 1, 2, 5, 6, 0, 0, 0]),
from_u32x8([3, 5, 6, 0, 0, 0, 0, 0]),
from_u32x8([0, 3, 5, 6, 0, 0, 0, 0]),
from_u32x8([1, 3, 5, 6, 0, 0, 0, 0]),
from_u32x8([0, 1, 3, 5, 6, 0, 0, 0]),
from_u32x8([2, 3, 5, 6, 0, 0, 0, 0]),
from_u32x8([0, 2, 3, 5, 6, 0, 0, 0]),
from_u32x8([1, 2, 3, 5, 6, 0, 0, 0]),
from_u32x8([0, 1, 2, 3, 5, 6, 0, 0]),
from_u32x8([4, 5, 6, 0, 0, 0, 0, 0]),
from_u32x8([0, 4, 5, 6, 0, 0, 0, 0]),
from_u32x8([1, 4, 5, 6, 0, 0, 0, 0]),
from_u32x8([0, 1, 4, 5, 6, 0, 0, 0]),
from_u32x8([2, 4, 5, 6, 0, 0, 0, 0]),
from_u32x8([0, 2, 4, 5, 6, 0, 0, 0]),
from_u32x8([1, 2, 4, 5, 6, 0, 0, 0]),
from_u32x8([0, 1, 2, 4, 5, 6, 0, 0]),
from_u32x8([3, 4, 5, 6, 0, 0, 0, 0]),
from_u32x8([0, 3, 4, 5, 6, 0, 0, 0]),
from_u32x8([1, 3, 4, 5, 6, 0, 0, 0]),
from_u32x8([0, 1, 3, 4, 5, 6, 0, 0]),
from_u32x8([2, 3, 4, 5, 6, 0, 0, 0]),
from_u32x8([0, 2, 3, 4, 5, 6, 0, 0]),
from_u32x8([1, 2, 3, 4, 5, 6, 0, 0]),
from_u32x8([0, 1, 2, 3, 4, 5, 6, 0]),
from_u32x8([7, 0, 0, 0, 0, 0, 0, 0]),
from_u32x8([0, 7, 0, 0, 0, 0, 0, 0]),
from_u32x8([1, 7, 0, 0, 0, 0, 0, 0]),
from_u32x8([0, 1, 7, 0, 0, 0, 0, 0]),
from_u32x8([2, 7, 0, 0, 0, 0, 0, 0]),
from_u32x8([0, 2, 7, 0, 0, 0, 0, 0]),
from_u32x8([1, 2, 7, 0, 0, 0, 0, 0]),
from_u32x8([0, 1, 2, 7, 0, 0, 0, 0]),
from_u32x8([3, 7, 0, 0, 0, 0, 0, 0]),
from_u32x8([0, 3, 7, 0, 0, 0, 0, 0]),
from_u32x8([1, 3, 7, 0, 0, 0, 0, 0]),
from_u32x8([0, 1, 3, 7, 0, 0, 0, 0]),
from_u32x8([2, 3, 7, 0, 0, 0, 0, 0]),
from_u32x8([0, 2, 3, 7, 0, 0, 0, 0]),
from_u32x8([1, 2, 3, 7, 0, 0, 0, 0]),
from_u32x8([0, 1, 2, 3, 7, 0, 0, 0]),
from_u32x8([4, 7, 0, 0, 0, 0, 0, 0]),
from_u32x8([0, 4, 7, 0, 0, 0, 0, 0]),
from_u32x8([1, 4, 7, 0, 0, 0, 0, 0]),
from_u32x8([0, 1, 4, 7, 0, 0, 0, 0]),
from_u32x8([2, 4, 7, 0, 0, 0, 0, 0]),
from_u32x8([0, 2, 4, 7, 0, 0, 0, 0]),
from_u32x8([1, 2, 4, 7, 0, 0, 0, 0]),
from_u32x8([0, 1, 2, 4, 7, 0, 0, 0]),
from_u32x8([3, 4, 7, 0, 0, 0, 0, 0]),
from_u32x8([0, 3, 4, 7, 0, 0, 0, 0]),
from_u32x8([1, 3, 4, 7, 0, 0, 0, 0]),
from_u32x8([0, 1, 3, 4, 7, 0, 0, 0]),
from_u32x8([2, 3, 4, 7, 0, 0, 0, 0]),
from_u32x8([0, 2, 3, 4, 7, 0, 0, 0]),
from_u32x8([1, 2, 3, 4, 7, 0, 0, 0]),
from_u32x8([0, 1, 2, 3, 4, 7, 0, 0]),
from_u32x8([5, 7, 0, 0, 0, 0, 0, 0]),
from_u32x8([0, 5, 7, 0, 0, 0, 0, 0]),
from_u32x8([1, 5, 7, 0, 0, 0, 0, 0]),
from_u32x8([0, 1, 5, 7, 0, 0, 0, 0]),
from_u32x8([2, 5, 7, 0, 0, 0, 0, 0]),
from_u32x8([0, 2, 5, 7, 0, 0, 0, 0]),
from_u32x8([1, 2, 5, 7, 0, 0, 0, 0]),
from_u32x8([0, 1, 2, 5, 7, 0, 0, 0]),
from_u32x8([3, 5, 7, 0, 0, 0, 0, 0]),
from_u32x8([0, 3, 5, 7, 0, 0, 0, 0]),
from_u32x8([1, 3, 5, 7, 0, 0, 0, 0]),
from_u32x8([0, 1, 3, 5, 7, 0, 0, 0]),
from_u32x8([2, 3, 5, 7, 0, 0, 0, 0]),
from_u32x8([0, 2, 3, 5, 7, 0, 0, 0]),
from_u32x8([1, 2, 3, 5, 7, 0, 0, 0]),
from_u32x8([0, 1, 2, 3, 5, 7, 0, 0]),
from_u32x8([4, 5, 7, 0, 0, 0, 0, 0]),
from_u32x8([0, 4, 5, 7, 0, 0, 0, 0]),
from_u32x8([1, 4, 5, 7, 0, 0, 0, 0]),
from_u32x8([0, 1, 4, 5, 7, 0, 0, 0]),
from_u32x8([2, 4, 5, 7, 0, 0, 0, 0]),
from_u32x8([0, 2, 4, 5, 7, 0, 0, 0]),
from_u32x8([1, 2, 4, 5, 7, 0, 0, 0]),
from_u32x8([0, 1, 2, 4, 5, 7, 0, 0]),
from_u32x8([3, 4, 5, 7, 0, 0, 0, 0]),
from_u32x8([0, 3, 4, 5, 7, 0, 0, 0]),
from_u32x8([1, 3, 4, 5, 7, 0, 0, 0]),
from_u32x8([0, 1, 3, 4, 5, 7, 0, 0]),
from_u32x8([2, 3, 4, 5, 7, 0, 0, 0]),
from_u32x8([0, 2, 3, 4, 5, 7, 0, 0]),
from_u32x8([1, 2, 3, 4, 5, 7, 0, 0]),
from_u32x8([0, 1, 2, 3, 4, 5, 7, 0]),
from_u32x8([6, 7, 0, 0, 0, 0, 0, 0]),
from_u32x8([0, 6, 7, 0, 0, 0, 0, 0]),
from_u32x8([1, 6, 7, 0, 0, 0, 0, 0]),
from_u32x8([0, 1, 6, 7, 0, 0, 0, 0]),
from_u32x8([2, 6, 7, 0, 0, 0, 0, 0]),
from_u32x8([0, 2, 6, 7, 0, 0, 0, 0]),
from_u32x8([1, 2, 6, 7, 0, 0, 0, 0]),
from_u32x8([0, 1, 2, 6, 7, 0, 0, 0]),
from_u32x8([3, 6, 7, 0, 0, 0, 0, 0]),
from_u32x8([0, 3, 6, 7, 0, 0, 0, 0]),
from_u32x8([1, 3, 6, 7, 0, 0, 0, 0]),
from_u32x8([0, 1, 3, 6, 7, 0, 0, 0]),
from_u32x8([2, 3, 6, 7, 0, 0, 0, 0]),
from_u32x8([0, 2, 3, 6, 7, 0, 0, 0]),
from_u32x8([1, 2, 3, 6, 7, 0, 0, 0]),
from_u32x8([0, 1, 2, 3, 6, 7, 0, 0]),
from_u32x8([4, 6, 7, 0, 0, 0, 0, 0]),
from_u32x8([0, 4, 6, 7, 0, 0, 0, 0]),
from_u32x8([1, 4, 6, 7, 0, 0, 0, 0]),
from_u32x8([0, 1, 4, 6, 7, 0, 0, 0]),
from_u32x8([2, 4, 6, 7, 0, 0, 0, 0]),
from_u32x8([0, 2, 4, 6, 7, 0, 0, 0]),
from_u32x8([1, 2, 4, 6, 7, 0, 0, 0]),
from_u32x8([0, 1, 2, 4, 6, 7, 0, 0]),
from_u32x8([3, 4, 6, 7, 0, 0, 0, 0]),
from_u32x8([0, 3, 4, 6, 7, 0, 0, 0]),
from_u32x8([1, 3, 4, 6, 7, 0, 0, 0]),
from_u32x8([0, 1, 3, 4, 6, 7, 0, 0]),
from_u32x8([2, 3, 4, 6, 7, 0, 0, 0]),
from_u32x8([0, 2, 3, 4, 6, 7, 0, 0]),
from_u32x8([1, 2, 3, 4, 6, 7, 0, 0]),
from_u32x8([0, 1, 2, 3, 4, 6, 7, 0]),
from_u32x8([5, 6, 7, 0, 0, 0, 0, 0]),
from_u32x8([0, 5, 6, 7, 0, 0, 0, 0]),
from_u32x8([1, 5, 6, 7, 0, 0, 0, 0]),
from_u32x8([0, 1, 5, 6, 7, 0, 0, 0]),
from_u32x8([2, 5, 6, 7, 0, 0, 0, 0]),
from_u32x8([0, 2, 5, 6, 7, 0, 0, 0]),
from_u32x8([1, 2, 5, 6, 7, 0, 0, 0]),
from_u32x8([0, 1, 2, 5, 6, 7, 0, 0]),
from_u32x8([3, 5, 6, 7, 0, 0, 0, 0]),
from_u32x8([0, 3, 5, 6, 7, 0, 0, 0]),
from_u32x8([1, 3, 5, 6, 7, 0, 0, 0]),
from_u32x8([0, 1, 3, 5, 6, 7, 0, 0]),
from_u32x8([2, 3, 5, 6, 7, 0, 0, 0]),
from_u32x8([0, 2, 3, 5, 6, 7, 0, 0]),
from_u32x8([1, 2, 3, 5, 6, 7, 0, 0]),
from_u32x8([0, 1, 2, 3, 5, 6, 7, 0]),
from_u32x8([4, 5, 6, 7, 0, 0, 0, 0]),
from_u32x8([0, 4, 5, 6, 7, 0, 0, 0]),
from_u32x8([1, 4, 5, 6, 7, 0, 0, 0]),
from_u32x8([0, 1, 4, 5, 6, 7, 0, 0]),
from_u32x8([2, 4, 5, 6, 7, 0, 0, 0]),
from_u32x8([0, 2, 4, 5, 6, 7, 0, 0]),
from_u32x8([1, 2, 4, 5, 6, 7, 0, 0]),
from_u32x8([0, 1, 2, 4, 5, 6, 7, 0]),
from_u32x8([3, 4, 5, 6, 7, 0, 0, 0]),
from_u32x8([0, 3, 4, 5, 6, 7, 0, 0]),
from_u32x8([1, 3, 4, 5, 6, 7, 0, 0]),
from_u32x8([0, 1, 3, 4, 5, 6, 7, 0]),
from_u32x8([2, 3, 4, 5, 6, 7, 0, 0]),
from_u32x8([0, 2, 3, 4, 5, 6, 7, 0]),
from_u32x8([1, 2, 3, 4, 5, 6, 7, 0]),
from_u32x8([0, 1, 2, 3, 4, 5, 6, 7]),
];

View File

@@ -0,0 +1,165 @@
use std::ops::RangeInclusive;
#[cfg(any(target_arch = "x86_64"))]
mod avx2;
mod scalar;
#[derive(Clone, Copy, Eq, PartialEq, Debug)]
#[repr(u8)]
enum FilterImplPerInstructionSet {
#[cfg(target_arch = "x86_64")]
AVX2 = 0u8,
Scalar = 1u8,
}
impl FilterImplPerInstructionSet {
#[inline]
pub fn is_available(&self) -> bool {
match *self {
#[cfg(target_arch = "x86_64")]
FilterImplPerInstructionSet::AVX2 => is_x86_feature_detected!("avx2"),
FilterImplPerInstructionSet::Scalar => true,
}
}
}
// List of available implementation in preferred order.
#[cfg(target_arch = "x86_64")]
const IMPLS: [FilterImplPerInstructionSet; 2] = [
FilterImplPerInstructionSet::AVX2,
FilterImplPerInstructionSet::Scalar,
];
#[cfg(not(target_arch = "x86_64"))]
const IMPLS: [FilterImplPerInstructionSet; 1] = [FilterImplPerInstructionSet::Scalar];
impl FilterImplPerInstructionSet {
#[allow(unused_variables)]
#[inline]
fn from(code: u8) -> FilterImplPerInstructionSet {
#[cfg(target_arch = "x86_64")]
if code == FilterImplPerInstructionSet::AVX2 as u8 {
return FilterImplPerInstructionSet::AVX2;
}
FilterImplPerInstructionSet::Scalar
}
#[inline]
fn filter_vec_in_place(self, range: RangeInclusive<u32>, offset: u32, output: &mut Vec<u32>) {
match self {
#[cfg(target_arch = "x86_64")]
FilterImplPerInstructionSet::AVX2 => avx2::filter_vec_in_place(range, offset, output),
FilterImplPerInstructionSet::Scalar => {
scalar::filter_vec_in_place(range, offset, output)
}
}
}
}
#[inline]
fn get_best_available_instruction_set() -> FilterImplPerInstructionSet {
use std::sync::atomic::{AtomicU8, Ordering};
static INSTRUCTION_SET_BYTE: AtomicU8 = AtomicU8::new(u8::MAX);
let instruction_set_byte: u8 = INSTRUCTION_SET_BYTE.load(Ordering::Relaxed);
if instruction_set_byte == u8::MAX {
// Let's initialize the instruction set and cache it.
let instruction_set = IMPLS
.into_iter()
.find(FilterImplPerInstructionSet::is_available)
.unwrap();
INSTRUCTION_SET_BYTE.store(instruction_set as u8, Ordering::Relaxed);
return instruction_set;
}
FilterImplPerInstructionSet::from(instruction_set_byte)
}
pub fn filter_vec_in_place(range: RangeInclusive<u32>, offset: u32, output: &mut Vec<u32>) {
get_best_available_instruction_set().filter_vec_in_place(range, offset, output)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_best_available_instruction_set() {
// This does not test much unfortunately.
// We just make sure the function returns without crashing and returns the same result.
let instruction_set = get_best_available_instruction_set();
assert_eq!(get_best_available_instruction_set(), instruction_set);
}
#[cfg(target_arch = "x86_64")]
#[test]
fn test_instruction_set_to_code_from_code() {
for instruction_set in [
FilterImplPerInstructionSet::AVX2,
FilterImplPerInstructionSet::Scalar,
] {
let code = instruction_set as u8;
assert_eq!(instruction_set, FilterImplPerInstructionSet::from(code));
}
}
fn test_filter_impl_empty_aux(filter_impl: FilterImplPerInstructionSet) {
let mut output = vec![];
filter_impl.filter_vec_in_place(0..=u32::MAX, 0, &mut output);
assert_eq!(&output, &[]);
}
fn test_filter_impl_simple_aux(filter_impl: FilterImplPerInstructionSet) {
let mut output = vec![3, 2, 1, 5, 11, 2, 5, 10, 2];
filter_impl.filter_vec_in_place(3..=10, 0, &mut output);
assert_eq!(&output, &[0, 3, 6, 7]);
}
fn test_filter_impl_simple_aux_shifted(filter_impl: FilterImplPerInstructionSet) {
let mut output = vec![3, 2, 1, 5, 11, 2, 5, 10, 2];
filter_impl.filter_vec_in_place(3..=10, 10, &mut output);
assert_eq!(&output, &[10, 13, 16, 17]);
}
fn test_filter_impl_simple_outside_i32_range(filter_impl: FilterImplPerInstructionSet) {
let mut output = vec![u32::MAX, i32::MAX as u32 + 1, 0, 1, 3, 1, 1, 1, 1];
filter_impl.filter_vec_in_place(1..=i32::MAX as u32 + 1u32, 0, &mut output);
assert_eq!(&output, &[1, 3, 4, 5, 6, 7, 8]);
}
fn test_filter_impl_test_suite(filter_impl: FilterImplPerInstructionSet) {
test_filter_impl_empty_aux(filter_impl);
test_filter_impl_simple_aux(filter_impl);
test_filter_impl_simple_aux_shifted(filter_impl);
test_filter_impl_simple_outside_i32_range(filter_impl);
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_filter_implementation_avx2() {
if FilterImplPerInstructionSet::AVX2.is_available() {
test_filter_impl_test_suite(FilterImplPerInstructionSet::AVX2);
}
}
#[test]
fn test_filter_implementation_scalar() {
test_filter_impl_test_suite(FilterImplPerInstructionSet::Scalar);
}
#[cfg(target_arch = "x86_64")]
proptest::proptest! {
#[test]
fn test_filter_compare_scalar_and_avx2_impl_proptest(
start in proptest::prelude::any::<u32>(),
end in proptest::prelude::any::<u32>(),
offset in 0u32..2u32,
mut vals in proptest::collection::vec(0..u32::MAX, 0..30)) {
if FilterImplPerInstructionSet::AVX2.is_available() {
let mut vals_clone = vals.clone();
FilterImplPerInstructionSet::AVX2.filter_vec_in_place(start..=end, offset, &mut vals);
FilterImplPerInstructionSet::Scalar.filter_vec_in_place(start..=end, offset, &mut vals_clone);
assert_eq!(&vals, &vals_clone);
}
}
}
}

View File

@@ -0,0 +1,13 @@
use std::ops::RangeInclusive;
pub fn filter_vec_in_place(range: RangeInclusive<u32>, offset: u32, output: &mut Vec<u32>) {
// We restrict the accepted boundary, because unsigned integers & SIMD don't
// play well.
let mut output_cursor = 0;
for i in 0..output.len() {
let val = output[i];
output[output_cursor] = offset + i as u32;
output_cursor += if range.contains(&val) { 1 } else { 0 };
}
output.truncate(output_cursor);
}

View File

@@ -1,5 +1,6 @@
mod bitpacker;
mod blocked_bitpacker;
mod filter_vec;
use std::cmp::Ordering;

View File

@@ -17,6 +17,7 @@ stacker = { path = "../stacker", package="tantivy-stacker"}
sstable = { path = "../sstable", package = "tantivy-sstable" }
common = { path = "../common", package = "tantivy-common" }
tantivy-bitpacker = { version= "0.3", path = "../bitpacker/" }
serde = "1.0.152"
[dev-dependencies]
proptest = "1"

View File

@@ -1,7 +1,6 @@
# zero to one
* revisit line codec
* removal of all rows of a column in the schema due to deletes
* add columns from schema on merge
* Plugging JSON
* replug examples

View File

@@ -0,0 +1,36 @@
use crate::{Column, DocId, RowId};
#[derive(Debug, Default, Clone)]
pub struct ColumnBlockAccessor<T> {
val_cache: Vec<T>,
docid_cache: Vec<DocId>,
row_id_cache: Vec<RowId>,
}
impl<T: PartialOrd + Copy + std::fmt::Debug + Send + Sync + 'static + Default>
ColumnBlockAccessor<T>
{
#[inline]
pub fn fetch_block(&mut self, docs: &[u32], accessor: &Column<T>) {
self.docid_cache.clear();
self.row_id_cache.clear();
accessor.row_ids_for_docs(docs, &mut self.docid_cache, &mut self.row_id_cache);
self.val_cache.resize(self.row_id_cache.len(), T::default());
accessor
.values
.get_vals(&self.row_id_cache, &mut self.val_cache);
}
#[inline]
pub fn iter_vals(&self) -> impl Iterator<Item = T> + '_ {
self.val_cache.iter().cloned()
}
#[inline]
pub fn iter_docid_vals(&self) -> impl Iterator<Item = (DocId, T)> + '_ {
self.docid_cache
.iter()
.cloned()
.zip(self.val_cache.iter().cloned())
}
}

View File

@@ -1,6 +1,6 @@
use std::io;
use std::ops::Deref;
use std::sync::Arc;
use std::{fmt, io};
use sstable::{Dictionary, VoidSSTable};
@@ -21,6 +21,14 @@ pub struct BytesColumn {
pub(crate) term_ord_column: Column<u64>,
}
impl fmt::Debug for BytesColumn {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BytesColumn")
.field("term_ord_column", &self.term_ord_column)
.finish()
}
}
impl BytesColumn {
/// Fills the given `output` buffer with the term associated to the ordinal `ord`.
///
@@ -36,7 +44,7 @@ impl BytesColumn {
}
pub fn term_ords(&self, row_id: RowId) -> impl Iterator<Item = u64> + '_ {
self.term_ord_column.values(row_id)
self.term_ord_column.values_for_doc(row_id)
}
/// Returns the column of ordinals
@@ -56,6 +64,12 @@ impl BytesColumn {
#[derive(Clone)]
pub struct StrColumn(BytesColumn);
impl fmt::Debug for StrColumn {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:?}", self.term_ord_column)
}
}
impl From<StrColumn> for BytesColumn {
fn from(str_column: StrColumn) -> BytesColumn {
str_column.0

View File

@@ -1,7 +1,7 @@
mod dictionary_encoded;
mod serialize;
use std::fmt::Debug;
use std::fmt::{self, Debug};
use std::io::Write;
use std::ops::{Deref, Range, RangeInclusive};
use std::sync::Arc;
@@ -16,14 +16,33 @@ pub use serialize::{
use crate::column_index::ColumnIndex;
use crate::column_values::monotonic_mapping::StrictlyMonotonicMappingToInternal;
use crate::column_values::{monotonic_map_column, ColumnValues};
use crate::{Cardinality, MonotonicallyMappableToU64, RowId};
use crate::{Cardinality, DocId, EmptyColumnValues, MonotonicallyMappableToU64, RowId};
#[derive(Clone)]
pub struct Column<T = u64> {
pub idx: ColumnIndex,
pub index: ColumnIndex,
pub values: Arc<dyn ColumnValues<T>>,
}
impl<T: Debug + PartialOrd + Send + Sync + Copy + 'static> Debug for Column<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let num_docs = self.num_docs();
let entries = (0..num_docs)
.map(|i| (i, self.values_for_doc(i).collect::<Vec<_>>()))
.filter(|(_, vals)| !vals.is_empty());
f.debug_map().entries(entries).finish()
}
}
impl<T: PartialOrd + Default> Column<T> {
pub fn build_empty_column(num_docs: u32) -> Column<T> {
Column {
index: ColumnIndex::Empty { num_docs },
values: Arc::new(EmptyColumnValues),
}
}
}
impl<T: MonotonicallyMappableToU64> Column<T> {
pub fn to_u64_monotonic(self) -> Column<u64> {
let values = Arc::new(monotonic_map_column(
@@ -31,20 +50,22 @@ impl<T: MonotonicallyMappableToU64> Column<T> {
StrictlyMonotonicMappingToInternal::<T>::new(),
));
Column {
idx: self.idx,
index: self.index,
values,
}
}
}
impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static> Column<T> {
#[inline]
pub fn get_cardinality(&self) -> Cardinality {
self.idx.get_cardinality()
self.index.get_cardinality()
}
pub fn num_docs(&self) -> RowId {
match &self.idx {
ColumnIndex::Full => self.values.num_vals() as u32,
match &self.index {
ColumnIndex::Empty { num_docs } => *num_docs,
ColumnIndex::Full => self.values.num_vals(),
ColumnIndex::Optional(optional_index) => optional_index.num_docs(),
ColumnIndex::Multivalued(col_index) => {
// The multivalued index contains all value start row_id,
@@ -63,11 +84,28 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static> Column<T> {
}
pub fn first(&self, row_id: RowId) -> Option<T> {
self.values(row_id).next()
self.values_for_doc(row_id).next()
}
pub fn values(&self, row_id: RowId) -> impl Iterator<Item = T> + '_ {
self.value_row_ids(row_id)
/// Translates a block of docis to row_ids.
///
/// returns the row_ids and the matching docids on the same index
/// e.g.
/// DocId In: [0, 5, 6]
/// DocId Out: [0, 0, 6, 6]
/// RowId Out: [0, 1, 2, 3]
#[inline]
pub fn row_ids_for_docs(
&self,
doc_ids: &[DocId],
doc_ids_out: &mut Vec<DocId>,
row_ids: &mut Vec<RowId>,
) {
self.index.docids_to_rowids(doc_ids, doc_ids_out, row_ids)
}
pub fn values_for_doc(&self, doc_id: DocId) -> impl Iterator<Item = T> + '_ {
self.value_row_ids(doc_id)
.map(|value_row_id: RowId| self.values.get_val(value_row_id))
}
@@ -77,17 +115,19 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static> Column<T> {
&self,
value_range: RangeInclusive<T>,
selected_docid_range: Range<u32>,
docids: &mut Vec<u32>,
doc_ids: &mut Vec<u32>,
) {
// convert passed docid range to row id range
let rowid_range = self.idx.docid_range_to_rowids(selected_docid_range.clone());
let rowid_range = self
.index
.docid_range_to_rowids(selected_docid_range.clone());
// Load rows
self.values
.get_row_ids_for_value_range(value_range, rowid_range, docids);
.get_row_ids_for_value_range(value_range, rowid_range, doc_ids);
// Convert rows to docids
self.idx
.select_batch_in_place(docids, selected_docid_range.start);
self.index
.select_batch_in_place(selected_docid_range.start, doc_ids);
}
/// Fils the output vector with the (possibly multiple values that are associated_with
@@ -96,7 +136,7 @@ impl<T: PartialOrd + Copy + Debug + Send + Sync + 'static> Column<T> {
/// This method clears the `output` vector.
pub fn fill_vals(&self, row_id: RowId, output: &mut Vec<T>) {
output.clear();
output.extend(self.values(row_id));
output.extend(self.values_for_doc(row_id));
}
pub fn first_or_default_col(self, default_value: T) -> Arc<dyn ColumnValues<T>> {
@@ -111,7 +151,7 @@ impl<T> Deref for Column<T> {
type Target = ColumnIndex;
fn deref(&self) -> &Self::Target {
&self.idx
&self.index
}
}
@@ -149,7 +189,8 @@ impl<T: PartialOrd + Debug + Send + Sync + Copy + 'static> ColumnValues<T>
}
fn num_vals(&self) -> u32 {
match &self.column.idx {
match &self.column.index {
ColumnIndex::Empty { .. } => 0u32,
ColumnIndex::Full => self.column.values.num_vals(),
ColumnIndex::Optional(optional_idx) => optional_idx.num_docs(),
ColumnIndex::Multivalued(multivalue_idx) => multivalue_idx.num_docs(),

View File

@@ -7,9 +7,10 @@ use sstable::Dictionary;
use crate::column::{BytesColumn, Column};
use crate::column_index::{serialize_column_index, SerializableColumnIndex};
use crate::column_values::serialize::serialize_column_values_u128;
use crate::column_values::u64_based::{serialize_u64_based_column_values, CodecType};
use crate::column_values::{MonotonicallyMappableToU128, MonotonicallyMappableToU64};
use crate::column_values::{
load_u64_based_column_values, serialize_column_values_u128, serialize_u64_based_column_values,
CodecType, MonotonicallyMappableToU128, MonotonicallyMappableToU64,
};
use crate::iterable::Iterable;
use crate::StrColumn;
@@ -49,10 +50,9 @@ pub fn open_column_u64<T: MonotonicallyMappableToU64>(bytes: OwnedBytes) -> io::
);
let (column_index_data, column_values_data) = body.split(column_index_num_bytes as usize);
let column_index = crate::column_index::open_column_index(column_index_data)?;
let column_values =
crate::column_values::u64_based::load_u64_based_column_values(column_values_data)?;
let column_values = load_u64_based_column_values(column_values_data)?;
Ok(Column {
idx: column_index,
index: column_index,
values: column_values,
})
}
@@ -71,7 +71,7 @@ pub fn open_column_u128<T: MonotonicallyMappableToU128>(
let column_index = crate::column_index::open_column_index(column_index_data)?;
let column_values = crate::column_values::open_u128_mapped(column_values_data)?;
Ok(Column {
idx: column_index,
index: column_index,
values: column_values,
})
}

View File

@@ -8,17 +8,16 @@ use crate::column_index::SerializableColumnIndex;
use crate::{Cardinality, ColumnIndex, MergeRowOrder};
// For simplification, we never have cardinality go down due to deletes.
fn detect_cardinality(columns: &[Option<ColumnIndex>]) -> Cardinality {
fn detect_cardinality(columns: &[ColumnIndex]) -> Cardinality {
columns
.iter()
.flatten()
.map(ColumnIndex::get_cardinality)
.max()
.unwrap_or(Cardinality::Full)
}
pub fn merge_column_index<'a>(
columns: &'a [Option<ColumnIndex>],
columns: &'a [ColumnIndex],
merge_row_order: &'a MergeRowOrder,
) -> SerializableColumnIndex<'a> {
// For simplification, we do not try to detect whether the cardinality could be
@@ -53,34 +52,33 @@ mod tests {
let optional_index: ColumnIndex = OptionalIndex::for_test(1, &[]).into();
let multivalued_index: ColumnIndex = MultiValueIndex::for_test(&[0, 1]).into();
assert_eq!(
detect_cardinality(&[Some(optional_index.clone()), None]),
detect_cardinality(&[optional_index.clone(), ColumnIndex::Empty { num_docs: 0 }]),
Cardinality::Optional
);
assert_eq!(
detect_cardinality(&[Some(optional_index.clone()), Some(ColumnIndex::Full)]),
detect_cardinality(&[optional_index.clone(), ColumnIndex::Full]),
Cardinality::Optional
);
assert_eq!(
detect_cardinality(&[Some(multivalued_index.clone()), None]),
Cardinality::Multivalued
);
assert_eq!(
detect_cardinality(&[
Some(multivalued_index.clone()),
Some(optional_index.clone())
multivalued_index.clone(),
ColumnIndex::Empty { num_docs: 0 }
]),
Cardinality::Multivalued
);
assert_eq!(
detect_cardinality(&[Some(optional_index), Some(multivalued_index)]),
detect_cardinality(&[multivalued_index.clone(), optional_index.clone()]),
Cardinality::Multivalued
);
assert_eq!(
detect_cardinality(&[optional_index, multivalued_index]),
Cardinality::Multivalued
);
}
#[test]
fn test_merge_index_multivalued_sorted() {
let column_indexes: Vec<Option<ColumnIndex>> =
vec![Some(MultiValueIndex::for_test(&[0, 2, 5]).into())];
let column_indexes: Vec<ColumnIndex> = vec![MultiValueIndex::for_test(&[0, 2, 5]).into()];
let merge_row_order: MergeRowOrder = ShuffleMergeOrder::for_test(
&[2],
vec![
@@ -104,10 +102,10 @@ mod tests {
#[test]
fn test_merge_index_multivalued_sorted_several_segment() {
let column_indexes: Vec<Option<ColumnIndex>> = vec![
Some(MultiValueIndex::for_test(&[0, 2, 5]).into()),
None,
Some(MultiValueIndex::for_test(&[0, 1, 4]).into()),
let column_indexes: Vec<ColumnIndex> = vec![
MultiValueIndex::for_test(&[0, 2, 5]).into(),
ColumnIndex::Empty { num_docs: 0 },
MultiValueIndex::for_test(&[0, 1, 4]).into(),
];
let merge_row_order: MergeRowOrder = ShuffleMergeOrder::for_test(
&[2, 0, 2],

View File

@@ -5,7 +5,7 @@ use crate::iterable::Iterable;
use crate::{Cardinality, ColumnIndex, RowId, ShuffleMergeOrder};
pub fn merge_column_index_shuffled<'a>(
column_indexes: &'a [Option<ColumnIndex>],
column_indexes: &'a [ColumnIndex],
cardinality_after_merge: Cardinality,
shuffle_merge_order: &'a ShuffleMergeOrder,
) -> SerializableColumnIndex<'a> {
@@ -33,41 +33,41 @@ pub fn merge_column_index_shuffled<'a>(
///
/// In other words the column_indexes passed as argument may NOT be multivalued.
fn merge_column_index_shuffled_optional<'a>(
column_indexes: &'a [Option<ColumnIndex>],
column_indexes: &'a [ColumnIndex],
merge_order: &'a ShuffleMergeOrder,
) -> Box<dyn Iterable<RowId> + 'a> {
Box::new(ShuffledOptionalIndex {
Box::new(ShuffledIndex {
column_indexes,
merge_order,
})
}
struct ShuffledOptionalIndex<'a> {
column_indexes: &'a [Option<ColumnIndex>],
struct ShuffledIndex<'a> {
column_indexes: &'a [ColumnIndex],
merge_order: &'a ShuffleMergeOrder,
}
impl<'a> Iterable<u32> for ShuffledOptionalIndex<'a> {
impl<'a> Iterable<u32> for ShuffledIndex<'a> {
fn boxed_iter(&self) -> Box<dyn Iterator<Item = u32> + '_> {
Box::new(self.merge_order
.iter_new_to_old_row_addrs()
.enumerate()
.filter_map(|(new_row_id, old_row_addr)| {
let Some(column_index) = &self.column_indexes[old_row_addr.segment_ord as usize] else {
return None;
};
let row_id = new_row_id as u32;
if column_index.has_value(old_row_addr.row_id) {
Some(row_id)
} else {
None
}
}))
Box::new(
self.merge_order
.iter_new_to_old_row_addrs()
.enumerate()
.filter_map(|(new_row_id, old_row_addr)| {
let column_index = &self.column_indexes[old_row_addr.segment_ord as usize];
let row_id = new_row_id as u32;
if column_index.has_value(old_row_addr.row_id) {
Some(row_id)
} else {
None
}
}),
)
}
}
fn merge_column_index_shuffled_multivalued<'a>(
column_indexes: &'a [Option<ColumnIndex>],
column_indexes: &'a [ColumnIndex],
merge_order: &'a ShuffleMergeOrder,
) -> Box<dyn Iterable<RowId> + 'a> {
Box::new(ShuffledMultivaluedIndex {
@@ -77,20 +77,18 @@ fn merge_column_index_shuffled_multivalued<'a>(
}
struct ShuffledMultivaluedIndex<'a> {
column_indexes: &'a [Option<ColumnIndex>],
column_indexes: &'a [ColumnIndex],
merge_order: &'a ShuffleMergeOrder,
}
fn iter_num_values<'a>(
column_indexes: &'a [Option<ColumnIndex>],
column_indexes: &'a [ColumnIndex],
merge_order: &'a ShuffleMergeOrder,
) -> impl Iterator<Item = u32> + 'a {
merge_order.iter_new_to_old_row_addrs().map(|row_addr| {
let Some(column_index) = &column_indexes[row_addr.segment_ord as usize] else {
// No values in the entire column. It surely means there are 0 values associated to this row.
return 0u32;
};
let column_index = &column_indexes[row_addr.segment_ord as usize];
match column_index {
ColumnIndex::Empty { .. } => 0u32,
ColumnIndex::Full => 1,
ColumnIndex::Optional(optional_index) => {
u32::from(optional_index.contains(row_addr.row_id))
@@ -142,7 +140,7 @@ mod tests {
#[test]
fn test_merge_column_index_optional_shuffle() {
let optional_index: ColumnIndex = OptionalIndex::for_test(2, &[0]).into();
let column_indexes = vec![Some(optional_index), Some(ColumnIndex::Full)];
let column_indexes = vec![optional_index, ColumnIndex::Full];
let row_addrs = vec![
RowAddr {
segment_ord: 0u32,

View File

@@ -9,7 +9,7 @@ use crate::{Cardinality, ColumnIndex, RowId, StackMergeOrder};
///
/// There are no sort nor deletes involved.
pub fn merge_column_index_stacked<'a>(
columns: &'a [Option<ColumnIndex>],
columns: &'a [ColumnIndex],
cardinality_after_merge: Cardinality,
stack_merge_order: &'a StackMergeOrder,
) -> SerializableColumnIndex<'a> {
@@ -33,7 +33,7 @@ pub fn merge_column_index_stacked<'a>(
}
struct StackedOptionalIndex<'a> {
columns: &'a [Option<ColumnIndex>],
columns: &'a [ColumnIndex],
stack_merge_order: &'a StackMergeOrder,
}
@@ -46,16 +46,16 @@ impl<'a> Iterable<RowId> for StackedOptionalIndex<'a> {
.flat_map(|(columnar_id, column_index_opt)| {
let columnar_row_range = self.stack_merge_order.columnar_range(columnar_id);
let rows_it: Box<dyn Iterator<Item = RowId>> = match column_index_opt {
Some(ColumnIndex::Full) => Box::new(columnar_row_range),
Some(ColumnIndex::Optional(optional_index)) => Box::new(
ColumnIndex::Full => Box::new(columnar_row_range),
ColumnIndex::Optional(optional_index) => Box::new(
optional_index
.iter_rows()
.map(move |row_id: RowId| columnar_row_range.start + row_id),
),
Some(ColumnIndex::Multivalued(_)) => {
ColumnIndex::Multivalued(_) => {
panic!("No multivalued index is allowed when stacking column index");
}
None => Box::new(std::iter::empty()),
ColumnIndex::Empty { .. } => Box::new(std::iter::empty()),
};
rows_it
}),
@@ -65,18 +65,18 @@ impl<'a> Iterable<RowId> for StackedOptionalIndex<'a> {
#[derive(Clone, Copy)]
struct StackedMultivaluedIndex<'a> {
columns: &'a [Option<ColumnIndex>],
columns: &'a [ColumnIndex],
stack_merge_order: &'a StackMergeOrder,
}
fn convert_column_opt_to_multivalued_index<'a>(
column_index_opt: Option<&'a ColumnIndex>,
column_index_opt: &'a ColumnIndex,
num_rows: RowId,
) -> Box<dyn Iterator<Item = RowId> + 'a> {
match column_index_opt {
None => Box::new(iter::repeat(0u32).take(num_rows as usize + 1)),
Some(ColumnIndex::Full) => Box::new(0..num_rows + 1),
Some(ColumnIndex::Optional(optional_index)) => {
ColumnIndex::Empty { .. } => Box::new(iter::repeat(0u32).take(num_rows as usize + 1)),
ColumnIndex::Full => Box::new(0..num_rows + 1),
ColumnIndex::Optional(optional_index) => {
Box::new(
(0..num_rows)
// TODO optimize
@@ -84,9 +84,7 @@ fn convert_column_opt_to_multivalued_index<'a>(
.chain(std::iter::once(optional_index.num_non_nulls())),
)
}
Some(ColumnIndex::Multivalued(multivalued_index)) => {
multivalued_index.start_index_column.iter()
}
ColumnIndex::Multivalued(multivalued_index) => multivalued_index.start_index_column.iter(),
}
}
@@ -95,7 +93,6 @@ impl<'a> Iterable<RowId> for StackedMultivaluedIndex<'a> {
let multivalued_indexes =
self.columns
.iter()
.map(Option::as_ref)
.enumerate()
.map(|(columnar_id, column_opt)| {
let num_rows =

View File

@@ -12,8 +12,11 @@ pub use serialize::{open_column_index, serialize_column_index, SerializableColum
use crate::column_index::multivalued_index::MultiValueIndex;
use crate::{Cardinality, DocId, RowId};
#[derive(Clone)]
#[derive(Clone, Debug)]
pub enum ColumnIndex {
Empty {
num_docs: u32,
},
Full,
Optional(OptionalIndex),
/// In addition, at index num_rows, an extra value is added
@@ -34,9 +37,15 @@ impl From<MultiValueIndex> for ColumnIndex {
}
impl ColumnIndex {
// Returns the cardinality of the column index.
//
// By convention, if the column contains no docs, we consider that it is
// full.
#[inline]
pub fn get_cardinality(&self) -> Cardinality {
match self {
ColumnIndex::Full => Cardinality::Full,
ColumnIndex::Empty { num_docs: 0 } | ColumnIndex::Full => Cardinality::Full,
ColumnIndex::Empty { .. } => Cardinality::Optional,
ColumnIndex::Optional(_) => Cardinality::Optional,
ColumnIndex::Multivalued(_) => Cardinality::Multivalued,
}
@@ -45,6 +54,7 @@ impl ColumnIndex {
/// Returns true if and only if there are at least one value associated to the row.
pub fn has_value(&self, doc_id: DocId) -> bool {
match self {
ColumnIndex::Empty { .. } => false,
ColumnIndex::Full => true,
ColumnIndex::Optional(optional_index) => optional_index.contains(doc_id),
ColumnIndex::Multivalued(multivalued_index) => {
@@ -55,6 +65,7 @@ impl ColumnIndex {
pub fn value_row_ids(&self, doc_id: DocId) -> Range<RowId> {
match self {
ColumnIndex::Empty { .. } => 0..0,
ColumnIndex::Full => doc_id..doc_id + 1,
ColumnIndex::Optional(optional_index) => {
if let Some(val) = optional_index.rank_if_exists(doc_id) {
@@ -67,8 +78,48 @@ impl ColumnIndex {
}
}
/// Translates a block of docis to row_ids.
///
/// returns the row_ids and the matching docids on the same index
/// e.g.
/// DocId In: [0, 5, 6]
/// DocId Out: [0, 0, 6, 6]
/// RowId Out: [0, 1, 2, 3]
#[inline]
pub fn docids_to_rowids(
&self,
doc_ids: &[DocId],
doc_ids_out: &mut Vec<DocId>,
row_ids: &mut Vec<RowId>,
) {
match self {
ColumnIndex::Empty { .. } => {}
ColumnIndex::Full => {
doc_ids_out.extend_from_slice(doc_ids);
row_ids.extend_from_slice(doc_ids);
}
ColumnIndex::Optional(optional_index) => {
for doc_id in doc_ids {
if let Some(row_id) = optional_index.rank_if_exists(*doc_id) {
doc_ids_out.push(*doc_id);
row_ids.push(row_id);
}
}
}
ColumnIndex::Multivalued(multivalued_index) => {
for doc_id in doc_ids {
for row_id in multivalued_index.range(*doc_id) {
doc_ids_out.push(*doc_id);
row_ids.push(row_id);
}
}
}
}
}
pub fn docid_range_to_rowids(&self, doc_id: Range<DocId>) -> Range<RowId> {
match self {
ColumnIndex::Empty { .. } => 0..0,
ColumnIndex::Full => doc_id,
ColumnIndex::Optional(optional_index) => {
let row_start = optional_index.rank(doc_id.start);
@@ -87,8 +138,11 @@ impl ColumnIndex {
}
}
pub fn select_batch_in_place(&self, rank_ids: &mut Vec<RowId>, doc_id_start: DocId) {
pub fn select_batch_in_place(&self, doc_id_start: DocId, rank_ids: &mut Vec<RowId>) {
match self {
ColumnIndex::Empty { .. } => {
rank_ids.clear();
}
ColumnIndex::Full => {
// No need to do anything:
// value_idx and row_idx are the same.
@@ -102,3 +156,21 @@ impl ColumnIndex {
}
}
}
#[cfg(test)]
mod tests {
use crate::{Cardinality, ColumnIndex};
#[test]
fn test_column_index_get_cardinality() {
assert_eq!(
ColumnIndex::Empty { num_docs: 0 }.get_cardinality(),
Cardinality::Full
);
assert_eq!(ColumnIndex::Full.get_cardinality(), Cardinality::Full);
assert_eq!(
ColumnIndex::Empty { num_docs: 1 }.get_cardinality(),
Cardinality::Optional
);
}
}

View File

@@ -5,8 +5,9 @@ use std::sync::Arc;
use common::OwnedBytes;
use crate::column_values::u64_based::CodecType;
use crate::column_values::ColumnValues;
use crate::column_values::{
load_u64_based_column_values, serialize_u64_based_column_values, CodecType, ColumnValues,
};
use crate::iterable::Iterable;
use crate::{DocId, RowId};
@@ -14,7 +15,7 @@ pub fn serialize_multivalued_index(
multivalued_index: &dyn Iterable<RowId>,
output: &mut impl Write,
) -> io::Result<()> {
crate::column_values::u64_based::serialize_u64_based_column_values(
serialize_u64_based_column_values(
multivalued_index,
&[CodecType::Bitpacked, CodecType::Linear],
output,
@@ -23,8 +24,7 @@ pub fn serialize_multivalued_index(
}
pub fn open_multivalued_index(bytes: OwnedBytes) -> io::Result<MultiValueIndex> {
let start_index_column: Arc<dyn ColumnValues<RowId>> =
crate::column_values::u64_based::load_u64_based_column_values(bytes)?;
let start_index_column: Arc<dyn ColumnValues<RowId>> = load_u64_based_column_values(bytes)?;
Ok(MultiValueIndex { start_index_column })
}
@@ -35,6 +35,14 @@ pub struct MultiValueIndex {
pub start_index_column: Arc<dyn crate::ColumnValues<RowId>>,
}
impl std::fmt::Debug for MultiValueIndex {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("MultiValuedIndex")
.field("num_rows", &self.start_index_column.num_vals())
.finish_non_exhaustive()
}
}
impl From<Arc<dyn ColumnValues<RowId>>> for MultiValueIndex {
fn from(start_index_column: Arc<dyn ColumnValues<RowId>>) -> Self {
MultiValueIndex { start_index_column }
@@ -83,13 +91,13 @@ impl MultiValueIndex {
let mut cur_doc = docid_start;
let mut last_doc = None;
assert!(self.start_index_column.get_val(docid_start) as u32 <= ranks[0]);
assert!(self.start_index_column.get_val(docid_start) <= ranks[0]);
let mut write_doc_pos = 0;
for i in 0..ranks.len() {
let pos = ranks[i];
loop {
let end = self.start_index_column.get_val(cur_doc + 1) as u32;
let end = self.start_index_column.get_val(cur_doc + 1);
if end > pos {
ranks[write_doc_pos] = cur_doc;
write_doc_pos += if last_doc == Some(cur_doc) { 0 } else { 1 };
@@ -106,11 +114,8 @@ impl MultiValueIndex {
#[cfg(test)]
mod tests {
use std::ops::Range;
use std::sync::Arc;
use super::MultiValueIndex;
use crate::column_values::IterColumn;
use crate::{ColumnValues, RowId};
fn index_to_pos_helper(
index: &MultiValueIndex,
@@ -124,9 +129,7 @@ mod tests {
#[test]
fn test_positions_to_docid() {
let offsets: Vec<RowId> = vec![0, 10, 12, 15, 22, 23]; // docid values are [0..10, 10..12, 12..15, etc.]
let column: Arc<dyn ColumnValues<RowId>> = Arc::new(IterColumn::from(offsets.into_iter()));
let index = MultiValueIndex::from(column);
let index = MultiValueIndex::for_test(&[0, 10, 12, 15, 22, 23]);
assert_eq!(index.num_docs(), 5);
let positions = &[10u32, 11, 15, 20, 21, 22];
assert_eq!(index_to_pos_helper(&index, 0..5, positions), vec![1, 3, 4]);

View File

@@ -88,6 +88,15 @@ pub struct OptionalIndex {
block_metas: Arc<[BlockMeta]>,
}
impl std::fmt::Debug for OptionalIndex {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OptionalIndex")
.field("num_rows", &self.num_rows)
.field("num_non_null_rows", &self.num_non_null_rows)
.finish_non_exhaustive()
}
}
/// Splits a value address into lower and upper 16bits.
/// The lower 16 bits are the value in the block
/// The upper 16 bits are the block index
@@ -440,7 +449,7 @@ impl SerializedBlockMeta {
#[inline]
fn is_sparse(num_rows_in_block: u32) -> bool {
num_rows_in_block < DENSE_BLOCK_THRESHOLD as u32
num_rows_in_block < DENSE_BLOCK_THRESHOLD
}
fn deserialize_optional_index_block_metadatas(
@@ -448,7 +457,7 @@ fn deserialize_optional_index_block_metadatas(
num_rows: u32,
) -> (Box<[BlockMeta]>, u32) {
let num_blocks = data.len() / SERIALIZED_BLOCK_META_NUM_BYTES;
let mut block_metas = Vec::with_capacity(num_blocks as usize + 1);
let mut block_metas = Vec::with_capacity(num_blocks + 1);
let mut start_byte_offset = 0;
let mut non_null_rows_before_block = 0;
for block_meta_bytes in data.chunks_exact(SERIALIZED_BLOCK_META_NUM_BYTES) {
@@ -479,7 +488,7 @@ fn deserialize_optional_index_block_metadatas(
block_variant,
});
start_byte_offset += block_variant.num_bytes_in_block();
non_null_rows_before_block += num_non_null_rows as u32;
non_null_rows_before_block += num_non_null_rows;
}
block_metas.resize(
((num_rows + BLOCK_SIZE - 1) / BLOCK_SIZE) as usize,

View File

@@ -32,7 +32,7 @@ pub const MINI_BLOCK_NUM_BYTES: usize = MINI_BLOCK_BITVEC_NUM_BYTES + MINI_BLOCK
/// Number of bytes in a dense block.
pub const DENSE_BLOCK_NUM_BYTES: u32 =
(ELEMENTS_PER_BLOCK as u32 / ELEMENTS_PER_MINI_BLOCK as u32) * MINI_BLOCK_NUM_BYTES as u32;
(ELEMENTS_PER_BLOCK / ELEMENTS_PER_MINI_BLOCK as u32) * MINI_BLOCK_NUM_BYTES as u32;
pub struct DenseBlockCodec;
@@ -229,7 +229,7 @@ pub fn serialize_dense_codec(
while block_id > current_block_id {
let dense_mini_block = DenseMiniBlock {
bitvec: block,
rank: non_null_rows_before as u16,
rank: non_null_rows_before,
};
output.write_all(&dense_mini_block.to_bytes())?;
non_null_rows_before += block.count_ones() as u16;

View File

@@ -37,7 +37,7 @@ proptest! {
fn test_with_random_sets_simple() {
let vals = 10..BLOCK_SIZE * 2;
let mut out: Vec<u8> = Vec::new();
serialize_optional_index(&vals.clone(), 100, &mut out).unwrap();
serialize_optional_index(&vals, 100, &mut out).unwrap();
let null_index = open_optional_index(OwnedBytes::new(out)).unwrap();
let ranks: Vec<u32> = (65_472u32..65_473u32).collect();
let els: Vec<u32> = ranks.iter().copied().map(|rank| rank + 10).collect();

View File

@@ -0,0 +1,135 @@
use std::sync::Arc;
use common::OwnedBytes;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use test::{self, Bencher};
use super::*;
use crate::column_values::u64_based::*;
fn get_data() -> Vec<u64> {
let mut rng = StdRng::seed_from_u64(2u64);
let mut data: Vec<_> = (100..55000_u64)
.map(|num| num + rng.gen::<u8>() as u64)
.collect();
data.push(99_000);
data.insert(1000, 2000);
data.insert(2000, 100);
data.insert(3000, 4100);
data.insert(4000, 100);
data.insert(5000, 800);
data
}
fn compute_stats(vals: impl Iterator<Item = u64>) -> ColumnStats {
let mut stats_collector = StatsCollector::default();
for val in vals {
stats_collector.collect(val);
}
stats_collector.stats()
}
#[inline(never)]
fn value_iter() -> impl Iterator<Item = u64> {
0..20_000
}
fn get_reader_for_bench<Codec: ColumnCodec>(data: &[u64]) -> Codec::ColumnValues {
let mut bytes = Vec::new();
let stats = compute_stats(data.iter().cloned());
let mut codec_serializer = Codec::estimator();
for val in data {
codec_serializer.collect(*val);
}
codec_serializer.serialize(&stats, Box::new(data.iter().copied()).as_mut(), &mut bytes);
Codec::load(OwnedBytes::new(bytes)).unwrap()
}
fn bench_get<Codec: ColumnCodec>(b: &mut Bencher, data: &[u64]) {
let col = get_reader_for_bench::<Codec>(data);
b.iter(|| {
let mut sum = 0u64;
for pos in value_iter() {
let val = col.get_val(pos as u32);
sum = sum.wrapping_add(val);
}
sum
});
}
#[inline(never)]
fn bench_get_dynamic_helper(b: &mut Bencher, col: Arc<dyn ColumnValues>) {
b.iter(|| {
let mut sum = 0u64;
for pos in value_iter() {
let val = col.get_val(pos as u32);
sum = sum.wrapping_add(val);
}
sum
});
}
fn bench_get_dynamic<Codec: ColumnCodec>(b: &mut Bencher, data: &[u64]) {
let col = Arc::new(get_reader_for_bench::<Codec>(data));
bench_get_dynamic_helper(b, col);
}
fn bench_create<Codec: ColumnCodec>(b: &mut Bencher, data: &[u64]) {
let stats = compute_stats(data.iter().cloned());
let mut bytes = Vec::new();
b.iter(|| {
bytes.clear();
let mut codec_serializer = Codec::estimator();
for val in data.iter().take(1024) {
codec_serializer.collect(*val);
}
codec_serializer.serialize(&stats, Box::new(data.iter().copied()).as_mut(), &mut bytes)
});
}
#[bench]
fn bench_fastfield_bitpack_create(b: &mut Bencher) {
let data: Vec<_> = get_data();
bench_create::<BitpackedCodec>(b, &data);
}
#[bench]
fn bench_fastfield_linearinterpol_create(b: &mut Bencher) {
let data: Vec<_> = get_data();
bench_create::<LinearCodec>(b, &data);
}
#[bench]
fn bench_fastfield_multilinearinterpol_create(b: &mut Bencher) {
let data: Vec<_> = get_data();
bench_create::<BlockwiseLinearCodec>(b, &data);
}
#[bench]
fn bench_fastfield_bitpack_get(b: &mut Bencher) {
let data: Vec<_> = get_data();
bench_get::<BitpackedCodec>(b, &data);
}
#[bench]
fn bench_fastfield_bitpack_get_dynamic(b: &mut Bencher) {
let data: Vec<_> = get_data();
bench_get_dynamic::<BitpackedCodec>(b, &data);
}
#[bench]
fn bench_fastfield_linearinterpol_get(b: &mut Bencher) {
let data: Vec<_> = get_data();
bench_get::<LinearCodec>(b, &data);
}
#[bench]
fn bench_fastfield_linearinterpol_get_dynamic(b: &mut Bencher) {
let data: Vec<_> = get_data();
bench_get_dynamic::<LinearCodec>(b, &data);
}
#[bench]
fn bench_fastfield_multilinearinterpol_get(b: &mut Bencher) {
let data: Vec<_> = get_data();
bench_get::<BlockwiseLinearCodec>(b, &data);
}
#[bench]
fn bench_fastfield_multilinearinterpol_get_dynamic(b: &mut Bencher) {
let data: Vec<_> = get_data();
bench_get_dynamic::<BlockwiseLinearCodec>(b, &data);
}

View File

@@ -1,384 +0,0 @@
use std::fmt::Debug;
use std::marker::PhantomData;
use std::ops::{Range, RangeInclusive};
use std::sync::Arc;
use tantivy_bitpacker::minmax;
use crate::column_values::monotonic_mapping::StrictlyMonotonicFn;
use crate::RowId;
/// `ColumnValues` provides access to a dense field column.
///
/// `Column` are just a wrapper over `ColumnValues` and a `ColumnIndex`.
pub trait ColumnValues<T: PartialOrd = u64>: Send + Sync {
/// Return the value associated with the given idx.
///
/// This accessor should return as fast as possible.
///
/// # Panics
///
/// May panic if `idx` is greater than the column length.
fn get_val(&self, idx: u32) -> T;
/// Fills an output buffer with the fast field values
/// associated with the `DocId` going from
/// `start` to `start + output.len()`.
///
/// # Panics
///
/// Must panic if `start + output.len()` is greater than
/// the segment's `maxdoc`.
#[inline(always)]
fn get_range(&self, start: u64, output: &mut [T]) {
for (out, idx) in output.iter_mut().zip(start..) {
*out = self.get_val(idx as u32);
}
}
/// Get the row ids of values which are in the provided value range.
///
/// Note that position == docid for single value fast fields
#[inline(always)]
fn get_row_ids_for_value_range(
&self,
value_range: RangeInclusive<T>,
row_id_range: Range<RowId>,
row_id_hits: &mut Vec<RowId>,
) {
let row_id_range = row_id_range.start..row_id_range.end.min(self.num_vals());
for idx in row_id_range.start..row_id_range.end {
let val = self.get_val(idx);
if value_range.contains(&val) {
row_id_hits.push(idx);
}
}
}
/// Returns the minimum value for this fast field.
///
/// This min_value may not be exact.
/// For instance, the min value does not take in account of possible
/// deleted document. All values are however guaranteed to be higher than
/// `.min_value()`.
fn min_value(&self) -> T;
/// Returns the maximum value for this fast field.
///
/// This max_value may not be exact.
/// For instance, the max value does not take in account of possible
/// deleted document. All values are however guaranteed to be higher than
/// `.max_value()`.
fn max_value(&self) -> T;
/// The number of values in the column.
fn num_vals(&self) -> u32;
/// Returns a iterator over the data
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = T> + 'a> {
Box::new((0..self.num_vals()).map(|idx| self.get_val(idx)))
}
}
impl<T: Copy + PartialOrd + Debug> ColumnValues<T> for Arc<dyn ColumnValues<T>> {
#[inline(always)]
fn get_val(&self, idx: u32) -> T {
self.as_ref().get_val(idx)
}
#[inline(always)]
fn min_value(&self) -> T {
self.as_ref().min_value()
}
#[inline(always)]
fn max_value(&self) -> T {
self.as_ref().max_value()
}
#[inline(always)]
fn num_vals(&self) -> u32 {
self.as_ref().num_vals()
}
#[inline(always)]
fn iter<'b>(&'b self) -> Box<dyn Iterator<Item = T> + 'b> {
self.as_ref().iter()
}
#[inline(always)]
fn get_range(&self, start: u64, output: &mut [T]) {
self.as_ref().get_range(start, output)
}
}
impl<'a, C: ColumnValues<T> + ?Sized, T: Copy + PartialOrd + Debug> ColumnValues<T> for &'a C {
fn get_val(&self, idx: u32) -> T {
(*self).get_val(idx)
}
fn min_value(&self) -> T {
(*self).min_value()
}
fn max_value(&self) -> T {
(*self).max_value()
}
fn num_vals(&self) -> u32 {
(*self).num_vals()
}
fn iter<'b>(&'b self) -> Box<dyn Iterator<Item = T> + 'b> {
(*self).iter()
}
fn get_range(&self, start: u64, output: &mut [T]) {
(*self).get_range(start, output)
}
}
/// VecColumn provides `Column` over a slice.
pub struct VecColumn<'a, T = u64> {
pub(crate) values: &'a [T],
pub(crate) min_value: T,
pub(crate) max_value: T,
}
impl<'a, T: Copy + PartialOrd + Send + Sync + Debug> ColumnValues<T> for VecColumn<'a, T> {
fn get_val(&self, position: u32) -> T {
self.values[position as usize]
}
fn iter(&self) -> Box<dyn Iterator<Item = T> + '_> {
Box::new(self.values.iter().copied())
}
fn min_value(&self) -> T {
self.min_value
}
fn max_value(&self) -> T {
self.max_value
}
fn num_vals(&self) -> u32 {
self.values.len() as u32
}
fn get_range(&self, start: u64, output: &mut [T]) {
output.copy_from_slice(&self.values[start as usize..][..output.len()])
}
}
impl<'a, T: Copy + PartialOrd + Default, V> From<&'a V> for VecColumn<'a, T>
where V: AsRef<[T]> + ?Sized
{
fn from(values: &'a V) -> Self {
let values = values.as_ref();
let (min_value, max_value) = minmax(values.iter().copied()).unwrap_or_default();
Self {
values,
min_value,
max_value,
}
}
}
struct MonotonicMappingColumn<C, T, Input> {
from_column: C,
monotonic_mapping: T,
_phantom: PhantomData<Input>,
}
/// Creates a view of a column transformed by a strictly monotonic mapping. See
/// [`StrictlyMonotonicFn`].
///
/// E.g. apply a gcd monotonic_mapping([100, 200, 300]) == [1, 2, 3]
/// monotonic_mapping.mapping() is expected to be injective, and we should always have
/// monotonic_mapping.inverse(monotonic_mapping.mapping(el)) == el
///
/// The inverse of the mapping is required for:
/// `fn get_positions_for_value_range(&self, range: RangeInclusive<T>) -> Vec<u64> `
/// The user provides the original value range and we need to monotonic map them in the same way the
/// serialization does before calling the underlying column.
///
/// Note that when opening a codec, the monotonic_mapping should be the inverse of the mapping
/// during serialization. And therefore the monotonic_mapping_inv when opening is the same as
/// monotonic_mapping during serialization.
pub fn monotonic_map_column<C, T, Input, Output>(
from_column: C,
monotonic_mapping: T,
) -> impl ColumnValues<Output>
where
C: ColumnValues<Input>,
T: StrictlyMonotonicFn<Input, Output> + Send + Sync,
Input: PartialOrd + Debug + Send + Sync + Clone,
Output: PartialOrd + Debug + Send + Sync + Clone,
{
MonotonicMappingColumn {
from_column,
monotonic_mapping,
_phantom: PhantomData,
}
}
impl<C, T, Input, Output> ColumnValues<Output> for MonotonicMappingColumn<C, T, Input>
where
C: ColumnValues<Input>,
T: StrictlyMonotonicFn<Input, Output> + Send + Sync,
Input: PartialOrd + Send + Debug + Sync + Clone,
Output: PartialOrd + Send + Debug + Sync + Clone,
{
#[inline]
fn get_val(&self, idx: u32) -> Output {
let from_val = self.from_column.get_val(idx);
self.monotonic_mapping.mapping(from_val)
}
fn min_value(&self) -> Output {
let from_min_value = self.from_column.min_value();
self.monotonic_mapping.mapping(from_min_value)
}
fn max_value(&self) -> Output {
let from_max_value = self.from_column.max_value();
self.monotonic_mapping.mapping(from_max_value)
}
fn num_vals(&self) -> u32 {
self.from_column.num_vals()
}
fn iter(&self) -> Box<dyn Iterator<Item = Output> + '_> {
Box::new(
self.from_column
.iter()
.map(|el| self.monotonic_mapping.mapping(el)),
)
}
fn get_row_ids_for_value_range(
&self,
range: RangeInclusive<Output>,
doc_id_range: Range<u32>,
positions: &mut Vec<u32>,
) {
self.from_column.get_row_ids_for_value_range(
self.monotonic_mapping.inverse(range.start().clone())
..=self.monotonic_mapping.inverse(range.end().clone()),
doc_id_range,
positions,
)
}
// We voluntarily do not implement get_range as it yields a regression,
// and we do not have any specialized implementation anyway.
}
/// Wraps an iterator into a `Column`.
pub struct IterColumn<T>(T);
impl<T> From<T> for IterColumn<T>
where T: Iterator + Clone + ExactSizeIterator
{
fn from(iter: T) -> Self {
IterColumn(iter)
}
}
impl<T> ColumnValues<T::Item> for IterColumn<T>
where
T: Iterator + Clone + ExactSizeIterator + Send + Sync,
T::Item: PartialOrd + Debug,
{
fn get_val(&self, idx: u32) -> T::Item {
self.0.clone().nth(idx as usize).unwrap()
}
fn min_value(&self) -> T::Item {
self.0.clone().next().unwrap()
}
fn max_value(&self) -> T::Item {
self.0.clone().last().unwrap()
}
fn num_vals(&self) -> u32 {
self.0.len() as u32
}
fn iter(&self) -> Box<dyn Iterator<Item = T::Item> + '_> {
Box::new(self.0.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::column_values::monotonic_mapping::{
StrictlyMonotonicMappingInverter, StrictlyMonotonicMappingToInternalBaseval,
StrictlyMonotonicMappingToInternalGCDBaseval,
};
#[test]
fn test_monotonic_mapping() {
let vals = &[3u64, 5u64][..];
let col = VecColumn::from(vals);
let mapped = monotonic_map_column(col, StrictlyMonotonicMappingToInternalBaseval::new(2));
assert_eq!(mapped.min_value(), 1u64);
assert_eq!(mapped.max_value(), 3u64);
assert_eq!(mapped.num_vals(), 2);
assert_eq!(mapped.num_vals(), 2);
assert_eq!(mapped.get_val(0), 1);
assert_eq!(mapped.get_val(1), 3);
}
#[test]
fn test_range_as_col() {
let col = IterColumn::from(10..100);
assert_eq!(col.num_vals(), 90);
assert_eq!(col.max_value(), 99);
}
#[test]
fn test_monotonic_mapping_iter() {
let vals: Vec<u64> = (10..110u64).map(|el| el * 10).collect();
let col = VecColumn::from(&vals);
let mapped = monotonic_map_column(
col,
StrictlyMonotonicMappingInverter::from(
StrictlyMonotonicMappingToInternalGCDBaseval::new(10, 100),
),
);
let val_i64s: Vec<u64> = mapped.iter().collect();
for i in 0..100 {
assert_eq!(val_i64s[i as usize], mapped.get_val(i));
}
}
#[test]
fn test_monotonic_mapping_get_range() {
let vals: Vec<u64> = (0..100u64).map(|el| el * 10).collect();
let col = VecColumn::from(&vals);
let mapped = monotonic_map_column(
col,
StrictlyMonotonicMappingInverter::from(
StrictlyMonotonicMappingToInternalGCDBaseval::new(10, 0),
),
);
assert_eq!(mapped.min_value(), 0u64);
assert_eq!(mapped.max_value(), 9900u64);
assert_eq!(mapped.num_vals(), 100);
let val_u64s: Vec<u64> = mapped.iter().collect();
assert_eq!(val_u64s.len(), 100);
for i in 0..100 {
assert_eq!(val_u64s[i as usize], mapped.get_val(i));
assert_eq!(val_u64s[i as usize], vals[i as usize] * 10);
}
let mut buf = [0u64; 20];
mapped.get_range(7, &mut buf[..]);
assert_eq!(&val_u64s[7..][..20], &buf);
}
}

View File

@@ -0,0 +1,40 @@
use std::fmt::Debug;
use std::sync::Arc;
use crate::iterable::Iterable;
use crate::{ColumnIndex, ColumnValues, MergeRowOrder};
pub(crate) struct MergedColumnValues<'a, T> {
pub(crate) column_indexes: &'a [ColumnIndex],
pub(crate) column_values: &'a [Option<Arc<dyn ColumnValues<T>>>],
pub(crate) merge_row_order: &'a MergeRowOrder,
}
impl<'a, T: Copy + PartialOrd + Debug> Iterable<T> for MergedColumnValues<'a, T> {
fn boxed_iter(&self) -> Box<dyn Iterator<Item = T> + '_> {
match self.merge_row_order {
MergeRowOrder::Stack(_) => Box::new(
self.column_values
.iter()
.flatten()
.flat_map(|column_value| column_value.iter()),
),
MergeRowOrder::Shuffled(shuffle_merge_order) => Box::new(
shuffle_merge_order
.iter_new_to_old_row_addrs()
.flat_map(|row_addr| {
let column_index = &self.column_indexes[row_addr.segment_ord as usize];
let column_values =
self.column_values[row_addr.segment_ord as usize].as_ref()?;
let value_range = column_index.value_row_ids(row_addr.row_id);
Some((value_range, column_values))
})
.flat_map(|(value_range, column_values)| {
value_range
.into_iter()
.map(|val| column_values.get_val(val))
}),
),
}
}
}

View File

@@ -7,260 +7,202 @@
//! - Monotonically map values to u64/u128
use std::fmt::Debug;
use std::io;
use std::io::Write;
use std::ops::{Range, RangeInclusive};
use std::sync::Arc;
use common::{BinarySerializable, OwnedBytes};
use compact_space::CompactSpaceDecompressor;
pub use monotonic_mapping::{MonotonicallyMappableToU64, StrictlyMonotonicFn};
use monotonic_mapping::{StrictlyMonotonicMappingInverter, StrictlyMonotonicMappingToInternal};
pub use monotonic_mapping_u128::MonotonicallyMappableToU128;
use serialize::U128Header;
mod compact_space;
mod merge;
pub(crate) mod monotonic_mapping;
pub(crate) mod monotonic_mapping_u128;
mod stats;
pub(crate) mod u64_based;
mod u128_based;
mod u64_based;
mod vec_column;
mod column;
pub(crate) mod serialize;
mod monotonic_column;
pub use serialize::serialize_column_values_u128;
pub(crate) use merge::MergedColumnValues;
pub use stats::ColumnStats;
pub use u128_based::{open_u128_mapped, serialize_column_values_u128};
pub use u64_based::{
load_u64_based_column_values, serialize_and_load_u64_based_column_values,
serialize_u64_based_column_values, CodecType, ALL_U64_CODEC_TYPES,
};
pub use vec_column::VecColumn;
pub use self::column::{monotonic_map_column, ColumnValues, IterColumn, VecColumn};
use crate::iterable::Iterable;
use crate::{ColumnIndex, MergeRowOrder};
pub use self::monotonic_column::monotonic_map_column;
use crate::RowId;
pub(crate) struct MergedColumnValues<'a, T> {
pub(crate) column_indexes: &'a [Option<ColumnIndex>],
pub(crate) column_values: &'a [Option<Arc<dyn ColumnValues<T>>>],
pub(crate) merge_row_order: &'a MergeRowOrder,
}
/// `ColumnValues` provides access to a dense field column.
///
/// `Column` are just a wrapper over `ColumnValues` and a `ColumnIndex`.
///
/// Any methods with a default and specialized implementation need to be called in the
/// wrappers that implement the trait: Arc and MonotonicMappingColumn
pub trait ColumnValues<T: PartialOrd = u64>: Send + Sync {
/// Return the value associated with the given idx.
///
/// This accessor should return as fast as possible.
///
/// # Panics
///
/// May panic if `idx` is greater than the column length.
fn get_val(&self, idx: u32) -> T;
impl<'a, T: Copy + PartialOrd + Debug> Iterable<T> for MergedColumnValues<'a, T> {
fn boxed_iter(&self) -> Box<dyn Iterator<Item = T> + '_> {
match self.merge_row_order {
MergeRowOrder::Stack(_) => {
Box::new(self
.column_values
.iter()
.flatten()
.flat_map(|column_value| column_value.iter()))
},
MergeRowOrder::Shuffled(shuffle_merge_order) => {
Box::new(shuffle_merge_order
.iter_new_to_old_row_addrs()
.flat_map(|row_addr| {
let Some(column_index) = self.column_indexes[row_addr.segment_ord as usize].as_ref() else {
return None;
};
let Some(column_values) = self.column_values[row_addr.segment_ord as usize].as_ref() else {
return None;
};
let value_range = column_index.value_row_ids(row_addr.row_id);
Some((value_range, column_values))
})
.flat_map(|(value_range, column_values)| {
value_range
.into_iter()
.map(|val| column_values.get_val(val))
})
)
},
/// Allows to push down multiple fetch calls, to avoid dynamic dispatch overhead.
///
/// idx and output should have the same length
///
/// # Panics
///
/// May panic if `idx` is greater than the column length.
fn get_vals(&self, indexes: &[u32], output: &mut [T]) {
assert!(indexes.len() == output.len());
let out_and_idx_chunks = output.chunks_exact_mut(4).zip(indexes.chunks_exact(4));
for (out_x4, idx_x4) in out_and_idx_chunks {
out_x4[0] = self.get_val(idx_x4[0]);
out_x4[1] = self.get_val(idx_x4[1]);
out_x4[2] = self.get_val(idx_x4[2]);
out_x4[3] = self.get_val(idx_x4[3]);
}
let step_size = 4;
let cutoff = indexes.len() - indexes.len() % step_size;
for idx in cutoff..indexes.len() {
output[idx] = self.get_val(indexes[idx]);
}
}
}
#[derive(PartialEq, Eq, PartialOrd, Ord, Debug, Clone, Copy)]
#[repr(u8)]
/// Available codecs to use to encode the u128 (via [`MonotonicallyMappableToU128`]) converted data.
pub enum U128FastFieldCodecType {
/// This codec takes a large number space (u128) and reduces it to a compact number space, by
/// removing the holes.
CompactSpace = 1,
}
impl BinarySerializable for U128FastFieldCodecType {
fn serialize<W: Write + ?Sized>(&self, wrt: &mut W) -> io::Result<()> {
self.to_code().serialize(wrt)
}
fn deserialize<R: io::Read>(reader: &mut R) -> io::Result<Self> {
let code = u8::deserialize(reader)?;
let codec_type: Self = Self::from_code(code)
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Unknown code `{code}.`"))?;
Ok(codec_type)
}
}
impl U128FastFieldCodecType {
pub(crate) fn to_code(self) -> u8 {
self as u8
}
pub(crate) fn from_code(code: u8) -> Option<Self> {
match code {
1 => Some(Self::CompactSpace),
_ => None,
/// Fills an output buffer with the fast field values
/// associated with the `DocId` going from
/// `start` to `start + output.len()`.
///
/// # Panics
///
/// Must panic if `start + output.len()` is greater than
/// the segment's `maxdoc`.
#[inline(always)]
fn get_range(&self, start: u64, output: &mut [T]) {
for (out, idx) in output.iter_mut().zip(start..) {
*out = self.get_val(idx as u32);
}
}
/// Get the row ids of values which are in the provided value range.
///
/// Note that position == docid for single value fast fields
fn get_row_ids_for_value_range(
&self,
value_range: RangeInclusive<T>,
row_id_range: Range<RowId>,
row_id_hits: &mut Vec<RowId>,
) {
let row_id_range = row_id_range.start..row_id_range.end.min(self.num_vals());
for idx in row_id_range.start..row_id_range.end {
let val = self.get_val(idx);
if value_range.contains(&val) {
row_id_hits.push(idx);
}
}
}
/// Returns a lower bound for this column of values.
///
/// All values are guaranteed to be higher than `.min_value()`
/// but this value is not necessary the best boundary value.
///
/// We have
/// ∀i < self.num_vals(), self.get_val(i) >= self.min_value()
/// But we don't have necessarily
/// ∃i < self.num_vals(), self.get_val(i) == self.min_value()
fn min_value(&self) -> T;
/// Returns an upper bound for this column of values.
///
/// All values are guaranteed to be lower than `.max_value()`
/// but this value is not necessary the best boundary value.
///
/// We have
/// ∀i < self.num_vals(), self.get_val(i) <= self.max_value()
/// But we don't have necessarily
/// ∃i < self.num_vals(), self.get_val(i) == self.max_value()
fn max_value(&self) -> T;
/// The number of values in the column.
fn num_vals(&self) -> u32;
/// Returns a iterator over the data
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = T> + 'a> {
Box::new((0..self.num_vals()).map(|idx| self.get_val(idx)))
}
}
/// Returns the correct codec reader wrapped in the `Arc` for the data.
pub fn open_u128_mapped<T: MonotonicallyMappableToU128 + Debug>(
mut bytes: OwnedBytes,
) -> io::Result<Arc<dyn ColumnValues<T>>> {
let header = U128Header::deserialize(&mut bytes)?;
assert_eq!(header.codec_type, U128FastFieldCodecType::CompactSpace);
let reader = CompactSpaceDecompressor::open(bytes)?;
/// Empty column of values.
pub struct EmptyColumnValues;
let inverted: StrictlyMonotonicMappingInverter<StrictlyMonotonicMappingToInternal<T>> =
StrictlyMonotonicMappingToInternal::<T>::new().into();
Ok(Arc::new(monotonic_map_column(reader, inverted)))
impl<T: PartialOrd + Default> ColumnValues<T> for EmptyColumnValues {
fn get_val(&self, _idx: u32) -> T {
panic!("Internal Error: Called get_val of empty column.")
}
fn min_value(&self) -> T {
T::default()
}
fn max_value(&self) -> T {
T::default()
}
fn num_vals(&self) -> u32 {
0
}
}
impl<T: Copy + PartialOrd + Debug> ColumnValues<T> for Arc<dyn ColumnValues<T>> {
#[inline(always)]
fn get_val(&self, idx: u32) -> T {
self.as_ref().get_val(idx)
}
#[inline(always)]
fn min_value(&self) -> T {
self.as_ref().min_value()
}
#[inline(always)]
fn max_value(&self) -> T {
self.as_ref().max_value()
}
#[inline(always)]
fn num_vals(&self) -> u32 {
self.as_ref().num_vals()
}
#[inline(always)]
fn iter<'b>(&'b self) -> Box<dyn Iterator<Item = T> + 'b> {
self.as_ref().iter()
}
#[inline(always)]
fn get_range(&self, start: u64, output: &mut [T]) {
self.as_ref().get_range(start, output)
}
#[inline(always)]
fn get_row_ids_for_value_range(
&self,
range: RangeInclusive<T>,
doc_id_range: Range<u32>,
positions: &mut Vec<u32>,
) {
self.as_ref()
.get_row_ids_for_value_range(range, doc_id_range, positions)
}
}
#[cfg(all(test, feature = "unstable"))]
mod bench {
use std::sync::Arc;
use common::OwnedBytes;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use test::{self, Bencher};
use super::*;
use crate::column_values::u64_based::*;
fn get_data() -> Vec<u64> {
let mut rng = StdRng::seed_from_u64(2u64);
let mut data: Vec<_> = (100..55000_u64)
.map(|num| num + rng.gen::<u8>() as u64)
.collect();
data.push(99_000);
data.insert(1000, 2000);
data.insert(2000, 100);
data.insert(3000, 4100);
data.insert(4000, 100);
data.insert(5000, 800);
data
}
fn compute_stats(vals: impl Iterator<Item = u64>) -> ColumnStats {
let mut stats_collector = StatsCollector::default();
for val in vals {
stats_collector.collect(val);
}
stats_collector.stats()
}
#[inline(never)]
fn value_iter() -> impl Iterator<Item = u64> {
0..20_000
}
fn get_reader_for_bench<Codec: ColumnCodec>(data: &[u64]) -> Codec::ColumnValues {
let mut bytes = Vec::new();
let stats = compute_stats(data.iter().cloned());
let mut codec_serializer = Codec::estimator();
for val in data {
codec_serializer.collect(*val);
}
codec_serializer.serialize(&stats, Box::new(data.iter().copied()).as_mut(), &mut bytes);
Codec::load(OwnedBytes::new(bytes)).unwrap()
}
fn bench_get<Codec: ColumnCodec>(b: &mut Bencher, data: &[u64]) {
let col = get_reader_for_bench::<Codec>(data);
b.iter(|| {
let mut sum = 0u64;
for pos in value_iter() {
let val = col.get_val(pos as u32);
sum = sum.wrapping_add(val);
}
sum
});
}
#[inline(never)]
fn bench_get_dynamic_helper(b: &mut Bencher, col: Arc<dyn ColumnValues>) {
b.iter(|| {
let mut sum = 0u64;
for pos in value_iter() {
let val = col.get_val(pos as u32);
sum = sum.wrapping_add(val);
}
sum
});
}
fn bench_get_dynamic<Codec: ColumnCodec>(b: &mut Bencher, data: &[u64]) {
let col = Arc::new(get_reader_for_bench::<Codec>(data));
bench_get_dynamic_helper(b, col);
}
fn bench_create<Codec: ColumnCodec>(b: &mut Bencher, data: &[u64]) {
let stats = compute_stats(data.iter().cloned());
let mut bytes = Vec::new();
b.iter(|| {
bytes.clear();
let mut codec_serializer = Codec::estimator();
for val in data.iter().take(1024) {
codec_serializer.collect(*val);
}
codec_serializer.serialize(&stats, Box::new(data.iter().copied()).as_mut(), &mut bytes)
});
}
#[bench]
fn bench_fastfield_bitpack_create(b: &mut Bencher) {
let data: Vec<_> = get_data();
bench_create::<BitpackedCodec>(b, &data);
}
#[bench]
fn bench_fastfield_linearinterpol_create(b: &mut Bencher) {
let data: Vec<_> = get_data();
bench_create::<LinearCodec>(b, &data);
}
#[bench]
fn bench_fastfield_multilinearinterpol_create(b: &mut Bencher) {
let data: Vec<_> = get_data();
bench_create::<BlockwiseLinearCodec>(b, &data);
}
#[bench]
fn bench_fastfield_bitpack_get(b: &mut Bencher) {
let data: Vec<_> = get_data();
bench_get::<BitpackedCodec>(b, &data);
}
#[bench]
fn bench_fastfield_bitpack_get_dynamic(b: &mut Bencher) {
let data: Vec<_> = get_data();
bench_get_dynamic::<BitpackedCodec>(b, &data);
}
#[bench]
fn bench_fastfield_linearinterpol_get(b: &mut Bencher) {
let data: Vec<_> = get_data();
bench_get::<LinearCodec>(b, &data);
}
#[bench]
fn bench_fastfield_linearinterpol_get_dynamic(b: &mut Bencher) {
let data: Vec<_> = get_data();
bench_get_dynamic::<LinearCodec>(b, &data);
}
#[bench]
fn bench_fastfield_multilinearinterpol_get(b: &mut Bencher) {
let data: Vec<_> = get_data();
bench_get::<BlockwiseLinearCodec>(b, &data);
}
#[bench]
fn bench_fastfield_multilinearinterpol_get_dynamic(b: &mut Bencher) {
let data: Vec<_> = get_data();
bench_get_dynamic::<BlockwiseLinearCodec>(b, &data);
}
}
mod bench;

View File

@@ -0,0 +1,120 @@
use std::fmt::Debug;
use std::marker::PhantomData;
use std::ops::{Range, RangeInclusive};
use crate::column_values::monotonic_mapping::StrictlyMonotonicFn;
use crate::ColumnValues;
struct MonotonicMappingColumn<C, T, Input> {
from_column: C,
monotonic_mapping: T,
_phantom: PhantomData<Input>,
}
/// Creates a view of a column transformed by a strictly monotonic mapping. See
/// [`StrictlyMonotonicFn`].
///
/// E.g. apply a gcd monotonic_mapping([100, 200, 300]) == [1, 2, 3]
/// monotonic_mapping.mapping() is expected to be injective, and we should always have
/// monotonic_mapping.inverse(monotonic_mapping.mapping(el)) == el
///
/// The inverse of the mapping is required for:
/// `fn get_positions_for_value_range(&self, range: RangeInclusive<T>) -> Vec<u64> `
/// The user provides the original value range and we need to monotonic map them in the same way the
/// serialization does before calling the underlying column.
///
/// Note that when opening a codec, the monotonic_mapping should be the inverse of the mapping
/// during serialization. And therefore the monotonic_mapping_inv when opening is the same as
/// monotonic_mapping during serialization.
pub fn monotonic_map_column<C, T, Input, Output>(
from_column: C,
monotonic_mapping: T,
) -> impl ColumnValues<Output>
where
C: ColumnValues<Input>,
T: StrictlyMonotonicFn<Input, Output> + Send + Sync,
Input: PartialOrd + Debug + Send + Sync + Clone,
Output: PartialOrd + Debug + Send + Sync + Clone,
{
MonotonicMappingColumn {
from_column,
monotonic_mapping,
_phantom: PhantomData,
}
}
impl<C, T, Input, Output> ColumnValues<Output> for MonotonicMappingColumn<C, T, Input>
where
C: ColumnValues<Input>,
T: StrictlyMonotonicFn<Input, Output> + Send + Sync,
Input: PartialOrd + Send + Debug + Sync + Clone,
Output: PartialOrd + Send + Debug + Sync + Clone,
{
#[inline(always)]
fn get_val(&self, idx: u32) -> Output {
let from_val = self.from_column.get_val(idx);
self.monotonic_mapping.mapping(from_val)
}
fn min_value(&self) -> Output {
let from_min_value = self.from_column.min_value();
self.monotonic_mapping.mapping(from_min_value)
}
fn max_value(&self) -> Output {
let from_max_value = self.from_column.max_value();
self.monotonic_mapping.mapping(from_max_value)
}
fn num_vals(&self) -> u32 {
self.from_column.num_vals()
}
fn iter(&self) -> Box<dyn Iterator<Item = Output> + '_> {
Box::new(
self.from_column
.iter()
.map(|el| self.monotonic_mapping.mapping(el)),
)
}
fn get_row_ids_for_value_range(
&self,
range: RangeInclusive<Output>,
doc_id_range: Range<u32>,
positions: &mut Vec<u32>,
) {
self.from_column.get_row_ids_for_value_range(
self.monotonic_mapping.inverse(range.start().clone())
..=self.monotonic_mapping.inverse(range.end().clone()),
doc_id_range,
positions,
)
}
// We voluntarily do not implement get_range as it yields a regression,
// and we do not have any specialized implementation anyway.
}
#[cfg(test)]
mod tests {
use super::*;
use crate::column_values::monotonic_mapping::{
StrictlyMonotonicMappingInverter, StrictlyMonotonicMappingToInternal,
};
use crate::column_values::VecColumn;
#[test]
fn test_monotonic_mapping_iter() {
let vals: Vec<u64> = (0..100u64).map(|el| el * 10).collect();
let col = VecColumn::from(&vals);
let mapped = monotonic_map_column(
col,
StrictlyMonotonicMappingInverter::from(StrictlyMonotonicMappingToInternal::<i64>::new()),
);
let val_i64s: Vec<u64> = mapped.iter().collect();
for i in 0..100 {
assert_eq!(val_i64s[i as usize], mapped.get_val(i));
}
}
}

View File

@@ -2,7 +2,6 @@ use std::fmt::Debug;
use std::marker::PhantomData;
use common::DateTime;
use fastdivide::DividerU64;
use super::MonotonicallyMappableToU128;
use crate::RowId;
@@ -113,68 +112,6 @@ where T: MonotonicallyMappableToU64
}
}
/// Mapping dividing by gcd and a base value.
///
/// The function is assumed to be only called on values divided by passed
/// gcd value. (It is necessary for the function to be monotonic.)
pub(crate) struct StrictlyMonotonicMappingToInternalGCDBaseval {
gcd_divider: DividerU64,
gcd: u64,
min_value: u64,
}
impl StrictlyMonotonicMappingToInternalGCDBaseval {
/// Creates a linear mapping `x -> gcd*x + min_value`.
pub(crate) fn new(gcd: u64, min_value: u64) -> Self {
let gcd_divider = DividerU64::divide_by(gcd);
Self {
gcd_divider,
gcd,
min_value,
}
}
}
impl<External: MonotonicallyMappableToU64> StrictlyMonotonicFn<External, u64>
for StrictlyMonotonicMappingToInternalGCDBaseval
{
#[inline(always)]
fn mapping(&self, inp: External) -> u64 {
self.gcd_divider
.divide(External::to_u64(inp) - self.min_value)
}
#[inline(always)]
fn inverse(&self, out: u64) -> External {
External::from_u64(self.min_value + out * self.gcd)
}
}
/// Strictly monotonic mapping with a base value.
pub(crate) struct StrictlyMonotonicMappingToInternalBaseval {
min_value: u64,
}
impl StrictlyMonotonicMappingToInternalBaseval {
/// Creates a linear mapping `x -> x + min_value`.
#[inline(always)]
pub(crate) fn new(min_value: u64) -> Self {
Self { min_value }
}
}
impl<External: MonotonicallyMappableToU64> StrictlyMonotonicFn<External, u64>
for StrictlyMonotonicMappingToInternalBaseval
{
#[inline(always)]
fn mapping(&self, val: External) -> u64 {
External::to_u64(val) - self.min_value
}
#[inline(always)]
fn inverse(&self, val: u64) -> External {
External::from_u64(self.min_value + val)
}
}
impl MonotonicallyMappableToU64 for u64 {
#[inline(always)]
fn to_u64(self) -> u64 {
@@ -263,13 +200,6 @@ mod tests {
// TODO
// identity mapping
// test_round_trip(&StrictlyMonotonicMappingToInternal::<u128>::new(), 100u128);
// base value to i64 round trip
let mapping = StrictlyMonotonicMappingToInternalBaseval::new(100);
test_round_trip::<_, _, u64>(&mapping, 100i64);
// base value and gcd to u64 round trip
let mapping = StrictlyMonotonicMappingToInternalGCDBaseval::new(10, 100);
test_round_trip::<_, _, u64>(&mapping, 100u64);
}
fn test_round_trip<T: StrictlyMonotonicFn<K, L>, K: std::fmt::Debug + Eq + Copy, L>(

View File

@@ -10,7 +10,7 @@ use super::{CompactSpace, RangeMapping};
/// Put the blanks for the sorted values into a binary heap
fn get_blanks(values_sorted: &BTreeSet<u128>) -> BinaryHeap<BlankRange> {
let mut blanks: BinaryHeap<BlankRange> = BinaryHeap::new();
for (first, second) in values_sorted.iter().tuple_windows() {
for (first, second) in values_sorted.iter().copied().tuple_windows() {
// Correctness Overflow: the values are deduped and sorted (BTreeSet property), that means
// there's always space between two values.
let blank_range = first + 1..=second - 1;
@@ -65,12 +65,12 @@ pub fn get_compact_space(
return compact_space_builder.finish();
}
let mut blanks: BinaryHeap<BlankRange> = get_blanks(values_deduped_sorted);
// Replace after stabilization of https://github.com/rust-lang/rust/issues/62924
// We start by space that's limited to min_value..=max_value
let min_value = *values_deduped_sorted.iter().next().unwrap_or(&0);
let max_value = *values_deduped_sorted.iter().last().unwrap_or(&0);
// Replace after stabilization of https://github.com/rust-lang/rust/issues/62924
let min_value = values_deduped_sorted.iter().next().copied().unwrap_or(0);
let max_value = values_deduped_sorted.iter().last().copied().unwrap_or(0);
let mut blanks: BinaryHeap<BlankRange> = get_blanks(values_deduped_sorted);
// +1 for null, in case min and max covers the whole space, we are off by one.
let mut amplitude_compact_space = (max_value - min_value).saturating_add(1);
@@ -84,6 +84,7 @@ pub fn get_compact_space(
let mut amplitude_bits: u8 = num_bits(amplitude_compact_space);
let mut blank_collector = BlankCollector::new();
// We will stage blanks until they reduce the compact space by at least 1 bit and then flush
// them if the metadata cost is lower than the total number of saved bits.
// Binary heap to process the gaps by their size
@@ -93,6 +94,7 @@ pub fn get_compact_space(
let staged_spaces_sum: u128 = blank_collector.staged_blanks_sum();
let amplitude_new_compact_space = amplitude_compact_space - staged_spaces_sum;
let amplitude_new_bits = num_bits(amplitude_new_compact_space);
if amplitude_bits == amplitude_new_bits {
continue;
}
@@ -100,7 +102,16 @@ pub fn get_compact_space(
// TODO: Maybe calculate exact cost of blanks and run this more expensive computation only,
// when amplitude_new_bits changes
let cost = blank_collector.num_staged_blanks() * cost_per_blank;
if cost >= saved_bits {
// We want to end up with a compact space that fits into 32 bits.
// In order to deal with pathological cases, we force the algorithm to keep
// refining the compact space the amplitude bits is lower than 32.
//
// The worst case scenario happens for a large number of u128s regularly
// spread over the full u128 space.
//
// This change will force the algorithm to degenerate into dictionary encoding.
if amplitude_bits <= 32 && cost >= saved_bits {
// Continue here, since although we walk over the blanks by size,
// we can potentially save a lot at the last bits, which are smaller blanks
//
@@ -115,6 +126,8 @@ pub fn get_compact_space(
compact_space_builder.add_blanks(blank_collector.drain().map(|blank| blank.blank_range()));
}
assert!(amplitude_bits <= 32);
// special case, when we don't collected any blanks because:
// * the data is empty (early exit)
// * the algorithm did decide it's not worth the cost, which can be the case for single values
@@ -199,7 +212,7 @@ impl CompactSpaceBuilder {
covered_space.push(0..=0); // empty data case
};
let mut compact_start: u64 = 1; // 0 is reserved for `null`
let mut compact_start: u32 = 1; // 0 is reserved for `null`
let mut ranges_mapping: Vec<RangeMapping> = Vec::with_capacity(covered_space.len());
for cov in covered_space {
let range_mapping = super::RangeMapping {
@@ -218,6 +231,7 @@ impl CompactSpaceBuilder {
#[cfg(test)]
mod tests {
use super::*;
use crate::column_values::u128_based::compact_space::COST_PER_BLANK_IN_BITS;
#[test]
fn test_binary_heap_pop_order() {
@@ -228,4 +242,11 @@ mod tests {
assert_eq!(blanks.pop().unwrap().blank_size(), 101);
assert_eq!(blanks.pop().unwrap().blank_size(), 11);
}
#[test]
fn test_worst_case_scenario() {
let vals: BTreeSet<u128> = (0..8).map(|i| i * ((1u128 << 34) / 8)).collect();
let compact_space = get_compact_space(&vals, vals.len() as u32, COST_PER_BLANK_IN_BITS);
assert!(compact_space.amplitude_compact_space() < u32::MAX as u128);
}
}

View File

@@ -17,16 +17,16 @@ use std::{
ops::{Range, RangeInclusive},
};
mod blank_range;
mod build_compact_space;
use build_compact_space::get_compact_space;
use common::{BinarySerializable, CountingWriter, OwnedBytes, VInt, VIntU128};
use tantivy_bitpacker::{self, BitPacker, BitUnpacker};
use crate::column_values::compact_space::build_compact_space::get_compact_space;
use crate::column_values::ColumnValues;
use crate::RowId;
mod blank_range;
mod build_compact_space;
/// The cost per blank is quite hard actually, since blanks are delta encoded, the actual cost of
/// blanks depends on the number of blanks.
///
@@ -42,15 +42,15 @@ pub struct CompactSpace {
#[derive(Debug, Clone, Eq, PartialEq)]
struct RangeMapping {
value_range: RangeInclusive<u128>,
compact_start: u64,
compact_start: u32,
}
impl RangeMapping {
fn range_length(&self) -> u64 {
(self.value_range.end() - self.value_range.start()) as u64 + 1
fn range_length(&self) -> u32 {
(self.value_range.end() - self.value_range.start()) as u32 + 1
}
// The last value of the compact space in this range
fn compact_end(&self) -> u64 {
fn compact_end(&self) -> u32 {
self.compact_start + self.range_length() - 1
}
}
@@ -81,7 +81,7 @@ impl BinarySerializable for CompactSpace {
let num_ranges = VInt::deserialize(reader)?.0;
let mut ranges_mapping: Vec<RangeMapping> = vec![];
let mut value = 0u128;
let mut compact_start = 1u64; // 0 is reserved for `null`
let mut compact_start = 1u32; // 0 is reserved for `null`
for _ in 0..num_ranges {
let blank_delta_start = VIntU128::deserialize(reader)?.0;
value += blank_delta_start;
@@ -122,10 +122,10 @@ impl CompactSpace {
/// Returns either Ok(the value in the compact space) or if it is outside the compact space the
/// Err(position where it would be inserted)
fn u128_to_compact(&self, value: u128) -> Result<u64, usize> {
fn u128_to_compact(&self, value: u128) -> Result<u32, usize> {
self.ranges_mapping
.binary_search_by(|probe| {
let value_range = &probe.value_range;
let value_range: &RangeInclusive<u128> = &probe.value_range;
if value < *value_range.start() {
Ordering::Greater
} else if value > *value_range.end() {
@@ -136,13 +136,13 @@ impl CompactSpace {
})
.map(|pos| {
let range_mapping = &self.ranges_mapping[pos];
let pos_in_range = (value - range_mapping.value_range.start()) as u64;
let pos_in_range: u32 = (value - range_mapping.value_range.start()) as u32;
range_mapping.compact_start + pos_in_range
})
}
/// Unpacks a value from compact space u64 to u128 space
fn compact_to_u128(&self, compact: u64) -> u128 {
/// Unpacks a value from compact space u32 to u128 space
fn compact_to_u128(&self, compact: u32) -> u128 {
let pos = self
.ranges_mapping
.binary_search_by_key(&compact, |range_mapping| range_mapping.compact_start)
@@ -178,11 +178,15 @@ impl CompactSpaceCompressor {
/// Taking the vals as Vec may cost a lot of memory. It is used to sort the vals.
pub fn train_from(iter: impl Iterator<Item = u128>) -> Self {
let mut values_sorted = BTreeSet::new();
// Total number of values, with their redundancy.
let mut total_num_values = 0u32;
for val in iter {
total_num_values += 1u32;
values_sorted.insert(val);
}
let min_value = *values_sorted.iter().next().unwrap_or(&0);
let max_value = *values_sorted.iter().last().unwrap_or(&0);
let compact_space =
get_compact_space(&values_sorted, total_num_values, COST_PER_BLANK_IN_BITS);
let amplitude_compact_space = compact_space.amplitude_compact_space();
@@ -193,13 +197,12 @@ impl CompactSpaceCompressor {
);
let num_bits = tantivy_bitpacker::compute_num_bits(amplitude_compact_space as u64);
let min_value = *values_sorted.iter().next().unwrap_or(&0);
let max_value = *values_sorted.iter().last().unwrap_or(&0);
assert_eq!(
compact_space
.u128_to_compact(max_value)
.expect("could not convert max value to compact space"),
amplitude_compact_space as u64
amplitude_compact_space as u32
);
CompactSpaceCompressor {
params: IPCodecParams {
@@ -240,7 +243,7 @@ impl CompactSpaceCompressor {
"Could not convert value to compact_space. This is a bug.",
)
})?;
bitpacker.write(compact, self.params.num_bits, write)?;
bitpacker.write(compact as u64, self.params.num_bits, write)?;
}
bitpacker.close(write)?;
self.write_footer(write)?;
@@ -314,48 +317,6 @@ impl ColumnValues<u128> for CompactSpaceDecompressor {
#[inline]
fn get_row_ids_for_value_range(
&self,
value_range: RangeInclusive<u128>,
positions_range: Range<u32>,
positions: &mut Vec<u32>,
) {
self.get_positions_for_value_range(value_range, positions_range, positions)
}
}
impl CompactSpaceDecompressor {
pub fn open(data: OwnedBytes) -> io::Result<CompactSpaceDecompressor> {
let (data_slice, footer_len_bytes) = data.split_at(data.len() - 4);
let footer_len = u32::deserialize(&mut &footer_len_bytes[..])?;
let data_footer = &data_slice[data_slice.len() - footer_len as usize..];
let params = IPCodecParams::deserialize(&mut &data_footer[..])?;
let decompressor = CompactSpaceDecompressor { data, params };
Ok(decompressor)
}
/// Converting to compact space for the decompressor is more complex, since we may get values
/// which are outside the compact space. e.g. if we map
/// 1000 => 5
/// 2000 => 6
///
/// and we want a mapping for 1005, there is no equivalent compact space. We instead return an
/// error with the index of the next range.
fn u128_to_compact(&self, value: u128) -> Result<u64, usize> {
self.params.compact_space.u128_to_compact(value)
}
fn compact_to_u128(&self, compact: u64) -> u128 {
self.params.compact_space.compact_to_u128(compact)
}
/// Comparing on compact space: Random dataset 0,24 (50% random hit) - 1.05 GElements/s
/// Comparing on compact space: Real dataset 1.08 GElements/s
///
/// Comparing on original space: Real dataset .06 GElements/s (not completely optimized)
#[inline]
pub fn get_positions_for_value_range(
&self,
value_range: RangeInclusive<u128>,
position_range: Range<u32>,
@@ -395,44 +356,42 @@ impl CompactSpaceDecompressor {
range_mapping.compact_end()
});
let range = compact_from..=compact_to;
let value_range = compact_from..=compact_to;
self.get_positions_for_compact_value_range(value_range, position_range, positions);
}
}
let scan_num_docs = position_range.end - position_range.start;
impl CompactSpaceDecompressor {
pub fn open(data: OwnedBytes) -> io::Result<CompactSpaceDecompressor> {
let (data_slice, footer_len_bytes) = data.split_at(data.len() - 4);
let footer_len = u32::deserialize(&mut &footer_len_bytes[..])?;
let step_size = 4;
let cutoff = position_range.start + scan_num_docs - scan_num_docs % step_size;
let data_footer = &data_slice[data_slice.len() - footer_len as usize..];
let params = IPCodecParams::deserialize(&mut &data_footer[..])?;
let decompressor = CompactSpaceDecompressor { data, params };
let mut push_if_in_range = |idx, val| {
if range.contains(&val) {
positions.push(idx);
}
};
let get_val = |idx| self.params.bit_unpacker.get(idx, &self.data);
// unrolled loop
for idx in (position_range.start..cutoff).step_by(step_size as usize) {
let idx1 = idx;
let idx2 = idx + 1;
let idx3 = idx + 2;
let idx4 = idx + 3;
let val1 = get_val(idx1);
let val2 = get_val(idx2);
let val3 = get_val(idx3);
let val4 = get_val(idx4);
push_if_in_range(idx1, val1);
push_if_in_range(idx2, val2);
push_if_in_range(idx3, val3);
push_if_in_range(idx4, val4);
}
Ok(decompressor)
}
// handle rest
for idx in cutoff..position_range.end {
push_if_in_range(idx, get_val(idx));
}
/// Converting to compact space for the decompressor is more complex, since we may get values
/// which are outside the compact space. e.g. if we map
/// 1000 => 5
/// 2000 => 6
///
/// and we want a mapping for 1005, there is no equivalent compact space. We instead return an
/// error with the index of the next range.
fn u128_to_compact(&self, value: u128) -> Result<u32, usize> {
self.params.compact_space.u128_to_compact(value)
}
fn compact_to_u128(&self, compact: u32) -> u128 {
self.params.compact_space.compact_to_u128(compact)
}
#[inline]
fn iter_compact(&self) -> impl Iterator<Item = u64> + '_ {
(0..self.params.num_vals).map(move |idx| self.params.bit_unpacker.get(idx, &self.data))
fn iter_compact(&self) -> impl Iterator<Item = u32> + '_ {
(0..self.params.num_vals)
.map(move |idx| self.params.bit_unpacker.get(idx, &self.data) as u32)
}
#[inline]
@@ -445,7 +404,7 @@ impl CompactSpaceDecompressor {
#[inline]
pub fn get(&self, idx: u32) -> u128 {
let compact = self.params.bit_unpacker.get(idx, &self.data);
let compact = self.params.bit_unpacker.get(idx, &self.data) as u32;
self.compact_to_u128(compact)
}
@@ -456,6 +415,20 @@ impl CompactSpaceDecompressor {
pub fn max_value(&self) -> u128 {
self.params.max_value
}
fn get_positions_for_compact_value_range(
&self,
value_range: RangeInclusive<u32>,
position_range: Range<u32>,
positions: &mut Vec<u32>,
) {
self.params.bit_unpacker.get_ids_for_value_range(
*value_range.start() as u64..=*value_range.end() as u64,
position_range,
&self.data,
positions,
);
}
}
#[cfg(test)]
@@ -464,17 +437,17 @@ mod tests {
use itertools::Itertools;
use super::*;
use crate::column_values::serialize::U128Header;
use crate::column_values::u128_based::U128Header;
use crate::column_values::{open_u128_mapped, serialize_column_values_u128};
#[test]
fn compact_space_test() {
let ips = &[
let ips: BTreeSet<u128> = [
2u128, 4u128, 1000, 1001, 1002, 1003, 1004, 1005, 1008, 1010, 1012, 1260,
]
.into_iter()
.collect();
let compact_space = get_compact_space(ips, ips.len() as u32, 11);
let compact_space = get_compact_space(&ips, ips.len() as u32, 11);
let amplitude = compact_space.amplitude_compact_space();
assert_eq!(amplitude, 17);
assert_eq!(1, compact_space.u128_to_compact(2).unwrap());
@@ -497,8 +470,8 @@ mod tests {
);
for ip in ips {
let compact = compact_space.u128_to_compact(*ip).unwrap();
assert_eq!(compact_space.compact_to_u128(compact), *ip);
let compact = compact_space.u128_to_compact(ip).unwrap();
assert_eq!(compact_space.compact_to_u128(compact), ip);
}
}
@@ -524,7 +497,7 @@ mod tests {
.map(|pos| pos as u32)
.collect::<Vec<_>>();
let mut positions = Vec::new();
decompressor.get_positions_for_value_range(
decompressor.get_row_ids_for_value_range(
range,
0..decompressor.num_vals(),
&mut positions,
@@ -569,7 +542,7 @@ mod tests {
let val = *val;
let pos = pos as u32;
let mut positions = Vec::new();
decomp.get_positions_for_value_range(val..=val, pos..pos + 1, &mut positions);
decomp.get_row_ids_for_value_range(val..=val, pos..pos + 1, &mut positions);
assert_eq!(positions, vec![pos]);
}

View File

@@ -1,12 +1,19 @@
use std::fmt::Debug;
use std::io;
use std::io::Write;
use std::sync::Arc;
use common::{BinarySerializable, VInt};
mod compact_space;
use crate::column_values::compact_space::CompactSpaceCompressor;
use crate::column_values::U128FastFieldCodecType;
use common::{BinarySerializable, OwnedBytes, VInt};
use compact_space::{CompactSpaceCompressor, CompactSpaceDecompressor};
use crate::column_values::monotonic_map_column;
use crate::column_values::monotonic_mapping::{
StrictlyMonotonicMappingInverter, StrictlyMonotonicMappingToInternal,
};
use crate::iterable::Iterable;
use crate::MonotonicallyMappableToU128;
use crate::{ColumnValues, MonotonicallyMappableToU128};
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub(crate) struct U128Header {
@@ -55,6 +62,52 @@ pub fn serialize_column_values_u128<T: MonotonicallyMappableToU128>(
Ok(())
}
#[derive(PartialEq, Eq, PartialOrd, Ord, Debug, Clone, Copy)]
#[repr(u8)]
/// Available codecs to use to encode the u128 (via [`MonotonicallyMappableToU128`]) converted data.
pub(crate) enum U128FastFieldCodecType {
/// This codec takes a large number space (u128) and reduces it to a compact number space, by
/// removing the holes.
CompactSpace = 1,
}
impl BinarySerializable for U128FastFieldCodecType {
fn serialize<W: Write + ?Sized>(&self, wrt: &mut W) -> io::Result<()> {
self.to_code().serialize(wrt)
}
fn deserialize<R: io::Read>(reader: &mut R) -> io::Result<Self> {
let code = u8::deserialize(reader)?;
let codec_type: Self = Self::from_code(code)
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Unknown code `{code}.`"))?;
Ok(codec_type)
}
}
impl U128FastFieldCodecType {
pub(crate) fn to_code(self) -> u8 {
self as u8
}
pub(crate) fn from_code(code: u8) -> Option<Self> {
match code {
1 => Some(Self::CompactSpace),
_ => None,
}
}
}
/// Returns the correct codec reader wrapped in the `Arc` for the data.
pub fn open_u128_mapped<T: MonotonicallyMappableToU128 + Debug>(
mut bytes: OwnedBytes,
) -> io::Result<Arc<dyn ColumnValues<T>>> {
let header = U128Header::deserialize(&mut bytes)?;
assert_eq!(header.codec_type, U128FastFieldCodecType::CompactSpace);
let reader = CompactSpaceDecompressor::open(bytes)?;
let inverted: StrictlyMonotonicMappingInverter<StrictlyMonotonicMappingToInternal<T>> =
StrictlyMonotonicMappingToInternal::<T>::new().into();
Ok(Arc::new(monotonic_map_column(reader, inverted)))
}
#[cfg(test)]
pub mod tests {
use super::*;

View File

@@ -1,4 +1,6 @@
use std::io::{self, Write};
use std::num::NonZeroU64;
use std::ops::{Range, RangeInclusive};
use common::{BinarySerializable, OwnedBytes};
use fastdivide::DividerU64;
@@ -16,6 +18,46 @@ pub struct BitpackedReader {
stats: ColumnStats,
}
#[inline(always)]
const fn div_ceil(n: u64, q: NonZeroU64) -> u64 {
// copied from unstable rust standard library.
let d = n / q.get();
let r = n % q.get();
if r > 0 {
d + 1
} else {
d
}
}
// The bitpacked codec applies a linear transformation `f` over data that are bitpacked.
// f is defined by:
// f: bitpacked -> stats.min_value + stats.gcd * bitpacked
//
// In order to run range queries, we invert the transformation.
// `transform_range_before_linear_transformation` returns the range of values
// [min_bipacked_value..max_bitpacked_value] such that
// f(bitpacked) ∈ [min_value, max_value] <=> bitpacked ∈ [min_bitpacked_value, max_bitpacked_value]
fn transform_range_before_linear_transformation(
stats: &ColumnStats,
range: RangeInclusive<u64>,
) -> Option<RangeInclusive<u64>> {
if range.is_empty() {
return None;
}
if stats.min_value > *range.end() {
return None;
}
if stats.max_value < *range.start() {
return None;
}
let shifted_range =
range.start().saturating_sub(stats.min_value)..=range.end().saturating_sub(stats.min_value);
let start_before_gcd_multiplication: u64 = div_ceil(*shifted_range.start(), stats.gcd);
let end_before_gcd_multiplication: u64 = *shifted_range.end() / stats.gcd;
Some(start_before_gcd_multiplication..=end_before_gcd_multiplication)
}
impl ColumnValues for BitpackedReader {
#[inline(always)]
fn get_val(&self, doc: u32) -> u64 {
@@ -34,6 +76,25 @@ impl ColumnValues for BitpackedReader {
fn num_vals(&self) -> RowId {
self.stats.num_rows
}
fn get_row_ids_for_value_range(
&self,
range: RangeInclusive<u64>,
doc_id_range: Range<u32>,
positions: &mut Vec<u32>,
) {
let Some(transformed_range) = transform_range_before_linear_transformation(&self.stats, range)
else {
positions.clear();
return;
};
self.bit_unpacker.get_ids_for_value_range(
transformed_range,
doc_id_range,
&self.data,
positions,
);
}
}
fn num_bits(stats: &ColumnStats) -> u8 {

View File

@@ -201,8 +201,8 @@ pub struct BlockwiseLinearReader {
impl ColumnValues for BlockwiseLinearReader {
#[inline(always)]
fn get_val(&self, idx: u32) -> u64 {
let block_id = (idx / BLOCK_SIZE as u32) as usize;
let idx_within_block = idx % (BLOCK_SIZE as u32);
let block_id = (idx / BLOCK_SIZE) as usize;
let idx_within_block = idx % BLOCK_SIZE;
let block = &self.blocks[block_id];
let interpoled_val: u64 = block.line.eval(idx_within_block);
let block_bytes = &self.data[block.data_start_offset..];

View File

@@ -19,6 +19,62 @@ fn test_serialize_and_load_simple() {
assert_eq!(col.get_val(1), 2);
assert_eq!(col.get_val(2), 5);
}
#[test]
fn test_empty_column_i64() {
let vals: [i64; 0] = [];
let mut num_acceptable_codecs = 0;
for codec in ALL_U64_CODEC_TYPES {
let mut buffer = Vec::new();
if serialize_u64_based_column_values(&&vals[..], &[codec], &mut buffer).is_err() {
continue;
}
num_acceptable_codecs += 1;
let col = load_u64_based_column_values::<i64>(OwnedBytes::new(buffer)).unwrap();
assert_eq!(col.num_vals(), 0);
assert_eq!(col.min_value(), i64::MIN);
assert_eq!(col.max_value(), i64::MIN);
}
assert!(num_acceptable_codecs > 0);
}
#[test]
fn test_empty_column_u64() {
let vals: [u64; 0] = [];
let mut num_acceptable_codecs = 0;
for codec in ALL_U64_CODEC_TYPES {
let mut buffer = Vec::new();
if serialize_u64_based_column_values(&&vals[..], &[codec], &mut buffer).is_err() {
continue;
}
num_acceptable_codecs += 1;
let col = load_u64_based_column_values::<u64>(OwnedBytes::new(buffer)).unwrap();
assert_eq!(col.num_vals(), 0);
assert_eq!(col.min_value(), u64::MIN);
assert_eq!(col.max_value(), u64::MIN);
}
assert!(num_acceptable_codecs > 0);
}
#[test]
fn test_empty_column_f64() {
let vals: [f64; 0] = [];
let mut num_acceptable_codecs = 0;
for codec in ALL_U64_CODEC_TYPES {
let mut buffer = Vec::new();
if serialize_u64_based_column_values(&&vals[..], &[codec], &mut buffer).is_err() {
continue;
}
num_acceptable_codecs += 1;
let col = load_u64_based_column_values::<f64>(OwnedBytes::new(buffer)).unwrap();
assert_eq!(col.num_vals(), 0);
// FIXME. f64::MIN would be better!
assert!(col.min_value().is_nan());
assert!(col.max_value().is_nan());
}
assert!(num_acceptable_codecs > 0);
}
pub(crate) fn create_and_validate<TColumnCodec: ColumnCodec>(
vals: &[u64],
name: &str,
@@ -43,14 +99,28 @@ pub(crate) fn create_and_validate<TColumnCodec: ColumnCodec>(
let reader = TColumnCodec::load(OwnedBytes::new(buffer)).unwrap();
assert_eq!(reader.num_vals(), vals.len() as u32);
let mut buffer = Vec::new();
for (doc, orig_val) in vals.iter().copied().enumerate() {
let val = reader.get_val(doc as u32);
assert_eq!(
val, orig_val,
"val `{val}` does not match orig_val {orig_val:?}, in data set {name}, data `{vals:?}`",
);
buffer.resize(1, 0);
reader.get_vals(&[doc as u32], &mut buffer);
let val = buffer[0];
assert_eq!(
val, orig_val,
"val `{val}` does not match orig_val {orig_val:?}, in data set {name}, data `{vals:?}`",
);
}
let all_docs: Vec<u32> = (0..vals.len() as u32).collect();
buffer.resize(all_docs.len(), 0);
reader.get_vals(&all_docs, &mut buffer);
assert_eq!(vals, buffer);
if !vals.is_empty() {
let test_rand_idx = rand::thread_rng().gen_range(0..=vals.len() - 1);
let expected_positions: Vec<u32> = vals

View File

@@ -0,0 +1,52 @@
use std::fmt::Debug;
use tantivy_bitpacker::minmax;
use crate::ColumnValues;
/// VecColumn provides `Column` over a slice.
pub struct VecColumn<'a, T = u64> {
pub(crate) values: &'a [T],
pub(crate) min_value: T,
pub(crate) max_value: T,
}
impl<'a, T: Copy + PartialOrd + Send + Sync + Debug> ColumnValues<T> for VecColumn<'a, T> {
fn get_val(&self, position: u32) -> T {
self.values[position as usize]
}
fn iter(&self) -> Box<dyn Iterator<Item = T> + '_> {
Box::new(self.values.iter().copied())
}
fn min_value(&self) -> T {
self.min_value
}
fn max_value(&self) -> T {
self.max_value
}
fn num_vals(&self) -> u32 {
self.values.len() as u32
}
fn get_range(&self, start: u64, output: &mut [T]) {
output.copy_from_slice(&self.values[start as usize..][..output.len()])
}
}
impl<'a, T: Copy + PartialOrd + Default, V> From<&'a V> for VecColumn<'a, T>
where V: AsRef<[T]> + ?Sized
{
fn from(values: &'a V) -> Self {
let values = values.as_ref();
let (min_value, max_value) = minmax(values.iter().copied()).unwrap_or_default();
Self {
values,
min_value,
max_value,
}
}
}

View File

@@ -1,12 +1,15 @@
use std::fmt;
use std::fmt::Debug;
use std::net::Ipv6Addr;
use serde::{Deserialize, Serialize};
use crate::value::NumericalType;
use crate::InvalidData;
/// The column type represents the column type.
/// Any changes need to be propagated to `COLUMN_TYPES`.
#[derive(Hash, Eq, PartialEq, Debug, Clone, Copy, Ord, PartialOrd)]
#[derive(Hash, Eq, PartialEq, Debug, Clone, Copy, Ord, PartialOrd, Serialize, Deserialize)]
#[repr(u8)]
pub enum ColumnType {
I64 = 0u8,
@@ -19,6 +22,22 @@ pub enum ColumnType {
DateTime = 7u8,
}
impl fmt::Display for ColumnType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let short_str = match self {
ColumnType::I64 => "i64",
ColumnType::U64 => "u64",
ColumnType::F64 => "f64",
ColumnType::Bytes => "bytes",
ColumnType::Str => "str",
ColumnType::Bool => "bool",
ColumnType::IpAddr => "ip",
ColumnType::DateTime => "datetime",
};
write!(f, "{}", short_str)
}
}
// The order needs to match _exactly_ the order in the enum
const COLUMN_TYPES: [ColumnType; 8] = [
ColumnType::I64,
@@ -143,7 +162,7 @@ mod tests {
}
}
for code in COLUMN_TYPES.len() as u8..=u8::MAX {
assert!(ColumnType::try_from_code(code as u8).is_err());
assert!(ColumnType::try_from_code(code).is_err());
}
}

View File

@@ -52,21 +52,18 @@ impl<'a> Iterable for RemappedTermOrdinalsValues<'a> {
impl<'a> RemappedTermOrdinalsValues<'a> {
fn boxed_iter_stacked(&self) -> Box<dyn Iterator<Item = u64> + '_> {
let iter = self
.bytes_columns
.iter()
.enumerate()
.flat_map(|(segment_ord, byte_column)| {
let segment_ord = self.term_ord_mapping.get_segment(segment_ord as u32);
byte_column.iter().flat_map(move |bytes_column| {
bytes_column
.ords()
.values
.iter()
.map(move |term_ord| segment_ord[term_ord as usize])
})
});
// TODO see if we can better decompose the mapping / and the stacking
let iter = self.bytes_columns.iter().flatten().enumerate().flat_map(
move |(seg_ord_with_column, bytes_column)| {
let term_ord_after_merge_mapping = self
.term_ord_mapping
.get_segment(seg_ord_with_column as u32);
bytes_column
.ords()
.values
.iter()
.map(move |term_ord| term_ord_after_merge_mapping[term_ord as usize])
},
);
Box::new(iter)
}
@@ -96,7 +93,7 @@ fn compute_term_bitset(column: &BytesColumn, row_bitset: &ReadOnlyBitSet) -> Bit
let num_terms = column.dictionary().num_terms();
let mut term_bitset = BitSet::with_max_value(num_terms as u32);
for row_id in row_bitset.iter() {
for term_ord in column.term_ord_column.values(row_id) {
for term_ord in column.term_ord_column.values_for_doc(row_id) {
term_bitset.insert(term_ord as u32);
}
}
@@ -133,7 +130,6 @@ fn serialize_merged_dict(
let mut merged_terms = TermMerger::new(field_term_streams);
let mut sstable_builder = sstable::VoidSSTable::writer(output);
// TODO support complex `merge_row_order`.
match merge_row_order {
MergeRowOrder::Stack(_) => {
let mut current_term_ord = 0;
@@ -191,7 +187,7 @@ struct TermOrdinalMapping {
impl TermOrdinalMapping {
fn add_segment(&mut self, max_term_ord: usize) {
self.per_segment_new_term_ordinals
.push(vec![TermOrdinal::default(); max_term_ord as usize]);
.push(vec![TermOrdinal::default(); max_term_ord]);
}
fn register_from_to(&mut self, segment_ord: usize, from_ord: TermOrdinal, to_ord: TermOrdinal) {

View File

@@ -2,8 +2,6 @@ mod merge_dict_column;
mod merge_mapping;
mod term_merger;
// mod sorted_doc_id_column;
use std::collections::{BTreeMap, HashMap, HashSet};
use std::io;
use std::net::Ipv6Addr;
@@ -30,7 +28,7 @@ use crate::{
///
/// See also [README.md].
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
enum ColumnTypeCategory {
pub(crate) enum ColumnTypeCategory {
Bool,
Str,
Numerical,
@@ -54,26 +52,49 @@ impl From<ColumnType> for ColumnTypeCategory {
}
}
/// Merge several columnar table together.
///
/// If several columns with the same name are conflicting with the numerical types in the
/// input columnars, the first type compatible out of i64, u64, f64 in that order will be used.
///
/// `require_columns` makes it possible to ensure that some columns will be present in the
/// resulting columnar. When a required column is a numerical column type, one of two things can
/// happen:
/// - If the required column type is compatible with all of the input columnar, the resulsting
/// merged
/// columnar will simply coerce the input column and use the required column type.
/// - If the required column type is incompatible with one of the input columnar, the merged
/// will fail with an InvalidData error.
///
/// `merge_row_order` makes it possible to remove or reorder row in the resulting
/// `Columnar` table.
///
/// Reminder: a string and a numerical column may bare the same column name. This is not
/// considered a conflict.
pub fn merge_columnar(
columnar_readers: &[&ColumnarReader],
required_columns: &[(String, ColumnType)],
merge_row_order: MergeRowOrder,
output: &mut impl io::Write,
) -> io::Result<()> {
let mut serializer = ColumnarSerializer::new(output);
let columns_to_merge = group_columns_for_merge(columnar_readers)?;
let num_rows_per_columnar = columnar_readers
.iter()
.map(|reader| reader.num_rows())
.collect::<Vec<u32>>();
let columns_to_merge = group_columns_for_merge(columnar_readers, required_columns)?;
for ((column_name, column_type), columns) in columns_to_merge {
let mut column_serializer =
serializer.serialize_column(column_name.as_bytes(), column_type);
merge_column(
column_type,
&num_rows_per_columnar,
columns,
&merge_row_order,
&mut column_serializer,
)?;
}
serializer.finalize(merge_row_order.num_rows())?;
Ok(())
}
@@ -90,6 +111,7 @@ fn dynamic_column_to_u64_monotonic(dynamic_column: DynamicColumn) -> Option<Colu
fn merge_column(
column_type: ColumnType,
num_docs_per_column: &[u32],
columns: Vec<Option<DynamicColumn>>,
merge_row_order: &MergeRowOrder,
wrt: &mut impl io::Write,
@@ -100,17 +122,19 @@ fn merge_column(
| ColumnType::F64
| ColumnType::DateTime
| ColumnType::Bool => {
let mut column_indexes: Vec<Option<ColumnIndex>> = Vec::with_capacity(columns.len());
let mut column_indexes: Vec<ColumnIndex> = Vec::with_capacity(columns.len());
let mut column_values: Vec<Option<Arc<dyn ColumnValues>>> =
Vec::with_capacity(columns.len());
for dynamic_column_opt in columns {
if let Some(Column { idx, values }) =
for (i, dynamic_column_opt) in columns.into_iter().enumerate() {
if let Some(Column { index: idx, values }) =
dynamic_column_opt.and_then(dynamic_column_to_u64_monotonic)
{
column_indexes.push(Some(idx));
column_indexes.push(idx);
column_values.push(Some(values));
} else {
column_indexes.push(None);
column_indexes.push(ColumnIndex::Empty {
num_docs: num_docs_per_column[i],
});
column_values.push(None);
}
}
@@ -124,15 +148,19 @@ fn merge_column(
serialize_column_mappable_to_u64(merged_column_index, &merge_column_values, wrt)?;
}
ColumnType::IpAddr => {
let mut column_indexes: Vec<Option<ColumnIndex>> = Vec::with_capacity(columns.len());
let mut column_indexes: Vec<ColumnIndex> = Vec::with_capacity(columns.len());
let mut column_values: Vec<Option<Arc<dyn ColumnValues<Ipv6Addr>>>> =
Vec::with_capacity(columns.len());
for dynamic_column_opt in columns {
if let Some(DynamicColumn::IpAddr(Column { idx, values })) = dynamic_column_opt {
column_indexes.push(Some(idx));
for (i, dynamic_column_opt) in columns.into_iter().enumerate() {
if let Some(DynamicColumn::IpAddr(Column { index: idx, values })) =
dynamic_column_opt
{
column_indexes.push(idx);
column_values.push(Some(values));
} else {
column_indexes.push(None);
column_indexes.push(ColumnIndex::Empty {
num_docs: num_docs_per_column[i],
});
column_values.push(None);
}
}
@@ -148,20 +176,22 @@ fn merge_column(
serialize_column_mappable_to_u128(merged_column_index, &merge_column_values, wrt)?;
}
ColumnType::Bytes | ColumnType::Str => {
let mut column_indexes: Vec<Option<ColumnIndex>> = Vec::with_capacity(columns.len());
let mut column_indexes: Vec<ColumnIndex> = Vec::with_capacity(columns.len());
let mut bytes_columns: Vec<Option<BytesColumn>> = Vec::with_capacity(columns.len());
for dynamic_column_opt in columns {
for (i, dynamic_column_opt) in columns.into_iter().enumerate() {
match dynamic_column_opt {
Some(DynamicColumn::Str(str_column)) => {
column_indexes.push(Some(str_column.term_ord_column.idx.clone()));
column_indexes.push(str_column.term_ord_column.index.clone());
bytes_columns.push(Some(str_column.into()));
}
Some(DynamicColumn::Bytes(bytes_column)) => {
column_indexes.push(Some(bytes_column.term_ord_column.idx.clone()));
column_indexes.push(bytes_column.term_ord_column.index.clone());
bytes_columns.push(Some(bytes_column));
}
_ => {
column_indexes.push(None);
column_indexes.push(ColumnIndex::Empty {
num_docs: num_docs_per_column[i],
});
bytes_columns.push(None);
}
}
@@ -174,98 +204,183 @@ fn merge_column(
Ok(())
}
struct GroupedColumns {
required_column_type: Option<ColumnType>,
columns: Vec<Option<DynamicColumn>>,
column_category: ColumnTypeCategory,
}
impl GroupedColumns {
fn for_category(column_category: ColumnTypeCategory, num_columnars: usize) -> Self {
GroupedColumns {
required_column_type: None,
columns: vec![None; num_columnars],
column_category,
}
}
/// Set the dynamic column for a given columnar.
fn set_column(&mut self, columnar_id: usize, column: DynamicColumn) {
self.columns[columnar_id] = Some(column);
}
/// Force the existence of a column, as well as its type.
fn require_type(&mut self, required_type: ColumnType) -> io::Result<()> {
if let Some(existing_required_type) = self.required_column_type {
if existing_required_type == required_type {
// This was just a duplicate in the `required_columns`.
// Nothing to do.
return Ok(());
} else {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Required column conflicts with another required column of the same type \
category.",
));
}
}
self.required_column_type = Some(required_type);
Ok(())
}
/// Returns the column type after merge.
///
/// This method does not check if the column types can actually be coerced to
/// this type.
fn column_type_after_merge(&self) -> ColumnType {
if let Some(required_type) = self.required_column_type {
return required_type;
}
let column_type: HashSet<ColumnType> = self
.columns
.iter()
.flatten()
.map(|column| column.column_type())
.collect();
if column_type.len() == 1 {
return column_type.into_iter().next().unwrap();
}
// At the moment, only the numerical categorical column type has more than one possible
// column type.
assert_eq!(self.column_category, ColumnTypeCategory::Numerical);
merged_numerical_columns_type(self.columns.iter().flatten()).into()
}
}
/// Returns the type of the merged numerical column.
///
/// This function picks the first numerical type out of i64, u64, f64 (order matters
/// here), that is compatible with all the `columns`.
///
/// # Panics
/// Panics if one of the column is not numerical.
fn merged_numerical_columns_type<'a>(
columns: impl Iterator<Item = &'a DynamicColumn>,
) -> NumericalType {
let mut compatible_numerical_types = CompatibleNumericalTypes::default();
for column in columns {
let (min_value, max_value) =
min_max_if_numerical(column).expect("All columns re required to be numerical");
compatible_numerical_types.accept_value(min_value);
compatible_numerical_types.accept_value(max_value);
}
compatible_numerical_types.to_numerical_type()
}
#[allow(clippy::type_complexity)]
fn group_columns_for_merge(
columnar_readers: &[&ColumnarReader],
required_columns: &[(String, ColumnType)],
) -> io::Result<BTreeMap<(String, ColumnType), Vec<Option<DynamicColumn>>>> {
// Each column name may have multiple types of column associated.
// For merging we are interested in the same column type category since they can be merged.
let mut columns_grouped: HashMap<(String, ColumnTypeCategory), Vec<Option<DynamicColumn>>> =
HashMap::new();
let mut columns_grouped: HashMap<(String, ColumnTypeCategory), GroupedColumns> = HashMap::new();
let num_columnars = columnar_readers.len();
for &(ref column_name, column_type) in required_columns {
columns_grouped
.entry((column_name.clone(), column_type.into()))
.or_insert_with(|| {
GroupedColumns::for_category(column_type.into(), columnar_readers.len())
})
.require_type(column_type)?;
}
for (columnar_id, columnar_reader) in columnar_readers.iter().enumerate() {
let column_name_and_handle = columnar_reader.list_columns()?;
for (column_name, handle) in column_name_and_handle {
let column_type_category: ColumnTypeCategory = handle.column_type().into();
let columns = columns_grouped
.entry((column_name, column_type_category))
.or_insert_with(|| vec![None; num_columnars]);
let column_category: ColumnTypeCategory = handle.column_type().into();
let column = handle.open()?;
columns[columnar_id] = Some(column);
columns_grouped
.entry((column_name, column_category))
.or_insert_with(|| {
GroupedColumns::for_category(column_category, columnar_readers.len())
})
.set_column(columnar_id, column);
}
}
let mut merge_columns: BTreeMap<(String, ColumnType), Vec<Option<DynamicColumn>>> =
BTreeMap::default();
Default::default();
for ((column_name, col_category), mut columns) in columns_grouped {
if col_category == ColumnTypeCategory::Numerical {
coerce_numerical_columns_to_same_type(&mut columns);
}
let column_type = columns
.iter()
.flatten()
.map(|col| col.column_type())
.next()
.unwrap();
merge_columns.insert((column_name, column_type), columns);
for ((column_name, _), mut grouped_columns) in columns_grouped {
let column_type = grouped_columns.column_type_after_merge();
coerce_columns(column_type, &mut grouped_columns.columns)?;
merge_columns.insert((column_name, column_type), grouped_columns.columns);
}
Ok(merge_columns)
}
/// Coerce a set of numerical columns to the same type.
///
/// If all columns are already from the same type, keep this type
/// (even if they could all be coerced to i64).
fn coerce_numerical_columns_to_same_type(columns: &mut [Option<DynamicColumn>]) {
let mut column_types: HashSet<NumericalType> = HashSet::default();
let mut compatible_numerical_types = CompatibleNumericalTypes::default();
for column in columns.iter().flatten() {
let min_value: NumericalValue;
let max_value: NumericalValue;
match column {
DynamicColumn::I64(column) => {
min_value = column.min_value().into();
max_value = column.max_value().into();
}
DynamicColumn::U64(column) => {
min_value = column.min_value().into();
max_value = column.min_value().into();
}
DynamicColumn::F64(column) => {
min_value = column.min_value().into();
max_value = column.min_value().into();
}
DynamicColumn::Bool(_)
| DynamicColumn::IpAddr(_)
| DynamicColumn::DateTime(_)
| DynamicColumn::Bytes(_)
| DynamicColumn::Str(_) => {
panic!("We expected only numerical columns.");
}
}
column_types.insert(column.column_type().numerical_type().unwrap());
compatible_numerical_types.accept_value(min_value);
compatible_numerical_types.accept_value(max_value);
}
if column_types.len() <= 1 {
// No need to do anything. The columns are already all from the same type.
// This is necessary to let use force a given type.
// TODO This works in a world where we do not allow a change of schema,
// but in the future, we will have to pass some kind of schema to enforce
// the logic.
return;
}
let coerce_type = compatible_numerical_types.to_numerical_type();
fn coerce_columns(
column_type: ColumnType,
columns: &mut [Option<DynamicColumn>],
) -> io::Result<()> {
for column_opt in columns.iter_mut() {
if let Some(column) = column_opt.take() {
*column_opt = column.coerce_numerical(coerce_type);
*column_opt = Some(coerce_column(column_type, column)?);
}
}
Ok(())
}
fn coerce_column(column_type: ColumnType, column: DynamicColumn) -> io::Result<DynamicColumn> {
if let Some(numerical_type) = column_type.numerical_type() {
column
.coerce_numerical(numerical_type)
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, ""))
} else {
if column.column_type() != column_type {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!(
"Cannot coerce column of type `{:?}` to `{column_type:?}`",
column.column_type()
),
));
}
Ok(column)
}
}
/// Returns the (min, max) of a column provided it is numerical (i64, u64. f64).
///
/// The min and the max are simply the numerical value as defined by `ColumnValue::min_value()`,
/// and `ColumnValue::max_value()`.
///
/// It is important to note that these values are only guaranteed to be lower/upper bound
/// (as opposed to min/max value).
/// If a column is empty, the min and max values are currently set to 0.
fn min_max_if_numerical(column: &DynamicColumn) -> Option<(NumericalValue, NumericalValue)> {
match column {
DynamicColumn::I64(column) => Some((column.min_value().into(), column.max_value().into())),
DynamicColumn::U64(column) => Some((column.min_value().into(), column.max_value().into())),
DynamicColumn::F64(column) => Some((column.min_value().into(), column.max_value().into())),
DynamicColumn::Bool(_)
| DynamicColumn::IpAddr(_)
| DynamicColumn::DateTime(_)
| DynamicColumn::Bytes(_)
| DynamicColumn::Str(_) => None,
}
}
#[cfg(test)]

View File

@@ -1,107 +0,0 @@
use std::sync::Arc;
use fastfield_codecs::Column;
use itertools::Itertools;
use crate::indexer::doc_id_mapping::SegmentDocIdMapping;
use crate::SegmentReader;
pub(crate) struct RemappedDocIdColumn<'a> {
doc_id_mapping: &'a SegmentDocIdMapping,
fast_field_readers: Vec<Arc<dyn Column<u64>>>,
min_value: u64,
max_value: u64,
num_vals: u32,
}
fn compute_min_max_val(
u64_reader: &dyn Column<u64>,
segment_reader: &SegmentReader,
) -> Option<(u64, u64)> {
if segment_reader.max_doc() == 0 {
return None;
}
if segment_reader.alive_bitset().is_none() {
// no deleted documents,
// we can use the previous min_val, max_val.
return Some((u64_reader.min_value(), u64_reader.max_value()));
}
// some deleted documents,
// we need to recompute the max / min
segment_reader
.doc_ids_alive()
.map(|doc_id| u64_reader.get_val(doc_id))
.minmax()
.into_option()
}
impl<'a> RemappedDocIdColumn<'a> {
pub(crate) fn new(
readers: &'a [SegmentReader],
doc_id_mapping: &'a SegmentDocIdMapping,
field: &str,
) -> Self {
let (min_value, max_value) = readers
.iter()
.filter_map(|reader| {
let u64_reader: Arc<dyn Column<u64>> =
reader.fast_fields().typed_fast_field_reader(field).expect(
"Failed to find a reader for single fast field. This is a tantivy bug and \
it should never happen.",
);
compute_min_max_val(&*u64_reader, reader)
})
.reduce(|a, b| (a.0.min(b.0), a.1.max(b.1)))
.expect("Unexpected error, empty readers in IndexMerger");
let fast_field_readers = readers
.iter()
.map(|reader| {
let u64_reader: Arc<dyn Column<u64>> =
reader.fast_fields().typed_fast_field_reader(field).expect(
"Failed to find a reader for single fast field. This is a tantivy bug and \
it should never happen.",
);
u64_reader
})
.collect::<Vec<_>>();
RemappedDocIdColumn {
doc_id_mapping,
fast_field_readers,
min_value,
max_value,
num_vals: doc_id_mapping.len() as u32,
}
}
}
impl<'a> Column for RemappedDocIdColumn<'a> {
fn get_val(&self, _doc: u32) -> u64 {
unimplemented!()
}
fn iter(&self) -> Box<dyn Iterator<Item = u64> + '_> {
Box::new(
self.doc_id_mapping
.iter_old_doc_addrs()
.map(|old_doc_addr| {
let fast_field_reader =
&self.fast_field_readers[old_doc_addr.segment_ord as usize];
fast_field_reader.get_val(old_doc_addr.doc_id)
}),
)
}
fn min_value(&self) -> u64 {
self.min_value
}
fn max_value(&self) -> u64 {
self.max_value
}
fn num_vals(&self) -> u32 {
self.num_vals
}
}

View File

@@ -1,3 +1,5 @@
use itertools::Itertools;
use super::*;
use crate::{Cardinality, ColumnarWriter, HasAssociatedColumnType, RowId};
@@ -24,7 +26,7 @@ fn test_column_coercion_to_u64() {
// u64 type
let columnar2 = make_columnar("numbers", &[u64::MAX]);
let column_map: BTreeMap<(String, ColumnType), Vec<Option<DynamicColumn>>> =
group_columns_for_merge(&[&columnar1, &columnar2]).unwrap();
group_columns_for_merge(&[&columnar1, &columnar2], &[]).unwrap();
assert_eq!(column_map.len(), 1);
assert!(column_map.contains_key(&("numbers".to_string(), ColumnType::U64)));
}
@@ -34,7 +36,7 @@ fn test_column_no_coercion_if_all_the_same() {
let columnar1 = make_columnar("numbers", &[1u64]);
let columnar2 = make_columnar("numbers", &[2u64]);
let column_map: BTreeMap<(String, ColumnType), Vec<Option<DynamicColumn>>> =
group_columns_for_merge(&[&columnar1, &columnar2]).unwrap();
group_columns_for_merge(&[&columnar1, &columnar2], &[]).unwrap();
assert_eq!(column_map.len(), 1);
assert!(column_map.contains_key(&("numbers".to_string(), ColumnType::U64)));
}
@@ -44,17 +46,74 @@ fn test_column_coercion_to_i64() {
let columnar1 = make_columnar("numbers", &[-1i64]);
let columnar2 = make_columnar("numbers", &[2u64]);
let column_map: BTreeMap<(String, ColumnType), Vec<Option<DynamicColumn>>> =
group_columns_for_merge(&[&columnar1, &columnar2]).unwrap();
group_columns_for_merge(&[&columnar1, &columnar2], &[]).unwrap();
assert_eq!(column_map.len(), 1);
assert!(column_map.contains_key(&("numbers".to_string(), ColumnType::I64)));
}
#[test]
fn test_impossible_coercion_returns_an_error() {
let columnar1 = make_columnar("numbers", &[u64::MAX]);
let group_error =
group_columns_for_merge(&[&columnar1], &[("numbers".to_string(), ColumnType::I64)])
.map(|_| ())
.unwrap_err();
assert_eq!(group_error.kind(), io::ErrorKind::InvalidInput);
}
#[test]
fn test_group_columns_with_required_column() {
let columnar1 = make_columnar("numbers", &[1i64]);
let columnar2 = make_columnar("numbers", &[2u64]);
let column_map: BTreeMap<(String, ColumnType), Vec<Option<DynamicColumn>>> =
group_columns_for_merge(
&[&columnar1, &columnar2],
&[("numbers".to_string(), ColumnType::U64)],
)
.unwrap();
assert_eq!(column_map.len(), 1);
assert!(column_map.contains_key(&("numbers".to_string(), ColumnType::U64)));
}
#[test]
fn test_group_columns_required_column_with_no_existing_columns() {
let columnar1 = make_columnar("numbers", &[2u64]);
let columnar2 = make_columnar("numbers", &[2u64]);
let column_map: BTreeMap<(String, ColumnType), Vec<Option<DynamicColumn>>> =
group_columns_for_merge(
&[&columnar1, &columnar2],
&[("required_col".to_string(), ColumnType::Str)],
)
.unwrap();
assert_eq!(column_map.len(), 2);
let columns = column_map
.get(&("required_col".to_string(), ColumnType::Str))
.unwrap();
assert_eq!(columns.len(), 2);
assert!(columns[0].is_none());
assert!(columns[1].is_none());
}
#[test]
fn test_group_columns_required_column_is_above_all_columns_have_the_same_type_rule() {
let columnar1 = make_columnar("numbers", &[2i64]);
let columnar2 = make_columnar("numbers", &[2i64]);
let column_map: BTreeMap<(String, ColumnType), Vec<Option<DynamicColumn>>> =
group_columns_for_merge(
&[&columnar1, &columnar2],
&[("numbers".to_string(), ColumnType::U64)],
)
.unwrap();
assert_eq!(column_map.len(), 1);
assert!(column_map.contains_key(&("numbers".to_string(), ColumnType::U64)));
}
#[test]
fn test_missing_column() {
let columnar1 = make_columnar("numbers", &[-1i64]);
let columnar2 = make_columnar("numbers2", &[2u64]);
let column_map: BTreeMap<(String, ColumnType), Vec<Option<DynamicColumn>>> =
group_columns_for_merge(&[&columnar1, &columnar2]).unwrap();
group_columns_for_merge(&[&columnar1, &columnar2], &[]).unwrap();
assert_eq!(column_map.len(), 2);
assert!(column_map.contains_key(&("numbers".to_string(), ColumnType::I64)));
{
@@ -96,20 +155,24 @@ fn make_numerical_columnar_multiple_columns(
ColumnarReader::open(buffer).unwrap()
}
fn make_byte_columnar_multiple_columns(columns: &[(&str, &[&[&[u8]]])]) -> ColumnarReader {
#[track_caller]
fn make_byte_columnar_multiple_columns(
columns: &[(&str, &[&[&[u8]]])],
num_rows: u32,
) -> ColumnarReader {
let mut dataframe_writer = ColumnarWriter::default();
for (column_name, column_values) in columns {
assert_eq!(
column_values.len(),
num_rows as usize,
"All columns must have `{num_rows}` rows"
);
for (row_id, vals) in column_values.iter().enumerate() {
for val in vals.iter() {
dataframe_writer.record_bytes(row_id as u32, column_name, *val);
dataframe_writer.record_bytes(row_id as u32, column_name, val);
}
}
}
let num_rows = columns
.iter()
.map(|(_, val_rows)| val_rows.len() as RowId)
.max()
.unwrap_or(0u32);
let mut buffer: Vec<u8> = Vec::new();
dataframe_writer
.serialize(num_rows, None, &mut buffer)
@@ -122,7 +185,7 @@ fn make_text_columnar_multiple_columns(columns: &[(&str, &[&[&str]])]) -> Column
for (column_name, column_values) in columns {
for (row_id, vals) in column_values.iter().enumerate() {
for val in vals.iter() {
dataframe_writer.record_str(row_id as u32, column_name, *val);
dataframe_writer.record_str(row_id as u32, column_name, val);
}
}
}
@@ -151,6 +214,7 @@ fn test_merge_columnar_numbers() {
let stack_merge_order = StackMergeOrder::stack(columnars);
crate::columnar::merge_columnar(
columnars,
&[],
MergeRowOrder::Stack(stack_merge_order),
&mut buffer,
)
@@ -176,6 +240,7 @@ fn test_merge_columnar_texts() {
let stack_merge_order = StackMergeOrder::stack(columnars);
crate::columnar::merge_columnar(
columnars,
&[],
MergeRowOrder::Stack(stack_merge_order),
&mut buffer,
)
@@ -186,6 +251,8 @@ fn test_merge_columnar_texts() {
let cols = columnar_reader.read_columns("texts").unwrap();
let dynamic_column = cols[0].open().unwrap();
let DynamicColumn::Str(vals) = dynamic_column else { panic!() };
assert_eq!(vals.ords().get_cardinality(), Cardinality::Optional);
let get_str_for_ord = |ord| {
let mut out = String::new();
vals.ord_to_str(ord, &mut out).unwrap();
@@ -213,13 +280,14 @@ fn test_merge_columnar_texts() {
#[test]
fn test_merge_columnar_byte() {
let columnar1 = make_byte_columnar_multiple_columns(&[("bytes", &[&[b"bbbb"], &[b"baaa"]])]);
let columnar2 = make_byte_columnar_multiple_columns(&[("bytes", &[&[], &[b"a"]])]);
let columnar1 = make_byte_columnar_multiple_columns(&[("bytes", &[&[b"bbbb"], &[b"baaa"]])], 2);
let columnar2 = make_byte_columnar_multiple_columns(&[("bytes", &[&[], &[b"a"]])], 2);
let mut buffer = Vec::new();
let columnars = &[&columnar1, &columnar2];
let stack_merge_order = StackMergeOrder::stack(columnars);
crate::columnar::merge_columnar(
columnars,
&[],
MergeRowOrder::Stack(stack_merge_order),
&mut buffer,
)
@@ -256,3 +324,149 @@ fn test_merge_columnar_byte() {
assert_eq!(get_bytes_for_row(2), b"");
assert_eq!(get_bytes_for_row(3), b"a");
}
#[test]
fn test_merge_columnar_byte_with_missing() {
let columnar1 = make_byte_columnar_multiple_columns(&[], 3);
let columnar2 = make_byte_columnar_multiple_columns(&[("col", &[&[b"b"], &[]])], 2);
let columnar3 = make_byte_columnar_multiple_columns(
&[
("col", &[&[], &[b"b"], &[b"a", b"b"]]),
("col2", &[&[b"hello"], &[], &[b"a", b"b"]]),
],
3,
);
let mut buffer = Vec::new();
let columnars = &[&columnar1, &columnar2, &columnar3];
let stack_merge_order = StackMergeOrder::stack(columnars);
crate::columnar::merge_columnar(
columnars,
&[],
MergeRowOrder::Stack(stack_merge_order),
&mut buffer,
)
.unwrap();
let columnar_reader = ColumnarReader::open(buffer).unwrap();
assert_eq!(columnar_reader.num_rows(), 3 + 2 + 3);
assert_eq!(columnar_reader.num_columns(), 2);
let cols = columnar_reader.read_columns("col").unwrap();
let dynamic_column = cols[0].open().unwrap();
let DynamicColumn::Bytes(vals) = dynamic_column else { panic!() };
let get_bytes_for_ord = |ord| {
let mut out = Vec::new();
vals.ord_to_bytes(ord, &mut out).unwrap();
out
};
assert_eq!(vals.dictionary.num_terms(), 2);
assert_eq!(get_bytes_for_ord(0), b"a");
assert_eq!(get_bytes_for_ord(1), b"b");
let get_bytes_for_row = |row_id| {
let terms: Vec<Vec<u8>> = vals
.term_ords(row_id)
.map(|term_ord| {
let mut out = Vec::new();
vals.ord_to_bytes(term_ord, &mut out).unwrap();
out
})
.collect();
terms
};
assert!(get_bytes_for_row(0).is_empty());
assert!(get_bytes_for_row(1).is_empty());
assert!(get_bytes_for_row(2).is_empty());
assert_eq!(get_bytes_for_row(3), vec![b"b".to_vec()]);
assert!(get_bytes_for_row(4).is_empty());
assert!(get_bytes_for_row(5).is_empty());
assert_eq!(get_bytes_for_row(6), vec![b"b".to_vec()]);
assert_eq!(get_bytes_for_row(7), vec![b"a".to_vec(), b"b".to_vec()]);
}
#[test]
fn test_merge_columnar_different_types() {
let columnar1 = make_text_columnar_multiple_columns(&[("mixed", &[&["a"]])]);
let columnar2 = make_text_columnar_multiple_columns(&[("mixed", &[&[], &["b"]])]);
let columnar3 = make_columnar("mixed", &[1i64]);
let mut buffer = Vec::new();
let columnars = &[&columnar1, &columnar2, &columnar3];
let stack_merge_order = StackMergeOrder::stack(columnars);
crate::columnar::merge_columnar(
columnars,
&[],
MergeRowOrder::Stack(stack_merge_order),
&mut buffer,
)
.unwrap();
let columnar_reader = ColumnarReader::open(buffer).unwrap();
assert_eq!(columnar_reader.num_rows(), 4);
assert_eq!(columnar_reader.num_columns(), 2);
let cols = columnar_reader.read_columns("mixed").unwrap();
// numeric column
let dynamic_column = cols[0].open().unwrap();
let DynamicColumn::I64(vals) = dynamic_column else { panic!() };
assert_eq!(vals.get_cardinality(), Cardinality::Optional);
assert_eq!(vals.values_for_doc(0).collect_vec(), vec![]);
assert_eq!(vals.values_for_doc(1).collect_vec(), vec![]);
assert_eq!(vals.values_for_doc(2).collect_vec(), vec![]);
assert_eq!(vals.values_for_doc(3).collect_vec(), vec![1]);
assert_eq!(vals.values_for_doc(4).collect_vec(), vec![]);
// text column
let dynamic_column = cols[1].open().unwrap();
let DynamicColumn::Str(vals) = dynamic_column else { panic!() };
assert_eq!(vals.ords().get_cardinality(), Cardinality::Optional);
let get_str_for_ord = |ord| {
let mut out = String::new();
vals.ord_to_str(ord, &mut out).unwrap();
out
};
assert_eq!(vals.dictionary.num_terms(), 2);
assert_eq!(get_str_for_ord(0), "a");
assert_eq!(get_str_for_ord(1), "b");
let get_str_for_row = |row_id| {
let term_ords: Vec<String> = vals
.term_ords(row_id)
.map(|el| {
let mut out = String::new();
vals.ord_to_str(el, &mut out).unwrap();
out
})
.collect();
term_ords
};
assert_eq!(get_str_for_row(0), vec!["a".to_string()]);
assert_eq!(get_str_for_row(1), Vec::<String>::new());
assert_eq!(get_str_for_row(2), vec!["b".to_string()]);
assert_eq!(get_str_for_row(3), Vec::<String>::new());
}
#[test]
fn test_merge_columnar_different_empty_cardinality() {
let columnar1 = make_text_columnar_multiple_columns(&[("mixed", &[&["a"]])]);
let columnar2 = make_columnar("mixed", &[1i64]);
let mut buffer = Vec::new();
let columnars = &[&columnar1, &columnar2];
let stack_merge_order = StackMergeOrder::stack(columnars);
crate::columnar::merge_columnar(
columnars,
&[],
MergeRowOrder::Stack(stack_merge_order),
&mut buffer,
)
.unwrap();
let columnar_reader = ColumnarReader::open(buffer).unwrap();
assert_eq!(columnar_reader.num_rows(), 2);
assert_eq!(columnar_reader.num_columns(), 2);
let cols = columnar_reader.read_columns("mixed").unwrap();
// numeric column
let dynamic_column = cols[0].open().unwrap();
assert_eq!(dynamic_column.get_cardinality(), Cardinality::Optional);
// text column
let dynamic_column = cols[1].open().unwrap();
assert_eq!(dynamic_column.get_cardinality(), Cardinality::Optional);
}

View File

@@ -1 +0,0 @@

View File

@@ -1,11 +1,12 @@
mod column_type;
mod format_version;
mod merge;
mod merge_index;
mod reader;
mod writer;
pub use column_type::{ColumnType, HasAssociatedColumnType};
#[cfg(test)]
pub(crate) use merge::ColumnTypeCategory;
pub use merge::{merge_columnar, MergeRowOrder, ShuffleMergeOrder, StackMergeOrder};
pub use reader::ColumnarReader;
pub use writer::ColumnarWriter;

View File

@@ -1,4 +1,4 @@
use std::{io, mem};
use std::{fmt, io, mem};
use common::file_slice::FileSlice;
use common::BinarySerializable;
@@ -21,6 +21,58 @@ pub struct ColumnarReader {
num_rows: RowId,
}
impl fmt::Debug for ColumnarReader {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let num_rows = self.num_rows();
let columns = self.list_columns().unwrap();
let num_cols = columns.len();
let mut debug_struct = f.debug_struct("Columnar");
debug_struct
.field("num_rows", &num_rows)
.field("num_cols", &num_cols);
for (col_name, dynamic_column_handle) in columns.into_iter().take(5) {
let col = dynamic_column_handle.open().unwrap();
if col.num_values() > 10 {
debug_struct.field(&col_name, &"..");
} else {
debug_struct.field(&col_name, &col);
}
}
if num_cols > 5 {
debug_struct.finish_non_exhaustive()?;
} else {
debug_struct.finish()?;
}
Ok(())
}
}
/// Functions by both the async/sync code listing columns.
/// It takes a stream from the column sstable and return the list of
/// `DynamicColumn` available in it.
fn read_all_columns_in_stream(
mut stream: sstable::Streamer<'_, RangeSSTable>,
column_data: &FileSlice,
) -> io::Result<Vec<DynamicColumnHandle>> {
let mut results = Vec::new();
while stream.advance() {
let key_bytes: &[u8] = stream.key();
let Some(column_code) = key_bytes.last().copied() else {
return Err(io_invalid_data("Empty column name.".to_string()));
};
let column_type = ColumnType::try_from_code(column_code)
.map_err(|_| io_invalid_data(format!("Unknown column code `{column_code}`")))?;
let range = stream.value();
let file_slice = column_data.slice(range.start as usize..range.end as usize);
let dynamic_column_handle = DynamicColumnHandle {
file_slice,
column_type,
};
results.push(dynamic_column_handle);
}
Ok(results)
}
impl ColumnarReader {
/// Opens a new Columnar file.
pub fn open<F>(file_slice: F) -> io::Result<ColumnarReader>
@@ -76,11 +128,7 @@ impl ColumnarReader {
Ok(results)
}
/// Get all columns for the given column name.
///
/// There can be more than one column associated to a given column name, provided they have
/// different types.
pub fn read_columns(&self, column_name: &str) -> io::Result<Vec<DynamicColumnHandle>> {
fn stream_for_column_range(&self, column_name: &str) -> sstable::StreamerBuilder<RangeSSTable> {
// Each column is a associated to a given `column_key`,
// that starts by `column_name\0column_header`.
//
@@ -89,36 +137,35 @@ impl ColumnarReader {
//
// This is in turn equivalent to searching for the range
// `[column_name,\0`..column_name\1)`.
// TODO can we get some more generic `prefix(..)` logic in the dictioanry.
// TODO can we get some more generic `prefix(..)` logic in the dictionary.
let mut start_key = column_name.to_string();
start_key.push('\0');
let mut end_key = column_name.to_string();
end_key.push(1u8 as char);
let mut stream = self
.column_dictionary
self.column_dictionary
.range()
.ge(start_key.as_bytes())
.lt(end_key.as_bytes())
.into_stream()?;
let mut results = Vec::new();
while stream.advance() {
let key_bytes: &[u8] = stream.key();
assert!(key_bytes.starts_with(start_key.as_bytes()));
let column_code: u8 = key_bytes.last().cloned().unwrap();
let column_type = ColumnType::try_from_code(column_code)
.map_err(|_| io_invalid_data(format!("Unknown column code `{column_code}`")))?;
let range = stream.value().clone();
let file_slice = self
.column_data
.slice(range.start as usize..range.end as usize);
let dynamic_column_handle = DynamicColumnHandle {
file_slice,
column_type,
};
results.push(dynamic_column_handle);
}
Ok(results)
}
pub async fn read_columns_async(
&self,
column_name: &str,
) -> io::Result<Vec<DynamicColumnHandle>> {
let stream = self
.stream_for_column_range(column_name)
.into_stream_async()
.await?;
read_all_columns_in_stream(stream, &self.column_data)
}
/// Get all columns for the given column name.
///
/// There can be more than one column associated to a given column name, provided they have
/// different types.
pub fn read_columns(&self, column_name: &str) -> io::Result<Vec<DynamicColumnHandle>> {
let stream = self.stream_for_column_range(column_name).into_stream()?;
read_all_columns_in_stream(stream, &self.column_data)
}
/// Return the number of columns in the columnar.

View File

@@ -310,7 +310,7 @@ mod tests {
buffer.extend_from_slice(b"234234");
let mut bytes = &buffer[..];
let serdeser_symbol = ColumnOperation::deserialize(&mut bytes).unwrap();
assert_eq!(bytes.len() + buf.as_ref().len() as usize, buffer.len());
assert_eq!(bytes.len() + buf.as_ref().len(), buffer.len());
assert_eq!(column_op, serdeser_symbol);
}
@@ -341,7 +341,7 @@ mod tests {
fn test_column_operation_unordered_aux(val: u32, expected_len: usize) {
let column_op = ColumnOperation::Value(UnorderedId(val));
let minibuf = column_op.serialize();
assert_eq!(minibuf.as_ref().len() as usize, expected_len);
assert_eq!({ minibuf.as_ref().len() }, expected_len);
let mut buf = minibuf.as_ref().to_vec();
buf.extend_from_slice(&[2, 2, 2, 2, 2, 2]);
let mut cursor = &buf[..];

View File

@@ -104,16 +104,25 @@ impl ColumnarWriter {
};
let mut symbols_buffer = Vec::new();
let mut values = Vec::new();
let mut last_doc_opt: Option<RowId> = None;
let mut start_doc_check_fill = 0;
let mut current_doc_opt: Option<RowId> = None;
// Assumption: NewDoc will never call the same doc twice and is strictly increasing between
// calls
for op in numerical_col_writer.operation_iterator(&self.arena, None, &mut symbols_buffer) {
match op {
ColumnOperation::NewDoc(doc) => {
last_doc_opt = Some(doc);
current_doc_opt = Some(doc);
}
ColumnOperation::Value(numerical_value) => {
if let Some(last_doc) = last_doc_opt {
if let Some(current_doc) = current_doc_opt {
// Fill up with 0.0 since last doc
values.extend((start_doc_check_fill..current_doc).map(|doc| (0.0, doc)));
start_doc_check_fill = current_doc + 1;
// handle multi values
current_doc_opt = None;
let score: f32 = f64::coerce(numerical_value) as f32;
values.push((score, last_doc));
values.push((score, current_doc));
}
}
}
@@ -123,9 +132,9 @@ impl ColumnarWriter {
}
values.sort_by(|(left_score, _), (right_score, _)| {
if reversed {
right_score.partial_cmp(left_score).unwrap()
right_score.total_cmp(left_score)
} else {
left_score.partial_cmp(right_score).unwrap()
left_score.total_cmp(right_score)
}
});
values.into_iter().map(|(_score, doc)| doc).collect()
@@ -761,7 +770,7 @@ mod tests {
assert_eq!(column_writer.get_cardinality(3), Cardinality::Full);
let mut buffer = Vec::new();
let symbols: Vec<ColumnOperation<NumericalValue>> = column_writer
.operation_iterator(&mut arena, None, &mut buffer)
.operation_iterator(&arena, None, &mut buffer)
.collect();
assert_eq!(symbols.len(), 6);
assert!(matches!(symbols[0], ColumnOperation::NewDoc(0u32)));
@@ -790,7 +799,7 @@ mod tests {
assert_eq!(column_writer.get_cardinality(3), Cardinality::Optional);
let mut buffer = Vec::new();
let symbols: Vec<ColumnOperation<NumericalValue>> = column_writer
.operation_iterator(&mut arena, None, &mut buffer)
.operation_iterator(&arena, None, &mut buffer)
.collect();
assert_eq!(symbols.len(), 4);
assert!(matches!(symbols[0], ColumnOperation::NewDoc(1u32)));
@@ -813,7 +822,7 @@ mod tests {
assert_eq!(column_writer.get_cardinality(2), Cardinality::Optional);
let mut buffer = Vec::new();
let symbols: Vec<ColumnOperation<NumericalValue>> = column_writer
.operation_iterator(&mut arena, None, &mut buffer)
.operation_iterator(&arena, None, &mut buffer)
.collect();
assert_eq!(symbols.len(), 2);
assert!(matches!(symbols[0], ColumnOperation::NewDoc(0u32)));
@@ -832,7 +841,7 @@ mod tests {
assert_eq!(column_writer.get_cardinality(1), Cardinality::Multivalued);
let mut buffer = Vec::new();
let symbols: Vec<ColumnOperation<NumericalValue>> = column_writer
.operation_iterator(&mut arena, None, &mut buffer)
.operation_iterator(&arena, None, &mut buffer)
.collect();
assert_eq!(symbols.len(), 3);
assert!(matches!(symbols[0], ColumnOperation::NewDoc(0u32)));

View File

@@ -150,11 +150,7 @@ mod tests {
multivalued_value_index_builder.record_row(2u32);
multivalued_value_index_builder.record_value();
assert_eq!(
multivalued_value_index_builder
.finish(4u32)
.iter()
.copied()
.collect::<Vec<u32>>(),
multivalued_value_index_builder.finish(4u32).to_vec(),
vec![0, 0, 2, 3, 3]
);
multivalued_value_index_builder.reset();
@@ -162,11 +158,7 @@ mod tests {
multivalued_value_index_builder.record_value();
multivalued_value_index_builder.record_value();
assert_eq!(
multivalued_value_index_builder
.finish(4u32)
.iter()
.copied()
.collect::<Vec<u32>>(),
multivalued_value_index_builder.finish(4u32).to_vec(),
vec![0, 0, 0, 2, 2]
);
}

View File

@@ -1,14 +1,14 @@
use std::io;
use std::net::Ipv6Addr;
use std::sync::Arc;
use std::{fmt, io};
use common::file_slice::FileSlice;
use common::{DateTime, HasLen, OwnedBytes};
use common::{ByteCount, DateTime, HasLen, OwnedBytes};
use crate::column::{BytesColumn, Column, StrColumn};
use crate::column_values::{monotonic_map_column, StrictlyMonotonicFn};
use crate::columnar::ColumnType;
use crate::{Cardinality, NumericalType};
use crate::{Cardinality, ColumnIndex, NumericalType};
#[derive(Clone)]
pub enum DynamicColumn {
@@ -22,19 +22,54 @@ pub enum DynamicColumn {
Str(StrColumn),
}
impl DynamicColumn {
pub fn get_cardinality(&self) -> Cardinality {
impl fmt::Debug for DynamicColumn {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "[{} {} |", self.get_cardinality(), self.column_type())?;
match self {
DynamicColumn::Bool(c) => c.get_cardinality(),
DynamicColumn::I64(c) => c.get_cardinality(),
DynamicColumn::U64(c) => c.get_cardinality(),
DynamicColumn::F64(c) => c.get_cardinality(),
DynamicColumn::IpAddr(c) => c.get_cardinality(),
DynamicColumn::DateTime(c) => c.get_cardinality(),
DynamicColumn::Bytes(c) => c.ords().get_cardinality(),
DynamicColumn::Str(c) => c.ords().get_cardinality(),
DynamicColumn::Bool(col) => write!(f, " {:?}", col)?,
DynamicColumn::I64(col) => write!(f, " {:?}", col)?,
DynamicColumn::U64(col) => write!(f, " {:?}", col)?,
DynamicColumn::F64(col) => write!(f, "{:?}", col)?,
DynamicColumn::IpAddr(col) => write!(f, "{:?}", col)?,
DynamicColumn::DateTime(col) => write!(f, "{:?}", col)?,
DynamicColumn::Bytes(col) => write!(f, "{:?}", col)?,
DynamicColumn::Str(col) => write!(f, "{:?}", col)?,
}
write!(f, "]")
}
}
impl DynamicColumn {
pub fn column_index(&self) -> &ColumnIndex {
match self {
DynamicColumn::Bool(c) => &c.index,
DynamicColumn::I64(c) => &c.index,
DynamicColumn::U64(c) => &c.index,
DynamicColumn::F64(c) => &c.index,
DynamicColumn::IpAddr(c) => &c.index,
DynamicColumn::DateTime(c) => &c.index,
DynamicColumn::Bytes(c) => &c.ords().index,
DynamicColumn::Str(c) => &c.ords().index,
}
}
pub fn get_cardinality(&self) -> Cardinality {
self.column_index().get_cardinality()
}
pub fn num_values(&self) -> u32 {
match self {
DynamicColumn::Bool(c) => c.values.num_vals(),
DynamicColumn::I64(c) => c.values.num_vals(),
DynamicColumn::U64(c) => c.values.num_vals(),
DynamicColumn::F64(c) => c.values.num_vals(),
DynamicColumn::IpAddr(c) => c.values.num_vals(),
DynamicColumn::DateTime(c) => c.values.num_vals(),
DynamicColumn::Bytes(c) => c.ords().values.num_vals(),
DynamicColumn::Str(c) => c.ords().values.num_vals(),
}
}
pub fn column_type(&self) -> ColumnType {
match self {
DynamicColumn::Bool(_) => ColumnType::Bool,
@@ -73,11 +108,11 @@ impl DynamicColumn {
fn coerce_to_f64(self) -> Option<DynamicColumn> {
match self {
DynamicColumn::I64(column) => Some(DynamicColumn::F64(Column {
idx: column.idx,
index: column.index,
values: Arc::new(monotonic_map_column(column.values, MapI64ToF64)),
})),
DynamicColumn::U64(column) => Some(DynamicColumn::F64(Column {
idx: column.idx,
index: column.index,
values: Arc::new(monotonic_map_column(column.values, MapU64ToF64)),
})),
DynamicColumn::F64(_) => Some(self),
@@ -91,7 +126,7 @@ impl DynamicColumn {
return None;
}
Some(DynamicColumn::I64(Column {
idx: column.idx,
index: column.index,
values: Arc::new(monotonic_map_column(column.values, MapU64ToI64)),
}))
}
@@ -106,7 +141,7 @@ impl DynamicColumn {
return None;
}
Some(DynamicColumn::U64(Column {
idx: column.idx,
index: column.index,
values: Arc::new(monotonic_map_column(column.values, MapI64ToU64)),
}))
}
@@ -206,10 +241,9 @@ impl DynamicColumnHandle {
self.open_internal(column_bytes)
}
// TODO rename load_async
pub async fn open_async(&self) -> io::Result<DynamicColumn> {
let column_bytes: OwnedBytes = self.file_slice.read_bytes_async().await?;
self.open_internal(column_bytes)
#[doc(hidden)]
pub fn file_slice(&self) -> &FileSlice {
&self.file_slice
}
/// Returns the `u64` fast field reader reader associated with `fields` of types
@@ -249,8 +283,8 @@ impl DynamicColumnHandle {
Ok(dynamic_column)
}
pub fn num_bytes(&self) -> usize {
self.file_slice.len()
pub fn num_bytes(&self) -> ByteCount {
self.file_slice.len().into()
}
pub fn column_type(&self) -> ColumnType {

View File

@@ -7,8 +7,10 @@ extern crate more_asserts;
#[cfg(all(test, feature = "unstable"))]
extern crate test;
use std::fmt::Display;
use std::io;
mod block_accessor;
mod column;
mod column_index;
pub mod column_values;
@@ -19,9 +21,12 @@ mod iterable;
pub(crate) mod utils;
mod value;
pub use block_accessor::ColumnBlockAccessor;
pub use column::{BytesColumn, Column, StrColumn};
pub use column_index::ColumnIndex;
pub use column_values::{ColumnValues, MonotonicallyMappableToU128, MonotonicallyMappableToU64};
pub use column_values::{
ColumnValues, EmptyColumnValues, MonotonicallyMappableToU128, MonotonicallyMappableToU64,
};
pub use columnar::{
merge_columnar, ColumnType, ColumnarReader, ColumnarWriter, HasAssociatedColumnType,
MergeRowOrder, ShuffleMergeOrder, StackMergeOrder,
@@ -71,6 +76,17 @@ pub enum Cardinality {
Multivalued = 2,
}
impl Display for Cardinality {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
let short_str = match self {
Cardinality::Full => "full",
Cardinality::Optional => "opt",
Cardinality::Multivalued => "mult",
};
write!(f, "{short_str}")
}
}
impl Cardinality {
pub fn is_optional(&self) -> bool {
matches!(self, Cardinality::Optional)
@@ -81,7 +97,6 @@ impl Cardinality {
pub(crate) fn to_code(self) -> u8 {
self as u8
}
pub(crate) fn try_from_code(code: u8) -> Result<Cardinality, InvalidData> {
match code {
0 => Ok(Cardinality::Full),

View File

@@ -1,10 +1,17 @@
use std::collections::HashMap;
use std::fmt::Debug;
use std::net::Ipv6Addr;
use common::DateTime;
use proptest::prelude::*;
use crate::column_values::MonotonicallyMappableToU128;
use crate::columnar::ColumnType;
use crate::columnar::{ColumnType, ColumnTypeCategory};
use crate::dynamic_column::{DynamicColumn, DynamicColumnHandle};
use crate::value::NumericalValue;
use crate::{Cardinality, ColumnarReader, ColumnarWriter};
use crate::value::{Coerce, NumericalValue};
use crate::{
BytesColumn, Cardinality, Column, ColumnarReader, ColumnarWriter, RowId, StackMergeOrder,
};
#[test]
fn test_dataframe_writer_str() {
@@ -17,7 +24,7 @@ fn test_dataframe_writer_str() {
assert_eq!(columnar.num_columns(), 1);
let cols: Vec<DynamicColumnHandle> = columnar.read_columns("my_string").unwrap();
assert_eq!(cols.len(), 1);
assert_eq!(cols[0].num_bytes(), 158);
assert_eq!(cols[0].num_bytes(), 89);
}
#[test]
@@ -31,7 +38,7 @@ fn test_dataframe_writer_bytes() {
assert_eq!(columnar.num_columns(), 1);
let cols: Vec<DynamicColumnHandle> = columnar.read_columns("my_string").unwrap();
assert_eq!(cols.len(), 1);
assert_eq!(cols[0].num_bytes(), 158);
assert_eq!(cols[0].num_bytes(), 89);
}
#[test]
@@ -126,7 +133,7 @@ fn test_dataframe_writer_numerical() {
assert_eq!(cols[0].num_bytes(), 33);
let column = cols[0].open().unwrap();
let DynamicColumn::I64(column_i64) = column else { panic!(); };
assert_eq!(column_i64.idx.get_cardinality(), Cardinality::Optional);
assert_eq!(column_i64.index.get_cardinality(), Cardinality::Optional);
assert_eq!(column_i64.first(0), None);
assert_eq!(column_i64.first(1), Some(12i64));
assert_eq!(column_i64.first(2), Some(13i64));
@@ -136,6 +143,46 @@ fn test_dataframe_writer_numerical() {
assert_eq!(column_i64.first(6), None); //< we can change the spec for that one.
}
#[test]
fn test_dataframe_sort_by_full() {
let mut dataframe_writer = ColumnarWriter::default();
dataframe_writer.record_numerical(0u32, "value", NumericalValue::U64(1));
dataframe_writer.record_numerical(1u32, "value", NumericalValue::U64(2));
let data = dataframe_writer.sort_order("value", 2, false);
assert_eq!(data, vec![0, 1]);
}
#[test]
fn test_dataframe_sort_by_opt() {
let mut dataframe_writer = ColumnarWriter::default();
dataframe_writer.record_numerical(1u32, "value", NumericalValue::U64(3));
dataframe_writer.record_numerical(3u32, "value", NumericalValue::U64(2));
let data = dataframe_writer.sort_order("value", 5, false);
// 0, 2, 4 is 0.0
assert_eq!(data, vec![0, 2, 4, 3, 1]);
let data = dataframe_writer.sort_order("value", 5, true);
assert_eq!(
data,
vec![4, 2, 0, 3, 1].into_iter().rev().collect::<Vec<_>>()
);
}
#[test]
fn test_dataframe_sort_by_multi() {
let mut dataframe_writer = ColumnarWriter::default();
// valid for sort
dataframe_writer.record_numerical(1u32, "value", NumericalValue::U64(2));
// those are ignored for sort
dataframe_writer.record_numerical(1u32, "value", NumericalValue::U64(4));
dataframe_writer.record_numerical(1u32, "value", NumericalValue::U64(4));
// valid for sort
dataframe_writer.record_numerical(3u32, "value", NumericalValue::U64(3));
// ignored, would change sort order
dataframe_writer.record_numerical(3u32, "value", NumericalValue::U64(1));
let data = dataframe_writer.sort_order("value", 4, false);
assert_eq!(data, vec![0, 2, 1, 3]);
}
#[test]
fn test_dictionary_encoded_str() {
let mut buffer = Vec::new();
@@ -210,3 +257,497 @@ fn test_dictionary_encoded_bytes() {
.unwrap();
assert_eq!(term_buffer, b"b");
}
fn num_strategy() -> impl Strategy<Value = NumericalValue> {
prop_oneof![
Just(NumericalValue::U64(0u64)),
Just(NumericalValue::U64(u64::MAX)),
Just(NumericalValue::I64(0i64)),
Just(NumericalValue::I64(i64::MIN)),
Just(NumericalValue::I64(i64::MAX)),
Just(NumericalValue::F64(1.2f64)),
]
}
#[derive(Debug, Clone, Copy)]
enum ColumnValue {
Str(&'static str),
Bytes(&'static [u8]),
Numerical(NumericalValue),
IpAddr(Ipv6Addr),
Bool(bool),
DateTime(DateTime),
}
impl ColumnValue {
pub(crate) fn column_type_category(&self) -> ColumnTypeCategory {
match self {
ColumnValue::Str(_) => ColumnTypeCategory::Str,
ColumnValue::Bytes(_) => ColumnTypeCategory::Bytes,
ColumnValue::Numerical(_) => ColumnTypeCategory::Numerical,
ColumnValue::IpAddr(_) => ColumnTypeCategory::IpAddr,
ColumnValue::Bool(_) => ColumnTypeCategory::Bool,
ColumnValue::DateTime(_) => ColumnTypeCategory::DateTime,
}
}
}
fn column_name_strategy() -> impl Strategy<Value = &'static str> {
prop_oneof![Just("c1"), Just("c2")]
}
fn string_strategy() -> impl Strategy<Value = &'static str> {
prop_oneof![Just("a"), Just("b")]
}
fn bytes_strategy() -> impl Strategy<Value = &'static [u8]> {
prop_oneof![Just(&[0u8][..]), Just(&[1u8][..])]
}
// A random column value
fn column_value_strategy() -> impl Strategy<Value = ColumnValue> {
prop_oneof![
10 => string_strategy().prop_map(|s| ColumnValue::Str(s)),
1 => bytes_strategy().prop_map(|b| ColumnValue::Bytes(b)),
40 => num_strategy().prop_map(|n| ColumnValue::Numerical(n)),
1 => (1u16..3u16).prop_map(|ip_addr_byte| ColumnValue::IpAddr(Ipv6Addr::new(
127,
0,
0,
0,
0,
0,
0,
ip_addr_byte
))),
1 => any::<bool>().prop_map(|b| ColumnValue::Bool(b)),
1 => (0_679_723_993i64..1_679_723_995i64)
.prop_map(|val| { ColumnValue::DateTime(DateTime::from_timestamp_secs(val)) })
]
}
// A document contains up to 4 values.
fn doc_strategy() -> impl Strategy<Value = Vec<(&'static str, ColumnValue)>> {
proptest::collection::vec((column_name_strategy(), column_value_strategy()), 0..4)
}
// A columnar contains up to 2 docs.
fn columnar_docs_strategy() -> impl Strategy<Value = Vec<Vec<(&'static str, ColumnValue)>>> {
proptest::collection::vec(doc_strategy(), 0..=2)
}
fn columnar_docs_and_mapping_strategy(
) -> impl Strategy<Value = (Vec<Vec<(&'static str, ColumnValue)>>, Vec<RowId>)> {
columnar_docs_strategy().prop_flat_map(|docs| {
permutation_strategy(docs.len()).prop_map(move |permutation| (docs.clone(), permutation))
})
}
fn permutation_strategy(n: usize) -> impl Strategy<Value = Vec<RowId>> {
Just((0u32..n as RowId).collect()).prop_shuffle()
}
fn build_columnar_with_mapping(
docs: &[Vec<(&'static str, ColumnValue)>],
old_to_new_row_ids_opt: Option<&[RowId]>,
) -> ColumnarReader {
let num_docs = docs.len() as u32;
let mut buffer = Vec::new();
let mut columnar_writer = ColumnarWriter::default();
for (doc_id, vals) in docs.iter().enumerate() {
for (column_name, col_val) in vals {
match *col_val {
ColumnValue::Str(str_val) => {
columnar_writer.record_str(doc_id as u32, column_name, str_val);
}
ColumnValue::Bytes(bytes) => {
columnar_writer.record_bytes(doc_id as u32, column_name, bytes)
}
ColumnValue::Numerical(num) => {
columnar_writer.record_numerical(doc_id as u32, column_name, num);
}
ColumnValue::IpAddr(ip_addr) => {
columnar_writer.record_ip_addr(doc_id as u32, column_name, ip_addr);
}
ColumnValue::Bool(bool_val) => {
columnar_writer.record_bool(doc_id as u32, column_name, bool_val);
}
ColumnValue::DateTime(date_time) => {
columnar_writer.record_datetime(doc_id as u32, column_name, date_time);
}
}
}
}
columnar_writer
.serialize(num_docs, old_to_new_row_ids_opt, &mut buffer)
.unwrap();
let columnar_reader = ColumnarReader::open(buffer).unwrap();
columnar_reader
}
fn build_columnar(docs: &[Vec<(&'static str, ColumnValue)>]) -> ColumnarReader {
build_columnar_with_mapping(docs, None)
}
fn assert_columnar_eq(left: &ColumnarReader, right: &ColumnarReader) {
assert_eq!(left.num_rows(), right.num_rows());
let left_columns = left.list_columns().unwrap();
let right_columns = right.list_columns().unwrap();
assert_eq!(left_columns.len(), right_columns.len());
for i in 0..left_columns.len() {
assert_eq!(left_columns[i].0, right_columns[i].0);
let left_column = left_columns[i].1.open().unwrap();
let right_column = right_columns[i].1.open().unwrap();
assert_dyn_column_eq(&left_column, &right_column);
}
}
fn assert_column_eq<T: Copy + PartialOrd + Debug + Send + Sync + 'static>(
left: &Column<T>,
right: &Column<T>,
) {
assert_eq!(left.get_cardinality(), right.get_cardinality());
assert_eq!(left.num_docs(), right.num_docs());
let num_docs = left.num_docs();
for doc in 0..num_docs {
assert_eq!(
left.index.value_row_ids(doc),
right.index.value_row_ids(doc)
);
}
assert_eq!(left.values.num_vals(), right.values.num_vals());
let num_vals = left.values.num_vals();
for i in 0..num_vals {
assert_eq!(left.values.get_val(i), right.values.get_val(i));
}
}
fn assert_bytes_column_eq(left: &BytesColumn, right: &BytesColumn) {
assert_eq!(
left.term_ord_column.get_cardinality(),
right.term_ord_column.get_cardinality()
);
assert_eq!(left.num_rows(), right.num_rows());
assert_column_eq(&left.term_ord_column, &right.term_ord_column);
assert_eq!(left.dictionary.num_terms(), right.dictionary.num_terms());
let num_terms = left.dictionary.num_terms();
let mut left_terms = left.dictionary.stream().unwrap();
let mut right_terms = right.dictionary.stream().unwrap();
for _ in 0..num_terms {
assert!(left_terms.advance());
assert!(right_terms.advance());
assert_eq!(left_terms.key(), right_terms.key());
}
assert!(!left_terms.advance());
assert!(!right_terms.advance());
}
fn assert_dyn_column_eq(left_dyn_column: &DynamicColumn, right_dyn_column: &DynamicColumn) {
assert_eq!(
&left_dyn_column.column_type(),
&right_dyn_column.column_type()
);
assert_eq!(
&left_dyn_column.get_cardinality(),
&right_dyn_column.get_cardinality()
);
match &(left_dyn_column, right_dyn_column) {
(DynamicColumn::Bool(left_col), DynamicColumn::Bool(right_col)) => {
assert_column_eq(left_col, right_col);
}
(DynamicColumn::I64(left_col), DynamicColumn::I64(right_col)) => {
assert_column_eq(left_col, right_col);
}
(DynamicColumn::U64(left_col), DynamicColumn::U64(right_col)) => {
assert_column_eq(left_col, right_col);
}
(DynamicColumn::F64(left_col), DynamicColumn::F64(right_col)) => {
assert_column_eq(left_col, right_col);
}
(DynamicColumn::DateTime(left_col), DynamicColumn::DateTime(right_col)) => {
assert_column_eq(left_col, right_col);
}
(DynamicColumn::IpAddr(left_col), DynamicColumn::IpAddr(right_col)) => {
assert_column_eq(left_col, right_col);
}
(DynamicColumn::Bytes(left_col), DynamicColumn::Bytes(right_col)) => {
assert_bytes_column_eq(left_col, right_col);
}
(DynamicColumn::Str(left_col), DynamicColumn::Str(right_col)) => {
assert_bytes_column_eq(left_col, right_col);
}
_ => {
unreachable!()
}
}
}
trait AssertEqualToColumnValue {
fn assert_equal_to_column_value(&self, column_value: &ColumnValue);
}
impl AssertEqualToColumnValue for bool {
fn assert_equal_to_column_value(&self, column_value: &ColumnValue) {
let ColumnValue::Bool(val) = column_value else { panic!() };
assert_eq!(self, val);
}
}
impl AssertEqualToColumnValue for Ipv6Addr {
fn assert_equal_to_column_value(&self, column_value: &ColumnValue) {
let ColumnValue::IpAddr(val) = column_value else { panic!() };
assert_eq!(self, val);
}
}
impl<T: Coerce + PartialEq + Debug + Into<NumericalValue>> AssertEqualToColumnValue for T {
fn assert_equal_to_column_value(&self, column_value: &ColumnValue) {
let ColumnValue::Numerical(num) = column_value else { panic!() };
assert_eq!(self, &T::coerce(*num));
}
}
impl AssertEqualToColumnValue for DateTime {
fn assert_equal_to_column_value(&self, column_value: &ColumnValue) {
let ColumnValue::DateTime(dt) = column_value else { panic!() };
assert_eq!(self, dt);
}
}
fn assert_column_values<
T: AssertEqualToColumnValue + PartialEq + Copy + PartialOrd + Debug + Send + Sync + 'static,
>(
col: &Column<T>,
expected: &HashMap<u32, Vec<&ColumnValue>>,
) {
let mut num_non_empty_rows = 0;
for doc in 0..col.num_docs() {
let doc_vals: Vec<T> = col.values_for_doc(doc).collect();
if doc_vals.is_empty() {
continue;
}
num_non_empty_rows += 1;
let expected_vals = expected.get(&doc).unwrap();
assert_eq!(doc_vals.len(), expected_vals.len());
for (val, &expected) in doc_vals.iter().zip(expected_vals.iter()) {
val.assert_equal_to_column_value(expected)
}
}
assert_eq!(num_non_empty_rows, expected.len());
}
fn assert_bytes_column_values(
col: &BytesColumn,
expected: &HashMap<u32, Vec<&ColumnValue>>,
is_str: bool,
) {
let mut num_non_empty_rows = 0;
let mut buffer = Vec::new();
for doc in 0..col.term_ord_column.num_docs() {
let doc_vals: Vec<u64> = col.term_ords(doc).collect();
if doc_vals.is_empty() {
continue;
}
let expected_vals = expected.get(&doc).unwrap();
assert_eq!(doc_vals.len(), expected_vals.len());
for (&expected_col_val, &ord) in expected_vals.iter().zip(&doc_vals) {
col.ord_to_bytes(ord, &mut buffer).unwrap();
match expected_col_val {
ColumnValue::Str(str_val) => {
assert!(is_str);
assert_eq!(str_val.as_bytes(), &buffer);
}
ColumnValue::Bytes(bytes_val) => {
assert!(!is_str);
assert_eq!(bytes_val, &buffer);
}
_ => {
panic!();
}
}
}
num_non_empty_rows += 1;
}
assert_eq!(num_non_empty_rows, expected.len());
}
// This proptest attempts to create a tiny columnar based of up to 3 rows, and checks that the
// resulting columnar matches the row data.
proptest! {
#![proptest_config(ProptestConfig::with_cases(500))]
#[test]
fn test_single_columnar_builder_proptest(docs in columnar_docs_strategy()) {
let columnar = build_columnar(&docs[..]);
assert_eq!(columnar.num_rows() as usize, docs.len());
let mut expected_columns: HashMap<(&str, ColumnTypeCategory), HashMap<u32, Vec<&ColumnValue>> > = Default::default();
for (doc_id, doc_vals) in docs.iter().enumerate() {
for (col_name, col_val) in doc_vals {
expected_columns
.entry((col_name, col_val.column_type_category()))
.or_default()
.entry(doc_id as u32)
.or_default()
.push(col_val);
}
}
let column_list = columnar.list_columns().unwrap();
assert_eq!(expected_columns.len(), column_list.len());
for (column_name, column) in column_list {
let dynamic_column = column.open().unwrap();
let col_category: ColumnTypeCategory = dynamic_column.column_type().into();
let expected_col_values: &HashMap<u32, Vec<&ColumnValue>> = expected_columns.get(&(column_name.as_str(), col_category)).unwrap();
match &dynamic_column {
DynamicColumn::Bool(col) =>
assert_column_values(col, expected_col_values),
DynamicColumn::I64(col) =>
assert_column_values(col, expected_col_values),
DynamicColumn::U64(col) =>
assert_column_values(col, expected_col_values),
DynamicColumn::F64(col) =>
assert_column_values(col, expected_col_values),
DynamicColumn::IpAddr(col) =>
assert_column_values(col, expected_col_values),
DynamicColumn::DateTime(col) =>
assert_column_values(col, expected_col_values),
DynamicColumn::Bytes(col) =>
assert_bytes_column_values(col, expected_col_values, false),
DynamicColumn::Str(col) =>
assert_bytes_column_values(col, expected_col_values, true),
}
}
}
}
// Same as `test_single_columnar_builder_proptest` but with a shuffling mapping.
proptest! {
#![proptest_config(ProptestConfig::with_cases(500))]
#[test]
fn test_single_columnar_builder_with_shuffle_proptest((docs, mapping) in columnar_docs_and_mapping_strategy()) {
let columnar = build_columnar_with_mapping(&docs[..], Some(&mapping));
assert_eq!(columnar.num_rows() as usize, docs.len());
let mut expected_columns: HashMap<(&str, ColumnTypeCategory), HashMap<u32, Vec<&ColumnValue>> > = Default::default();
for (doc_id, doc_vals) in docs.iter().enumerate() {
for (col_name, col_val) in doc_vals {
expected_columns
.entry((col_name, col_val.column_type_category()))
.or_default()
.entry(mapping[doc_id])
.or_default()
.push(col_val);
}
}
let column_list = columnar.list_columns().unwrap();
assert_eq!(expected_columns.len(), column_list.len());
for (column_name, column) in column_list {
let dynamic_column = column.open().unwrap();
let col_category: ColumnTypeCategory = dynamic_column.column_type().into();
let expected_col_values: &HashMap<u32, Vec<&ColumnValue>> = expected_columns.get(&(column_name.as_str(), col_category)).unwrap();
for _doc_id in 0..columnar.num_rows() {
match &dynamic_column {
DynamicColumn::Bool(col) =>
assert_column_values(col, expected_col_values),
DynamicColumn::I64(col) =>
assert_column_values(col, expected_col_values),
DynamicColumn::U64(col) =>
assert_column_values(col, expected_col_values),
DynamicColumn::F64(col) =>
assert_column_values(col, expected_col_values),
DynamicColumn::IpAddr(col) =>
assert_column_values(col, expected_col_values),
DynamicColumn::DateTime(col) =>
assert_column_values(col, expected_col_values),
DynamicColumn::Bytes(col) =>
assert_bytes_column_values(col, expected_col_values, false),
DynamicColumn::Str(col) =>
assert_bytes_column_values(col, expected_col_values, true),
}
}
}
}
}
// This tests create 2 or 3 random small columnar and attempts to merge them.
// It compares the resulting merged dataframe with what would have been obtained by building the
// dataframe from the concatenated rows to begin with.
proptest! {
#![proptest_config(ProptestConfig::with_cases(1000))]
#[test]
fn test_columnar_merge_proptest(columnar_docs in proptest::collection::vec(columnar_docs_strategy(), 2..=3)) {
let columnar_readers: Vec<ColumnarReader> = columnar_docs.iter()
.map(|docs| build_columnar(&docs[..]))
.collect::<Vec<_>>();
let columnar_readers_arr: Vec<&ColumnarReader> = columnar_readers.iter().collect();
let mut output: Vec<u8> = Vec::new();
let stack_merge_order = StackMergeOrder::stack(&columnar_readers_arr[..]).into();
crate::merge_columnar(&columnar_readers_arr[..], &[], stack_merge_order, &mut output).unwrap();
let merged_columnar = ColumnarReader::open(output).unwrap();
let concat_rows: Vec<Vec<(&'static str, ColumnValue)>> = columnar_docs.iter().cloned().flatten().collect();
let expected_merged_columnar = build_columnar(&concat_rows[..]);
assert_columnar_eq(&merged_columnar, &expected_merged_columnar);
}
}
#[test]
fn test_columnar_merging_empty_columnar() {
let columnar_docs: Vec<Vec<Vec<(&str, ColumnValue)>>> =
vec![vec![], vec![vec![("c1", ColumnValue::Str("a"))]]];
let columnar_readers: Vec<ColumnarReader> = columnar_docs
.iter()
.map(|docs| build_columnar(&docs[..]))
.collect::<Vec<_>>();
let columnar_readers_arr: Vec<&ColumnarReader> = columnar_readers.iter().collect();
let mut output: Vec<u8> = Vec::new();
let stack_merge_order = StackMergeOrder::stack(&columnar_readers_arr[..]);
crate::merge_columnar(
&columnar_readers_arr[..],
&[],
crate::MergeRowOrder::Stack(stack_merge_order),
&mut output,
)
.unwrap();
let merged_columnar = ColumnarReader::open(output).unwrap();
let concat_rows: Vec<Vec<(&'static str, ColumnValue)>> =
columnar_docs.iter().cloned().flatten().collect();
let expected_merged_columnar = build_columnar(&concat_rows[..]);
assert_columnar_eq(&merged_columnar, &expected_merged_columnar);
}
#[test]
fn test_columnar_merging_number_columns() {
let columnar_docs: Vec<Vec<Vec<(&str, ColumnValue)>>> = vec![
// columnar 1
vec![
// doc 1.1
vec![("c2", ColumnValue::Numerical(0i64.into()))],
],
// columnar2
vec![
// doc 2.1
vec![("c2", ColumnValue::Numerical(0u64.into()))],
// doc 2.2
vec![("c2", ColumnValue::Numerical(u64::MAX.into()))],
],
];
let columnar_readers: Vec<ColumnarReader> = columnar_docs
.iter()
.map(|docs| build_columnar(&docs[..]))
.collect::<Vec<_>>();
let columnar_readers_arr: Vec<&ColumnarReader> = columnar_readers.iter().collect();
let mut output: Vec<u8> = Vec::new();
let stack_merge_order = StackMergeOrder::stack(&columnar_readers_arr[..]);
crate::merge_columnar(
&columnar_readers_arr[..],
&[],
crate::MergeRowOrder::Stack(stack_merge_order),
&mut output,
)
.unwrap();
let merged_columnar = ColumnarReader::open(output).unwrap();
let concat_rows: Vec<Vec<(&'static str, ColumnValue)>> =
columnar_docs.iter().cloned().flatten().collect();
let expected_merged_columnar = build_columnar(&concat_rows[..]);
assert_columnar_eq(&merged_columnar, &expected_merged_columnar);
}
// TODO add non trivial remap and merge
// TODO test required_columns
// TODO document edge case: required_columns incompatible with values.

View File

@@ -4,6 +4,8 @@ use std::{fmt, io, u64};
use ownedbytes::OwnedBytes;
use crate::ByteCount;
#[derive(Clone, Copy, Eq, PartialEq)]
pub struct TinySet(u64);
@@ -386,8 +388,8 @@ impl ReadOnlyBitSet {
}
/// Number of bytes used in the bitset representation.
pub fn num_bytes(&self) -> usize {
self.data.len()
pub fn num_bytes(&self) -> ByteCount {
self.data.len().into()
}
}

114
common/src/byte_count.rs Normal file
View File

@@ -0,0 +1,114 @@
use std::iter::Sum;
use std::ops::{Add, AddAssign};
use serde::{Deserialize, Serialize};
/// Indicates space usage in bytes
#[derive(Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct ByteCount(u64);
impl std::fmt::Debug for ByteCount {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.human_readable())
}
}
impl std::fmt::Display for ByteCount {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.human_readable())
}
}
const SUFFIX_AND_THRESHOLD: [(&str, u64); 5] = [
("KB", 1_000),
("MB", 1_000_000),
("GB", 1_000_000_000),
("TB", 1_000_000_000_000),
("PB", 1_000_000_000_000_000),
];
impl ByteCount {
#[inline]
pub fn get_bytes(&self) -> u64 {
self.0
}
pub fn human_readable(&self) -> String {
for (suffix, threshold) in SUFFIX_AND_THRESHOLD.iter().rev() {
if self.get_bytes() >= *threshold {
let unit_num = self.get_bytes() as f64 / *threshold as f64;
return format!("{:.2} {}", unit_num, suffix);
}
}
format!("{:.2} B", self.get_bytes())
}
}
impl From<u64> for ByteCount {
fn from(value: u64) -> Self {
ByteCount(value)
}
}
impl From<usize> for ByteCount {
fn from(value: usize) -> Self {
ByteCount(value as u64)
}
}
impl Sum for ByteCount {
#[inline]
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.fold(ByteCount::default(), |acc, x| acc + x)
}
}
impl PartialEq<u64> for ByteCount {
#[inline]
fn eq(&self, other: &u64) -> bool {
self.get_bytes() == *other
}
}
impl PartialOrd<u64> for ByteCount {
#[inline]
fn partial_cmp(&self, other: &u64) -> Option<std::cmp::Ordering> {
self.get_bytes().partial_cmp(other)
}
}
impl Add for ByteCount {
type Output = Self;
#[inline]
fn add(self, other: Self) -> Self {
Self(self.get_bytes() + other.get_bytes())
}
}
impl AddAssign for ByteCount {
#[inline]
fn add_assign(&mut self, other: Self) {
*self = Self(self.get_bytes() + other.get_bytes());
}
}
#[cfg(test)]
mod test {
use crate::ByteCount;
#[test]
fn test_bytes() {
assert_eq!(ByteCount::from(0u64).human_readable(), "0 B");
assert_eq!(ByteCount::from(300u64).human_readable(), "300 B");
assert_eq!(ByteCount::from(1_000_000u64).human_readable(), "1.00 MB");
assert_eq!(ByteCount::from(1_500_000u64).human_readable(), "1.50 MB");
assert_eq!(
ByteCount::from(1_500_000_000u64).human_readable(),
"1.50 GB"
);
assert_eq!(
ByteCount::from(3_213_000_000_000u64).human_readable(),
"3.21 TB"
);
}
}

View File

@@ -29,13 +29,23 @@ pub enum DatePrecision {
/// All constructors and conversions are provided as explicit
/// functions and not by implementing any `From`/`Into` traits
/// to prevent unintended usage.
#[derive(Clone, Default, Copy, PartialEq, Eq, PartialOrd, Ord)]
#[derive(Clone, Default, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct DateTime {
// Timestamp in microseconds.
pub(crate) timestamp_micros: i64,
}
impl DateTime {
/// Minimum possible `DateTime` value.
pub const MIN: DateTime = DateTime {
timestamp_micros: i64::MIN,
};
/// Maximum possible `DateTime` value.
pub const MAX: DateTime = DateTime {
timestamp_micros: i64::MAX,
};
/// Create new from UNIX timestamp in seconds
pub const fn from_timestamp_secs(seconds: i64) -> Self {
Self {

View File

@@ -0,0 +1,63 @@
use std::io::{self, Read, Write};
use crate::BinarySerializable;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u32)]
pub enum DictionaryKind {
Fst = 1,
SSTable = 2,
}
#[derive(Debug, Clone, PartialEq)]
pub struct DictionaryFooter {
pub kind: DictionaryKind,
pub version: u32,
}
impl DictionaryFooter {
pub fn verify_equal(&self, other: &DictionaryFooter) -> io::Result<()> {
if self.kind != other.kind {
return Err(io::Error::new(
io::ErrorKind::Other,
format!(
"Invalid dictionary type, expected {:?}, found {:?}",
self.kind, other.kind
),
));
}
if self.version != other.version {
return Err(io::Error::new(
io::ErrorKind::Other,
format!(
"Unsuported dictionary version, expected {}, found {}",
self.version, other.version
),
));
}
Ok(())
}
}
impl BinarySerializable for DictionaryFooter {
fn serialize<W: Write + ?Sized>(&self, writer: &mut W) -> io::Result<()> {
self.version.serialize(writer)?;
(self.kind as u32).serialize(writer)
}
fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
let version = u32::deserialize(reader)?;
let kind = u32::deserialize(reader)?;
let kind = match kind {
1 => DictionaryKind::Fst,
2 => DictionaryKind::SSTable,
_ => {
return Err(io::Error::new(
io::ErrorKind::Other,
format!("invalid dictionary kind: {kind}"),
))
}
};
Ok(DictionaryFooter { kind, version })
}
}

View File

@@ -5,7 +5,7 @@ use std::{fmt, io};
use async_trait::async_trait;
use ownedbytes::{OwnedBytes, StableDeref};
use crate::HasLen;
use crate::{ByteCount, HasLen};
/// Objects that represents files sections in tantivy.
///
@@ -216,6 +216,11 @@ impl FileSlice {
pub fn slice_to(&self, to_offset: usize) -> FileSlice {
self.slice(0..to_offset)
}
/// Returns the byte count of the FileSlice.
pub fn num_bytes(&self) -> ByteCount {
self.range.len().into()
}
}
#[async_trait]

View File

@@ -5,14 +5,18 @@ use std::ops::Deref;
pub use byteorder::LittleEndian as Endianness;
mod bitset;
mod byte_count;
mod datetime;
mod dictionary_footer;
pub mod file_slice;
mod group_by;
mod serialize;
mod vint;
mod writer;
pub use bitset::*;
pub use byte_count::ByteCount;
pub use datetime::{DatePrecision, DateTime};
pub use dictionary_footer::*;
pub use group_by::GroupByIteratorExtended;
pub use ownedbytes::{OwnedBytes, StableDeref};
pub use serialize::{BinarySerializable, DeserializeFrom, FixedSize};
@@ -109,6 +113,21 @@ pub fn u64_to_f64(val: u64) -> f64 {
})
}
/// Replaces a given byte in the `bytes` slice of bytes.
///
/// This function assumes that the needle is rarely contained in the bytes string
/// and offers a fast path if the needle is not present.
pub fn replace_in_place(needle: u8, replacement: u8, bytes: &mut [u8]) {
if !bytes.contains(&needle) {
return;
}
for b in bytes {
if *b == needle {
*b = replacement;
}
}
}
#[cfg(test)]
pub mod test {
@@ -173,4 +192,20 @@ pub mod test {
assert!(f64_to_u64(-2.0) < f64_to_u64(1.0));
assert!(f64_to_u64(-2.0) < f64_to_u64(-1.5));
}
#[test]
fn test_replace_in_place() {
let test_aux = |before_replacement: &[u8], expected: &[u8]| {
let mut bytes: Vec<u8> = before_replacement.to_vec();
super::replace_in_place(b'b', b'c', &mut bytes);
assert_eq!(&bytes[..], expected);
};
test_aux(b"", b"");
test_aux(b"b", b"c");
test_aux(b"baaa", b"caaa");
test_aux(b"aaab", b"aaac");
test_aux(b"aaabaa", b"aaacaa");
test_aux(b"aaaaaa", b"aaaaaa");
test_aux(b"bbbb", b"cccc");
}
}

View File

@@ -1,129 +1,319 @@
// # Aggregation example
//
// This example shows how you can use built-in aggregations.
// We will use range buckets and compute the average in each bucket.
//
// We will use nested aggregations with buckets and metrics:
// - Range buckets and compute the average in each bucket.
// - Term aggregation and compute the min price in each bucket
// ---
use serde_json::Value;
use serde_json::{Deserializer, Value};
use tantivy::aggregation::agg_req::{
Aggregation, Aggregations, BucketAggregation, BucketAggregationType, MetricAggregation,
RangeAggregation,
};
use tantivy::aggregation::agg_result::AggregationResults;
use tantivy::aggregation::bucket::RangeAggregationRange;
use tantivy::aggregation::metric::AverageAggregation;
use tantivy::aggregation::AggregationCollector;
use tantivy::query::TermQuery;
use tantivy::schema::{self, IndexRecordOption, Schema, TextFieldIndexing};
use tantivy::{doc, Index, Term};
use tantivy::query::AllQuery;
use tantivy::schema::{self, IndexRecordOption, Schema, TextFieldIndexing, FAST};
use tantivy::Index;
fn main() -> tantivy::Result<()> {
// # Create Schema
//
// Lets create a schema for a footwear shop, with 4 fields: name, category, stock and price.
// category, stock and price will be fast fields as that's the requirement
// for aggregation queries.
//
let mut schema_builder = Schema::builder();
// In preparation of the `TermsAggregation`, the category field is configured with:
// - `set_fast`
// - `raw` tokenizer
//
// The tokenizer is set to "raw", because the fast field uses the same dictionary as the
// inverted index. (This behaviour will change in tantivy 0.20, where the fast field will
// always be raw tokenized independent from the regular tokenizing)
//
let text_fieldtype = schema::TextOptions::default()
.set_indexing_options(
TextFieldIndexing::default().set_index_option(IndexRecordOption::WithFreqs),
TextFieldIndexing::default()
.set_index_option(IndexRecordOption::WithFreqs)
.set_tokenizer("raw"),
)
.set_fast(None)
.set_stored();
let text_field = schema_builder.add_text_field("text", text_fieldtype);
let score_fieldtype = crate::schema::NumericOptions::default().set_fast();
let highscore_field = schema_builder.add_f64_field("highscore", score_fieldtype.clone());
let price_field = schema_builder.add_f64_field("price", score_fieldtype);
schema_builder.add_text_field("category", text_fieldtype);
schema_builder.add_f64_field("stock", FAST);
schema_builder.add_f64_field("price", FAST);
let schema = schema_builder.build();
// # Indexing documents
//
// Lets index a bunch of documents for this example.
let index = Index::create_in_ram(schema);
let index = Index::create_in_ram(schema.clone());
let data = r#"{
"name": "Almond Toe Court Shoes, Patent Black",
"category": "Womens Footwear",
"price": 99.00,
"stock": 5
}
{
"name": "Suede Shoes, Blue",
"category": "Womens Footwear",
"price": 42.00,
"stock": 4
}
{
"name": "Leather Driver Saddle Loafers, Tan",
"category": "Mens Footwear",
"price": 34.00,
"stock": 12
}
{
"name": "Flip Flops, Red",
"category": "Mens Footwear",
"price": 19.00,
"stock": 6
}
{
"name": "Flip Flops, Blue",
"category": "Mens Footwear",
"price": 19.00,
"stock": 0
}
{
"name": "Gold Button Cardigan, Black",
"category": "Womens Casualwear",
"price": 167.00,
"stock": 6
}
{
"name": "Cotton Shorts, Medium Red",
"category": "Womens Casualwear",
"price": 30.00,
"stock": 5
}
{
"name": "Fine Stripe Short SleeveShirt, Grey",
"category": "Mens Casualwear",
"price": 49.99,
"stock": 9
}
{
"name": "Fine Stripe Short SleeveShirt, Green",
"category": "Mens Casualwear",
"price": 49.99,
"offer": 39.99,
"stock": 9
}
{
"name": "Sharkskin Waistcoat, Charcoal",
"category": "Mens Formalwear",
"price": 75.00,
"stock": 2
}
{
"name": "Lightweight Patch PocketBlazer, Deer",
"category": "Mens Formalwear",
"price": 175.50,
"stock": 1
}
{
"name": "Bird Print Dress, Black",
"category": "Womens Formalwear",
"price": 270.00,
"stock": 10
}
{
"name": "Mid Twist Cut-Out Dress, Pink",
"category": "Womens Formalwear",
"price": 540.00,
"stock": 5
}"#;
let stream = Deserializer::from_str(data).into_iter::<Value>();
let mut index_writer = index.writer(50_000_000)?;
// writing the segment
index_writer.add_document(doc!(
text_field => "cool",
highscore_field => 1f64,
price_field => 0f64,
))?;
index_writer.add_document(doc!(
text_field => "cool",
highscore_field => 3f64,
price_field => 1f64,
))?;
index_writer.add_document(doc!(
text_field => "cool",
highscore_field => 5f64,
price_field => 1f64,
))?;
index_writer.add_document(doc!(
text_field => "nohit",
highscore_field => 6f64,
price_field => 2f64,
))?;
index_writer.add_document(doc!(
text_field => "cool",
highscore_field => 7f64,
price_field => 2f64,
))?;
index_writer.commit()?;
index_writer.add_document(doc!(
text_field => "cool",
highscore_field => 11f64,
price_field => 10f64,
))?;
index_writer.add_document(doc!(
text_field => "cool",
highscore_field => 14f64,
price_field => 15f64,
))?;
index_writer.add_document(doc!(
text_field => "cool",
highscore_field => 15f64,
price_field => 20f64,
))?;
let mut num_indexed = 0;
for value in stream {
let doc = schema.parse_document(&serde_json::to_string(&value.unwrap())?)?;
index_writer.add_document(doc)?;
num_indexed += 1;
if num_indexed > 4 {
// Writing the first segment
index_writer.commit()?;
}
}
// Writing the second segment
index_writer.commit()?;
// We have two segments now. The `AggregationCollector` will run the aggregation on each
// segment and then merge the results into an `IntermediateAggregationResult`.
let reader = index.reader()?;
let text_field = reader.searcher().schema().get_field("text").unwrap();
let searcher = reader.searcher();
// ---
// # Aggregation Query
//
//
// We can construct the query by building the request structure or by deserializing from JSON.
// The JSON API is more stable and therefore recommended.
//
// ## Request 1
let term_query = TermQuery::new(
Term::from_field_text(text_field, "cool"),
IndexRecordOption::Basic,
);
let agg_req_str = r#"
{
"group_by_stock": {
"aggs": {
"average_price": { "avg": { "field": "price" } }
},
"range": {
"field": "stock",
"ranges": [
{ "key": "few", "to": 1.0 },
{ "key": "some", "from": 1.0, "to": 10.0 },
{ "key": "many", "from": 10.0 }
]
}
}
} "#;
let sub_agg_req_1: Aggregations = vec![(
"average_price".to_string(),
Aggregation::Metric(MetricAggregation::Average(
AverageAggregation::from_field_name("price".to_string()),
)),
)]
.into_iter()
.collect();
// In this Aggregation we want to get the average price for different groups, depending on how
// many items are in stock. We define custom ranges `few`, `some`, `many` via the
// range aggregation.
// For every bucket we want the average price, so we create a nested metric aggregation on the
// range bucket aggregation. Only buckets support nested aggregations.
// ### Request JSON API
//
let agg_req_1: Aggregations = vec![(
"score_ranges".to_string(),
Aggregation::Bucket(BucketAggregation {
let agg_req: Aggregations = serde_json::from_str(agg_req_str)?;
let collector = AggregationCollector::from_aggs(agg_req, Default::default());
let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap();
let res2: Value = serde_json::to_value(agg_res)?;
// ### Request Rust API
//
// This is exactly the same request as above, but via the rust structures.
//
let agg_req: Aggregations = vec![(
"group_by_stock".to_string(),
Aggregation::Bucket(Box::new(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "highscore".to_string(),
field: "stock".to_string(),
ranges: vec![
(-1f64..9f64).into(),
(9f64..14f64).into(),
(14f64..20f64).into(),
RangeAggregationRange {
key: Some("few".into()),
from: None,
to: Some(1f64),
},
RangeAggregationRange {
key: Some("some".into()),
from: Some(1f64),
to: Some(10f64),
},
RangeAggregationRange {
key: Some("many".into()),
from: Some(10f64),
to: None,
},
],
..Default::default()
}),
sub_aggregation: sub_agg_req_1,
}),
sub_aggregation: vec![(
"average_price".to_string(),
Aggregation::Metric(MetricAggregation::Average(
AverageAggregation::from_field_name("price".to_string()),
)),
)]
.into_iter()
.collect(),
})),
)]
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req_1, None, index.schema());
let collector = AggregationCollector::from_aggs(agg_req, Default::default());
// We use the `AllQuery` which will pass all documents to the AggregationCollector.
let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap();
let searcher = reader.searcher();
let agg_res: AggregationResults = searcher.search(&term_query, &collector).unwrap();
let res1: Value = serde_json::to_value(agg_res)?;
// ### Aggregation Result
//
// The resulting structure deserializes in the same JSON format as elastic search.
//
let expected_res = r#"
{
"group_by_stock":{
"buckets":[
{"average_price":{"value":19.0},"doc_count":1,"key":"few","to":1.0},
{"average_price":{"value":124.748},"doc_count":10,"from":1.0,"key":"some","to":10.0},
{"average_price":{"value":152.0},"doc_count":2,"from":10.0,"key":"many"}
]
}
}
"#;
let expected_json: Value = serde_json::from_str(expected_res)?;
assert_eq!(expected_json, res1);
assert_eq!(expected_json, res2);
// ### Request 2
//
// Now we are interested in the minimum price per category, so we create a bucket per
// category via `TermsAggregation`. We are interested in the highest minimum prices, and set the
// order of the buckets `"order": { "min_price": "desc" }` to be sorted by the the metric of
// the sub aggregation. (awesome)
//
let agg_req_str = r#"
{
"min_price_per_category": {
"aggs": {
"min_price": { "min": { "field": "price" } }
},
"terms": {
"field": "category",
"min_doc_count": 1,
"order": { "min_price": "desc" }
}
}
} "#;
let agg_req: Aggregations = serde_json::from_str(agg_req_str)?;
let collector = AggregationCollector::from_aggs(agg_req, Default::default());
let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap();
let res: Value = serde_json::to_value(agg_res)?;
println!("{}", serde_json::to_string_pretty(&res)?);
// Minimum price per category, sorted by minimum price descending
//
// As you can see, the starting prices for `Formalwear` are higher than `Casualwear`.
//
let expected_res = r#"
{
"min_price_per_category": {
"buckets": [
{ "doc_count": 2, "key": "Womens Formalwear", "min_price": { "value": 270.0 } },
{ "doc_count": 2, "key": "Mens Formalwear", "min_price": { "value": 75.0 } },
{ "doc_count": 2, "key": "Mens Casualwear", "min_price": { "value": 49.99 } },
{ "doc_count": 2, "key": "Womens Footwear", "min_price": { "value": 42.0 } },
{ "doc_count": 2, "key": "Womens Casualwear", "min_price": { "value": 30.0 } },
{ "doc_count": 3, "key": "Mens Footwear", "min_price": { "value": 19.0 } }
],
"sum_other_doc_count": 0
}
}
"#;
let expected_json: Value = serde_json::from_str(expected_res)?;
assert_eq!(expected_json, res);
Ok(())
}

View File

@@ -105,7 +105,7 @@ impl SegmentCollector for StatsSegmentCollector {
fn collect(&mut self, doc: u32, _score: Score) {
// Since we know the values are single value, we could call `first_or_default_col` on the
// column and fetch single values.
for value in self.fast_field_reader.values(doc) {
for value in self.fast_field_reader.values_for_doc(doc) {
let value = value as f64;
self.stats.count += 1;
self.stats.sum += value;
@@ -171,7 +171,7 @@ fn main() -> tantivy::Result<()> {
let searcher = reader.searcher();
let query_parser = QueryParser::for_index(&index, vec![product_name, product_description]);
// here we want to get a hit on the 'ken' in Frankenstein
// here we want to search for `broom` and use `StatsCollector` on the hits.
let query = query_parser.parse_query("broom")?;
if let Some(stats) =
searcher.search(&query, &StatsCollector::with_field("price".to_string()))?

View File

@@ -1,7 +1,7 @@
// # Defining a tokenizer pipeline
//
// In this example, we'll see how to define a tokenizer pipeline
// by aligning a bunch of `TokenFilter`.
// In this example, we'll see how to define a tokenizer
// by creating a custom `NgramTokenizer`.
use tantivy::collector::TopDocs;
use tantivy::query::QueryParser;
use tantivy::schema::*;

View File

@@ -14,6 +14,7 @@ fn main() -> tantivy::Result<()> {
.set_stored()
.set_fast()
.set_precision(tantivy::DatePrecision::Seconds);
// Add `occurred_at` date field type
let occurred_at = schema_builder.add_date_field("occurred_at", opts);
let event_type = schema_builder.add_text_field("event", STRING | STORED);
let schema = schema_builder.build();
@@ -22,6 +23,7 @@ fn main() -> tantivy::Result<()> {
let index = Index::create_in_ram(schema.clone());
let mut index_writer = index.writer(50_000_000)?;
// The dates are passed as string in the RFC3339 format
let doc = schema.parse_document(
r#"{
"occurred_at": "2022-06-22T12:53:50.53Z",
@@ -41,14 +43,16 @@ fn main() -> tantivy::Result<()> {
let reader = index.reader()?;
let searcher = reader.searcher();
// # Default fields: event_type
// # Search
let query_parser = QueryParser::for_index(&index, vec![event_type]);
{
let query = query_parser.parse_query("event:comment")?;
// Simple exact search on the date
let query = query_parser.parse_query("occurred_at:\"2022-06-22T12:53:50.53Z\"")?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(5))?;
assert_eq!(count_docs.len(), 1);
}
{
// Range query on the date field
let query = query_parser
.parse_query(r#"occurred_at:[2022-06-22T12:58:00Z TO 2022-06-23T00:00:00Z}"#)?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(4))?;

View File

@@ -1,3 +1,12 @@
// # Faceted Search With Tweak Score
//
// This example covers the faceted search functionalities of
// tantivy.
//
// We will :
// - define a text field "name" in our schema
// - define a facet field "classification" in our schema
use std::collections::HashSet;
use tantivy::collector::TopDocs;
@@ -55,6 +64,7 @@ fn main() -> tantivy::Result<()> {
.collect(),
);
let top_docs_by_custom_score =
// Call TopDocs with a custom tweak score
TopDocs::with_limit(2).tweak_score(move |segment_reader: &SegmentReader| {
let ingredient_reader = segment_reader.facet_reader("ingredient").unwrap();
let facet_dict = ingredient_reader.facet_dict();
@@ -65,6 +75,7 @@ fn main() -> tantivy::Result<()> {
.collect();
move |doc: DocId, original_score: Score| {
// Update the original score with a tweaked score
let missing_ingredients = ingredient_reader
.facet_ords(doc)
.filter(|ord| !query_ords.contains(ord))

167
examples/fuzzy_search.rs Normal file
View File

@@ -0,0 +1,167 @@
// # Basic Example
//
// This example covers the basic functionalities of
// tantivy.
//
// We will :
// - define our schema
// - create an index in a directory
// - index a few documents into our index
// - search for the best document matching a basic query
// - retrieve the best document's original content.
// ---
// Importing tantivy...
use tantivy::collector::{Count, TopDocs};
use tantivy::query::FuzzyTermQuery;
use tantivy::schema::*;
use tantivy::{doc, Index, ReloadPolicy};
use tempfile::TempDir;
fn main() -> tantivy::Result<()> {
// Let's create a temporary directory for the
// sake of this example
let index_path = TempDir::new()?;
// # Defining the schema
//
// The Tantivy index requires a very strict schema.
// The schema declares which fields are in the index,
// and for each field, its type and "the way it should
// be indexed".
// First we need to define a schema ...
let mut schema_builder = Schema::builder();
// Our first field is title.
// We want full-text search for it, and we also want
// to be able to retrieve the document after the search.
//
// `TEXT | STORED` is some syntactic sugar to describe
// that.
//
// `TEXT` means the field should be tokenized and indexed,
// along with its term frequency and term positions.
//
// `STORED` means that the field will also be saved
// in a compressed, row-oriented key-value store.
// This store is useful for reconstructing the
// documents that were selected during the search phase.
let title = schema_builder.add_text_field("title", TEXT | STORED);
let schema = schema_builder.build();
// # Indexing documents
//
// Let's create a brand new index.
//
// This will actually just save a meta.json
// with our schema in the directory.
let index = Index::create_in_dir(&index_path, schema.clone())?;
// To insert a document we will need an index writer.
// There must be only one writer at a time.
// This single `IndexWriter` is already
// multithreaded.
//
// Here we give tantivy a budget of `50MB`.
// Using a bigger memory_arena for the indexer may increase
// throughput, but 50 MB is already plenty.
let mut index_writer = index.writer(50_000_000)?;
// Let's index our documents!
// We first need a handle on the title and the body field.
// ### Adding documents
//
index_writer.add_document(doc!(
title => "The Name of the Wind",
))?;
index_writer.add_document(doc!(
title => "The Diary of Muadib",
))?;
index_writer.add_document(doc!(
title => "A Dairy Cow",
))?;
index_writer.add_document(doc!(
title => "The Diary of a Young Girl",
))?;
index_writer.commit()?;
// ### Committing
//
// At this point our documents are not searchable.
//
//
// We need to call `.commit()` explicitly to force the
// `index_writer` to finish processing the documents in the queue,
// flush the current index to the disk, and advertise
// the existence of new documents.
//
// This call is blocking.
index_writer.commit()?;
// If `.commit()` returns correctly, then all of the
// documents that have been added are guaranteed to be
// persistently indexed.
//
// In the scenario of a crash or a power failure,
// tantivy behaves as if it has rolled back to its last
// commit.
// # Searching
//
// ### Searcher
//
// A reader is required first in order to search an index.
// It acts as a `Searcher` pool that reloads itself,
// depending on a `ReloadPolicy`.
//
// For a search server you will typically create one reader for the entire lifetime of your
// program, and acquire a new searcher for every single request.
//
// In the code below, we rely on the 'ON_COMMIT' policy: the reader
// will reload the index automatically after each commit.
let reader = index
.reader_builder()
.reload_policy(ReloadPolicy::OnCommit)
.try_into()?;
// We now need to acquire a searcher.
//
// A searcher points to a snapshotted, immutable version of the index.
//
// Some search experience might require more than
// one query. Using the same searcher ensures that all of these queries will run on the
// same version of the index.
//
// Acquiring a `searcher` is very cheap.
//
// You should acquire a searcher every time you start processing a request and
// and release it right after your query is finished.
let searcher = reader.searcher();
// ### FuzzyTermQuery
{
let term = Term::from_field_text(title, "Diary");
let query = FuzzyTermQuery::new(term, 2, true);
let (top_docs, count) = searcher
.search(&query, &(TopDocs::with_limit(5), Count))
.unwrap();
assert_eq!(count, 3);
assert_eq!(top_docs.len(), 3);
for (score, doc_address) in top_docs {
let retrieved_doc = searcher.doc(doc_address)?;
// Note that the score is not lower for the fuzzy hit.
// There's an issue open for that: https://github.com/quickwit-oss/tantivy/issues/563
println!("score {score:?} doc {}", schema.to_json(&retrieved_doc));
// score 1.0 doc {"title":["The Diary of Muadib"]}
//
// score 1.0 doc {"title":["The Diary of a Young Girl"]}
//
// score 1.0 doc {"title":["A Dairy Cow"]}
}
}
Ok(())
}

View File

@@ -10,6 +10,10 @@ use tantivy::Index;
fn main() -> tantivy::Result<()> {
// # Defining the schema
// We set the IP field as `INDEXED`, so it can be searched
// `FAST` will create a fast field. The fast field will be used to execute search queries.
// `FAST` is not a requirement for range queries, it can also be executed on the inverted index
// which is created by `INDEXED`.
let mut schema_builder = Schema::builder();
let event_type = schema_builder.add_text_field("event_type", STRING | STORED);
let ip = schema_builder.add_ip_addr_field("ip", STORED | INDEXED | FAST);
@@ -19,51 +23,81 @@ fn main() -> tantivy::Result<()> {
let index = Index::create_in_ram(schema.clone());
let mut index_writer = index.writer(50_000_000)?;
// ### IPv4
// Adding documents that contain an IPv4 address. Notice that the IP addresses are passed as
// `String`. Since the field is of type ip, we parse the IP address from the string and store it
// internally as IPv6.
let doc = schema.parse_document(
r#"{
"ip": "192.168.0.33",
"event_type": "login"
}"#,
"ip": "192.168.0.33",
"event_type": "login"
}"#,
)?;
index_writer.add_document(doc)?;
let doc = schema.parse_document(
r#"{
"ip": "192.168.0.80",
"event_type": "checkout"
}"#,
"ip": "192.168.0.80",
"event_type": "checkout"
}"#,
)?;
index_writer.add_document(doc)?;
// ### IPv6
// Adding a document that contains an IPv6 address.
let doc = schema.parse_document(
r#"{
"ip": "2001:0db8:85a3:0000:0000:8a2e:0370:7334",
"event_type": "checkout"
}"#,
"ip": "2001:0db8:85a3:0000:0000:8a2e:0370:7334",
"event_type": "checkout"
}"#,
)?;
index_writer.add_document(doc)?;
// Commit will create a segment containing our documents.
index_writer.commit()?;
let reader = index.reader()?;
let searcher = reader.searcher();
// # Search
// Range queries on IPv4. Since we created a fast field, the fast field will be used to execute
// the search.
// ### Range Queries
let query_parser = QueryParser::for_index(&index, vec![event_type, ip]);
{
let query = query_parser.parse_query("ip:[192.168.0.0 TO 192.168.0.100]")?;
// Inclusive range queries
let query = query_parser.parse_query("ip:[192.168.0.80 TO 192.168.0.100]")?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(5))?;
assert_eq!(count_docs.len(), 2);
assert_eq!(count_docs.len(), 1);
}
{
let query = query_parser.parse_query("ip:[192.168.1.0 TO 192.168.1.100]")?;
// Exclusive range queries
let query = query_parser.parse_query("ip:{192.168.0.80 TO 192.168.1.100]")?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?;
assert_eq!(count_docs.len(), 0);
}
{
// Find docs with IP addresses smaller equal 192.168.1.100
let query = query_parser.parse_query("ip:[* TO 192.168.1.100]")?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?;
assert_eq!(count_docs.len(), 2);
}
{
// Find docs with IP addresses smaller than 192.168.1.100
let query = query_parser.parse_query("ip:[* TO 192.168.1.100}")?;
let count_docs = searcher.search(&*query, &TopDocs::with_limit(2))?;
assert_eq!(count_docs.len(), 2);
}
// ### Exact Queries
// Exact search on IPv4.
{
let query = query_parser.parse_query("ip:192.168.0.80")?;
let count_docs = searcher.search(&*query, &Count)?;
assert_eq!(count_docs, 1);
}
// Exact search on IPv6.
// IpV6 addresses need to be quoted because they contain `:`
{
// IpV6 needs to be escaped because it contains `:`
let query = query_parser.parse_query("ip:\"2001:0db8:85a3:0000:0000:8a2e:0370:7334\"")?;
let count_docs = searcher.search(&*query, &Count)?;
assert_eq!(count_docs, 1);

View File

@@ -12,7 +12,7 @@
use tantivy::collector::{Count, TopDocs};
use tantivy::query::TermQuery;
use tantivy::schema::*;
use tantivy::tokenizer::{PreTokenizedString, SimpleTokenizer, Token, Tokenizer};
use tantivy::tokenizer::{PreTokenizedString, SimpleTokenizer, Token, TokenStream, Tokenizer};
use tantivy::{doc, Index, ReloadPolicy};
use tempfile::TempDir;

View File

@@ -50,12 +50,13 @@ fn main() -> tantivy::Result<()> {
// This tokenizer lowers all of the text (to help with stop word matching)
// then removes all instances of `the` and `and` from the corpus
let tokenizer = TextAnalyzer::from(SimpleTokenizer)
let tokenizer = TextAnalyzer::builder(SimpleTokenizer)
.filter(LowerCaser)
.filter(StopWordFilter::remove(vec![
"the".to_string(),
"and".to_string(),
]));
]))
.build();
index.tokenizers().register("stoppy", tokenizer);

View File

@@ -17,7 +17,6 @@ use tantivy::{
type ProductId = u64;
/// Price
type Price = u32;
pub trait PriceFetcher: Send + Sync + 'static {
@@ -90,10 +89,10 @@ impl Warmer for DynamicPriceColumn {
}
}
/// For the sake of this example, the table is just an editable HashMap behind a RwLock.
/// This map represents a map (ProductId -> Price)
///
/// In practise, it could be fetching things from an external service, like a SQL table.
// For the sake of this example, the table is just an editable HashMap behind a RwLock.
// This map represents a map (ProductId -> Price)
//
// In practise, it could be fetching things from an external service, like a SQL table.
#[derive(Default, Clone)]
pub struct ExternalPriceTable {
prices: Arc<RwLock<HashMap<ProductId, Price>>>,

View File

@@ -0,0 +1,611 @@
#[cfg(all(test, feature = "unstable"))]
mod bench {
use columnar::Cardinality;
use rand::prelude::SliceRandom;
use rand::{thread_rng, Rng};
use test::{self, Bencher};
use super::*;
use crate::aggregation::bucket::{
CustomOrder, HistogramAggregation, HistogramBounds, Order, OrderTarget, TermsAggregation,
};
use crate::aggregation::metric::StatsAggregation;
use crate::query::AllQuery;
use crate::schema::{Schema, TextFieldIndexing, FAST, STRING};
use crate::Index;
fn get_test_index_bench(cardinality: Cardinality) -> crate::Result<Index> {
let mut schema_builder = Schema::builder();
let text_fieldtype = crate::schema::TextOptions::default()
.set_indexing_options(
TextFieldIndexing::default().set_index_option(IndexRecordOption::WithFreqs),
)
.set_stored();
let text_field = schema_builder.add_text_field("text", text_fieldtype);
let text_field_many_terms = schema_builder.add_text_field("text_many_terms", STRING | FAST);
let text_field_few_terms = schema_builder.add_text_field("text_few_terms", STRING | FAST);
let score_fieldtype = crate::schema::NumericOptions::default().set_fast();
let score_field = schema_builder.add_u64_field("score", score_fieldtype.clone());
let score_field_f64 = schema_builder.add_f64_field("score_f64", score_fieldtype.clone());
let score_field_i64 = schema_builder.add_i64_field("score_i64", score_fieldtype);
let index = Index::create_from_tempdir(schema_builder.build())?;
let few_terms_data = vec!["INFO", "ERROR", "WARN", "DEBUG"];
let many_terms_data = (0..150_000)
.map(|num| format!("author{}", num))
.collect::<Vec<_>>();
{
let mut rng = thread_rng();
let mut index_writer = index.writer_with_num_threads(1, 100_000_000)?;
// To make the different test cases comparable we just change one doc to force the
// cardinality
if cardinality == Cardinality::Optional {
index_writer.add_document(doc!())?;
}
if cardinality == Cardinality::Multivalued {
index_writer.add_document(doc!(
text_field => "cool",
text_field => "cool",
text_field_many_terms => "cool",
text_field_many_terms => "cool",
text_field_few_terms => "cool",
text_field_few_terms => "cool",
score_field => 1u64,
score_field => 1u64,
score_field_f64 => 1.0,
score_field_f64 => 1.0,
score_field_i64 => 1i64,
score_field_i64 => 1i64,
))?;
}
for _ in 0..1_000_000 {
let val: f64 = rng.gen_range(0.0..1_000_000.0);
index_writer.add_document(doc!(
text_field => "cool",
text_field_many_terms => many_terms_data.choose(&mut rng).unwrap().to_string(),
text_field_few_terms => few_terms_data.choose(&mut rng).unwrap().to_string(),
score_field => val as u64,
score_field_f64 => val,
score_field_i64 => val as i64,
))?;
}
// writing the segment
index_writer.commit()?;
}
Ok(index)
}
use paste::paste;
#[macro_export]
macro_rules! bench_all_cardinalities {
( $x:ident ) => {
paste! {
#[bench]
fn $x(b: &mut Bencher) {
[<$x _card>](b, Cardinality::Full)
}
#[bench]
fn [<$x _opt>](b: &mut Bencher) {
[<$x _card>](b, Cardinality::Optional)
}
#[bench]
fn [<$x _multi>](b: &mut Bencher) {
[<$x _card>](b, Cardinality::Multivalued)
}
}
};
}
bench_all_cardinalities!(bench_aggregation_average_u64);
fn bench_aggregation_average_u64_card(b: &mut Bencher, cardinality: Cardinality) {
let index = get_test_index_bench(cardinality).unwrap();
let reader = index.reader().unwrap();
let text_field = reader.searcher().schema().get_field("text").unwrap();
b.iter(|| {
let term_query = TermQuery::new(
Term::from_field_text(text_field, "cool"),
IndexRecordOption::Basic,
);
let agg_req_1: Aggregations = vec![(
"average".to_string(),
Aggregation::Metric(MetricAggregation::Average(
AverageAggregation::from_field_name("score".to_string()),
)),
)]
.into_iter()
.collect();
let collector = get_collector(agg_req_1);
let searcher = reader.searcher();
searcher.search(&term_query, &collector).unwrap()
});
}
bench_all_cardinalities!(bench_aggregation_stats_f64);
fn bench_aggregation_stats_f64_card(b: &mut Bencher, cardinality: Cardinality) {
let index = get_test_index_bench(cardinality).unwrap();
let reader = index.reader().unwrap();
let text_field = reader.searcher().schema().get_field("text").unwrap();
b.iter(|| {
let term_query = TermQuery::new(
Term::from_field_text(text_field, "cool"),
IndexRecordOption::Basic,
);
let agg_req_1: Aggregations = vec![(
"average_f64".to_string(),
Aggregation::Metric(MetricAggregation::Stats(StatsAggregation::from_field_name(
"score_f64".to_string(),
))),
)]
.into_iter()
.collect();
let collector = get_collector(agg_req_1);
let searcher = reader.searcher();
searcher.search(&term_query, &collector).unwrap()
});
}
bench_all_cardinalities!(bench_aggregation_average_f64);
fn bench_aggregation_average_f64_card(b: &mut Bencher, cardinality: Cardinality) {
let index = get_test_index_bench(cardinality).unwrap();
let reader = index.reader().unwrap();
let text_field = reader.searcher().schema().get_field("text").unwrap();
b.iter(|| {
let term_query = TermQuery::new(
Term::from_field_text(text_field, "cool"),
IndexRecordOption::Basic,
);
let agg_req_1: Aggregations = vec![(
"average_f64".to_string(),
Aggregation::Metric(MetricAggregation::Average(
AverageAggregation::from_field_name("score_f64".to_string()),
)),
)]
.into_iter()
.collect();
let collector = get_collector(agg_req_1);
let searcher = reader.searcher();
searcher.search(&term_query, &collector).unwrap()
});
}
bench_all_cardinalities!(bench_aggregation_average_u64_and_f64);
fn bench_aggregation_average_u64_and_f64_card(b: &mut Bencher, cardinality: Cardinality) {
let index = get_test_index_bench(cardinality).unwrap();
let reader = index.reader().unwrap();
let text_field = reader.searcher().schema().get_field("text").unwrap();
b.iter(|| {
let term_query = TermQuery::new(
Term::from_field_text(text_field, "cool"),
IndexRecordOption::Basic,
);
let agg_req_1: Aggregations = vec![
(
"average_f64".to_string(),
Aggregation::Metric(MetricAggregation::Average(
AverageAggregation::from_field_name("score_f64".to_string()),
)),
),
(
"average".to_string(),
Aggregation::Metric(MetricAggregation::Average(
AverageAggregation::from_field_name("score".to_string()),
)),
),
]
.into_iter()
.collect();
let collector = get_collector(agg_req_1);
let searcher = reader.searcher();
searcher.search(&term_query, &collector).unwrap()
});
}
bench_all_cardinalities!(bench_aggregation_terms_few);
fn bench_aggregation_terms_few_card(b: &mut Bencher, cardinality: Cardinality) {
let index = get_test_index_bench(cardinality).unwrap();
let reader = index.reader().unwrap();
b.iter(|| {
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "text_few_terms".to_string(),
..Default::default()
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
let collector = get_collector(agg_req);
let searcher = reader.searcher();
searcher.search(&AllQuery, &collector).unwrap()
});
}
bench_all_cardinalities!(bench_aggregation_terms_many_with_sub_agg);
fn bench_aggregation_terms_many_with_sub_agg_card(b: &mut Bencher, cardinality: Cardinality) {
let index = get_test_index_bench(cardinality).unwrap();
let reader = index.reader().unwrap();
b.iter(|| {
let sub_agg_req: Aggregations = vec![(
"average_f64".to_string(),
Aggregation::Metric(MetricAggregation::Average(
AverageAggregation::from_field_name("score_f64".to_string()),
)),
)]
.into_iter()
.collect();
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "text_many_terms".to_string(),
..Default::default()
}),
sub_aggregation: sub_agg_req,
}
.into(),
),
)]
.into_iter()
.collect();
let collector = get_collector(agg_req);
let searcher = reader.searcher();
searcher.search(&AllQuery, &collector).unwrap()
});
}
bench_all_cardinalities!(bench_aggregation_terms_many2);
fn bench_aggregation_terms_many2_card(b: &mut Bencher, cardinality: Cardinality) {
let index = get_test_index_bench(cardinality).unwrap();
let reader = index.reader().unwrap();
b.iter(|| {
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "text_many_terms".to_string(),
..Default::default()
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
let collector = get_collector(agg_req);
let searcher = reader.searcher();
searcher.search(&AllQuery, &collector).unwrap()
});
}
bench_all_cardinalities!(bench_aggregation_terms_many_order_by_term);
fn bench_aggregation_terms_many_order_by_term_card(b: &mut Bencher, cardinality: Cardinality) {
let index = get_test_index_bench(cardinality).unwrap();
let reader = index.reader().unwrap();
b.iter(|| {
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "text_many_terms".to_string(),
order: Some(CustomOrder {
order: Order::Desc,
target: OrderTarget::Key,
}),
..Default::default()
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
let collector = get_collector(agg_req);
let searcher = reader.searcher();
searcher.search(&AllQuery, &collector).unwrap()
});
}
bench_all_cardinalities!(bench_aggregation_range_only);
fn bench_aggregation_range_only_card(b: &mut Bencher, cardinality: Cardinality) {
let index = get_test_index_bench(cardinality).unwrap();
let reader = index.reader().unwrap();
b.iter(|| {
let agg_req_1: Aggregations = vec![(
"rangef64".to_string(),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "score_f64".to_string(),
ranges: vec![
(3f64..7000f64).into(),
(7000f64..20000f64).into(),
(20000f64..30000f64).into(),
(30000f64..40000f64).into(),
(40000f64..50000f64).into(),
(50000f64..60000f64).into(),
],
..Default::default()
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
let collector = get_collector(agg_req_1);
let searcher = reader.searcher();
searcher.search(&AllQuery, &collector).unwrap()
});
}
bench_all_cardinalities!(bench_aggregation_range_with_avg);
fn bench_aggregation_range_with_avg_card(b: &mut Bencher, cardinality: Cardinality) {
let index = get_test_index_bench(cardinality).unwrap();
let reader = index.reader().unwrap();
b.iter(|| {
let sub_agg_req: Aggregations = vec![(
"average_f64".to_string(),
Aggregation::Metric(MetricAggregation::Average(
AverageAggregation::from_field_name("score_f64".to_string()),
)),
)]
.into_iter()
.collect();
let agg_req_1: Aggregations = vec![(
"rangef64".to_string(),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "score_f64".to_string(),
ranges: vec![
(3f64..7000f64).into(),
(7000f64..20000f64).into(),
(20000f64..30000f64).into(),
(30000f64..40000f64).into(),
(40000f64..50000f64).into(),
(50000f64..60000f64).into(),
],
..Default::default()
}),
sub_aggregation: sub_agg_req,
}
.into(),
),
)]
.into_iter()
.collect();
let collector = get_collector(agg_req_1);
let searcher = reader.searcher();
searcher.search(&AllQuery, &collector).unwrap()
});
}
// hard bounds has a different algorithm, because it actually limits collection range
//
bench_all_cardinalities!(bench_aggregation_histogram_only_hard_bounds);
fn bench_aggregation_histogram_only_hard_bounds_card(
b: &mut Bencher,
cardinality: Cardinality,
) {
let index = get_test_index_bench(cardinality).unwrap();
let reader = index.reader().unwrap();
b.iter(|| {
let agg_req_1: Aggregations = vec![(
"rangef64".to_string(),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 100f64,
hard_bounds: Some(HistogramBounds {
min: 1000.0,
max: 300_000.0,
}),
..Default::default()
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
let collector = get_collector(agg_req_1);
let searcher = reader.searcher();
searcher.search(&AllQuery, &collector).unwrap()
});
}
bench_all_cardinalities!(bench_aggregation_histogram_with_avg);
fn bench_aggregation_histogram_with_avg_card(b: &mut Bencher, cardinality: Cardinality) {
let index = get_test_index_bench(cardinality).unwrap();
let reader = index.reader().unwrap();
b.iter(|| {
let sub_agg_req: Aggregations = vec![(
"average_f64".to_string(),
Aggregation::Metric(MetricAggregation::Average(
AverageAggregation::from_field_name("score_f64".to_string()),
)),
)]
.into_iter()
.collect();
let agg_req_1: Aggregations = vec![(
"rangef64".to_string(),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 100f64, // 1000 buckets
..Default::default()
}),
sub_aggregation: sub_agg_req,
}
.into(),
),
)]
.into_iter()
.collect();
let collector = get_collector(agg_req_1);
let searcher = reader.searcher();
searcher.search(&AllQuery, &collector).unwrap()
});
}
bench_all_cardinalities!(bench_aggregation_histogram_only);
fn bench_aggregation_histogram_only_card(b: &mut Bencher, cardinality: Cardinality) {
let index = get_test_index_bench(cardinality).unwrap();
let reader = index.reader().unwrap();
b.iter(|| {
let agg_req_1: Aggregations = vec![(
"rangef64".to_string(),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 100f64, // 1000 buckets
..Default::default()
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
let collector = get_collector(agg_req_1);
let searcher = reader.searcher();
searcher.search(&AllQuery, &collector).unwrap()
});
}
bench_all_cardinalities!(bench_aggregation_avg_and_range_with_avg);
fn bench_aggregation_avg_and_range_with_avg_card(b: &mut Bencher, cardinality: Cardinality) {
let index = get_test_index_bench(cardinality).unwrap();
let reader = index.reader().unwrap();
let text_field = reader.searcher().schema().get_field("text").unwrap();
b.iter(|| {
let term_query = TermQuery::new(
Term::from_field_text(text_field, "cool"),
IndexRecordOption::Basic,
);
let sub_agg_req_1: Aggregations = vec![(
"average_in_range".to_string(),
Aggregation::Metric(MetricAggregation::Average(
AverageAggregation::from_field_name("score".to_string()),
)),
)]
.into_iter()
.collect();
let agg_req_1: Aggregations = vec![
(
"average".to_string(),
Aggregation::Metric(MetricAggregation::Average(
AverageAggregation::from_field_name("score".to_string()),
)),
),
(
"rangef64".to_string(),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "score_f64".to_string(),
ranges: vec![
(3f64..7000f64).into(),
(7000f64..20000f64).into(),
(20000f64..60000f64).into(),
],
..Default::default()
}),
sub_aggregation: sub_agg_req_1,
}
.into(),
),
),
]
.into_iter()
.collect();
let collector = get_collector(agg_req_1);
let searcher = reader.searcher();
searcher.search(&term_query, &collector).unwrap()
});
}
}

View File

@@ -0,0 +1,95 @@
use std::collections::HashMap;
use std::sync::atomic::AtomicU64;
use std::sync::Arc;
use common::ByteCount;
use super::collector::DEFAULT_MEMORY_LIMIT;
use super::{AggregationError, DEFAULT_BUCKET_LIMIT};
use crate::TantivyError;
/// An estimate for memory consumption. Non recursive
pub trait MemoryConsumption {
fn memory_consumption(&self) -> usize;
}
impl<K, V, S> MemoryConsumption for HashMap<K, V, S> {
fn memory_consumption(&self) -> usize {
let num_items = self.capacity();
(std::mem::size_of::<K>() + std::mem::size_of::<V>()) * num_items
}
}
/// Aggregation memory limit after which the request fails. Defaults to DEFAULT_MEMORY_LIMIT
/// (500MB). The limit is shared by all SegmentCollectors
pub struct AggregationLimits {
/// The counter which is shared between the aggregations for one request.
memory_consumption: Arc<AtomicU64>,
/// The memory_limit in bytes
memory_limit: ByteCount,
/// The maximum number of buckets _returned_
/// This is not counting intermediate buckets.
bucket_limit: u32,
}
impl Clone for AggregationLimits {
fn clone(&self) -> Self {
Self {
memory_consumption: Arc::clone(&self.memory_consumption),
memory_limit: self.memory_limit,
bucket_limit: self.bucket_limit,
}
}
}
impl Default for AggregationLimits {
fn default() -> Self {
Self {
memory_consumption: Default::default(),
memory_limit: DEFAULT_MEMORY_LIMIT.into(),
bucket_limit: DEFAULT_BUCKET_LIMIT,
}
}
}
impl AggregationLimits {
/// *memory_limit*
/// memory_limit is defined in bytes.
/// Aggregation fails when the estimated memory consumption of the aggregation is higher than
/// memory_limit.
/// memory_limit will default to `DEFAULT_MEMORY_LIMIT` (500MB)
///
/// *bucket_limit*
/// Limits the maximum number of buckets returned from an aggregation request.
/// bucket_limit will default to `DEFAULT_BUCKET_LIMIT` (65000)
pub fn new(memory_limit: Option<u64>, bucket_limit: Option<u32>) -> Self {
Self {
memory_consumption: Default::default(),
memory_limit: memory_limit.unwrap_or(DEFAULT_MEMORY_LIMIT).into(),
bucket_limit: bucket_limit.unwrap_or(DEFAULT_BUCKET_LIMIT),
}
}
pub(crate) fn validate_memory_consumption(&self) -> crate::Result<()> {
if self.get_memory_consumed() > self.memory_limit {
return Err(TantivyError::AggregationError(
AggregationError::MemoryExceeded {
limit: self.memory_limit,
current: self.get_memory_consumed(),
},
));
}
Ok(())
}
pub(crate) fn add_memory_consumed(&self, num_bytes: u64) {
self.memory_consumption
.fetch_add(num_bytes, std::sync::atomic::Ordering::Relaxed);
}
/// Returns the estimated memory consumed by the aggregations
pub fn get_memory_consumed(&self) -> ByteCount {
self.memory_consumption
.load(std::sync::atomic::Ordering::Relaxed)
.into()
}
pub(crate) fn get_bucket_limit(&self) -> u32 {
self.bucket_limit
}
}

View File

@@ -16,14 +16,14 @@
//! let agg_req1: Aggregations = vec![
//! (
//! "range".to_string(),
//! Aggregation::Bucket(BucketAggregation {
//! Aggregation::Bucket(Box::new(BucketAggregation {
//! bucket_agg: BucketAggregationType::Range(RangeAggregation{
//! field: "score".to_string(),
//! ranges: vec![(3f64..7f64).into(), (7f64..20f64).into()],
//! keyed: false,
//! }),
//! sub_aggregation: Default::default(),
//! }),
//! })),
//! ),
//! ]
//! .into_iter()
@@ -50,7 +50,7 @@ use std::collections::{HashMap, HashSet};
use serde::{Deserialize, Serialize};
pub use super::bucket::RangeAggregation;
use super::bucket::{HistogramAggregation, TermsAggregation};
use super::bucket::{DateHistogramAggregationReq, HistogramAggregation, TermsAggregation};
use super::metric::{
AverageAggregation, CountAggregation, MaxAggregation, MinAggregation, StatsAggregation,
SumAggregation,
@@ -110,10 +110,13 @@ impl BucketAggregationInternal {
_ => None,
}
}
pub(crate) fn as_histogram(&self) -> Option<&HistogramAggregation> {
pub(crate) fn as_histogram(&self) -> crate::Result<Option<HistogramAggregation>> {
match &self.bucket_agg {
BucketAggregationType::Histogram(histogram) => Some(histogram),
_ => None,
BucketAggregationType::Histogram(histogram) => Ok(Some(histogram.clone())),
BucketAggregationType::DateHistogram(histogram) => {
Ok(Some(histogram.to_histogram_req()?))
}
_ => Ok(None),
}
}
pub(crate) fn as_term(&self) -> Option<&TermsAggregation> {
@@ -124,15 +127,6 @@ impl BucketAggregationInternal {
}
}
/// Extract all fields, where the term directory is used in the tree.
pub fn get_term_dict_field_names(aggs: &Aggregations) -> HashSet<String> {
let mut term_dict_field_names = Default::default();
for el in aggs.values() {
el.get_term_dict_field_names(&mut term_dict_field_names)
}
term_dict_field_names
}
/// Extract all fast field names used in the tree.
pub fn get_fast_field_names(aggs: &Aggregations) -> HashSet<String> {
let mut fast_field_names = Default::default();
@@ -149,22 +143,18 @@ pub fn get_fast_field_names(aggs: &Aggregations) -> HashSet<String> {
#[serde(untagged)]
pub enum Aggregation {
/// Bucket aggregation, see [`BucketAggregation`] for details.
Bucket(BucketAggregation),
Bucket(Box<BucketAggregation>),
/// Metric aggregation, see [`MetricAggregation`] for details.
Metric(MetricAggregation),
}
impl Aggregation {
fn get_term_dict_field_names(&self, term_field_names: &mut HashSet<String>) {
if let Aggregation::Bucket(bucket) = self {
bucket.get_term_dict_field_names(term_field_names)
}
}
fn get_fast_field_names(&self, fast_field_names: &mut HashSet<String>) {
match self {
Aggregation::Bucket(bucket) => bucket.get_fast_field_names(fast_field_names),
Aggregation::Metric(metric) => metric.get_fast_field_names(fast_field_names),
Aggregation::Metric(metric) => {
fast_field_names.insert(metric.get_fast_field_name().to_string());
}
}
}
}
@@ -193,14 +183,9 @@ pub struct BucketAggregation {
}
impl BucketAggregation {
fn get_term_dict_field_names(&self, term_dict_field_names: &mut HashSet<String>) {
if let BucketAggregationType::Terms(terms) = &self.bucket_agg {
term_dict_field_names.insert(terms.field.to_string());
}
term_dict_field_names.extend(get_term_dict_field_names(&self.sub_aggregation));
}
fn get_fast_field_names(&self, fast_field_names: &mut HashSet<String>) {
self.bucket_agg.get_fast_field_names(fast_field_names);
let fast_field_name = self.bucket_agg.get_fast_field_name();
fast_field_names.insert(fast_field_name.to_string());
fast_field_names.extend(get_fast_field_names(&self.sub_aggregation));
}
}
@@ -214,20 +199,22 @@ pub enum BucketAggregationType {
/// Put data into buckets of user-defined ranges.
#[serde(rename = "histogram")]
Histogram(HistogramAggregation),
/// Put data into buckets of user-defined ranges.
#[serde(rename = "date_histogram")]
DateHistogram(DateHistogramAggregationReq),
/// Put data into buckets of terms.
#[serde(rename = "terms")]
Terms(TermsAggregation),
}
impl BucketAggregationType {
fn get_fast_field_names(&self, fast_field_names: &mut HashSet<String>) {
fn get_fast_field_name(&self) -> &str {
match self {
BucketAggregationType::Terms(terms) => fast_field_names.insert(terms.field.to_string()),
BucketAggregationType::Range(range) => fast_field_names.insert(range.field.to_string()),
BucketAggregationType::Histogram(histogram) => {
fast_field_names.insert(histogram.field.to_string())
}
};
BucketAggregationType::Terms(terms) => terms.field.as_str(),
BucketAggregationType::Range(range) => range.field.as_str(),
BucketAggregationType::Histogram(histogram) => histogram.field.as_str(),
BucketAggregationType::DateHistogram(histogram) => histogram.field.as_str(),
}
}
}
@@ -262,16 +249,15 @@ pub enum MetricAggregation {
}
impl MetricAggregation {
fn get_fast_field_names(&self, fast_field_names: &mut HashSet<String>) {
let fast_field_name = match self {
fn get_fast_field_name(&self) -> &str {
match self {
MetricAggregation::Average(avg) => avg.field_name(),
MetricAggregation::Count(count) => count.field_name(),
MetricAggregation::Max(max) => max.field_name(),
MetricAggregation::Min(min) => min.field_name(),
MetricAggregation::Stats(stats) => stats.field_name(),
MetricAggregation::Sum(sum) => sum.field_name(),
};
fast_field_names.insert(fast_field_name.to_string());
}
}
}
@@ -315,7 +301,7 @@ mod tests {
fn serialize_to_json_test() {
let agg_req1: Aggregations = vec![(
"range".to_string(),
Aggregation::Bucket(BucketAggregation {
Aggregation::Bucket(Box::new(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "score".to_string(),
ranges: vec![
@@ -327,7 +313,7 @@ mod tests {
keyed: true,
}),
sub_aggregation: Default::default(),
}),
})),
)]
.into_iter()
.collect();
@@ -365,7 +351,7 @@ mod tests {
let agg_req2: Aggregations = vec![
(
"range".to_string(),
Aggregation::Bucket(BucketAggregation {
Aggregation::Bucket(Box::new(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "score2".to_string(),
ranges: vec![
@@ -377,7 +363,7 @@ mod tests {
..Default::default()
}),
sub_aggregation: Default::default(),
}),
})),
),
(
"metric".to_string(),
@@ -391,7 +377,7 @@ mod tests {
let agg_req1: Aggregations = vec![(
"range".to_string(),
Aggregation::Bucket(BucketAggregation {
Aggregation::Bucket(Box::new(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "score".to_string(),
ranges: vec![
@@ -403,7 +389,7 @@ mod tests {
..Default::default()
}),
sub_aggregation: agg_req2,
}),
})),
)]
.into_iter()
.collect();

View File

@@ -1,20 +1,18 @@
//! This will enhance the request tree with access to the fastfield and metadata.
use std::rc::Rc;
use std::sync::atomic::AtomicU32;
use columnar::{Column, StrColumn};
use columnar::{Column, ColumnBlockAccessor, ColumnType, StrColumn};
use super::agg_req::{Aggregation, Aggregations, BucketAggregationType, MetricAggregation};
use super::bucket::{HistogramAggregation, RangeAggregation, TermsAggregation};
use super::bucket::{
DateHistogramAggregationReq, HistogramAggregation, RangeAggregation, TermsAggregation,
};
use super::metric::{
AverageAggregation, CountAggregation, MaxAggregation, MinAggregation, StatsAggregation,
SumAggregation,
};
use super::segment_agg_result::BucketCount;
use super::segment_agg_result::AggregationLimits;
use super::VecWithNames;
use crate::schema::Type;
use crate::{SegmentReader, TantivyError};
use crate::SegmentReader;
#[derive(Clone, Default)]
pub(crate) struct AggregationsWithAccessor {
@@ -41,10 +39,20 @@ pub struct BucketAggregationWithAccessor {
/// based on search terms. So eventually this needs to be Option or moved.
pub(crate) accessor: Column<u64>,
pub(crate) str_dict_column: Option<StrColumn>,
pub(crate) field_type: Type,
pub(crate) field_type: ColumnType,
pub(crate) bucket_agg: BucketAggregationType,
pub(crate) sub_aggregation: AggregationsWithAccessor,
pub(crate) bucket_count: BucketCount,
pub(crate) limits: AggregationLimits,
pub(crate) column_block_accessor: ColumnBlockAccessor<u64>,
}
fn get_numeric_or_date_column_types() -> &'static [ColumnType] {
&[
ColumnType::F64,
ColumnType::U64,
ColumnType::I64,
ColumnType::DateTime,
]
}
impl BucketAggregationWithAccessor {
@@ -52,22 +60,37 @@ impl BucketAggregationWithAccessor {
bucket: &BucketAggregationType,
sub_aggregation: &Aggregations,
reader: &SegmentReader,
bucket_count: Rc<AtomicU32>,
max_bucket_count: u32,
limits: AggregationLimits,
) -> crate::Result<BucketAggregationWithAccessor> {
let mut str_dict_column = None;
let (accessor, field_type) = match &bucket {
BucketAggregationType::Range(RangeAggregation {
field: field_name, ..
}) => get_ff_reader_and_validate(reader, field_name)?,
}) => get_ff_reader_and_validate(
reader,
field_name,
Some(get_numeric_or_date_column_types()),
)?,
BucketAggregationType::Histogram(HistogramAggregation {
field: field_name, ..
}) => get_ff_reader_and_validate(reader, field_name)?,
}) => get_ff_reader_and_validate(
reader,
field_name,
Some(get_numeric_or_date_column_types()),
)?,
BucketAggregationType::DateHistogram(DateHistogramAggregationReq {
field: field_name,
..
}) => get_ff_reader_and_validate(
reader,
field_name,
Some(get_numeric_or_date_column_types()),
)?,
BucketAggregationType::Terms(TermsAggregation {
field: field_name, ..
}) => {
str_dict_column = reader.fast_fields().str(field_name)?;
get_ff_reader_and_validate(reader, field_name)?
get_ff_reader_and_validate(reader, field_name, None)?
}
};
let sub_aggregation = sub_aggregation.clone();
@@ -77,15 +100,12 @@ impl BucketAggregationWithAccessor {
sub_aggregation: get_aggs_with_accessor_and_validate(
&sub_aggregation,
reader,
bucket_count.clone(),
max_bucket_count,
&limits.clone(),
)?,
bucket_agg: bucket.clone(),
str_dict_column,
bucket_count: BucketCount {
bucket_count,
max_bucket_count,
},
limits,
column_block_accessor: Default::default(),
})
}
}
@@ -94,8 +114,9 @@ impl BucketAggregationWithAccessor {
#[derive(Clone)]
pub struct MetricAggregationWithAccessor {
pub metric: MetricAggregation,
pub field_type: Type,
pub field_type: ColumnType,
pub accessor: Column<u64>,
pub column_block_accessor: ColumnBlockAccessor<u64>,
}
impl MetricAggregationWithAccessor {
@@ -110,12 +131,17 @@ impl MetricAggregationWithAccessor {
| MetricAggregation::Min(MinAggregation { field: field_name })
| MetricAggregation::Stats(StatsAggregation { field: field_name })
| MetricAggregation::Sum(SumAggregation { field: field_name }) => {
let (accessor, field_type) = get_ff_reader_and_validate(reader, field_name)?;
let (accessor, field_type) = get_ff_reader_and_validate(
reader,
field_name,
Some(get_numeric_or_date_column_types()),
)?;
Ok(MetricAggregationWithAccessor {
accessor,
field_type,
metric: metric.clone(),
column_block_accessor: Default::default(),
})
}
}
@@ -125,8 +151,7 @@ impl MetricAggregationWithAccessor {
pub(crate) fn get_aggs_with_accessor_and_validate(
aggs: &Aggregations,
reader: &SegmentReader,
bucket_count: Rc<AtomicU32>,
max_bucket_count: u32,
limits: &AggregationLimits,
) -> crate::Result<AggregationsWithAccessor> {
let mut metrics = vec![];
let mut buckets = vec![];
@@ -138,8 +163,7 @@ pub(crate) fn get_aggs_with_accessor_and_validate(
&bucket.bucket_agg,
&bucket.sub_aggregation,
reader,
Rc::clone(&bucket_count),
max_bucket_count,
limits.clone(),
)?,
)),
Aggregation::Metric(metric) => metrics.push((
@@ -158,22 +182,16 @@ pub(crate) fn get_aggs_with_accessor_and_validate(
fn get_ff_reader_and_validate(
reader: &SegmentReader,
field_name: &str,
) -> crate::Result<(columnar::Column<u64>, Type)> {
let field = reader.schema().get_field(field_name)?;
// TODO we should get type metadata from columnar
let field_type = reader
.schema()
.get_field_entry(field)
.field_type()
.value_type();
// TODO Do validation
allowed_column_types: Option<&[ColumnType]>,
) -> crate::Result<(columnar::Column<u64>, ColumnType)> {
let ff_fields = reader.fast_fields();
let ff_field = ff_fields.u64_lenient(field_name)?.ok_or_else(|| {
TantivyError::InvalidArgument(format!(
"No numerical fast field found for field: {}",
field_name
))
})?;
Ok((ff_field, field_type))
let ff_field_with_type = ff_fields
.u64_lenient_for_type(allowed_column_types, field_name)?
.unwrap_or_else(|| {
(
Column::build_empty_column(reader.num_docs()),
ColumnType::U64,
)
});
Ok(ff_field_with_type)
}

View File

@@ -11,8 +11,8 @@ use super::agg_req::BucketAggregationInternal;
use super::bucket::GetDocCount;
use super::intermediate_agg_result::{IntermediateBucketResult, IntermediateMetricResult};
use super::metric::{SingleMetricResult, Stats};
use super::segment_agg_result::AggregationLimits;
use super::Key;
use crate::schema::Schema;
use crate::TantivyError;
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
@@ -20,6 +20,13 @@ use crate::TantivyError;
pub struct AggregationResults(pub FxHashMap<String, AggregationResult>);
impl AggregationResults {
pub(crate) fn get_bucket_count(&self) -> u64 {
self.0
.values()
.map(|agg| agg.get_bucket_count())
.sum::<u64>()
}
pub(crate) fn get_value_from_aggregation(
&self,
name: &str,
@@ -48,6 +55,13 @@ pub enum AggregationResult {
}
impl AggregationResult {
pub(crate) fn get_bucket_count(&self) -> u64 {
match self {
AggregationResult::BucketResult(bucket) => bucket.get_bucket_count(),
AggregationResult::MetricResult(_) => 0,
}
}
pub(crate) fn get_value_from_aggregation(
&self,
_name: &str,
@@ -154,12 +168,28 @@ pub enum BucketResult {
}
impl BucketResult {
pub(crate) fn get_bucket_count(&self) -> u64 {
match self {
BucketResult::Range { buckets } => {
buckets.iter().map(|bucket| bucket.get_bucket_count()).sum()
}
BucketResult::Histogram { buckets } => {
buckets.iter().map(|bucket| bucket.get_bucket_count()).sum()
}
BucketResult::Terms {
buckets,
sum_other_doc_count: _,
doc_count_error_upper_bound: _,
} => buckets.iter().map(|bucket| bucket.get_bucket_count()).sum(),
}
}
pub(crate) fn empty_from_req(
req: &BucketAggregationInternal,
schema: &Schema,
limits: &AggregationLimits,
) -> crate::Result<Self> {
let empty_bucket = IntermediateBucketResult::empty_from_req(&req.bucket_agg);
empty_bucket.into_final_bucket_result(req, schema)
empty_bucket.into_final_bucket_result(req, limits)
}
}
@@ -174,6 +204,15 @@ pub enum BucketEntries<T> {
HashMap(FxHashMap<String, T>),
}
impl<T> BucketEntries<T> {
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = &T> + 'a> {
match self {
BucketEntries::Vec(vec) => Box::new(vec.iter()),
BucketEntries::HashMap(map) => Box::new(map.values()),
}
}
}
/// This is the default entry for a bucket, which contains a key, count, and optionally
/// sub-aggregations.
///
@@ -213,6 +252,11 @@ pub struct BucketEntry {
/// Sub-aggregations in this bucket.
pub sub_aggregation: AggregationResults,
}
impl BucketEntry {
pub(crate) fn get_bucket_count(&self) -> u64 {
1 + self.sub_aggregation.get_bucket_count()
}
}
impl GetDocCount for &BucketEntry {
fn doc_count(&self) -> u64 {
self.doc_count
@@ -276,3 +320,8 @@ pub struct RangeBucketEntry {
#[serde(skip_serializing_if = "Option::is_none")]
pub to_as_string: Option<String>,
}
impl RangeBucketEntry {
pub(crate) fn get_bucket_count(&self) -> u64 {
1 + self.sub_aggregation.get_bucket_count()
}
}

View File

@@ -0,0 +1,866 @@
use serde_json::Value;
use crate::aggregation::agg_req::{
Aggregation, Aggregations, BucketAggregation, BucketAggregationType, MetricAggregation,
};
use crate::aggregation::agg_result::AggregationResults;
use crate::aggregation::bucket::{RangeAggregation, TermsAggregation};
use crate::aggregation::buf_collector::DOC_BLOCK_SIZE;
use crate::aggregation::collector::AggregationCollector;
use crate::aggregation::intermediate_agg_result::IntermediateAggregationResults;
use crate::aggregation::metric::AverageAggregation;
use crate::aggregation::segment_agg_result::AggregationLimits;
use crate::aggregation::tests::{get_test_index_2_segments, get_test_index_from_values_and_terms};
use crate::aggregation::DistributedAggregationCollector;
use crate::query::{AllQuery, TermQuery};
use crate::schema::{IndexRecordOption, Schema, FAST};
use crate::{Index, Term};
fn get_avg_req(field_name: &str) -> Aggregation {
Aggregation::Metric(MetricAggregation::Average(
AverageAggregation::from_field_name(field_name.to_string()),
))
}
fn get_collector(agg_req: Aggregations) -> AggregationCollector {
AggregationCollector::from_aggs(agg_req, Default::default())
}
// *** EVERY BUCKET-TYPE SHOULD BE TESTED HERE ***
fn test_aggregation_flushing(
merge_segments: bool,
use_distributed_collector: bool,
) -> crate::Result<()> {
let mut values_and_terms = (0..80)
.map(|val| vec![(val as f64, "terma".to_string())])
.collect::<Vec<_>>();
values_and_terms.last_mut().unwrap()[0].1 = "termb".to_string();
let index = get_test_index_from_values_and_terms(merge_segments, &values_and_terms)?;
let reader = index.reader()?;
assert_eq!(DOC_BLOCK_SIZE, 64);
// In the tree we cache Documents of DOC_BLOCK_SIZE, before passing them down as one block.
//
// Build a request so that on the first level we have one full cache, which is then flushed.
// The same cache should have some residue docs at the end, which are flushed (Range 0-70)
// -> 70 docs
//
// The second level should also have some residue docs in the cache that are flushed at the
// end.
//
// A second bucket on the first level should have the cache unfilled
// let elasticsearch_compatible_json_req = r#"
let elasticsearch_compatible_json = json!(
{
"bucketsL1": {
"range": {
"field": "score",
"ranges": [ { "to": 3.0f64 }, { "from": 3.0f64, "to": 70.0f64 }, { "from": 70.0f64 } ]
},
"aggs": {
"bucketsL2": {
"range": {
"field": "score",
"ranges": [ { "to": 30.0f64 }, { "from": 30.0f64, "to": 70.0f64 }, { "from": 70.0f64 } ]
}
}
}
},
"histogram_test":{
"histogram": {
"field": "score",
"interval": 70.0,
"offset": 3.0
},
"aggs": {
"bucketsL2": {
"histogram": {
"field": "score",
"interval": 70.0
}
}
}
},
"term_agg_test":{
"terms": {
"field": "string_id"
},
"aggs": {
"bucketsL2": {
"histogram": {
"field": "score",
"interval": 70.0
}
}
}
}
});
let agg_req: Aggregations =
serde_json::from_str(&serde_json::to_string(&elasticsearch_compatible_json).unwrap())
.unwrap();
let agg_res: AggregationResults = if use_distributed_collector {
let collector = DistributedAggregationCollector::from_aggs(
agg_req.clone(),
AggregationLimits::default(),
);
let searcher = reader.searcher();
let intermediate_agg_result = searcher.search(&AllQuery, &collector).unwrap();
intermediate_agg_result
.into_final_bucket_result(agg_req, &Default::default())
.unwrap()
} else {
let collector = get_collector(agg_req);
let searcher = reader.searcher();
searcher.search(&AllQuery, &collector).unwrap()
};
let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?;
assert_eq!(res["bucketsL1"]["buckets"][0]["doc_count"], 3);
assert_eq!(
res["bucketsL1"]["buckets"][0]["bucketsL2"]["buckets"][0]["doc_count"],
3
);
assert_eq!(res["bucketsL1"]["buckets"][1]["key"], "3-70");
assert_eq!(res["bucketsL1"]["buckets"][1]["doc_count"], 70 - 3);
assert_eq!(
res["bucketsL1"]["buckets"][1]["bucketsL2"]["buckets"][0]["doc_count"],
27
);
assert_eq!(
res["bucketsL1"]["buckets"][1]["bucketsL2"]["buckets"][1]["doc_count"],
40
);
assert_eq!(
res["bucketsL1"]["buckets"][1]["bucketsL2"]["buckets"][2]["doc_count"],
0
);
assert_eq!(
res["bucketsL1"]["buckets"][2]["bucketsL2"]["buckets"][2]["doc_count"],
80 - 70
);
assert_eq!(res["bucketsL1"]["buckets"][2]["doc_count"], 80 - 70);
assert_eq!(
res["term_agg_test"],
json!(
{
"buckets": [
{
"bucketsL2": {
"buckets": [
{
"doc_count": 70,
"key": 0.0
},
{
"doc_count": 9,
"key": 70.0
}
]
},
"doc_count": 79,
"key": "terma"
},
{
"bucketsL2": {
"buckets": [
{
"doc_count": 1,
"key": 70.0
}
]
},
"doc_count": 1,
"key": "termb"
}
],
"doc_count_error_upper_bound": 0,
"sum_other_doc_count": 0
}
)
);
Ok(())
}
#[test]
fn test_aggregation_flushing_variants() {
test_aggregation_flushing(false, false).unwrap();
test_aggregation_flushing(false, true).unwrap();
test_aggregation_flushing(true, false).unwrap();
test_aggregation_flushing(true, true).unwrap();
}
#[test]
fn test_aggregation_level1() -> crate::Result<()> {
let index = get_test_index_2_segments(true)?;
let reader = index.reader()?;
let text_field = reader.searcher().schema().get_field("text").unwrap();
let term_query = TermQuery::new(
Term::from_field_text(text_field, "cool"),
IndexRecordOption::Basic,
);
let agg_req_1: Aggregations = vec![
("average_i64".to_string(), get_avg_req("score_i64")),
("average_f64".to_string(), get_avg_req("score_f64")),
("average".to_string(), get_avg_req("score")),
(
"range".to_string(),
Aggregation::Bucket(Box::new(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "score".to_string(),
ranges: vec![(3f64..7f64).into(), (7f64..20f64).into()],
..Default::default()
}),
sub_aggregation: Default::default(),
})),
),
(
"rangef64".to_string(),
Aggregation::Bucket(Box::new(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "score_f64".to_string(),
ranges: vec![(3f64..7f64).into(), (7f64..20f64).into()],
..Default::default()
}),
sub_aggregation: Default::default(),
})),
),
(
"rangei64".to_string(),
Aggregation::Bucket(Box::new(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "score_i64".to_string(),
ranges: vec![(3f64..7f64).into(), (7f64..20f64).into()],
..Default::default()
}),
sub_aggregation: Default::default(),
})),
),
]
.into_iter()
.collect();
let collector = get_collector(agg_req_1);
let searcher = reader.searcher();
let agg_res: AggregationResults = searcher.search(&term_query, &collector).unwrap();
let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?;
assert_eq!(res["average"]["value"], 12.142857142857142);
assert_eq!(res["average_f64"]["value"], 12.214285714285714);
assert_eq!(res["average_i64"]["value"], 12.142857142857142);
assert_eq!(
res["range"]["buckets"],
json!(
[
{
"key": "*-3",
"doc_count": 1,
"to": 3.0
},
{
"key": "3-7",
"doc_count": 2,
"from": 3.0,
"to": 7.0
},
{
"key": "7-20",
"doc_count": 3,
"from": 7.0,
"to": 20.0
},
{
"key": "20-*",
"doc_count": 1,
"from": 20.0
}
])
);
Ok(())
}
fn test_aggregation_level2(
merge_segments: bool,
use_distributed_collector: bool,
use_elastic_json_req: bool,
) -> crate::Result<()> {
let index = get_test_index_2_segments(merge_segments)?;
let reader = index.reader()?;
let text_field = reader.searcher().schema().get_field("text").unwrap();
let term_query = TermQuery::new(
Term::from_field_text(text_field, "cool"),
IndexRecordOption::Basic,
);
let query_with_no_hits = TermQuery::new(
Term::from_field_text(text_field, "thistermdoesnotexist"),
IndexRecordOption::Basic,
);
let sub_agg_req: Aggregations = vec![
("average_in_range".to_string(), get_avg_req("score")),
(
"term_agg".to_string(),
Aggregation::Bucket(Box::new(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "text".to_string(),
..Default::default()
}),
sub_aggregation: Default::default(),
})),
),
]
.into_iter()
.collect();
let agg_req: Aggregations = if use_elastic_json_req {
let elasticsearch_compatible_json_req = r#"
{
"rangef64": {
"range": {
"field": "score_f64",
"ranges": [
{ "to": 3.0 },
{ "from": 3.0, "to": 7.0 },
{ "from": 7.0, "to": 19.0 },
{ "from": 19.0, "to": 20.0 },
{ "from": 20.0 }
]
},
"aggs": {
"average_in_range": { "avg": { "field": "score" } },
"term_agg": { "terms": { "field": "text" } }
}
},
"rangei64": {
"range": {
"field": "score_i64",
"ranges": [
{ "to": 3.0 },
{ "from": 3.0, "to": 7.0 },
{ "from": 7.0, "to": 19.0 },
{ "from": 19.0, "to": 20.0 },
{ "from": 20.0 }
]
},
"aggs": {
"average_in_range": { "avg": { "field": "score" } },
"term_agg": { "terms": { "field": "text" } }
}
},
"average": {
"avg": { "field": "score" }
},
"range": {
"range": {
"field": "score",
"ranges": [
{ "to": 3.0 },
{ "from": 3.0, "to": 7.0 },
{ "from": 7.0, "to": 19.0 },
{ "from": 19.0, "to": 20.0 },
{ "from": 20.0 }
]
},
"aggs": {
"average_in_range": { "avg": { "field": "score" } },
"term_agg": { "terms": { "field": "text" } }
}
}
}
"#;
let value: Aggregations = serde_json::from_str(elasticsearch_compatible_json_req).unwrap();
value
} else {
let agg_req: Aggregations = vec![
("average".to_string(), get_avg_req("score")),
(
"range".to_string(),
Aggregation::Bucket(Box::new(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "score".to_string(),
ranges: vec![
(3f64..7f64).into(),
(7f64..19f64).into(),
(19f64..20f64).into(),
],
..Default::default()
}),
sub_aggregation: sub_agg_req.clone(),
})),
),
(
"rangef64".to_string(),
Aggregation::Bucket(Box::new(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "score_f64".to_string(),
ranges: vec![
(3f64..7f64).into(),
(7f64..19f64).into(),
(19f64..20f64).into(),
],
..Default::default()
}),
sub_aggregation: sub_agg_req.clone(),
})),
),
(
"rangei64".to_string(),
Aggregation::Bucket(Box::new(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "score_i64".to_string(),
ranges: vec![
(3f64..7f64).into(),
(7f64..19f64).into(),
(19f64..20f64).into(),
],
..Default::default()
}),
sub_aggregation: sub_agg_req,
})),
),
]
.into_iter()
.collect();
agg_req
};
let agg_res: AggregationResults = if use_distributed_collector {
let collector =
DistributedAggregationCollector::from_aggs(agg_req.clone(), Default::default());
let searcher = reader.searcher();
let res = searcher.search(&term_query, &collector).unwrap();
// Test de/serialization roundtrip on intermediate_agg_result
let res: IntermediateAggregationResults =
serde_json::from_str(&serde_json::to_string(&res).unwrap()).unwrap();
res.into_final_bucket_result(agg_req.clone(), &Default::default())
.unwrap()
} else {
let collector = get_collector(agg_req.clone());
let searcher = reader.searcher();
searcher.search(&term_query, &collector).unwrap()
};
let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?;
assert_eq!(res["range"]["buckets"][1]["key"], "3-7");
assert_eq!(res["range"]["buckets"][1]["doc_count"], 2u64);
assert_eq!(res["rangef64"]["buckets"][1]["doc_count"], 2u64);
assert_eq!(res["rangei64"]["buckets"][1]["doc_count"], 2u64);
assert_eq!(res["average"]["value"], 12.142857142857142f64);
assert_eq!(res["range"]["buckets"][2]["key"], "7-19");
assert_eq!(res["range"]["buckets"][2]["doc_count"], 3u64);
assert_eq!(res["rangef64"]["buckets"][2]["doc_count"], 3u64);
assert_eq!(res["rangei64"]["buckets"][2]["doc_count"], 3u64);
assert_eq!(res["rangei64"]["buckets"][5], serde_json::Value::Null);
assert_eq!(res["range"]["buckets"][4]["key"], "20-*");
assert_eq!(res["range"]["buckets"][4]["doc_count"], 1u64);
assert_eq!(res["rangef64"]["buckets"][4]["doc_count"], 1u64);
assert_eq!(res["rangei64"]["buckets"][4]["doc_count"], 1u64);
assert_eq!(res["range"]["buckets"][3]["key"], "19-20");
assert_eq!(res["range"]["buckets"][3]["doc_count"], 0u64);
assert_eq!(res["rangef64"]["buckets"][3]["doc_count"], 0u64);
assert_eq!(res["rangei64"]["buckets"][3]["doc_count"], 0u64);
assert_eq!(
res["range"]["buckets"][3]["average_in_range"]["value"],
serde_json::Value::Null
);
assert_eq!(
res["range"]["buckets"][4]["average_in_range"]["value"],
44.0f64
);
assert_eq!(
res["rangef64"]["buckets"][4]["average_in_range"]["value"],
44.0f64
);
assert_eq!(
res["rangei64"]["buckets"][4]["average_in_range"]["value"],
44.0f64
);
assert_eq!(
res["range"]["7-19"]["average_in_range"]["value"],
res["rangef64"]["7-19"]["average_in_range"]["value"]
);
assert_eq!(
res["range"]["7-19"]["average_in_range"]["value"],
res["rangei64"]["7-19"]["average_in_range"]["value"]
);
// Test empty result set
let collector = get_collector(agg_req);
let searcher = reader.searcher();
searcher.search(&query_with_no_hits, &collector).unwrap();
Ok(())
}
#[test]
fn test_aggregation_level2_multi_segments() -> crate::Result<()> {
test_aggregation_level2(false, false, false)
}
#[test]
fn test_aggregation_level2_single_segment() -> crate::Result<()> {
test_aggregation_level2(true, false, false)
}
#[test]
fn test_aggregation_level2_multi_segments_distributed_collector() -> crate::Result<()> {
test_aggregation_level2(false, true, false)
}
#[test]
fn test_aggregation_level2_single_segment_distributed_collector() -> crate::Result<()> {
test_aggregation_level2(true, true, false)
}
#[test]
fn test_aggregation_level2_multi_segments_use_json() -> crate::Result<()> {
test_aggregation_level2(false, false, true)
}
#[test]
fn test_aggregation_level2_single_segment_use_json() -> crate::Result<()> {
test_aggregation_level2(true, false, true)
}
#[test]
fn test_aggregation_level2_multi_segments_distributed_collector_use_json() -> crate::Result<()> {
test_aggregation_level2(false, true, true)
}
#[test]
fn test_aggregation_level2_single_segment_distributed_collector_use_json() -> crate::Result<()> {
test_aggregation_level2(true, true, true)
}
#[test]
fn test_aggregation_invalid_requests() -> crate::Result<()> {
let index = get_test_index_2_segments(false)?;
let reader = index.reader()?;
let avg_on_field = |field_name: &str| {
let agg_req_1: Aggregations = vec![(
"average".to_string(),
Aggregation::Metric(MetricAggregation::Average(
AverageAggregation::from_field_name(field_name.to_string()),
)),
)]
.into_iter()
.collect();
let collector = get_collector(agg_req_1);
let searcher = reader.searcher();
searcher.search(&AllQuery, &collector)
};
let agg_res = avg_on_field("dummy_text").unwrap_err();
assert_eq!(
format!("{:?}", agg_res),
r#"InvalidArgument("Field \"dummy_text\" is not configured as fast field")"#
);
// TODO: This should return an error
// let agg_res = avg_on_field("not_exist_field").unwrap_err();
// assert_eq!(
// format!("{:?}", agg_res),
// r#"InvalidArgument("No fast field found for field: not_exist_field")"#
//);
// TODO: This should return an error
// let agg_res = avg_on_field("ip_addr").unwrap_err();
// assert_eq!(
// format!("{:?}", agg_res),
// r#"InvalidArgument("No fast field found for field: ip_addr")"#
//);
Ok(())
}
#[test]
fn test_aggregation_on_json_object() {
let mut schema_builder = Schema::builder();
let json = schema_builder.add_json_field("json", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests().unwrap();
index_writer
.add_document(doc!(json => json!({"color": "red"})))
.unwrap();
index_writer
.add_document(doc!(json => json!({"color": "blue"})))
.unwrap();
index_writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let agg: Aggregations = vec![(
"jsonagg".to_string(),
Aggregation::Bucket(Box::new(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "json.color".to_string(),
..Default::default()
}),
sub_aggregation: Default::default(),
})),
)]
.into_iter()
.collect();
let aggregation_collector = get_collector(agg);
let aggregation_results = searcher.search(&AllQuery, &aggregation_collector).unwrap();
let aggregation_res_json = serde_json::to_value(aggregation_results).unwrap();
assert_eq!(
&aggregation_res_json,
&serde_json::json!({
"jsonagg": {
"buckets": [
{"doc_count": 1, "key": "blue"},
{"doc_count": 1, "key": "red"}
],
"doc_count_error_upper_bound": 0,
"sum_other_doc_count": 0
}
})
);
}
#[test]
fn test_aggregation_on_json_object_empty_columns() {
let mut schema_builder = Schema::builder();
let json = schema_builder.add_json_field("json", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests().unwrap();
// => Empty column when accessing color
index_writer
.add_document(doc!(json => json!({"price": 10.0})))
.unwrap();
index_writer.commit().unwrap();
// => Empty column when accessing price
index_writer
.add_document(doc!(json => json!({"color": "blue"})))
.unwrap();
index_writer.commit().unwrap();
// => Non Empty columns
index_writer
.add_document(doc!(json => json!({"color": "red", "price": 10.0})))
.unwrap();
index_writer
.add_document(doc!(json => json!({"color": "red", "price": 10.0})))
.unwrap();
index_writer
.add_document(doc!(json => json!({"color": "green", "price": 20.0})))
.unwrap();
index_writer
.add_document(doc!(json => json!({"color": "green", "price": 20.0})))
.unwrap();
index_writer
.add_document(doc!(json => json!({"color": "green", "price": 20.0})))
.unwrap();
index_writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let agg: Aggregations = vec![(
"jsonagg".to_string(),
Aggregation::Bucket(Box::new(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "json.color".to_string(),
..Default::default()
}),
sub_aggregation: Default::default(),
})),
)]
.into_iter()
.collect();
let aggregation_collector = get_collector(agg);
let aggregation_results = searcher.search(&AllQuery, &aggregation_collector).unwrap();
let aggregation_res_json = serde_json::to_value(aggregation_results).unwrap();
assert_eq!(
&aggregation_res_json,
&serde_json::json!({
"jsonagg": {
"buckets": [
{"doc_count": 3, "key": "green"},
{"doc_count": 2, "key": "red"},
{"doc_count": 1, "key": "blue"}
],
"doc_count_error_upper_bound": 0,
"sum_other_doc_count": 0
}
})
);
let agg_req_str = r#"
{
"jsonagg": {
"aggs": {
"min_price": { "min": { "field": "json.price" } }
},
"terms": {
"field": "json.color",
"order": { "min_price": "desc" }
}
}
} "#;
let agg: Aggregations = serde_json::from_str(agg_req_str).unwrap();
let aggregation_collector = get_collector(agg);
let aggregation_results = searcher.search(&AllQuery, &aggregation_collector).unwrap();
let aggregation_res_json = serde_json::to_value(aggregation_results).unwrap();
assert_eq!(
&aggregation_res_json,
&serde_json::json!(
{
"jsonagg": {
"buckets": [
{
"key": "green",
"doc_count": 3,
"min_price": {
"value": 20.0
}
},
{
"key": "red",
"doc_count": 2,
"min_price": {
"value": 10.0
}
},
{
"key": "blue",
"doc_count": 1,
"min_price": {
"value": null
}
}
],
"sum_other_doc_count": 0
}
}
)
);
}
#[test]
fn test_aggregation_on_json_object_mixed_types() {
let mut schema_builder = Schema::builder();
let json = schema_builder.add_json_field("json", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests().unwrap();
// => Segment with all values numeric
index_writer
.add_document(doc!(json => json!({"mixed_type": 10.0})))
.unwrap();
index_writer.commit().unwrap();
// => Segment with all values text
index_writer
.add_document(doc!(json => json!({"mixed_type": "blue"})))
.unwrap();
index_writer.commit().unwrap();
// => Segment with all boolen
index_writer
.add_document(doc!(json => json!({"mixed_type": true})))
.unwrap();
index_writer.commit().unwrap();
// => Segment with mixed values
index_writer
.add_document(doc!(json => json!({"mixed_type": "red"})))
.unwrap();
index_writer
.add_document(doc!(json => json!({"mixed_type": -20.5})))
.unwrap();
index_writer
.add_document(doc!(json => json!({"mixed_type": true})))
.unwrap();
index_writer.commit().unwrap();
// All bucket types
let agg_req_str = r#"
{
"termagg": {
"terms": {
"field": "json.mixed_type",
"order": { "min_price": "desc" }
},
"aggs": {
"min_price": { "min": { "field": "json.mixed_type" } }
}
},
"rangeagg": {
"range": {
"field": "json.mixed_type",
"ranges": [
{ "to": 3.0 },
{ "from": 19.0, "to": 20.0 },
{ "from": 20.0 }
]
},
"aggs": {
"average_in_range": { "avg": { "field": "json.mixed_type" } }
}
}
} "#;
let agg: Aggregations = serde_json::from_str(agg_req_str).unwrap();
let aggregation_collector = get_collector(agg);
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let aggregation_results = searcher.search(&AllQuery, &aggregation_collector).unwrap();
let aggregation_res_json = serde_json::to_value(aggregation_results).unwrap();
assert_eq!(
&aggregation_res_json,
&serde_json::json!({
"rangeagg": {
"buckets": [
{ "average_in_range": { "value": -20.5 }, "doc_count": 1, "key": "*-3", "to": 3.0 },
{ "average_in_range": { "value": 10.0 }, "doc_count": 1, "from": 3.0, "key": "3-19", "to": 19.0 },
{ "average_in_range": { "value": null }, "doc_count": 0, "from": 19.0, "key": "19-20", "to": 20.0 },
{ "average_in_range": { "value": null }, "doc_count": 0, "from": 20.0, "key": "20-*" }
]
},
"termagg": {
"buckets": [
{ "doc_count": 1, "key": 10.0, "min_price": { "value": 10.0 } },
{ "doc_count": 1, "key": -20.5, "min_price": { "value": -20.5 } },
// TODO red is missing since there is no multi aggregation within one
// segment for multiple types
// TODO bool is also not yet handled in aggregation
{ "doc_count": 1, "key": "blue", "min_price": { "value": null } }
],
"sum_other_doc_count": 0
}
}
)
);
}

View File

@@ -1,5 +1,8 @@
use serde::{Deserialize, Serialize};
use super::{HistogramAggregation, HistogramBounds};
use crate::aggregation::AggregationError;
/// DateHistogramAggregation is similar to `HistogramAggregation`, but it can only be used with date
/// type.
///
@@ -29,8 +32,16 @@ use serde::{Deserialize, Serialize};
/// See [`BucketEntry`](crate::aggregation::agg_result::BucketEntry)
#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
pub struct DateHistogramAggregationReq {
#[doc(hidden)]
/// Only for validation
interval: Option<String>,
#[doc(hidden)]
/// Only for validation
date_interval: Option<String>,
/// The field to aggregate on.
pub field: String,
/// The format to format dates.
pub format: Option<String>,
/// The interval to chunk your data range. Each bucket spans a value range of
/// [0..fixed_interval). Accepted values
///
@@ -51,33 +62,143 @@ pub struct DateHistogramAggregationReq {
///
/// Fractional time values are not supported, but you can address this by shifting to another
/// time unit (e.g., `1.5h` could instead be specified as `90m`).
pub fixed_interval: String,
///
/// `Option` for validation, the parameter is not optional
pub fixed_interval: Option<String>,
/// Intervals implicitly defines an absolute grid of buckets `[interval * k, interval * (k +
/// 1))`.
pub offset: Option<String>,
/// The minimum number of documents in a bucket to be returned. Defaults to 0.
pub min_doc_count: Option<u64>,
/// Limits the data range to `[min, max]` closed interval.
///
/// This can be used to filter values if they are not in the data range.
///
/// hard_bounds only limits the buckets, to force a range set both extended_bounds and
/// hard_bounds to the same range.
///
/// Needs to be provided as timestamp in microseconds precision.
///
/// ## Example
/// ```json
/// {
/// "sales_over_time": {
/// "date_histogram": {
/// "field": "dates",
/// "interval": "1d",
/// "hard_bounds": {
/// "min": 0,
/// "max": 1420502400000000
/// }
/// }
/// }
/// }
/// ```
pub hard_bounds: Option<HistogramBounds>,
/// Can be set to extend your bounds. The range of the buckets is by default defined by the
/// data range of the values of the documents. As the name suggests, this can only be used to
/// extend the value range. If the bounds for min or max are not extending the range, the value
/// has no effect on the returned buckets.
///
/// Cannot be set in conjunction with min_doc_count > 0, since the empty buckets from extended
/// bounds would not be returned.
pub extended_bounds: Option<HistogramBounds>,
/// Whether to return the buckets as a hash map
#[serde(default)]
pub keyed: bool,
}
impl DateHistogramAggregationReq {
pub(crate) fn to_histogram_req(&self) -> crate::Result<HistogramAggregation> {
self.validate()?;
Ok(HistogramAggregation {
field: self.field.to_string(),
interval: parse_into_microseconds(self.fixed_interval.as_ref().unwrap())? as f64,
offset: self
.offset
.as_ref()
.map(|offset| parse_offset_into_microseconds(offset))
.transpose()?
.map(|el| el as f64),
min_doc_count: self.min_doc_count,
hard_bounds: None,
extended_bounds: None,
keyed: self.keyed,
})
}
fn validate(&self) -> crate::Result<()> {
if let Some(interval) = self.interval.as_ref() {
return Err(crate::TantivyError::InvalidArgument(format!(
"`interval` parameter {:?} in date histogram is unsupported, only \
`fixed_interval` is supported",
interval
)));
}
if let Some(interval) = self.date_interval.as_ref() {
return Err(crate::TantivyError::InvalidArgument(format!(
"`date_interval` parameter {:?} in date histogram is unsupported, only \
`fixed_interval` is supported",
interval
)));
}
if self.format.is_some() {
return Err(crate::TantivyError::InvalidArgument(
"format parameter on date_histogram is unsupported".to_string(),
));
}
if self.fixed_interval.is_none() {
return Err(crate::TantivyError::InvalidArgument(
"fixed_interval in date histogram is missing".to_string(),
));
}
parse_into_microseconds(self.fixed_interval.as_ref().unwrap())?;
Ok(())
}
}
#[derive(Debug, PartialEq, Eq)]
#[derive(Debug, Clone, PartialEq, Eq, Error)]
/// Errors when parsing the fixed interval for `DateHistogramAggregationReq`.
pub enum DateHistogramParseError {
/// Unit not recognized in passed String
#[error("Unit not recognized in passed String {0:?}")]
UnitNotRecognized(String),
/// Number not found in passed String
#[error("Number not found in passed String {0:?}")]
NumberMissing(String),
/// Unit not found in passed String
#[error("Unit not found in passed String {0:?}")]
UnitMissing(String),
/// Offset invalid
#[error("passed offset is invalid {0:?}")]
InvalidOffset(String),
}
fn parse_into_milliseconds(input: &str) -> Result<u64, DateHistogramParseError> {
fn parse_offset_into_microseconds(input: &str) -> Result<i64, AggregationError> {
let is_sign = |byte| &[byte] == b"-" || &[byte] == b"+";
if input.is_empty() {
return Err(DateHistogramParseError::InvalidOffset(input.to_string()).into());
}
let has_sign = is_sign(input.as_bytes()[0]);
if has_sign {
let (sign, input) = input.split_at(1);
let val = parse_into_microseconds(input)?;
if sign == "-" {
Ok(-val)
} else {
Ok(val)
}
} else {
parse_into_microseconds(input)
}
}
fn parse_into_microseconds(input: &str) -> Result<i64, AggregationError> {
let split_boundary = input
.as_bytes()
.iter()
@@ -85,12 +206,12 @@ fn parse_into_milliseconds(input: &str) -> Result<u64, DateHistogramParseError>
.count();
let (number, unit) = input.split_at(split_boundary);
if number.is_empty() {
return Err(DateHistogramParseError::NumberMissing(input.to_string()));
return Err(DateHistogramParseError::NumberMissing(input.to_string()).into());
}
if unit.is_empty() {
return Err(DateHistogramParseError::UnitMissing(input.to_string()));
return Err(DateHistogramParseError::UnitMissing(input.to_string()).into());
}
let number: u64 = number
let number: i64 = number
.parse()
// Technically this should never happen, but there was a bug
// here and being defensive does not hurt.
@@ -102,36 +223,288 @@ fn parse_into_milliseconds(input: &str) -> Result<u64, DateHistogramParseError>
"m" => 60 * 1000,
"h" => 60 * 60 * 1000,
"d" => 24 * 60 * 60 * 1000,
_ => return Err(DateHistogramParseError::UnitNotRecognized(unit.to_string())),
_ => return Err(DateHistogramParseError::UnitNotRecognized(unit.to_string()).into()),
};
Ok(number * multiplier_from_unit)
Ok(number * multiplier_from_unit * 1000)
}
#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;
use super::*;
use crate::aggregation::agg_req::Aggregations;
use crate::aggregation::tests::exec_request;
use crate::indexer::NoMergePolicy;
use crate::schema::{Schema, FAST};
use crate::Index;
#[test]
fn test_parse_into_milliseconds() {
assert_eq!(parse_into_milliseconds("1m").unwrap(), 60_000);
assert_eq!(parse_into_milliseconds("2m").unwrap(), 120_000);
fn test_parse_into_microseconds() {
assert_eq!(parse_into_microseconds("1m").unwrap(), 60_000_000);
assert_eq!(parse_into_microseconds("2m").unwrap(), 120_000_000);
assert_eq!(
parse_into_milliseconds("2y").unwrap_err(),
DateHistogramParseError::UnitNotRecognized("y".to_string())
parse_into_microseconds("2y").unwrap_err(),
DateHistogramParseError::UnitNotRecognized("y".to_string()).into()
);
assert_eq!(
parse_into_milliseconds("2000").unwrap_err(),
DateHistogramParseError::UnitMissing("2000".to_string())
parse_into_microseconds("2000").unwrap_err(),
DateHistogramParseError::UnitMissing("2000".to_string()).into()
);
assert_eq!(
parse_into_milliseconds("ms").unwrap_err(),
DateHistogramParseError::NumberMissing("ms".to_string())
parse_into_microseconds("ms").unwrap_err(),
DateHistogramParseError::NumberMissing("ms".to_string()).into()
);
}
#[test]
fn test_parse_offset_into_microseconds() {
assert_eq!(parse_offset_into_microseconds("1m").unwrap(), 60_000_000);
assert_eq!(parse_offset_into_microseconds("+1m").unwrap(), 60_000_000);
assert_eq!(parse_offset_into_microseconds("-1m").unwrap(), -60_000_000);
assert_eq!(parse_offset_into_microseconds("2m").unwrap(), 120_000_000);
assert_eq!(parse_offset_into_microseconds("+2m").unwrap(), 120_000_000);
assert_eq!(parse_offset_into_microseconds("-2m").unwrap(), -120_000_000);
assert_eq!(parse_offset_into_microseconds("-2ms").unwrap(), -2_000);
assert_eq!(
parse_offset_into_microseconds("2y").unwrap_err(),
DateHistogramParseError::UnitNotRecognized("y".to_string()).into()
);
assert_eq!(
parse_offset_into_microseconds("2000").unwrap_err(),
DateHistogramParseError::UnitMissing("2000".to_string()).into()
);
assert_eq!(
parse_offset_into_microseconds("ms").unwrap_err(),
DateHistogramParseError::NumberMissing("ms".to_string()).into()
);
}
#[test]
fn test_parse_into_milliseconds_do_not_accept_non_ascii() {
assert!(parse_into_milliseconds("m").is_err());
assert!(parse_into_microseconds("m").is_err());
}
pub fn get_test_index_from_docs(
merge_segments: bool,
segment_and_docs: &[Vec<&str>],
) -> crate::Result<Index> {
let mut schema_builder = Schema::builder();
schema_builder.add_date_field("date", FAST);
schema_builder.add_text_field("text", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema.clone());
{
let mut index_writer = index.writer_with_num_threads(1, 30_000_000)?;
index_writer.set_merge_policy(Box::new(NoMergePolicy));
for values in segment_and_docs {
for doc_str in values {
let doc = schema.parse_document(doc_str)?;
index_writer.add_document(doc)?;
}
// writing the segment
index_writer.commit()?;
}
}
if merge_segments {
let segment_ids = index
.searchable_segment_ids()
.expect("Searchable segments failed.");
if segment_ids.len() > 1 {
let mut index_writer = index.writer_for_tests()?;
index_writer.merge(&segment_ids).wait()?;
index_writer.wait_merging_threads()?;
}
}
Ok(index)
}
#[test]
fn histogram_test_date_force_merge_segments() -> crate::Result<()> {
histogram_test_date_merge_segments(true)
}
#[test]
fn histogram_test_date() -> crate::Result<()> {
histogram_test_date_merge_segments(false)
}
fn histogram_test_date_merge_segments(merge_segments: bool) -> crate::Result<()> {
let docs = vec![
vec![r#"{ "date": "2015-01-01T12:10:30Z", "text": "aaa" }"#],
vec![r#"{ "date": "2015-01-01T11:11:30Z", "text": "bbb" }"#],
vec![r#"{ "date": "2015-01-02T00:00:00Z", "text": "bbb" }"#],
vec![r#"{ "date": "2015-01-06T00:00:00Z", "text": "ccc" }"#],
];
let index = get_test_index_from_docs(merge_segments, &docs)?;
// 30day + offset
let elasticsearch_compatible_json = json!(
{
"sales_over_time": {
"date_histogram": {
"field": "date",
"fixed_interval": "30d",
"offset": "-4d"
}
}
}
);
let agg_req: Aggregations =
serde_json::from_str(&serde_json::to_string(&elasticsearch_compatible_json).unwrap())
.unwrap();
let res = exec_request(agg_req, &index)?;
let expected_res = json!({
"sales_over_time" : {
"buckets" : [
{
"key_as_string" : "2015-01-01T00:00:00Z",
"key" : 1420070400000000.0,
"doc_count" : 4
}
]
}
});
assert_eq!(res, expected_res);
// 30day + offset + sub_agg
let elasticsearch_compatible_json = json!(
{
"sales_over_time": {
"date_histogram": {
"field": "date",
"fixed_interval": "30d",
"offset": "-4d"
},
"aggs": {
"texts": {
"terms": {"field": "text"}
}
}
}
}
);
let agg_req: Aggregations =
serde_json::from_str(&serde_json::to_string(&elasticsearch_compatible_json).unwrap())
.unwrap();
let res = exec_request(agg_req, &index)?;
println!("{}", serde_json::to_string_pretty(&res).unwrap());
let expected_res = json!({
"sales_over_time" : {
"buckets" : [
{
"key_as_string" : "2015-01-01T00:00:00Z",
"key" : 1420070400000000.0,
"doc_count" : 4,
"texts": {
"buckets": [
{
"doc_count": 2,
"key": "bbb"
},
{
"doc_count": 1,
"key": "ccc"
},
{
"doc_count": 1,
"key": "aaa"
}
],
"doc_count_error_upper_bound": 0,
"sum_other_doc_count": 0
}
}
]
}
});
assert_eq!(res, expected_res);
// 1day
let elasticsearch_compatible_json = json!(
{
"sales_over_time": {
"date_histogram": {
"field": "date",
"fixed_interval": "1d"
}
}
}
);
let agg_req: Aggregations =
serde_json::from_str(&serde_json::to_string(&elasticsearch_compatible_json).unwrap())
.unwrap();
let res = exec_request(agg_req, &index)?;
let expected_res = json!( {
"sales_over_time": {
"buckets": [
{
"doc_count": 2,
"key": 1420070400000000.0,
"key_as_string": "2015-01-01T00:00:00Z"
},
{
"doc_count": 1,
"key": 1420156800000000.0,
"key_as_string": "2015-01-02T00:00:00Z"
},
{
"doc_count": 0,
"key": 1420243200000000.0,
"key_as_string": "2015-01-03T00:00:00Z"
},
{
"doc_count": 0,
"key": 1420329600000000.0,
"key_as_string": "2015-01-04T00:00:00Z"
},
{
"doc_count": 0,
"key": 1420416000000000.0,
"key_as_string": "2015-01-05T00:00:00Z"
},
{
"doc_count": 1,
"key": 1420502400000000.0,
"key_as_string": "2015-01-06T00:00:00Z"
}
]
}
});
assert_eq!(res, expected_res);
Ok(())
}
#[test]
fn histogram_test_invalid_req() -> crate::Result<()> {
let docs = vec![];
let index = get_test_index_from_docs(false, &docs)?;
let elasticsearch_compatible_json = json!(
{
"sales_over_time": {
"date_histogram": {
"field": "date",
"interval": "30d",
"offset": "-4d"
}
}
}
);
let agg_req: Aggregations =
serde_json::from_str(&serde_json::to_string(&elasticsearch_compatible_json).unwrap())
.unwrap();
let err = exec_request(agg_req, &index).unwrap_err();
assert_eq!(
err.to_string(),
r#"An invalid argument was passed: '`interval` parameter "30d" in date histogram is unsupported, only `fixed_interval` is supported'"#
);
Ok(())
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,4 @@
// mod date_histogram;
mod date_histogram;
mod histogram;
// pub use date_histogram::*;
pub use date_histogram::*;
pub use histogram::*;

View File

@@ -21,28 +21,25 @@ use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
pub use term_agg::*;
/// Order for buckets in a bucket aggregation.
#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize, Default)]
pub enum Order {
/// Asc order
#[serde(rename = "asc")]
Asc,
/// Desc order
#[serde(rename = "desc")]
#[default]
Desc,
}
impl Default for Order {
fn default() -> Self {
Order::Desc
}
}
#[derive(Clone, Debug, PartialEq)]
/// Order property by which to apply the order
#[derive(Default)]
pub enum OrderTarget {
/// The key of the bucket
Key,
/// The doc count of the bucket
#[default]
Count,
/// Order by value of the sub aggregation metric with identified by given `String`.
///
@@ -50,11 +47,6 @@ pub enum OrderTarget {
SubAggregation(String),
}
impl Default for OrderTarget {
fn default() -> Self {
OrderTarget::Count
}
}
impl From<&str> for OrderTarget {
fn from(val: &str) -> Self {
match val {

View File

@@ -1,24 +1,22 @@
use std::fmt::Debug;
use std::ops::Range;
use columnar::MonotonicallyMappableToU64;
use columnar::{ColumnType, MonotonicallyMappableToU64};
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use crate::aggregation::agg_req_with_accessor::{
AggregationsWithAccessor, BucketAggregationWithAccessor,
};
use crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor;
use crate::aggregation::intermediate_agg_result::{
IntermediateBucketResult, IntermediateRangeBucketEntry, IntermediateRangeBucketResult,
IntermediateAggregationResults, IntermediateBucketResult, IntermediateRangeBucketEntry,
IntermediateRangeBucketResult,
};
use crate::aggregation::segment_agg_result::{
BucketCount, GenericSegmentAggregationResultsCollector, SegmentAggregationCollector,
build_segment_agg_collector, AggregationLimits, SegmentAggregationCollector,
};
use crate::aggregation::{
f64_from_fastfield_u64, f64_to_fastfield_u64, format_date, Key, SerializedKey,
f64_from_fastfield_u64, f64_to_fastfield_u64, format_date, Key, SerializedKey, VecWithNames,
};
use crate::schema::Type;
use crate::{DocId, TantivyError};
use crate::TantivyError;
/// Provide user-defined buckets to aggregate on.
/// Two special buckets will automatically be created to cover the whole range of values.
@@ -128,14 +126,15 @@ pub(crate) struct SegmentRangeAndBucketEntry {
pub struct SegmentRangeCollector {
/// The buckets containing the aggregation data.
buckets: Vec<SegmentRangeAndBucketEntry>,
field_type: Type,
column_type: ColumnType,
pub(crate) accessor_idx: usize,
}
#[derive(Clone)]
pub(crate) struct SegmentRangeBucketEntry {
pub key: Key,
pub doc_count: u64,
pub sub_aggregation: Option<GenericSegmentAggregationResultsCollector>,
pub sub_aggregation: Option<Box<dyn SegmentAggregationCollector>>,
/// The from range of the bucket. Equals `f64::MIN` when `None`.
pub from: Option<f64>,
/// The to range of the bucket. Equals `f64::MAX` when `None`. Open interval, `to` is not
@@ -174,12 +173,14 @@ impl SegmentRangeBucketEntry {
}
}
impl SegmentRangeCollector {
pub fn into_intermediate_bucket_result(
self,
agg_with_accessor: &BucketAggregationWithAccessor,
) -> crate::Result<IntermediateBucketResult> {
let field_type = self.field_type;
impl SegmentAggregationCollector for SegmentRangeCollector {
fn into_intermediate_aggregations_result(
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
) -> crate::Result<IntermediateAggregationResults> {
let field_type = self.column_type;
let name = agg_with_accessor.buckets.keys[self.accessor_idx].to_string();
let sub_agg = &agg_with_accessor.buckets.values[self.accessor_idx].sub_aggregation;
let buckets: FxHashMap<SerializedKey, IntermediateRangeBucketEntry> = self
.buckets
@@ -189,21 +190,80 @@ impl SegmentRangeCollector {
range_to_string(&range_bucket.range, &field_type)?,
range_bucket
.bucket
.into_intermediate_bucket_entry(&agg_with_accessor.sub_aggregation)?,
.into_intermediate_bucket_entry(sub_agg)?,
))
})
.collect::<crate::Result<_>>()?;
Ok(IntermediateBucketResult::Range(
IntermediateRangeBucketResult { buckets },
))
let bucket = IntermediateBucketResult::Range(IntermediateRangeBucketResult {
buckets,
column_type: Some(self.column_type),
});
let buckets = Some(VecWithNames::from_entries(vec![(name, bucket)]));
Ok(IntermediateAggregationResults {
metrics: None,
buckets,
})
}
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
self.collect_block(&[doc], agg_with_accessor)
}
#[inline]
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
let bucket_agg_accessor = &mut agg_with_accessor.buckets.values[self.accessor_idx];
bucket_agg_accessor
.column_block_accessor
.fetch_block(docs, &bucket_agg_accessor.accessor);
for (doc, val) in bucket_agg_accessor.column_block_accessor.iter_docid_vals() {
let bucket_pos = self.get_bucket_pos(val);
let bucket = &mut self.buckets[bucket_pos];
bucket.bucket.doc_count += 1;
if let Some(sub_aggregation) = &mut bucket.bucket.sub_aggregation {
sub_aggregation.collect(doc, &mut bucket_agg_accessor.sub_aggregation)?;
}
}
Ok(())
}
fn flush(&mut self, agg_with_accessor: &mut AggregationsWithAccessor) -> crate::Result<()> {
let sub_aggregation_accessor =
&mut agg_with_accessor.buckets.values[self.accessor_idx].sub_aggregation;
for bucket in self.buckets.iter_mut() {
if let Some(sub_agg) = bucket.bucket.sub_aggregation.as_mut() {
sub_agg.flush(sub_aggregation_accessor)?;
}
}
Ok(())
}
}
impl SegmentRangeCollector {
pub(crate) fn from_req_and_validate(
req: &RangeAggregation,
sub_aggregation: &AggregationsWithAccessor,
bucket_count: &BucketCount,
field_type: Type,
limits: &AggregationLimits,
field_type: ColumnType,
accessor_idx: usize,
) -> crate::Result<Self> {
// The range input on the request is f64.
// We need to convert to u64 ranges, because we read the values as u64.
@@ -229,11 +289,7 @@ impl SegmentRangeCollector {
let sub_aggregation = if sub_aggregation.is_empty() {
None
} else {
Some(
GenericSegmentAggregationResultsCollector::from_req_and_validate(
sub_aggregation,
)?,
)
Some(build_segment_agg_collector(sub_aggregation)?)
};
Ok(SegmentRangeAndBucketEntry {
@@ -249,57 +305,18 @@ impl SegmentRangeCollector {
})
.collect::<crate::Result<_>>()?;
bucket_count.add_count(buckets.len() as u32);
bucket_count.validate_bucket_count()?;
limits.add_memory_consumed(
buckets.len() as u64 * std::mem::size_of::<SegmentRangeAndBucketEntry>() as u64,
);
limits.validate_memory_consumption()?;
Ok(SegmentRangeCollector {
buckets,
field_type,
column_type: field_type,
accessor_idx,
})
}
#[inline]
pub(crate) fn collect_block(
&mut self,
docs: &[DocId],
bucket_with_accessor: &BucketAggregationWithAccessor,
force_flush: bool,
) -> crate::Result<()> {
let accessor = &bucket_with_accessor.accessor;
for doc in docs {
for val in accessor.values(*doc) {
let bucket_pos = self.get_bucket_pos(val);
self.increment_bucket(bucket_pos, *doc, &bucket_with_accessor.sub_aggregation)?;
}
}
if force_flush {
for bucket in &mut self.buckets {
if let Some(sub_aggregation) = &mut bucket.bucket.sub_aggregation {
sub_aggregation
.flush_staged_docs(&bucket_with_accessor.sub_aggregation, force_flush)?;
}
}
}
Ok(())
}
#[inline]
fn increment_bucket(
&mut self,
bucket_pos: usize,
doc: DocId,
bucket_with_accessor: &AggregationsWithAccessor,
) -> crate::Result<()> {
let bucket = &mut self.buckets[bucket_pos];
bucket.bucket.doc_count += 1;
if let Some(sub_aggregation) = &mut bucket.bucket.sub_aggregation {
sub_aggregation.collect(doc, bucket_with_accessor)?;
}
Ok(())
}
#[inline]
fn get_bucket_pos(&self, val: u64) -> usize {
let pos = self
@@ -325,7 +342,7 @@ impl SegmentRangeCollector {
/// more computational expensive when many documents are hit.
fn to_u64_range(
range: &RangeAggregationRange,
field_type: &Type,
field_type: &ColumnType,
) -> crate::Result<InternalRangeAggregationRange> {
let start = if let Some(from) = range.from {
f64_to_fastfield_u64(from, field_type)
@@ -351,7 +368,7 @@ fn to_u64_range(
/// beginning and end and filling gaps.
fn extend_validate_ranges(
buckets: &[RangeAggregationRange],
field_type: &Type,
field_type: &ColumnType,
) -> crate::Result<Vec<InternalRangeAggregationRange>> {
let mut converted_buckets = buckets
.iter()
@@ -393,13 +410,16 @@ fn extend_validate_ranges(
Ok(converted_buckets)
}
pub(crate) fn range_to_string(range: &Range<u64>, field_type: &Type) -> crate::Result<String> {
pub(crate) fn range_to_string(
range: &Range<u64>,
field_type: &ColumnType,
) -> crate::Result<String> {
// is_start is there for malformed requests, e.g. ig the user passes the range u64::MIN..0.0,
// it should be rendered as "*-0" and not "*-*"
let to_str = |val: u64, is_start: bool| {
if (is_start && val == u64::MIN) || (!is_start && val == u64::MAX) {
Ok("*".to_string())
} else if *field_type == Type::Date {
} else if *field_type == ColumnType::DateTime {
let val = i64::from_u64(val);
format_date(val)
} else {
@@ -414,7 +434,7 @@ pub(crate) fn range_to_string(range: &Range<u64>, field_type: &Type) -> crate::R
))
}
pub(crate) fn range_to_key(range: &Range<u64>, field_type: &Type) -> crate::Result<Key> {
pub(crate) fn range_to_key(range: &Range<u64>, field_type: &ColumnType) -> crate::Result<Key> {
Ok(Key::Str(range_to_string(range, field_type)?))
}
@@ -426,8 +446,9 @@ mod tests {
use super::*;
use crate::aggregation::agg_req::{
Aggregation, Aggregations, BucketAggregation, BucketAggregationType,
Aggregation, Aggregations, BucketAggregation, BucketAggregationType, MetricAggregation,
};
use crate::aggregation::metric::AverageAggregation;
use crate::aggregation::tests::{
exec_request, exec_request_with_query, get_test_index_2_segments,
get_test_index_with_num_docs,
@@ -435,7 +456,7 @@ mod tests {
pub fn get_collector_from_ranges(
ranges: Vec<RangeAggregationRange>,
field_type: Type,
field_type: ColumnType,
) -> SegmentRangeCollector {
let req = RangeAggregation {
field: "dummy".to_string(),
@@ -448,6 +469,7 @@ mod tests {
&Default::default(),
&Default::default(),
field_type,
0,
)
.expect("unexpected error")
}
@@ -458,14 +480,61 @@ mod tests {
let agg_req: Aggregations = vec![(
"range".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "fraction_f64".to_string(),
ranges: vec![(0f64..0.1f64).into(), (0.1f64..0.2f64).into()],
..Default::default()
}),
sub_aggregation: Default::default(),
}),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "fraction_f64".to_string(),
ranges: vec![(0f64..0.1f64).into(), (0.1f64..0.2f64).into()],
..Default::default()
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
let res = exec_request_with_query(agg_req, &index, None)?;
assert_eq!(res["range"]["buckets"][0]["key"], "*-0");
assert_eq!(res["range"]["buckets"][0]["doc_count"], 0);
assert_eq!(res["range"]["buckets"][1]["key"], "0-0.1");
assert_eq!(res["range"]["buckets"][1]["doc_count"], 10);
assert_eq!(res["range"]["buckets"][2]["key"], "0.1-0.2");
assert_eq!(res["range"]["buckets"][2]["doc_count"], 10);
assert_eq!(res["range"]["buckets"][3]["key"], "0.2-*");
assert_eq!(res["range"]["buckets"][3]["doc_count"], 80);
Ok(())
}
#[test]
fn range_fraction_test_with_sub_agg() -> crate::Result<()> {
let index = get_test_index_with_num_docs(false, 100)?;
let sub_agg_req: Aggregations = vec![(
"score_f64".to_string(),
Aggregation::Metric(MetricAggregation::Average(
AverageAggregation::from_field_name("score_f64".to_string()),
)),
)]
.into_iter()
.collect();
let agg_req: Aggregations = vec![(
"range".to_string(),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "fraction_f64".to_string(),
ranges: vec![(0f64..0.1f64).into(), (0.1f64..0.2f64).into()],
..Default::default()
}),
sub_aggregation: sub_agg_req,
}
.into(),
),
)]
.into_iter()
.collect();
@@ -490,14 +559,17 @@ mod tests {
let agg_req: Aggregations = vec![(
"range".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "fraction_f64".to_string(),
ranges: vec![(0f64..0.1f64).into(), (0.1f64..0.2f64).into()],
keyed: true,
}),
sub_aggregation: Default::default(),
}),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "fraction_f64".to_string(),
ranges: vec![(0f64..0.1f64).into(), (0.1f64..0.2f64).into()],
keyed: true,
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
@@ -527,25 +599,28 @@ mod tests {
let agg_req: Aggregations = vec![(
"range".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "fraction_f64".to_string(),
ranges: vec![
RangeAggregationRange {
key: Some("custom-key-0-to-0.1".to_string()),
from: Some(0f64),
to: Some(0.1f64),
},
RangeAggregationRange {
key: None,
from: Some(0.1f64),
to: Some(0.2f64),
},
],
keyed: false,
}),
sub_aggregation: Default::default(),
}),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "fraction_f64".to_string(),
ranges: vec![
RangeAggregationRange {
key: Some("custom-key-0-to-0.1".to_string()),
from: Some(0f64),
to: Some(0.1f64),
},
RangeAggregationRange {
key: None,
from: Some(0.1f64),
to: Some(0.2f64),
},
],
keyed: false,
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
@@ -584,25 +659,28 @@ mod tests {
let agg_req: Aggregations = vec![(
"date_ranges".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "date".to_string(),
ranges: vec![
RangeAggregationRange {
key: None,
from: None,
to: Some(1546300800000000.0f64),
},
RangeAggregationRange {
key: None,
from: Some(1546300800000000.0f64),
to: Some(1546387200000000.0f64),
},
],
keyed: false,
}),
sub_aggregation: Default::default(),
}),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "date".to_string(),
ranges: vec![
RangeAggregationRange {
key: None,
from: None,
to: Some(1546300800000000.0f64),
},
RangeAggregationRange {
key: None,
from: Some(1546300800000000.0f64),
to: Some(1546387200000000.0f64),
},
],
keyed: false,
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
@@ -646,18 +724,21 @@ mod tests {
let agg_req: Aggregations = vec![(
"range".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "fraction_f64".to_string(),
ranges: vec![RangeAggregationRange {
key: Some("custom-key-0-to-0.1".to_string()),
from: Some(0f64),
to: Some(0.1f64),
}],
keyed: true,
}),
sub_aggregation: Default::default(),
}),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "fraction_f64".to_string(),
ranges: vec![RangeAggregationRange {
key: Some("custom-key-0-to-0.1".to_string()),
from: Some(0f64),
to: Some(0.1f64),
}],
keyed: true,
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
@@ -683,7 +764,7 @@ mod tests {
#[test]
fn bucket_test_extend_range_hole() {
let buckets = vec![(10f64..20f64).into(), (30f64..40f64).into()];
let collector = get_collector_from_ranges(buckets, Type::F64);
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = collector.buckets;
assert_eq!(buckets[0].range.start, u64::MIN);
@@ -706,7 +787,7 @@ mod tests {
(10f64..20f64).into(),
(20f64..f64::MAX).into(),
];
let collector = get_collector_from_ranges(buckets, Type::F64);
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = collector.buckets;
assert_eq!(buckets[0].range.start, u64::MIN);
@@ -721,7 +802,7 @@ mod tests {
#[test]
fn bucket_range_test_negative_vals() {
let buckets = vec![(-10f64..-1f64).into()];
let collector = get_collector_from_ranges(buckets, Type::F64);
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = collector.buckets;
assert_eq!(&buckets[0].bucket.key.to_string(), "*--10");
@@ -730,7 +811,7 @@ mod tests {
#[test]
fn bucket_range_test_positive_vals() {
let buckets = vec![(0f64..10f64).into()];
let collector = get_collector_from_ranges(buckets, Type::F64);
let collector = get_collector_from_ranges(buckets, ColumnType::F64);
let buckets = collector.buckets;
assert_eq!(&buckets[0].bucket.key.to_string(), "*-0");
@@ -740,7 +821,7 @@ mod tests {
#[test]
fn range_binary_search_test_u64() {
let check_ranges = |ranges: Vec<RangeAggregationRange>| {
let collector = get_collector_from_ranges(ranges, Type::U64);
let collector = get_collector_from_ranges(ranges, ColumnType::U64);
let search = |val: u64| collector.get_bucket_pos(val);
assert_eq!(search(u64::MIN), 0);
@@ -786,7 +867,7 @@ mod tests {
fn range_binary_search_test_f64() {
let ranges = vec![(10.0..100.0).into()];
let collector = get_collector_from_ranges(ranges, Type::F64);
let collector = get_collector_from_ranges(ranges, ColumnType::F64);
let search = |val: u64| collector.get_bucket_pos(val);
assert_eq!(search(u64::MIN), 0);
@@ -821,7 +902,7 @@ mod bench {
buckets.push((bucket_start..bucket_start + bucket_size as f64).into())
}
get_collector_from_ranges(buckets, Type::U64)
get_collector_from_ranges(buckets, ColumnType::U64)
}
fn get_rand_docs(total_docs: u64, num_docs_returned: u64) -> Vec<u64> {

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,82 @@
use super::agg_req_with_accessor::AggregationsWithAccessor;
use super::intermediate_agg_result::IntermediateAggregationResults;
use super::segment_agg_result::SegmentAggregationCollector;
use crate::DocId;
pub(crate) const DOC_BLOCK_SIZE: usize = 64;
pub(crate) type DocBlock = [DocId; DOC_BLOCK_SIZE];
/// BufAggregationCollector buffers documents before calling collect_block().
#[derive(Clone)]
pub(crate) struct BufAggregationCollector {
pub(crate) collector: Box<dyn SegmentAggregationCollector>,
staged_docs: DocBlock,
num_staged_docs: usize,
}
impl std::fmt::Debug for BufAggregationCollector {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SegmentAggregationResultsCollector")
.field("staged_docs", &&self.staged_docs[..self.num_staged_docs])
.field("num_staged_docs", &self.num_staged_docs)
.finish()
}
}
impl BufAggregationCollector {
pub fn new(collector: Box<dyn SegmentAggregationCollector>) -> Self {
Self {
collector,
num_staged_docs: 0,
staged_docs: [0; DOC_BLOCK_SIZE],
}
}
}
impl SegmentAggregationCollector for BufAggregationCollector {
#[inline]
fn into_intermediate_aggregations_result(
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
) -> crate::Result<IntermediateAggregationResults> {
Box::new(self.collector).into_intermediate_aggregations_result(agg_with_accessor)
}
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
self.staged_docs[self.num_staged_docs] = doc;
self.num_staged_docs += 1;
if self.num_staged_docs == self.staged_docs.len() {
self.collector
.collect_block(&self.staged_docs[..self.num_staged_docs], agg_with_accessor)?;
self.num_staged_docs = 0;
}
Ok(())
}
#[inline]
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
self.collector.collect_block(docs, agg_with_accessor)?;
Ok(())
}
#[inline]
fn flush(&mut self, agg_with_accessor: &mut AggregationsWithAccessor) -> crate::Result<()> {
self.collector
.collect_block(&self.staged_docs[..self.num_staged_docs], agg_with_accessor)?;
self.num_staged_docs = 0;
self.collector.flush(agg_with_accessor)?;
Ok(())
}
}

View File

@@ -1,38 +1,36 @@
use std::rc::Rc;
use super::agg_req::Aggregations;
use super::agg_req_with_accessor::AggregationsWithAccessor;
use super::agg_result::AggregationResults;
use super::buf_collector::BufAggregationCollector;
use super::intermediate_agg_result::IntermediateAggregationResults;
use super::segment_agg_result::{build_segment_agg_collector, SegmentAggregationCollector};
use super::segment_agg_result::{
build_segment_agg_collector, AggregationLimits, SegmentAggregationCollector,
};
use crate::aggregation::agg_req_with_accessor::get_aggs_with_accessor_and_validate;
use crate::collector::{Collector, SegmentCollector};
use crate::schema::Schema;
use crate::{SegmentReader, TantivyError};
use crate::{DocId, SegmentReader, TantivyError};
/// The default max bucket count, before the aggregation fails.
pub const MAX_BUCKET_COUNT: u32 = 65000;
pub const DEFAULT_BUCKET_LIMIT: u32 = 65000;
/// The default memory limit in bytes before the aggregation fails. 500MB
pub const DEFAULT_MEMORY_LIMIT: u64 = 500_000_000;
/// Collector for aggregations.
///
/// The collector collects all aggregations by the underlying aggregation request.
pub struct AggregationCollector {
schema: Schema,
agg: Aggregations,
max_bucket_count: u32,
limits: AggregationLimits,
}
impl AggregationCollector {
/// Create collector from aggregation request.
///
/// Aggregation fails when the total bucket count is higher than max_bucket_count.
/// max_bucket_count will default to `MAX_BUCKET_COUNT` (65000) when unset
pub fn from_aggs(agg: Aggregations, max_bucket_count: Option<u32>, schema: Schema) -> Self {
Self {
schema,
agg,
max_bucket_count: max_bucket_count.unwrap_or(MAX_BUCKET_COUNT),
}
/// Aggregation fails when the limits in `AggregationLimits` is exceeded. (memory limit and
/// bucket limit)
pub fn from_aggs(agg: Aggregations, limits: AggregationLimits) -> Self {
Self { agg, limits }
}
}
@@ -46,18 +44,16 @@ impl AggregationCollector {
/// into the final `AggregationResults` via the `into_final_result()` method.
pub struct DistributedAggregationCollector {
agg: Aggregations,
max_bucket_count: u32,
limits: AggregationLimits,
}
impl DistributedAggregationCollector {
/// Create collector from aggregation request.
///
/// max_bucket_count will default to `MAX_BUCKET_COUNT` (65000) when unset
pub fn from_aggs(agg: Aggregations, max_bucket_count: Option<u32>) -> Self {
Self {
agg,
max_bucket_count: max_bucket_count.unwrap_or(MAX_BUCKET_COUNT),
}
/// Aggregation fails when the limits in `AggregationLimits` is exceeded. (memory limit and
/// bucket limit)
pub fn from_aggs(agg: Aggregations, limits: AggregationLimits) -> Self {
Self { agg, limits }
}
}
@@ -71,11 +67,7 @@ impl Collector for DistributedAggregationCollector {
_segment_local_id: crate::SegmentOrdinal,
reader: &crate::SegmentReader,
) -> crate::Result<Self::Child> {
AggregationSegmentCollector::from_agg_req_and_reader(
&self.agg,
reader,
self.max_bucket_count,
)
AggregationSegmentCollector::from_agg_req_and_reader(&self.agg, reader, &self.limits)
}
fn requires_scoring(&self) -> bool {
@@ -100,11 +92,7 @@ impl Collector for AggregationCollector {
_segment_local_id: crate::SegmentOrdinal,
reader: &crate::SegmentReader,
) -> crate::Result<Self::Child> {
AggregationSegmentCollector::from_agg_req_and_reader(
&self.agg,
reader,
self.max_bucket_count,
)
AggregationSegmentCollector::from_agg_req_and_reader(&self.agg, reader, &self.limits)
}
fn requires_scoring(&self) -> bool {
@@ -116,7 +104,7 @@ impl Collector for AggregationCollector {
segment_fruits: Vec<<Self::Child as SegmentCollector>::Fruit>,
) -> crate::Result<Self::Fruit> {
let res = merge_fruits(segment_fruits)?;
res.into_final_bucket_result(self.agg.clone(), &self.schema)
res.into_final_bucket_result(self.agg.clone(), &self.limits)
}
}
@@ -137,7 +125,7 @@ fn merge_fruits(
/// `AggregationSegmentCollector` does the aggregation collection on a segment.
pub struct AggregationSegmentCollector {
aggs_with_accessor: AggregationsWithAccessor,
result: Box<dyn SegmentAggregationCollector>,
agg_collector: BufAggregationCollector,
error: Option<TantivyError>,
}
@@ -147,14 +135,14 @@ impl AggregationSegmentCollector {
pub fn from_agg_req_and_reader(
agg: &Aggregations,
reader: &SegmentReader,
max_bucket_count: u32,
limits: &AggregationLimits,
) -> crate::Result<Self> {
let aggs_with_accessor =
get_aggs_with_accessor_and_validate(agg, reader, Rc::default(), max_bucket_count)?;
let result = build_segment_agg_collector(&aggs_with_accessor)?;
let aggs_with_accessor = get_aggs_with_accessor_and_validate(agg, reader, limits)?;
let result =
BufAggregationCollector::new(build_segment_agg_collector(&aggs_with_accessor)?);
Ok(AggregationSegmentCollector {
aggs_with_accessor,
result,
agg_collector: result,
error: None,
})
}
@@ -164,11 +152,29 @@ impl SegmentCollector for AggregationSegmentCollector {
type Fruit = crate::Result<IntermediateAggregationResults>;
#[inline]
fn collect(&mut self, doc: crate::DocId, _score: crate::Score) {
fn collect(&mut self, doc: DocId, _score: crate::Score) {
if self.error.is_some() {
return;
}
if let Err(err) = self.result.collect(doc, &self.aggs_with_accessor) {
if let Err(err) = self
.agg_collector
.collect(doc, &mut self.aggs_with_accessor)
{
self.error = Some(err);
}
}
/// The query pushes the documents to the collector via this method.
///
/// Only valid for Collectors that ignore docs
fn collect_block(&mut self, docs: &[DocId]) {
if self.error.is_some() {
return;
}
if let Err(err) = self
.agg_collector
.collect_block(docs, &mut self.aggs_with_accessor)
{
self.error = Some(err);
}
}
@@ -177,9 +183,7 @@ impl SegmentCollector for AggregationSegmentCollector {
if let Some(err) = self.error {
return Err(err);
}
self.result
.flush_staged_docs(&self.aggs_with_accessor, true)?;
self.result
.into_intermediate_aggregations_result(&self.aggs_with_accessor)
self.agg_collector.flush(&mut self.aggs_with_accessor)?;
Box::new(self.agg_collector).into_intermediate_aggregations_result(&self.aggs_with_accessor)
}
}

33
src/aggregation/error.rs Normal file
View File

@@ -0,0 +1,33 @@
use common::ByteCount;
use super::bucket::DateHistogramParseError;
/// Error that may occur when opening a directory
#[derive(Debug, Clone, PartialEq, Eq, Error)]
pub enum AggregationError {
/// Date histogram parse error
#[error("Date histogram parse error: {0:?}")]
DateHistogramParseError(#[from] DateHistogramParseError),
/// Memory limit exceeded
#[error(
"Aborting aggregation because memory limit was exceeded. Limit: {limit:?}, Current: \
{current:?}"
)]
MemoryExceeded {
/// Memory consumption limit
limit: ByteCount,
/// Current memory consumption
current: ByteCount,
},
/// Bucket limit exceeded
#[error(
"Aborting aggregation because bucket limit was exceeded. Limit: {limit:?}, Current: \
{current:?}"
)]
BucketLimitExceeded {
/// Bucket limit
limit: u32,
/// Current num buckets
current: u32,
},
}

View File

@@ -4,9 +4,11 @@
use std::cmp::Ordering;
use columnar::ColumnType;
use itertools::Itertools;
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use serde::ser::SerializeSeq;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use super::agg_req::{
Aggregations, AggregationsInternal, BucketAggregationInternal, BucketAggregationType,
@@ -21,11 +23,11 @@ use super::metric::{
IntermediateAverage, IntermediateCount, IntermediateMax, IntermediateMin, IntermediateStats,
IntermediateSum,
};
use super::segment_agg_result::SegmentMetricResultCollector;
use super::{format_date, Key, SerializedKey, VecWithNames};
use super::segment_agg_result::AggregationLimits;
use super::{format_date, AggregationError, Key, SerializedKey, VecWithNames};
use crate::aggregation::agg_result::{AggregationResults, BucketEntries, BucketEntry};
use crate::aggregation::bucket::TermsAggregationInternal;
use crate::schema::Schema;
use crate::TantivyError;
/// Contains the intermediate aggregation result, which is optimized to be merged with other
/// intermediate results.
@@ -42,9 +44,19 @@ impl IntermediateAggregationResults {
pub fn into_final_bucket_result(
self,
req: Aggregations,
schema: &Schema,
limits: &AggregationLimits,
) -> crate::Result<AggregationResults> {
self.into_final_bucket_result_internal(&(req.into()), schema)
let res = self.into_final_bucket_result_internal(&(req.into()), limits)?;
let bucket_count = res.get_bucket_count() as u32;
if bucket_count > limits.get_bucket_limit() {
return Err(TantivyError::AggregationError(
AggregationError::BucketLimitExceeded {
limit: limits.get_bucket_limit(),
current: bucket_count,
},
));
}
Ok(res)
}
/// Convert intermediate result and its aggregation request to the final result.
@@ -54,7 +66,7 @@ impl IntermediateAggregationResults {
pub(crate) fn into_final_bucket_result_internal(
self,
req: &AggregationsInternal,
schema: &Schema,
limits: &AggregationLimits,
) -> crate::Result<AggregationResults> {
// Important assumption:
// When the tree contains buckets/metric, we expect it to have all buckets/metrics from the
@@ -62,11 +74,11 @@ impl IntermediateAggregationResults {
let mut results: FxHashMap<String, AggregationResult> = FxHashMap::default();
if let Some(buckets) = self.buckets {
convert_and_add_final_buckets_to_result(&mut results, buckets, &req.buckets, schema)?
convert_and_add_final_buckets_to_result(&mut results, buckets, &req.buckets, limits)?
} else {
// When there are no buckets, we create empty buckets, so that the serialized json
// format is constant
add_empty_final_buckets_to_result(&mut results, &req.buckets, schema)?
add_empty_final_buckets_to_result(&mut results, &req.buckets, limits)?
};
if let Some(metrics) = self.metrics {
@@ -167,12 +179,12 @@ fn add_empty_final_metrics_to_result(
fn add_empty_final_buckets_to_result(
results: &mut FxHashMap<String, AggregationResult>,
req_buckets: &VecWithNames<BucketAggregationInternal>,
schema: &Schema,
limits: &AggregationLimits,
) -> crate::Result<()> {
let requested_buckets = req_buckets.iter();
for (key, req) in requested_buckets {
let empty_bucket =
AggregationResult::BucketResult(BucketResult::empty_from_req(req, schema)?);
AggregationResult::BucketResult(BucketResult::empty_from_req(req, limits)?);
results.insert(key.to_string(), empty_bucket);
}
Ok(())
@@ -182,13 +194,13 @@ fn convert_and_add_final_buckets_to_result(
results: &mut FxHashMap<String, AggregationResult>,
buckets: VecWithNames<IntermediateBucketResult>,
req_buckets: &VecWithNames<BucketAggregationInternal>,
schema: &Schema,
limits: &AggregationLimits,
) -> crate::Result<()> {
assert_eq!(buckets.len(), req_buckets.len());
let buckets_with_request = buckets.into_iter().zip(req_buckets.values());
for ((key, bucket), req) in buckets_with_request {
let result = AggregationResult::BucketResult(bucket.into_final_bucket_result(req, schema)?);
let result = AggregationResult::BucketResult(bucket.into_final_bucket_result(req, limits)?);
results.insert(key, result);
}
Ok(())
@@ -220,32 +232,6 @@ pub enum IntermediateMetricResult {
Sum(IntermediateSum),
}
impl From<SegmentMetricResultCollector> for IntermediateMetricResult {
fn from(tree: SegmentMetricResultCollector) -> Self {
use super::metric::SegmentStatsType;
match tree {
SegmentMetricResultCollector::Stats(collector) => match collector.collecting_for {
SegmentStatsType::Average => IntermediateMetricResult::Average(
IntermediateAverage::from_collector(collector),
),
SegmentStatsType::Count => {
IntermediateMetricResult::Count(IntermediateCount::from_collector(collector))
}
SegmentStatsType::Max => {
IntermediateMetricResult::Max(IntermediateMax::from_collector(collector))
}
SegmentStatsType::Min => {
IntermediateMetricResult::Min(IntermediateMin::from_collector(collector))
}
SegmentStatsType::Stats => IntermediateMetricResult::Stats(collector.stats),
SegmentStatsType::Sum => {
IntermediateMetricResult::Sum(IntermediateSum::from_collector(collector))
}
},
}
}
}
impl IntermediateMetricResult {
pub(crate) fn empty_from_req(req: &MetricAggregation) -> Self {
match req {
@@ -309,6 +295,8 @@ pub enum IntermediateBucketResult {
/// This is the histogram entry for a bucket, which contains a key, count, and optionally
/// sub_aggregations.
Histogram {
/// The column_type of the underlying `Column`
column_type: Option<ColumnType>,
/// The buckets
buckets: Vec<IntermediateHistogramBucketEntry>,
},
@@ -320,7 +308,7 @@ impl IntermediateBucketResult {
pub(crate) fn into_final_bucket_result(
self,
req: &BucketAggregationInternal,
schema: &Schema,
limits: &AggregationLimits,
) -> crate::Result<BucketResult> {
match self {
IntermediateBucketResult::Range(range_res) => {
@@ -330,9 +318,10 @@ impl IntermediateBucketResult {
.map(|bucket| {
bucket.into_final_bucket_entry(
&req.sub_aggregation,
schema,
req.as_range()
.expect("unexpected aggregation, expected histogram aggregation"),
range_res.column_type,
limits,
)
})
.collect::<crate::Result<Vec<_>>>()?;
@@ -359,16 +348,22 @@ impl IntermediateBucketResult {
};
Ok(BucketResult::Range { buckets })
}
IntermediateBucketResult::Histogram { buckets } => {
IntermediateBucketResult::Histogram {
column_type,
buckets,
} => {
let histogram_req = &req
.as_histogram()?
.expect("unexpected aggregation, expected histogram aggregation");
let buckets = intermediate_histogram_buckets_to_final_buckets(
buckets,
req.as_histogram()
.expect("unexpected aggregation, expected histogram aggregation"),
column_type,
histogram_req,
&req.sub_aggregation,
schema,
limits,
)?;
let buckets = if req.as_histogram().unwrap().keyed {
let buckets = if histogram_req.keyed {
let mut bucket_map =
FxHashMap::with_capacity_and_hasher(buckets.len(), Default::default());
for bucket in buckets {
@@ -384,7 +379,7 @@ impl IntermediateBucketResult {
req.as_term()
.expect("unexpected aggregation, expected term aggregation"),
&req.sub_aggregation,
schema,
limits,
),
}
}
@@ -393,8 +388,11 @@ impl IntermediateBucketResult {
match req {
BucketAggregationType::Terms(_) => IntermediateBucketResult::Terms(Default::default()),
BucketAggregationType::Range(_) => IntermediateBucketResult::Range(Default::default()),
BucketAggregationType::Histogram(_) => {
IntermediateBucketResult::Histogram { buckets: vec![] }
BucketAggregationType::Histogram(_) | BucketAggregationType::DateHistogram(_) => {
IntermediateBucketResult::Histogram {
buckets: vec![],
column_type: None,
}
}
}
}
@@ -404,7 +402,7 @@ impl IntermediateBucketResult {
IntermediateBucketResult::Terms(term_res_left),
IntermediateBucketResult::Terms(term_res_right),
) => {
merge_maps(&mut term_res_left.entries, term_res_right.entries);
merge_key_maps(&mut term_res_left.entries, term_res_right.entries);
term_res_left.sum_other_doc_count += term_res_right.sum_other_doc_count;
term_res_left.doc_count_error_upper_bound +=
term_res_right.doc_count_error_upper_bound;
@@ -414,7 +412,7 @@ impl IntermediateBucketResult {
IntermediateBucketResult::Range(range_res_left),
IntermediateBucketResult::Range(range_res_right),
) => {
merge_maps(&mut range_res_left.buckets, range_res_right.buckets);
merge_serialized_key_maps(&mut range_res_left.buckets, range_res_right.buckets);
}
(
IntermediateBucketResult::Histogram {
@@ -460,22 +458,51 @@ impl IntermediateBucketResult {
/// Range aggregation including error counts
pub struct IntermediateRangeBucketResult {
pub(crate) buckets: FxHashMap<SerializedKey, IntermediateRangeBucketEntry>,
pub(crate) column_type: Option<ColumnType>,
}
#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize)]
/// Term aggregation including error counts
pub struct IntermediateTermBucketResult {
pub(crate) entries: FxHashMap<String, IntermediateTermBucketEntry>,
#[serde(
serialize_with = "serialize_entries",
deserialize_with = "deserialize_entries"
)]
pub(crate) entries: FxHashMap<Key, IntermediateTermBucketEntry>,
pub(crate) sum_other_doc_count: u64,
pub(crate) doc_count_error_upper_bound: u64,
}
// Serialize into a Vec to circument the JSON limitation, where keys can't be numbers
fn serialize_entries<S>(
entries: &FxHashMap<Key, IntermediateTermBucketEntry>,
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut seq = serializer.serialize_seq(Some(entries.len()))?;
for (k, v) in entries {
seq.serialize_element(&(k, v))?;
}
seq.end()
}
fn deserialize_entries<'de, D>(
deserializer: D,
) -> Result<FxHashMap<Key, IntermediateTermBucketEntry>, D::Error>
where D: Deserializer<'de> {
let vec_entries: Vec<(Key, IntermediateTermBucketEntry)> =
Deserialize::deserialize(deserializer)?;
Ok(vec_entries.into_iter().collect())
}
impl IntermediateTermBucketResult {
pub(crate) fn into_final_result(
self,
req: &TermsAggregation,
sub_aggregation_req: &AggregationsInternal,
schema: &Schema,
limits: &AggregationLimits,
) -> crate::Result<BucketResult> {
let req = TermsAggregationInternal::from_req(req);
let mut buckets: Vec<BucketEntry> = self
@@ -485,11 +512,11 @@ impl IntermediateTermBucketResult {
.map(|(key, entry)| {
Ok(BucketEntry {
key_as_string: None,
key: Key::Str(key),
key,
doc_count: entry.doc_count,
sub_aggregation: entry
.sub_aggregation
.into_final_bucket_result_internal(sub_aggregation_req, schema)?,
.into_final_bucket_result_internal(sub_aggregation_req, limits)?,
})
})
.collect::<crate::Result<_>>()?;
@@ -521,7 +548,7 @@ impl IntermediateTermBucketResult {
let val = bucket
.sub_aggregation
.get_value_from_aggregation(agg_name, agg_property)?
.unwrap_or(f64::NAN);
.unwrap_or(f64::MIN);
Ok((bucket, val))
})
.collect::<crate::Result<Vec<_>>>()?;
@@ -563,7 +590,7 @@ trait MergeFruits {
fn merge_fruits(&mut self, other: Self);
}
fn merge_maps<V: MergeFruits + Clone>(
fn merge_serialized_key_maps<V: MergeFruits + Clone>(
entries_left: &mut FxHashMap<SerializedKey, V>,
mut entries_right: FxHashMap<SerializedKey, V>,
) {
@@ -578,6 +605,21 @@ fn merge_maps<V: MergeFruits + Clone>(
}
}
fn merge_key_maps<V: MergeFruits + Clone>(
entries_left: &mut FxHashMap<Key, V>,
mut entries_right: FxHashMap<Key, V>,
) {
for (name, entry_left) in entries_left.iter_mut() {
if let Some(entry_right) = entries_right.remove(name) {
entry_left.merge_fruits(entry_right);
}
}
for (key, res) in entries_right.into_iter() {
entries_left.entry(key).or_insert(res);
}
}
/// This is the histogram entry for a bucket, which contains a key, count, and optionally
/// sub_aggregations.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
@@ -594,7 +636,7 @@ impl IntermediateHistogramBucketEntry {
pub(crate) fn into_final_bucket_entry(
self,
req: &AggregationsInternal,
schema: &Schema,
limits: &AggregationLimits,
) -> crate::Result<BucketEntry> {
Ok(BucketEntry {
key_as_string: None,
@@ -602,7 +644,7 @@ impl IntermediateHistogramBucketEntry {
doc_count: self.doc_count,
sub_aggregation: self
.sub_aggregation
.into_final_bucket_result_internal(req, schema)?,
.into_final_bucket_result_internal(req, limits)?,
})
}
}
@@ -639,15 +681,16 @@ impl IntermediateRangeBucketEntry {
pub(crate) fn into_final_bucket_entry(
self,
req: &AggregationsInternal,
schema: &Schema,
range_req: &RangeAggregation,
_range_req: &RangeAggregation,
column_type: Option<ColumnType>,
limits: &AggregationLimits,
) -> crate::Result<RangeBucketEntry> {
let mut range_bucket_entry = RangeBucketEntry {
key: self.key,
doc_count: self.doc_count,
sub_aggregation: self
.sub_aggregation
.into_final_bucket_result_internal(req, schema)?,
.into_final_bucket_result_internal(req, limits)?,
to: self.to,
from: self.from,
to_as_string: None,
@@ -656,8 +699,7 @@ impl IntermediateRangeBucketEntry {
// If we have a date type on the histogram buckets, we add the `key_as_string` field as
// rfc339
let field = schema.get_field(&range_req.field)?;
if schema.get_field_entry(field).field_type().is_date() {
if column_type == Some(ColumnType::DateTime) {
if let Some(val) = range_bucket_entry.to {
let key_as_string = format_date(val as i64)?;
range_bucket_entry.to_as_string = Some(key_as_string);
@@ -728,7 +770,10 @@ mod tests {
}
map.insert(
"my_agg_level2".to_string(),
IntermediateBucketResult::Range(IntermediateRangeBucketResult { buckets }),
IntermediateBucketResult::Range(IntermediateRangeBucketResult {
buckets,
column_type: None,
}),
);
IntermediateAggregationResults {
buckets: Some(VecWithNames::from_entries(map.into_iter().collect())),
@@ -758,7 +803,10 @@ mod tests {
}
map.insert(
"my_agg_level1".to_string(),
IntermediateBucketResult::Range(IntermediateRangeBucketResult { buckets }),
IntermediateBucketResult::Range(IntermediateRangeBucketResult {
buckets,
column_type: None,
}),
);
IntermediateAggregationResults {
buckets: Some(VecWithNames::from_entries(map.into_iter().collect())),
@@ -822,4 +870,26 @@ mod tests {
assert_eq!(tree_left, orig);
}
#[test]
fn test_term_bucket_json_roundtrip() {
let term_buckets = IntermediateTermBucketResult {
entries: vec![(
Key::F64(5.0),
IntermediateTermBucketEntry {
doc_count: 10,
sub_aggregation: Default::default(),
},
)]
.into_iter()
.collect(),
sum_other_doc_count: 0,
doc_count_error_upper_bound: 0,
};
let term_buckets_round: IntermediateTermBucketResult =
serde_json::from_str(&serde_json::to_string(&term_buckets).unwrap()).unwrap();
assert_eq!(term_buckets, term_buckets_round);
}
}

View File

@@ -81,7 +81,7 @@ mod tests {
"price_sum": { "sum": { "field": "price" } }
}"#;
let aggregations: Aggregations = serde_json::from_str(aggregations_json).unwrap();
let collector = AggregationCollector::from_aggs(aggregations, None, index.schema());
let collector = AggregationCollector::from_aggs(aggregations, Default::default());
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let aggregations_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap();

View File

@@ -1,14 +1,15 @@
use columnar::Column;
use columnar::ColumnType;
use serde::{Deserialize, Serialize};
use super::*;
use crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor;
use crate::aggregation::agg_req_with_accessor::{
AggregationsWithAccessor, MetricAggregationWithAccessor,
};
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResults, IntermediateMetricResult,
};
use crate::aggregation::segment_agg_result::SegmentAggregationCollector;
use crate::aggregation::{f64_from_fastfield_u64, VecWithNames};
use crate::schema::Type;
use crate::{DocId, TantivyError};
/// A multi-value metric aggregation that computes a collection of statistics on numeric values that
@@ -153,36 +154,51 @@ pub(crate) enum SegmentStatsType {
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct SegmentStatsCollector {
field_type: Type,
field_type: ColumnType,
pub(crate) collecting_for: SegmentStatsType,
pub(crate) stats: IntermediateStats,
pub(crate) accessor_idx: usize,
val_cache: Vec<u64>,
}
impl SegmentStatsCollector {
pub fn from_req(field_type: Type, collecting_for: SegmentStatsType) -> Self {
pub fn from_req(
field_type: ColumnType,
collecting_for: SegmentStatsType,
accessor_idx: usize,
) -> Self {
Self {
field_type,
collecting_for,
stats: IntermediateStats::default(),
accessor_idx,
val_cache: Default::default(),
}
}
pub(crate) fn collect_block(&mut self, docs: &[DocId], field: &Column<u64>) {
// TODO special case for Required, Optional column type
for doc in docs {
for val in field.values(*doc) {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.stats.collect(val1);
}
#[inline]
pub(crate) fn collect_block_with_field(
&mut self,
docs: &[DocId],
agg_accessor: &mut MetricAggregationWithAccessor,
) {
agg_accessor
.column_block_accessor
.fetch_block(docs, &agg_accessor.accessor);
for val in agg_accessor.column_block_accessor.iter_vals() {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.stats.collect(val1);
}
}
}
impl SegmentAggregationCollector for SegmentStatsCollector {
#[inline]
fn into_intermediate_aggregations_result(
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
) -> crate::Result<IntermediateAggregationResults> {
let name = agg_with_accessor.metrics.keys[0].to_string();
let name = agg_with_accessor.metrics.keys[self.accessor_idx].to_string();
let intermediate_metric_result = match self.collecting_for {
SegmentStatsType::Average => {
@@ -214,13 +230,15 @@ impl SegmentAggregationCollector for SegmentStatsCollector {
})
}
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
agg_with_accessor: &AggregationsWithAccessor,
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
let accessor = &agg_with_accessor.metrics.values[0].accessor;
for val in accessor.values(doc) {
let field = &agg_with_accessor.metrics.values[self.accessor_idx].accessor;
for val in field.values_for_doc(doc) {
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.stats.collect(val1);
}
@@ -228,11 +246,14 @@ impl SegmentAggregationCollector for SegmentStatsCollector {
Ok(())
}
fn flush_staged_docs(
#[inline]
fn collect_block(
&mut self,
_agg_with_accessor: &AggregationsWithAccessor,
_force_flush: bool,
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
let field = &mut agg_with_accessor.metrics.values[self.accessor_idx];
self.collect_block_with_field(docs, field);
Ok(())
}
}
@@ -272,7 +293,7 @@ mod tests {
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req_1, None, index.schema());
let collector = AggregationCollector::from_aggs(agg_req_1, Default::default());
let reader = index.reader()?;
let searcher = reader.searcher();
@@ -293,6 +314,43 @@ mod tests {
Ok(())
}
#[test]
fn test_aggregation_stats_simple() -> crate::Result<()> {
// test index without segments
let values = vec![10.0];
let index = get_test_index_from_values(false, &values)?;
let agg_req_1: Aggregations = vec![(
"stats".to_string(),
Aggregation::Metric(MetricAggregation::Stats(StatsAggregation::from_field_name(
"score".to_string(),
))),
)]
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req_1, Default::default());
let reader = index.reader()?;
let searcher = reader.searcher();
let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap();
let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?;
assert_eq!(
res["stats"],
json!({
"avg": 10.0,
"count": 1,
"max": 10.0,
"min": 10.0,
"sum": 10.0
})
);
Ok(())
}
#[test]
fn test_aggregation_stats() -> crate::Result<()> {
let index = get_test_index_2_segments(false)?;
@@ -326,30 +384,33 @@ mod tests {
),
(
"range".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "score".to_string(),
ranges: vec![
(3f64..7f64).into(),
(7f64..19f64).into(),
(19f64..20f64).into(),
],
..Default::default()
}),
sub_aggregation: iter::once((
"stats".to_string(),
Aggregation::Metric(MetricAggregation::Stats(
StatsAggregation::from_field_name("score".to_string()),
)),
))
.collect(),
}),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "score".to_string(),
ranges: vec![
(3f64..7f64).into(),
(7f64..19f64).into(),
(19f64..20f64).into(),
],
..Default::default()
}),
sub_aggregation: iter::once((
"stats".to_string(),
Aggregation::Metric(MetricAggregation::Stats(
StatsAggregation::from_field_name("score".to_string()),
)),
))
.collect(),
}
.into(),
),
),
]
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req_1, None, index.schema());
let collector = AggregationCollector::from_aggs(agg_req_1, Default::default());
let searcher = reader.searcher();
let agg_res: AggregationResults = searcher.search(&term_query, &collector).unwrap();

File diff suppressed because it is too large Load Diff

View File

@@ -4,26 +4,20 @@
//! merging.
use std::fmt::Debug;
use std::rc::Rc;
use std::sync::atomic::AtomicU32;
pub(crate) use super::agg_limits::AggregationLimits;
use super::agg_req::MetricAggregation;
use super::agg_req_with_accessor::{
AggregationsWithAccessor, BucketAggregationWithAccessor, MetricAggregationWithAccessor,
};
use super::bucket::{SegmentHistogramCollector, SegmentRangeCollector, SegmentTermCollector};
use super::collector::MAX_BUCKET_COUNT;
use super::intermediate_agg_result::{IntermediateAggregationResults, IntermediateBucketResult};
use super::intermediate_agg_result::IntermediateAggregationResults;
use super::metric::{
AverageAggregation, CountAggregation, MaxAggregation, MinAggregation, SegmentStatsCollector,
SegmentStatsType, StatsAggregation, SumAggregation,
};
use super::VecWithNames;
use crate::aggregation::agg_req::BucketAggregationType;
use crate::{DocId, TantivyError};
pub(crate) const DOC_BLOCK_SIZE: usize = 64;
pub(crate) type DocBlock = [DocId; DOC_BLOCK_SIZE];
pub(crate) trait SegmentAggregationCollector: CollectorClone + Debug {
fn into_intermediate_aggregations_result(
@@ -34,14 +28,20 @@ pub(crate) trait SegmentAggregationCollector: CollectorClone + Debug {
fn collect(
&mut self,
doc: crate::DocId,
agg_with_accessor: &AggregationsWithAccessor,
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()>;
fn flush_staged_docs(
fn collect_block(
&mut self,
agg_with_accessor: &AggregationsWithAccessor,
force_flush: bool,
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()>;
/// Finalize method. Some Aggregator collect blocks of docs before calling `collect_block`.
/// This method ensures those staged docs will be collected.
fn flush(&mut self, _agg_with_accessor: &mut AggregationsWithAccessor) -> crate::Result<()> {
Ok(())
}
}
pub(crate) trait CollectorClone {
@@ -68,54 +68,97 @@ pub(crate) fn build_segment_agg_collector(
// Single metric special case
if req.buckets.is_empty() && req.metrics.len() == 1 {
let req = &req.metrics.values[0];
let stats_collector = match &req.metric {
MetricAggregation::Average(AverageAggregation { .. }) => {
SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Average)
}
MetricAggregation::Count(CountAggregation { .. }) => {
SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Count)
}
MetricAggregation::Max(MaxAggregation { .. }) => {
SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Max)
}
MetricAggregation::Min(MinAggregation { .. }) => {
SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Min)
}
MetricAggregation::Stats(StatsAggregation { .. }) => {
SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Stats)
}
MetricAggregation::Sum(SumAggregation { .. }) => {
SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Sum)
}
};
let accessor_idx = 0;
return build_metric_segment_agg_collector(req, accessor_idx);
}
return Ok(Box::new(stats_collector));
// Single bucket special case
if req.metrics.is_empty() && req.buckets.len() == 1 {
let req = &req.buckets.values[0];
let accessor_idx = 0;
return build_bucket_segment_agg_collector(req, accessor_idx);
}
let agg = GenericSegmentAggregationResultsCollector::from_req_and_validate(req)?;
Ok(Box::new(agg))
}
#[derive(Clone)]
pub(crate) fn build_metric_segment_agg_collector(
req: &MetricAggregationWithAccessor,
accessor_idx: usize,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
let stats_collector = match &req.metric {
MetricAggregation::Average(AverageAggregation { .. }) => {
SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Average, accessor_idx)
}
MetricAggregation::Count(CountAggregation { .. }) => {
SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Count, accessor_idx)
}
MetricAggregation::Max(MaxAggregation { .. }) => {
SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Max, accessor_idx)
}
MetricAggregation::Min(MinAggregation { .. }) => {
SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Min, accessor_idx)
}
MetricAggregation::Stats(StatsAggregation { .. }) => {
SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Stats, accessor_idx)
}
MetricAggregation::Sum(SumAggregation { .. }) => {
SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Sum, accessor_idx)
}
};
Ok(Box::new(stats_collector))
}
pub(crate) fn build_bucket_segment_agg_collector(
req: &BucketAggregationWithAccessor,
accessor_idx: usize,
) -> crate::Result<Box<dyn SegmentAggregationCollector>> {
match &req.bucket_agg {
BucketAggregationType::Terms(terms_req) => {
Ok(Box::new(SegmentTermCollector::from_req_and_validate(
terms_req,
&req.sub_aggregation,
req.field_type,
accessor_idx,
)?))
}
BucketAggregationType::Range(range_req) => {
Ok(Box::new(SegmentRangeCollector::from_req_and_validate(
range_req,
&req.sub_aggregation,
&req.limits,
req.field_type,
accessor_idx,
)?))
}
BucketAggregationType::Histogram(histogram) => {
Ok(Box::new(SegmentHistogramCollector::from_req_and_validate(
histogram,
&req.sub_aggregation,
req.field_type,
accessor_idx,
)?))
}
BucketAggregationType::DateHistogram(histogram) => {
Ok(Box::new(SegmentHistogramCollector::from_req_and_validate(
&histogram.to_histogram_req()?,
&req.sub_aggregation,
req.field_type,
accessor_idx,
)?))
}
}
}
#[derive(Clone, Default)]
/// The GenericSegmentAggregationResultsCollector is the generic version of the collector, which
/// can handle arbitrary complexity of sub-aggregations. Ideally we never have to pick this one
/// and can provide specialized versions instead, that remove some of its overhead.
pub(crate) struct GenericSegmentAggregationResultsCollector {
pub(crate) metrics: Option<VecWithNames<SegmentMetricResultCollector>>,
pub(crate) buckets: Option<VecWithNames<SegmentBucketResultCollector>>,
staged_docs: DocBlock,
num_staged_docs: usize,
}
impl Default for GenericSegmentAggregationResultsCollector {
fn default() -> Self {
Self {
metrics: Default::default(),
buckets: Default::default(),
staged_docs: [0; DOC_BLOCK_SIZE],
num_staged_docs: Default::default(),
}
}
pub(crate) metrics: Option<Vec<Box<dyn SegmentAggregationCollector>>>,
pub(crate) buckets: Option<Vec<Box<dyn SegmentAggregationCollector>>>,
}
impl Debug for GenericSegmentAggregationResultsCollector {
@@ -123,8 +166,6 @@ impl Debug for GenericSegmentAggregationResultsCollector {
f.debug_struct("SegmentAggregationResultsCollector")
.field("metrics", &self.metrics)
.field("buckets", &self.buckets)
.field("staged_docs", &&self.staged_docs[..self.num_staged_docs])
.field("num_staged_docs", &self.num_staged_docs)
.finish()
}
}
@@ -135,16 +176,29 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
agg_with_accessor: &AggregationsWithAccessor,
) -> crate::Result<IntermediateAggregationResults> {
let buckets = if let Some(buckets) = self.buckets {
let entries = buckets
.into_iter()
.zip(agg_with_accessor.buckets.values())
.map(|((key, bucket), acc)| Ok((key, bucket.into_intermediate_bucket_result(acc)?)))
.collect::<crate::Result<Vec<(String, _)>>>()?;
Some(VecWithNames::from_entries(entries))
let mut intermeditate_buckets = VecWithNames::default();
for bucket in buckets {
// TODO too many allocations?
let res = bucket.into_intermediate_aggregations_result(agg_with_accessor)?;
// unwrap is fine since we only have buckets here
intermeditate_buckets.extend(res.buckets.unwrap());
}
Some(intermeditate_buckets)
} else {
None
};
let metrics = if let Some(metrics) = self.metrics {
let mut intermeditate_metrics = VecWithNames::default();
for metric in metrics {
// TODO too many allocations?
let res = metric.into_intermediate_aggregations_result(agg_with_accessor)?;
// unwrap is fine since we only have metrics here
intermeditate_metrics.extend(res.metrics.unwrap());
}
Some(intermeditate_metrics)
} else {
None
};
let metrics = self.metrics.map(VecWithNames::from_other);
Ok(IntermediateAggregationResults { metrics, buckets })
}
@@ -152,264 +206,78 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
fn collect(
&mut self,
doc: crate::DocId,
agg_with_accessor: &AggregationsWithAccessor,
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
self.staged_docs[self.num_staged_docs] = doc;
self.num_staged_docs += 1;
if self.num_staged_docs == self.staged_docs.len() {
self.flush_staged_docs(agg_with_accessor, false)?;
}
self.collect_block(&[doc], agg_with_accessor)?;
Ok(())
}
fn flush_staged_docs(
fn collect_block(
&mut self,
agg_with_accessor: &AggregationsWithAccessor,
force_flush: bool,
docs: &[crate::DocId],
agg_with_accessor: &mut AggregationsWithAccessor,
) -> crate::Result<()> {
if self.num_staged_docs == 0 {
return Ok(());
if let Some(metrics) = self.metrics.as_mut() {
for collector in metrics {
collector.collect_block(docs, agg_with_accessor)?;
}
}
if let Some(buckets) = self.buckets.as_mut() {
for collector in buckets {
collector.collect_block(docs, agg_with_accessor)?;
}
}
Ok(())
}
fn flush(&mut self, agg_with_accessor: &mut AggregationsWithAccessor) -> crate::Result<()> {
if let Some(metrics) = &mut self.metrics {
for (collector, agg_with_accessor) in
metrics.values_mut().zip(agg_with_accessor.metrics.values())
{
collector
.collect_block(&self.staged_docs[..self.num_staged_docs], agg_with_accessor);
for collector in metrics {
collector.flush(agg_with_accessor)?;
}
}
if let Some(buckets) = &mut self.buckets {
for (collector, agg_with_accessor) in
buckets.values_mut().zip(agg_with_accessor.buckets.values())
{
collector.collect_block(
&self.staged_docs[..self.num_staged_docs],
agg_with_accessor,
force_flush,
)?;
for collector in buckets {
collector.flush(agg_with_accessor)?;
}
}
self.num_staged_docs = 0;
Ok(())
}
}
impl GenericSegmentAggregationResultsCollector {
pub fn into_intermediate_aggregations_result(
self,
agg_with_accessor: &AggregationsWithAccessor,
) -> crate::Result<IntermediateAggregationResults> {
let buckets = if let Some(buckets) = self.buckets {
let entries = buckets
.into_iter()
.zip(agg_with_accessor.buckets.values())
.map(|((key, bucket), acc)| Ok((key, bucket.into_intermediate_bucket_result(acc)?)))
.collect::<crate::Result<Vec<(String, _)>>>()?;
Some(VecWithNames::from_entries(entries))
} else {
None
};
let metrics = self.metrics.map(VecWithNames::from_other);
Ok(IntermediateAggregationResults { metrics, buckets })
}
pub(crate) fn from_req_and_validate(req: &AggregationsWithAccessor) -> crate::Result<Self> {
let buckets = req
.buckets
.iter()
.map(|(key, req)| {
Ok((
key.to_string(),
SegmentBucketResultCollector::from_req_and_validate(req)?,
))
.enumerate()
.map(|(accessor_idx, (_key, req))| {
build_bucket_segment_agg_collector(req, accessor_idx)
})
.collect::<crate::Result<Vec<(String, _)>>>()?;
.collect::<crate::Result<Vec<Box<dyn SegmentAggregationCollector>>>>()?;
let metrics = req
.metrics
.iter()
.map(|(key, req)| {
Ok((
key.to_string(),
SegmentMetricResultCollector::from_req_and_validate(req)?,
))
.enumerate()
.map(|(accessor_idx, (_key, req))| {
build_metric_segment_agg_collector(req, accessor_idx)
})
.collect::<crate::Result<Vec<(String, _)>>>()?;
.collect::<crate::Result<Vec<Box<dyn SegmentAggregationCollector>>>>()?;
let metrics = if metrics.is_empty() {
None
} else {
Some(VecWithNames::from_entries(metrics))
Some(metrics)
};
let buckets = if buckets.is_empty() {
None
} else {
Some(VecWithNames::from_entries(buckets))
Some(buckets)
};
Ok(GenericSegmentAggregationResultsCollector {
metrics,
buckets,
staged_docs: [0; DOC_BLOCK_SIZE],
num_staged_docs: 0,
})
}
}
#[derive(Clone, Debug, PartialEq)]
pub(crate) enum SegmentMetricResultCollector {
Stats(SegmentStatsCollector),
}
impl SegmentMetricResultCollector {
pub fn from_req_and_validate(req: &MetricAggregationWithAccessor) -> crate::Result<Self> {
match &req.metric {
MetricAggregation::Average(AverageAggregation { .. }) => {
Ok(SegmentMetricResultCollector::Stats(
SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Average),
))
}
MetricAggregation::Count(CountAggregation { .. }) => {
Ok(SegmentMetricResultCollector::Stats(
SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Count),
))
}
MetricAggregation::Max(MaxAggregation { .. }) => {
Ok(SegmentMetricResultCollector::Stats(
SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Max),
))
}
MetricAggregation::Min(MinAggregation { .. }) => {
Ok(SegmentMetricResultCollector::Stats(
SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Min),
))
}
MetricAggregation::Stats(StatsAggregation { .. }) => {
Ok(SegmentMetricResultCollector::Stats(
SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Stats),
))
}
MetricAggregation::Sum(SumAggregation { .. }) => {
Ok(SegmentMetricResultCollector::Stats(
SegmentStatsCollector::from_req(req.field_type, SegmentStatsType::Sum),
))
}
}
}
pub(crate) fn collect_block(&mut self, doc: &[DocId], metric: &MetricAggregationWithAccessor) {
match self {
SegmentMetricResultCollector::Stats(stats_collector) => {
stats_collector.collect_block(doc, &metric.accessor);
}
}
}
}
/// SegmentBucketAggregationResultCollectors will have specialized buckets for collection inside
/// segments.
/// The typical structure of Map<Key, Bucket> is not suitable during collection for performance
/// reasons.
#[derive(Clone, Debug)]
pub(crate) enum SegmentBucketResultCollector {
Range(SegmentRangeCollector),
Histogram(Box<SegmentHistogramCollector>),
Terms(Box<SegmentTermCollector>),
}
impl SegmentBucketResultCollector {
pub fn into_intermediate_bucket_result(
self,
agg_with_accessor: &BucketAggregationWithAccessor,
) -> crate::Result<IntermediateBucketResult> {
match self {
SegmentBucketResultCollector::Terms(terms) => {
terms.into_intermediate_bucket_result(agg_with_accessor)
}
SegmentBucketResultCollector::Range(range) => {
range.into_intermediate_bucket_result(agg_with_accessor)
}
SegmentBucketResultCollector::Histogram(histogram) => {
histogram.into_intermediate_bucket_result(agg_with_accessor)
}
}
}
pub fn from_req_and_validate(req: &BucketAggregationWithAccessor) -> crate::Result<Self> {
match &req.bucket_agg {
BucketAggregationType::Terms(terms_req) => Ok(Self::Terms(Box::new(
SegmentTermCollector::from_req_and_validate(terms_req, &req.sub_aggregation)?,
))),
BucketAggregationType::Range(range_req) => {
Ok(Self::Range(SegmentRangeCollector::from_req_and_validate(
range_req,
&req.sub_aggregation,
&req.bucket_count,
req.field_type,
)?))
}
BucketAggregationType::Histogram(histogram) => Ok(Self::Histogram(Box::new(
SegmentHistogramCollector::from_req_and_validate(
histogram,
&req.sub_aggregation,
req.field_type,
&req.accessor,
)?,
))),
}
}
#[inline]
pub(crate) fn collect_block(
&mut self,
doc: &[DocId],
bucket_with_accessor: &BucketAggregationWithAccessor,
force_flush: bool,
) -> crate::Result<()> {
match self {
SegmentBucketResultCollector::Range(range) => {
range.collect_block(doc, bucket_with_accessor, force_flush)?;
}
SegmentBucketResultCollector::Histogram(histogram) => {
histogram.collect_block(doc, bucket_with_accessor, force_flush)?;
}
SegmentBucketResultCollector::Terms(terms) => {
terms.collect_block(doc, bucket_with_accessor, force_flush)?;
}
}
Ok(())
}
}
#[derive(Clone)]
pub(crate) struct BucketCount {
/// The counter which is shared between the aggregations for one request.
pub(crate) bucket_count: Rc<AtomicU32>,
pub(crate) max_bucket_count: u32,
}
impl Default for BucketCount {
fn default() -> Self {
Self {
bucket_count: Default::default(),
max_bucket_count: MAX_BUCKET_COUNT,
}
}
}
impl BucketCount {
pub(crate) fn validate_bucket_count(&self) -> crate::Result<()> {
if self.get_count() > self.max_bucket_count {
return Err(TantivyError::InvalidArgument(
"Aborting aggregation because too many buckets were created".to_string(),
));
}
Ok(())
}
pub(crate) fn add_count(&self, count: u32) {
self.bucket_count
.fetch_add(count, std::sync::atomic::Ordering::Relaxed);
}
pub(crate) fn get_count(&self) -> u32 {
self.bucket_count.load(std::sync::atomic::Ordering::Relaxed)
Ok(GenericSegmentAggregationResultsCollector { metrics, buckets })
}
}

View File

@@ -515,8 +515,7 @@ mod tests {
expected_compressed_collapsed_mapping: &[usize],
expected_unique_facet_ords: &[(u64, usize)],
) {
let (compressed_collapsed_mapping, unique_facet_ords) =
compress_mapping(&collapsed_mapping);
let (compressed_collapsed_mapping, unique_facet_ords) = compress_mapping(collapsed_mapping);
assert_eq!(
compressed_collapsed_mapping,
expected_compressed_collapsed_mapping

View File

@@ -113,7 +113,7 @@ impl Collector for HistogramCollector {
segment: &crate::SegmentReader,
) -> crate::Result<Self::Child> {
let column_opt = segment.fast_fields().u64_lenient(&self.field)?;
let column = column_opt.ok_or_else(|| FastFieldNotAvailableError {
let (column, _column_type) = column_opt.ok_or_else(|| FastFieldNotAvailableError {
field_name: self.field.clone(),
})?;
let column_u64 = column.first_or_default_col(0u64);

View File

@@ -180,9 +180,11 @@ pub trait Collector: Sync + Send {
})?;
}
(Some(alive_bitset), false) => {
weight.for_each_no_score(reader, &mut |doc| {
if alive_bitset.is_alive(doc) {
segment_collector.collect(doc, 0.0);
weight.for_each_no_score(reader, &mut |docs| {
for doc in docs.iter().cloned() {
if alive_bitset.is_alive(doc) {
segment_collector.collect(doc, 0.0);
}
}
})?;
}
@@ -192,8 +194,8 @@ pub trait Collector: Sync + Send {
})?;
}
(None, false) => {
weight.for_each_no_score(reader, &mut |doc| {
segment_collector.collect(doc, 0.0);
weight.for_each_no_score(reader, &mut |docs| {
segment_collector.collect_block(docs);
})?;
}
}
@@ -270,6 +272,13 @@ pub trait SegmentCollector: 'static {
/// The query pushes the scored document to the collector via this method.
fn collect(&mut self, doc: DocId, score: Score);
/// The query pushes the scored document to the collector via this method.
fn collect_block(&mut self, docs: &[DocId]) {
for doc in docs {
self.collect(*doc, 0.0);
}
}
/// Extract the fruit of the collection from the `SegmentCollector`.
fn harvest(self) -> Self::Fruit;
}

View File

@@ -56,9 +56,8 @@ pub fn test_filter_collector() -> crate::Result<()> {
assert_eq!(filtered_top_docs.len(), 0);
fn date_filter(value: DateTime) -> bool {
(crate::DateTime::from(value).into_utc()
- OffsetDateTime::parse("2019-04-09T00:00:00+00:00", &Rfc3339).unwrap())
.whole_weeks()
(value.into_utc() - OffsetDateTime::parse("2019-04-09T00:00:00+00:00", &Rfc3339).unwrap())
.whole_weeks()
> 0
}
@@ -201,7 +200,7 @@ impl SegmentCollector for FastFieldSegmentCollector {
type Fruit = Vec<u64>;
fn collect(&mut self, doc: DocId, _score: Score) {
self.vals.extend(self.reader.values(doc));
self.vals.extend(self.reader.values_for_doc(doc));
}
fn harvest(self) -> Vec<u64> {

View File

@@ -155,12 +155,13 @@ impl CustomScorer<u64> for ScorerByField {
//
// The conversion will then happen only on the top-K docs.
let sort_column_opt = segment_reader.fast_fields().u64_lenient(&self.field)?;
let sort_column = sort_column_opt
.ok_or_else(|| FastFieldNotAvailableError {
let (sort_column, _sort_column_type) =
sort_column_opt.ok_or_else(|| FastFieldNotAvailableError {
field_name: self.field.clone(),
})?
.first_or_default_col(0u64);
Ok(ScorerByFastFieldReader { sort_column })
})?;
Ok(ScorerByFastFieldReader {
sort_column: sort_column.first_or_default_col(0u64),
})
}
}
@@ -1030,7 +1031,7 @@ mod tests {
let segment = searcher.segment_reader(0);
let top_collector = TopDocs::with_limit(4).order_by_u64_field(SIZE);
let err = top_collector.for_segment(0, segment).err().unwrap();
assert!(matches!(err, crate::TantivyError::SchemaError(_)));
assert!(matches!(err, crate::TantivyError::InvalidArgument(_)));
Ok(())
}

View File

@@ -662,304 +662,3 @@ impl fmt::Debug for Index {
write!(f, "Index({:?})", self.directory)
}
}
#[cfg(test)]
mod tests {
use crate::collector::Count;
use crate::directory::{RamDirectory, WatchCallback};
use crate::query::TermQuery;
use crate::schema::{Field, IndexRecordOption, Schema, INDEXED, TEXT};
use crate::tokenizer::TokenizerManager;
use crate::{Directory, Index, IndexBuilder, IndexReader, IndexSettings, ReloadPolicy, Term};
#[test]
fn test_indexer_for_field() {
let mut schema_builder = Schema::builder();
let num_likes_field = schema_builder.add_u64_field("num_likes", INDEXED);
let body_field = schema_builder.add_text_field("body", TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
assert!(index.tokenizer_for_field(body_field).is_ok());
assert_eq!(
format!("{:?}", index.tokenizer_for_field(num_likes_field).err()),
"Some(SchemaError(\"\\\"num_likes\\\" is not a text field.\"))"
);
}
#[test]
fn test_set_tokenizer_manager() {
let mut schema_builder = Schema::builder();
schema_builder.add_u64_field("num_likes", INDEXED);
schema_builder.add_text_field("body", TEXT);
let schema = schema_builder.build();
let index = IndexBuilder::new()
// set empty tokenizer manager
.tokenizers(TokenizerManager::new())
.schema(schema)
.create_in_ram()
.unwrap();
assert!(index.tokenizers().get("raw").is_none());
}
#[test]
fn test_index_exists() {
let directory: Box<dyn Directory> = Box::new(RamDirectory::create());
assert!(!Index::exists(directory.as_ref()).unwrap());
assert!(Index::create(
directory.clone(),
throw_away_schema(),
IndexSettings::default()
)
.is_ok());
assert!(Index::exists(directory.as_ref()).unwrap());
}
#[test]
fn open_or_create_should_create() {
let directory = RamDirectory::create();
assert!(!Index::exists(&directory).unwrap());
assert!(Index::open_or_create(directory.clone(), throw_away_schema()).is_ok());
assert!(Index::exists(&directory).unwrap());
}
#[test]
fn open_or_create_should_open() {
let directory: Box<dyn Directory> = Box::new(RamDirectory::create());
assert!(Index::create(
directory.clone(),
throw_away_schema(),
IndexSettings::default()
)
.is_ok());
assert!(Index::exists(directory.as_ref()).unwrap());
assert!(Index::open_or_create(directory, throw_away_schema()).is_ok());
}
#[test]
fn create_should_wipeoff_existing() {
let directory: Box<dyn Directory> = Box::new(RamDirectory::create());
assert!(Index::create(
directory.clone(),
throw_away_schema(),
IndexSettings::default()
)
.is_ok());
assert!(Index::exists(directory.as_ref()).unwrap());
assert!(Index::create(
directory,
Schema::builder().build(),
IndexSettings::default()
)
.is_ok());
}
#[test]
fn open_or_create_exists_but_schema_does_not_match() {
let directory = RamDirectory::create();
assert!(Index::create(
directory.clone(),
throw_away_schema(),
IndexSettings::default()
)
.is_ok());
assert!(Index::exists(&directory).unwrap());
assert!(Index::open_or_create(directory.clone(), throw_away_schema()).is_ok());
let err = Index::open_or_create(directory, Schema::builder().build());
assert_eq!(
format!("{:?}", err.unwrap_err()),
"SchemaError(\"An index exists but the schema does not match.\")"
);
}
fn throw_away_schema() -> Schema {
let mut schema_builder = Schema::builder();
let _ = schema_builder.add_u64_field("num_likes", INDEXED);
schema_builder.build()
}
#[test]
fn test_index_on_commit_reload_policy() -> crate::Result<()> {
let schema = throw_away_schema();
let field = schema.get_field("num_likes").unwrap();
let index = Index::create_in_ram(schema);
let reader = index
.reader_builder()
.reload_policy(ReloadPolicy::OnCommit)
.try_into()
.unwrap();
assert_eq!(reader.searcher().num_docs(), 0);
test_index_on_commit_reload_policy_aux(field, &index, &reader)
}
#[cfg(feature = "mmap")]
mod mmap_specific {
use std::path::PathBuf;
use tempfile::TempDir;
use super::*;
use crate::Directory;
#[test]
fn test_index_on_commit_reload_policy_mmap() -> crate::Result<()> {
let schema = throw_away_schema();
let field = schema.get_field("num_likes").unwrap();
let tempdir = TempDir::new().unwrap();
let tempdir_path = PathBuf::from(tempdir.path());
let index = Index::create_in_dir(tempdir_path, schema).unwrap();
let reader = index
.reader_builder()
.reload_policy(ReloadPolicy::OnCommit)
.try_into()
.unwrap();
assert_eq!(reader.searcher().num_docs(), 0);
test_index_on_commit_reload_policy_aux(field, &index, &reader)
}
#[test]
fn test_index_manual_policy_mmap() -> crate::Result<()> {
let schema = throw_away_schema();
let field = schema.get_field("num_likes").unwrap();
let mut index = Index::create_from_tempdir(schema)?;
let mut writer = index.writer_for_tests()?;
writer.commit()?;
let reader = index
.reader_builder()
.reload_policy(ReloadPolicy::Manual)
.try_into()?;
assert_eq!(reader.searcher().num_docs(), 0);
writer.add_document(doc!(field=>1u64))?;
let (sender, receiver) = crossbeam_channel::unbounded();
let _handle = index.directory_mut().watch(WatchCallback::new(move || {
let _ = sender.send(());
}));
writer.commit()?;
assert!(receiver.recv().is_ok());
assert_eq!(reader.searcher().num_docs(), 0);
reader.reload()?;
assert_eq!(reader.searcher().num_docs(), 1);
Ok(())
}
#[test]
fn test_index_on_commit_reload_policy_different_directories() -> crate::Result<()> {
let schema = throw_away_schema();
let field = schema.get_field("num_likes").unwrap();
let tempdir = TempDir::new().unwrap();
let tempdir_path = PathBuf::from(tempdir.path());
let write_index = Index::create_in_dir(&tempdir_path, schema).unwrap();
let read_index = Index::open_in_dir(&tempdir_path).unwrap();
let reader = read_index
.reader_builder()
.reload_policy(ReloadPolicy::OnCommit)
.try_into()
.unwrap();
assert_eq!(reader.searcher().num_docs(), 0);
test_index_on_commit_reload_policy_aux(field, &write_index, &reader)
}
}
fn test_index_on_commit_reload_policy_aux(
field: Field,
index: &Index,
reader: &IndexReader,
) -> crate::Result<()> {
let mut reader_index = reader.index();
let (sender, receiver) = crossbeam_channel::unbounded();
let _watch_handle = reader_index
.directory_mut()
.watch(WatchCallback::new(move || {
let _ = sender.send(());
}));
let mut writer = index.writer_for_tests()?;
assert_eq!(reader.searcher().num_docs(), 0);
writer.add_document(doc!(field=>1u64))?;
writer.commit().unwrap();
// We need a loop here because it is possible for notify to send more than
// one modify event. It was observed on CI on MacOS.
loop {
assert!(receiver.recv().is_ok());
if reader.searcher().num_docs() == 1 {
break;
}
}
writer.add_document(doc!(field=>2u64))?;
writer.commit().unwrap();
// ... Same as above
loop {
assert!(receiver.recv().is_ok());
if reader.searcher().num_docs() == 2 {
break;
}
}
Ok(())
}
// This test will not pass on windows, because windows
// prevent deleting files that are MMapped.
#[cfg(not(target_os = "windows"))]
#[test]
fn garbage_collect_works_as_intended() -> crate::Result<()> {
let directory = RamDirectory::create();
let schema = throw_away_schema();
let field = schema.get_field("num_likes").unwrap();
let index = Index::create(directory.clone(), schema, IndexSettings::default())?;
let mut writer = index.writer_with_num_threads(1, 32_000_000).unwrap();
for _seg in 0..8 {
for i in 0u64..1_000u64 {
writer.add_document(doc!(field => i))?;
}
writer.commit()?;
}
let mem_right_after_commit = directory.total_mem_usage();
let reader = index
.reader_builder()
.reload_policy(ReloadPolicy::Manual)
.try_into()?;
assert_eq!(reader.searcher().num_docs(), 8_000);
assert_eq!(reader.searcher().segment_readers().len(), 8);
writer.wait_merging_threads()?;
let mem_right_after_merge_finished = directory.total_mem_usage();
reader.reload().unwrap();
let searcher = reader.searcher();
assert_eq!(searcher.segment_readers().len(), 1);
assert_eq!(searcher.num_docs(), 8_000);
assert!(
mem_right_after_merge_finished < mem_right_after_commit,
"(mem after merge){} is expected < (mem before merge){}",
mem_right_after_merge_finished,
mem_right_after_commit
);
Ok(())
}
#[test]
fn test_single_segment_index_writer() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let text_field = schema_builder.add_text_field("text", TEXT);
let schema = schema_builder.build();
let directory = RamDirectory::default();
let mut single_segment_index_writer = Index::builder()
.schema(schema)
.single_segment_index_writer(directory, 10_000_000)?;
for _ in 0..10 {
let doc = doc!(text_field=>"hello");
single_segment_index_writer.add_document(doc)?;
}
let index = single_segment_index_writer.finalize()?;
let searcher = index.reader()?.searcher();
let term_query = TermQuery::new(
Term::from_field_text(text_field, "hello"),
IndexRecordOption::Basic,
);
let count = searcher.search(&term_query, &Count)?;
assert_eq!(count, 10);
Ok(())
}
}

View File

@@ -1,10 +1,11 @@
use columnar::MonotonicallyMappableToU64;
use common::replace_in_place;
use murmurhash32::murmurhash2;
use rustc_hash::FxHashMap;
use crate::fastfield::FastValue;
use crate::postings::{IndexingContext, IndexingPosition, PostingsWriter};
use crate::schema::term::{JSON_END_OF_PATH, JSON_PATH_SEGMENT_SEP};
use crate::schema::term::{JSON_END_OF_PATH, JSON_PATH_SEGMENT_SEP, JSON_PATH_SEGMENT_SEP_STR};
use crate::schema::{Field, Type};
use crate::time::format_description::well_known::Rfc3339;
use crate::time::{OffsetDateTime, UtcOffset};
@@ -199,7 +200,7 @@ fn infer_type_from_str(text: &str) -> TextOrDateTime {
}
}
// Tries to infer a JSON type from a string
// Tries to infer a JSON type from a string.
pub(crate) fn convert_to_fast_value_and_get_term(
json_term_writer: &mut JsonTermWriter,
phrase: &str,
@@ -295,6 +296,32 @@ fn split_json_path(json_path: &str) -> Vec<String> {
json_path_segments
}
/// Takes a field name, a json path as supplied by a user, and whether we should expand dots, and
/// return a column key, as expected by the columnar crate.
///
/// This function will detect unescaped dots in the path, and split over them.
/// If expand_dots is enabled, then even escaped dots will be split over.
///
/// The resulting list of segment then gets stitched together, joined by \1 separator,
/// as defined in the columnar crate.
pub(crate) fn encode_column_name(
field_name: &str,
json_path: &str,
expand_dots_enabled: bool,
) -> String {
let mut column_key: String = String::with_capacity(field_name.len() + json_path.len() + 1);
column_key.push_str(field_name);
for mut segment in split_json_path(json_path) {
column_key.push_str(JSON_PATH_SEGMENT_SEP_STR);
if expand_dots_enabled {
// We need to replace `.` by JSON_PATH_SEGMENT_SEP.
unsafe { replace_in_place(b'.', JSON_PATH_SEGMENT_SEP, segment.as_bytes_mut()) };
}
column_key.push_str(&segment);
}
column_key
}
impl<'a> JsonTermWriter<'a> {
pub fn from_field_and_json_path(
field: Field,
@@ -343,18 +370,10 @@ impl<'a> JsonTermWriter<'a> {
if self.path_stack.len() > 1 {
buffer[buffer_len - 1] = JSON_PATH_SEGMENT_SEP;
}
if self.expand_dots_enabled && segment.as_bytes().contains(&b'.') {
let appended_segment = self.term_buffer.append_bytes(segment.as_bytes());
if self.expand_dots_enabled {
// We need to replace `.` by JSON_PATH_SEGMENT_SEP.
self.term_buffer
.append_bytes(segment.as_bytes())
.iter_mut()
.for_each(|byte| {
if *byte == b'.' {
*byte = JSON_PATH_SEGMENT_SEP;
}
});
} else {
self.term_buffer.append_bytes(segment.as_bytes());
replace_in_place(b'.', JSON_PATH_SEGMENT_SEP, appended_segment);
}
self.term_buffer.push_byte(JSON_PATH_SEGMENT_SEP);
self.path_stack.push(self.term_buffer.len_bytes());
@@ -373,7 +392,7 @@ impl<'a> JsonTermWriter<'a> {
&self.term().value_bytes()[..end_of_path - 1]
}
pub fn set_fast_value<T: FastValue>(&mut self, val: T) {
pub(crate) fn set_fast_value<T: FastValue>(&mut self, val: T) {
self.close_path_and_set_type(T::to_type());
let value = if T::to_type() == Type::Date {
DateTime::from_u64(val.to_u64())

View File

@@ -2,6 +2,7 @@ mod executor;
pub mod index;
mod index_meta;
mod inverted_index_reader;
pub mod json_utils;
pub mod searcher;
mod segment;
mod segment_component;
@@ -36,3 +37,6 @@ pub static META_FILEPATH: Lazy<&'static Path> = Lazy::new(|| Path::new("meta.jso
/// Removing this file is safe, but will prevent the garbage collection of all of the file that
/// are currently in the directory
pub static MANAGED_FILEPATH: Lazy<&'static Path> = Lazy::new(|| Path::new(".managed.json"));
#[cfg(test)]
mod tests;

View File

@@ -4,7 +4,7 @@ use std::{fmt, io};
use crate::collector::Collector;
use crate::core::{Executor, SegmentReader};
use crate::query::{EnableScoring, Query};
use crate::query::{Bm25StatisticsProvider, EnableScoring, Query};
use crate::schema::{Document, Schema, Term};
use crate::space_usage::SearcherSpaceUsage;
use crate::store::{CacheStats, StoreReader};
@@ -176,8 +176,27 @@ impl Searcher {
query: &dyn Query,
collector: &C,
) -> crate::Result<C::Fruit> {
self.search_with_statistics_provider(query, collector, self)
}
/// Same as [`search(...)`](Searcher::search) but allows specifying
/// a [Bm25StatisticsProvider].
///
/// This can be used to adjust the statistics used in computing BM25
/// scores.
pub fn search_with_statistics_provider<C: Collector>(
&self,
query: &dyn Query,
collector: &C,
statistics_provider: &dyn Bm25StatisticsProvider,
) -> crate::Result<C::Fruit> {
let enabled_scoring = if collector.requires_scoring() {
EnableScoring::enabled_from_statistics_provider(statistics_provider, self)
} else {
EnableScoring::disabled_from_searcher(self)
};
let executor = self.inner.index.search_executor();
self.search_with_executor(query, collector, executor)
self.search_with_executor(query, collector, executor, enabled_scoring)
}
/// Same as [`search(...)`](Searcher::search) but multithreaded.
@@ -197,12 +216,8 @@ impl Searcher {
query: &dyn Query,
collector: &C,
executor: &Executor,
enabled_scoring: EnableScoring,
) -> crate::Result<C::Fruit> {
let enabled_scoring = if collector.requires_scoring() {
EnableScoring::enabled_from_searcher(self)
} else {
EnableScoring::disabled_from_searcher(self)
};
let weight = query.weight(enabled_scoring)?;
let segment_readers = self.segment_readers();
let fruits = executor.map(

View File

@@ -38,7 +38,7 @@ pub struct SegmentReader {
termdict_composite: CompositeFile,
postings_composite: CompositeFile,
positions_composite: CompositeFile,
fast_fields_readers: Arc<FastFieldReaders>,
fast_fields_readers: FastFieldReaders,
fieldnorm_readers: FieldNormReaders,
store_file: FileSlice,
@@ -167,7 +167,7 @@ impl SegmentReader {
let schema = segment.schema();
let fast_fields_data = segment.open_read(SegmentComponent::FastFields)?;
let fast_fields_readers = Arc::new(FastFieldReaders::open(fast_fields_data)?);
let fast_fields_readers = FastFieldReaders::open(fast_fields_data, schema.clone())?;
let fieldnorm_data = segment.open_read(SegmentComponent::FieldNorms)?;
let fieldnorm_readers = FieldNormReaders::open(fieldnorm_data)?;
@@ -327,7 +327,7 @@ impl SegmentReader {
self.alive_bitset_opt
.as_ref()
.map(AliveBitSet::space_usage)
.unwrap_or(0),
.unwrap_or_default(),
))
}
}

347
src/core/tests.rs Normal file
View File

@@ -0,0 +1,347 @@
use crate::collector::Count;
use crate::directory::{RamDirectory, WatchCallback};
use crate::indexer::NoMergePolicy;
use crate::query::TermQuery;
use crate::schema::{Field, IndexRecordOption, Schema, INDEXED, STRING, TEXT};
use crate::tokenizer::TokenizerManager;
use crate::{
Directory, Document, Index, IndexBuilder, IndexReader, IndexSettings, ReloadPolicy, SegmentId,
Term,
};
#[test]
fn test_indexer_for_field() {
let mut schema_builder = Schema::builder();
let num_likes_field = schema_builder.add_u64_field("num_likes", INDEXED);
let body_field = schema_builder.add_text_field("body", TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
assert!(index.tokenizer_for_field(body_field).is_ok());
assert_eq!(
format!("{:?}", index.tokenizer_for_field(num_likes_field).err()),
"Some(SchemaError(\"\\\"num_likes\\\" is not a text field.\"))"
);
}
#[test]
fn test_set_tokenizer_manager() {
let mut schema_builder = Schema::builder();
schema_builder.add_u64_field("num_likes", INDEXED);
schema_builder.add_text_field("body", TEXT);
let schema = schema_builder.build();
let index = IndexBuilder::new()
// set empty tokenizer manager
.tokenizers(TokenizerManager::new())
.schema(schema)
.create_in_ram()
.unwrap();
assert!(index.tokenizers().get("raw").is_none());
}
#[test]
fn test_index_exists() {
let directory: Box<dyn Directory> = Box::new(RamDirectory::create());
assert!(!Index::exists(directory.as_ref()).unwrap());
assert!(Index::create(
directory.clone(),
throw_away_schema(),
IndexSettings::default()
)
.is_ok());
assert!(Index::exists(directory.as_ref()).unwrap());
}
#[test]
fn open_or_create_should_create() {
let directory = RamDirectory::create();
assert!(!Index::exists(&directory).unwrap());
assert!(Index::open_or_create(directory.clone(), throw_away_schema()).is_ok());
assert!(Index::exists(&directory).unwrap());
}
#[test]
fn open_or_create_should_open() {
let directory: Box<dyn Directory> = Box::new(RamDirectory::create());
assert!(Index::create(
directory.clone(),
throw_away_schema(),
IndexSettings::default()
)
.is_ok());
assert!(Index::exists(directory.as_ref()).unwrap());
assert!(Index::open_or_create(directory, throw_away_schema()).is_ok());
}
#[test]
fn create_should_wipeoff_existing() {
let directory: Box<dyn Directory> = Box::new(RamDirectory::create());
assert!(Index::create(
directory.clone(),
throw_away_schema(),
IndexSettings::default()
)
.is_ok());
assert!(Index::exists(directory.as_ref()).unwrap());
assert!(Index::create(
directory,
Schema::builder().build(),
IndexSettings::default()
)
.is_ok());
}
#[test]
fn open_or_create_exists_but_schema_does_not_match() {
let directory = RamDirectory::create();
assert!(Index::create(
directory.clone(),
throw_away_schema(),
IndexSettings::default()
)
.is_ok());
assert!(Index::exists(&directory).unwrap());
assert!(Index::open_or_create(directory.clone(), throw_away_schema()).is_ok());
let err = Index::open_or_create(directory, Schema::builder().build());
assert_eq!(
format!("{:?}", err.unwrap_err()),
"SchemaError(\"An index exists but the schema does not match.\")"
);
}
fn throw_away_schema() -> Schema {
let mut schema_builder = Schema::builder();
let _ = schema_builder.add_u64_field("num_likes", INDEXED);
schema_builder.build()
}
#[test]
fn test_index_on_commit_reload_policy() -> crate::Result<()> {
let schema = throw_away_schema();
let field = schema.get_field("num_likes").unwrap();
let index = Index::create_in_ram(schema);
let reader = index
.reader_builder()
.reload_policy(ReloadPolicy::OnCommit)
.try_into()
.unwrap();
assert_eq!(reader.searcher().num_docs(), 0);
test_index_on_commit_reload_policy_aux(field, &index, &reader)
}
#[cfg(feature = "mmap")]
mod mmap_specific {
use std::path::PathBuf;
use tempfile::TempDir;
use super::*;
use crate::Directory;
#[test]
fn test_index_on_commit_reload_policy_mmap() -> crate::Result<()> {
let schema = throw_away_schema();
let field = schema.get_field("num_likes").unwrap();
let tempdir = TempDir::new().unwrap();
let tempdir_path = PathBuf::from(tempdir.path());
let index = Index::create_in_dir(tempdir_path, schema).unwrap();
let reader = index
.reader_builder()
.reload_policy(ReloadPolicy::OnCommit)
.try_into()
.unwrap();
assert_eq!(reader.searcher().num_docs(), 0);
test_index_on_commit_reload_policy_aux(field, &index, &reader)
}
#[test]
fn test_index_manual_policy_mmap() -> crate::Result<()> {
let schema = throw_away_schema();
let field = schema.get_field("num_likes").unwrap();
let mut index = Index::create_from_tempdir(schema)?;
let mut writer = index.writer_for_tests()?;
writer.commit()?;
let reader = index
.reader_builder()
.reload_policy(ReloadPolicy::Manual)
.try_into()?;
assert_eq!(reader.searcher().num_docs(), 0);
writer.add_document(doc!(field=>1u64))?;
let (sender, receiver) = crossbeam_channel::unbounded();
let _handle = index.directory_mut().watch(WatchCallback::new(move || {
let _ = sender.send(());
}));
writer.commit()?;
assert!(receiver.recv().is_ok());
assert_eq!(reader.searcher().num_docs(), 0);
reader.reload()?;
assert_eq!(reader.searcher().num_docs(), 1);
Ok(())
}
#[test]
fn test_index_on_commit_reload_policy_different_directories() -> crate::Result<()> {
let schema = throw_away_schema();
let field = schema.get_field("num_likes").unwrap();
let tempdir = TempDir::new().unwrap();
let tempdir_path = PathBuf::from(tempdir.path());
let write_index = Index::create_in_dir(&tempdir_path, schema).unwrap();
let read_index = Index::open_in_dir(&tempdir_path).unwrap();
let reader = read_index
.reader_builder()
.reload_policy(ReloadPolicy::OnCommit)
.try_into()
.unwrap();
assert_eq!(reader.searcher().num_docs(), 0);
test_index_on_commit_reload_policy_aux(field, &write_index, &reader)
}
}
fn test_index_on_commit_reload_policy_aux(
field: Field,
index: &Index,
reader: &IndexReader,
) -> crate::Result<()> {
let mut reader_index = reader.index();
let (sender, receiver) = crossbeam_channel::unbounded();
let _watch_handle = reader_index
.directory_mut()
.watch(WatchCallback::new(move || {
let _ = sender.send(());
}));
let mut writer = index.writer_for_tests()?;
assert_eq!(reader.searcher().num_docs(), 0);
writer.add_document(doc!(field=>1u64))?;
writer.commit().unwrap();
// We need a loop here because it is possible for notify to send more than
// one modify event. It was observed on CI on MacOS.
loop {
assert!(receiver.recv().is_ok());
if reader.searcher().num_docs() == 1 {
break;
}
}
writer.add_document(doc!(field=>2u64))?;
writer.commit().unwrap();
// ... Same as above
loop {
assert!(receiver.recv().is_ok());
if reader.searcher().num_docs() == 2 {
break;
}
}
Ok(())
}
// This test will not pass on windows, because windows
// prevent deleting files that are MMapped.
#[cfg(not(target_os = "windows"))]
#[test]
fn garbage_collect_works_as_intended() -> crate::Result<()> {
let directory = RamDirectory::create();
let schema = throw_away_schema();
let field = schema.get_field("num_likes").unwrap();
let index = Index::create(directory.clone(), schema, IndexSettings::default())?;
let mut writer = index.writer_with_num_threads(1, 32_000_000).unwrap();
for _seg in 0..8 {
for i in 0u64..1_000u64 {
writer.add_document(doc!(field => i))?;
}
writer.commit()?;
}
let mem_right_after_commit = directory.total_mem_usage();
let reader = index
.reader_builder()
.reload_policy(ReloadPolicy::Manual)
.try_into()?;
assert_eq!(reader.searcher().num_docs(), 8_000);
assert_eq!(reader.searcher().segment_readers().len(), 8);
writer.wait_merging_threads()?;
let mem_right_after_merge_finished = directory.total_mem_usage();
reader.reload().unwrap();
let searcher = reader.searcher();
assert_eq!(searcher.segment_readers().len(), 1);
assert_eq!(searcher.num_docs(), 8_000);
assert!(
mem_right_after_merge_finished < mem_right_after_commit,
"(mem after merge){} is expected < (mem before merge){}",
mem_right_after_merge_finished,
mem_right_after_commit
);
Ok(())
}
#[test]
fn test_single_segment_index_writer() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let text_field = schema_builder.add_text_field("text", TEXT);
let schema = schema_builder.build();
let directory = RamDirectory::default();
let mut single_segment_index_writer = Index::builder()
.schema(schema)
.single_segment_index_writer(directory, 10_000_000)?;
for _ in 0..10 {
let doc = doc!(text_field=>"hello");
single_segment_index_writer.add_document(doc)?;
}
let index = single_segment_index_writer.finalize()?;
let searcher = index.reader()?.searcher();
let term_query = TermQuery::new(
Term::from_field_text(text_field, "hello"),
IndexRecordOption::Basic,
);
let count = searcher.search(&term_query, &Count)?;
assert_eq!(count, 10);
Ok(())
}
#[test]
fn test_merging_segment_update_docfreq() {
let mut schema_builder = Schema::builder();
let text_field = schema_builder.add_text_field("text", TEXT);
let id_field = schema_builder.add_text_field("id", STRING);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut writer = index.writer_for_tests().unwrap();
writer.set_merge_policy(Box::new(NoMergePolicy));
for _ in 0..5 {
writer.add_document(doc!(text_field=>"hello")).unwrap();
}
writer
.add_document(doc!(text_field=>"hello", id_field=>"TO_BE_DELETED"))
.unwrap();
writer
.add_document(doc!(text_field=>"hello", id_field=>"TO_BE_DELETED"))
.unwrap();
writer.add_document(Document::default()).unwrap();
writer.commit().unwrap();
for _ in 0..7 {
writer.add_document(doc!(text_field=>"hello")).unwrap();
}
writer.add_document(Document::default()).unwrap();
writer.add_document(Document::default()).unwrap();
writer.delete_term(Term::from_field_text(id_field, "TO_BE_DELETED"));
writer.commit().unwrap();
let segment_ids: Vec<SegmentId> = index
.list_all_segment_metas()
.into_iter()
.map(|reader| reader.id())
.collect();
writer.merge(&segment_ids[..]).wait().unwrap();
let index_reader = index.reader().unwrap();
let searcher = index_reader.searcher();
assert_eq!(searcher.segment_readers().len(), 1);
assert_eq!(searcher.num_docs(), 15);
let segment_reader = searcher.segment_reader(0);
assert_eq!(segment_reader.max_doc(), 15);
let inv_index = segment_reader.inverted_index(text_field).unwrap();
let term = Term::from_field_text(text_field, "hello");
let term_info = inv_index.get_term_info(&term).unwrap().unwrap();
assert_eq!(term_info.doc_freq, 12);
}

View File

@@ -172,7 +172,7 @@ impl CompositeFile {
let mut fields = Vec::new();
for (&field_addr, byte_range) in &self.offsets_index {
let mut field_usage = FieldUsage::empty(field_addr.field);
field_usage.add_field_idx(field_addr.idx, byte_range.len());
field_usage.add_field_idx(field_addr.idx, byte_range.len().into());
fields.push(field_usage);
}
PerFieldSpaceUsage::new(fields)

Some files were not shown because too many files have changed in this diff Show More