Compare commits

..

14 Commits

Author SHA1 Message Date
Paul Masurel
ab703486aa Updated columnar todo 2023-03-21 18:55:23 +09:00
PSeitz
8f7f1d6be4 add Display for ByteCount (#1949)
* add Display for ByteCount

* export missing AggregationLimits
2023-03-21 08:02:35 +01:00
PSeitz
6a7a1106d6 work in batches of docs (#1937)
* work in batches of docs

* add fill_buffer test
2023-03-21 06:57:44 +01:00
PSeitz
9e2faecf5b add memory limit for aggregations (#1942)
* add memory limit for aggregations

introduce AggregationLimits to set memory consumption limit and bucket limits
memory limit is checked during aggregation, bucket limit is checked before returning the aggregation request.

* Apply suggestions from code review

Co-authored-by: Paul Masurel <paul@quickwit.io>

* add ByteCount with human readable format

---------

Co-authored-by: Paul Masurel <paul@quickwit.io>
2023-03-16 06:21:07 +01:00
PSeitz
b6703f1b3c fix validation in date histogram (#1936)
fix validation in date histogram for parameters interval and date_interval
2023-03-15 06:10:43 +01:00
PSeitz
2fb3740cb0 handle missing column for aggs (#1920)
* handle missing column for aggs

add empty column fallback for missing column in aggs.
Fix sort for term agg on sub-agg with missing value (null is smallest)

* add error when field is not fast
2023-03-15 06:09:59 +01:00
PSeitz
8459efa32c split term collection count and sub_agg (#1921)
use unrolled ColumnValues::get_vals
2023-03-13 04:37:41 +01:00
PSeitz
61cfd8dc57 fix clippy (#1927) 2023-03-13 03:12:02 +01:00
trinity-1686a
064518156f refactor tokenization pipeline to use GATs (#1924)
* refactor tokenization pipeline to use GATs

* fix doctests

* fix clippy lints

* remove commented code
2023-03-09 09:39:37 +01:00
PSeitz
a42a96f470 fix panic in dict column merge (#1930)
* fix panic in dict column merge

* Bugfix and added unit test

---------

Co-authored-by: Paul Masurel <paul@quickwit.io>
2023-03-08 22:04:37 +09:00
trinity-1686a
fcf5a25d93 use DeltaReader directly to implement Dictionnary::ord_to_term (#1928) 2023-03-08 11:15:56 +09:00
dependabot[bot]
c0a5b28fd3 Update lru requirement from 0.9.0 to 0.10.0 (#1932)
Updates the requirements on [lru](https://github.com/jeromefroe/lru-rs) to permit the latest version.
- [Release notes](https://github.com/jeromefroe/lru-rs/releases)
- [Changelog](https://github.com/jeromefroe/lru-rs/blob/master/CHANGELOG.md)
- [Commits](https://github.com/jeromefroe/lru-rs/compare/0.9.0...0.10.0)

---
updated-dependencies:
- dependency-name: lru
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-03-07 15:09:02 +09:00
trinity-1686a
a4f7ca8309 use DeltaReader directly to implement Dictionnary::term_ord (#1925)
* use DeltaReader directly to implement Dictionnary::term_ord

* add some additional test case for Dictionary::term_ord
2023-03-06 09:45:22 +01:00
Paul Masurel
364e321415 Clippy fix (#1926) 2023-03-06 10:37:17 +09:00
82 changed files with 2428 additions and 1373 deletions

2
.gitignore vendored
View File

@@ -13,3 +13,5 @@ benchmark
.idea
trace.dat
cargo-timing*
control
variable

View File

@@ -16,6 +16,7 @@ rust-version = "1.62"
[dependencies]
oneshot = "0.1.5"
base64 = "0.21.0"
byteorder = "1.4.3"
crc32fast = "1.3.2"
once_cell = "1.10.0"
regex = { version = "1.5.5", default-features = false, features = ["std", "unicode"] }
@@ -47,7 +48,7 @@ murmurhash32 = "0.3.0"
time = { version = "0.3.10", features = ["serde-well-known"] }
smallvec = "1.8.0"
rayon = "1.5.2"
lru = "0.9.0"
lru = "0.10.0"
fastdivide = "0.4.0"
itertools = "0.10.3"
measure_time = "0.8.2"

View File

@@ -1,28 +1,22 @@
# zero to one
* revisit line codec
* add columns from schema on merge
* Plugging JSON
* replug examples
* move datetime to quickwit common
* switch to nanos
* reintroduce the gcd map.
# Perf and Size
* remove alloc in `ord_to_term`
+ multivaued range queries restrat frm the beginning all of the time.
* re-add ZSTD compression for dictionaries
no systematic monotonic mapping
consider removing multilinear
f32?
adhoc solution for bool?
add metrics helper for aggregate. sum(row_id)
review inline absence/presence
improv perf of select using PDEP
compare with roaring bitmap/elias fano etc etc.
SIMD range? (see blog post)
Add alignment?
Consider another codec to bridge the gap between few and 5k elements
* no systematic monotonic mapping
* consider removing multilinear
* f32?
* adhoc solution for bool?
* add metrics helper for aggregate. sum(row_id)
* review inline absence/presence
* improv perf of select using PDEP
* compare with roaring bitmap/elias fano etc etc.
* SIMD range? (see blog post)
* Add alignment?
* Consider another codec to bridge the gap between few and 5k elements
# Cleanup and rationalization
in benchmark, unify percent vs ratio, f32 vs f64.
@@ -30,15 +24,10 @@ investigate if should have better errors? io::Error is overused at the moment.
rename rank/select in unit tests
Review the public API via cargo doc
go through TODOs
remove all doc_id occurences -> row_id
use the rank & select naming in unit tests branch.
multi-linear -> blockwise
linear codec -> simply a multiplication for the index column
rename columnar to something more explicit, like column_dictionary or columnar_table
rename fastfield -> column
document changes
rationalization FastFieldValue, HasColumnType
isolate u128_based and uniform naming
# Other
fix enhance column-cli

View File

@@ -58,10 +58,21 @@ pub trait ColumnValues<T: PartialOrd = u64>: Send + Sync {
/// # Panics
///
/// May panic if `idx` is greater than the column length.
fn get_vals(&self, idx: &[u32], output: &mut [T]) {
assert!(idx.len() == output.len());
for (out, idx) in output.iter_mut().zip(idx.iter()) {
*out = self.get_val(*idx as u32);
fn get_vals(&self, indexes: &[u32], output: &mut [T]) {
assert!(indexes.len() == output.len());
let out_and_idx_chunks = output.chunks_exact_mut(4).zip(indexes.chunks_exact(4));
for (out_x4, idx_x4) in out_and_idx_chunks {
out_x4[0] = self.get_val(idx_x4[0]);
out_x4[1] = self.get_val(idx_x4[1]);
out_x4[2] = self.get_val(idx_x4[2]);
out_x4[3] = self.get_val(idx_x4[3]);
}
let step_size = 4;
let cutoff = indexes.len() - indexes.len() % step_size;
for idx in cutoff..indexes.len() {
output[idx] = self.get_val(indexes[idx]);
}
}

View File

@@ -50,7 +50,7 @@ where
Input: PartialOrd + Send + Debug + Sync + Clone,
Output: PartialOrd + Send + Debug + Sync + Clone,
{
#[inline]
#[inline(always)]
fn get_val(&self, idx: u32) -> Output {
let from_val = self.from_column.get_val(idx);
self.monotonic_mapping.mapping(from_val)

View File

@@ -1,6 +1,6 @@
use proptest::prelude::*;
use proptest::strategy::Strategy;
use proptest::{num, prop_oneof, proptest};
use proptest::{prop_oneof, proptest};
#[test]
fn test_serialize_and_load_simple() {
@@ -99,14 +99,28 @@ pub(crate) fn create_and_validate<TColumnCodec: ColumnCodec>(
let reader = TColumnCodec::load(OwnedBytes::new(buffer)).unwrap();
assert_eq!(reader.num_vals(), vals.len() as u32);
let mut buffer = Vec::new();
for (doc, orig_val) in vals.iter().copied().enumerate() {
let val = reader.get_val(doc as u32);
assert_eq!(
val, orig_val,
"val `{val}` does not match orig_val {orig_val:?}, in data set {name}, data `{vals:?}`",
);
buffer.resize(1, 0);
reader.get_vals(&[doc as u32], &mut buffer);
let val = buffer[0];
assert_eq!(
val, orig_val,
"val `{val}` does not match orig_val {orig_val:?}, in data set {name}, data `{vals:?}`",
);
}
let all_docs: Vec<u32> = (0..vals.len() as u32).collect();
buffer.resize(all_docs.len(), 0);
reader.get_vals(&all_docs, &mut buffer);
assert_eq!(vals, buffer);
if !vals.is_empty() {
let test_rand_idx = rand::thread_rng().gen_range(0..=vals.len() - 1);
let expected_positions: Vec<u32> = vals

View File

@@ -52,21 +52,18 @@ impl<'a> Iterable for RemappedTermOrdinalsValues<'a> {
impl<'a> RemappedTermOrdinalsValues<'a> {
fn boxed_iter_stacked(&self) -> Box<dyn Iterator<Item = u64> + '_> {
let iter = self
.bytes_columns
.iter()
.enumerate()
.flat_map(|(segment_ord, byte_column)| {
let segment_ord = self.term_ord_mapping.get_segment(segment_ord as u32);
byte_column.iter().flat_map(move |bytes_column| {
bytes_column
.ords()
.values
.iter()
.map(move |term_ord| segment_ord[term_ord as usize])
})
});
// TODO see if we can better decompose the mapping / and the stacking
let iter = self.bytes_columns.iter().flatten().enumerate().flat_map(
move |(seg_ord_with_column, bytes_column)| {
let term_ord_after_merge_mapping = self
.term_ord_mapping
.get_segment(seg_ord_with_column as u32);
bytes_column
.ords()
.values
.iter()
.map(move |term_ord| term_ord_after_merge_mapping[term_ord as usize])
},
);
Box::new(iter)
}
@@ -133,7 +130,6 @@ fn serialize_merged_dict(
let mut merged_terms = TermMerger::new(field_term_streams);
let mut sstable_builder = sstable::VoidSSTable::writer(output);
// TODO support complex `merge_row_order`.
match merge_row_order {
MergeRowOrder::Stack(_) => {
let mut current_term_ord = 0;

View File

@@ -153,20 +153,24 @@ fn make_numerical_columnar_multiple_columns(
ColumnarReader::open(buffer).unwrap()
}
fn make_byte_columnar_multiple_columns(columns: &[(&str, &[&[&[u8]]])]) -> ColumnarReader {
#[track_caller]
fn make_byte_columnar_multiple_columns(
columns: &[(&str, &[&[&[u8]]])],
num_rows: u32,
) -> ColumnarReader {
let mut dataframe_writer = ColumnarWriter::default();
for (column_name, column_values) in columns {
assert_eq!(
column_values.len(),
num_rows as usize,
"All columns must have `{num_rows}` rows"
);
for (row_id, vals) in column_values.iter().enumerate() {
for val in vals.iter() {
dataframe_writer.record_bytes(row_id as u32, column_name, val);
}
}
}
let num_rows = columns
.iter()
.map(|(_, val_rows)| val_rows.len() as RowId)
.max()
.unwrap_or(0u32);
let mut buffer: Vec<u8> = Vec::new();
dataframe_writer
.serialize(num_rows, None, &mut buffer)
@@ -272,8 +276,8 @@ fn test_merge_columnar_texts() {
#[test]
fn test_merge_columnar_byte() {
let columnar1 = make_byte_columnar_multiple_columns(&[("bytes", &[&[b"bbbb"], &[b"baaa"]])]);
let columnar2 = make_byte_columnar_multiple_columns(&[("bytes", &[&[], &[b"a"]])]);
let columnar1 = make_byte_columnar_multiple_columns(&[("bytes", &[&[b"bbbb"], &[b"baaa"]])], 2);
let columnar2 = make_byte_columnar_multiple_columns(&[("bytes", &[&[], &[b"a"]])], 2);
let mut buffer = Vec::new();
let columnars = &[&columnar1, &columnar2];
let stack_merge_order = StackMergeOrder::stack(columnars);
@@ -316,3 +320,59 @@ fn test_merge_columnar_byte() {
assert_eq!(get_bytes_for_row(2), b"");
assert_eq!(get_bytes_for_row(3), b"a");
}
#[test]
fn test_merge_columnar_byte_with_missing() {
let columnar1 = make_byte_columnar_multiple_columns(&[], 3);
let columnar2 = make_byte_columnar_multiple_columns(&[("col", &[&[b"b"], &[]])], 2);
let columnar3 = make_byte_columnar_multiple_columns(
&[
("col", &[&[], &[b"b"], &[b"a", b"b"]]),
("col2", &[&[b"hello"], &[], &[b"a", b"b"]]),
],
3,
);
let mut buffer = Vec::new();
let columnars = &[&columnar1, &columnar2, &columnar3];
let stack_merge_order = StackMergeOrder::stack(columnars);
crate::columnar::merge_columnar(
columnars,
&[],
MergeRowOrder::Stack(stack_merge_order),
&mut buffer,
)
.unwrap();
let columnar_reader = ColumnarReader::open(buffer).unwrap();
assert_eq!(columnar_reader.num_rows(), 3 + 2 + 3);
assert_eq!(columnar_reader.num_columns(), 2);
let cols = columnar_reader.read_columns("col").unwrap();
let dynamic_column = cols[0].open().unwrap();
let DynamicColumn::Bytes(vals) = dynamic_column else { panic!() };
let get_bytes_for_ord = |ord| {
let mut out = Vec::new();
vals.ord_to_bytes(ord, &mut out).unwrap();
out
};
assert_eq!(vals.dictionary.num_terms(), 2);
assert_eq!(get_bytes_for_ord(0), b"a");
assert_eq!(get_bytes_for_ord(1), b"b");
let get_bytes_for_row = |row_id| {
let terms: Vec<Vec<u8>> = vals
.term_ords(row_id)
.map(|term_ord| {
let mut out = Vec::new();
vals.ord_to_bytes(term_ord, &mut out).unwrap();
out
})
.collect();
terms
};
assert!(get_bytes_for_row(0).is_empty());
assert!(get_bytes_for_row(1).is_empty());
assert!(get_bytes_for_row(2).is_empty());
assert_eq!(get_bytes_for_row(3), vec![b"b".to_vec()]);
assert!(get_bytes_for_row(4).is_empty());
assert!(get_bytes_for_row(5).is_empty());
assert_eq!(get_bytes_for_row(6), vec![b"b".to_vec()]);
assert_eq!(get_bytes_for_row(7), vec![b"a".to_vec(), b"b".to_vec()]);
}

View File

@@ -3,7 +3,7 @@ use std::net::Ipv6Addr;
use std::sync::Arc;
use common::file_slice::FileSlice;
use common::{DateTime, HasLen, OwnedBytes};
use common::{ByteCount, DateTime, HasLen, OwnedBytes};
use crate::column::{BytesColumn, Column, StrColumn};
use crate::column_values::{monotonic_map_column, StrictlyMonotonicFn};
@@ -248,8 +248,8 @@ impl DynamicColumnHandle {
Ok(dynamic_column)
}
pub fn num_bytes(&self) -> usize {
self.file_slice.len()
pub fn num_bytes(&self) -> ByteCount {
self.file_slice.len().into()
}
pub fn column_type(&self) -> ColumnType {

View File

@@ -13,6 +13,7 @@ repository = "https://github.com/quickwit-oss/tantivy"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
byteorder = "1.4.3"
ownedbytes = { version= "0.5", path="../ownedbytes" }
async-trait = "0.1"
time = { version = "0.3.10", features = ["serde-well-known"] }

View File

@@ -4,6 +4,8 @@ use std::{fmt, io, u64};
use ownedbytes::OwnedBytes;
use crate::ByteCount;
#[derive(Clone, Copy, Eq, PartialEq)]
pub struct TinySet(u64);
@@ -386,8 +388,8 @@ impl ReadOnlyBitSet {
}
/// Number of bytes used in the bitset representation.
pub fn num_bytes(&self) -> usize {
self.data.len()
pub fn num_bytes(&self) -> ByteCount {
self.data.len().into()
}
}

114
common/src/byte_count.rs Normal file
View File

@@ -0,0 +1,114 @@
use std::iter::Sum;
use std::ops::{Add, AddAssign};
use serde::{Deserialize, Serialize};
/// Indicates space usage in bytes
#[derive(Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct ByteCount(u64);
impl std::fmt::Debug for ByteCount {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.human_readable())
}
}
impl std::fmt::Display for ByteCount {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.human_readable())
}
}
const SUFFIX_AND_THRESHOLD: [(&str, u64); 5] = [
("KB", 1_000),
("MB", 1_000_000),
("GB", 1_000_000_000),
("TB", 1_000_000_000_000),
("PB", 1_000_000_000_000_000),
];
impl ByteCount {
#[inline]
pub fn get_bytes(&self) -> u64 {
self.0
}
pub fn human_readable(&self) -> String {
for (suffix, threshold) in SUFFIX_AND_THRESHOLD.iter().rev() {
if self.get_bytes() >= *threshold {
let unit_num = self.get_bytes() as f64 / *threshold as f64;
return format!("{:.2} {}", unit_num, suffix);
}
}
format!("{:.2} B", self.get_bytes())
}
}
impl From<u64> for ByteCount {
fn from(value: u64) -> Self {
ByteCount(value)
}
}
impl From<usize> for ByteCount {
fn from(value: usize) -> Self {
ByteCount(value as u64)
}
}
impl Sum for ByteCount {
#[inline]
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.fold(ByteCount::default(), |acc, x| acc + x)
}
}
impl PartialEq<u64> for ByteCount {
#[inline]
fn eq(&self, other: &u64) -> bool {
self.get_bytes() == *other
}
}
impl PartialOrd<u64> for ByteCount {
#[inline]
fn partial_cmp(&self, other: &u64) -> Option<std::cmp::Ordering> {
self.get_bytes().partial_cmp(other)
}
}
impl Add for ByteCount {
type Output = Self;
#[inline]
fn add(self, other: Self) -> Self {
Self(self.get_bytes() + other.get_bytes())
}
}
impl AddAssign for ByteCount {
#[inline]
fn add_assign(&mut self, other: Self) {
*self = Self(self.get_bytes() + other.get_bytes());
}
}
#[cfg(test)]
mod test {
use crate::ByteCount;
#[test]
fn test_bytes() {
assert_eq!(ByteCount::from(0u64).human_readable(), "0 B");
assert_eq!(ByteCount::from(300u64).human_readable(), "300 B");
assert_eq!(ByteCount::from(1_000_000u64).human_readable(), "1.00 MB");
assert_eq!(ByteCount::from(1_500_000u64).human_readable(), "1.50 MB");
assert_eq!(
ByteCount::from(1_500_000_000u64).human_readable(),
"1.50 GB"
);
assert_eq!(
ByteCount::from(3_213_000_000_000u64).human_readable(),
"3.21 TB"
);
}
}

View File

@@ -5,7 +5,7 @@ use std::{fmt, io};
use async_trait::async_trait;
use ownedbytes::{OwnedBytes, StableDeref};
use crate::HasLen;
use crate::{ByteCount, HasLen};
/// Objects that represents files sections in tantivy.
///
@@ -216,6 +216,11 @@ impl FileSlice {
pub fn slice_to(&self, to_offset: usize) -> FileSlice {
self.slice(0..to_offset)
}
/// Returns the byte count of the FileSlice.
pub fn num_bytes(&self) -> ByteCount {
self.range.len().into()
}
}
#[async_trait]

View File

@@ -2,7 +2,10 @@
use std::ops::Deref;
pub use byteorder::LittleEndian as Endianness;
mod bitset;
mod byte_count;
mod datetime;
pub mod file_slice;
mod group_by;
@@ -10,6 +13,7 @@ mod serialize;
mod vint;
mod writer;
pub use bitset::*;
pub use byte_count::ByteCount;
pub use datetime::{DatePrecision, DateTime};
pub use group_by::GroupByIteratorExtended;
pub use ownedbytes::{OwnedBytes, StableDeref};

View File

@@ -1,7 +1,9 @@
use std::io::{Read, Write};
use std::{fmt, io};
use crate::VInt;
use byteorder::{ReadBytesExt, WriteBytesExt};
use crate::{Endianness, VInt};
#[derive(Default)]
struct Counter(u64);
@@ -105,13 +107,11 @@ impl<Left: BinarySerializable + FixedSize, Right: BinarySerializable + FixedSize
impl BinarySerializable for u32 {
fn serialize<W: Write + ?Sized>(&self, writer: &mut W) -> io::Result<()> {
writer.write_all(&self.to_le_bytes())
writer.write_u32::<Endianness>(*self)
}
fn deserialize<R: Read>(reader: &mut R) -> io::Result<u32> {
let mut buf = [0u8; 4];
reader.read_exact(&mut buf)?;
Ok(u32::from_le_bytes(buf))
reader.read_u32::<Endianness>()
}
}
@@ -121,13 +121,11 @@ impl FixedSize for u32 {
impl BinarySerializable for u16 {
fn serialize<W: Write + ?Sized>(&self, writer: &mut W) -> io::Result<()> {
writer.write_all(&self.to_le_bytes())
writer.write_u16::<Endianness>(*self)
}
fn deserialize<R: Read>(reader: &mut R) -> io::Result<u16> {
let mut buf = [0u8; 2];
reader.read_exact(&mut buf)?;
Ok(Self::from_le_bytes(buf))
reader.read_u16::<Endianness>()
}
}
@@ -137,12 +135,10 @@ impl FixedSize for u16 {
impl BinarySerializable for u64 {
fn serialize<W: Write + ?Sized>(&self, writer: &mut W) -> io::Result<()> {
writer.write_all(&self.to_le_bytes())
writer.write_u64::<Endianness>(*self)
}
fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
let mut buf = [0u8; 8];
reader.read_exact(&mut buf)?;
Ok(Self::from_le_bytes(buf))
reader.read_u64::<Endianness>()
}
}
@@ -152,12 +148,10 @@ impl FixedSize for u64 {
impl BinarySerializable for u128 {
fn serialize<W: Write + ?Sized>(&self, writer: &mut W) -> io::Result<()> {
writer.write_all(&self.to_le_bytes())
writer.write_u128::<Endianness>(*self)
}
fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
let mut buf = [0u8; 16];
reader.read_exact(&mut buf)?;
Ok(Self::from_le_bytes(buf))
reader.read_u128::<Endianness>()
}
}
@@ -167,12 +161,10 @@ impl FixedSize for u128 {
impl BinarySerializable for f32 {
fn serialize<W: Write + ?Sized>(&self, writer: &mut W) -> io::Result<()> {
writer.write_all(&self.to_le_bytes())
writer.write_f32::<Endianness>(*self)
}
fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
let mut buf = [0u8; 4];
reader.read_exact(&mut buf)?;
Ok(Self::from_le_bytes(buf))
reader.read_f32::<Endianness>()
}
}
@@ -182,12 +174,10 @@ impl FixedSize for f32 {
impl BinarySerializable for i64 {
fn serialize<W: Write + ?Sized>(&self, writer: &mut W) -> io::Result<()> {
writer.write_all(&self.to_le_bytes())
writer.write_i64::<Endianness>(*self)
}
fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
let mut buf = [0u8; Self::SIZE_IN_BYTES];
reader.read_exact(&mut buf)?;
Ok(Self::from_le_bytes(buf))
reader.read_i64::<Endianness>()
}
}
@@ -197,12 +187,10 @@ impl FixedSize for i64 {
impl BinarySerializable for f64 {
fn serialize<W: Write + ?Sized>(&self, writer: &mut W) -> io::Result<()> {
writer.write_all(&self.to_le_bytes())
writer.write_f64::<Endianness>(*self)
}
fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
let mut buf = [0u8; Self::SIZE_IN_BYTES];
reader.read_exact(&mut buf)?;
Ok(Self::from_le_bytes(buf))
reader.read_f64::<Endianness>()
}
}
@@ -212,12 +200,10 @@ impl FixedSize for f64 {
impl BinarySerializable for u8 {
fn serialize<W: Write + ?Sized>(&self, writer: &mut W) -> io::Result<()> {
writer.write_all(&self.to_le_bytes())
writer.write_u8(*self)
}
fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> {
let mut buf = [0u8; Self::SIZE_IN_BYTES];
reader.read_exact(&mut buf)?;
Ok(Self::from_le_bytes(buf))
fn deserialize<R: Read>(reader: &mut R) -> io::Result<u8> {
reader.read_u8()
}
}
@@ -227,10 +213,10 @@ impl FixedSize for u8 {
impl BinarySerializable for bool {
fn serialize<W: Write + ?Sized>(&self, writer: &mut W) -> io::Result<()> {
(*self as u8).serialize(writer)
writer.write_u8(u8::from(*self))
}
fn deserialize<R: Read>(reader: &mut R) -> io::Result<bool> {
let val = u8::deserialize(reader)?;
let val = reader.read_u8()?;
match val {
0 => Ok(false),
1 => Ok(true),

View File

@@ -1,6 +1,8 @@
use std::io;
use std::io::{Read, Write};
use byteorder::{ByteOrder, LittleEndian};
use super::BinarySerializable;
/// Variable int serializes a u128 number
@@ -125,7 +127,7 @@ pub fn serialize_vint_u32(val: u32, buf: &mut [u8; 8]) -> &[u8] {
5,
),
};
buf.copy_from_slice(&res.to_le_bytes());
LittleEndian::write_u64(&mut buf[..], res);
&buf[0..num_bytes]
}

View File

@@ -192,7 +192,7 @@ fn main() -> tantivy::Result<()> {
//
let agg_req: Aggregations = serde_json::from_str(agg_req_str)?;
let collector = AggregationCollector::from_aggs(agg_req, None);
let collector = AggregationCollector::from_aggs(agg_req, Default::default());
let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap();
let res2: Value = serde_json::to_value(agg_res)?;
@@ -204,7 +204,7 @@ fn main() -> tantivy::Result<()> {
let agg_req: Aggregations = vec![(
"group_by_stock".to_string(),
Aggregation::Bucket(BucketAggregation {
Aggregation::Bucket(Box::new(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "stock".to_string(),
ranges: vec![
@@ -234,12 +234,12 @@ fn main() -> tantivy::Result<()> {
)]
.into_iter()
.collect(),
}),
})),
)]
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req, None);
let collector = AggregationCollector::from_aggs(agg_req, Default::default());
// We use the `AllQuery` which will pass all documents to the AggregationCollector.
let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap();
@@ -287,7 +287,7 @@ fn main() -> tantivy::Result<()> {
let agg_req: Aggregations = serde_json::from_str(agg_req_str)?;
let collector = AggregationCollector::from_aggs(agg_req, None);
let collector = AggregationCollector::from_aggs(agg_req, Default::default());
let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap();
let res: Value = serde_json::to_value(agg_res)?;

View File

@@ -12,7 +12,7 @@
use tantivy::collector::{Count, TopDocs};
use tantivy::query::TermQuery;
use tantivy::schema::*;
use tantivy::tokenizer::{PreTokenizedString, SimpleTokenizer, Token, Tokenizer};
use tantivy::tokenizer::{PreTokenizedString, SimpleTokenizer, Token, TokenStream, Tokenizer};
use tantivy::{doc, Index, ReloadPolicy};
use tempfile::TempDir;

View File

@@ -50,12 +50,13 @@ fn main() -> tantivy::Result<()> {
// This tokenizer lowers all of the text (to help with stop word matching)
// then removes all instances of `the` and `and` from the corpus
let tokenizer = TextAnalyzer::from(SimpleTokenizer)
let tokenizer = TextAnalyzer::builder(SimpleTokenizer)
.filter(LowerCaser)
.filter(StopWordFilter::remove(vec![
"the".to_string(),
"and".to_string(),
]));
]))
.build();
index.tokenizers().register("stoppy", tokenizer);

View File

@@ -0,0 +1,94 @@
use std::collections::HashMap;
use std::sync::atomic::AtomicU64;
use std::sync::Arc;
use common::ByteCount;
use super::collector::DEFAULT_MEMORY_LIMIT;
use super::{AggregationError, DEFAULT_BUCKET_LIMIT};
use crate::TantivyError;
/// An estimate for memory consumption
pub trait MemoryConsumption {
fn memory_consumption(&self) -> usize;
}
impl<K, V, S> MemoryConsumption for HashMap<K, V, S> {
fn memory_consumption(&self) -> usize {
let num_items = self.capacity();
(std::mem::size_of::<K>() + std::mem::size_of::<V>()) * num_items
}
}
/// Aggregation memory limit after which the request fails. Defaults to DEFAULT_MEMORY_LIMIT
/// (500MB). The limit is shared by all SegmentCollectors
pub struct AggregationLimits {
/// The counter which is shared between the aggregations for one request.
memory_consumption: Arc<AtomicU64>,
/// The memory_limit in bytes
memory_limit: ByteCount,
/// The maximum number of buckets _returned_
/// This is not counting intermediate buckets.
bucket_limit: u32,
}
impl Clone for AggregationLimits {
fn clone(&self) -> Self {
Self {
memory_consumption: Arc::clone(&self.memory_consumption),
memory_limit: self.memory_limit,
bucket_limit: self.bucket_limit,
}
}
}
impl Default for AggregationLimits {
fn default() -> Self {
Self {
memory_consumption: Default::default(),
memory_limit: DEFAULT_MEMORY_LIMIT.into(),
bucket_limit: DEFAULT_BUCKET_LIMIT,
}
}
}
impl AggregationLimits {
/// *memory_limit*
/// memory_limit is defined in bytes.
/// Aggregation fails when the estimated memory consumption of the aggregation is higher than
/// memory_limit.
/// memory_limit will default to `DEFAULT_MEMORY_LIMIT` (500MB)
///
/// *bucket_limit*
/// Limits the maximum number of buckets returned from an aggregation request.
/// bucket_limit will default to `DEFAULT_BUCKET_LIMIT` (65000)
pub fn new(memory_limit: Option<u64>, bucket_limit: Option<u32>) -> Self {
Self {
memory_consumption: Default::default(),
memory_limit: memory_limit.unwrap_or(DEFAULT_MEMORY_LIMIT).into(),
bucket_limit: bucket_limit.unwrap_or(DEFAULT_BUCKET_LIMIT),
}
}
pub(crate) fn validate_memory_consumption(&self) -> crate::Result<()> {
if self.get_memory_consumed() > self.memory_limit {
return Err(TantivyError::AggregationError(
AggregationError::MemoryExceeded {
limit: self.memory_limit,
current: self.get_memory_consumed(),
},
));
}
Ok(())
}
pub(crate) fn add_memory_consumed(&self, num_bytes: u64) {
self.memory_consumption
.fetch_add(num_bytes, std::sync::atomic::Ordering::Relaxed);
}
pub fn get_memory_consumed(&self) -> ByteCount {
self.memory_consumption
.load(std::sync::atomic::Ordering::Relaxed)
.into()
}
pub fn get_bucket_limit(&self) -> u32 {
self.bucket_limit
}
}

View File

@@ -16,14 +16,14 @@
//! let agg_req1: Aggregations = vec![
//! (
//! "range".to_string(),
//! Aggregation::Bucket(BucketAggregation {
//! Aggregation::Bucket(Box::new(BucketAggregation {
//! bucket_agg: BucketAggregationType::Range(RangeAggregation{
//! field: "score".to_string(),
//! ranges: vec![(3f64..7f64).into(), (7f64..20f64).into()],
//! keyed: false,
//! }),
//! sub_aggregation: Default::default(),
//! }),
//! })),
//! ),
//! ]
//! .into_iter()
@@ -143,7 +143,7 @@ pub fn get_fast_field_names(aggs: &Aggregations) -> HashSet<String> {
#[serde(untagged)]
pub enum Aggregation {
/// Bucket aggregation, see [`BucketAggregation`] for details.
Bucket(BucketAggregation),
Bucket(Box<BucketAggregation>),
/// Metric aggregation, see [`MetricAggregation`] for details.
Metric(MetricAggregation),
}
@@ -301,7 +301,7 @@ mod tests {
fn serialize_to_json_test() {
let agg_req1: Aggregations = vec![(
"range".to_string(),
Aggregation::Bucket(BucketAggregation {
Aggregation::Bucket(Box::new(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "score".to_string(),
ranges: vec![
@@ -313,7 +313,7 @@ mod tests {
keyed: true,
}),
sub_aggregation: Default::default(),
}),
})),
)]
.into_iter()
.collect();
@@ -351,7 +351,7 @@ mod tests {
let agg_req2: Aggregations = vec![
(
"range".to_string(),
Aggregation::Bucket(BucketAggregation {
Aggregation::Bucket(Box::new(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "score2".to_string(),
ranges: vec![
@@ -363,7 +363,7 @@ mod tests {
..Default::default()
}),
sub_aggregation: Default::default(),
}),
})),
),
(
"metric".to_string(),
@@ -377,7 +377,7 @@ mod tests {
let agg_req1: Aggregations = vec![(
"range".to_string(),
Aggregation::Bucket(BucketAggregation {
Aggregation::Bucket(Box::new(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "score".to_string(),
ranges: vec![
@@ -389,7 +389,7 @@ mod tests {
..Default::default()
}),
sub_aggregation: agg_req2,
}),
})),
)]
.into_iter()
.collect();

View File

@@ -1,9 +1,8 @@
//! This will enhance the request tree with access to the fastfield and metadata.
use std::rc::Rc;
use std::sync::atomic::AtomicU32;
use std::sync::Arc;
use columnar::{Column, ColumnType, StrColumn};
use columnar::{Column, ColumnType, ColumnValues, StrColumn};
use super::agg_req::{Aggregation, Aggregations, BucketAggregationType, MetricAggregation};
use super::bucket::{
@@ -13,9 +12,9 @@ use super::metric::{
AverageAggregation, CountAggregation, MaxAggregation, MinAggregation, StatsAggregation,
SumAggregation,
};
use super::segment_agg_result::BucketCount;
use super::segment_agg_result::AggregationLimits;
use super::VecWithNames;
use crate::{SegmentReader, TantivyError};
use crate::SegmentReader;
#[derive(Clone, Default)]
pub(crate) struct AggregationsWithAccessor {
@@ -45,7 +44,7 @@ pub struct BucketAggregationWithAccessor {
pub(crate) field_type: ColumnType,
pub(crate) bucket_agg: BucketAggregationType,
pub(crate) sub_aggregation: AggregationsWithAccessor,
pub(crate) bucket_count: BucketCount,
pub(crate) limits: AggregationLimits,
}
impl BucketAggregationWithAccessor {
@@ -53,8 +52,7 @@ impl BucketAggregationWithAccessor {
bucket: &BucketAggregationType,
sub_aggregation: &Aggregations,
reader: &SegmentReader,
bucket_count: Rc<AtomicU32>,
max_bucket_count: u32,
limits: AggregationLimits,
) -> crate::Result<BucketAggregationWithAccessor> {
let mut str_dict_column = None;
let (accessor, field_type) = match &bucket {
@@ -82,15 +80,11 @@ impl BucketAggregationWithAccessor {
sub_aggregation: get_aggs_with_accessor_and_validate(
&sub_aggregation,
reader,
bucket_count.clone(),
max_bucket_count,
&limits.clone(),
)?,
bucket_agg: bucket.clone(),
str_dict_column,
bucket_count: BucketCount {
bucket_count,
max_bucket_count,
},
limits,
})
}
}
@@ -130,8 +124,7 @@ impl MetricAggregationWithAccessor {
pub(crate) fn get_aggs_with_accessor_and_validate(
aggs: &Aggregations,
reader: &SegmentReader,
bucket_count: Rc<AtomicU32>,
max_bucket_count: u32,
limits: &AggregationLimits,
) -> crate::Result<AggregationsWithAccessor> {
let mut metrics = vec![];
let mut buckets = vec![];
@@ -143,8 +136,7 @@ pub(crate) fn get_aggs_with_accessor_and_validate(
&bucket.bucket_agg,
&bucket.sub_aggregation,
reader,
Rc::clone(&bucket_count),
max_bucket_count,
limits.clone(),
)?,
)),
Aggregation::Metric(metric) => metrics.push((
@@ -167,8 +159,31 @@ fn get_ff_reader_and_validate(
let ff_fields = reader.fast_fields();
let ff_field_with_type = ff_fields
.u64_lenient_with_type(field_name)?
.ok_or_else(|| {
TantivyError::InvalidArgument(format!("No fast field found for field: {}", field_name))
})?;
.unwrap_or_else(|| (build_empty_column(reader.num_docs()), ColumnType::U64));
Ok(ff_field_with_type)
}
// Empty Column
fn build_empty_column(num_docs: u32) -> Column {
struct EmptyValues;
impl ColumnValues for EmptyValues {
fn get_val(&self, _idx: u32) -> u64 {
unimplemented!("Internal Error: Called get_val of empty column.")
}
fn min_value(&self) -> u64 {
unimplemented!("Internal Error: Called min_value of empty column.")
}
fn max_value(&self) -> u64 {
unimplemented!("Internal Error: Called max_value of empty column.")
}
fn num_vals(&self) -> u32 {
0
}
}
Column {
idx: columnar::ColumnIndex::Empty { num_docs },
values: Arc::new(EmptyValues),
}
}

View File

@@ -11,6 +11,7 @@ use super::agg_req::BucketAggregationInternal;
use super::bucket::GetDocCount;
use super::intermediate_agg_result::{IntermediateBucketResult, IntermediateMetricResult};
use super::metric::{SingleMetricResult, Stats};
use super::segment_agg_result::AggregationLimits;
use super::Key;
use crate::TantivyError;
@@ -19,6 +20,13 @@ use crate::TantivyError;
pub struct AggregationResults(pub FxHashMap<String, AggregationResult>);
impl AggregationResults {
pub(crate) fn get_bucket_count(&self) -> u64 {
self.0
.values()
.map(|agg| agg.get_bucket_count())
.sum::<u64>()
}
pub(crate) fn get_value_from_aggregation(
&self,
name: &str,
@@ -47,6 +55,13 @@ pub enum AggregationResult {
}
impl AggregationResult {
pub(crate) fn get_bucket_count(&self) -> u64 {
match self {
AggregationResult::BucketResult(bucket) => bucket.get_bucket_count(),
AggregationResult::MetricResult(_) => 0,
}
}
pub(crate) fn get_value_from_aggregation(
&self,
_name: &str,
@@ -153,9 +168,28 @@ pub enum BucketResult {
}
impl BucketResult {
pub(crate) fn empty_from_req(req: &BucketAggregationInternal) -> crate::Result<Self> {
pub(crate) fn get_bucket_count(&self) -> u64 {
match self {
BucketResult::Range { buckets } => {
buckets.iter().map(|bucket| bucket.get_bucket_count()).sum()
}
BucketResult::Histogram { buckets } => {
buckets.iter().map(|bucket| bucket.get_bucket_count()).sum()
}
BucketResult::Terms {
buckets,
sum_other_doc_count: _,
doc_count_error_upper_bound: _,
} => buckets.iter().map(|bucket| bucket.get_bucket_count()).sum(),
}
}
pub(crate) fn empty_from_req(
req: &BucketAggregationInternal,
limits: &AggregationLimits,
) -> crate::Result<Self> {
let empty_bucket = IntermediateBucketResult::empty_from_req(&req.bucket_agg);
empty_bucket.into_final_bucket_result(req)
empty_bucket.into_final_bucket_result(req, limits)
}
}
@@ -170,6 +204,15 @@ pub enum BucketEntries<T> {
HashMap(FxHashMap<String, T>),
}
impl<T> BucketEntries<T> {
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = &T> + 'a> {
match self {
BucketEntries::Vec(vec) => Box::new(vec.iter()),
BucketEntries::HashMap(map) => Box::new(map.values()),
}
}
}
/// This is the default entry for a bucket, which contains a key, count, and optionally
/// sub-aggregations.
///
@@ -209,6 +252,11 @@ pub struct BucketEntry {
/// Sub-aggregations in this bucket.
pub sub_aggregation: AggregationResults,
}
impl BucketEntry {
pub(crate) fn get_bucket_count(&self) -> u64 {
1 + self.sub_aggregation.get_bucket_count()
}
}
impl GetDocCount for &BucketEntry {
fn doc_count(&self) -> u64 {
self.doc_count
@@ -272,3 +320,8 @@ pub struct RangeBucketEntry {
#[serde(skip_serializing_if = "Option::is_none")]
pub to_as_string: Option<String>,
}
impl RangeBucketEntry {
pub(crate) fn get_bucket_count(&self) -> u64 {
1 + self.sub_aggregation.get_bucket_count()
}
}

View File

@@ -9,11 +9,12 @@ use crate::aggregation::buf_collector::DOC_BLOCK_SIZE;
use crate::aggregation::collector::AggregationCollector;
use crate::aggregation::intermediate_agg_result::IntermediateAggregationResults;
use crate::aggregation::metric::AverageAggregation;
use crate::aggregation::segment_agg_result::AggregationLimits;
use crate::aggregation::tests::{get_test_index_2_segments, get_test_index_from_values_and_terms};
use crate::aggregation::DistributedAggregationCollector;
use crate::query::{AllQuery, TermQuery};
use crate::schema::IndexRecordOption;
use crate::Term;
use crate::schema::{IndexRecordOption, Schema, FAST};
use crate::{Index, Term};
fn get_avg_req(field_name: &str) -> Aggregation {
Aggregation::Metric(MetricAggregation::Average(
@@ -21,6 +22,10 @@ fn get_avg_req(field_name: &str) -> Aggregation {
))
}
fn get_collector(agg_req: Aggregations) -> AggregationCollector {
AggregationCollector::from_aggs(agg_req, Default::default())
}
// *** EVERY BUCKET-TYPE SHOULD BE TESTED HERE ***
fn test_aggregation_flushing(
merge_segments: bool,
@@ -98,15 +103,18 @@ fn test_aggregation_flushing(
.unwrap();
let agg_res: AggregationResults = if use_distributed_collector {
let collector = DistributedAggregationCollector::from_aggs(agg_req.clone(), None);
let collector = DistributedAggregationCollector::from_aggs(
agg_req.clone(),
AggregationLimits::default(),
);
let searcher = reader.searcher();
let intermediate_agg_result = searcher.search(&AllQuery, &collector).unwrap();
intermediate_agg_result
.into_final_bucket_result(agg_req)
.into_final_bucket_result(agg_req, &Default::default())
.unwrap()
} else {
let collector = AggregationCollector::from_aggs(agg_req, None);
let collector = get_collector(agg_req);
let searcher = reader.searcher();
searcher.search(&AllQuery, &collector).unwrap()
@@ -208,42 +216,42 @@ fn test_aggregation_level1() -> crate::Result<()> {
("average".to_string(), get_avg_req("score")),
(
"range".to_string(),
Aggregation::Bucket(BucketAggregation {
Aggregation::Bucket(Box::new(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "score".to_string(),
ranges: vec![(3f64..7f64).into(), (7f64..20f64).into()],
..Default::default()
}),
sub_aggregation: Default::default(),
}),
})),
),
(
"rangef64".to_string(),
Aggregation::Bucket(BucketAggregation {
Aggregation::Bucket(Box::new(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "score_f64".to_string(),
ranges: vec![(3f64..7f64).into(), (7f64..20f64).into()],
..Default::default()
}),
sub_aggregation: Default::default(),
}),
})),
),
(
"rangei64".to_string(),
Aggregation::Bucket(BucketAggregation {
Aggregation::Bucket(Box::new(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "score_i64".to_string(),
ranges: vec![(3f64..7f64).into(), (7f64..20f64).into()],
..Default::default()
}),
sub_aggregation: Default::default(),
}),
})),
),
]
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req_1, None);
let collector = get_collector(agg_req_1);
let searcher = reader.searcher();
let agg_res: AggregationResults = searcher.search(&term_query, &collector).unwrap();
@@ -308,13 +316,13 @@ fn test_aggregation_level2(
("average_in_range".to_string(), get_avg_req("score")),
(
"term_agg".to_string(),
Aggregation::Bucket(BucketAggregation {
Aggregation::Bucket(Box::new(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "text".to_string(),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
})),
),
]
.into_iter()
@@ -382,7 +390,7 @@ fn test_aggregation_level2(
("average".to_string(), get_avg_req("score")),
(
"range".to_string(),
Aggregation::Bucket(BucketAggregation {
Aggregation::Bucket(Box::new(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "score".to_string(),
ranges: vec![
@@ -393,11 +401,11 @@ fn test_aggregation_level2(
..Default::default()
}),
sub_aggregation: sub_agg_req.clone(),
}),
})),
),
(
"rangef64".to_string(),
Aggregation::Bucket(BucketAggregation {
Aggregation::Bucket(Box::new(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "score_f64".to_string(),
ranges: vec![
@@ -408,11 +416,11 @@ fn test_aggregation_level2(
..Default::default()
}),
sub_aggregation: sub_agg_req.clone(),
}),
})),
),
(
"rangei64".to_string(),
Aggregation::Bucket(BucketAggregation {
Aggregation::Bucket(Box::new(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "score_i64".to_string(),
ranges: vec![
@@ -423,7 +431,7 @@ fn test_aggregation_level2(
..Default::default()
}),
sub_aggregation: sub_agg_req,
}),
})),
),
]
.into_iter()
@@ -432,16 +440,18 @@ fn test_aggregation_level2(
};
let agg_res: AggregationResults = if use_distributed_collector {
let collector = DistributedAggregationCollector::from_aggs(agg_req.clone(), None);
let collector =
DistributedAggregationCollector::from_aggs(agg_req.clone(), Default::default());
let searcher = reader.searcher();
let res = searcher.search(&term_query, &collector).unwrap();
// Test de/serialization roundtrip on intermediate_agg_result
let res: IntermediateAggregationResults =
serde_json::from_str(&serde_json::to_string(&res).unwrap()).unwrap();
res.into_final_bucket_result(agg_req.clone()).unwrap()
res.into_final_bucket_result(agg_req.clone(), &Default::default())
.unwrap()
} else {
let collector = AggregationCollector::from_aggs(agg_req.clone(), None);
let collector = get_collector(agg_req.clone());
let searcher = reader.searcher();
searcher.search(&term_query, &collector).unwrap()
@@ -499,7 +509,7 @@ fn test_aggregation_level2(
);
// Test empty result set
let collector = AggregationCollector::from_aggs(agg_req, None);
let collector = get_collector(agg_req);
let searcher = reader.searcher();
searcher.search(&query_with_no_hits, &collector).unwrap();
@@ -562,34 +572,203 @@ fn test_aggregation_invalid_requests() -> crate::Result<()> {
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req_1, None);
let collector = get_collector(agg_req_1);
let searcher = reader.searcher();
searcher.search(&AllQuery, &collector).unwrap_err()
searcher.search(&AllQuery, &collector)
};
let agg_res = avg_on_field("dummy_text");
let agg_res = avg_on_field("dummy_text").unwrap_err();
assert_eq!(
format!("{:?}", agg_res),
r#"InvalidArgument("No fast field found for field: dummy_text")"#
r#"InvalidArgument("Field \"dummy_text\" is not configured as fast field")"#
);
let agg_res = avg_on_field("not_exist_field");
assert_eq!(
format!("{:?}", agg_res),
r#"InvalidArgument("No fast field found for field: not_exist_field")"#
);
// TODO: This should return an error
// let agg_res = avg_on_field("not_exist_field").unwrap_err();
// assert_eq!(
// format!("{:?}", agg_res),
// r#"InvalidArgument("No fast field found for field: not_exist_field")"#
//);
let agg_res = avg_on_field("ip_addr");
assert_eq!(
format!("{:?}", agg_res),
r#"InvalidArgument("No fast field found for field: ip_addr")"#
);
// TODO: This should return an error
// let agg_res = avg_on_field("ip_addr").unwrap_err();
// assert_eq!(
// format!("{:?}", agg_res),
// r#"InvalidArgument("No fast field found for field: ip_addr")"#
//);
Ok(())
}
#[test]
fn test_aggregation_on_json_object() {
let mut schema_builder = Schema::builder();
let json = schema_builder.add_json_field("json", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests().unwrap();
index_writer
.add_document(doc!(json => json!({"color": "red"})))
.unwrap();
index_writer
.add_document(doc!(json => json!({"color": "blue"})))
.unwrap();
index_writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let agg: Aggregations = vec![(
"jsonagg".to_string(),
Aggregation::Bucket(Box::new(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "json.color".to_string(),
..Default::default()
}),
sub_aggregation: Default::default(),
})),
)]
.into_iter()
.collect();
let aggregation_collector = get_collector(agg);
let aggregation_results = searcher.search(&AllQuery, &aggregation_collector).unwrap();
let aggregation_res_json = serde_json::to_value(aggregation_results).unwrap();
assert_eq!(
&aggregation_res_json,
&serde_json::json!({
"jsonagg": {
"buckets": [
{"doc_count": 1, "key": "blue"},
{"doc_count": 1, "key": "red"}
],
"doc_count_error_upper_bound": 0,
"sum_other_doc_count": 0
}
})
);
}
#[test]
fn test_aggregation_on_json_object_empty_columns() {
let mut schema_builder = Schema::builder();
let json = schema_builder.add_json_field("json", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests().unwrap();
// => Empty column when accessing color
index_writer
.add_document(doc!(json => json!({"price": 10.0})))
.unwrap();
index_writer.commit().unwrap();
// => Empty column when accessing price
index_writer
.add_document(doc!(json => json!({"color": "blue"})))
.unwrap();
index_writer.commit().unwrap();
// => Non Empty columns
index_writer
.add_document(doc!(json => json!({"color": "red", "price": 10.0})))
.unwrap();
index_writer
.add_document(doc!(json => json!({"color": "red", "price": 10.0})))
.unwrap();
index_writer
.add_document(doc!(json => json!({"color": "green", "price": 20.0})))
.unwrap();
index_writer
.add_document(doc!(json => json!({"color": "green", "price": 20.0})))
.unwrap();
index_writer
.add_document(doc!(json => json!({"color": "green", "price": 20.0})))
.unwrap();
index_writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let agg: Aggregations = vec![(
"jsonagg".to_string(),
Aggregation::Bucket(Box::new(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "json.color".to_string(),
..Default::default()
}),
sub_aggregation: Default::default(),
})),
)]
.into_iter()
.collect();
let aggregation_collector = get_collector(agg);
let aggregation_results = searcher.search(&AllQuery, &aggregation_collector).unwrap();
let aggregation_res_json = serde_json::to_value(aggregation_results).unwrap();
assert_eq!(
&aggregation_res_json,
&serde_json::json!({
"jsonagg": {
"buckets": [
{"doc_count": 3, "key": "green"},
{"doc_count": 2, "key": "red"},
{"doc_count": 1, "key": "blue"}
],
"doc_count_error_upper_bound": 0,
"sum_other_doc_count": 0
}
})
);
let agg_req_str = r#"
{
"jsonagg": {
"aggs": {
"min_price": { "min": { "field": "json.price" } }
},
"terms": {
"field": "json.color",
"order": { "min_price": "desc" }
}
}
} "#;
let agg: Aggregations = serde_json::from_str(agg_req_str).unwrap();
let aggregation_collector = get_collector(agg);
let aggregation_results = searcher.search(&AllQuery, &aggregation_collector).unwrap();
let aggregation_res_json = serde_json::to_value(aggregation_results).unwrap();
assert_eq!(
&aggregation_res_json,
&serde_json::json!(
{
"jsonagg": {
"buckets": [
{
"key": "green",
"doc_count": 3,
"min_price": {
"value": 20.0
}
},
{
"key": "red",
"doc_count": 2,
"min_price": {
"value": 10.0
}
},
{
"key": "blue",
"doc_count": 1,
"min_price": {
"value": null
}
}
],
"sum_other_doc_count": 0
}
}
)
);
}
#[cfg(all(test, feature = "unstable"))]
mod bench {
@@ -713,7 +892,7 @@ mod bench {
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req_1, None);
let collector = get_collector(agg_req_1);
let searcher = reader.searcher();
searcher.search(&term_query, &collector).unwrap()
@@ -742,7 +921,7 @@ mod bench {
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req_1, None);
let collector = get_collector(agg_req_1);
let searcher = reader.searcher();
searcher.search(&term_query, &collector).unwrap()
@@ -771,7 +950,7 @@ mod bench {
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req_1, None);
let collector = get_collector(agg_req_1);
let searcher = reader.searcher();
searcher.search(&term_query, &collector).unwrap()
@@ -808,7 +987,7 @@ mod bench {
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req_1, None);
let collector = get_collector(agg_req_1);
let searcher = reader.searcher();
searcher.search(&term_query, &collector).unwrap()
@@ -824,18 +1003,21 @@ mod bench {
b.iter(|| {
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "text_few_terms".to_string(),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "text_few_terms".to_string(),
..Default::default()
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req, None);
let collector = get_collector(agg_req);
let searcher = reader.searcher();
searcher.search(&AllQuery, &collector).unwrap()
@@ -860,18 +1042,21 @@ mod bench {
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "text_many_terms".to_string(),
..Default::default()
}),
sub_aggregation: sub_agg_req,
}),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "text_many_terms".to_string(),
..Default::default()
}),
sub_aggregation: sub_agg_req,
}
.into(),
),
)]
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req, None);
let collector = get_collector(agg_req);
let searcher = reader.searcher();
searcher.search(&AllQuery, &collector).unwrap()
@@ -887,18 +1072,21 @@ mod bench {
b.iter(|| {
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "text_many_terms".to_string(),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "text_many_terms".to_string(),
..Default::default()
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req, None);
let collector = get_collector(agg_req);
let searcher = reader.searcher();
searcher.search(&AllQuery, &collector).unwrap()
@@ -914,22 +1102,25 @@ mod bench {
b.iter(|| {
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "text_many_terms".to_string(),
order: Some(CustomOrder {
order: Order::Desc,
target: OrderTarget::Key,
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "text_many_terms".to_string(),
order: Some(CustomOrder {
order: Order::Desc,
target: OrderTarget::Key,
}),
..Default::default()
}),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req, None);
let collector = get_collector(agg_req);
let searcher = reader.searcher();
searcher.search(&AllQuery, &collector).unwrap()
@@ -945,26 +1136,29 @@ mod bench {
b.iter(|| {
let agg_req_1: Aggregations = vec![(
"rangef64".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "score_f64".to_string(),
ranges: vec![
(3f64..7000f64).into(),
(7000f64..20000f64).into(),
(20000f64..30000f64).into(),
(30000f64..40000f64).into(),
(40000f64..50000f64).into(),
(50000f64..60000f64).into(),
],
..Default::default()
}),
sub_aggregation: Default::default(),
}),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "score_f64".to_string(),
ranges: vec![
(3f64..7000f64).into(),
(7000f64..20000f64).into(),
(20000f64..30000f64).into(),
(30000f64..40000f64).into(),
(40000f64..50000f64).into(),
(50000f64..60000f64).into(),
],
..Default::default()
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req_1, None);
let collector = get_collector(agg_req_1);
let searcher = reader.searcher();
searcher.search(&AllQuery, &collector).unwrap()
@@ -989,26 +1183,29 @@ mod bench {
let agg_req_1: Aggregations = vec![(
"rangef64".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "score_f64".to_string(),
ranges: vec![
(3f64..7000f64).into(),
(7000f64..20000f64).into(),
(20000f64..30000f64).into(),
(30000f64..40000f64).into(),
(40000f64..50000f64).into(),
(50000f64..60000f64).into(),
],
..Default::default()
}),
sub_aggregation: sub_agg_req,
}),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "score_f64".to_string(),
ranges: vec![
(3f64..7000f64).into(),
(7000f64..20000f64).into(),
(20000f64..30000f64).into(),
(30000f64..40000f64).into(),
(40000f64..50000f64).into(),
(50000f64..60000f64).into(),
],
..Default::default()
}),
sub_aggregation: sub_agg_req,
}
.into(),
),
)]
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req_1, None);
let collector = get_collector(agg_req_1);
let searcher = reader.searcher();
searcher.search(&AllQuery, &collector).unwrap()
@@ -1029,24 +1226,26 @@ mod bench {
b.iter(|| {
let agg_req_1: Aggregations = vec![(
"rangef64".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 100f64,
hard_bounds: Some(HistogramBounds {
min: 1000.0,
max: 300_000.0,
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 100f64,
hard_bounds: Some(HistogramBounds {
min: 1000.0,
max: 300_000.0,
}),
..Default::default()
}),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req_1, None);
let collector = get_collector(agg_req_1);
let searcher = reader.searcher();
searcher.search(&AllQuery, &collector).unwrap()
});
@@ -1070,19 +1269,22 @@ mod bench {
let agg_req_1: Aggregations = vec![(
"rangef64".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 100f64, // 1000 buckets
..Default::default()
}),
sub_aggregation: sub_agg_req,
}),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 100f64, // 1000 buckets
..Default::default()
}),
sub_aggregation: sub_agg_req,
}
.into(),
),
)]
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req_1, None);
let collector = get_collector(agg_req_1);
let searcher = reader.searcher();
searcher.search(&AllQuery, &collector).unwrap()
@@ -1098,19 +1300,22 @@ mod bench {
b.iter(|| {
let agg_req_1: Aggregations = vec![(
"rangef64".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 100f64, // 1000 buckets
..Default::default()
}),
sub_aggregation: Default::default(),
}),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 100f64, // 1000 buckets
..Default::default()
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req_1, None);
let collector = get_collector(agg_req_1);
let searcher = reader.searcher();
searcher.search(&AllQuery, &collector).unwrap()
@@ -1148,24 +1353,27 @@ mod bench {
),
(
"rangef64".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "score_f64".to_string(),
ranges: vec![
(3f64..7000f64).into(),
(7000f64..20000f64).into(),
(20000f64..60000f64).into(),
],
..Default::default()
}),
sub_aggregation: sub_agg_req_1,
}),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "score_f64".to_string(),
ranges: vec![
(3f64..7000f64).into(),
(7000f64..20000f64).into(),
(20000f64..60000f64).into(),
],
..Default::default()
}),
sub_aggregation: sub_agg_req_1,
}
.into(),
),
),
]
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req_1, None);
let collector = get_collector(agg_req_1);
let searcher = reader.searcher();
searcher.search(&term_query, &collector).unwrap()

View File

@@ -62,7 +62,9 @@ pub struct DateHistogramAggregationReq {
///
/// Fractional time values are not supported, but you can address this by shifting to another
/// time unit (e.g., `1.5h` could instead be specified as `90m`).
pub fixed_interval: String,
///
/// `Option` for validation, the parameter is not optional
pub fixed_interval: Option<String>,
/// Intervals implicitly defines an absolute grid of buckets `[interval * k, interval * (k +
/// 1))`.
pub offset: Option<String>,
@@ -112,7 +114,7 @@ impl DateHistogramAggregationReq {
self.validate()?;
Ok(HistogramAggregation {
field: self.field.to_string(),
interval: parse_into_microseconds(&self.fixed_interval)? as f64,
interval: parse_into_microseconds(self.fixed_interval.as_ref().unwrap())? as f64,
offset: self
.offset
.as_ref()
@@ -127,11 +129,18 @@ impl DateHistogramAggregationReq {
}
fn validate(&self) -> crate::Result<()> {
if self.interval.is_some() {
if let Some(interval) = self.interval.as_ref() {
return Err(crate::TantivyError::InvalidArgument(format!(
"`interval` parameter {:?} in date histogram is unsupported, only \
`fixed_interval` is supported",
self.interval
interval
)));
}
if let Some(interval) = self.date_interval.as_ref() {
return Err(crate::TantivyError::InvalidArgument(format!(
"`date_interval` parameter {:?} in date histogram is unsupported, only \
`fixed_interval` is supported",
interval
)));
}
if self.format.is_some() {
@@ -140,15 +149,13 @@ impl DateHistogramAggregationReq {
));
}
if self.date_interval.is_some() {
if self.fixed_interval.is_none() {
return Err(crate::TantivyError::InvalidArgument(
"date_interval in date histogram is unsupported, only `fixed_interval` is \
supported"
.to_string(),
"fixed_interval in date histogram is missing".to_string(),
));
}
parse_into_microseconds(&self.fixed_interval)?;
parse_into_microseconds(self.fixed_interval.as_ref().unwrap())?;
Ok(())
}
@@ -470,6 +477,34 @@ mod tests {
});
assert_eq!(res, expected_res);
Ok(())
}
#[test]
fn histogram_test_invalid_req() -> crate::Result<()> {
let docs = vec![];
let index = get_test_index_from_docs(false, &docs)?;
let elasticsearch_compatible_json = json!(
{
"sales_over_time": {
"date_histogram": {
"field": "date",
"interval": "30d",
"offset": "-4d"
}
}
}
);
let agg_req: Aggregations =
serde_json::from_str(&serde_json::to_string(&elasticsearch_compatible_json).unwrap())
.unwrap();
let err = exec_request(agg_req, &index).unwrap_err();
assert_eq!(
err.to_string(),
r#"An invalid argument was passed: '`interval` parameter "30d" in date histogram is unsupported, only `fixed_interval` is supported'"#
);
Ok(())
}
}

View File

@@ -7,6 +7,7 @@ use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use tantivy_bitpacker::minmax;
use crate::aggregation::agg_limits::MemoryConsumption;
use crate::aggregation::agg_req::AggregationsInternal;
use crate::aggregation::agg_req_with_accessor::{
AggregationsWithAccessor, BucketAggregationWithAccessor,
@@ -16,7 +17,7 @@ use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResults, IntermediateBucketResult, IntermediateHistogramBucketEntry,
};
use crate::aggregation::segment_agg_result::{
build_segment_agg_collector, SegmentAggregationCollector,
build_segment_agg_collector, AggregationLimits, SegmentAggregationCollector,
};
use crate::aggregation::{f64_from_fastfield_u64, format_date, VecWithNames};
use crate::{DocId, TantivyError};
@@ -230,6 +231,7 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
})
}
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
@@ -238,6 +240,7 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
self.collect_block(&[doc], agg_with_accessor)
}
#[inline]
fn collect_block(
&mut self,
docs: &[crate::DocId],
@@ -247,6 +250,8 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
let sub_aggregation_accessor =
&agg_with_accessor.buckets.values[self.accessor_idx].sub_aggregation;
let mem_pre = self.get_memory_consumption();
let bounds = self.bounds;
let interval = self.interval;
let offset = self.offset;
@@ -269,6 +274,12 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
}
}
}
let mem_delta = self.get_memory_consumption() - mem_pre;
let limits = &agg_with_accessor.buckets.values[self.accessor_idx].limits;
limits.add_memory_consumed(mem_delta as u64);
limits.validate_memory_consumption()?;
Ok(())
}
@@ -285,6 +296,12 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
}
impl SegmentHistogramCollector {
fn get_memory_consumption(&self) -> usize {
let self_mem = std::mem::size_of::<Self>();
let sub_aggs_mem = self.sub_aggregations.memory_consumption();
let buckets_mem = self.buckets.memory_consumption();
self_mem + sub_aggs_mem + buckets_mem
}
pub fn into_intermediate_bucket_result(
self,
agg_with_accessor: &BucketAggregationWithAccessor,
@@ -387,6 +404,7 @@ fn intermediate_buckets_to_final_buckets_fill_gaps(
buckets: Vec<IntermediateHistogramBucketEntry>,
histogram_req: &HistogramAggregation,
sub_aggregation: &AggregationsInternal,
limits: &AggregationLimits,
) -> crate::Result<Vec<BucketEntry>> {
// Generate the full list of buckets without gaps.
//
@@ -394,7 +412,17 @@ fn intermediate_buckets_to_final_buckets_fill_gaps(
// extended_bounds from the request
let min_max = minmax(buckets.iter().map(|bucket| bucket.key));
// TODO add memory check
// memory check upfront
let (_, first_bucket_num, last_bucket_num) =
generate_bucket_pos_with_opt_minmax(histogram_req, min_max);
let added_buckets = (first_bucket_num..=last_bucket_num)
.count()
.saturating_sub(buckets.len());
limits.add_memory_consumed(
added_buckets as u64 * std::mem::size_of::<IntermediateHistogramBucketEntry>() as u64,
);
limits.validate_memory_consumption()?;
// create buckets
let fill_gaps_buckets = generate_buckets_with_opt_minmax(histogram_req, min_max);
let empty_sub_aggregation = IntermediateAggregationResults::empty_from_req(sub_aggregation);
@@ -423,7 +451,9 @@ fn intermediate_buckets_to_final_buckets_fill_gaps(
sub_aggregation: empty_sub_aggregation.clone(),
},
})
.map(|intermediate_bucket| intermediate_bucket.into_final_bucket_entry(sub_aggregation))
.map(|intermediate_bucket| {
intermediate_bucket.into_final_bucket_entry(sub_aggregation, limits)
})
.collect::<crate::Result<Vec<_>>>()
}
@@ -433,18 +463,26 @@ pub(crate) fn intermediate_histogram_buckets_to_final_buckets(
column_type: Option<ColumnType>,
histogram_req: &HistogramAggregation,
sub_aggregation: &AggregationsInternal,
limits: &AggregationLimits,
) -> crate::Result<Vec<BucketEntry>> {
let mut buckets = if histogram_req.min_doc_count() == 0 {
// With min_doc_count != 0, we may need to add buckets, so that there are no
// gaps, since intermediate result does not contain empty buckets (filtered to
// reduce serialization size).
intermediate_buckets_to_final_buckets_fill_gaps(buckets, histogram_req, sub_aggregation)?
intermediate_buckets_to_final_buckets_fill_gaps(
buckets,
histogram_req,
sub_aggregation,
limits,
)?
} else {
buckets
.into_iter()
.filter(|histogram_bucket| histogram_bucket.doc_count >= histogram_req.min_doc_count())
.map(|histogram_bucket| histogram_bucket.into_final_bucket_entry(sub_aggregation))
.map(|histogram_bucket| {
histogram_bucket.into_final_bucket_entry(sub_aggregation, limits)
})
.collect::<crate::Result<Vec<_>>>()?
};
@@ -483,15 +521,27 @@ fn get_req_min_max(req: &HistogramAggregation, min_max: Option<(f64, f64)>) -> (
/// Generates buckets with req.interval
/// Range is computed for provided min_max and request extended_bounds/hard_bounds
/// returns empty vec when there is no range to span
pub(crate) fn generate_buckets_with_opt_minmax(
pub(crate) fn generate_bucket_pos_with_opt_minmax(
req: &HistogramAggregation,
min_max: Option<(f64, f64)>,
) -> Vec<f64> {
) -> (f64, i64, i64) {
let (min, max) = get_req_min_max(req, min_max);
let offset = req.offset.unwrap_or(0.0);
let first_bucket_num = get_bucket_pos_f64(min, req.interval, offset) as i64;
let last_bucket_num = get_bucket_pos_f64(max, req.interval, offset) as i64;
(offset, first_bucket_num, last_bucket_num)
}
/// Generates buckets with req.interval
/// Range is computed for provided min_max and request extended_bounds/hard_bounds
/// returns empty vec when there is no range to span
pub(crate) fn generate_buckets_with_opt_minmax(
req: &HistogramAggregation,
min_max: Option<(f64, f64)>,
) -> Vec<f64> {
let (offset, first_bucket_num, last_bucket_num) =
generate_bucket_pos_with_opt_minmax(req, min_max);
let mut buckets = Vec::with_capacity((first_bucket_num..=last_bucket_num).count());
for bucket_pos in first_bucket_num..=last_bucket_num {
let bucket_key = bucket_pos as f64 * req.interval + offset;
@@ -513,8 +563,8 @@ mod tests {
};
use crate::aggregation::metric::{AverageAggregation, StatsAggregation};
use crate::aggregation::tests::{
exec_request, exec_request_with_query, get_test_index_2_segments,
get_test_index_from_values, get_test_index_with_num_docs,
exec_request, exec_request_with_query, exec_request_with_query_and_memory_limit,
get_test_index_2_segments, get_test_index_from_values, get_test_index_with_num_docs,
};
#[test]
@@ -525,7 +575,7 @@ mod tests {
let agg_req: Aggregations = vec![(
"my_interval".to_string(),
Aggregation::Bucket(BucketAggregation {
Aggregation::Bucket(Box::new(BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 3.5,
@@ -533,7 +583,7 @@ mod tests {
..Default::default()
}),
sub_aggregation: Default::default(),
}),
})),
)]
.into_iter()
.collect();
@@ -551,7 +601,7 @@ mod tests {
// With offset
let agg_req: Aggregations = vec![(
"my_interval".to_string(),
Aggregation::Bucket(BucketAggregation {
Aggregation::Bucket(Box::new(BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 3.5,
@@ -559,7 +609,7 @@ mod tests {
..Default::default()
}),
sub_aggregation: Default::default(),
}),
})),
)]
.into_iter()
.collect();
@@ -600,14 +650,14 @@ mod tests {
let agg_req: Aggregations = vec![(
"my_interval".to_string(),
Aggregation::Bucket(BucketAggregation {
Aggregation::Bucket(Box::new(BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 1.0,
..Default::default()
}),
sub_aggregation: Default::default(),
}),
})),
)]
.into_iter()
.collect();
@@ -635,14 +685,14 @@ mod tests {
let agg_req: Aggregations = vec![(
"histogram".to_string(),
Aggregation::Bucket(BucketAggregation {
Aggregation::Bucket(Box::new(BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 1.0,
..Default::default()
}),
sub_aggregation: Default::default(),
}),
})),
)]
.into_iter()
.collect();
@@ -659,6 +709,40 @@ mod tests {
Ok(())
}
#[test]
fn histogram_memory_limit() -> crate::Result<()> {
let index = get_test_index_with_num_docs(true, 100)?;
let agg_req: Aggregations = vec![(
"histogram".to_string(),
Aggregation::Bucket(Box::new(BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 0.1,
..Default::default()
}),
sub_aggregation: Default::default(),
})),
)]
.into_iter()
.collect();
let res = exec_request_with_query_and_memory_limit(
agg_req,
&index,
None,
AggregationLimits::new(Some(5_000), None),
)
.unwrap_err();
assert_eq!(
res.to_string(),
"Aborting aggregation because memory limit was exceeded. Limit: 5.00 KB, Current: \
102.48 KB"
);
Ok(())
}
#[test]
fn histogram_merge_test() -> crate::Result<()> {
// Merge buckets counts from different segments
@@ -668,14 +752,14 @@ mod tests {
let agg_req: Aggregations = vec![(
"histogram".to_string(),
Aggregation::Bucket(BucketAggregation {
Aggregation::Bucket(Box::new(BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 1.0,
..Default::default()
}),
sub_aggregation: Default::default(),
}),
})),
)]
.into_iter()
.collect();
@@ -708,7 +792,7 @@ mod tests {
let agg_req: Aggregations = vec![(
"histogram".to_string(),
Aggregation::Bucket(BucketAggregation {
Aggregation::Bucket(Box::new(BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 1.0,
@@ -716,7 +800,7 @@ mod tests {
..Default::default()
}),
sub_aggregation: Default::default(),
}),
})),
)]
.into_iter()
.collect();
@@ -746,7 +830,7 @@ mod tests {
let agg_req: Aggregations = vec![(
"histogram".to_string(),
Aggregation::Bucket(BucketAggregation {
Aggregation::Bucket(Box::new(BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 1.0,
@@ -757,7 +841,7 @@ mod tests {
..Default::default()
}),
sub_aggregation: Default::default(),
}),
})),
)]
.into_iter()
.collect();
@@ -778,7 +862,7 @@ mod tests {
let agg_req: Aggregations = vec![(
"histogram".to_string(),
Aggregation::Bucket(BucketAggregation {
Aggregation::Bucket(Box::new(BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 1.0,
@@ -786,7 +870,7 @@ mod tests {
..Default::default()
}),
sub_aggregation: Default::default(),
}),
})),
)]
.into_iter()
.collect();
@@ -809,7 +893,7 @@ mod tests {
let agg_req: Aggregations = vec![(
"histogram".to_string(),
Aggregation::Bucket(BucketAggregation {
Aggregation::Bucket(Box::new(BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 1.0,
@@ -818,7 +902,7 @@ mod tests {
..Default::default()
}),
sub_aggregation: Default::default(),
}),
})),
)]
.into_iter()
.collect();
@@ -853,18 +937,21 @@ mod tests {
let agg_req: Aggregations = vec![(
"histogram".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 1.0,
hard_bounds: Some(HistogramBounds {
min: 2.0,
max: 12.0,
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 1.0,
hard_bounds: Some(HistogramBounds {
min: 2.0,
max: 12.0,
}),
..Default::default()
}),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
@@ -884,22 +971,25 @@ mod tests {
//
let agg_req: Aggregations = vec![(
"histogram".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 1.0,
hard_bounds: Some(HistogramBounds {
min: 2.0,
max: 12.0,
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 1.0,
hard_bounds: Some(HistogramBounds {
min: 2.0,
max: 12.0,
}),
extended_bounds: Some(HistogramBounds {
min: 2.0,
max: 12.0,
}),
..Default::default()
}),
extended_bounds: Some(HistogramBounds {
min: 2.0,
max: 12.0,
}),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
@@ -918,22 +1008,25 @@ mod tests {
// Invalid request
let agg_req: Aggregations = vec![(
"histogram".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 1.0,
hard_bounds: Some(HistogramBounds {
min: 2.0,
max: 12.0,
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 1.0,
hard_bounds: Some(HistogramBounds {
min: 2.0,
max: 12.0,
}),
extended_bounds: Some(HistogramBounds {
min: 1.0,
max: 12.0,
}),
..Default::default()
}),
extended_bounds: Some(HistogramBounds {
min: 1.0,
max: 12.0,
}),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
@@ -963,14 +1056,17 @@ mod tests {
let agg_req: Aggregations = vec![(
"histogram".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 1.0,
..Default::default()
}),
sub_aggregation: Default::default(),
}),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 1.0,
..Default::default()
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
@@ -1011,18 +1107,21 @@ mod tests {
let agg_req: Aggregations = vec![(
"histogram".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 1.0,
extended_bounds: Some(HistogramBounds {
min: 2.0,
max: 12.0,
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 1.0,
extended_bounds: Some(HistogramBounds {
min: 2.0,
max: 12.0,
}),
..Default::default()
}),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
@@ -1039,19 +1138,22 @@ mod tests {
let agg_req: Aggregations = vec![(
"histogram".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 1.0,
extended_bounds: Some(HistogramBounds { min: 2.0, max: 5.0 }),
hard_bounds: Some(HistogramBounds {
min: 2.0,
max: 12.0,
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 1.0,
extended_bounds: Some(HistogramBounds { min: 2.0, max: 5.0 }),
hard_bounds: Some(HistogramBounds {
min: 2.0,
max: 12.0,
}),
..Default::default()
}),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
@@ -1068,18 +1170,21 @@ mod tests {
// hard_bounds will not extend the result
let agg_req: Aggregations = vec![(
"histogram".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 1.0,
hard_bounds: Some(HistogramBounds {
min: 2.0,
max: 12.0,
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 1.0,
hard_bounds: Some(HistogramBounds {
min: 2.0,
max: 12.0,
}),
..Default::default()
}),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
@@ -1114,18 +1219,21 @@ mod tests {
let agg_req: Aggregations = vec![(
"histogram".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 1.0,
extended_bounds: Some(HistogramBounds {
min: 2.0,
max: 12.0,
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 1.0,
extended_bounds: Some(HistogramBounds {
min: 2.0,
max: 12.0,
}),
..Default::default()
}),
..Default::default()
}),
sub_aggregation: agg_req,
}),
sub_aggregation: agg_req,
}
.into(),
),
)]
.into_iter()
.collect();
@@ -1175,14 +1283,17 @@ mod tests {
let agg_req: Aggregations = vec![(
"histogram".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 100000.0,
..Default::default()
}),
sub_aggregation: Default::default(),
}),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 100000.0,
..Default::default()
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
@@ -1213,14 +1324,17 @@ mod tests {
let agg_req: Aggregations = vec![(
"histogram".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "date".to_string(),
interval: 86400000000.0, // one day in microseconds
..Default::default()
}),
sub_aggregation: Default::default(),
}),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "date".to_string(),
interval: 86400000000.0, // one day in microseconds
..Default::default()
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
@@ -1261,14 +1375,17 @@ mod tests {
let agg_req: Aggregations = vec![(
"histogram".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 0.0,
..Default::default()
}),
sub_aggregation: Default::default(),
}),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 0.0,
..Default::default()
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
@@ -1286,15 +1403,18 @@ mod tests {
let agg_req: Aggregations = vec![(
"histogram".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 50.0,
keyed: true,
..Default::default()
}),
sub_aggregation: Default::default(),
}),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 50.0,
keyed: true,
..Default::default()
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();

View File

@@ -11,7 +11,7 @@ use crate::aggregation::intermediate_agg_result::{
IntermediateRangeBucketResult,
};
use crate::aggregation::segment_agg_result::{
build_segment_agg_collector, BucketCount, SegmentAggregationCollector,
build_segment_agg_collector, AggregationLimits, SegmentAggregationCollector,
};
use crate::aggregation::{
f64_from_fastfield_u64, f64_to_fastfield_u64, format_date, Key, SerializedKey, VecWithNames,
@@ -208,6 +208,7 @@ impl SegmentAggregationCollector for SegmentRangeCollector {
})
}
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
@@ -216,6 +217,7 @@ impl SegmentAggregationCollector for SegmentRangeCollector {
self.collect_block(&[doc], agg_with_accessor)
}
#[inline]
fn collect_block(
&mut self,
docs: &[crate::DocId],
@@ -258,7 +260,7 @@ impl SegmentRangeCollector {
pub(crate) fn from_req_and_validate(
req: &RangeAggregation,
sub_aggregation: &AggregationsWithAccessor,
bucket_count: &BucketCount,
limits: &AggregationLimits,
field_type: ColumnType,
accessor_idx: usize,
) -> crate::Result<Self> {
@@ -302,8 +304,10 @@ impl SegmentRangeCollector {
})
.collect::<crate::Result<_>>()?;
bucket_count.add_count(buckets.len() as u32);
bucket_count.validate_bucket_count()?;
limits.add_memory_consumed(
buckets.len() as u64 * std::mem::size_of::<SegmentRangeAndBucketEntry>() as u64,
);
limits.validate_memory_consumption()?;
Ok(SegmentRangeCollector {
buckets,
@@ -475,14 +479,17 @@ mod tests {
let agg_req: Aggregations = vec![(
"range".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "fraction_f64".to_string(),
ranges: vec![(0f64..0.1f64).into(), (0.1f64..0.2f64).into()],
..Default::default()
}),
sub_aggregation: Default::default(),
}),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "fraction_f64".to_string(),
ranges: vec![(0f64..0.1f64).into(), (0.1f64..0.2f64).into()],
..Default::default()
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
@@ -516,14 +523,17 @@ mod tests {
let agg_req: Aggregations = vec![(
"range".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "fraction_f64".to_string(),
ranges: vec![(0f64..0.1f64).into(), (0.1f64..0.2f64).into()],
..Default::default()
}),
sub_aggregation: sub_agg_req,
}),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "fraction_f64".to_string(),
ranges: vec![(0f64..0.1f64).into(), (0.1f64..0.2f64).into()],
..Default::default()
}),
sub_aggregation: sub_agg_req,
}
.into(),
),
)]
.into_iter()
.collect();
@@ -548,14 +558,17 @@ mod tests {
let agg_req: Aggregations = vec![(
"range".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "fraction_f64".to_string(),
ranges: vec![(0f64..0.1f64).into(), (0.1f64..0.2f64).into()],
keyed: true,
}),
sub_aggregation: Default::default(),
}),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "fraction_f64".to_string(),
ranges: vec![(0f64..0.1f64).into(), (0.1f64..0.2f64).into()],
keyed: true,
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
@@ -585,25 +598,28 @@ mod tests {
let agg_req: Aggregations = vec![(
"range".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "fraction_f64".to_string(),
ranges: vec![
RangeAggregationRange {
key: Some("custom-key-0-to-0.1".to_string()),
from: Some(0f64),
to: Some(0.1f64),
},
RangeAggregationRange {
key: None,
from: Some(0.1f64),
to: Some(0.2f64),
},
],
keyed: false,
}),
sub_aggregation: Default::default(),
}),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "fraction_f64".to_string(),
ranges: vec![
RangeAggregationRange {
key: Some("custom-key-0-to-0.1".to_string()),
from: Some(0f64),
to: Some(0.1f64),
},
RangeAggregationRange {
key: None,
from: Some(0.1f64),
to: Some(0.2f64),
},
],
keyed: false,
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
@@ -642,25 +658,28 @@ mod tests {
let agg_req: Aggregations = vec![(
"date_ranges".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "date".to_string(),
ranges: vec![
RangeAggregationRange {
key: None,
from: None,
to: Some(1546300800000000.0f64),
},
RangeAggregationRange {
key: None,
from: Some(1546300800000000.0f64),
to: Some(1546387200000000.0f64),
},
],
keyed: false,
}),
sub_aggregation: Default::default(),
}),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "date".to_string(),
ranges: vec![
RangeAggregationRange {
key: None,
from: None,
to: Some(1546300800000000.0f64),
},
RangeAggregationRange {
key: None,
from: Some(1546300800000000.0f64),
to: Some(1546387200000000.0f64),
},
],
keyed: false,
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
@@ -704,18 +723,21 @@ mod tests {
let agg_req: Aggregations = vec![(
"range".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "fraction_f64".to_string(),
ranges: vec![RangeAggregationRange {
key: Some("custom-key-0-to-0.1".to_string()),
from: Some(0f64),
to: Some(0.1f64),
}],
keyed: true,
}),
sub_aggregation: Default::default(),
}),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "fraction_f64".to_string(),
ranges: vec![RangeAggregationRange {
key: Some("custom-key-0-to-0.1".to_string()),
from: Some(0f64),
to: Some(0.1f64),
}],
keyed: true,
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();

View File

@@ -53,7 +53,7 @@ use crate::TantivyError;
/// into segment_size.
///
/// Result type is [`BucketResult`](crate::aggregation::agg_result::BucketResult) with
/// [`TermBucketEntry`](crate::aggregation::agg_result::BucketEntry) on the
/// [`BucketEntry`](crate::aggregation::agg_result::BucketEntry) on the
/// `AggregationCollector`.
///
/// Result type is
@@ -205,54 +205,14 @@ impl TermsAggregationInternal {
#[derive(Clone, Debug, Default)]
/// Container to store term_ids/or u64 values and their buckets.
struct TermBuckets {
pub(crate) entries: FxHashMap<u64, TermBucketEntry>,
}
#[derive(Clone, Default)]
struct TermBucketEntry {
doc_count: u64,
sub_aggregations: Option<Box<dyn SegmentAggregationCollector>>,
}
impl Debug for TermBucketEntry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TermBucketEntry")
.field("doc_count", &self.doc_count)
.finish()
}
}
impl TermBucketEntry {
fn from_blueprint(blueprint: &Option<Box<dyn SegmentAggregationCollector>>) -> Self {
Self {
doc_count: 0,
sub_aggregations: blueprint.clone(),
}
}
pub(crate) fn into_intermediate_bucket_entry(
self,
agg_with_accessor: &AggregationsWithAccessor,
) -> crate::Result<IntermediateTermBucketEntry> {
let sub_aggregation = if let Some(sub_aggregation) = self.sub_aggregations {
sub_aggregation.into_intermediate_aggregations_result(agg_with_accessor)?
} else {
Default::default()
};
Ok(IntermediateTermBucketEntry {
doc_count: self.doc_count,
sub_aggregation,
})
}
pub(crate) entries: FxHashMap<u64, u64>,
pub(crate) sub_aggs: FxHashMap<u64, Box<dyn SegmentAggregationCollector>>,
}
impl TermBuckets {
fn force_flush(&mut self, agg_with_accessor: &AggregationsWithAccessor) -> crate::Result<()> {
for entry in &mut self.entries.values_mut() {
if let Some(sub_aggregations) = entry.sub_aggregations.as_mut() {
sub_aggregations.flush(agg_with_accessor)?;
}
for sub_aggregations in &mut self.sub_aggs.values_mut() {
sub_aggregations.as_mut().flush(agg_with_accessor)?;
}
Ok(())
}
@@ -268,6 +228,7 @@ pub struct SegmentTermCollector {
blueprint: Option<Box<dyn SegmentAggregationCollector>>,
field_type: ColumnType,
accessor_idx: usize,
val_cache: Vec<u64>,
}
pub(crate) fn get_agg_name_and_property(name: &str) -> (&str, &str) {
@@ -292,6 +253,7 @@ impl SegmentAggregationCollector for SegmentTermCollector {
})
}
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
@@ -300,6 +262,7 @@ impl SegmentAggregationCollector for SegmentTermCollector {
self.collect_block(&[doc], agg_with_accessor)
}
#[inline]
fn collect_block(
&mut self,
docs: &[crate::DocId],
@@ -310,28 +273,35 @@ impl SegmentAggregationCollector for SegmentTermCollector {
&agg_with_accessor.buckets.values[self.accessor_idx].sub_aggregation;
if accessor.get_cardinality() == Cardinality::Full {
for doc in docs {
let term_id = accessor.values.get_val(*doc);
let entry = self
.term_buckets
.entries
.entry(term_id)
.or_insert_with(|| TermBucketEntry::from_blueprint(&self.blueprint));
entry.doc_count += 1;
if let Some(sub_aggregations) = entry.sub_aggregations.as_mut() {
self.val_cache.resize(docs.len(), 0);
accessor.values.get_vals(docs, &mut self.val_cache);
for term_id in self.val_cache.iter().cloned() {
let entry = self.term_buckets.entries.entry(term_id).or_default();
*entry += 1;
}
// has subagg
if let Some(blueprint) = self.blueprint.as_ref() {
for (doc, term_id) in docs.iter().zip(self.val_cache.iter().cloned()) {
let sub_aggregations = self
.term_buckets
.sub_aggs
.entry(term_id)
.or_insert_with(|| blueprint.clone());
sub_aggregations.collect(*doc, sub_aggregation_accessor)?;
}
}
} else {
for doc in docs {
for term_id in accessor.values_for_doc(*doc) {
let entry = self
.term_buckets
.entries
.entry(term_id)
.or_insert_with(|| TermBucketEntry::from_blueprint(&self.blueprint));
entry.doc_count += 1;
if let Some(sub_aggregations) = entry.sub_aggregations.as_mut() {
let entry = self.term_buckets.entries.entry(term_id).or_default();
*entry += 1;
// TODO: check if seperate loop is faster (may depend on the codec)
if let Some(blueprint) = self.blueprint.as_ref() {
let sub_aggregations = self
.term_buckets
.sub_aggs
.entry(term_id)
.or_insert_with(|| blueprint.clone());
sub_aggregations.collect(*doc, sub_aggregation_accessor)?;
}
}
@@ -386,15 +356,16 @@ impl SegmentTermCollector {
blueprint,
field_type,
accessor_idx,
val_cache: Default::default(),
})
}
#[inline]
pub(crate) fn into_intermediate_bucket_result(
self,
mut self,
agg_with_accessor: &BucketAggregationWithAccessor,
) -> crate::Result<IntermediateBucketResult> {
let mut entries: Vec<(u64, TermBucketEntry)> =
self.term_buckets.entries.into_iter().collect();
let mut entries: Vec<(u64, u64)> = self.term_buckets.entries.into_iter().collect();
let order_by_sub_aggregation =
matches!(self.req.order.target, OrderTarget::SubAggregation(_));
@@ -417,9 +388,9 @@ impl SegmentTermCollector {
}
OrderTarget::Count => {
if self.req.order.order == Order::Desc {
entries.sort_unstable_by_key(|bucket| std::cmp::Reverse(bucket.doc_count()));
entries.sort_unstable_by_key(|bucket| std::cmp::Reverse(bucket.1));
} else {
entries.sort_unstable_by_key(|bucket| bucket.doc_count());
entries.sort_unstable_by_key(|bucket| bucket.1);
}
}
}
@@ -432,6 +403,35 @@ impl SegmentTermCollector {
let mut dict: FxHashMap<Key, IntermediateTermBucketEntry> = Default::default();
dict.reserve(entries.len());
let mut into_intermediate_bucket_entry =
|id, doc_count| -> crate::Result<IntermediateTermBucketEntry> {
let intermediate_entry = if self.blueprint.as_ref().is_some() {
IntermediateTermBucketEntry {
doc_count,
sub_aggregation: self
.term_buckets
.sub_aggs
.remove(&id)
.unwrap_or_else(|| {
panic!(
"Internal Error: could not find subaggregation for id {}",
id
)
})
.into_intermediate_aggregations_result(
&agg_with_accessor.sub_aggregation,
)?,
}
} else {
IntermediateTermBucketEntry {
doc_count,
sub_aggregation: Default::default(),
}
};
Ok(intermediate_entry)
};
if self.field_type == ColumnType::Str {
let term_dict = agg_with_accessor
.str_dict_column
@@ -439,17 +439,17 @@ impl SegmentTermCollector {
.expect("internal error: term dictionary not found for term aggregation");
let mut buffer = String::new();
for (term_id, entry) in entries {
for (term_id, doc_count) in entries {
if !term_dict.ord_to_str(term_id, &mut buffer)? {
return Err(TantivyError::InternalError(format!(
"Couldn't find term_id {} in dict",
term_id
)));
}
dict.insert(
Key::Str(buffer.to_string()),
entry.into_intermediate_bucket_entry(&agg_with_accessor.sub_aggregation)?,
);
let intermediate_entry = into_intermediate_bucket_entry(term_id, doc_count)?;
dict.insert(Key::Str(buffer.to_string()), intermediate_entry);
}
if self.req.min_doc_count == 0 {
// TODO: Handle rev streaming for descending sorting by keys
@@ -468,12 +468,10 @@ impl SegmentTermCollector {
}
}
} else {
for (val, entry) in entries {
for (val, doc_count) in entries {
let intermediate_entry = into_intermediate_bucket_entry(val, doc_count)?;
let val = f64_from_fastfield_u64(val, &self.field_type);
dict.insert(
Key::F64(val),
entry.into_intermediate_bucket_entry(&agg_with_accessor.sub_aggregation)?,
);
dict.insert(Key::F64(val), intermediate_entry);
}
};
@@ -490,14 +488,9 @@ impl SegmentTermCollector {
pub(crate) trait GetDocCount {
fn doc_count(&self) -> u64;
}
impl GetDocCount for (u32, TermBucketEntry) {
impl GetDocCount for (u64, u64) {
fn doc_count(&self) -> u64 {
self.1.doc_count
}
}
impl GetDocCount for (u64, TermBucketEntry) {
fn doc_count(&self) -> u64 {
self.1.doc_count
self.1
}
}
impl GetDocCount for (String, IntermediateTermBucketEntry) {
@@ -559,13 +552,16 @@ mod tests {
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
..Default::default()
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
@@ -581,15 +577,18 @@ mod tests {
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
size: Some(2),
split_size: Some(2),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
size: Some(2),
split_size: Some(2),
..Default::default()
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
@@ -608,20 +607,23 @@ mod tests {
// test min_doc_count
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
size: Some(2),
min_doc_count: Some(3),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
size: Some(2),
min_doc_count: Some(3),
..Default::default()
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
let res = exec_request(agg_req.clone(), &index)?;
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "terma");
assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 5);
assert_eq!(
@@ -676,17 +678,20 @@ mod tests {
// sub agg desc
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
order: Some(CustomOrder {
order: Order::Asc,
target: OrderTarget::Count,
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
order: Some(CustomOrder {
order: Order::Asc,
target: OrderTarget::Count,
}),
..Default::default()
}),
..Default::default()
}),
sub_aggregation: sub_agg.clone(),
}),
sub_aggregation: sub_agg.clone(),
}
.into(),
),
)]
.into_iter()
.collect();
@@ -711,45 +716,54 @@ mod tests {
let agg_req: Aggregations = vec![
(
"my_scores1".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "score".to_string(),
order: Some(CustomOrder {
order: Order::Asc,
target: OrderTarget::Count,
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "score".to_string(),
order: Some(CustomOrder {
order: Order::Asc,
target: OrderTarget::Count,
}),
..Default::default()
}),
..Default::default()
}),
sub_aggregation: sub_agg.clone(),
}),
sub_aggregation: sub_agg.clone(),
}
.into(),
),
),
(
"my_scores2".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "score_f64".to_string(),
order: Some(CustomOrder {
order: Order::Asc,
target: OrderTarget::Count,
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "score_f64".to_string(),
order: Some(CustomOrder {
order: Order::Asc,
target: OrderTarget::Count,
}),
..Default::default()
}),
..Default::default()
}),
sub_aggregation: sub_agg.clone(),
}),
sub_aggregation: sub_agg.clone(),
}
.into(),
),
),
(
"my_scores3".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "score_i64".to_string(),
order: Some(CustomOrder {
order: Order::Asc,
target: OrderTarget::Count,
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "score_i64".to_string(),
order: Some(CustomOrder {
order: Order::Asc,
target: OrderTarget::Count,
}),
..Default::default()
}),
..Default::default()
}),
sub_aggregation: sub_agg,
}),
sub_aggregation: sub_agg,
}
.into(),
),
),
]
.into_iter()
@@ -850,17 +864,20 @@ mod tests {
// sub agg desc
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
order: Some(CustomOrder {
order: Order::Desc,
target: OrderTarget::SubAggregation("avg_score".to_string()),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
order: Some(CustomOrder {
order: Order::Desc,
target: OrderTarget::SubAggregation("avg_score".to_string()),
}),
..Default::default()
}),
..Default::default()
}),
sub_aggregation: sub_agg.clone(),
}),
sub_aggregation: sub_agg.clone(),
}
.into(),
),
)]
.into_iter()
.collect();
@@ -883,17 +900,20 @@ mod tests {
// sub agg asc
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
order: Some(CustomOrder {
order: Order::Asc,
target: OrderTarget::SubAggregation("avg_score".to_string()),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
order: Some(CustomOrder {
order: Order::Asc,
target: OrderTarget::SubAggregation("avg_score".to_string()),
}),
..Default::default()
}),
..Default::default()
}),
sub_aggregation: sub_agg.clone(),
}),
sub_aggregation: sub_agg.clone(),
}
.into(),
),
)]
.into_iter()
.collect();
@@ -917,17 +937,20 @@ mod tests {
// sub agg multi value asc
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
order: Some(CustomOrder {
order: Order::Asc,
target: OrderTarget::SubAggregation("stats_score.avg".to_string()),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
order: Some(CustomOrder {
order: Order::Asc,
target: OrderTarget::SubAggregation("stats_score.avg".to_string()),
}),
..Default::default()
}),
..Default::default()
}),
sub_aggregation: sub_agg.clone(),
}),
sub_aggregation: sub_agg.clone(),
}
.into(),
),
)]
.into_iter()
.collect();
@@ -951,17 +974,20 @@ mod tests {
// sub agg invalid request
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
order: Some(CustomOrder {
order: Order::Asc,
target: OrderTarget::SubAggregation("doesnotexist".to_string()),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
order: Some(CustomOrder {
order: Order::Asc,
target: OrderTarget::SubAggregation("doesnotexist".to_string()),
}),
..Default::default()
}),
..Default::default()
}),
sub_aggregation: sub_agg,
}),
sub_aggregation: sub_agg,
}
.into(),
),
)]
.into_iter()
.collect();
@@ -998,17 +1024,20 @@ mod tests {
// key asc
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
order: Some(CustomOrder {
order: Order::Asc,
target: OrderTarget::Key,
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
order: Some(CustomOrder {
order: Order::Asc,
target: OrderTarget::Key,
}),
..Default::default()
}),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
@@ -1025,18 +1054,21 @@ mod tests {
// key desc and size cut_off
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
order: Some(CustomOrder {
order: Order::Asc,
target: OrderTarget::Key,
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
order: Some(CustomOrder {
order: Order::Asc,
target: OrderTarget::Key,
}),
size: Some(2),
..Default::default()
}),
size: Some(2),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
@@ -1056,19 +1088,22 @@ mod tests {
// key asc and segment_size cut_off
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
order: Some(CustomOrder {
order: Order::Asc,
target: OrderTarget::Key,
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
order: Some(CustomOrder {
order: Order::Asc,
target: OrderTarget::Key,
}),
size: Some(2),
segment_size: Some(2),
..Default::default()
}),
size: Some(2),
segment_size: Some(2),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
@@ -1086,17 +1121,20 @@ mod tests {
// key desc
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
order: Some(CustomOrder {
order: Order::Desc,
target: OrderTarget::Key,
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
order: Some(CustomOrder {
order: Order::Desc,
target: OrderTarget::Key,
}),
..Default::default()
}),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
@@ -1113,18 +1151,21 @@ mod tests {
// key desc, size cut_off
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
order: Some(CustomOrder {
order: Order::Desc,
target: OrderTarget::Key,
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
order: Some(CustomOrder {
order: Order::Desc,
target: OrderTarget::Key,
}),
size: Some(2),
..Default::default()
}),
size: Some(2),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
@@ -1143,19 +1184,22 @@ mod tests {
// key desc, segment_size cut_off
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
order: Some(CustomOrder {
order: Order::Desc,
target: OrderTarget::Key,
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
order: Some(CustomOrder {
order: Order::Desc,
target: OrderTarget::Key,
}),
size: Some(2),
segment_size: Some(2),
..Default::default()
}),
size: Some(2),
segment_size: Some(2),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
@@ -1184,14 +1228,17 @@ mod tests {
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
min_doc_count: Some(0),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
min_doc_count: Some(0),
..Default::default()
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
@@ -1224,15 +1271,18 @@ mod tests {
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
size: Some(2),
segment_size: Some(2),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
size: Some(2),
segment_size: Some(2),
..Default::default()
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
@@ -1254,16 +1304,19 @@ mod tests {
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
size: Some(2),
segment_size: Some(2),
show_term_doc_count_error: Some(false),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
size: Some(2),
segment_size: Some(2),
show_term_doc_count_error: Some(false),
..Default::default()
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
@@ -1316,14 +1369,17 @@ mod tests {
let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "text_id".to_string(),
min_doc_count: Some(0),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "text_id".to_string(),
min_doc_count: Some(0),
..Default::default()
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
@@ -1344,19 +1400,22 @@ mod tests {
fn test_json_format() -> crate::Result<()> {
let agg_req: Aggregations = vec![(
"term_agg_test".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
size: Some(2),
segment_size: Some(2),
order: Some(CustomOrder {
target: OrderTarget::Key,
order: Order::Desc,
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
size: Some(2),
segment_size: Some(2),
order: Some(CustomOrder {
target: OrderTarget::Key,
order: Order::Desc,
}),
..Default::default()
}),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();
@@ -1391,14 +1450,17 @@ mod tests {
// test alias shard_size, split_size
let agg_req: Aggregations = vec![(
"term_agg_test".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
split_size: Some(2),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
split_size: Some(2),
..Default::default()
}),
sub_aggregation: Default::default(),
}
.into(),
),
)]
.into_iter()
.collect();

View File

@@ -34,6 +34,7 @@ impl BufAggregationCollector {
}
impl SegmentAggregationCollector for BufAggregationCollector {
#[inline]
fn into_intermediate_aggregations_result(
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
@@ -41,6 +42,7 @@ impl SegmentAggregationCollector for BufAggregationCollector {
Box::new(self.collector).into_intermediate_aggregations_result(agg_with_accessor)
}
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
@@ -56,17 +58,18 @@ impl SegmentAggregationCollector for BufAggregationCollector {
Ok(())
}
#[inline]
fn collect_block(
&mut self,
docs: &[crate::DocId],
agg_with_accessor: &AggregationsWithAccessor,
) -> crate::Result<()> {
for doc in docs {
self.collect(*doc, agg_with_accessor)?;
}
self.collector.collect_block(docs, agg_with_accessor)?;
Ok(())
}
#[inline]
fn flush(&mut self, agg_with_accessor: &AggregationsWithAccessor) -> crate::Result<()> {
self.collector
.collect_block(&self.staged_docs[..self.num_staged_docs], agg_with_accessor)?;

View File

@@ -1,36 +1,36 @@
use std::rc::Rc;
use super::agg_req::Aggregations;
use super::agg_req_with_accessor::AggregationsWithAccessor;
use super::agg_result::AggregationResults;
use super::buf_collector::BufAggregationCollector;
use super::intermediate_agg_result::IntermediateAggregationResults;
use super::segment_agg_result::{build_segment_agg_collector, SegmentAggregationCollector};
use super::segment_agg_result::{
build_segment_agg_collector, AggregationLimits, SegmentAggregationCollector,
};
use crate::aggregation::agg_req_with_accessor::get_aggs_with_accessor_and_validate;
use crate::collector::{Collector, SegmentCollector};
use crate::{SegmentReader, TantivyError};
use crate::{DocId, SegmentReader, TantivyError};
/// The default max bucket count, before the aggregation fails.
pub const MAX_BUCKET_COUNT: u32 = 65000;
pub const DEFAULT_BUCKET_LIMIT: u32 = 65000;
/// The default memory limit in bytes before the aggregation fails. 500MB
pub const DEFAULT_MEMORY_LIMIT: u64 = 500_000_000;
/// Collector for aggregations.
///
/// The collector collects all aggregations by the underlying aggregation request.
pub struct AggregationCollector {
agg: Aggregations,
max_bucket_count: u32,
limits: AggregationLimits,
}
impl AggregationCollector {
/// Create collector from aggregation request.
///
/// Aggregation fails when the total bucket count is higher than max_bucket_count.
/// max_bucket_count will default to `MAX_BUCKET_COUNT` (65000) when unset
pub fn from_aggs(agg: Aggregations, max_bucket_count: Option<u32>) -> Self {
Self {
agg,
max_bucket_count: max_bucket_count.unwrap_or(MAX_BUCKET_COUNT),
}
/// Aggregation fails when the limits in `AggregationLimits` is exceeded. (memory limit and
/// bucket limit)
pub fn from_aggs(agg: Aggregations, limits: AggregationLimits) -> Self {
Self { agg, limits }
}
}
@@ -44,18 +44,16 @@ impl AggregationCollector {
/// into the final `AggregationResults` via the `into_final_result()` method.
pub struct DistributedAggregationCollector {
agg: Aggregations,
max_bucket_count: u32,
limits: AggregationLimits,
}
impl DistributedAggregationCollector {
/// Create collector from aggregation request.
///
/// max_bucket_count will default to `MAX_BUCKET_COUNT` (65000) when unset
pub fn from_aggs(agg: Aggregations, max_bucket_count: Option<u32>) -> Self {
Self {
agg,
max_bucket_count: max_bucket_count.unwrap_or(MAX_BUCKET_COUNT),
}
/// Aggregation fails when the limits in `AggregationLimits` is exceeded. (memory limit and
/// bucket limit)
pub fn from_aggs(agg: Aggregations, limits: AggregationLimits) -> Self {
Self { agg, limits }
}
}
@@ -69,11 +67,7 @@ impl Collector for DistributedAggregationCollector {
_segment_local_id: crate::SegmentOrdinal,
reader: &crate::SegmentReader,
) -> crate::Result<Self::Child> {
AggregationSegmentCollector::from_agg_req_and_reader(
&self.agg,
reader,
self.max_bucket_count,
)
AggregationSegmentCollector::from_agg_req_and_reader(&self.agg, reader, &self.limits)
}
fn requires_scoring(&self) -> bool {
@@ -98,11 +92,7 @@ impl Collector for AggregationCollector {
_segment_local_id: crate::SegmentOrdinal,
reader: &crate::SegmentReader,
) -> crate::Result<Self::Child> {
AggregationSegmentCollector::from_agg_req_and_reader(
&self.agg,
reader,
self.max_bucket_count,
)
AggregationSegmentCollector::from_agg_req_and_reader(&self.agg, reader, &self.limits)
}
fn requires_scoring(&self) -> bool {
@@ -114,7 +104,7 @@ impl Collector for AggregationCollector {
segment_fruits: Vec<<Self::Child as SegmentCollector>::Fruit>,
) -> crate::Result<Self::Fruit> {
let res = merge_fruits(segment_fruits)?;
res.into_final_bucket_result(self.agg.clone())
res.into_final_bucket_result(self.agg.clone(), &self.limits)
}
}
@@ -135,7 +125,7 @@ fn merge_fruits(
/// `AggregationSegmentCollector` does the aggregation collection on a segment.
pub struct AggregationSegmentCollector {
aggs_with_accessor: AggregationsWithAccessor,
result: BufAggregationCollector,
agg_collector: BufAggregationCollector,
error: Option<TantivyError>,
}
@@ -145,15 +135,14 @@ impl AggregationSegmentCollector {
pub fn from_agg_req_and_reader(
agg: &Aggregations,
reader: &SegmentReader,
max_bucket_count: u32,
limits: &AggregationLimits,
) -> crate::Result<Self> {
let aggs_with_accessor =
get_aggs_with_accessor_and_validate(agg, reader, Rc::default(), max_bucket_count)?;
let aggs_with_accessor = get_aggs_with_accessor_and_validate(agg, reader, limits)?;
let result =
BufAggregationCollector::new(build_segment_agg_collector(&aggs_with_accessor)?);
Ok(AggregationSegmentCollector {
aggs_with_accessor,
result,
agg_collector: result,
error: None,
})
}
@@ -163,11 +152,26 @@ impl SegmentCollector for AggregationSegmentCollector {
type Fruit = crate::Result<IntermediateAggregationResults>;
#[inline]
fn collect(&mut self, doc: crate::DocId, _score: crate::Score) {
fn collect(&mut self, doc: DocId, _score: crate::Score) {
if self.error.is_some() {
return;
}
if let Err(err) = self.result.collect(doc, &self.aggs_with_accessor) {
if let Err(err) = self.agg_collector.collect(doc, &self.aggs_with_accessor) {
self.error = Some(err);
}
}
/// The query pushes the documents to the collector via this method.
///
/// Only valid for Collectors that ignore docs
fn collect_block(&mut self, docs: &[DocId]) {
if self.error.is_some() {
return;
}
if let Err(err) = self
.agg_collector
.collect_block(docs, &self.aggs_with_accessor)
{
self.error = Some(err);
}
}
@@ -176,7 +180,7 @@ impl SegmentCollector for AggregationSegmentCollector {
if let Some(err) = self.error {
return Err(err);
}
self.result.flush(&self.aggs_with_accessor)?;
Box::new(self.result).into_intermediate_aggregations_result(&self.aggs_with_accessor)
self.agg_collector.flush(&self.aggs_with_accessor)?;
Box::new(self.agg_collector).into_intermediate_aggregations_result(&self.aggs_with_accessor)
}
}

View File

@@ -1,9 +1,33 @@
use common::ByteCount;
use super::bucket::DateHistogramParseError;
/// Error that may occur when opening a directory
#[derive(Debug, Clone, PartialEq, Eq, Error)]
pub enum AggregationError {
/// Failed to open the directory.
/// Date histogram parse error
#[error("Date histogram parse error: {0:?}")]
DateHistogramParseError(#[from] DateHistogramParseError),
/// Memory limit exceeded
#[error(
"Aborting aggregation because memory limit was exceeded. Limit: {limit:?}, Current: \
{current:?}"
)]
MemoryExceeded {
/// Memory consumption limit
limit: ByteCount,
/// Current memory consumption
current: ByteCount,
},
/// Bucket limit exceeded
#[error(
"Aborting aggregation because bucket limit was exceeded. Limit: {limit:?}, Current: \
{current:?}"
)]
BucketLimitExceeded {
/// Bucket limit
limit: u32,
/// Current num buckets
current: u32,
},
}

View File

@@ -22,9 +22,11 @@ use super::metric::{
IntermediateAverage, IntermediateCount, IntermediateMax, IntermediateMin, IntermediateStats,
IntermediateSum,
};
use super::{format_date, Key, SerializedKey, VecWithNames};
use super::segment_agg_result::AggregationLimits;
use super::{format_date, AggregationError, Key, SerializedKey, VecWithNames};
use crate::aggregation::agg_result::{AggregationResults, BucketEntries, BucketEntry};
use crate::aggregation::bucket::TermsAggregationInternal;
use crate::TantivyError;
/// Contains the intermediate aggregation result, which is optimized to be merged with other
/// intermediate results.
@@ -38,8 +40,23 @@ pub struct IntermediateAggregationResults {
impl IntermediateAggregationResults {
/// Convert intermediate result and its aggregation request to the final result.
pub fn into_final_bucket_result(self, req: Aggregations) -> crate::Result<AggregationResults> {
self.into_final_bucket_result_internal(&(req.into()))
pub fn into_final_bucket_result(
self,
req: Aggregations,
limits: &AggregationLimits,
) -> crate::Result<AggregationResults> {
// TODO count and validate buckets
let res = self.into_final_bucket_result_internal(&(req.into()), limits)?;
let bucket_count = res.get_bucket_count() as u32;
if bucket_count > limits.get_bucket_limit() {
return Err(TantivyError::AggregationError(
AggregationError::BucketLimitExceeded {
limit: limits.get_bucket_limit(),
current: bucket_count,
},
));
}
Ok(res)
}
/// Convert intermediate result and its aggregation request to the final result.
@@ -49,6 +66,7 @@ impl IntermediateAggregationResults {
pub(crate) fn into_final_bucket_result_internal(
self,
req: &AggregationsInternal,
limits: &AggregationLimits,
) -> crate::Result<AggregationResults> {
// Important assumption:
// When the tree contains buckets/metric, we expect it to have all buckets/metrics from the
@@ -56,11 +74,11 @@ impl IntermediateAggregationResults {
let mut results: FxHashMap<String, AggregationResult> = FxHashMap::default();
if let Some(buckets) = self.buckets {
convert_and_add_final_buckets_to_result(&mut results, buckets, &req.buckets)?
convert_and_add_final_buckets_to_result(&mut results, buckets, &req.buckets, limits)?
} else {
// When there are no buckets, we create empty buckets, so that the serialized json
// format is constant
add_empty_final_buckets_to_result(&mut results, &req.buckets)?
add_empty_final_buckets_to_result(&mut results, &req.buckets, limits)?
};
if let Some(metrics) = self.metrics {
@@ -161,10 +179,12 @@ fn add_empty_final_metrics_to_result(
fn add_empty_final_buckets_to_result(
results: &mut FxHashMap<String, AggregationResult>,
req_buckets: &VecWithNames<BucketAggregationInternal>,
limits: &AggregationLimits,
) -> crate::Result<()> {
let requested_buckets = req_buckets.iter();
for (key, req) in requested_buckets {
let empty_bucket = AggregationResult::BucketResult(BucketResult::empty_from_req(req)?);
let empty_bucket =
AggregationResult::BucketResult(BucketResult::empty_from_req(req, limits)?);
results.insert(key.to_string(), empty_bucket);
}
Ok(())
@@ -174,12 +194,13 @@ fn convert_and_add_final_buckets_to_result(
results: &mut FxHashMap<String, AggregationResult>,
buckets: VecWithNames<IntermediateBucketResult>,
req_buckets: &VecWithNames<BucketAggregationInternal>,
limits: &AggregationLimits,
) -> crate::Result<()> {
assert_eq!(buckets.len(), req_buckets.len());
let buckets_with_request = buckets.into_iter().zip(req_buckets.values());
for ((key, bucket), req) in buckets_with_request {
let result = AggregationResult::BucketResult(bucket.into_final_bucket_result(req)?);
let result = AggregationResult::BucketResult(bucket.into_final_bucket_result(req, limits)?);
results.insert(key, result);
}
Ok(())
@@ -287,6 +308,7 @@ impl IntermediateBucketResult {
pub(crate) fn into_final_bucket_result(
self,
req: &BucketAggregationInternal,
limits: &AggregationLimits,
) -> crate::Result<BucketResult> {
match self {
IntermediateBucketResult::Range(range_res) => {
@@ -299,6 +321,7 @@ impl IntermediateBucketResult {
req.as_range()
.expect("unexpected aggregation, expected histogram aggregation"),
range_res.column_type,
limits,
)
})
.collect::<crate::Result<Vec<_>>>()?;
@@ -337,6 +360,7 @@ impl IntermediateBucketResult {
column_type,
histogram_req,
&req.sub_aggregation,
limits,
)?;
let buckets = if histogram_req.keyed {
@@ -355,6 +379,7 @@ impl IntermediateBucketResult {
req.as_term()
.expect("unexpected aggregation, expected term aggregation"),
&req.sub_aggregation,
limits,
),
}
}
@@ -449,6 +474,7 @@ impl IntermediateTermBucketResult {
self,
req: &TermsAggregation,
sub_aggregation_req: &AggregationsInternal,
limits: &AggregationLimits,
) -> crate::Result<BucketResult> {
let req = TermsAggregationInternal::from_req(req);
let mut buckets: Vec<BucketEntry> = self
@@ -462,7 +488,7 @@ impl IntermediateTermBucketResult {
doc_count: entry.doc_count,
sub_aggregation: entry
.sub_aggregation
.into_final_bucket_result_internal(sub_aggregation_req)?,
.into_final_bucket_result_internal(sub_aggregation_req, limits)?,
})
})
.collect::<crate::Result<_>>()?;
@@ -494,7 +520,7 @@ impl IntermediateTermBucketResult {
let val = bucket
.sub_aggregation
.get_value_from_aggregation(agg_name, agg_property)?
.unwrap_or(f64::NAN);
.unwrap_or(f64::MIN);
Ok((bucket, val))
})
.collect::<crate::Result<Vec<_>>>()?;
@@ -582,6 +608,7 @@ impl IntermediateHistogramBucketEntry {
pub(crate) fn into_final_bucket_entry(
self,
req: &AggregationsInternal,
limits: &AggregationLimits,
) -> crate::Result<BucketEntry> {
Ok(BucketEntry {
key_as_string: None,
@@ -589,7 +616,7 @@ impl IntermediateHistogramBucketEntry {
doc_count: self.doc_count,
sub_aggregation: self
.sub_aggregation
.into_final_bucket_result_internal(req)?,
.into_final_bucket_result_internal(req, limits)?,
})
}
}
@@ -628,13 +655,14 @@ impl IntermediateRangeBucketEntry {
req: &AggregationsInternal,
_range_req: &RangeAggregation,
column_type: Option<ColumnType>,
limits: &AggregationLimits,
) -> crate::Result<RangeBucketEntry> {
let mut range_bucket_entry = RangeBucketEntry {
key: self.key,
doc_count: self.doc_count,
sub_aggregation: self
.sub_aggregation
.into_final_bucket_result_internal(req)?,
.into_final_bucket_result_internal(req, limits)?,
to: self.to,
from: self.from,
to_as_string: None,

View File

@@ -81,7 +81,7 @@ mod tests {
"price_sum": { "sum": { "field": "price" } }
}"#;
let aggregations: Aggregations = serde_json::from_str(aggregations_json).unwrap();
let collector = AggregationCollector::from_aggs(aggregations, None);
let collector = AggregationCollector::from_aggs(aggregations, Default::default());
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let aggregations_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap();

View File

@@ -156,6 +156,7 @@ pub(crate) struct SegmentStatsCollector {
pub(crate) collecting_for: SegmentStatsType,
pub(crate) stats: IntermediateStats,
pub(crate) accessor_idx: usize,
val_cache: Vec<u64>,
}
impl SegmentStatsCollector {
@@ -169,14 +170,16 @@ impl SegmentStatsCollector {
collecting_for,
stats: IntermediateStats::default(),
accessor_idx,
val_cache: Default::default(),
}
}
#[inline]
pub(crate) fn collect_block_with_field(&mut self, docs: &[DocId], field: &Column<u64>) {
if field.get_cardinality() == Cardinality::Full {
for doc in docs {
let val = field.values.get_val(*doc);
let val1 = f64_from_fastfield_u64(val, &self.field_type);
self.val_cache.resize(docs.len(), 0);
field.values.get_vals(docs, &mut self.val_cache);
for val in self.val_cache.iter() {
let val1 = f64_from_fastfield_u64(*val, &self.field_type);
self.stats.collect(val1);
}
} else {
@@ -191,6 +194,7 @@ impl SegmentStatsCollector {
}
impl SegmentAggregationCollector for SegmentStatsCollector {
#[inline]
fn into_intermediate_aggregations_result(
self: Box<Self>,
agg_with_accessor: &AggregationsWithAccessor,
@@ -227,6 +231,7 @@ impl SegmentAggregationCollector for SegmentStatsCollector {
})
}
#[inline]
fn collect(
&mut self,
doc: crate::DocId,
@@ -289,7 +294,7 @@ mod tests {
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req_1, None);
let collector = AggregationCollector::from_aggs(agg_req_1, Default::default());
let reader = index.reader()?;
let searcher = reader.searcher();
@@ -326,7 +331,7 @@ mod tests {
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req_1, None);
let collector = AggregationCollector::from_aggs(agg_req_1, Default::default());
let reader = index.reader()?;
let searcher = reader.searcher();
@@ -380,30 +385,33 @@ mod tests {
),
(
"range".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "score".to_string(),
ranges: vec![
(3f64..7f64).into(),
(7f64..19f64).into(),
(19f64..20f64).into(),
],
..Default::default()
}),
sub_aggregation: iter::once((
"stats".to_string(),
Aggregation::Metric(MetricAggregation::Stats(
StatsAggregation::from_field_name("score".to_string()),
)),
))
.collect(),
}),
Aggregation::Bucket(
BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "score".to_string(),
ranges: vec![
(3f64..7f64).into(),
(7f64..19f64).into(),
(19f64..20f64).into(),
],
..Default::default()
}),
sub_aggregation: iter::once((
"stats".to_string(),
Aggregation::Metric(MetricAggregation::Stats(
StatsAggregation::from_field_name("score".to_string()),
)),
))
.collect(),
}
.into(),
),
),
]
.into_iter()
.collect();
let collector = AggregationCollector::from_aggs(agg_req_1, None);
let collector = AggregationCollector::from_aggs(agg_req_1, Default::default());
let searcher = reader.searcher();
let agg_res: AggregationResults = searcher.search(&term_query, &collector).unwrap();

View File

@@ -70,7 +70,7 @@
//! .into_iter()
//! .collect();
//!
//! let collector = AggregationCollector::from_aggs(agg_req, None);
//! let collector = AggregationCollector::from_aggs(agg_req, Default::default());
//!
//! let searcher = reader.searcher();
//! let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap();
@@ -130,14 +130,14 @@
//! let agg_req_1: Aggregations = vec![
//! (
//! "range".to_string(),
//! Aggregation::Bucket(BucketAggregation {
//! Aggregation::Bucket(Box::new(BucketAggregation {
//! bucket_agg: BucketAggregationType::Range(RangeAggregation{
//! field: "score".to_string(),
//! ranges: vec![(3f64..7f64).into(), (7f64..20f64).into()],
//! keyed: false,
//! }),
//! sub_aggregation: sub_agg_req_1.clone(),
//! }),
//! })),
//! ),
//! ]
//! .into_iter()
@@ -155,6 +155,7 @@
//! [`AggregationResults`](agg_result::AggregationResults) via the
//! [`into_final_bucket_result`](intermediate_agg_result::IntermediateAggregationResults::into_final_bucket_result) method.
mod agg_limits;
pub mod agg_req;
mod agg_req_with_accessor;
pub mod agg_result;
@@ -165,6 +166,7 @@ mod date;
mod error;
pub mod intermediate_agg_result;
pub mod metric;
mod segment_agg_result;
use std::collections::HashMap;
use std::fmt::Display;
@@ -172,9 +174,10 @@ use std::fmt::Display;
#[cfg(test)]
mod agg_tests;
pub use agg_limits::AggregationLimits;
pub use collector::{
AggregationCollector, AggregationSegmentCollector, DistributedAggregationCollector,
MAX_BUCKET_COUNT,
DEFAULT_BUCKET_LIMIT,
};
use columnar::{ColumnType, MonotonicallyMappableToU64};
pub(crate) use date::format_date;
@@ -345,9 +348,8 @@ mod tests {
use time::OffsetDateTime;
use super::agg_req::Aggregations;
use super::segment_agg_result::AggregationLimits;
use super::*;
use crate::aggregation::agg_req::{Aggregation, BucketAggregation, BucketAggregationType};
use crate::aggregation::bucket::TermsAggregation;
use crate::indexer::NoMergePolicy;
use crate::query::{AllQuery, TermQuery};
use crate::schema::{IndexRecordOption, Schema, TextFieldIndexing, FAST, STRING};
@@ -371,7 +373,16 @@ mod tests {
index: &Index,
query: Option<(&str, &str)>,
) -> crate::Result<Value> {
let collector = AggregationCollector::from_aggs(agg_req, None);
exec_request_with_query_and_memory_limit(agg_req, index, query, Default::default())
}
pub fn exec_request_with_query_and_memory_limit(
agg_req: Aggregations,
index: &Index,
query: Option<(&str, &str)>,
limits: AggregationLimits,
) -> crate::Result<Value> {
let collector = AggregationCollector::from_aggs(agg_req, limits);
let reader = index.reader()?;
let searcher = reader.searcher();
@@ -595,50 +606,4 @@ mod tests {
Ok(index)
}
#[test]
fn test_aggregation_on_json_object() {
let mut schema_builder = Schema::builder();
let json = schema_builder.add_json_field("json", FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer = index.writer_for_tests().unwrap();
index_writer
.add_document(doc!(json => json!({"color": "red"})))
.unwrap();
index_writer
.add_document(doc!(json => json!({"color": "blue"})))
.unwrap();
index_writer.commit().unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let agg: Aggregations = vec![(
"jsonagg".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "json.color".to_string(),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
)]
.into_iter()
.collect();
let aggregation_collector = AggregationCollector::from_aggs(agg, None);
let aggregation_results = searcher.search(&AllQuery, &aggregation_collector).unwrap();
let aggregation_res_json = serde_json::to_value(aggregation_results).unwrap();
assert_eq!(
&aggregation_res_json,
&serde_json::json!({
"jsonagg": {
"buckets": [
{"doc_count": 1, "key": "blue"},
{"doc_count": 1, "key": "red"}
],
"doc_count_error_upper_bound": 0,
"sum_other_doc_count": 0
}
})
);
}
}

View File

@@ -4,15 +4,13 @@
//! merging.
use std::fmt::Debug;
use std::rc::Rc;
use std::sync::atomic::AtomicU32;
pub(crate) use super::agg_limits::AggregationLimits;
use super::agg_req::MetricAggregation;
use super::agg_req_with_accessor::{
AggregationsWithAccessor, BucketAggregationWithAccessor, MetricAggregationWithAccessor,
};
use super::bucket::{SegmentHistogramCollector, SegmentRangeCollector, SegmentTermCollector};
use super::collector::MAX_BUCKET_COUNT;
use super::intermediate_agg_result::IntermediateAggregationResults;
use super::metric::{
AverageAggregation, CountAggregation, MaxAggregation, MinAggregation, SegmentStatsCollector,
@@ -20,7 +18,6 @@ use super::metric::{
};
use super::VecWithNames;
use crate::aggregation::agg_req::BucketAggregationType;
use crate::TantivyError;
pub(crate) trait SegmentAggregationCollector: CollectorClone + Debug {
fn into_intermediate_aggregations_result(
@@ -131,7 +128,7 @@ pub(crate) fn build_bucket_segment_agg_collector(
Ok(Box::new(SegmentRangeCollector::from_req_and_validate(
range_req,
&req.sub_aggregation,
&req.bucket_count,
&req.limits,
req.field_type,
accessor_idx,
)?))
@@ -284,37 +281,3 @@ impl GenericSegmentAggregationResultsCollector {
Ok(GenericSegmentAggregationResultsCollector { metrics, buckets })
}
}
#[derive(Clone)]
pub(crate) struct BucketCount {
/// The counter which is shared between the aggregations for one request.
pub(crate) bucket_count: Rc<AtomicU32>,
pub(crate) max_bucket_count: u32,
}
impl Default for BucketCount {
fn default() -> Self {
Self {
bucket_count: Default::default(),
max_bucket_count: MAX_BUCKET_COUNT,
}
}
}
impl BucketCount {
pub(crate) fn validate_bucket_count(&self) -> crate::Result<()> {
if self.get_count() > self.max_bucket_count {
return Err(TantivyError::InvalidArgument(
"Aborting aggregation because too many buckets were created".to_string(),
));
}
Ok(())
}
pub(crate) fn add_count(&self, count: u32) {
self.bucket_count
.fetch_add(count, std::sync::atomic::Ordering::Relaxed);
}
pub(crate) fn get_count(&self) -> u32 {
self.bucket_count.load(std::sync::atomic::Ordering::Relaxed)
}
}

View File

@@ -180,9 +180,11 @@ pub trait Collector: Sync + Send {
})?;
}
(Some(alive_bitset), false) => {
weight.for_each_no_score(reader, &mut |doc| {
if alive_bitset.is_alive(doc) {
segment_collector.collect(doc, 0.0);
weight.for_each_no_score(reader, &mut |docs| {
for doc in docs.iter().cloned() {
if alive_bitset.is_alive(doc) {
segment_collector.collect(doc, 0.0);
}
}
})?;
}
@@ -192,8 +194,8 @@ pub trait Collector: Sync + Send {
})?;
}
(None, false) => {
weight.for_each_no_score(reader, &mut |doc| {
segment_collector.collect(doc, 0.0);
weight.for_each_no_score(reader, &mut |docs| {
segment_collector.collect_block(docs);
})?;
}
}
@@ -270,6 +272,13 @@ pub trait SegmentCollector: 'static {
/// The query pushes the scored document to the collector via this method.
fn collect(&mut self, doc: DocId, score: Score);
/// The query pushes the scored document to the collector via this method.
fn collect_block(&mut self, docs: &[DocId]) {
for doc in docs {
self.collect(*doc, 0.0);
}
}
/// Extract the fruit of the collection from the `SegmentCollector`.
fn harvest(self) -> Self::Fruit;
}

View File

@@ -1030,7 +1030,7 @@ mod tests {
let segment = searcher.segment_reader(0);
let top_collector = TopDocs::with_limit(4).order_by_u64_field(SIZE);
let err = top_collector.for_segment(0, segment).err().unwrap();
assert!(matches!(err, crate::TantivyError::SchemaError(_)));
assert!(matches!(err, crate::TantivyError::InvalidArgument(_)));
Ok(())
}

View File

@@ -327,7 +327,7 @@ impl SegmentReader {
self.alive_bitset_opt
.as_ref()
.map(AliveBitSet::space_usage)
.unwrap_or(0),
.unwrap_or_default(),
))
}
}

View File

@@ -172,7 +172,7 @@ impl CompositeFile {
let mut fields = Vec::new();
for (&field_addr, byte_range) in &self.offsets_index {
let mut field_usage = FieldUsage::empty(field_addr.field);
field_usage.add_field_idx(field_addr.idx, byte_range.len());
field_usage.add_field_idx(field_addr.idx, byte_range.len().into());
fields.push(field_usage);
}
PerFieldSpaceUsage::new(fields)

View File

@@ -9,6 +9,8 @@ use crate::DocId;
/// to compare `[u32; 4]`.
pub const TERMINATED: DocId = i32::MAX as u32;
pub const BUFFER_LEN: usize = 64;
/// Represents an iterable set of sorted doc ids.
pub trait DocSet: Send {
/// Goes to the next element.
@@ -59,7 +61,7 @@ pub trait DocSet: Send {
/// This method is only here for specific high-performance
/// use case where batching. The normal way to
/// go through the `DocId`'s is to call `.advance()`.
fn fill_buffer(&mut self, buffer: &mut [DocId]) -> usize {
fn fill_buffer(&mut self, buffer: &mut [DocId; BUFFER_LEN]) -> usize {
if self.doc() == TERMINATED {
return 0;
}
@@ -149,6 +151,11 @@ impl<TDocSet: DocSet + ?Sized> DocSet for Box<TDocSet> {
unboxed.seek(target)
}
fn fill_buffer(&mut self, buffer: &mut [DocId; BUFFER_LEN]) -> usize {
let unboxed: &mut TDocSet = self.borrow_mut();
unboxed.fill_buffer(buffer)
}
fn doc(&self) -> DocId {
let unboxed: &TDocSet = self.borrow();
unboxed.doc()

View File

@@ -55,7 +55,7 @@ impl fmt::Debug for DataCorruption {
#[derive(Debug, Clone, Error)]
pub enum TantivyError {
/// Error when handling aggregations.
#[error("AggregationError {0:?}")]
#[error(transparent)]
AggregationError(#[from] AggregationError),
/// Failed to open the directory.
#[error("Failed to open the directory: '{0:?}'")]

View File

@@ -1,9 +1,8 @@
use std::io;
use std::io::Write;
use common::{intersect_bitsets, BitSet, OwnedBytes, ReadOnlyBitSet};
use common::{intersect_bitsets, BitSet, ByteCount, OwnedBytes, ReadOnlyBitSet};
use crate::space_usage::ByteCount;
use crate::DocId;
/// Write an alive `BitSet`

View File

@@ -80,7 +80,7 @@ mod tests {
use std::path::Path;
use columnar::{Column, MonotonicallyMappableToU64, StrColumn};
use common::{HasLen, TerminatingWrite};
use common::{ByteCount, HasLen, TerminatingWrite};
use once_cell::sync::Lazy;
use rand::prelude::SliceRandom;
use rand::rngs::StdRng;
@@ -862,16 +862,16 @@ mod tests {
#[test]
pub fn test_gcd_date() {
let size_prec_sec = test_gcd_date_with_codec(DatePrecision::Seconds);
assert!((1000 * 13 / 8..100 + 1000 * 13 / 8).contains(&size_prec_sec)); // 13 bits per val = ceil(log_2(number of seconds in 2hours);
assert!((1000 * 13 / 8..100 + 1000 * 13 / 8).contains(&size_prec_sec.get_bytes())); // 13 bits per val = ceil(log_2(number of seconds in 2hours);
let size_prec_micros = test_gcd_date_with_codec(DatePrecision::Microseconds);
assert!((1000 * 33 / 8..100 + 1000 * 33 / 8).contains(&size_prec_micros));
assert!((1000 * 33 / 8..100 + 1000 * 33 / 8).contains(&size_prec_micros.get_bytes()));
// 33 bits per
// val = ceil(log_2(number
// of microsecsseconds
// in 2hours);
}
fn test_gcd_date_with_codec(precision: DatePrecision) -> usize {
fn test_gcd_date_with_codec(precision: DatePrecision) -> ByteCount {
let mut rng = StdRng::seed_from_u64(2u64);
const T0: i64 = 1_662_345_825_012_529i64;
const ONE_HOUR_IN_MICROSECS: i64 = 3_600 * 1_000_000;
@@ -1068,8 +1068,8 @@ mod tests {
let searcher = index.reader().unwrap().searcher();
let segment_reader = searcher.segment_reader(0u32);
let fast_fields = segment_reader.fast_fields();
let column_without_opt: Option<StrColumn> = fast_fields.str("without.hello").unwrap();
assert!(column_without_opt.is_none());
let column_without_opt = fast_fields.str("without.hello");
assert!(column_without_opt.is_err());
let column_with_opt: Option<StrColumn> = fast_fields.str("with.hello").unwrap();
let column_with: StrColumn = column_with_opt.unwrap();
assert!(column_with.term_ords(0).next().is_none());
@@ -1166,7 +1166,7 @@ mod tests {
let searcher = index.reader().unwrap().searcher();
let fast_field_reader = searcher.segment_reader(0u32).fast_fields();
let column = fast_field_reader
.column_opt::<i64>(&"jsonfield.attr.age")
.column_opt::<i64>("jsonfield.attr.age")
.unwrap()
.unwrap();
let vals: Vec<i64> = column.values_for_doc(0u32).collect();
@@ -1191,7 +1191,7 @@ mod tests {
let searcher = index.reader().unwrap().searcher();
let fast_field_reader = searcher.segment_reader(0u32).fast_fields();
let column = fast_field_reader
.column_opt::<i64>(&"jsonfield.attr.age")
.column_opt::<i64>("jsonfield.attr.age")
.unwrap()
.unwrap();
let vals: Vec<i64> = column.values_for_doc(0u32).collect();

View File

@@ -6,11 +6,13 @@ use columnar::{
BytesColumn, Column, ColumnType, ColumnValues, ColumnarReader, DynamicColumn,
DynamicColumnHandle, HasAssociatedColumnType, StrColumn,
};
use common::ByteCount;
use crate::core::json_utils::encode_column_name;
use crate::directory::FileSlice;
use crate::schema::{Field, FieldEntry, FieldType, Schema};
use crate::space_usage::{FieldUsage, PerFieldSpaceUsage};
use crate::TantivyError;
/// Provides access to all of the BitpackedFastFieldReader.
///
@@ -28,7 +30,7 @@ impl FastFieldReaders {
Ok(FastFieldReaders { columnar, schema })
}
fn resolve_field(&self, column_name: &str) -> Option<String> {
fn resolve_field(&self, column_name: &str) -> crate::Result<Option<String>> {
let default_field_opt: Option<Field> = if cfg!(feature = "quickwit") {
self.schema.get_field("_dynamic").ok()
} else {
@@ -41,7 +43,7 @@ impl FastFieldReaders {
let mut per_field_usages: Vec<FieldUsage> = Default::default();
for (field, field_entry) in schema.fields() {
let column_handles = self.columnar.read_columns(field_entry.name())?;
let num_bytes: usize = column_handles
let num_bytes: ByteCount = column_handles
.iter()
.map(|column_handle| column_handle.num_bytes())
.sum();
@@ -82,27 +84,35 @@ impl FastFieldReaders {
&'a self,
field_name: &'a str,
default_field_opt: Option<Field>,
) -> Option<String> {
let (field, path): (Field, &str) = self
) -> crate::Result<Option<String>> {
let Some((field, path)): Option<(Field, &str)> = self
.schema
.find_field(field_name)
.or_else(|| default_field_opt.map(|default_field| (default_field, field_name)))?;
.or_else(|| default_field_opt.map(|default_field| (default_field, field_name)))
else{
return Ok(None);
};
let field_entry: &FieldEntry = self.schema.get_field_entry(field);
if !field_entry.is_fast() {
return Err(TantivyError::InvalidArgument(format!(
"Field {field_name:?} is not configured as fast field"
)));
}
let field_name = self.schema.get_field_name(field);
if path.is_empty() {
return Some(field_name.to_string());
return Ok(Some(field_name.to_string()));
}
let field_entry: &FieldEntry = self.schema.get_field_entry(field);
let field_type = field_entry.field_type();
match (field_type, path) {
(FieldType::JsonObject(json_options), path) if !path.is_empty() => {
Some(encode_column_name(
Ok(Some(encode_column_name(
field_entry.name(),
path,
json_options.is_expand_dots_enabled(),
))
)))
}
(_, "") => Some(field_entry.name().to_string()),
_ => None,
(_, "") => Ok(Some(field_entry.name().to_string())),
_ => Ok(None),
}
}
@@ -127,9 +137,9 @@ impl FastFieldReaders {
/// Returns the number of `bytes` associated with a column.
///
/// Returns 0 if the column does not exist.
pub fn column_num_bytes(&self, field: &str) -> crate::Result<usize> {
let Some(resolved_field_name) = self.resolve_field(field) else {
return Ok(0);
pub fn column_num_bytes(&self, field: &str) -> crate::Result<ByteCount> {
let Some(resolved_field_name) = self.resolve_field(field)? else {
return Ok(0u64.into());
};
Ok(self
.columnar
@@ -216,7 +226,7 @@ impl FastFieldReaders {
field_name: &str,
column_type: ColumnType,
) -> crate::Result<Option<DynamicColumnHandle>> {
let Some(resolved_field_name) = self.resolve_field(field_name) else {
let Some(resolved_field_name) = self.resolve_field(field_name)? else {
return Ok(None);
};
let dynamic_column_handle_opt = self
@@ -232,7 +242,7 @@ impl FastFieldReaders {
&self,
field_name: &str,
) -> crate::Result<Vec<DynamicColumnHandle>> {
let Some(resolved_field_name) = self.resolve_field(field_name) else {
let Some(resolved_field_name) = self.resolve_field(field_name)? else {
return Ok(Vec::new());
};
let columns = self
@@ -251,12 +261,14 @@ impl FastFieldReaders {
}
/// Returns the `u64` column used to represent any `u64`-mapped typed (i64, u64, f64, DateTime).
///
/// Returns Ok(None) for empty columns
#[doc(hidden)]
pub fn u64_lenient_with_type(
&self,
field_name: &str,
) -> crate::Result<Option<(Column<u64>, ColumnType)>> {
let Some(resolved_field_name) = self.resolve_field(field_name) else {
let Some(resolved_field_name) = self.resolve_field(field_name)? else {
return Ok(None);
};
for col in self.columnar.read_columns(&resolved_field_name)? {
@@ -316,44 +328,57 @@ mod tests {
let reader = searcher.segment_reader(0u32);
let fast_field_readers = reader.fast_fields();
assert_eq!(
fast_field_readers.resolve_column_name_given_default_field("age", None),
fast_field_readers
.resolve_column_name_given_default_field("age", None)
.unwrap(),
Some("age".to_string())
);
assert_eq!(
fast_field_readers.resolve_column_name_given_default_field("age", Some(dynamic_field)),
fast_field_readers
.resolve_column_name_given_default_field("age", Some(dynamic_field))
.unwrap(),
Some("age".to_string())
);
assert_eq!(
fast_field_readers.resolve_column_name_given_default_field(
"json_expand_dots_disabled.attr.color",
None
),
fast_field_readers
.resolve_column_name_given_default_field(
"json_expand_dots_disabled.attr.color",
None
)
.unwrap(),
Some("json_expand_dots_disabled\u{1}attr\u{1}color".to_string())
);
assert_eq!(
fast_field_readers.resolve_column_name_given_default_field(
"json_expand_dots_disabled.attr\\.color",
Some(dynamic_field)
),
fast_field_readers
.resolve_column_name_given_default_field(
"json_expand_dots_disabled.attr\\.color",
Some(dynamic_field)
)
.unwrap(),
Some("json_expand_dots_disabled\u{1}attr.color".to_string())
);
assert_eq!(
fast_field_readers.resolve_column_name_given_default_field(
"json_expand_dots_enabled.attr\\.color",
Some(dynamic_field)
),
fast_field_readers
.resolve_column_name_given_default_field(
"json_expand_dots_enabled.attr\\.color",
Some(dynamic_field)
)
.unwrap(),
Some("json_expand_dots_enabled\u{1}attr\u{1}color".to_string())
);
assert_eq!(
fast_field_readers
.resolve_column_name_given_default_field("notinschema.attr.color", None),
.resolve_column_name_given_default_field("notinschema.attr.color", None)
.unwrap(),
None
);
assert_eq!(
fast_field_readers.resolve_column_name_given_default_field(
"notinschema.attr.color",
Some(dynamic_field)
),
fast_field_readers
.resolve_column_name_given_default_field(
"notinschema.attr.color",
Some(dynamic_field)
)
.unwrap(),
Some("_dyna\u{1}notinschema\u{1}attr\u{1}color".to_string())
);
}

View File

@@ -94,10 +94,12 @@ fn compute_deleted_bitset(
// document that were inserted before it.
delete_op
.target
.for_each_no_score(segment_reader, &mut |doc_matching_delete_query| {
if doc_opstamps.is_deleted(doc_matching_delete_query, delete_op.opstamp) {
alive_bitset.remove(doc_matching_delete_query);
might_have_changed = true;
.for_each_no_score(segment_reader, &mut |docs_matching_delete_query| {
for doc_matching_delete_query in docs_matching_delete_query.iter().cloned() {
if doc_opstamps.is_deleted(doc_matching_delete_query, delete_op.opstamp) {
alive_bitset.remove(doc_matching_delete_query);
might_have_changed = true;
}
}
})?;
delete_cursor.advance();

View File

@@ -188,7 +188,7 @@ impl SegmentWriter {
let mut indexing_position = IndexingPosition::default();
postings_writer.index_text(
doc_id,
&mut *facet_tokenizer,
&mut facet_tokenizer,
term_buffer,
ctx,
&mut indexing_position,

View File

@@ -1,5 +1,5 @@
use crate::core::SegmentReader;
use crate::docset::{DocSet, TERMINATED};
use crate::docset::{DocSet, BUFFER_LEN, TERMINATED};
use crate::query::boost_query::BoostScorer;
use crate::query::explanation::does_not_match;
use crate::query::{EnableScoring, Explanation, Query, Scorer, Weight};
@@ -44,6 +44,7 @@ pub struct AllScorer {
}
impl DocSet for AllScorer {
#[inline(always)]
fn advance(&mut self) -> DocId {
if self.doc + 1 >= self.max_doc {
self.doc = TERMINATED;
@@ -53,6 +54,30 @@ impl DocSet for AllScorer {
self.doc
}
fn fill_buffer(&mut self, buffer: &mut [DocId; BUFFER_LEN]) -> usize {
if self.doc() == TERMINATED {
return 0;
}
let is_safe_distance = self.doc() + (buffer.len() as u32) < self.max_doc;
if is_safe_distance {
let num_items = buffer.len();
for buffer_val in buffer {
*buffer_val = self.doc();
self.doc += 1;
}
num_items
} else {
for (i, buffer_val) in buffer.iter_mut().enumerate() {
*buffer_val = self.doc();
if self.advance() == TERMINATED {
return i + 1;
}
}
buffer.len()
}
}
#[inline(always)]
fn doc(&self) -> DocId {
self.doc
}
@@ -71,8 +96,8 @@ impl Scorer for AllScorer {
#[cfg(test)]
mod tests {
use super::AllQuery;
use crate::docset::TERMINATED;
use crate::query::{EnableScoring, Query};
use crate::docset::{DocSet, BUFFER_LEN, TERMINATED};
use crate::query::{AllScorer, EnableScoring, Query};
use crate::schema::{Schema, TEXT};
use crate::Index;
@@ -132,4 +157,22 @@ mod tests {
}
Ok(())
}
#[test]
pub fn test_fill_buffer() {
let mut postings = AllScorer {
doc: 0u32,
max_doc: BUFFER_LEN as u32 * 2 + 9,
};
let mut buffer = [0u32; BUFFER_LEN];
assert_eq!(postings.fill_buffer(&mut buffer), BUFFER_LEN);
for i in 0u32..BUFFER_LEN as u32 {
assert_eq!(buffer[i as usize], i);
}
assert_eq!(postings.fill_buffer(&mut buffer), BUFFER_LEN);
for i in 0u32..BUFFER_LEN as u32 {
assert_eq!(buffer[i as usize], i + BUFFER_LEN as u32);
}
assert_eq!(postings.fill_buffer(&mut buffer), 9);
}
}

View File

@@ -45,6 +45,7 @@ impl From<BitSet> for BitSetDocSet {
}
impl DocSet for BitSetDocSet {
#[inline]
fn advance(&mut self) -> DocId {
if let Some(lower) = self.cursor_tinybitset.pop_lowest() {
self.doc = (self.cursor_bucket * 64u32) | lower;

View File

@@ -1,11 +1,12 @@
use std::collections::HashMap;
use crate::core::SegmentReader;
use crate::docset::BUFFER_LEN;
use crate::postings::FreqReadingOption;
use crate::query::explanation::does_not_match;
use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner};
use crate::query::term_query::TermScorer;
use crate::query::weight::{for_each_docset, for_each_pruning_scorer, for_each_scorer};
use crate::query::weight::{for_each_docset_buffered, for_each_pruning_scorer, for_each_scorer};
use crate::query::{
intersect_scorers, EmptyScorer, Exclude, Explanation, Occur, RequiredOptionalScorer, Scorer,
Union, Weight,
@@ -222,16 +223,18 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
fn for_each_no_score(
&self,
reader: &SegmentReader,
callback: &mut dyn FnMut(DocId),
callback: &mut dyn FnMut(&[DocId]),
) -> crate::Result<()> {
let scorer = self.complex_scorer(reader, 1.0, || DoNothingCombiner)?;
let mut buffer = [0u32; BUFFER_LEN];
match scorer {
SpecializedScorer::TermUnion(term_scorers) => {
let mut union_scorer = Union::build(term_scorers, &self.score_combiner_fn);
for_each_docset(&mut union_scorer, callback);
for_each_docset_buffered(&mut union_scorer, &mut buffer, callback);
}
SpecializedScorer::Other(mut scorer) => {
for_each_docset(scorer.as_mut(), callback);
for_each_docset_buffered(scorer.as_mut(), &mut buffer, callback);
}
}
Ok(())

View File

@@ -1,5 +1,6 @@
use std::fmt;
use crate::docset::BUFFER_LEN;
use crate::fastfield::AliveBitSet;
use crate::query::explanation::does_not_match;
use crate::query::{EnableScoring, Explanation, Query, Scorer, Weight};
@@ -106,7 +107,7 @@ impl<S: Scorer> DocSet for BoostScorer<S> {
self.underlying.seek(target)
}
fn fill_buffer(&mut self, buffer: &mut [DocId]) -> usize {
fn fill_buffer(&mut self, buffer: &mut [DocId; BUFFER_LEN]) -> usize {
self.underlying.fill_buffer(buffer)
}

View File

@@ -1,5 +1,6 @@
use std::fmt;
use crate::docset::BUFFER_LEN;
use crate::query::{EnableScoring, Explanation, Query, Scorer, Weight};
use crate::{DocId, DocSet, Score, SegmentReader, TantivyError, Term};
@@ -119,7 +120,7 @@ impl<TDocSet: DocSet> DocSet for ConstScorer<TDocSet> {
self.docset.seek(target)
}
fn fill_buffer(&mut self, buffer: &mut [DocId]) -> usize {
fn fill_buffer(&mut self, buffer: &mut [DocId; BUFFER_LEN]) -> usize {
self.docset.fill_buffer(buffer)
}

View File

@@ -4,7 +4,9 @@ use std::collections::{BinaryHeap, HashMap};
use crate::query::bm25::idf;
use crate::query::{BooleanQuery, BoostQuery, Occur, Query, TermQuery};
use crate::schema::{Field, FieldType, IndexRecordOption, Term, Value};
use crate::tokenizer::{BoxTokenStream, FacetTokenizer, PreTokenizedStream, Tokenizer};
use crate::tokenizer::{
BoxTokenStream, FacetTokenizer, PreTokenizedStream, TokenStream, Tokenizer,
};
use crate::{DocAddress, Result, Searcher, TantivyError};
#[derive(Debug, PartialEq)]

View File

@@ -913,9 +913,10 @@ mod test {
let tokenizer_manager = TokenizerManager::default();
tokenizer_manager.register(
"en_with_stop_words",
TextAnalyzer::from(SimpleTokenizer)
TextAnalyzer::builder(SimpleTokenizer)
.filter(LowerCaser)
.filter(StopWordFilter::remove(vec!["the".to_string()])),
.filter(StopWordFilter::remove(vec!["the".to_string()]))
.build(),
);
QueryParser::new(schema, default_fields, tokenizer_manager)
}

View File

@@ -1,11 +1,11 @@
use super::term_scorer::TermScorer;
use crate::core::SegmentReader;
use crate::docset::DocSet;
use crate::docset::{DocSet, BUFFER_LEN};
use crate::fieldnorm::FieldNormReader;
use crate::postings::SegmentPostings;
use crate::query::bm25::Bm25Weight;
use crate::query::explanation::does_not_match;
use crate::query::weight::{for_each_docset, for_each_scorer};
use crate::query::weight::{for_each_docset_buffered, for_each_scorer};
use crate::query::{Explanation, Scorer, Weight};
use crate::schema::IndexRecordOption;
use crate::{DocId, Score, Term};
@@ -61,10 +61,11 @@ impl Weight for TermWeight {
fn for_each_no_score(
&self,
reader: &SegmentReader,
callback: &mut dyn FnMut(DocId),
callback: &mut dyn FnMut(&[DocId]),
) -> crate::Result<()> {
let mut scorer = self.specialized_scorer(reader, 1.0)?;
for_each_docset(&mut scorer, callback);
let mut buffer = [0u32; BUFFER_LEN];
for_each_docset_buffered(&mut scorer, &mut buffer, callback);
Ok(())
}

View File

@@ -53,7 +53,7 @@ impl HasLen for VecDocSet {
pub mod tests {
use super::*;
use crate::docset::DocSet;
use crate::docset::{DocSet, BUFFER_LEN};
use crate::DocId;
#[test]
@@ -72,17 +72,17 @@ pub mod tests {
#[test]
pub fn test_fill_buffer() {
let doc_ids: Vec<DocId> = (1u32..210u32).collect();
let doc_ids: Vec<DocId> = (1u32..=(BUFFER_LEN as u32 * 2 + 9)).collect();
let mut postings = VecDocSet::from(doc_ids);
let mut buffer = vec![1000u32; 100];
assert_eq!(postings.fill_buffer(&mut buffer[..]), 100);
for i in 0u32..100u32 {
let mut buffer = [0u32; BUFFER_LEN];
assert_eq!(postings.fill_buffer(&mut buffer), BUFFER_LEN);
for i in 0u32..BUFFER_LEN as u32 {
assert_eq!(buffer[i as usize], i + 1);
}
assert_eq!(postings.fill_buffer(&mut buffer[..]), 100);
for i in 0u32..100u32 {
assert_eq!(buffer[i as usize], i + 101);
assert_eq!(postings.fill_buffer(&mut buffer), BUFFER_LEN);
for i in 0u32..BUFFER_LEN as u32 {
assert_eq!(buffer[i as usize], i + 1 + BUFFER_LEN as u32);
}
assert_eq!(postings.fill_buffer(&mut buffer[..]), 9);
assert_eq!(postings.fill_buffer(&mut buffer), 9);
}
}

View File

@@ -1,5 +1,6 @@
use super::Scorer;
use crate::core::SegmentReader;
use crate::docset::BUFFER_LEN;
use crate::query::Explanation;
use crate::{DocId, DocSet, Score, TERMINATED};
@@ -18,11 +19,18 @@ pub(crate) fn for_each_scorer<TScorer: Scorer + ?Sized>(
/// Iterates through all of the documents matched by the DocSet
/// `DocSet`.
pub(crate) fn for_each_docset<T: DocSet + ?Sized>(docset: &mut T, callback: &mut dyn FnMut(DocId)) {
let mut doc = docset.doc();
while doc != TERMINATED {
callback(doc);
doc = docset.advance();
#[inline]
pub(crate) fn for_each_docset_buffered<T: DocSet + ?Sized>(
docset: &mut T,
buffer: &mut [DocId; BUFFER_LEN],
mut callback: impl FnMut(&[DocId]),
) {
loop {
let num_items = docset.fill_buffer(buffer);
callback(&buffer[..num_items]);
if num_items != buffer.len() {
break;
}
}
}
@@ -93,10 +101,12 @@ pub trait Weight: Send + Sync + 'static {
fn for_each_no_score(
&self,
reader: &SegmentReader,
callback: &mut dyn FnMut(DocId),
callback: &mut dyn FnMut(&[DocId]),
) -> crate::Result<()> {
let mut docset = self.scorer(reader, 1.0)?;
for_each_docset(docset.as_mut(), callback);
let mut buffer = [0u32; BUFFER_LEN];
for_each_docset_buffered(&mut docset, &mut buffer, callback);
Ok(())
}

View File

@@ -364,7 +364,8 @@ where B: AsRef<[u8]>
/// (this does not include the field.)
///
/// If the term is a string, its value is utf-8 encoded.
/// If the term is a u64, its value is encoded in big endian.
/// If the term is a u64, its value is encoded according
/// to `byteorder::BigEndian`.
pub fn value_bytes(&self) -> &[u8] {
&self.0.as_ref()[TERM_METADATA_LENGTH..]
}

View File

@@ -9,14 +9,12 @@
use std::collections::HashMap;
use common::ByteCount;
use serde::{Deserialize, Serialize};
use crate::schema::Field;
use crate::SegmentComponent;
/// Indicates space usage in bytes
pub type ByteCount = usize;
/// Enum containing any of the possible space usage results for segment components.
pub enum ComponentSpaceUsage {
/// Data is stored per field in a uniform way
@@ -38,7 +36,7 @@ impl SearcherSpaceUsage {
pub(crate) fn new() -> SearcherSpaceUsage {
SearcherSpaceUsage {
segments: Vec::new(),
total: 0,
total: Default::default(),
}
}
@@ -260,7 +258,7 @@ impl FieldUsage {
pub(crate) fn empty(field: Field) -> FieldUsage {
FieldUsage {
field,
num_bytes: 0,
num_bytes: Default::default(),
sub_num_bytes: Vec::new(),
}
}
@@ -294,7 +292,7 @@ impl FieldUsage {
mod test {
use crate::core::Index;
use crate::schema::{Field, Schema, FAST, INDEXED, STORED, TEXT};
use crate::space_usage::{ByteCount, PerFieldSpaceUsage};
use crate::space_usage::PerFieldSpaceUsage;
use crate::Term;
#[test]
@@ -304,14 +302,14 @@ mod test {
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let searcher_space_usage = searcher.space_usage().unwrap();
assert_eq!(0, searcher_space_usage.total());
assert_eq!(searcher_space_usage.total(), 0u64);
}
fn expect_single_field(
field_space: &PerFieldSpaceUsage,
field: &Field,
min_size: ByteCount,
max_size: ByteCount,
min_size: u64,
max_size: u64,
) {
assert!(field_space.total() >= min_size);
assert!(field_space.total() <= max_size);
@@ -353,12 +351,12 @@ mod test {
expect_single_field(segment.termdict(), &name, 1, 512);
expect_single_field(segment.postings(), &name, 1, 512);
assert_eq!(0, segment.positions().total());
assert_eq!(segment.positions().total(), 0);
expect_single_field(segment.fast_fields(), &name, 1, 512);
expect_single_field(segment.fieldnorms(), &name, 1, 512);
// TODO: understand why the following fails
// assert_eq!(0, segment.store().total());
assert_eq!(0, segment.deletes());
assert_eq!(segment.deletes(), 0);
Ok(())
}
@@ -394,11 +392,11 @@ mod test {
expect_single_field(segment.termdict(), &name, 1, 512);
expect_single_field(segment.postings(), &name, 1, 512);
expect_single_field(segment.positions(), &name, 1, 512);
assert_eq!(0, segment.fast_fields().total());
assert_eq!(segment.fast_fields().total(), 0);
expect_single_field(segment.fieldnorms(), &name, 1, 512);
// TODO: understand why the following fails
// assert_eq!(0, segment.store().total());
assert_eq!(0, segment.deletes());
assert_eq!(segment.deletes(), 0);
Ok(())
}
@@ -430,14 +428,14 @@ mod test {
assert_eq!(4, segment.num_docs());
assert_eq!(0, segment.termdict().total());
assert_eq!(0, segment.postings().total());
assert_eq!(0, segment.positions().total());
assert_eq!(0, segment.fast_fields().total());
assert_eq!(0, segment.fieldnorms().total());
assert_eq!(segment.termdict().total(), 0);
assert_eq!(segment.postings().total(), 0);
assert_eq!(segment.positions().total(), 0);
assert_eq!(segment.fast_fields().total(), 0);
assert_eq!(segment.fieldnorms().total(), 0);
assert!(segment.store().total() > 0);
assert!(segment.store().total() < 512);
assert_eq!(0, segment.deletes());
assert_eq!(segment.deletes(), 0);
Ok(())
}
@@ -478,8 +476,8 @@ mod test {
expect_single_field(segment_space_usage.termdict(), &name, 1, 512);
expect_single_field(segment_space_usage.postings(), &name, 1, 512);
assert_eq!(0, segment_space_usage.positions().total());
assert_eq!(0, segment_space_usage.fast_fields().total());
assert_eq!(segment_space_usage.positions().total(), 0u64);
assert_eq!(segment_space_usage.fast_fields().total(), 0u64);
expect_single_field(segment_space_usage.fieldnorms(), &name, 1, 512);
assert!(segment_space_usage.deletes() > 0);
Ok(())

View File

@@ -5,7 +5,7 @@ use std::ops::{AddAssign, Range};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use common::{BinarySerializable, HasLen, OwnedBytes};
use common::{BinarySerializable, OwnedBytes};
use lru::LruCache;
use super::footer::DocStoreFooter;
@@ -122,7 +122,8 @@ impl StoreReader {
let (data_file, offset_index_file) = data_and_offset.split(footer.offset as usize);
let index_data = offset_index_file.read_bytes()?;
let space_usage = StoreSpaceUsage::new(data_file.len(), offset_index_file.len());
let space_usage =
StoreSpaceUsage::new(data_file.num_bytes(), offset_index_file.num_bytes());
let skip_index = SkipIndex::open(index_data);
Ok(StoreReader {
decompressor: footer.decompressor,

View File

@@ -1,6 +1,7 @@
use std::cmp;
use std::io::{self, Read, Write};
use byteorder::{ByteOrder, LittleEndian};
use common::{BinarySerializable, FixedSize};
use tantivy_bitpacker::{compute_num_bits, BitPacker};
@@ -103,7 +104,7 @@ fn extract_bits(data: &[u8], addr_bits: usize, num_bits: u8) -> u64 {
let addr_byte = addr_bits / 8;
let bit_shift = (addr_bits % 8) as u64;
let val_unshifted_unmasked: u64 = if data.len() >= addr_byte + 8 {
u64::from_le_bytes(data[addr_byte..][..8].try_into().unwrap())
LittleEndian::read_u64(&data[addr_byte..][..8])
} else {
// the buffer is not large enough.
// Let's copy the few remaining bytes to a 8 byte buffer
@@ -112,7 +113,7 @@ fn extract_bits(data: &[u8], addr_bits: usize, num_bits: u8) -> u64 {
let data_to_copy = &data[addr_byte..];
let nbytes = data_to_copy.len();
buf[..nbytes].copy_from_slice(data_to_copy);
u64::from_le_bytes(buf)
LittleEndian::read_u64(&buf)
};
let val_shifted_unmasked = val_unshifted_unmasked >> bit_shift;
let mask = (1u64 << u64::from(num_bits)) - 1;

View File

@@ -2,16 +2,18 @@
//! ```rust
//! use tantivy::tokenizer::*;
//!
//! let tokenizer = TextAnalyzer::from(RawTokenizer)
//! .filter(AlphaNumOnlyFilter);
//! let tokenizer = TextAnalyzer::builder(RawTokenizer)
//! .filter(AlphaNumOnlyFilter)
//! .build();
//!
//! let mut stream = tokenizer.token_stream("hello there");
//! // is none because the raw filter emits one token that
//! // contains a space
//! assert!(stream.next().is_none());
//!
//! let tokenizer = TextAnalyzer::from(SimpleTokenizer)
//! .filter(AlphaNumOnlyFilter);
//! let tokenizer = TextAnalyzer::builder(SimpleTokenizer)
//! .filter(AlphaNumOnlyFilter)
//! .build();
//!
//! let mut stream = tokenizer.token_stream("hello there 💣");
//! assert!(stream.next().is_some());
@@ -19,30 +21,45 @@
//! // the "emoji" is dropped because its not an alphanum
//! assert!(stream.next().is_none());
//! ```
use super::{BoxTokenStream, Token, TokenFilter, TokenStream};
use super::{Token, TokenFilter, TokenStream, Tokenizer};
/// `TokenFilter` that removes all tokens that contain non
/// ascii alphanumeric characters.
#[derive(Clone)]
pub struct AlphaNumOnlyFilter;
pub struct AlphaNumOnlyFilterStream<'a> {
tail: BoxTokenStream<'a>,
pub struct AlphaNumOnlyFilterStream<T> {
tail: T,
}
impl<'a> AlphaNumOnlyFilterStream<'a> {
impl<T> AlphaNumOnlyFilterStream<T> {
fn predicate(&self, token: &Token) -> bool {
token.text.chars().all(|c| c.is_ascii_alphanumeric())
}
}
impl TokenFilter for AlphaNumOnlyFilter {
fn transform<'a>(&self, token_stream: BoxTokenStream<'a>) -> BoxTokenStream<'a> {
BoxTokenStream::from(AlphaNumOnlyFilterStream { tail: token_stream })
type Tokenizer<T: Tokenizer> = AlphaNumOnlyFilterWrapper<T>;
fn transform<T: Tokenizer>(self, tokenizer: T) -> AlphaNumOnlyFilterWrapper<T> {
AlphaNumOnlyFilterWrapper(tokenizer)
}
}
impl<'a> TokenStream for AlphaNumOnlyFilterStream<'a> {
#[derive(Clone)]
pub struct AlphaNumOnlyFilterWrapper<T>(T);
impl<T: Tokenizer> Tokenizer for AlphaNumOnlyFilterWrapper<T> {
type TokenStream<'a> = AlphaNumOnlyFilterStream<T::TokenStream<'a>>;
fn token_stream<'a>(&self, text: &'a str) -> Self::TokenStream<'a> {
AlphaNumOnlyFilterStream {
tail: self.0.token_stream(text),
}
}
}
impl<T: TokenStream> TokenStream for AlphaNumOnlyFilterStream<T> {
fn advance(&mut self) -> bool {
while self.tail.advance() {
if self.predicate(self.tail.token()) {
@@ -79,7 +96,9 @@ mod tests {
}
fn token_stream_helper(text: &str) -> Vec<Token> {
let a = TextAnalyzer::from(SimpleTokenizer).filter(AlphaNumOnlyFilter);
let a = TextAnalyzer::builder(SimpleTokenizer)
.filter(AlphaNumOnlyFilter)
.build();
let mut token_stream = a.token_stream(text);
let mut tokens: Vec<Token> = vec![];
let mut add_token = |token: &Token| {

View File

@@ -1,6 +1,6 @@
use std::mem;
use super::{BoxTokenStream, Token, TokenFilter, TokenStream};
use super::{Token, TokenFilter, TokenStream, Tokenizer};
/// This class converts alphabetic, numeric, and symbolic Unicode characters
/// which are not in the first 127 ASCII characters (the "Basic Latin" Unicode
@@ -9,20 +9,33 @@ use super::{BoxTokenStream, Token, TokenFilter, TokenStream};
pub struct AsciiFoldingFilter;
impl TokenFilter for AsciiFoldingFilter {
fn transform<'a>(&self, token_stream: BoxTokenStream<'a>) -> BoxTokenStream<'a> {
From::from(AsciiFoldingFilterTokenStream {
tail: token_stream,
buffer: String::with_capacity(100),
})
type Tokenizer<T: Tokenizer> = AsciiFoldingFilterWrapper<T>;
fn transform<T: Tokenizer>(self, tokenizer: T) -> AsciiFoldingFilterWrapper<T> {
AsciiFoldingFilterWrapper(tokenizer)
}
}
pub struct AsciiFoldingFilterTokenStream<'a> {
buffer: String,
tail: BoxTokenStream<'a>,
#[derive(Clone)]
pub struct AsciiFoldingFilterWrapper<T>(T);
impl<T: Tokenizer> Tokenizer for AsciiFoldingFilterWrapper<T> {
type TokenStream<'a> = AsciiFoldingFilterTokenStream<T::TokenStream<'a>>;
fn token_stream<'a>(&self, text: &'a str) -> Self::TokenStream<'a> {
AsciiFoldingFilterTokenStream {
buffer: String::with_capacity(100),
tail: self.0.token_stream(text),
}
}
}
impl<'a> TokenStream for AsciiFoldingFilterTokenStream<'a> {
pub struct AsciiFoldingFilterTokenStream<T> {
buffer: String,
tail: T,
}
impl<T: TokenStream> TokenStream for AsciiFoldingFilterTokenStream<T> {
fn advance(&mut self) -> bool {
if !self.tail.advance() {
return false;
@@ -1560,8 +1573,9 @@ mod tests {
fn folding_helper(text: &str) -> Vec<String> {
let mut tokens = Vec::new();
TextAnalyzer::from(SimpleTokenizer)
TextAnalyzer::builder(SimpleTokenizer)
.filter(AsciiFoldingFilter)
.build()
.token_stream(text)
.process(&mut |token| {
tokens.push(token.text.clone());
@@ -1570,8 +1584,9 @@ mod tests {
}
fn folding_using_raw_tokenizer_helper(text: &str) -> String {
let mut token_stream = TextAnalyzer::from(RawTokenizer)
let mut token_stream = TextAnalyzer::builder(RawTokenizer)
.filter(AsciiFoldingFilter)
.build()
.token_stream(text);
token_stream.advance();
token_stream.token().text.clone()

View File

@@ -1,16 +1,17 @@
use crate::tokenizer::{BoxTokenStream, Token, TokenStream, Tokenizer};
use crate::tokenizer::{Token, TokenStream, Tokenizer};
#[derive(Clone)]
pub(crate) struct EmptyTokenizer;
impl Tokenizer for EmptyTokenizer {
fn token_stream<'a>(&self, _text: &'a str) -> BoxTokenStream<'a> {
EmptyTokenStream::default().into()
type TokenStream<'a> = EmptyTokenStream;
fn token_stream(&self, _text: &str) -> EmptyTokenStream {
EmptyTokenStream::default()
}
}
#[derive(Default)]
struct EmptyTokenStream {
pub struct EmptyTokenStream {
token: Token,
}
@@ -30,7 +31,7 @@ impl TokenStream for EmptyTokenStream {
#[cfg(test)]
mod tests {
use crate::tokenizer::Tokenizer;
use crate::tokenizer::{TokenStream, Tokenizer};
#[test]
fn test_empty_tokenizer() {

View File

@@ -1,4 +1,4 @@
use super::{BoxTokenStream, Token, TokenStream, Tokenizer};
use super::{Token, TokenStream, Tokenizer};
use crate::schema::FACET_SEP_BYTE;
/// The `FacetTokenizer` process a `Facet` binary representation
@@ -26,7 +26,8 @@ pub struct FacetTokenStream<'a> {
}
impl Tokenizer for FacetTokenizer {
fn token_stream<'a>(&self, text: &'a str) -> BoxTokenStream<'a> {
type TokenStream<'a> = FacetTokenStream<'a>;
fn token_stream<'a>(&self, text: &'a str) -> FacetTokenStream<'a> {
let token = Token {
position: 0,
..Default::default()
@@ -36,7 +37,6 @@ impl Tokenizer for FacetTokenizer {
state: State::RootFacetNotEmitted, //< pos is the first char that has not been processed yet.
token,
}
.into()
}
}
@@ -87,7 +87,7 @@ mod tests {
use super::FacetTokenizer;
use crate::schema::Facet;
use crate::tokenizer::{Token, Tokenizer};
use crate::tokenizer::{Token, TokenStream, Tokenizer};
#[test]
fn test_facet_tokenizer() {

View File

@@ -1,29 +1,42 @@
use std::mem;
use super::{Token, TokenFilter, TokenStream};
use crate::tokenizer::BoxTokenStream;
impl TokenFilter for LowerCaser {
fn transform<'a>(&self, token_stream: BoxTokenStream<'a>) -> BoxTokenStream<'a> {
BoxTokenStream::from(LowerCaserTokenStream {
tail: token_stream,
buffer: String::with_capacity(100),
})
}
}
use super::{Token, TokenFilter, TokenStream, Tokenizer};
/// Token filter that lowercase terms.
#[derive(Clone)]
pub struct LowerCaser;
pub struct LowerCaserTokenStream<'a> {
impl TokenFilter for LowerCaser {
type Tokenizer<T: Tokenizer> = LowerCaserFilter<T>;
fn transform<T: Tokenizer>(self, tokenizer: T) -> Self::Tokenizer<T> {
LowerCaserFilter(tokenizer)
}
}
#[derive(Clone)]
pub struct LowerCaserFilter<T>(T);
impl<T: Tokenizer> Tokenizer for LowerCaserFilter<T> {
type TokenStream<'a> = LowerCaserTokenStream<T::TokenStream<'a>>;
fn token_stream<'a>(&self, text: &'a str) -> Self::TokenStream<'a> {
LowerCaserTokenStream {
tail: self.0.token_stream(text),
buffer: String::new(),
}
}
}
pub struct LowerCaserTokenStream<T> {
buffer: String,
tail: BoxTokenStream<'a>,
tail: T,
}
// writes a lowercased version of text into output.
fn to_lowercase_unicode(text: &str, output: &mut String) {
output.clear();
output.reserve(50);
for c in text.chars() {
// Contrary to the std, we do not take care of sigma special case.
// This will have an normalizationo effect, which is ok for search.
@@ -31,7 +44,7 @@ fn to_lowercase_unicode(text: &str, output: &mut String) {
}
}
impl<'a> TokenStream for LowerCaserTokenStream<'a> {
impl<T: TokenStream> TokenStream for LowerCaserTokenStream<T> {
fn advance(&mut self) -> bool {
if !self.tail.advance() {
return false;
@@ -73,8 +86,9 @@ mod tests {
}
fn token_stream_helper(text: &str) -> Vec<Token> {
let mut token_stream = TextAnalyzer::from(SimpleTokenizer)
let mut token_stream = TextAnalyzer::builder(SimpleTokenizer)
.filter(LowerCaser)
.build()
.token_stream(text);
let mut tokens = vec![];
let mut add_token = |token: &Token| {

View File

@@ -66,10 +66,11 @@
//! ```rust
//! use tantivy::tokenizer::*;
//!
//! let en_stem = TextAnalyzer::from(SimpleTokenizer)
//! let en_stem = TextAnalyzer::builder(SimpleTokenizer)
//! .filter(RemoveLongFilter::limit(40))
//! .filter(LowerCaser)
//! .filter(Stemmer::new(Language::English));
//! .filter(Stemmer::new(Language::English))
//! .build();
//! ```
//!
//! Once your tokenizer is defined, you need to
@@ -112,9 +113,10 @@
//! let index = Index::create_in_ram(schema);
//!
//! // We need to register our tokenizer :
//! let custom_en_tokenizer = TextAnalyzer::from(SimpleTokenizer)
//! let custom_en_tokenizer = TextAnalyzer::builder(SimpleTokenizer)
//! .filter(RemoveLongFilter::limit(40))
//! .filter(LowerCaser);
//! .filter(LowerCaser)
//! .build();
//! index
//! .tokenizers()
//! .register("custom_en", custom_en_tokenizer);
@@ -137,9 +139,7 @@ mod tokenizer;
mod tokenizer_manager;
mod whitespace_tokenizer;
pub use tokenizer_api::{
BoxTokenFilter, BoxTokenStream, Token, TokenFilter, TokenStream, Tokenizer,
};
pub use tokenizer_api::{BoxTokenStream, Token, TokenFilter, TokenStream, Tokenizer};
pub use self::alphanum_only::AlphaNumOnlyFilter;
pub use self::ascii_folding_filter::AsciiFoldingFilter;
@@ -237,10 +237,11 @@ pub mod tests {
let tokenizer_manager = TokenizerManager::default();
tokenizer_manager.register(
"el_stem",
TextAnalyzer::from(SimpleTokenizer)
TextAnalyzer::builder(SimpleTokenizer)
.filter(RemoveLongFilter::limit(40))
.filter(LowerCaser)
.filter(Stemmer::new(Language::Greek)),
.filter(Stemmer::new(Language::Greek))
.build(),
);
let en_tokenizer = tokenizer_manager.get("el_stem").unwrap();
let mut tokens: Vec<Token> = vec![];

View File

@@ -1,5 +1,4 @@
use super::{Token, TokenStream, Tokenizer};
use crate::tokenizer::BoxTokenStream;
/// Tokenize the text by splitting words into n-grams of the given size(s)
///
@@ -132,8 +131,9 @@ pub struct NgramTokenStream<'a> {
}
impl Tokenizer for NgramTokenizer {
fn token_stream<'a>(&self, text: &'a str) -> BoxTokenStream<'a> {
From::from(NgramTokenStream {
type TokenStream<'a> = NgramTokenStream<'a>;
fn token_stream<'a>(&self, text: &'a str) -> NgramTokenStream<'a> {
NgramTokenStream {
ngram_charidx_iterator: StutteringIterator::new(
CodepointFrontiers::for_str(text),
self.min_gram,
@@ -142,7 +142,7 @@ impl Tokenizer for NgramTokenizer {
prefix_only: self.prefix_only,
text,
token: Token::default(),
})
}
}
}
@@ -303,9 +303,9 @@ mod tests {
use super::{utf8_codepoint_width, CodepointFrontiers, NgramTokenizer, StutteringIterator};
use crate::tokenizer::tests::assert_token;
use crate::tokenizer::{BoxTokenStream, Token, Tokenizer};
use crate::tokenizer::{Token, TokenStream, Tokenizer};
fn test_helper(mut tokenizer: BoxTokenStream) -> Vec<Token> {
fn test_helper<T: TokenStream>(mut tokenizer: T) -> Vec<Token> {
let mut tokens: Vec<Token> = vec![];
tokenizer.process(&mut |token: &Token| tokens.push(token.clone()));
tokens

View File

@@ -1,5 +1,4 @@
use super::{Token, TokenStream, Tokenizer};
use crate::tokenizer::BoxTokenStream;
/// For each value of the field, emit a single unprocessed token.
#[derive(Clone)]
@@ -11,7 +10,8 @@ pub struct RawTokenStream {
}
impl Tokenizer for RawTokenizer {
fn token_stream<'a>(&self, text: &'a str) -> BoxTokenStream<'a> {
type TokenStream<'a> = RawTokenStream;
fn token_stream(&self, text: &str) -> RawTokenStream {
let token = Token {
offset_from: 0,
offset_to: text.len(),
@@ -23,7 +23,6 @@ impl Tokenizer for RawTokenizer {
token,
has_token: true,
}
.into()
}
}

View File

@@ -1,6 +1,6 @@
use regex::Regex;
use super::{BoxTokenStream, Token, TokenStream, Tokenizer};
use super::{Token, TokenStream, Tokenizer};
use crate::TantivyError;
/// Tokenize the text by using a regex pattern to split.
@@ -60,13 +60,14 @@ impl RegexTokenizer {
}
impl Tokenizer for RegexTokenizer {
fn token_stream<'a>(&self, text: &'a str) -> BoxTokenStream<'a> {
BoxTokenStream::from(RegexTokenStream {
type TokenStream<'a> = RegexTokenStream<'a>;
fn token_stream<'a>(&self, text: &'a str) -> RegexTokenStream<'a> {
RegexTokenStream {
regex: self.regex.clone(),
text,
token: Token::default(),
cursor: 0,
})
}
}
}

View File

@@ -2,8 +2,9 @@
//! ```rust
//! use tantivy::tokenizer::*;
//!
//! let tokenizer = TextAnalyzer::from(SimpleTokenizer)
//! .filter(RemoveLongFilter::limit(5));
//! let tokenizer = TextAnalyzer::builder(SimpleTokenizer)
//! .filter(RemoveLongFilter::limit(5))
//! .build();
//!
//! let mut stream = tokenizer.token_stream("toolong nice");
//! // because `toolong` is more than 5 characters, it is filtered
@@ -11,8 +12,7 @@
//! assert_eq!(stream.next().unwrap().text, "nice");
//! assert!(stream.next().is_none());
//! ```
use super::{Token, TokenFilter, TokenStream};
use crate::tokenizer::BoxTokenStream;
use super::{Token, TokenFilter, TokenStream, Tokenizer};
/// `RemoveLongFilter` removes tokens that are longer
/// than a given number of bytes (in UTF-8 representation).
@@ -31,27 +31,46 @@ impl RemoveLongFilter {
}
}
impl<'a> RemoveLongFilterStream<'a> {
impl<T> RemoveLongFilterStream<T> {
fn predicate(&self, token: &Token) -> bool {
token.text.len() < self.token_length_limit
}
}
impl TokenFilter for RemoveLongFilter {
fn transform<'a>(&self, token_stream: BoxTokenStream<'a>) -> BoxTokenStream<'a> {
BoxTokenStream::from(RemoveLongFilterStream {
token_length_limit: self.length_limit,
tail: token_stream,
})
type Tokenizer<T: Tokenizer> = RemoveLongFilterWrapper<T>;
fn transform<T: Tokenizer>(self, tokenizer: T) -> RemoveLongFilterWrapper<T> {
RemoveLongFilterWrapper {
length_limit: self.length_limit,
inner: tokenizer,
}
}
}
pub struct RemoveLongFilterStream<'a> {
token_length_limit: usize,
tail: BoxTokenStream<'a>,
#[derive(Clone)]
pub struct RemoveLongFilterWrapper<T: Tokenizer> {
length_limit: usize,
inner: T,
}
impl<'a> TokenStream for RemoveLongFilterStream<'a> {
impl<T: Tokenizer> Tokenizer for RemoveLongFilterWrapper<T> {
type TokenStream<'a> = RemoveLongFilterStream<T::TokenStream<'a>>;
fn token_stream<'a>(&self, text: &'a str) -> Self::TokenStream<'a> {
RemoveLongFilterStream {
token_length_limit: self.length_limit,
tail: self.inner.token_stream(text),
}
}
}
pub struct RemoveLongFilterStream<T> {
token_length_limit: usize,
tail: T,
}
impl<T: TokenStream> TokenStream for RemoveLongFilterStream<T> {
fn advance(&mut self) -> bool {
while self.tail.advance() {
if self.predicate(self.tail.token()) {
@@ -84,7 +103,9 @@ mod tests {
}
fn token_stream_helper(text: &str) -> Vec<Token> {
let a = TextAnalyzer::from(SimpleTokenizer).filter(RemoveLongFilter::limit(6));
let a = TextAnalyzer::builder(SimpleTokenizer)
.filter(RemoveLongFilter::limit(6))
.build();
let mut token_stream = a.token_stream(text);
let mut tokens: Vec<Token> = vec![];
let mut add_token = |token: &Token| {

View File

@@ -1,6 +1,6 @@
use std::str::CharIndices;
use super::{BoxTokenStream, Token, TokenStream, Tokenizer};
use super::{Token, TokenStream, Tokenizer};
/// Tokenize the text by splitting on whitespaces and punctuation.
#[derive(Clone)]
@@ -13,12 +13,13 @@ pub struct SimpleTokenStream<'a> {
}
impl Tokenizer for SimpleTokenizer {
fn token_stream<'a>(&self, text: &'a str) -> BoxTokenStream<'a> {
BoxTokenStream::from(SimpleTokenStream {
type TokenStream<'a> = SimpleTokenStream<'a>;
fn token_stream<'a>(&self, text: &'a str) -> SimpleTokenStream<'a> {
SimpleTokenStream {
text,
chars: text.char_indices(),
token: Token::default(),
})
}
}
}

View File

@@ -2,7 +2,7 @@ use std::sync::Arc;
use aho_corasick::{AhoCorasick, AhoCorasickBuilder, MatchKind, StateID};
use super::{BoxTokenStream, Token, TokenFilter, TokenStream};
use super::{Token, TokenFilter, TokenStream, Tokenizer};
/// A [`TokenFilter`] which splits compound words into their parts
/// based on a given dictionary.
@@ -23,9 +23,11 @@ use super::{BoxTokenStream, Token, TokenFilter, TokenStream};
/// use tantivy::tokenizer::{SimpleTokenizer, SplitCompoundWords, TextAnalyzer};
///
/// let tokenizer =
/// TextAnalyzer::from(SimpleTokenizer).filter(SplitCompoundWords::from_dictionary([
/// TextAnalyzer::builder(SimpleTokenizer)
/// .filter(SplitCompoundWords::from_dictionary([
/// "dampf", "schiff", "fahrt", "brot", "backen", "automat",
/// ]));
/// ]))
/// .build();
///
/// let mut stream = tokenizer.token_stream("dampfschifffahrt");
/// assert_eq!(stream.next().unwrap().text, "dampf");
@@ -76,24 +78,45 @@ impl<S: StateID> SplitCompoundWords<S> {
}
impl<S: StateID + Send + Sync + 'static> TokenFilter for SplitCompoundWords<S> {
fn transform<'a>(&self, stream: BoxTokenStream<'a>) -> BoxTokenStream<'a> {
BoxTokenStream::from(SplitCompoundWordsTokenStream {
dict: self.dict.clone(),
tail: stream,
cuts: Vec::new(),
parts: Vec::new(),
})
type Tokenizer<T: Tokenizer> = SplitCompoundWordsFilter<T, S>;
fn transform<T: Tokenizer>(self, tokenizer: T) -> SplitCompoundWordsFilter<T, S> {
SplitCompoundWordsFilter {
dict: self.dict,
inner: tokenizer,
}
}
}
struct SplitCompoundWordsTokenStream<'a, S: StateID> {
#[derive(Clone)]
pub struct SplitCompoundWordsFilter<T, S: StateID> {
dict: Arc<AhoCorasick<S>>,
tail: BoxTokenStream<'a>,
inner: T,
}
impl<T: Tokenizer, S: StateID + Send + Sync + 'static> Tokenizer
for SplitCompoundWordsFilter<T, S>
{
type TokenStream<'a> = SplitCompoundWordsTokenStream<T::TokenStream<'a>, S>;
fn token_stream<'a>(&self, text: &'a str) -> Self::TokenStream<'a> {
SplitCompoundWordsTokenStream {
dict: self.dict.clone(),
tail: self.inner.token_stream(text),
cuts: Vec::new(),
parts: Vec::new(),
}
}
}
pub struct SplitCompoundWordsTokenStream<T, S: StateID> {
dict: Arc<AhoCorasick<S>>,
tail: T,
cuts: Vec<usize>,
parts: Vec<Token>,
}
impl<'a, S: StateID> SplitCompoundWordsTokenStream<'a, S> {
impl<T: TokenStream, S: StateID> SplitCompoundWordsTokenStream<T, S> {
// Will use `self.cuts` to fill `self.parts` if `self.tail.token()`
// can fully be split into consecutive matches against `self.dict`.
fn split(&mut self) {
@@ -129,7 +152,7 @@ impl<'a, S: StateID> SplitCompoundWordsTokenStream<'a, S> {
}
}
impl<'a, S: StateID> TokenStream for SplitCompoundWordsTokenStream<'a, S> {
impl<T: TokenStream, S: StateID> TokenStream for SplitCompoundWordsTokenStream<T, S> {
fn advance(&mut self) -> bool {
self.parts.pop();
@@ -165,8 +188,9 @@ mod tests {
#[test]
fn splitting_compound_words_works() {
let tokenizer = TextAnalyzer::from(SimpleTokenizer)
.filter(SplitCompoundWords::from_dictionary(["foo", "bar"]));
let tokenizer = TextAnalyzer::builder(SimpleTokenizer)
.filter(SplitCompoundWords::from_dictionary(["foo", "bar"]))
.build();
{
let mut stream = tokenizer.token_stream("");

View File

@@ -4,8 +4,7 @@ use std::mem;
use rust_stemmers::{self, Algorithm};
use serde::{Deserialize, Serialize};
use super::{Token, TokenFilter, TokenStream};
use crate::tokenizer::BoxTokenStream;
use super::{Token, TokenFilter, TokenStream, Tokenizer};
/// Available stemmer languages.
#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Copy, Clone)]
@@ -82,23 +81,42 @@ impl Default for Stemmer {
}
impl TokenFilter for Stemmer {
fn transform<'a>(&self, token_stream: BoxTokenStream<'a>) -> BoxTokenStream<'a> {
let inner_stemmer = rust_stemmers::Stemmer::create(self.stemmer_algorithm);
BoxTokenStream::from(StemmerTokenStream {
tail: token_stream,
stemmer: inner_stemmer,
buffer: String::new(),
})
type Tokenizer<T: Tokenizer> = StemmerFilter<T>;
fn transform<T: Tokenizer>(self, tokenizer: T) -> StemmerFilter<T> {
StemmerFilter {
stemmer_algorithm: self.stemmer_algorithm,
inner: tokenizer,
}
}
}
pub struct StemmerTokenStream<'a> {
tail: BoxTokenStream<'a>,
#[derive(Clone)]
pub struct StemmerFilter<T> {
stemmer_algorithm: Algorithm,
inner: T,
}
impl<T: Tokenizer> Tokenizer for StemmerFilter<T> {
type TokenStream<'a> = StemmerTokenStream<T::TokenStream<'a>>;
fn token_stream<'a>(&self, text: &'a str) -> Self::TokenStream<'a> {
let stemmer = rust_stemmers::Stemmer::create(self.stemmer_algorithm);
StemmerTokenStream {
tail: self.inner.token_stream(text),
stemmer,
buffer: String::new(),
}
}
}
pub struct StemmerTokenStream<T> {
tail: T,
stemmer: rust_stemmers::Stemmer,
buffer: String,
}
impl<'a> TokenStream for StemmerTokenStream<'a> {
impl<T: TokenStream> TokenStream for StemmerTokenStream<T> {
fn advance(&mut self) -> bool {
if !self.tail.advance() {
return false;

View File

@@ -2,8 +2,9 @@
//! ```rust
//! use tantivy::tokenizer::*;
//!
//! let tokenizer = TextAnalyzer::from(SimpleTokenizer)
//! .filter(StopWordFilter::remove(vec!["the".to_string(), "is".to_string()]));
//! let tokenizer = TextAnalyzer::builder(SimpleTokenizer)
//! .filter(StopWordFilter::remove(vec!["the".to_string(), "is".to_string()]))
//! .build();
//!
//! let mut stream = tokenizer.token_stream("the fox is crafty");
//! assert_eq!(stream.next().unwrap().text, "fox");
@@ -20,7 +21,7 @@ use rustc_hash::FxHashSet;
#[cfg(feature = "stopwords")]
use super::Language;
use super::{BoxTokenStream, Token, TokenFilter, TokenStream};
use super::{Token, TokenFilter, TokenStream, Tokenizer};
/// `TokenFilter` that removes stop words from a token stream
#[derive(Clone)]
@@ -69,27 +70,46 @@ impl StopWordFilter {
}
}
pub struct StopWordFilterStream<'a> {
words: Arc<FxHashSet<String>>,
tail: BoxTokenStream<'a>,
}
impl TokenFilter for StopWordFilter {
fn transform<'a>(&self, token_stream: BoxTokenStream<'a>) -> BoxTokenStream<'a> {
BoxTokenStream::from(StopWordFilterStream {
words: self.words.clone(),
tail: token_stream,
})
type Tokenizer<T: Tokenizer> = StopWordFilterWrapper<T>;
fn transform<T: Tokenizer>(self, tokenizer: T) -> StopWordFilterWrapper<T> {
StopWordFilterWrapper {
words: self.words,
inner: tokenizer,
}
}
}
impl<'a> StopWordFilterStream<'a> {
#[derive(Clone)]
pub struct StopWordFilterWrapper<T> {
words: Arc<FxHashSet<String>>,
inner: T,
}
impl<T: Tokenizer> Tokenizer for StopWordFilterWrapper<T> {
type TokenStream<'a> = StopWordFilterStream<T::TokenStream<'a>>;
fn token_stream<'a>(&self, text: &'a str) -> Self::TokenStream<'a> {
StopWordFilterStream {
words: self.words.clone(),
tail: self.inner.token_stream(text),
}
}
}
pub struct StopWordFilterStream<T> {
words: Arc<FxHashSet<String>>,
tail: T,
}
impl<T> StopWordFilterStream<T> {
fn predicate(&self, token: &Token) -> bool {
!self.words.contains(&token.text)
}
}
impl<'a> TokenStream for StopWordFilterStream<'a> {
impl<T: TokenStream> TokenStream for StopWordFilterStream<T> {
fn advance(&mut self) -> bool {
while self.tail.advance() {
if self.predicate(self.tail.token()) {
@@ -131,7 +151,9 @@ mod tests {
"am".to_string(),
"i".to_string(),
];
let a = TextAnalyzer::from(SimpleTokenizer).filter(StopWordFilter::remove(stops));
let a = TextAnalyzer::builder(SimpleTokenizer)
.filter(StopWordFilter::remove(stops))
.build();
let mut token_stream = a.token_stream(text);
let mut tokens: Vec<Token> = vec![];
let mut add_token = |token: &Token| {

View File

@@ -1,15 +1,12 @@
/// The tokenizer module contains all of the tools used to process
/// text in `tantivy`.
use tokenizer_api::{BoxTokenFilter, BoxTokenStream, Tokenizer};
use tokenizer_api::{BoxTokenStream, BoxableTokenizer, TokenFilter, Tokenizer};
use crate::tokenizer::empty_tokenizer::EmptyTokenizer;
/// `TextAnalyzer` tokenizes an input text into tokens and modifies the resulting `TokenStream`.
///
/// It simply wraps a `Tokenizer` and a list of `TokenFilter` that are applied sequentially.
pub struct TextAnalyzer {
tokenizer: Box<dyn Tokenizer>,
token_filters: Vec<BoxTokenFilter>,
tokenizer: Box<dyn BoxableTokenizer>,
}
impl Default for TextAnalyzer {
@@ -18,52 +15,21 @@ impl Default for TextAnalyzer {
}
}
impl<T: Tokenizer> From<T> for TextAnalyzer {
impl<T: Tokenizer + Clone> From<T> for TextAnalyzer {
fn from(tokenizer: T) -> Self {
TextAnalyzer::new(tokenizer, Vec::new())
TextAnalyzer::builder(tokenizer).build()
}
}
impl TextAnalyzer {
/// Creates a new `TextAnalyzer` given a tokenizer and a vector of `BoxTokenFilter`.
///
/// When creating a `TextAnalyzer` from a `Tokenizer` alone, prefer using
/// `TextAnalyzer::from(tokenizer)`.
pub fn new<T: Tokenizer>(tokenizer: T, token_filters: Vec<BoxTokenFilter>) -> TextAnalyzer {
TextAnalyzer {
tokenizer: Box::new(tokenizer),
token_filters,
}
}
/// Appends a token filter to the current tokenizer.
///
/// The method consumes the current `TokenStream` and returns a
/// new one.
///
/// # Example
///
/// ```rust
/// use tantivy::tokenizer::*;
///
/// let en_stem = TextAnalyzer::from(SimpleTokenizer)
/// .filter(RemoveLongFilter::limit(40))
/// .filter(LowerCaser)
/// .filter(Stemmer::default());
/// ```
#[must_use]
pub fn filter<F: Into<BoxTokenFilter>>(mut self, token_filter: F) -> Self {
self.token_filters.push(token_filter.into());
self
/// Create a new TextAnalyzerBuilder
pub fn builder<T: Tokenizer>(tokenizer: T) -> TextAnalyzerBuilder<T> {
TextAnalyzerBuilder { tokenizer }
}
/// Creates a token stream for a given `str`.
pub fn token_stream<'a>(&self, text: &'a str) -> BoxTokenStream<'a> {
let mut token_stream = self.tokenizer.token_stream(text);
for token_filter in &self.token_filters {
token_stream = token_filter.transform(token_stream);
}
token_stream
self.tokenizer.box_token_stream(text)
}
}
@@ -71,11 +37,39 @@ impl Clone for TextAnalyzer {
fn clone(&self) -> Self {
TextAnalyzer {
tokenizer: self.tokenizer.box_clone(),
token_filters: self
.token_filters
.iter()
.map(|token_filter| token_filter.box_clone())
.collect(),
}
}
}
/// Builder helper for [`TextAnalyzer`]
pub struct TextAnalyzerBuilder<T> {
tokenizer: T,
}
impl<T: Tokenizer> TextAnalyzerBuilder<T> {
/// Appends a token filter to the current builder.
///
/// # Example
///
/// ```rust
/// use tantivy::tokenizer::*;
///
/// let en_stem = TextAnalyzer::builder(SimpleTokenizer)
/// .filter(RemoveLongFilter::limit(40))
/// .filter(LowerCaser)
/// .filter(Stemmer::default())
/// .build();
/// ```
pub fn filter<F: TokenFilter>(self, token_filter: F) -> TextAnalyzerBuilder<F::Tokenizer<T>> {
TextAnalyzerBuilder {
tokenizer: token_filter.transform(self.tokenizer),
}
}
/// Finalize building the TextAnalyzer
pub fn build(self) -> TextAnalyzer {
TextAnalyzer {
tokenizer: Box::new(self.tokenizer),
}
}
}

View File

@@ -61,16 +61,18 @@ impl Default for TokenizerManager {
manager.register("raw", RawTokenizer);
manager.register(
"default",
TextAnalyzer::from(SimpleTokenizer)
TextAnalyzer::builder(SimpleTokenizer)
.filter(RemoveLongFilter::limit(40))
.filter(LowerCaser),
.filter(LowerCaser)
.build(),
);
manager.register(
"en_stem",
TextAnalyzer::from(SimpleTokenizer)
TextAnalyzer::builder(SimpleTokenizer)
.filter(RemoveLongFilter::limit(40))
.filter(LowerCaser)
.filter(Stemmer::new(Language::English)),
.filter(Stemmer::new(Language::English))
.build(),
);
manager.register("whitespace", WhitespaceTokenizer);
manager

View File

@@ -1,6 +1,6 @@
use std::str::CharIndices;
use super::{BoxTokenStream, Token, TokenStream, Tokenizer};
use super::{Token, TokenStream, Tokenizer};
/// Tokenize the text by splitting on whitespaces.
#[derive(Clone)]
@@ -13,12 +13,13 @@ pub struct WhitespaceTokenStream<'a> {
}
impl Tokenizer for WhitespaceTokenizer {
fn token_stream<'a>(&self, text: &'a str) -> BoxTokenStream<'a> {
BoxTokenStream::from(WhitespaceTokenStream {
type TokenStream<'a> = WhitespaceTokenStream<'a>;
fn token_stream<'a>(&self, text: &'a str) -> WhitespaceTokenStream<'a> {
WhitespaceTokenStream {
text,
chars: text.char_indices(),
token: Token::default(),
})
}
}
}

View File

@@ -1,3 +1,4 @@
use std::cmp::Ordering;
use std::io;
use std::marker::PhantomData;
use std::ops::{Bound, RangeBounds};
@@ -96,6 +97,14 @@ impl<TSSTable: SSTable> Dictionary<TSSTable> {
Ok(TSSTable::delta_reader(data))
}
pub(crate) fn sstable_delta_reader_block(
&self,
block_addr: BlockAddr,
) -> io::Result<DeltaReader<'static, TSSTable::ValueReader>> {
let data = self.sstable_slice.read_bytes_slice(block_addr.byte_range)?;
Ok(TSSTable::delta_reader(data))
}
/// This function returns a file slice covering a set of sstable blocks
/// that include the key range passed in arguments. Optionally returns
/// only block for up to `limit` matching terms.
@@ -215,13 +224,43 @@ impl<TSSTable: SSTable> Dictionary<TSSTable> {
};
let mut term_ord = block_addr.first_ordinal;
let mut sstable_reader = self.sstable_reader_block(block_addr)?;
while sstable_reader.advance()? {
if sstable_reader.key() == key_bytes {
return Ok(Some(term_ord));
let mut ok_bytes = 0;
let mut sstable_delta_reader = self.sstable_delta_reader_block(block_addr)?;
while sstable_delta_reader.advance()? {
let prefix_len = sstable_delta_reader.common_prefix_len();
let suffix = sstable_delta_reader.suffix();
match prefix_len.cmp(&ok_bytes) {
Ordering::Less => return Ok(None), // poped bytes already matched => too far
Ordering::Equal => (),
Ordering::Greater => {
// the ok prefix is less than current entry prefix => continue to next elem
term_ord += 1;
continue;
}
}
// we have ok_bytes byte of common prefix, check if this key adds more
for (key_byte, suffix_byte) in key_bytes[ok_bytes..].iter().zip(suffix) {
match suffix_byte.cmp(key_byte) {
Ordering::Less => break, // byte too small
Ordering::Equal => ok_bytes += 1, // new matching byte
Ordering::Greater => return Ok(None), // too far
}
}
if ok_bytes == key_bytes.len() {
if prefix_len + suffix.len() == ok_bytes {
return Ok(Some(term_ord));
} else {
// current key is a prefix of current element, not a match
return Ok(None);
}
}
term_ord += 1;
}
Ok(None)
}
@@ -240,14 +279,14 @@ impl<TSSTable: SSTable> Dictionary<TSSTable> {
let first_ordinal = block_addr.first_ordinal;
// then search inside that block only
let mut sstable_reader = self.sstable_reader_block(block_addr)?;
let mut sstable_delta_reader = self.sstable_delta_reader_block(block_addr)?;
for _ in first_ordinal..=ord {
if !sstable_reader.advance()? {
if !sstable_delta_reader.advance()? {
return Ok(false);
}
bytes.truncate(sstable_delta_reader.common_prefix_len());
bytes.extend_from_slice(sstable_delta_reader.suffix());
}
bytes.clear();
bytes.extend_from_slice(sstable_reader.key());
Ok(true)
}
@@ -456,6 +495,12 @@ mod tests {
slice.restrict(0..0);
assert!(dic.get(b"~~~").unwrap().is_none());
assert!(dic.term_ord(b"~~~").unwrap().is_none());
slice.restrict(0..slice.bytes.len());
// between 1000F and 10010, test case where matched prefix > prefix kept
assert!(dic.term_ord(b"1000G").unwrap().is_none());
// shorter than 10000, tests prefix case
assert!(dic.term_ord(b"1000").unwrap().is_none());
}
#[test]

View File

@@ -6,4 +6,5 @@ license = "MIT"
[dependencies]
murmurhash32 = "0.3"
byteorder = "1"
common = { version = "0.5", path = "../common/", package = "tantivy-common" }

View File

@@ -1,5 +1,6 @@
use std::{iter, mem, slice};
use byteorder::{ByteOrder, NativeEndian};
use murmurhash32::murmurhash2;
use super::{Addr, MemoryArena};
@@ -154,7 +155,7 @@ impl ArenaHashMap {
#[inline]
fn get_key_value(&self, addr: Addr) -> (&[u8], Addr) {
let data = self.memory_arena.slice_from(addr);
let key_bytes_len = u16::from_ne_bytes(data[..2].try_into().unwrap()) as usize;
let key_bytes_len = NativeEndian::read_u16(data) as usize;
let key_bytes: &[u8] = &data[2..][..key_bytes_len];
(key_bytes, addr.offset(2u32 + key_bytes_len as u32))
}
@@ -272,7 +273,7 @@ impl ArenaHashMap {
let key_addr = self.memory_arena.allocate_space(num_bytes);
{
let data = self.memory_arena.slice_mut(key_addr, num_bytes);
data[..2].copy_from_slice(&u16::to_ne_bytes(key.len() as u16));
NativeEndian::write_u16(data, key.len() as u16);
let stop = 2 + key.len();
data[2..stop].copy_from_slice(key);
store(&mut data[stop..], val);

View File

@@ -42,28 +42,31 @@ impl Default for Token {
/// `Tokenizer` are in charge of splitting text into a stream of token
/// before indexing.
///
/// # Warning
///
/// This API may change to use associated types.
pub trait Tokenizer: 'static + Send + Sync + TokenizerClone {
pub trait Tokenizer: 'static + Clone + Send + Sync {
/// The token stream returned by this Tokenizer.
type TokenStream<'a>: TokenStream;
/// Creates a token stream for a given `str`.
fn token_stream<'a>(&self, text: &'a str) -> BoxTokenStream<'a>;
fn token_stream<'a>(&self, text: &'a str) -> Self::TokenStream<'a>;
}
pub trait TokenizerClone {
fn box_clone(&self) -> Box<dyn Tokenizer>;
/// A boxable `Tokenizer`, with its `TokenStream` type erased.
pub trait BoxableTokenizer: 'static + Send + Sync {
/// Creates a boxed token stream for a given `str`.
fn box_token_stream<'a>(&self, text: &'a str) -> BoxTokenStream<'a>;
/// Clone this tokenizer.
fn box_clone(&self) -> Box<dyn BoxableTokenizer>;
}
impl<T: Tokenizer + Clone> TokenizerClone for T {
fn box_clone(&self) -> Box<dyn Tokenizer> {
impl<T: Tokenizer> BoxableTokenizer for T {
fn box_token_stream<'a>(&self, text: &'a str) -> BoxTokenStream<'a> {
self.token_stream(text).into()
}
fn box_clone(&self) -> Box<dyn BoxableTokenizer> {
Box::new(self.clone())
}
}
/// Simple wrapper of `Box<dyn TokenStream + 'a>`.
///
/// See [`TokenStream`] for more information.
pub struct BoxTokenStream<'a>(Box<dyn TokenStream + 'a>);
impl<'a, T> From<T> for BoxTokenStream<'a>
@@ -139,39 +142,13 @@ pub trait TokenStream {
}
}
/// Simple wrapper of `Box<dyn TokenFilter + 'a>`.
///
/// See [`TokenFilter`] for more information.
pub struct BoxTokenFilter(Box<dyn TokenFilter>);
impl Deref for BoxTokenFilter {
type Target = dyn TokenFilter;
fn deref(&self) -> &dyn TokenFilter {
&*self.0
}
}
impl<T: TokenFilter> From<T> for BoxTokenFilter {
fn from(tokenizer: T) -> BoxTokenFilter {
BoxTokenFilter(Box::new(tokenizer))
}
}
pub trait TokenFilterClone {
fn box_clone(&self) -> BoxTokenFilter;
}
/// Trait for the pluggable components of `Tokenizer`s.
pub trait TokenFilter: 'static + Send + Sync + TokenFilterClone {
/// Wraps a token stream and returns the modified one.
fn transform<'a>(&self, token_stream: BoxTokenStream<'a>) -> BoxTokenStream<'a>;
}
impl<T: TokenFilter + Clone> TokenFilterClone for T {
fn box_clone(&self) -> BoxTokenFilter {
BoxTokenFilter::from(self.clone())
}
pub trait TokenFilter: 'static + Send + Sync {
/// The Tokenizer type returned by this filter, typically parametrized by the underlying
/// Tokenizer.
type Tokenizer<T: Tokenizer>: Tokenizer;
/// Wraps a Tokenizer and returns a new one.
fn transform<T: Tokenizer>(self, tokenizer: T) -> Self::Tokenizer<T>;
}
#[cfg(test)]