diff --git a/CHANGELOG.md b/CHANGELOG.md index d63ccef7d..0263e8865 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,9 +1,64 @@ -Tantivy 0.21.2 +Tantivy 0.22 ================================ -#### Bugfixes -- Bugfix: Merge operations would panic for JsonObject with position enabled, when they contain numbers or booleans. [#2251](https://github.com/quickwit-oss/tantivy/issues/2251). -#### Features/Improvements +Tantivy 0.22 will be able to read indices created with Tantivy 0.21. + +#### Bugfixes +- Fix null byte handling in JSON paths (null bytes in json keys caused panic during indexing) [#2345](https://github.com/quickwit-oss/tantivy/pull/2345)(@PSeitz) +- Fix bug that can cause `get_docids_for_value_range` to panic. [#2295](https://github.com/quickwit-oss/tantivy/pull/2295)(@fulmicoton) +- Avoid 1 document indices by increase min memory to 15MB for indexing [#2176](https://github.com/quickwit-oss/tantivy/pull/2176)(@PSeitz) +- Fix merge panic for JSON fields [#2284](https://github.com/quickwit-oss/tantivy/pull/2284)(@PSeitz) +- Fix bug occuring when merging JSON object indexed with positions. [#2253](https://github.com/quickwit-oss/tantivy/pull/2253)(@fulmicoton) +- Fix empty DateHistogram gap bug [#2183](https://github.com/quickwit-oss/tantivy/pull/2183)(@PSeitz) +- Fix range query end check (fields with less than 1 value per doc are affected) [#2226](https://github.com/quickwit-oss/tantivy/pull/2226)(@PSeitz) +- Handle exclusive out of bounds ranges on fastfield range queries [#2174](https://github.com/quickwit-oss/tantivy/pull/2174)(@PSeitz) + +#### Breaking API Changes +- rename ReloadPolicy onCommit to onCommitWithDelay [#2235](https://github.com/quickwit-oss/tantivy/pull/2235)(@giovannicuccu) +- Move exports from the root into modules [#2220](https://github.com/quickwit-oss/tantivy/pull/2220)(@PSeitz) +- Accept field name instead of `Field` in FilterCollector [#2196](https://github.com/quickwit-oss/tantivy/pull/2196)(@PSeitz) +- remove deprecated IntOptions and DateTime [#2353](https://github.com/quickwit-oss/tantivy/pull/2353)(@PSeitz) + +#### Features/Improvements +- Tantivy documents as a trait: Index data directly without converting to tantivy types first [#2071](https://github.com/quickwit-oss/tantivy/pull/2071)(@ChillFish8) +- encode some part of posting list as -1 instead of direct values (smaller inverted indices) [#2185](https://github.com/quickwit-oss/tantivy/pull/2185)(@trinity-1686a) +- **Aggregation** + - Support to deserialize f64 from string [#2311](https://github.com/quickwit-oss/tantivy/pull/2311)(@PSeitz) + - Add a top_hits aggregator [#2198](https://github.com/quickwit-oss/tantivy/pull/2198)(@ditsuke) + - Support bool type in term aggregation [#2318](https://github.com/quickwit-oss/tantivy/pull/2318)(@PSeitz) + - Support ip adresses in term aggregation [#2319](https://github.com/quickwit-oss/tantivy/pull/2319)(@PSeitz) + - Support date type in term aggregation [#2172](https://github.com/quickwit-oss/tantivy/pull/2172)(@PSeitz) + - Support escaped dot when addressing field [#2250](https://github.com/quickwit-oss/tantivy/pull/2250)(@PSeitz) + +- Add ExistsQuery to check documents that have a value [#2160](https://github.com/quickwit-oss/tantivy/pull/2160)(@imotov) +- Expose TopDocs::order_by_u64_field again [#2282](https://github.com/quickwit-oss/tantivy/pull/2282)(@ditsuke) + +- **Memory/Performance** + - Faster TopN: replace BinaryHeap with TopNComputer [#2186](https://github.com/quickwit-oss/tantivy/pull/2186)(@PSeitz) + - reduce number of allocations during indexing [#2257](https://github.com/quickwit-oss/tantivy/pull/2257)(@PSeitz) + - Less Memory while indexing: docid deltas while indexing [#2249](https://github.com/quickwit-oss/tantivy/pull/2249)(@PSeitz) + - Faster indexing: use term hashmap in fastfield [#2243](https://github.com/quickwit-oss/tantivy/pull/2243)(@PSeitz) + - term hashmap remove copy in is_empty, unused unordered_id [#2229](https://github.com/quickwit-oss/tantivy/pull/2229)(@PSeitz) + - add method to fetch block of first values in columnar [#2330](https://github.com/quickwit-oss/tantivy/pull/2330)(@PSeitz) + - Faster aggregations: add fast path for full columns in fetch_block [#2328](https://github.com/quickwit-oss/tantivy/pull/2328)(@PSeitz) + - Faster sstable loading: use fst for sstable index [#2268](https://github.com/quickwit-oss/tantivy/pull/2268)(@trinity-1686a) + +- **QueryParser** + - allow newline where we allow space in query parser [#2302](https://github.com/quickwit-oss/tantivy/pull/2302)(@trinity-1686a) + - allow some mixing of occur and bool in strict query parser [#2323](https://github.com/quickwit-oss/tantivy/pull/2323)(@trinity-1686a) + - handle * inside term in lenient query parser [#2228](https://github.com/quickwit-oss/tantivy/pull/2228)(@trinity-1686a) + - add support for exists query syntax in query parser [#2170](https://github.com/quickwit-oss/tantivy/pull/2170)(@trinity-1686a) +- Add shared search executor [#2312](https://github.com/quickwit-oss/tantivy/pull/2312)(@MochiXu) +- Truncate keys to u16::MAX in term hashmap [#2299](https://github.com/quickwit-oss/tantivy/pull/2299)(@PSeitz) +- report if a term matched when warming up posting list [#2309](https://github.com/quickwit-oss/tantivy/pull/2309)(@trinity-1686a) +- Support json fields in FuzzyTermQuery [#2173](https://github.com/quickwit-oss/tantivy/pull/2173)(@PingXia-at) +- Read list of fields encoded in term dictionary for JSON fields [#2184](https://github.com/quickwit-oss/tantivy/pull/2184)(@PSeitz) +- add collect_block to BoxableSegmentCollector [#2331](https://github.com/quickwit-oss/tantivy/pull/2331)(@PSeitz) +- expose collect_block buffer size [#2326](https://github.com/quickwit-oss/tantivy/pull/2326)(@PSeitz) +- Forward regex parser errors [#2288](https://github.com/quickwit-oss/tantivy/pull/2288)(@adamreichold) +- Make FacetCounts defaultable and cloneable. [#2322](https://github.com/quickwit-oss/tantivy/pull/2322)(@adamreichold) +- Derive Debug for SchemaBuilder [#2254](https://github.com/quickwit-oss/tantivy/pull/2254)(@GodTamIt) +- add missing inlines to tantivy options [#2245](https://github.com/quickwit-oss/tantivy/pull/2245)(@PSeitz) Tantivy 0.21.1 ================================ diff --git a/Cargo.toml b/Cargo.toml index 7bb412a7c..3580168e1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tantivy" -version = "0.21.2" +version = "0.22.0" authors = ["Paul Masurel "] license = "MIT" categories = ["database-implementations", "data-structures"] @@ -11,12 +11,12 @@ repository = "https://github.com/quickwit-oss/tantivy" readme = "README.md" keywords = ["search", "information", "retrieval"] edition = "2021" -rust-version = "1.62" +rust-version = "1.63" exclude = ["benches/*.json", "benches/*.txt"] [dependencies] oneshot = "0.1.5" -base64 = "0.21.0" +base64 = "0.22.0" byteorder = "1.4.3" crc32fast = "1.3.2" once_cell = "1.10.0" @@ -31,14 +31,14 @@ log = "0.4.16" serde = { version = "1.0.136", features = ["derive"] } serde_json = "1.0.79" num_cpus = "1.13.1" -fs4 = { version = "0.7.0", optional = true } +fs4 = { version = "0.8.0", optional = true } levenshtein_automata = "0.2.1" uuid = { version = "1.0.0", features = ["v4", "serde"] } crossbeam-channel = "0.5.4" rust-stemmers = "1.2.0" downcast-rs = "1.2.0" -bitpacking = { git = "https://github.com/quickwit-oss/bitpacking", rev = "f730b75", default-features = false, features = ["bitpacker4x"] } -census = "0.4.0" +bitpacking = { version = "0.9.2", default-features = false, features = ["bitpacker4x"] } +census = "0.4.2" rustc-hash = "1.1.0" thiserror = "1.0.30" htmlescape = "0.3.1" @@ -52,13 +52,13 @@ itertools = "0.12.0" measure_time = "0.8.2" arc-swap = "1.5.0" -columnar = { version= "0.2", path="./columnar", package ="tantivy-columnar" } -sstable = { version= "0.2", path="./sstable", package ="tantivy-sstable", optional = true } -stacker = { version= "0.2", path="./stacker", package ="tantivy-stacker" } -query-grammar = { version= "0.21.0", path="./query-grammar", package = "tantivy-query-grammar" } -tantivy-bitpacker = { version= "0.5", path="./bitpacker" } -common = { version= "0.6", path = "./common/", package = "tantivy-common" } -tokenizer-api = { version= "0.2", path="./tokenizer-api", package="tantivy-tokenizer-api" } +columnar = { version= "0.3", path="./columnar", package ="tantivy-columnar" } +sstable = { version= "0.3", path="./sstable", package ="tantivy-sstable", optional = true } +stacker = { version= "0.3", path="./stacker", package ="tantivy-stacker" } +query-grammar = { version= "0.22.0", path="./query-grammar", package = "tantivy-query-grammar" } +tantivy-bitpacker = { version= "0.6", path="./bitpacker" } +common = { version= "0.7", path = "./common/", package = "tantivy-common" } +tokenizer-api = { version= "0.3", path="./tokenizer-api", package="tantivy-tokenizer-api" } sketches-ddsketch = { version = "0.2.1", features = ["use_serde"] } futures-util = { version = "0.3.28", optional = true } fnv = "1.0.7" @@ -77,6 +77,10 @@ futures = "0.3.21" paste = "1.0.11" more-asserts = "0.3.1" rand_distr = "0.4.3" +time = { version = "0.3.10", features = ["serde-well-known", "macros"] } +postcard = { version = "1.0.4", features = [ + "use-std", +], default-features = false } [target.'cfg(not(windows))'.dev-dependencies] criterion = { version = "0.5", default-features = false } diff --git a/README.md b/README.md index 2cd8e8d76..ac6b15d45 100644 --- a/README.md +++ b/README.md @@ -5,19 +5,18 @@ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) [![Crates.io](https://img.shields.io/crates/v/tantivy.svg)](https://crates.io/crates/tantivy) -![Tantivy](https://tantivy-search.github.io/logo/tantivy-logo.png) +Tantivy, the fastest full-text search engine library written in Rust -**Tantivy** is a **full-text search engine library** written in Rust. +## Fast full-text search engine library written in Rust -It is closer to [Apache Lucene](https://lucene.apache.org/) than to [Elasticsearch](https://www.elastic.co/products/elasticsearch) or [Apache Solr](https://lucene.apache.org/solr/) in the sense it is not -an off-the-shelf search engine server, but rather a crate that can be used -to build such a search engine. +**If you are looking for an alternative to Elasticsearch or Apache Solr, check out [Quickwit](https://github.com/quickwit-oss/quickwit), our distributed search engine built on top of Tantivy.** + +Tantivy is closer to [Apache Lucene](https://lucene.apache.org/) than to [Elasticsearch](https://www.elastic.co/products/elasticsearch) or [Apache Solr](https://lucene.apache.org/solr/) in the sense it is not +an off-the-shelf search engine server, but rather a crate that can be used to build such a search engine. Tantivy is, in fact, strongly inspired by Lucene's design. -If you are looking for an alternative to Elasticsearch or Apache Solr, check out [Quickwit](https://github.com/quickwit-oss/quickwit), our search engine built on top of Tantivy. - -# Benchmark +## Benchmark The following [benchmark](https://tantivy-search.github.io/bench/) breakdowns performance for different types of queries/collections. @@ -28,7 +27,7 @@ Your mileage WILL vary depending on the nature of queries and their load. Details about the benchmark can be found at this [repository](https://github.com/quickwit-oss/search-benchmark-game). -# Features +## Features - Full-text search - Configurable tokenizer (stemming available for 17 Latin languages) with third party support for Chinese ([tantivy-jieba](https://crates.io/crates/tantivy-jieba) and [cang-jie](https://crates.io/crates/cang-jie)), Japanese ([lindera](https://github.com/lindera-morphology/lindera-tantivy), [Vaporetto](https://crates.io/crates/vaporetto_tantivy), and [tantivy-tokenizer-tiny-segmenter](https://crates.io/crates/tantivy-tokenizer-tiny-segmenter)) and Korean ([lindera](https://github.com/lindera-morphology/lindera-tantivy) + [lindera-ko-dic-builder](https://github.com/lindera-morphology/lindera-ko-dic-builder)) @@ -54,11 +53,11 @@ Details about the benchmark can be found at this [repository](https://github.com - Searcher Warmer API - Cheesy logo with a horse -## Non-features +### Non-features Distributed search is out of the scope of Tantivy, but if you are looking for this feature, check out [Quickwit](https://github.com/quickwit-oss/quickwit/). -# Getting started +## Getting started Tantivy works on stable Rust and supports Linux, macOS, and Windows. @@ -68,7 +67,7 @@ index documents, and search via the CLI or a small server with a REST API. It walks you through getting a Wikipedia search engine up and running in a few minutes. - [Reference doc for the last released version](https://docs.rs/tantivy/) -# How can I support this project? +## How can I support this project? There are many ways to support this project. @@ -79,16 +78,16 @@ There are many ways to support this project. - Contribute code (you can join [our Discord server](https://discord.gg/MT27AG5EVE)) - Talk about Tantivy around you -# Contributing code +## Contributing code We use the GitHub Pull Request workflow: reference a GitHub ticket and/or include a comprehensive commit message when opening a PR. Feel free to update CHANGELOG.md with your contribution. -## Tokenizer +### Tokenizer When implementing a tokenizer for tantivy depend on the `tantivy-tokenizer-api` crate. -## Clone and build locally +### Clone and build locally Tantivy compiles on stable Rust. To check out and run tests, you can simply run: @@ -99,7 +98,7 @@ cd tantivy cargo test ``` -# Companies Using Tantivy +## Companies Using Tantivy

Etsy  @@ -111,7 +110,7 @@ cargo test Element.io

-# FAQ +## FAQ ### Can I use Tantivy in other languages? diff --git a/benches/index-bench.rs b/benches/index-bench.rs index 00a181982..f9ae63b68 100644 --- a/benches/index-bench.rs +++ b/benches/index-bench.rs @@ -1,4 +1,4 @@ -use criterion::{criterion_group, criterion_main, Criterion, Throughput}; +use criterion::{criterion_group, criterion_main, BatchSize, Bencher, Criterion, Throughput}; use tantivy::schema::{TantivyDocument, FAST, INDEXED, STORED, STRING, TEXT}; use tantivy::{tokenizer, Index, IndexWriter}; @@ -6,8 +6,94 @@ const HDFS_LOGS: &str = include_str!("hdfs.json"); const GH_LOGS: &str = include_str!("gh.json"); const WIKI: &str = include_str!("wiki.json"); -fn get_lines(input: &str) -> Vec<&str> { - input.trim().split('\n').collect() +fn benchmark( + b: &mut Bencher, + input: &str, + schema: tantivy::schema::Schema, + commit: bool, + parse_json: bool, + is_dynamic: bool, +) { + if is_dynamic { + benchmark_dynamic_json(b, input, schema, commit, parse_json) + } else { + _benchmark(b, input, schema, commit, parse_json, |schema, doc_json| { + TantivyDocument::parse_json(&schema, doc_json).unwrap() + }) + } +} + +fn get_index(schema: tantivy::schema::Schema) -> Index { + let mut index = Index::create_in_ram(schema.clone()); + let ff_tokenizer_manager = tokenizer::TokenizerManager::default(); + ff_tokenizer_manager.register( + "raw", + tokenizer::TextAnalyzer::builder(tokenizer::RawTokenizer::default()) + .filter(tokenizer::RemoveLongFilter::limit(255)) + .build(), + ); + index.set_fast_field_tokenizers(ff_tokenizer_manager.clone()); + index +} + +fn _benchmark( + b: &mut Bencher, + input: &str, + schema: tantivy::schema::Schema, + commit: bool, + include_json_parsing: bool, + create_doc: impl Fn(&tantivy::schema::Schema, &str) -> TantivyDocument, +) { + if include_json_parsing { + let lines: Vec<&str> = input.trim().split('\n').collect(); + b.iter(|| { + let index = get_index(schema.clone()); + let mut index_writer: IndexWriter = + index.writer_with_num_threads(1, 100_000_000).unwrap(); + for doc_json in &lines { + let doc = create_doc(&schema, doc_json); + index_writer.add_document(doc).unwrap(); + } + if commit { + index_writer.commit().unwrap(); + } + }) + } else { + let docs: Vec<_> = input + .trim() + .split('\n') + .map(|doc_json| create_doc(&schema, doc_json)) + .collect(); + b.iter_batched( + || docs.clone(), + |docs| { + let index = get_index(schema.clone()); + let mut index_writer: IndexWriter = + index.writer_with_num_threads(1, 100_000_000).unwrap(); + for doc in docs { + index_writer.add_document(doc).unwrap(); + } + if commit { + index_writer.commit().unwrap(); + } + }, + BatchSize::SmallInput, + ) + } +} +fn benchmark_dynamic_json( + b: &mut Bencher, + input: &str, + schema: tantivy::schema::Schema, + commit: bool, + parse_json: bool, +) { + let json_field = schema.get_field("json").unwrap(); + _benchmark(b, input, schema, commit, parse_json, |_schema, doc_json| { + let json_val: serde_json::Map = + serde_json::from_str(doc_json).unwrap(); + tantivy::doc!(json_field=>json_val) + }) } pub fn hdfs_index_benchmark(c: &mut Criterion) { @@ -25,7 +111,7 @@ pub fn hdfs_index_benchmark(c: &mut Criterion) { schema_builder.add_text_field("severity", FAST); schema_builder.build() }; - let schema_with_store = { + let _schema_with_store = { let mut schema_builder = tantivy::schema::SchemaBuilder::new(); schema_builder.add_u64_field("timestamp", INDEXED | STORED); schema_builder.add_text_field("body", TEXT | STORED); @@ -34,101 +120,39 @@ pub fn hdfs_index_benchmark(c: &mut Criterion) { }; let dynamic_schema = { let mut schema_builder = tantivy::schema::SchemaBuilder::new(); - schema_builder.add_json_field("json", TEXT); + schema_builder.add_json_field("json", TEXT | FAST); schema_builder.build() }; let mut group = c.benchmark_group("index-hdfs"); group.throughput(Throughput::Bytes(HDFS_LOGS.len() as u64)); group.sample_size(20); - group.bench_function("index-hdfs-no-commit", |b| { - let lines = get_lines(HDFS_LOGS); - b.iter(|| { - let index = Index::create_in_ram(schema.clone()); - let index_writer: IndexWriter = index.writer_with_num_threads(1, 100_000_000).unwrap(); - for doc_json in &lines { - let doc = TantivyDocument::parse_json(&schema, doc_json).unwrap(); - index_writer.add_document(doc).unwrap(); + + let benches = [ + ("only-indexed-".to_string(), schema, false), + //("stored-".to_string(), _schema_with_store, false), + ("only-fast-".to_string(), schema_only_fast, false), + ("dynamic-".to_string(), dynamic_schema, true), + ]; + + for (prefix, schema, is_dynamic) in benches { + for commit in [false, true] { + let suffix = if commit { "with-commit" } else { "no-commit" }; + for parse_json in [false] { + // for parse_json in [false, true] { + let suffix = if parse_json { + format!("{}-with-json-parsing", suffix) + } else { + format!("{}", suffix) + }; + + let bench_name = format!("{}{}", prefix, suffix); + group.bench_function(bench_name, |b| { + benchmark(b, HDFS_LOGS, schema.clone(), commit, parse_json, is_dynamic) + }); } - }) - }); - group.bench_function("index-hdfs-with-commit", |b| { - let lines = get_lines(HDFS_LOGS); - b.iter(|| { - let index = Index::create_in_ram(schema.clone()); - let mut index_writer: IndexWriter = - index.writer_with_num_threads(1, 100_000_000).unwrap(); - for doc_json in &lines { - let doc = TantivyDocument::parse_json(&schema, doc_json).unwrap(); - index_writer.add_document(doc).unwrap(); - } - index_writer.commit().unwrap(); - }) - }); - group.bench_function("index-hdfs-no-commit-with-docstore", |b| { - let lines = get_lines(HDFS_LOGS); - b.iter(|| { - let index = Index::create_in_ram(schema_with_store.clone()); - let index_writer: IndexWriter = index.writer_with_num_threads(1, 100_000_000).unwrap(); - for doc_json in &lines { - let doc = TantivyDocument::parse_json(&schema, doc_json).unwrap(); - index_writer.add_document(doc).unwrap(); - } - }) - }); - group.bench_function("index-hdfs-with-commit-with-docstore", |b| { - let lines = get_lines(HDFS_LOGS); - b.iter(|| { - let index = Index::create_in_ram(schema_with_store.clone()); - let mut index_writer: IndexWriter = - index.writer_with_num_threads(1, 100_000_000).unwrap(); - for doc_json in &lines { - let doc = TantivyDocument::parse_json(&schema, doc_json).unwrap(); - index_writer.add_document(doc).unwrap(); - } - index_writer.commit().unwrap(); - }) - }); - group.bench_function("index-hdfs-no-commit-fastfield", |b| { - let lines = get_lines(HDFS_LOGS); - b.iter(|| { - let index = Index::create_in_ram(schema_only_fast.clone()); - let index_writer: IndexWriter = index.writer_with_num_threads(1, 100_000_000).unwrap(); - for doc_json in &lines { - let doc = TantivyDocument::parse_json(&schema, doc_json).unwrap(); - index_writer.add_document(doc).unwrap(); - } - }) - }); - group.bench_function("index-hdfs-with-commit-fastfield", |b| { - let lines = get_lines(HDFS_LOGS); - b.iter(|| { - let index = Index::create_in_ram(schema_only_fast.clone()); - let mut index_writer: IndexWriter = - index.writer_with_num_threads(1, 100_000_000).unwrap(); - for doc_json in &lines { - let doc = TantivyDocument::parse_json(&schema, doc_json).unwrap(); - index_writer.add_document(doc).unwrap(); - } - index_writer.commit().unwrap(); - }) - }); - group.bench_function("index-hdfs-no-commit-json-without-docstore", |b| { - let lines = get_lines(HDFS_LOGS); - b.iter(|| { - let index = Index::create_in_ram(dynamic_schema.clone()); - let json_field = dynamic_schema.get_field("json").unwrap(); - let mut index_writer: IndexWriter = - index.writer_with_num_threads(1, 100_000_000).unwrap(); - for doc_json in &lines { - let json_val: serde_json::Map = - serde_json::from_str(doc_json).unwrap(); - let doc = tantivy::doc!(json_field=>json_val); - index_writer.add_document(doc).unwrap(); - } - index_writer.commit().unwrap(); - }) - }); + } + } } pub fn gh_index_benchmark(c: &mut Criterion) { @@ -142,64 +166,19 @@ pub fn gh_index_benchmark(c: &mut Criterion) { schema_builder.add_json_field("json", FAST); schema_builder.build() }; - let ff_tokenizer_manager = tokenizer::TokenizerManager::default(); - ff_tokenizer_manager.register( - "raw", - tokenizer::TextAnalyzer::builder(tokenizer::RawTokenizer::default()) - .filter(tokenizer::RemoveLongFilter::limit(255)) - .build(), - ); let mut group = c.benchmark_group("index-gh"); group.throughput(Throughput::Bytes(GH_LOGS.len() as u64)); group.bench_function("index-gh-no-commit", |b| { - let lines = get_lines(GH_LOGS); - b.iter(|| { - let json_field = dynamic_schema.get_field("json").unwrap(); - let mut index = Index::create_in_ram(dynamic_schema.clone()); - index.set_fast_field_tokenizers(ff_tokenizer_manager.clone()); - let index_writer: IndexWriter = index.writer_with_num_threads(1, 100_000_000).unwrap(); - for doc_json in &lines { - let json_val: serde_json::Map = - serde_json::from_str(doc_json).unwrap(); - let doc = tantivy::doc!(json_field=>json_val); - index_writer.add_document(doc).unwrap(); - } - }) + benchmark_dynamic_json(b, GH_LOGS, dynamic_schema.clone(), false, false) }); group.bench_function("index-gh-fast", |b| { - let lines = get_lines(GH_LOGS); - b.iter(|| { - let json_field = dynamic_schema_fast.get_field("json").unwrap(); - let mut index = Index::create_in_ram(dynamic_schema_fast.clone()); - index.set_fast_field_tokenizers(ff_tokenizer_manager.clone()); - let index_writer: IndexWriter = index.writer_with_num_threads(1, 100_000_000).unwrap(); - for doc_json in &lines { - let json_val: serde_json::Map = - serde_json::from_str(doc_json).unwrap(); - let doc = tantivy::doc!(json_field=>json_val); - index_writer.add_document(doc).unwrap(); - } - }) + benchmark_dynamic_json(b, GH_LOGS, dynamic_schema_fast.clone(), false, false) }); - group.bench_function("index-gh-with-commit", |b| { - let lines = get_lines(GH_LOGS); - b.iter(|| { - let json_field = dynamic_schema.get_field("json").unwrap(); - let mut index = Index::create_in_ram(dynamic_schema.clone()); - index.set_fast_field_tokenizers(ff_tokenizer_manager.clone()); - let mut index_writer: IndexWriter = - index.writer_with_num_threads(1, 100_000_000).unwrap(); - for doc_json in &lines { - let json_val: serde_json::Map = - serde_json::from_str(doc_json).unwrap(); - let doc = tantivy::doc!(json_field=>json_val); - index_writer.add_document(doc).unwrap(); - } - index_writer.commit().unwrap(); - }) + group.bench_function("index-gh-fast-with-commit", |b| { + benchmark_dynamic_json(b, GH_LOGS, dynamic_schema_fast.clone(), true, false) }); } @@ -214,34 +193,10 @@ pub fn wiki_index_benchmark(c: &mut Criterion) { group.throughput(Throughput::Bytes(WIKI.len() as u64)); group.bench_function("index-wiki-no-commit", |b| { - let lines = get_lines(WIKI); - b.iter(|| { - let json_field = dynamic_schema.get_field("json").unwrap(); - let index = Index::create_in_ram(dynamic_schema.clone()); - let index_writer: IndexWriter = index.writer_with_num_threads(1, 100_000_000).unwrap(); - for doc_json in &lines { - let json_val: serde_json::Map = - serde_json::from_str(doc_json).unwrap(); - let doc = tantivy::doc!(json_field=>json_val); - index_writer.add_document(doc).unwrap(); - } - }) + benchmark_dynamic_json(b, WIKI, dynamic_schema.clone(), false, false) }); group.bench_function("index-wiki-with-commit", |b| { - let lines = get_lines(WIKI); - b.iter(|| { - let json_field = dynamic_schema.get_field("json").unwrap(); - let index = Index::create_in_ram(dynamic_schema.clone()); - let mut index_writer: IndexWriter = - index.writer_with_num_threads(1, 100_000_000).unwrap(); - for doc_json in &lines { - let json_val: serde_json::Map = - serde_json::from_str(doc_json).unwrap(); - let doc = tantivy::doc!(json_field=>json_val); - index_writer.add_document(doc).unwrap(); - } - index_writer.commit().unwrap(); - }) + benchmark_dynamic_json(b, WIKI, dynamic_schema.clone(), true, false) }); } diff --git a/bitpacker/Cargo.toml b/bitpacker/Cargo.toml index 7b1283293..104f5f805 100644 --- a/bitpacker/Cargo.toml +++ b/bitpacker/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tantivy-bitpacker" -version = "0.5.0" +version = "0.6.0" edition = "2021" authors = ["Paul Masurel "] license = "MIT" @@ -15,7 +15,7 @@ homepage = "https://github.com/quickwit-oss/tantivy" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -bitpacking = {version="0.8", default-features=false, features = ["bitpacker1x"]} +bitpacking = { version = "0.9.2", default-features = false, features = ["bitpacker1x"] } [dev-dependencies] rand = "0.8" diff --git a/bitpacker/src/bitpacker.rs b/bitpacker/src/bitpacker.rs index 903daccf8..11ea37566 100644 --- a/bitpacker/src/bitpacker.rs +++ b/bitpacker/src/bitpacker.rs @@ -1,4 +1,3 @@ -use std::convert::TryInto; use std::io; use std::ops::{Range, RangeInclusive}; diff --git a/cliff.toml b/cliff.toml index 99bd50662..03424f52b 100644 --- a/cliff.toml +++ b/cliff.toml @@ -1,6 +1,10 @@ # configuration file for git-cliff{ pattern = "foo", replace = "bar"} # see https://github.com/orhun/git-cliff#configuration-file +[remote.github] +owner = "quickwit-oss" +repo = "tantivy" + [changelog] # changelog header header = """ @@ -8,15 +12,43 @@ header = """ # template for the changelog body # https://tera.netlify.app/docs/#introduction body = """ -{% if version %}\ - {{ version | trim_start_matches(pat="v") }} ({{ timestamp | date(format="%Y-%m-%d") }}) - ================== -{% else %}\ - ## [unreleased] -{% endif %}\ +## What's Changed + +{%- if version %} in {{ version }}{%- endif -%} {% for commit in commits %} - - {% if commit.breaking %}[**breaking**] {% endif %}{{ commit.message | split(pat="\n") | first | trim | upper_first }}(@{{ commit.author.name }})\ -{% endfor %} + {% if commit.github.pr_title -%} + {%- set commit_message = commit.github.pr_title -%} + {%- else -%} + {%- set commit_message = commit.message -%} + {%- endif -%} + - {{ commit_message | split(pat="\n") | first | trim }}\ + {% if commit.github.pr_number %} \ + [#{{ commit.github.pr_number }}]({{ self::remote_url() }}/pull/{{ commit.github.pr_number }}){% if commit.github.username %}(@{{ commit.github.username }}){%- endif -%} \ + {%- endif %} +{%- endfor -%} + +{% if github.contributors | filter(attribute="is_first_time", value=true) | length != 0 %} + {% raw %}\n{% endraw -%} + ## New Contributors +{%- endif %}\ +{% for contributor in github.contributors | filter(attribute="is_first_time", value=true) %} + * @{{ contributor.username }} made their first contribution + {%- if contributor.pr_number %} in \ + [#{{ contributor.pr_number }}]({{ self::remote_url() }}/pull/{{ contributor.pr_number }}) \ + {%- endif %} +{%- endfor -%} + +{% if version %} + {% if previous.version %} + **Full Changelog**: {{ self::remote_url() }}/compare/{{ previous.version }}...{{ version }} + {% endif %} +{% else -%} + {% raw %}\n{% endraw %} +{% endif %} + +{%- macro remote_url() -%} + https://github.com/{{ remote.github.owner }}/{{ remote.github.repo }} +{%- endmacro -%} """ # remove the leading and trailing whitespace from the template trim = true @@ -25,53 +57,24 @@ footer = """ """ postprocessors = [ - { pattern = 'Paul Masurel', replace = "fulmicoton"}, # replace with github user - { pattern = 'PSeitz', replace = "PSeitz"}, # replace with github user - { pattern = 'Adam Reichold', replace = "adamreichold"}, # replace with github user - { pattern = 'trinity-1686a', replace = "trinity-1686a"}, # replace with github user - { pattern = 'Michael Kleen', replace = "mkleen"}, # replace with github user - { pattern = 'Adrien Guillo', replace = "guilload"}, # replace with github user - { pattern = 'François Massot', replace = "fmassot"}, # replace with github user - { pattern = 'Naveen Aiathurai', replace = "naveenann"}, # replace with github user - { pattern = '', replace = ""}, # replace with github user ] [git] # parse the commits based on https://www.conventionalcommits.org # This is required or commit.message contains the whole commit message and not just the title -conventional_commits = true +conventional_commits = false # filter out the commits that are not conventional -filter_unconventional = false +filter_unconventional = true # process each line of a commit as an individual commit split_commits = false # regex for preprocessing the commit messages commit_preprocessors = [ - { pattern = '\((\w+\s)?#([0-9]+)\)', replace = "[#${2}](https://github.com/quickwit-oss/tantivy/issues/${2})"}, # replace issue numbers + { pattern = '\((\w+\s)?#([0-9]+)\)', replace = ""}, ] #link_parsers = [ #{ pattern = "#(\\d+)", href = "https://github.com/quickwit-oss/tantivy/pulls/$1"}, #] # regex for parsing and grouping commits -commit_parsers = [ - { message = "^feat", group = "Features"}, - { message = "^fix", group = "Bug Fixes"}, - { message = "^doc", group = "Documentation"}, - { message = "^perf", group = "Performance"}, - { message = "^refactor", group = "Refactor"}, - { message = "^style", group = "Styling"}, - { message = "^test", group = "Testing"}, - { message = "^chore\\(release\\): prepare for", skip = true}, - { message = "(?i)clippy", skip = true}, - { message = "(?i)dependabot", skip = true}, - { message = "(?i)fmt", skip = true}, - { message = "(?i)bump", skip = true}, - { message = "(?i)readme", skip = true}, - { message = "(?i)comment", skip = true}, - { message = "(?i)spelling", skip = true}, - { message = "^chore", group = "Miscellaneous Tasks"}, - { body = ".*security", group = "Security"}, - { message = ".*", group = "Other", default_scope = "other"}, -] # protect breaking changes from being skipped due to matching a skipping commit_parser protect_breaking_commits = false # filter out the commits that are not matched by commit parsers diff --git a/columnar/Cargo.toml b/columnar/Cargo.toml index c100f185a..36a5a55d5 100644 --- a/columnar/Cargo.toml +++ b/columnar/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tantivy-columnar" -version = "0.2.0" +version = "0.3.0" edition = "2021" license = "MIT" homepage = "https://github.com/quickwit-oss/tantivy" @@ -12,11 +12,12 @@ categories = ["database-implementations", "data-structures", "compression"] itertools = "0.12.0" fastdivide = "0.4.0" -stacker = { version= "0.2", path = "../stacker", package="tantivy-stacker"} -sstable = { version= "0.2", path = "../sstable", package = "tantivy-sstable" } -common = { version= "0.6", path = "../common", package = "tantivy-common" } -tantivy-bitpacker = { version= "0.5", path = "../bitpacker/" } +stacker = { version= "0.3", path = "../stacker", package="tantivy-stacker"} +sstable = { version= "0.3", path = "../sstable", package = "tantivy-sstable" } +common = { version= "0.7", path = "../common", package = "tantivy-common" } +tantivy-bitpacker = { version= "0.6", path = "../bitpacker/" } serde = "1.0.152" +downcast-rs = "1.2.0" [dev-dependencies] proptest = "1" diff --git a/columnar/benches/bench_first_vals.rs b/columnar/benches/bench_first_vals.rs new file mode 100644 index 000000000..b7bc49dc7 --- /dev/null +++ b/columnar/benches/bench_first_vals.rs @@ -0,0 +1,155 @@ +#![feature(test)] +extern crate test; + +use std::sync::Arc; + +use rand::prelude::*; +use tantivy_columnar::column_values::{serialize_and_load_u64_based_column_values, CodecType}; +use tantivy_columnar::*; +use test::{black_box, Bencher}; + +struct Columns { + pub optional: Column, + pub full: Column, + pub multi: Column, +} + +fn get_test_columns() -> Columns { + let data = generate_permutation(); + let mut dataframe_writer = ColumnarWriter::default(); + for (idx, val) in data.iter().enumerate() { + dataframe_writer.record_numerical(idx as u32, "full_values", NumericalValue::U64(*val)); + if idx % 2 == 0 { + dataframe_writer.record_numerical( + idx as u32, + "optional_values", + NumericalValue::U64(*val), + ); + } + dataframe_writer.record_numerical(idx as u32, "multi_values", NumericalValue::U64(*val)); + dataframe_writer.record_numerical(idx as u32, "multi_values", NumericalValue::U64(*val)); + } + let mut buffer: Vec = Vec::new(); + dataframe_writer + .serialize(data.len() as u32, None, &mut buffer) + .unwrap(); + let columnar = ColumnarReader::open(buffer).unwrap(); + + let cols: Vec = columnar.read_columns("optional_values").unwrap(); + assert_eq!(cols.len(), 1); + let optional = cols[0].open_u64_lenient().unwrap().unwrap(); + assert_eq!(optional.index.get_cardinality(), Cardinality::Optional); + + let cols: Vec = columnar.read_columns("full_values").unwrap(); + assert_eq!(cols.len(), 1); + let column_full = cols[0].open_u64_lenient().unwrap().unwrap(); + assert_eq!(column_full.index.get_cardinality(), Cardinality::Full); + + let cols: Vec = columnar.read_columns("multi_values").unwrap(); + assert_eq!(cols.len(), 1); + let multi = cols[0].open_u64_lenient().unwrap().unwrap(); + assert_eq!(multi.index.get_cardinality(), Cardinality::Multivalued); + + Columns { + optional, + full: column_full, + multi, + } +} + +const NUM_VALUES: u64 = 100_000; +fn generate_permutation() -> Vec { + let mut permutation: Vec = (0u64..NUM_VALUES).collect(); + permutation.shuffle(&mut StdRng::from_seed([1u8; 32])); + permutation +} + +pub fn serialize_and_load(column: &[u64], codec_type: CodecType) -> Arc> { + serialize_and_load_u64_based_column_values(&column, &[codec_type]) +} + +fn run_bench_on_column_full_scan(b: &mut Bencher, column: Column) { + let num_iter = black_box(NUM_VALUES); + b.iter(|| { + let mut sum = 0u64; + for i in 0..num_iter as u32 { + let val = column.first(i); + sum += val.unwrap_or(0); + } + sum + }); +} +fn run_bench_on_column_block_fetch(b: &mut Bencher, column: Column) { + let mut block: Vec> = vec![None; 64]; + let fetch_docids = (0..64).collect::>(); + b.iter(move || { + column.first_vals(&fetch_docids, &mut block); + block[0] + }); +} +fn run_bench_on_column_block_single_calls(b: &mut Bencher, column: Column) { + let mut block: Vec> = vec![None; 64]; + let fetch_docids = (0..64).collect::>(); + b.iter(move || { + for i in 0..fetch_docids.len() { + block[i] = column.first(fetch_docids[i]); + } + block[0] + }); +} + +/// Column first method +#[bench] +fn bench_get_first_on_full_column_full_scan(b: &mut Bencher) { + let column = get_test_columns().full; + run_bench_on_column_full_scan(b, column); +} + +#[bench] +fn bench_get_first_on_optional_column_full_scan(b: &mut Bencher) { + let column = get_test_columns().optional; + run_bench_on_column_full_scan(b, column); +} + +#[bench] +fn bench_get_first_on_multi_column_full_scan(b: &mut Bencher) { + let column = get_test_columns().multi; + run_bench_on_column_full_scan(b, column); +} + +/// Block fetch column accessor +#[bench] +fn bench_get_block_first_on_optional_column(b: &mut Bencher) { + let column = get_test_columns().optional; + run_bench_on_column_block_fetch(b, column); +} + +#[bench] +fn bench_get_block_first_on_multi_column(b: &mut Bencher) { + let column = get_test_columns().multi; + run_bench_on_column_block_fetch(b, column); +} + +#[bench] +fn bench_get_block_first_on_full_column(b: &mut Bencher) { + let column = get_test_columns().full; + run_bench_on_column_block_fetch(b, column); +} + +#[bench] +fn bench_get_block_first_on_optional_column_single_calls(b: &mut Bencher) { + let column = get_test_columns().optional; + run_bench_on_column_block_single_calls(b, column); +} + +#[bench] +fn bench_get_block_first_on_multi_column_single_calls(b: &mut Bencher) { + let column = get_test_columns().multi; + run_bench_on_column_block_single_calls(b, column); +} + +#[bench] +fn bench_get_block_first_on_full_column_single_calls(b: &mut Bencher) { + let column = get_test_columns().full; + run_bench_on_column_block_single_calls(b, column); +} diff --git a/columnar/benches/bench_u128.rs b/columnar/benches/bench_values_u128.rs similarity index 100% rename from columnar/benches/bench_u128.rs rename to columnar/benches/bench_values_u128.rs diff --git a/columnar/benches/bench_u64.rs b/columnar/benches/bench_values_u64.rs similarity index 96% rename from columnar/benches/bench_u64.rs rename to columnar/benches/bench_values_u64.rs index 556cb8f02..313a85754 100644 --- a/columnar/benches/bench_u64.rs +++ b/columnar/benches/bench_values_u64.rs @@ -16,14 +16,6 @@ fn generate_permutation() -> Vec { permutation } -fn generate_random() -> Vec { - let mut permutation: Vec = (0u64..100_000u64) - .map(|el| el + random::() as u64) - .collect(); - permutation.shuffle(&mut StdRng::from_seed([1u8; 32])); - permutation -} - // Warning: this generates the same permutation at each call fn generate_permutation_gcd() -> Vec { let mut permutation: Vec = (1u64..100_000u64).map(|el| el * 1000).collect(); diff --git a/columnar/src/block_accessor.rs b/columnar/src/block_accessor.rs index 378f36104..d746c598a 100644 --- a/columnar/src/block_accessor.rs +++ b/columnar/src/block_accessor.rs @@ -14,20 +14,32 @@ impl ColumnBlockAccessor { #[inline] - pub fn fetch_block(&mut self, docs: &[u32], accessor: &Column) { - self.docid_cache.clear(); - self.row_id_cache.clear(); - accessor.row_ids_for_docs(docs, &mut self.docid_cache, &mut self.row_id_cache); - self.val_cache.resize(self.row_id_cache.len(), T::default()); - accessor - .values - .get_vals(&self.row_id_cache, &mut self.val_cache); + pub fn fetch_block<'a>(&'a mut self, docs: &'a [u32], accessor: &Column) { + if accessor.index.get_cardinality().is_full() { + self.val_cache.resize(docs.len(), T::default()); + accessor.values.get_vals(docs, &mut self.val_cache); + } else { + self.docid_cache.clear(); + self.row_id_cache.clear(); + accessor.row_ids_for_docs(docs, &mut self.docid_cache, &mut self.row_id_cache); + self.val_cache.resize(self.row_id_cache.len(), T::default()); + accessor + .values + .get_vals(&self.row_id_cache, &mut self.val_cache); + } } #[inline] pub fn fetch_block_with_missing(&mut self, docs: &[u32], accessor: &Column, missing: T) { self.fetch_block(docs, accessor); - // We can compare docid_cache with docs to find missing docs - if docs.len() != self.docid_cache.len() || accessor.index.is_multivalue() { + // no missing values + if accessor.index.get_cardinality().is_full() { + return; + } + + // We can compare docid_cache length with docs to find missing docs + // For multi value columns we can't rely on the length and always need to scan + if accessor.index.get_cardinality().is_multivalue() || docs.len() != self.docid_cache.len() + { self.missing_docids_cache.clear(); find_missing_docs(docs, &self.docid_cache, |doc| { self.missing_docids_cache.push(doc); @@ -44,11 +56,25 @@ impl } #[inline] - pub fn iter_docid_vals(&self) -> impl Iterator + '_ { - self.docid_cache - .iter() - .cloned() - .zip(self.val_cache.iter().cloned()) + /// Returns an iterator over the docids and values + /// The passed in `docs` slice needs to be the same slice that was passed to `fetch_block` or + /// `fetch_block_with_missing`. + /// + /// The docs is used if the column is full (each docs has exactly one value), otherwise the + /// internal docid vec is used for the iterator, which e.g. may contain duplicate docs. + pub fn iter_docid_vals<'a>( + &'a self, + docs: &'a [u32], + accessor: &Column, + ) -> impl Iterator + '_ { + if accessor.index.get_cardinality().is_full() { + docs.iter().cloned().zip(self.val_cache.iter().cloned()) + } else { + self.docid_cache + .iter() + .cloned() + .zip(self.val_cache.iter().cloned()) + } } } diff --git a/columnar/src/column/mod.rs b/columnar/src/column/mod.rs index 37db03e1b..dd6dc0f21 100644 --- a/columnar/src/column/mod.rs +++ b/columnar/src/column/mod.rs @@ -3,17 +3,17 @@ mod serialize; use std::fmt::{self, Debug}; use std::io::Write; -use std::ops::{Deref, Range, RangeInclusive}; +use std::ops::{Range, RangeInclusive}; use std::sync::Arc; use common::BinarySerializable; pub use dictionary_encoded::{BytesColumn, StrColumn}; pub use serialize::{ - open_column_bytes, open_column_str, open_column_u128, open_column_u64, - serialize_column_mappable_to_u128, serialize_column_mappable_to_u64, + open_column_bytes, open_column_str, open_column_u128, open_column_u128_as_compact_u64, + open_column_u64, serialize_column_mappable_to_u128, serialize_column_mappable_to_u64, }; -use crate::column_index::ColumnIndex; +use crate::column_index::{ColumnIndex, Set}; use crate::column_values::monotonic_mapping::StrictlyMonotonicMappingToInternal; use crate::column_values::{monotonic_map_column, ColumnValues}; use crate::{Cardinality, DocId, EmptyColumnValues, MonotonicallyMappableToU64, RowId}; @@ -83,10 +83,36 @@ impl Column { self.values.max_value() } + #[inline] pub fn first(&self, row_id: RowId) -> Option { self.values_for_doc(row_id).next() } + /// Load the first value for each docid in the provided slice. + #[inline] + pub fn first_vals(&self, docids: &[DocId], output: &mut [Option]) { + match &self.index { + ColumnIndex::Empty { .. } => {} + ColumnIndex::Full => self.values.get_vals_opt(docids, output), + ColumnIndex::Optional(optional_index) => { + for (i, docid) in docids.iter().enumerate() { + output[i] = optional_index + .rank_if_exists(*docid) + .map(|rowid| self.values.get_val(rowid)); + } + } + ColumnIndex::Multivalued(multivalued_index) => { + for (i, docid) in docids.iter().enumerate() { + let range = multivalued_index.range(*docid); + let is_empty = range.start == range.end; + if !is_empty { + output[i] = Some(self.values.get_val(range.start)); + } + } + } + } + } + /// Translates a block of docis to row_ids. /// /// returns the row_ids and the matching docids on the same index @@ -105,7 +131,8 @@ impl Column { } pub fn values_for_doc(&self, doc_id: DocId) -> impl Iterator + '_ { - self.value_row_ids(doc_id) + self.index + .value_row_ids(doc_id) .map(|value_row_id: RowId| self.values.get_val(value_row_id)) } @@ -147,14 +174,6 @@ impl Column { } } -impl Deref for Column { - type Target = ColumnIndex; - - fn deref(&self) -> &Self::Target { - &self.index - } -} - impl BinarySerializable for Cardinality { fn serialize(&self, writer: &mut W) -> std::io::Result<()> { self.to_code().serialize(writer) @@ -176,6 +195,7 @@ struct FirstValueWithDefault { impl ColumnValues for FirstValueWithDefault { + #[inline(always)] fn get_val(&self, idx: u32) -> T { self.column.first(idx).unwrap_or(self.default_value) } diff --git a/columnar/src/column/serialize.rs b/columnar/src/column/serialize.rs index 5b6b0efc5..4198487bb 100644 --- a/columnar/src/column/serialize.rs +++ b/columnar/src/column/serialize.rs @@ -76,6 +76,26 @@ pub fn open_column_u128( }) } +/// Open the column as u64. +/// +/// See [`open_u128_as_compact_u64`] for more details. +pub fn open_column_u128_as_compact_u64(bytes: OwnedBytes) -> io::Result> { + let (body, column_index_num_bytes_payload) = bytes.rsplit(4); + let column_index_num_bytes = u32::from_le_bytes( + column_index_num_bytes_payload + .as_slice() + .try_into() + .unwrap(), + ); + let (column_index_data, column_values_data) = body.split(column_index_num_bytes as usize); + let column_index = crate::column_index::open_column_index(column_index_data)?; + let column_values = crate::column_values::open_u128_as_compact_u64(column_values_data)?; + Ok(Column { + index: column_index, + values: column_values, + }) +} + pub fn open_column_bytes(data: OwnedBytes) -> io::Result { let (body, dictionary_len_bytes) = data.rsplit(4); let dictionary_len = u32::from_le_bytes(dictionary_len_bytes.as_slice().try_into().unwrap()); diff --git a/columnar/src/column_index/merge/shuffled.rs b/columnar/src/column_index/merge/shuffled.rs index 6acf199ff..f93b89635 100644 --- a/columnar/src/column_index/merge/shuffled.rs +++ b/columnar/src/column_index/merge/shuffled.rs @@ -140,7 +140,7 @@ mod tests { #[test] fn test_merge_column_index_optional_shuffle() { let optional_index: ColumnIndex = OptionalIndex::for_test(2, &[0]).into(); - let column_indexes = vec![optional_index, ColumnIndex::Full]; + let column_indexes = [optional_index, ColumnIndex::Full]; let row_addrs = vec![ RowAddr { segment_ord: 0u32, diff --git a/columnar/src/column_index/merge/stacked.rs b/columnar/src/column_index/merge/stacked.rs index 9ef294b60..ba91b8d64 100644 --- a/columnar/src/column_index/merge/stacked.rs +++ b/columnar/src/column_index/merge/stacked.rs @@ -111,10 +111,7 @@ fn stack_multivalued_indexes<'a>( let mut last_row_id = 0; let mut current_it = multivalued_indexes.next(); Box::new(std::iter::from_fn(move || loop { - let Some(multivalued_index) = current_it.as_mut() else { - return None; - }; - if let Some(row_id) = multivalued_index.next() { + if let Some(row_id) = current_it.as_mut()?.next() { last_row_id = offset + row_id; return Some(last_row_id); } diff --git a/columnar/src/column_index/mod.rs b/columnar/src/column_index/mod.rs index d6711566d..f52e26ff4 100644 --- a/columnar/src/column_index/mod.rs +++ b/columnar/src/column_index/mod.rs @@ -42,10 +42,6 @@ impl From for ColumnIndex { } impl ColumnIndex { - #[inline] - pub fn is_multivalue(&self) -> bool { - matches!(self, ColumnIndex::Multivalued(_)) - } /// Returns the cardinality of the column index. /// /// By convention, if the column contains no docs, we consider that it is @@ -126,18 +122,18 @@ impl ColumnIndex { } } - pub fn docid_range_to_rowids(&self, doc_id: Range) -> Range { + pub fn docid_range_to_rowids(&self, doc_id_range: Range) -> Range { match self { ColumnIndex::Empty { .. } => 0..0, - ColumnIndex::Full => doc_id, + ColumnIndex::Full => doc_id_range, ColumnIndex::Optional(optional_index) => { - let row_start = optional_index.rank(doc_id.start); - let row_end = optional_index.rank(doc_id.end); + let row_start = optional_index.rank(doc_id_range.start); + let row_end = optional_index.rank(doc_id_range.end); row_start..row_end } ColumnIndex::Multivalued(multivalued_index) => { - let end_docid = doc_id.end.min(multivalued_index.num_docs() - 1) + 1; - let start_docid = doc_id.start.min(end_docid); + let end_docid = doc_id_range.end.min(multivalued_index.num_docs() - 1) + 1; + let start_docid = doc_id_range.start.min(end_docid); let row_start = multivalued_index.start_index_column.get_val(start_docid); let row_end = multivalued_index.start_index_column.get_val(end_docid); diff --git a/columnar/src/column_index/optional_index/mod.rs b/columnar/src/column_index/optional_index/mod.rs index e885ee5bc..d48415284 100644 --- a/columnar/src/column_index/optional_index/mod.rs +++ b/columnar/src/column_index/optional_index/mod.rs @@ -21,8 +21,6 @@ const DENSE_BLOCK_THRESHOLD: u32 = const ELEMENTS_PER_BLOCK: u32 = u16::MAX as u32 + 1; -const BLOCK_SIZE: RowId = 1 << 16; - #[derive(Copy, Clone, Debug)] struct BlockMeta { non_null_rows_before_block: u32, @@ -109,8 +107,8 @@ struct RowAddr { #[inline(always)] fn row_addr_from_row_id(row_id: RowId) -> RowAddr { RowAddr { - block_id: (row_id / BLOCK_SIZE) as u16, - in_block_row_id: (row_id % BLOCK_SIZE) as u16, + block_id: (row_id / ELEMENTS_PER_BLOCK) as u16, + in_block_row_id: (row_id % ELEMENTS_PER_BLOCK) as u16, } } @@ -185,8 +183,13 @@ impl Set for OptionalIndex { } } + /// Any value doc_id is allowed. + /// In particular, doc_id = num_rows. #[inline] fn rank(&self, doc_id: DocId) -> RowId { + if doc_id >= self.num_docs() { + return self.num_non_nulls(); + } let RowAddr { block_id, in_block_row_id, @@ -200,13 +203,15 @@ impl Set for OptionalIndex { block_meta.non_null_rows_before_block + block_offset_row_id } + /// Any value doc_id is allowed. + /// In particular, doc_id = num_rows. #[inline] fn rank_if_exists(&self, doc_id: DocId) -> Option { let RowAddr { block_id, in_block_row_id, } = row_addr_from_row_id(doc_id); - let block_meta = self.block_metas[block_id as usize]; + let block_meta = *self.block_metas.get(block_id as usize)?; let block = self.block(block_meta); let block_offset_row_id = match block { Block::Dense(dense_block) => dense_block.rank_if_exists(in_block_row_id), @@ -491,7 +496,7 @@ fn deserialize_optional_index_block_metadatas( non_null_rows_before_block += num_non_null_rows; } block_metas.resize( - ((num_rows + BLOCK_SIZE - 1) / BLOCK_SIZE) as usize, + ((num_rows + ELEMENTS_PER_BLOCK - 1) / ELEMENTS_PER_BLOCK) as usize, BlockMeta { non_null_rows_before_block, start_byte_offset, diff --git a/columnar/src/column_index/optional_index/set.rs b/columnar/src/column_index/optional_index/set.rs index 15b527dbd..b2bf9cbe2 100644 --- a/columnar/src/column_index/optional_index/set.rs +++ b/columnar/src/column_index/optional_index/set.rs @@ -39,7 +39,8 @@ pub trait Set { /// /// # Panics /// - /// May panic if rank is greater than the number of elements in the Set. + /// May panic if rank is greater or equal to the number of + /// elements in the Set. fn select(&self, rank: T) -> T; /// Creates a brand new select cursor. diff --git a/columnar/src/column_index/optional_index/set_block/dense.rs b/columnar/src/column_index/optional_index/set_block/dense.rs index 8d041e441..08ca31b19 100644 --- a/columnar/src/column_index/optional_index/set_block/dense.rs +++ b/columnar/src/column_index/optional_index/set_block/dense.rs @@ -1,4 +1,3 @@ -use std::convert::TryInto; use std::io::{self, Write}; use common::BinarySerializable; diff --git a/columnar/src/column_index/optional_index/tests.rs b/columnar/src/column_index/optional_index/tests.rs index 2acc5f6e6..d25f267c2 100644 --- a/columnar/src/column_index/optional_index/tests.rs +++ b/columnar/src/column_index/optional_index/tests.rs @@ -1,8 +1,31 @@ -use proptest::prelude::{any, prop, *}; -use proptest::strategy::Strategy; +use proptest::prelude::*; use proptest::{prop_oneof, proptest}; use super::*; +use crate::{ColumnarReader, ColumnarWriter, DynamicColumnHandle}; + +#[test] +fn test_optional_index_bug_2293() { + // tests for panic in docid_range_to_rowids for docid == num_docs + test_optional_index_with_num_docs(ELEMENTS_PER_BLOCK - 1); + test_optional_index_with_num_docs(ELEMENTS_PER_BLOCK); + test_optional_index_with_num_docs(ELEMENTS_PER_BLOCK + 1); +} +fn test_optional_index_with_num_docs(num_docs: u32) { + let mut dataframe_writer = ColumnarWriter::default(); + dataframe_writer.record_numerical(100, "score", 80i64); + let mut buffer: Vec = Vec::new(); + dataframe_writer + .serialize(num_docs, None, &mut buffer) + .unwrap(); + let columnar = ColumnarReader::open(buffer).unwrap(); + assert_eq!(columnar.num_columns(), 1); + let cols: Vec = columnar.read_columns("score").unwrap(); + assert_eq!(cols.len(), 1); + + let col = cols[0].open().unwrap(); + col.column_index().docid_range_to_rowids(0..num_docs); +} #[test] fn test_dense_block_threshold() { @@ -35,7 +58,7 @@ proptest! { #[test] fn test_with_random_sets_simple() { - let vals = 10..BLOCK_SIZE * 2; + let vals = 10..ELEMENTS_PER_BLOCK * 2; let mut out: Vec = Vec::new(); serialize_optional_index(&vals, 100, &mut out).unwrap(); let null_index = open_optional_index(OwnedBytes::new(out)).unwrap(); @@ -171,7 +194,7 @@ fn test_optional_index_rank() { test_optional_index_rank_aux(&[0u32, 1u32]); let mut block = Vec::new(); block.push(3u32); - block.extend((0..BLOCK_SIZE).map(|i| i + BLOCK_SIZE + 1)); + block.extend((0..ELEMENTS_PER_BLOCK).map(|i| i + ELEMENTS_PER_BLOCK + 1)); test_optional_index_rank_aux(&block); } @@ -185,8 +208,8 @@ fn test_optional_index_iter_empty_one() { fn test_optional_index_iter_dense_block() { let mut block = Vec::new(); block.push(3u32); - block.extend((0..BLOCK_SIZE).map(|i| i + BLOCK_SIZE + 1)); - test_optional_index_iter_aux(&block, 3 * BLOCK_SIZE); + block.extend((0..ELEMENTS_PER_BLOCK).map(|i| i + ELEMENTS_PER_BLOCK + 1)); + test_optional_index_iter_aux(&block, 3 * ELEMENTS_PER_BLOCK); } #[test] diff --git a/columnar/src/column_values/merge.rs b/columnar/src/column_values/merge.rs index ff3d657f4..a3b2df18a 100644 --- a/columnar/src/column_values/merge.rs +++ b/columnar/src/column_values/merge.rs @@ -10,7 +10,7 @@ pub(crate) struct MergedColumnValues<'a, T> { pub(crate) merge_row_order: &'a MergeRowOrder, } -impl<'a, T: Copy + PartialOrd + Debug> Iterable for MergedColumnValues<'a, T> { +impl<'a, T: Copy + PartialOrd + Debug + 'static> Iterable for MergedColumnValues<'a, T> { fn boxed_iter(&self) -> Box + '_> { match self.merge_row_order { MergeRowOrder::Stack(_) => Box::new( diff --git a/columnar/src/column_values/mod.rs b/columnar/src/column_values/mod.rs index f2e1b036a..ef5de5154 100644 --- a/columnar/src/column_values/mod.rs +++ b/columnar/src/column_values/mod.rs @@ -10,6 +10,7 @@ use std::fmt::Debug; use std::ops::{Range, RangeInclusive}; use std::sync::Arc; +use downcast_rs::DowncastSync; pub use monotonic_mapping::{MonotonicallyMappableToU64, StrictlyMonotonicFn}; pub use monotonic_mapping_u128::MonotonicallyMappableToU128; @@ -25,7 +26,10 @@ mod monotonic_column; pub(crate) use merge::MergedColumnValues; pub use stats::ColumnStats; -pub use u128_based::{open_u128_mapped, serialize_column_values_u128}; +pub use u128_based::{ + open_u128_as_compact_u64, open_u128_mapped, serialize_column_values_u128, + CompactSpaceU64Accessor, +}; pub use u64_based::{ load_u64_based_column_values, serialize_and_load_u64_based_column_values, serialize_u64_based_column_values, CodecType, ALL_U64_CODEC_TYPES, @@ -41,7 +45,7 @@ use crate::RowId; /// /// Any methods with a default and specialized implementation need to be called in the /// wrappers that implement the trait: Arc and MonotonicMappingColumn -pub trait ColumnValues: Send + Sync { +pub trait ColumnValues: Send + Sync + DowncastSync { /// Return the value associated with the given idx. /// /// This accessor should return as fast as possible. @@ -68,11 +72,40 @@ pub trait ColumnValues: Send + Sync { out_x4[3] = self.get_val(idx_x4[3]); } - let step_size = 4; - let cutoff = indexes.len() - indexes.len() % step_size; + let out_and_idx_chunks = output + .chunks_exact_mut(4) + .into_remainder() + .iter_mut() + .zip(indexes.chunks_exact(4).remainder()); + for (out, idx) in out_and_idx_chunks { + *out = self.get_val(*idx); + } + } - for idx in cutoff..indexes.len() { - output[idx] = self.get_val(indexes[idx]); + /// Allows to push down multiple fetch calls, to avoid dynamic dispatch overhead. + /// The slightly weird `Option` in output allows pushdown to full columns. + /// + /// idx and output should have the same length + /// + /// # Panics + /// + /// May panic if `idx` is greater than the column length. + fn get_vals_opt(&self, indexes: &[u32], output: &mut [Option]) { + assert!(indexes.len() == output.len()); + let out_and_idx_chunks = output.chunks_exact_mut(4).zip(indexes.chunks_exact(4)); + for (out_x4, idx_x4) in out_and_idx_chunks { + out_x4[0] = Some(self.get_val(idx_x4[0])); + out_x4[1] = Some(self.get_val(idx_x4[1])); + out_x4[2] = Some(self.get_val(idx_x4[2])); + out_x4[3] = Some(self.get_val(idx_x4[3])); + } + let out_and_idx_chunks = output + .chunks_exact_mut(4) + .into_remainder() + .iter_mut() + .zip(indexes.chunks_exact(4).remainder()); + for (out, idx) in out_and_idx_chunks { + *out = Some(self.get_val(*idx)); } } @@ -101,7 +134,7 @@ pub trait ColumnValues: Send + Sync { row_id_hits: &mut Vec, ) { let row_id_range = row_id_range.start..row_id_range.end.min(self.num_vals()); - for idx in row_id_range.start..row_id_range.end { + for idx in row_id_range { let val = self.get_val(idx); if value_range.contains(&val) { row_id_hits.push(idx); @@ -139,6 +172,7 @@ pub trait ColumnValues: Send + Sync { Box::new((0..self.num_vals()).map(|idx| self.get_val(idx))) } } +downcast_rs::impl_downcast!(sync ColumnValues where T: PartialOrd); /// Empty column of values. pub struct EmptyColumnValues; @@ -161,12 +195,17 @@ impl ColumnValues for EmptyColumnValues { } } -impl ColumnValues for Arc> { +impl ColumnValues for Arc> { #[inline(always)] fn get_val(&self, idx: u32) -> T { self.as_ref().get_val(idx) } + #[inline(always)] + fn get_vals_opt(&self, indexes: &[u32], output: &mut [Option]) { + self.as_ref().get_vals_opt(indexes, output) + } + #[inline(always)] fn min_value(&self) -> T { self.as_ref().min_value() diff --git a/columnar/src/column_values/monotonic_column.rs b/columnar/src/column_values/monotonic_column.rs index de48d7a0f..506650be3 100644 --- a/columnar/src/column_values/monotonic_column.rs +++ b/columnar/src/column_values/monotonic_column.rs @@ -31,10 +31,10 @@ pub fn monotonic_map_column( monotonic_mapping: T, ) -> impl ColumnValues where - C: ColumnValues, - T: StrictlyMonotonicFn + Send + Sync, - Input: PartialOrd + Debug + Send + Sync + Clone, - Output: PartialOrd + Debug + Send + Sync + Clone, + C: ColumnValues + 'static, + T: StrictlyMonotonicFn + Send + Sync + 'static, + Input: PartialOrd + Debug + Send + Sync + Clone + 'static, + Output: PartialOrd + Debug + Send + Sync + Clone + 'static, { MonotonicMappingColumn { from_column, @@ -45,10 +45,10 @@ where impl ColumnValues for MonotonicMappingColumn where - C: ColumnValues, - T: StrictlyMonotonicFn + Send + Sync, - Input: PartialOrd + Send + Debug + Sync + Clone, - Output: PartialOrd + Send + Debug + Sync + Clone, + C: ColumnValues + 'static, + T: StrictlyMonotonicFn + Send + Sync + 'static, + Input: PartialOrd + Send + Debug + Sync + Clone + 'static, + Output: PartialOrd + Send + Debug + Sync + Clone + 'static, { #[inline(always)] fn get_val(&self, idx: u32) -> Output { @@ -107,7 +107,7 @@ mod tests { #[test] fn test_monotonic_mapping_iter() { let vals: Vec = (0..100u64).map(|el| el * 10).collect(); - let col = VecColumn::from(&vals); + let col = VecColumn::from(vals); let mapped = monotonic_map_column( col, StrictlyMonotonicMappingInverter::from(StrictlyMonotonicMappingToInternal::::new()), diff --git a/columnar/src/column_values/u128_based/compact_space/mod.rs b/columnar/src/column_values/u128_based/compact_space/mod.rs index 3b1069657..f246c7b0c 100644 --- a/columnar/src/column_values/u128_based/compact_space/mod.rs +++ b/columnar/src/column_values/u128_based/compact_space/mod.rs @@ -22,7 +22,7 @@ mod build_compact_space; use build_compact_space::get_compact_space; use common::{BinarySerializable, CountingWriter, OwnedBytes, VInt, VIntU128}; -use tantivy_bitpacker::{self, BitPacker, BitUnpacker}; +use tantivy_bitpacker::{BitPacker, BitUnpacker}; use crate::column_values::ColumnValues; use crate::RowId; @@ -148,7 +148,7 @@ impl CompactSpace { .binary_search_by_key(&compact, |range_mapping| range_mapping.compact_start) // Correctness: Overflow. The first range starts at compact space 0, the error from // binary search can never be 0 - .map_or_else(|e| e - 1, |v| v); + .unwrap_or_else(|e| e - 1); let range_mapping = &self.ranges_mapping[pos]; let diff = compact - range_mapping.compact_start; @@ -292,6 +292,63 @@ impl BinarySerializable for IPCodecParams { } } +/// Exposes the compact space compressed values as u64. +/// +/// This allows faster access to the values, as u64 is faster to work with than u128. +/// It also allows to handle u128 values like u64, via the `open_u64_lenient` as a uniform +/// access interface. +/// +/// When converting from the internal u64 to u128 `compact_to_u128` can be used. +pub struct CompactSpaceU64Accessor(CompactSpaceDecompressor); +impl CompactSpaceU64Accessor { + pub(crate) fn open(data: OwnedBytes) -> io::Result { + let decompressor = CompactSpaceU64Accessor(CompactSpaceDecompressor::open(data)?); + Ok(decompressor) + } + /// Convert a compact space value to u128 + pub fn compact_to_u128(&self, compact: u32) -> u128 { + self.0.compact_to_u128(compact) + } +} + +impl ColumnValues for CompactSpaceU64Accessor { + #[inline] + fn get_val(&self, doc: u32) -> u64 { + let compact = self.0.get_compact(doc); + compact as u64 + } + + fn min_value(&self) -> u64 { + self.0.u128_to_compact(self.0.min_value()).unwrap() as u64 + } + + fn max_value(&self) -> u64 { + self.0.u128_to_compact(self.0.max_value()).unwrap() as u64 + } + + fn num_vals(&self) -> u32 { + self.0.params.num_vals + } + + #[inline] + fn iter(&self) -> Box + '_> { + Box::new(self.0.iter_compact().map(|el| el as u64)) + } + + #[inline] + fn get_row_ids_for_value_range( + &self, + value_range: RangeInclusive, + position_range: Range, + positions: &mut Vec, + ) { + let value_range = self.0.compact_to_u128(*value_range.start() as u32) + ..=self.0.compact_to_u128(*value_range.end() as u32); + self.0 + .get_row_ids_for_value_range(value_range, position_range, positions) + } +} + impl ColumnValues for CompactSpaceDecompressor { #[inline] fn get_val(&self, doc: u32) -> u128 { @@ -402,9 +459,14 @@ impl CompactSpaceDecompressor { .map(|compact| self.compact_to_u128(compact)) } + #[inline] + pub fn get_compact(&self, idx: u32) -> u32 { + self.params.bit_unpacker.get(idx, &self.data) as u32 + } + #[inline] pub fn get(&self, idx: u32) -> u128 { - let compact = self.params.bit_unpacker.get(idx, &self.data) as u32; + let compact = self.get_compact(idx); self.compact_to_u128(compact) } diff --git a/columnar/src/column_values/u128_based/mod.rs b/columnar/src/column_values/u128_based/mod.rs index 0cae841c5..0e827b460 100644 --- a/columnar/src/column_values/u128_based/mod.rs +++ b/columnar/src/column_values/u128_based/mod.rs @@ -6,7 +6,9 @@ use std::sync::Arc; mod compact_space; use common::{BinarySerializable, OwnedBytes, VInt}; -use compact_space::{CompactSpaceCompressor, CompactSpaceDecompressor}; +pub use compact_space::{ + CompactSpaceCompressor, CompactSpaceDecompressor, CompactSpaceU64Accessor, +}; use crate::column_values::monotonic_map_column; use crate::column_values::monotonic_mapping::{ @@ -108,6 +110,23 @@ pub fn open_u128_mapped( StrictlyMonotonicMappingToInternal::::new().into(); Ok(Arc::new(monotonic_map_column(reader, inverted))) } + +/// Returns the u64 representation of the u128 data. +/// The internal representation of the data as u64 is useful for faster processing. +/// +/// In order to convert to u128 back cast to `CompactSpaceU64Accessor` and call +/// `compact_to_u128`. +/// +/// # Notice +/// In case there are new codecs added, check for usages of `CompactSpaceDecompressorU64` and +/// also handle the new codecs. +pub fn open_u128_as_compact_u64(mut bytes: OwnedBytes) -> io::Result>> { + let header = U128Header::deserialize(&mut bytes)?; + assert_eq!(header.codec_type, U128FastFieldCodecType::CompactSpace); + let reader = CompactSpaceU64Accessor::open(bytes)?; + Ok(Arc::new(reader)) +} + #[cfg(test)] pub mod tests { use super::*; diff --git a/columnar/src/column_values/u64_based/bitpacked.rs b/columnar/src/column_values/u64_based/bitpacked.rs index d9800d9f8..3ed999648 100644 --- a/columnar/src/column_values/u64_based/bitpacked.rs +++ b/columnar/src/column_values/u64_based/bitpacked.rs @@ -63,7 +63,6 @@ impl ColumnValues for BitpackedReader { fn get_val(&self, doc: u32) -> u64 { self.stats.min_value + self.stats.gcd.get() * self.bit_unpacker.get(doc, &self.data) } - #[inline] fn min_value(&self) -> u64 { self.stats.min_value diff --git a/columnar/src/column_values/u64_based/blockwise_linear.rs b/columnar/src/column_values/u64_based/blockwise_linear.rs index 9e8e0cc29..2abf8205b 100644 --- a/columnar/src/column_values/u64_based/blockwise_linear.rs +++ b/columnar/src/column_values/u64_based/blockwise_linear.rs @@ -63,7 +63,10 @@ impl BlockwiseLinearEstimator { if self.block.is_empty() { return; } - let line = Line::train(&VecColumn::from(&self.block)); + let column = VecColumn::from(std::mem::take(&mut self.block)); + let line = Line::train(&column); + self.block = column.into(); + let mut max_value = 0u64; for (i, buffer_val) in self.block.iter().enumerate() { let interpolated_val = line.eval(i as u32); @@ -125,7 +128,7 @@ impl ColumnCodecEstimator for BlockwiseLinearEstimator { *buffer_val = gcd_divider.divide(*buffer_val - stats.min_value); } - let line = Line::train(&VecColumn::from(&buffer)); + let line = Line::train(&VecColumn::from(buffer.to_vec())); assert!(!buffer.is_empty()); diff --git a/columnar/src/column_values/u64_based/line.rs b/columnar/src/column_values/u64_based/line.rs index e84b5b228..f3d5504fd 100644 --- a/columnar/src/column_values/u64_based/line.rs +++ b/columnar/src/column_values/u64_based/line.rs @@ -184,7 +184,7 @@ mod tests { } fn test_eval_max_err(ys: &[u64]) -> Option { - let line = Line::train(&VecColumn::from(&ys)); + let line = Line::train(&VecColumn::from(ys.to_vec())); ys.iter() .enumerate() .map(|(x, y)| y.wrapping_sub(line.eval(x as u32))) diff --git a/columnar/src/column_values/u64_based/linear.rs b/columnar/src/column_values/u64_based/linear.rs index b5b49679c..ba0c9e641 100644 --- a/columnar/src/column_values/u64_based/linear.rs +++ b/columnar/src/column_values/u64_based/linear.rs @@ -173,7 +173,9 @@ impl LinearCodecEstimator { fn collect_before_line_estimation(&mut self, value: u64) { self.block.push(value); if self.block.len() == LINE_ESTIMATION_BLOCK_LEN { - let line = Line::train(&VecColumn::from(&self.block)); + let column = VecColumn::from(std::mem::take(&mut self.block)); + let line = Line::train(&column); + self.block = column.into(); let block = std::mem::take(&mut self.block); for val in block { self.collect_after_line_estimation(&line, val); diff --git a/columnar/src/column_values/u64_based/tests.rs b/columnar/src/column_values/u64_based/tests.rs index 4ab45906c..973ff6d90 100644 --- a/columnar/src/column_values/u64_based/tests.rs +++ b/columnar/src/column_values/u64_based/tests.rs @@ -1,5 +1,4 @@ use proptest::prelude::*; -use proptest::strategy::Strategy; use proptest::{prop_oneof, proptest}; #[test] diff --git a/columnar/src/column_values/vec_column.rs b/columnar/src/column_values/vec_column.rs index 59f5d72ab..bc8599343 100644 --- a/columnar/src/column_values/vec_column.rs +++ b/columnar/src/column_values/vec_column.rs @@ -4,14 +4,14 @@ use tantivy_bitpacker::minmax; use crate::ColumnValues; -/// VecColumn provides `Column` over a slice. -pub struct VecColumn<'a, T = u64> { - pub(crate) values: &'a [T], +/// VecColumn provides `Column` over a `Vec`. +pub struct VecColumn { + pub(crate) values: Vec, pub(crate) min_value: T, pub(crate) max_value: T, } -impl<'a, T: Copy + PartialOrd + Send + Sync + Debug> ColumnValues for VecColumn<'a, T> { +impl ColumnValues for VecColumn { fn get_val(&self, position: u32) -> T { self.values[position as usize] } @@ -37,11 +37,8 @@ impl<'a, T: Copy + PartialOrd + Send + Sync + Debug> ColumnValues for VecColu } } -impl<'a, T: Copy + PartialOrd + Default, V> From<&'a V> for VecColumn<'a, T> -where V: AsRef<[T]> + ?Sized -{ - fn from(values: &'a V) -> Self { - let values = values.as_ref(); +impl From> for VecColumn { + fn from(values: Vec) -> Self { let (min_value, max_value) = minmax(values.iter().copied()).unwrap_or_default(); Self { values, @@ -50,3 +47,8 @@ where V: AsRef<[T]> + ?Sized } } } +impl From for Vec { + fn from(column: VecColumn) -> Self { + column.values + } +} diff --git a/columnar/src/columnar/merge/tests.rs b/columnar/src/columnar/merge/tests.rs index 2e688e319..32f29bccd 100644 --- a/columnar/src/columnar/merge/tests.rs +++ b/columnar/src/columnar/merge/tests.rs @@ -1,7 +1,3 @@ -use std::collections::BTreeMap; - -use itertools::Itertools; - use super::*; use crate::{Cardinality, ColumnarWriter, HasAssociatedColumnType, RowId}; diff --git a/columnar/src/columnar/writer/mod.rs b/columnar/src/columnar/writer/mod.rs index 53f0088c8..32b31b901 100644 --- a/columnar/src/columnar/writer/mod.rs +++ b/columnar/src/columnar/writer/mod.rs @@ -13,9 +13,7 @@ pub(crate) use serializer::ColumnarSerializer; use stacker::{Addr, ArenaHashMap, MemoryArena}; use crate::column_index::SerializableColumnIndex; -use crate::column_values::{ - ColumnValues, MonotonicallyMappableToU128, MonotonicallyMappableToU64, VecColumn, -}; +use crate::column_values::{MonotonicallyMappableToU128, MonotonicallyMappableToU64}; use crate::columnar::column_type::ColumnType; use crate::columnar::writer::column_writers::{ ColumnWriter, NumericalColumnWriter, StrOrBytesColumnWriter, @@ -645,10 +643,7 @@ fn send_to_serialize_column_mappable_to_u128< value_index_builders: &mut PreallocatedIndexBuilders, values: &mut Vec, mut wrt: impl io::Write, -) -> io::Result<()> -where - for<'a> VecColumn<'a, T>: ColumnValues, -{ +) -> io::Result<()> { values.clear(); // TODO: split index and values let serializable_column_index = match cardinality { @@ -701,10 +696,7 @@ fn send_to_serialize_column_mappable_to_u64( value_index_builders: &mut PreallocatedIndexBuilders, values: &mut Vec, mut wrt: impl io::Write, -) -> io::Result<()> -where - for<'a> VecColumn<'a, u64>: ColumnValues, -{ +) -> io::Result<()> { values.clear(); let serializable_column_index = match cardinality { Cardinality::Full => { diff --git a/columnar/src/columnar/writer/serializer.rs b/columnar/src/columnar/writer/serializer.rs index 0d99a76c7..394e61cd9 100644 --- a/columnar/src/columnar/writer/serializer.rs +++ b/columnar/src/columnar/writer/serializer.rs @@ -18,7 +18,12 @@ pub struct ColumnarSerializer { /// code. fn prepare_key(key: &[u8], column_type: ColumnType, buffer: &mut Vec) { buffer.clear(); - buffer.extend_from_slice(key); + // Convert 0 bytes to '0' string, as 0 bytes are reserved for the end of the path. + if key.contains(&0u8) { + buffer.extend(key.iter().map(|&b| if b == 0 { b'0' } else { b })); + } else { + buffer.extend_from_slice(key); + } buffer.push(0u8); buffer.push(column_type.to_code()); } @@ -96,14 +101,13 @@ impl<'a, W: io::Write> io::Write for ColumnSerializer<'a, W> { #[cfg(test)] mod tests { use super::*; - use crate::columnar::column_type::ColumnType; #[test] fn test_prepare_key_bytes() { let mut buffer: Vec = b"somegarbage".to_vec(); prepare_key(b"root\0child", ColumnType::Str, &mut buffer); assert_eq!(buffer.len(), 12); - assert_eq!(&buffer[..10], b"root\0child"); + assert_eq!(&buffer[..10], b"root0child"); assert_eq!(buffer[10], 0u8); assert_eq!(buffer[11], ColumnType::Str.to_code()); } diff --git a/columnar/src/dynamic_column.rs b/columnar/src/dynamic_column.rs index 0c566382e..0a18d4207 100644 --- a/columnar/src/dynamic_column.rs +++ b/columnar/src/dynamic_column.rs @@ -8,7 +8,7 @@ use common::{ByteCount, DateTime, HasLen, OwnedBytes}; use crate::column::{BytesColumn, Column, StrColumn}; use crate::column_values::{monotonic_map_column, StrictlyMonotonicFn}; use crate::columnar::ColumnType; -use crate::{Cardinality, ColumnIndex, NumericalType}; +use crate::{Cardinality, ColumnIndex, ColumnValues, NumericalType}; #[derive(Clone)] pub enum DynamicColumn { @@ -247,7 +247,12 @@ impl DynamicColumnHandle { } /// Returns the `u64` fast field reader reader associated with `fields` of types - /// Str, u64, i64, f64, bool, or datetime. + /// Str, u64, i64, f64, bool, ip, or datetime. + /// + /// Notice that for IpAddr, the fastfield reader will return the u64 representation of the + /// IpAddr. + /// In order to convert to u128 back cast to `CompactSpaceU64Accessor` and call + /// `compact_to_u128`. /// /// If not, the fastfield reader will returns the u64-value associated with the original /// FastValue. @@ -258,7 +263,10 @@ impl DynamicColumnHandle { let column: BytesColumn = crate::column::open_column_bytes(column_bytes)?; Ok(Some(column.term_ord_column)) } - ColumnType::IpAddr => Ok(None), + ColumnType::IpAddr => { + let column = crate::column::open_column_u128_as_compact_u64(column_bytes)?; + Ok(Some(column)) + } ColumnType::Bool | ColumnType::I64 | ColumnType::U64 diff --git a/columnar/src/lib.rs b/columnar/src/lib.rs index a20b8363b..7236ea5bc 100644 --- a/columnar/src/lib.rs +++ b/columnar/src/lib.rs @@ -113,6 +113,9 @@ impl Cardinality { pub fn is_multivalue(&self) -> bool { matches!(self, Cardinality::Multivalued) } + pub fn is_full(&self) -> bool { + matches!(self, Cardinality::Full) + } pub(crate) fn to_code(self) -> u8 { self as u8 } diff --git a/columnar/src/tests.rs b/columnar/src/tests.rs index 2d45080b9..5e5c50f55 100644 --- a/columnar/src/tests.rs +++ b/columnar/src/tests.rs @@ -26,7 +26,7 @@ fn test_dataframe_writer_str() { assert_eq!(columnar.num_columns(), 1); let cols: Vec = columnar.read_columns("my_string").unwrap(); assert_eq!(cols.len(), 1); - assert_eq!(cols[0].num_bytes(), 87); + assert_eq!(cols[0].num_bytes(), 73); } #[test] @@ -40,7 +40,7 @@ fn test_dataframe_writer_bytes() { assert_eq!(columnar.num_columns(), 1); let cols: Vec = columnar.read_columns("my_string").unwrap(); assert_eq!(cols.len(), 1); - assert_eq!(cols[0].num_bytes(), 87); + assert_eq!(cols[0].num_bytes(), 73); } #[test] diff --git a/common/Cargo.toml b/common/Cargo.toml index 91765b8f7..a04bfcdb3 100644 --- a/common/Cargo.toml +++ b/common/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tantivy-common" -version = "0.6.0" +version = "0.7.0" authors = ["Paul Masurel ", "Pascal Seitz "] license = "MIT" edition = "2021" @@ -14,7 +14,7 @@ repository = "https://github.com/quickwit-oss/tantivy" [dependencies] byteorder = "1.4.3" -ownedbytes = { version= "0.6", path="../ownedbytes" } +ownedbytes = { version= "0.7", path="../ownedbytes" } async-trait = "0.1" time = { version = "0.3.10", features = ["serde-well-known"] } serde = { version = "1.0.136", features = ["derive"] } diff --git a/common/src/bitset.rs b/common/src/bitset.rs index 6932b0416..b25a52845 100644 --- a/common/src/bitset.rs +++ b/common/src/bitset.rs @@ -1,6 +1,5 @@ -use std::convert::TryInto; use std::io::Write; -use std::{fmt, io, u64}; +use std::{fmt, io}; use ownedbytes::OwnedBytes; diff --git a/common/src/datetime.rs b/common/src/datetime.rs index 3aeadad3e..0dd80b147 100644 --- a/common/src/datetime.rs +++ b/common/src/datetime.rs @@ -1,5 +1,3 @@ -#![allow(deprecated)] - use std::fmt; use std::io::{Read, Write}; @@ -27,9 +25,6 @@ pub enum DateTimePrecision { Nanoseconds, } -#[deprecated(since = "0.20.0", note = "Use `DateTimePrecision` instead")] -pub type DatePrecision = DateTimePrecision; - /// A date/time value with nanoseconds precision. /// /// This timestamp does not carry any explicit time zone information. @@ -40,7 +35,7 @@ pub type DatePrecision = DateTimePrecision; /// All constructors and conversions are provided as explicit /// functions and not by implementing any `From`/`Into` traits /// to prevent unintended usage. -#[derive(Clone, Default, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Clone, Default, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] pub struct DateTime { // Timestamp in nanoseconds. pub(crate) timestamp_nanos: i64, diff --git a/common/src/json_path_writer.rs b/common/src/json_path_writer.rs index 43a5da8eb..555f343e9 100644 --- a/common/src/json_path_writer.rs +++ b/common/src/json_path_writer.rs @@ -5,6 +5,12 @@ pub const JSON_PATH_SEGMENT_SEP: u8 = 1u8; pub const JSON_PATH_SEGMENT_SEP_STR: &str = unsafe { std::str::from_utf8_unchecked(&[JSON_PATH_SEGMENT_SEP]) }; +/// Separates the json path and the value in +/// a JSON term binary representation. +pub const JSON_END_OF_PATH: u8 = 0u8; +pub const JSON_END_OF_PATH_STR: &str = + unsafe { std::str::from_utf8_unchecked(&[JSON_END_OF_PATH]) }; + /// Create a new JsonPathWriter, that creates flattened json paths for tantivy. #[derive(Clone, Debug, Default)] pub struct JsonPathWriter { @@ -14,6 +20,14 @@ pub struct JsonPathWriter { } impl JsonPathWriter { + pub fn with_expand_dots(expand_dots: bool) -> Self { + JsonPathWriter { + path: String::new(), + indices: Vec::new(), + expand_dots, + } + } + pub fn new() -> Self { JsonPathWriter { path: String::new(), @@ -39,8 +53,8 @@ impl JsonPathWriter { pub fn push(&mut self, segment: &str) { let len_path = self.path.len(); self.indices.push(len_path); - if !self.path.is_empty() { - self.path.push_str(JSON_PATH_SEGMENT_SEP_STR); + if self.indices.len() > 1 { + self.path.push(JSON_PATH_SEGMENT_SEP as char); } self.path.push_str(segment); if self.expand_dots { @@ -55,6 +69,12 @@ impl JsonPathWriter { } } + /// Set the end of JSON path marker. + #[inline] + pub fn set_end(&mut self) { + self.path.push_str(JSON_END_OF_PATH_STR); + } + /// Remove the last segment. Does nothing if the path is empty. #[inline] pub fn pop(&mut self) { @@ -91,6 +111,7 @@ mod tests { #[test] fn json_path_writer_test() { let mut writer = JsonPathWriter::new(); + writer.set_expand_dots(false); writer.push("root"); assert_eq!(writer.as_str(), "root"); @@ -109,4 +130,15 @@ mod tests { writer.push("k8s.node.id"); assert_eq!(writer.as_str(), "root\u{1}k8s\u{1}node\u{1}id"); } + + #[test] + fn test_json_path_expand_dots_enabled_pop_segment() { + let mut json_writer = JsonPathWriter::with_expand_dots(true); + json_writer.push("hello"); + assert_eq!(json_writer.as_str(), "hello"); + json_writer.push("color.hue"); + assert_eq!(json_writer.as_str(), "hello\x01color\x01hue"); + json_writer.pop(); + assert_eq!(json_writer.as_str(), "hello"); + } } diff --git a/common/src/lib.rs b/common/src/lib.rs index 054378ee5..bfbccecd9 100644 --- a/common/src/lib.rs +++ b/common/src/lib.rs @@ -9,14 +9,12 @@ mod byte_count; mod datetime; pub mod file_slice; mod group_by; -mod json_path_writer; +pub mod json_path_writer; mod serialize; mod vint; mod writer; pub use bitset::*; pub use byte_count::ByteCount; -#[allow(deprecated)] -pub use datetime::DatePrecision; pub use datetime::{DateTime, DateTimePrecision}; pub use group_by::GroupByIteratorExtended; pub use json_path_writer::JsonPathWriter; diff --git a/common/src/serialize.rs b/common/src/serialize.rs index 69b94090f..181d61e54 100644 --- a/common/src/serialize.rs +++ b/common/src/serialize.rs @@ -290,8 +290,7 @@ impl<'a> BinarySerializable for Cow<'a, [u8]> { #[cfg(test)] pub mod test { - use super::{VInt, *}; - use crate::serialize::BinarySerializable; + use super::*; pub fn fixed_size_test() { let mut buffer = Vec::new(); O::default().serialize(&mut buffer).unwrap(); diff --git a/ownedbytes/Cargo.toml b/ownedbytes/Cargo.toml index 8f990b3d3..2391dbef6 100644 --- a/ownedbytes/Cargo.toml +++ b/ownedbytes/Cargo.toml @@ -1,7 +1,7 @@ [package] authors = ["Paul Masurel ", "Pascal Seitz "] name = "ownedbytes" -version = "0.6.0" +version = "0.7.0" edition = "2021" description = "Expose data as static slice" license = "MIT" diff --git a/ownedbytes/src/lib.rs b/ownedbytes/src/lib.rs index 67feb0312..9266af386 100644 --- a/ownedbytes/src/lib.rs +++ b/ownedbytes/src/lib.rs @@ -1,4 +1,3 @@ -use std::convert::TryInto; use std::ops::{Deref, Range}; use std::sync::Arc; use std::{fmt, io}; diff --git a/query-grammar/Cargo.toml b/query-grammar/Cargo.toml index 26be4e72a..b9fecb25a 100644 --- a/query-grammar/Cargo.toml +++ b/query-grammar/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tantivy-query-grammar" -version = "0.21.0" +version = "0.22.0" authors = ["Paul Masurel "] license = "MIT" categories = ["database-implementations", "data-structures"] diff --git a/query-grammar/src/infallible.rs b/query-grammar/src/infallible.rs index 6de085b13..0f9edec8e 100644 --- a/query-grammar/src/infallible.rs +++ b/query-grammar/src/infallible.rs @@ -81,8 +81,8 @@ where T: InputTakeAtPosition + Clone, ::Item: AsChar + Clone, { - opt_i(nom::character::complete::space0)(input) - .map(|(left, (spaces, errors))| (left, (spaces.expect("space0 can't fail"), errors))) + opt_i(nom::character::complete::multispace0)(input) + .map(|(left, (spaces, errors))| (left, (spaces.expect("multispace0 can't fail"), errors))) } pub(crate) fn space1_infallible(input: T) -> JResult> @@ -90,7 +90,7 @@ where T: InputTakeAtPosition + Clone + InputLength, ::Item: AsChar + Clone, { - opt_i(nom::character::complete::space1)(input).map(|(left, (spaces, mut errors))| { + opt_i(nom::character::complete::multispace1)(input).map(|(left, (spaces, mut errors))| { if spaces.is_none() { errors.push(LenientErrorInternal { pos: left.input_len(), diff --git a/query-grammar/src/query_grammar.rs b/query-grammar/src/query_grammar.rs index ed3dd1fdb..15802002b 100644 --- a/query-grammar/src/query_grammar.rs +++ b/query-grammar/src/query_grammar.rs @@ -3,11 +3,11 @@ use std::iter::once; use nom::branch::alt; use nom::bytes::complete::tag; use nom::character::complete::{ - anychar, char, digit1, none_of, one_of, satisfy, space0, space1, u32, + anychar, char, digit1, multispace0, multispace1, none_of, one_of, satisfy, u32, }; use nom::combinator::{eof, map, map_res, opt, peek, recognize, value, verify}; use nom::error::{Error, ErrorKind}; -use nom::multi::{many0, many1, separated_list0, separated_list1}; +use nom::multi::{many0, many1, separated_list0}; use nom::sequence::{delimited, preceded, separated_pair, terminated, tuple}; use nom::IResult; @@ -65,7 +65,7 @@ fn word_infallible(delimiter: &str) -> impl Fn(&str) -> JResult<&str, Option<&st |inp| { opt_i_err( preceded( - space0, + multispace0, recognize(many1(satisfy(|c| { !c.is_whitespace() && !delimiter.contains(c) }))), @@ -218,27 +218,14 @@ fn term_or_phrase_infallible(inp: &str) -> JResult<&str, Option> } fn term_group(inp: &str) -> IResult<&str, UserInputAst> { - let occur_symbol = alt(( - value(Occur::MustNot, char('-')), - value(Occur::Must, char('+')), - )); - map( tuple(( - terminated(field_name, space0), - delimited( - tuple((char('('), space0)), - separated_list0(space1, tuple((opt(occur_symbol), term_or_phrase))), - char(')'), - ), + terminated(field_name, multispace0), + delimited(tuple((char('('), multispace0)), ast, char(')')), )), - |(field_name, terms)| { - UserInputAst::Clause( - terms - .into_iter() - .map(|(occur, leaf)| (occur, leaf.set_field(Some(field_name.clone())).into())) - .collect(), - ) + |(field_name, mut ast)| { + ast.set_default_field(field_name); + ast }, )(inp) } @@ -250,7 +237,7 @@ fn term_group_precond(inp: &str) -> IResult<&str, (), ()> { (), peek(tuple(( field_name, - space0, + multispace0, char('('), // when we are here, we know it can't be anything but a term group ))), )(inp) @@ -258,46 +245,18 @@ fn term_group_precond(inp: &str) -> IResult<&str, (), ()> { } fn term_group_infallible(inp: &str) -> JResult<&str, UserInputAst> { - let (mut inp, (field_name, _, _, _)) = - tuple((field_name, space0, char('('), space0))(inp).expect("precondition failed"); + let (inp, (field_name, _, _, _)) = + tuple((field_name, multispace0, char('('), multispace0))(inp).expect("precondition failed"); - let mut terms = Vec::new(); - let mut errs = Vec::new(); - - let mut first_round = true; - loop { - let mut space_error = if first_round { - first_round = false; - Vec::new() - } else { - let (rest, (_, err)) = space1_infallible(inp)?; - inp = rest; - err - }; - if inp.is_empty() { - errs.push(LenientErrorInternal { - pos: inp.len(), - message: "missing )".to_string(), - }); - break Ok((inp, (UserInputAst::Clause(terms), errs))); - } - if let Some(inp) = inp.strip_prefix(')') { - break Ok((inp, (UserInputAst::Clause(terms), errs))); - } - // only append missing space error if we did not reach the end of group - errs.append(&mut space_error); - - // here we do the assumption term_or_phrase_infallible always consume something if the - // first byte is not `)` or ' '. If it did not, we would end up looping. - - let (rest, ((occur, leaf), mut err)) = - tuple_infallible((occur_symbol, term_or_phrase_infallible))(inp)?; - errs.append(&mut err); - if let Some(leaf) = leaf { - terms.push((occur, leaf.set_field(Some(field_name.clone())).into())); - } - inp = rest; - } + let res = delimited_infallible( + nothing, + map(ast_infallible, |(mut ast, errors)| { + ast.set_default_field(field_name.to_string()); + (ast, errors) + }), + opt_i_err(char(')'), "expected ')'"), + )(inp); + res } fn exists(inp: &str) -> IResult<&str, UserInputLeaf> { @@ -305,7 +264,7 @@ fn exists(inp: &str) -> IResult<&str, UserInputLeaf> { UserInputLeaf::Exists { field: String::new(), }, - tuple((space0, char('*'))), + tuple((multispace0, char('*'))), )(inp) } @@ -314,7 +273,7 @@ fn exists_precond(inp: &str) -> IResult<&str, (), ()> { (), peek(tuple(( field_name, - space0, + multispace0, char('*'), // when we are here, we know it can't be anything but a exists ))), )(inp) @@ -323,7 +282,7 @@ fn exists_precond(inp: &str) -> IResult<&str, (), ()> { fn exists_infallible(inp: &str) -> JResult<&str, UserInputAst> { let (inp, (field_name, _, _)) = - tuple((field_name, space0, char('*')))(inp).expect("precondition failed"); + tuple((field_name, multispace0, char('*')))(inp).expect("precondition failed"); let exists = UserInputLeaf::Exists { field: field_name }.into(); Ok((inp, (exists, Vec::new()))) @@ -349,7 +308,7 @@ fn literal_no_group_infallible(inp: &str) -> JResult<&str, Option> alt_infallible( ( ( - value((), tuple((tag("IN"), space0, char('[')))), + value((), tuple((tag("IN"), multispace0, char('[')))), map(set_infallible, |(set, errs)| (Some(set), errs)), ), ( @@ -430,8 +389,8 @@ fn range(inp: &str) -> IResult<&str, UserInputLeaf> { // check for unbounded range in the form of <5, <=10, >5, >=5 let elastic_unbounded_range = map( tuple(( - preceded(space0, alt((tag(">="), tag("<="), tag("<"), tag(">")))), - preceded(space0, range_term_val()), + preceded(multispace0, alt((tag(">="), tag("<="), tag("<"), tag(">")))), + preceded(multispace0, range_term_val()), )), |(comparison_sign, bound)| match comparison_sign { ">=" => (UserInputBound::Inclusive(bound), UserInputBound::Unbounded), @@ -444,7 +403,7 @@ fn range(inp: &str) -> IResult<&str, UserInputLeaf> { ); let lower_bound = map( - separated_pair(one_of("{["), space0, range_term_val()), + separated_pair(one_of("{["), multispace0, range_term_val()), |(boundary_char, lower_bound)| { if lower_bound == "*" { UserInputBound::Unbounded @@ -457,7 +416,7 @@ fn range(inp: &str) -> IResult<&str, UserInputLeaf> { ); let upper_bound = map( - separated_pair(range_term_val(), space0, one_of("}]")), + separated_pair(range_term_val(), multispace0, one_of("}]")), |(upper_bound, boundary_char)| { if upper_bound == "*" { UserInputBound::Unbounded @@ -469,8 +428,11 @@ fn range(inp: &str) -> IResult<&str, UserInputLeaf> { }, ); - let lower_to_upper = - separated_pair(lower_bound, tuple((space1, tag("TO"), space1)), upper_bound); + let lower_to_upper = separated_pair( + lower_bound, + tuple((multispace1, tag("TO"), multispace1)), + upper_bound, + ); map( alt((elastic_unbounded_range, lower_to_upper)), @@ -490,13 +452,16 @@ fn range_infallible(inp: &str) -> JResult<&str, UserInputLeaf> { word_infallible("]}"), space1_infallible, opt_i_err( - terminated(tag("TO"), alt((value((), space1), value((), eof)))), + terminated(tag("TO"), alt((value((), multispace1), value((), eof)))), "missing keyword TO", ), word_infallible("]}"), opt_i_err(one_of("]}"), "missing range delimiter"), )), - |((lower_bound_kind, _space0, lower, _space1, to, upper, upper_bound_kind), errs)| { + |( + (lower_bound_kind, _multispace0, lower, _multispace1, to, upper, upper_bound_kind), + errs, + )| { let lower_bound = match (lower_bound_kind, lower) { (_, Some("*")) => UserInputBound::Unbounded, (_, None) => UserInputBound::Unbounded, @@ -596,10 +561,10 @@ fn range_infallible(inp: &str) -> JResult<&str, UserInputLeaf> { fn set(inp: &str) -> IResult<&str, UserInputLeaf> { map( preceded( - tuple((space0, tag("IN"), space1)), + tuple((multispace0, tag("IN"), multispace1)), delimited( - tuple((char('['), space0)), - separated_list0(space1, map(simple_term, |(_, term)| term)), + tuple((char('['), multispace0)), + separated_list0(multispace1, map(simple_term, |(_, term)| term)), char(']'), ), ), @@ -667,7 +632,7 @@ fn leaf(inp: &str) -> IResult<&str, UserInputAst> { alt(( delimited(char('('), ast, char(')')), map(char('*'), |_| UserInputAst::from(UserInputLeaf::All)), - map(preceded(tuple((tag("NOT"), space1)), leaf), negate), + map(preceded(tuple((tag("NOT"), multispace1)), leaf), negate), literal, ))(inp) } @@ -780,27 +745,23 @@ fn binary_operand(inp: &str) -> IResult<&str, BinaryOperand> { } fn aggregate_binary_expressions( - left: UserInputAst, - others: Vec<(BinaryOperand, UserInputAst)>, -) -> UserInputAst { - let mut dnf: Vec> = vec![vec![left]]; - for (operator, operand_ast) in others { - match operator { - BinaryOperand::And => { - if let Some(last) = dnf.last_mut() { - last.push(operand_ast); - } - } - BinaryOperand::Or => { - dnf.push(vec![operand_ast]); - } - } - } - if dnf.len() == 1 { - UserInputAst::and(dnf.into_iter().next().unwrap()) //< safe + left: (Option, UserInputAst), + others: Vec<(Option, Option, UserInputAst)>, +) -> Result { + let mut leafs = Vec::with_capacity(others.len() + 1); + leafs.push((None, left.0, Some(left.1))); + leafs.extend( + others + .into_iter() + .map(|(operand, occur, ast)| (operand, occur, Some(ast))), + ); + // the parameters we pass should statically guarantee we can't get errors + // (no prefix BinaryOperand is provided) + let (res, mut errors) = aggregate_infallible_expressions(leafs); + if errors.is_empty() { + Ok(res) } else { - let conjunctions = dnf.into_iter().map(UserInputAst::and).collect(); - UserInputAst::or(conjunctions) + Err(errors.swap_remove(0)) } } @@ -816,30 +777,10 @@ fn aggregate_infallible_expressions( return (UserInputAst::empty_query(), err); } - let use_operand = leafs.iter().any(|(operand, _, _)| operand.is_some()); - let all_operand = leafs - .iter() - .skip(1) - .all(|(operand, _, _)| operand.is_some()); let early_operand = leafs .iter() .take(1) .all(|(operand, _, _)| operand.is_some()); - let use_occur = leafs.iter().any(|(_, occur, _)| occur.is_some()); - - if use_operand && use_occur { - err.push(LenientErrorInternal { - pos: 0, - message: "Use of mixed occur and boolean operator".to_string(), - }); - } - - if use_operand && !all_operand { - err.push(LenientErrorInternal { - pos: 0, - message: "Missing boolean operator".to_string(), - }); - } if early_operand { err.push(LenientErrorInternal { @@ -866,7 +807,15 @@ fn aggregate_infallible_expressions( Some(BinaryOperand::And) => Some(Occur::Must), _ => Some(Occur::Should), }; - clauses.push(vec![(occur.or(default_op), ast.clone())]); + if occur == &Some(Occur::MustNot) && default_op == Some(Occur::Should) { + // if occur is MustNot *and* operation is OR, we synthetize a ShouldNot + clauses.push(vec![( + Some(Occur::Should), + ast.clone().unary(Occur::MustNot), + )]) + } else { + clauses.push(vec![(occur.or(default_op), ast.clone())]); + } } None => { let default_op = match next_operator { @@ -874,7 +823,15 @@ fn aggregate_infallible_expressions( Some(BinaryOperand::Or) => Some(Occur::Should), None => None, }; - clauses.push(vec![(occur.or(default_op), ast.clone())]) + if occur == &Some(Occur::MustNot) && default_op == Some(Occur::Should) { + // if occur is MustNot *and* operation is OR, we synthetize a ShouldNot + clauses.push(vec![( + Some(Occur::Should), + ast.clone().unary(Occur::MustNot), + )]) + } else { + clauses.push(vec![(occur.or(default_op), ast.clone())]) + } } } } @@ -891,7 +848,12 @@ fn aggregate_infallible_expressions( } } Some(BinaryOperand::Or) => { - clauses.push(vec![(last_occur.or(Some(Occur::Should)), last_ast)]); + if last_occur == Some(Occur::MustNot) { + // if occur is MustNot *and* operation is OR, we synthetize a ShouldNot + clauses.push(vec![(Some(Occur::Should), last_ast.unary(Occur::MustNot))]); + } else { + clauses.push(vec![(last_occur.or(Some(Occur::Should)), last_ast)]); + } } None => clauses.push(vec![(last_occur, last_ast)]), } @@ -917,35 +879,29 @@ fn aggregate_infallible_expressions( } } -fn operand_leaf(inp: &str) -> IResult<&str, (BinaryOperand, UserInputAst)> { - tuple(( - terminated(binary_operand, space0), - terminated(boosted_leaf, space0), - ))(inp) +fn operand_leaf(inp: &str) -> IResult<&str, (Option, Option, UserInputAst)> { + map( + tuple(( + terminated(opt(binary_operand), multispace0), + terminated(occur_leaf, multispace0), + )), + |(operand, (occur, ast))| (operand, occur, ast), + )(inp) } fn ast(inp: &str) -> IResult<&str, UserInputAst> { - let boolean_expr = map( - separated_pair(boosted_leaf, space1, many1(operand_leaf)), + let boolean_expr = map_res( + separated_pair(occur_leaf, multispace1, many1(operand_leaf)), |(left, right)| aggregate_binary_expressions(left, right), ); - let whitespace_separated_leaves = map(separated_list1(space1, occur_leaf), |subqueries| { - if subqueries.len() == 1 { - let (occur_opt, ast) = subqueries.into_iter().next().unwrap(); - match occur_opt.unwrap_or(Occur::Should) { - Occur::Must | Occur::Should => ast, - Occur::MustNot => UserInputAst::Clause(vec![(Some(Occur::MustNot), ast)]), - } + let single_leaf = map(occur_leaf, |(occur, ast)| { + if occur == Some(Occur::MustNot) { + ast.unary(Occur::MustNot) } else { - UserInputAst::Clause(subqueries.into_iter().collect()) + ast } }); - - delimited( - space0, - alt((boolean_expr, whitespace_separated_leaves)), - space0, - )(inp) + delimited(multispace0, alt((boolean_expr, single_leaf)), multispace0)(inp) } fn ast_infallible(inp: &str) -> JResult<&str, UserInputAst> { @@ -969,7 +925,7 @@ fn ast_infallible(inp: &str) -> JResult<&str, UserInputAst> { } pub fn parse_to_ast(inp: &str) -> IResult<&str, UserInputAst> { - map(delimited(space0, opt(ast), eof), |opt_ast| { + map(delimited(multispace0, opt(ast), eof), |opt_ast| { rewrite_ast(opt_ast.unwrap_or_else(UserInputAst::empty_query)) })(inp) } @@ -1145,24 +1101,43 @@ mod test { #[test] fn test_parse_query_to_ast_binary_op() { test_parse_query_to_ast_helper("a AND b", "(+a +b)"); + test_parse_query_to_ast_helper("a\nAND b", "(+a +b)"); test_parse_query_to_ast_helper("a OR b", "(?a ?b)"); test_parse_query_to_ast_helper("a OR b AND c", "(?a ?(+b +c))"); test_parse_query_to_ast_helper("a AND b AND c", "(+a +b +c)"); - test_is_parse_err("a OR b aaa", "(?a ?b *aaa)"); - test_is_parse_err("a AND b aaa", "(?(+a +b) *aaa)"); - test_is_parse_err("aaa a OR b ", "(*aaa ?a ?b)"); - test_is_parse_err("aaa ccc a OR b ", "(*aaa *ccc ?a ?b)"); - test_is_parse_err("aaa a AND b ", "(*aaa ?(+a +b))"); - test_is_parse_err("aaa ccc a AND b ", "(*aaa *ccc ?(+a +b))"); + test_parse_query_to_ast_helper("a OR b aaa", "(?a ?b *aaa)"); + test_parse_query_to_ast_helper("a AND b aaa", "(?(+a +b) *aaa)"); + test_parse_query_to_ast_helper("aaa a OR b ", "(*aaa ?a ?b)"); + test_parse_query_to_ast_helper("aaa ccc a OR b ", "(*aaa *ccc ?a ?b)"); + test_parse_query_to_ast_helper("aaa a AND b ", "(*aaa ?(+a +b))"); + test_parse_query_to_ast_helper("aaa ccc a AND b ", "(*aaa *ccc ?(+a +b))"); } #[test] fn test_parse_mixed_bool_occur() { - test_is_parse_err("a OR b +aaa", "(?a ?b +aaa)"); - test_is_parse_err("a AND b -aaa", "(?(+a +b) -aaa)"); - test_is_parse_err("+a OR +b aaa", "(+a +b *aaa)"); - test_is_parse_err("-a AND -b aaa", "(?(-a -b) *aaa)"); - test_is_parse_err("-aaa +ccc -a OR b ", "(-aaa +ccc -a ?b)"); + test_parse_query_to_ast_helper("+a OR +b", "(+a +b)"); + + test_parse_query_to_ast_helper("a AND -b", "(+a -b)"); + test_parse_query_to_ast_helper("-a AND b", "(-a +b)"); + test_parse_query_to_ast_helper("a AND NOT b", "(+a +(-b))"); + test_parse_query_to_ast_helper("NOT a AND b", "(+(-a) +b)"); + + test_parse_query_to_ast_helper("a AND NOT b AND c", "(+a +(-b) +c)"); + test_parse_query_to_ast_helper("a AND -b AND c", "(+a -b +c)"); + + test_parse_query_to_ast_helper("a OR -b", "(?a ?(-b))"); + test_parse_query_to_ast_helper("-a OR b", "(?(-a) ?b)"); + test_parse_query_to_ast_helper("a OR NOT b", "(?a ?(-b))"); + test_parse_query_to_ast_helper("NOT a OR b", "(?(-a) ?b)"); + + test_parse_query_to_ast_helper("a OR NOT b OR c", "(?a ?(-b) ?c)"); + test_parse_query_to_ast_helper("a OR -b OR c", "(?a ?(-b) ?c)"); + + test_parse_query_to_ast_helper("a OR b +aaa", "(?a ?b +aaa)"); + test_parse_query_to_ast_helper("a AND b -aaa", "(?(+a +b) -aaa)"); + test_parse_query_to_ast_helper("+a OR +b aaa", "(+a +b *aaa)"); + test_parse_query_to_ast_helper("-a AND -b aaa", "(?(-a -b) *aaa)"); + test_parse_query_to_ast_helper("-aaa +ccc -a OR b ", "(-aaa +ccc ?(-a) ?b)"); } #[test] @@ -1452,8 +1427,18 @@ mod test { #[test] fn test_parse_query_term_group() { - test_parse_query_to_ast_helper(r#"field:(abc)"#, r#"(*"field":abc)"#); + test_parse_query_to_ast_helper(r#"field:(abc)"#, r#""field":abc"#); test_parse_query_to_ast_helper(r#"field:(+a -"b c")"#, r#"(+"field":a -"field":"b c")"#); + test_parse_query_to_ast_helper(r#"field:(a AND "b c")"#, r#"(+"field":a +"field":"b c")"#); + test_parse_query_to_ast_helper(r#"field:(a OR "b c")"#, r#"(?"field":a ?"field":"b c")"#); + test_parse_query_to_ast_helper( + r#"field:(a OR (b AND c))"#, + r#"(?"field":a ?(+"field":b +"field":c))"#, + ); + test_parse_query_to_ast_helper( + r#"field:(a [b TO c])"#, + r#"(*"field":a *"field":["b" TO "c"])"#, + ); test_is_parse_err(r#"field:(+a -"b c""#, r#"(+"field":a -"field":"b c")"#); } diff --git a/query-grammar/src/user_input_ast.rs b/query-grammar/src/user_input_ast.rs index d0d1a0266..7289da55f 100644 --- a/query-grammar/src/user_input_ast.rs +++ b/query-grammar/src/user_input_ast.rs @@ -44,6 +44,26 @@ impl UserInputLeaf { }, } } + + pub(crate) fn set_default_field(&mut self, default_field: String) { + match self { + UserInputLeaf::Literal(ref mut literal) if literal.field_name.is_none() => { + literal.field_name = Some(default_field) + } + UserInputLeaf::All => { + *self = UserInputLeaf::Exists { + field: default_field, + } + } + UserInputLeaf::Range { ref mut field, .. } if field.is_none() => { + *field = Some(default_field) + } + UserInputLeaf::Set { ref mut field, .. } if field.is_none() => { + *field = Some(default_field) + } + _ => (), // field was already set, do nothing + } + } } impl Debug for UserInputLeaf { @@ -205,6 +225,16 @@ impl UserInputAst { pub fn or(asts: Vec) -> UserInputAst { UserInputAst::compose(Occur::Should, asts) } + + pub(crate) fn set_default_field(&mut self, field: String) { + match self { + UserInputAst::Clause(clauses) => clauses + .iter_mut() + .for_each(|(_, ast)| ast.set_default_field(field.clone())), + UserInputAst::Leaf(leaf) => leaf.set_default_field(field), + UserInputAst::Boost(ref mut ast, _) => ast.set_default_field(field), + } + } } impl From for UserInputLeaf { diff --git a/src/aggregation/agg_bench.rs b/src/aggregation/agg_bench.rs index ec534d994..84c0bb382 100644 --- a/src/aggregation/agg_bench.rs +++ b/src/aggregation/agg_bench.rs @@ -290,6 +290,41 @@ mod bench { }); } + bench_all_cardinalities!(bench_aggregation_terms_many_with_top_hits_agg); + + fn bench_aggregation_terms_many_with_top_hits_agg_card( + b: &mut Bencher, + cardinality: Cardinality, + ) { + let index = get_test_index_bench(cardinality).unwrap(); + let reader = index.reader().unwrap(); + + b.iter(|| { + let agg_req: Aggregations = serde_json::from_value(json!({ + "my_texts": { + "terms": { "field": "text_many_terms" }, + "aggs": { + "top_hits": { "top_hits": + { + "sort": [ + { "score": "desc" } + ], + "size": 2, + "doc_value_fields": ["score_f64"] + } + } + } + }, + })) + .unwrap(); + + let collector = get_collector(agg_req); + + let searcher = reader.searcher(); + searcher.search(&AllQuery, &collector).unwrap() + }); + } + bench_all_cardinalities!(bench_aggregation_terms_many_with_sub_agg); fn bench_aggregation_terms_many_with_sub_agg_card(b: &mut Bencher, cardinality: Cardinality) { diff --git a/src/aggregation/agg_req.rs b/src/aggregation/agg_req.rs index a6157d594..fea06bdf9 100644 --- a/src/aggregation/agg_req.rs +++ b/src/aggregation/agg_req.rs @@ -35,7 +35,7 @@ use super::bucket::{ }; use super::metric::{ AverageAggregation, CountAggregation, ExtendedStatsAggregation, MaxAggregation, MinAggregation, - PercentilesAggregationReq, StatsAggregation, SumAggregation, + PercentilesAggregationReq, StatsAggregation, SumAggregation, TopHitsAggregation, }; /// The top-level aggregation request structure, which contains [`Aggregation`] and their user @@ -93,7 +93,12 @@ impl Aggregation { } fn get_fast_field_names(&self, fast_field_names: &mut HashSet) { - fast_field_names.insert(self.agg.get_fast_field_name().to_string()); + fast_field_names.extend( + self.agg + .get_fast_field_names() + .iter() + .map(|s| s.to_string()), + ); fast_field_names.extend(get_fast_field_names(&self.sub_aggregation)); } } @@ -152,24 +157,28 @@ pub enum AggregationVariants { /// Computes the sum of the extracted values. #[serde(rename = "percentiles")] Percentiles(PercentilesAggregationReq), + /// Finds the top k values matching some order + #[serde(rename = "top_hits")] + TopHits(TopHitsAggregation), } impl AggregationVariants { - /// Returns the name of the field used by the aggregation. - pub fn get_fast_field_name(&self) -> &str { + /// Returns the name of the fields used by the aggregation. + pub fn get_fast_field_names(&self) -> Vec<&str> { match self { - AggregationVariants::Terms(terms) => terms.field.as_str(), - AggregationVariants::Range(range) => range.field.as_str(), - AggregationVariants::Histogram(histogram) => histogram.field.as_str(), - AggregationVariants::DateHistogram(histogram) => histogram.field.as_str(), - AggregationVariants::Average(avg) => avg.field_name(), - AggregationVariants::Count(count) => count.field_name(), - AggregationVariants::Max(max) => max.field_name(), - AggregationVariants::Min(min) => min.field_name(), - AggregationVariants::Stats(stats) => stats.field_name(), - AggregationVariants::ExtendedStats(extended_stats) => extended_stats.field_name(), - AggregationVariants::Sum(sum) => sum.field_name(), - AggregationVariants::Percentiles(per) => per.field_name(), + AggregationVariants::Terms(terms) => vec![terms.field.as_str()], + AggregationVariants::Range(range) => vec![range.field.as_str()], + AggregationVariants::Histogram(histogram) => vec![histogram.field.as_str()], + AggregationVariants::DateHistogram(histogram) => vec![histogram.field.as_str()], + AggregationVariants::Average(avg) => vec![avg.field_name()], + AggregationVariants::Count(count) => vec![count.field_name()], + AggregationVariants::Max(max) => vec![max.field_name()], + AggregationVariants::Min(min) => vec![min.field_name()], + AggregationVariants::Stats(stats) => vec![stats.field_name()], + AggregationVariants::ExtendedStats(extended_stats) => vec![extended_stats.field_name()], + AggregationVariants::Sum(sum) => vec![sum.field_name()], + AggregationVariants::Percentiles(per) => vec![per.field_name()], + AggregationVariants::TopHits(top_hits) => top_hits.field_names(), } } diff --git a/src/aggregation/agg_req_with_accessor.rs b/src/aggregation/agg_req_with_accessor.rs index 6dda293e8..1680876c7 100644 --- a/src/aggregation/agg_req_with_accessor.rs +++ b/src/aggregation/agg_req_with_accessor.rs @@ -1,6 +1,9 @@ //! This will enhance the request tree with access to the fastfield and metadata. -use columnar::{Column, ColumnBlockAccessor, ColumnType, StrColumn}; +use std::collections::HashMap; +use std::io; + +use columnar::{Column, ColumnBlockAccessor, ColumnType, DynamicColumn, StrColumn}; use super::agg_limits::ResourceLimitGuard; use super::agg_req::{Aggregation, AggregationVariants, Aggregations}; @@ -14,7 +17,7 @@ use super::metric::{ use super::segment_agg_result::AggregationLimits; use super::VecWithNames; use crate::aggregation::{f64_to_fastfield_u64, Key}; -use crate::SegmentReader; +use crate::{SegmentOrdinal, SegmentReader}; #[derive(Default)] pub(crate) struct AggregationsWithAccessor { @@ -32,6 +35,7 @@ impl AggregationsWithAccessor { } pub struct AggregationWithAccessor { + pub(crate) segment_ordinal: SegmentOrdinal, /// In general there can be buckets without fast field access, e.g. buckets that are created /// based on search terms. That is not that case currently, but eventually this needs to be /// Option or moved. @@ -44,10 +48,16 @@ pub struct AggregationWithAccessor { pub(crate) limits: ResourceLimitGuard, pub(crate) column_block_accessor: ColumnBlockAccessor, /// Used for missing term aggregation, which checks all columns for existence. + /// And also for `top_hits` aggregation, which may sort on multiple fields. /// By convention the missing aggregation is chosen, when this property is set /// (instead bein set in `agg`). /// If this needs to used by other aggregations, we need to refactor this. - pub(crate) accessors: Vec>, + // NOTE: we can make all other aggregations use this instead of the `accessor` and `field_type` + // (making them obsolete) But will it have a performance impact? + pub(crate) accessors: Vec<(Column, ColumnType)>, + /// Map field names to all associated column accessors. + /// This field is used for `docvalue_fields`, which is currently only supported for `top_hits`. + pub(crate) value_accessors: HashMap>, pub(crate) agg: Aggregation, } @@ -57,19 +67,55 @@ impl AggregationWithAccessor { agg: &Aggregation, sub_aggregation: &Aggregations, reader: &SegmentReader, + segment_ordinal: SegmentOrdinal, limits: AggregationLimits, ) -> crate::Result> { - let add_agg_with_accessor = |accessor: Column, + let mut agg = agg.clone(); + + let add_agg_with_accessor = |agg: &Aggregation, + accessor: Column, column_type: ColumnType, aggs: &mut Vec| -> crate::Result<()> { let res = AggregationWithAccessor { + segment_ordinal, accessor, - accessors: Vec::new(), + accessors: Default::default(), + value_accessors: Default::default(), field_type: column_type, sub_aggregation: get_aggs_with_segment_accessor_and_validate( sub_aggregation, reader, + segment_ordinal, + &limits, + )?, + agg: agg.clone(), + limits: limits.new_guard(), + missing_value_for_accessor: None, + str_dict_column: None, + column_block_accessor: Default::default(), + }; + aggs.push(res); + Ok(()) + }; + + let add_agg_with_accessors = |agg: &Aggregation, + accessors: Vec<(Column, ColumnType)>, + aggs: &mut Vec, + value_accessors: HashMap>| + -> crate::Result<()> { + let (accessor, field_type) = accessors.first().expect("at least one accessor"); + let res = AggregationWithAccessor { + segment_ordinal, + // TODO: We should do away with the `accessor` field altogether + accessor: accessor.clone(), + value_accessors, + field_type: *field_type, + accessors, + sub_aggregation: get_aggs_with_segment_accessor_and_validate( + sub_aggregation, + reader, + segment_ordinal, &limits, )?, agg: agg.clone(), @@ -84,32 +130,36 @@ impl AggregationWithAccessor { let mut res: Vec = Vec::new(); use AggregationVariants::*; - match &agg.agg { + + match agg.agg { Range(RangeAggregation { - field: field_name, .. + field: ref field_name, + .. }) => { let (accessor, column_type) = get_ff_reader(reader, field_name, Some(get_numeric_or_date_column_types()))?; - add_agg_with_accessor(accessor, column_type, &mut res)?; + add_agg_with_accessor(&agg, accessor, column_type, &mut res)?; } Histogram(HistogramAggregation { - field: field_name, .. + field: ref field_name, + .. }) => { let (accessor, column_type) = get_ff_reader(reader, field_name, Some(get_numeric_or_date_column_types()))?; - add_agg_with_accessor(accessor, column_type, &mut res)?; + add_agg_with_accessor(&agg, accessor, column_type, &mut res)?; } DateHistogram(DateHistogramAggregationReq { - field: field_name, .. + field: ref field_name, + .. }) => { let (accessor, column_type) = // Only DateTime is supported for DateHistogram get_ff_reader(reader, field_name, Some(&[ColumnType::DateTime]))?; - add_agg_with_accessor(accessor, column_type, &mut res)?; + add_agg_with_accessor(&agg, accessor, column_type, &mut res)?; } Terms(TermsAggregation { - field: field_name, - missing, + field: ref field_name, + ref missing, .. }) => { let str_dict_column = reader.fast_fields().str(field_name)?; @@ -119,9 +169,9 @@ impl AggregationWithAccessor { ColumnType::F64, ColumnType::Str, ColumnType::DateTime, + ColumnType::Bool, + ColumnType::IpAddr, // ColumnType::Bytes Unsupported - // ColumnType::Bool Unsupported - // ColumnType::IpAddr Unsupported ]; // In case the column is empty we want the shim column to match the missing type @@ -162,24 +212,11 @@ impl AggregationWithAccessor { let column_and_types = get_all_ff_reader_or_empty(reader, field_name, None, fallback_type)?; - let accessors: Vec = - column_and_types.iter().map(|(a, _)| a.clone()).collect(); - let agg_wit_acc = AggregationWithAccessor { - missing_value_for_accessor: None, - accessor: accessors[0].clone(), - accessors, - field_type: ColumnType::U64, - sub_aggregation: get_aggs_with_segment_accessor_and_validate( - sub_aggregation, - reader, - &limits, - )?, - agg: agg.clone(), - str_dict_column: str_dict_column.clone(), - limits: limits.new_guard(), - column_block_accessor: Default::default(), - }; - res.push(agg_wit_acc); + let accessors = column_and_types + .iter() + .map(|c_t| (c_t.0.clone(), c_t.1)) + .collect(); + add_agg_with_accessors(&agg, accessors, &mut res, Default::default())?; } for (accessor, column_type) in column_and_types { @@ -189,21 +226,25 @@ impl AggregationWithAccessor { missing.clone() }; - let missing_value_for_accessor = - if let Some(missing) = missing_value_term_agg.as_ref() { - get_missing_val(column_type, missing, agg.agg.get_fast_field_name())? - } else { - None - }; + let missing_value_for_accessor = if let Some(missing) = + missing_value_term_agg.as_ref() + { + get_missing_val(column_type, missing, agg.agg.get_fast_field_names()[0])? + } else { + None + }; let agg = AggregationWithAccessor { + segment_ordinal, missing_value_for_accessor, accessor, - accessors: Vec::new(), + accessors: Default::default(), + value_accessors: Default::default(), field_type: column_type, sub_aggregation: get_aggs_with_segment_accessor_and_validate( sub_aggregation, reader, + segment_ordinal, &limits, )?, agg: agg.clone(), @@ -215,37 +256,66 @@ impl AggregationWithAccessor { } } Average(AverageAggregation { - field: field_name, .. + field: ref field_name, + .. }) | Count(CountAggregation { - field: field_name, .. + field: ref field_name, + .. }) | Max(MaxAggregation { - field: field_name, .. + field: ref field_name, + .. }) | Min(MinAggregation { - field: field_name, .. + field: ref field_name, + .. }) | Stats(StatsAggregation { - field: field_name, .. + field: ref field_name, + .. }) | ExtendedStats(ExtendedStatsAggregation { field: field_name, .. }) | Sum(SumAggregation { - field: field_name, .. + field: ref field_name, + .. }) => { let (accessor, column_type) = get_ff_reader(reader, field_name, Some(get_numeric_or_date_column_types()))?; - add_agg_with_accessor(accessor, column_type, &mut res)?; + add_agg_with_accessor(&agg, accessor, column_type, &mut res)?; } - Percentiles(percentiles) => { + Percentiles(ref percentiles) => { let (accessor, column_type) = get_ff_reader( reader, percentiles.field_name(), Some(get_numeric_or_date_column_types()), )?; - add_agg_with_accessor(accessor, column_type, &mut res)?; + add_agg_with_accessor(&agg, accessor, column_type, &mut res)?; + } + TopHits(ref mut top_hits) => { + top_hits.validate_and_resolve_field_names(reader.fast_fields().columnar())?; + let accessors: Vec<(Column, ColumnType)> = top_hits + .field_names() + .iter() + .map(|field| { + get_ff_reader(reader, field, Some(get_numeric_or_date_column_types())) + }) + .collect::>()?; + + let value_accessors = top_hits + .value_field_names() + .iter() + .map(|field_name| { + Ok(( + field_name.to_string(), + get_dynamic_columns(reader, field_name)?, + )) + }) + .collect::>()?; + + add_agg_with_accessors(&agg, accessors, &mut res, value_accessors)?; } }; @@ -287,6 +357,7 @@ fn get_numeric_or_date_column_types() -> &'static [ColumnType] { pub(crate) fn get_aggs_with_segment_accessor_and_validate( aggs: &Aggregations, reader: &SegmentReader, + segment_ordinal: SegmentOrdinal, limits: &AggregationLimits, ) -> crate::Result { let mut aggss = Vec::new(); @@ -295,6 +366,7 @@ pub(crate) fn get_aggs_with_segment_accessor_and_validate( agg, agg.sub_aggregation(), reader, + segment_ordinal, limits.clone(), )?; for agg in aggs { @@ -324,6 +396,19 @@ fn get_ff_reader( Ok(ff_field_with_type) } +fn get_dynamic_columns( + reader: &SegmentReader, + field_name: &str, +) -> crate::Result> { + let ff_fields = reader.fast_fields().dynamic_column_handles(field_name)?; + let cols = ff_fields + .iter() + .map(|h| h.open()) + .collect::>()?; + assert!(!ff_fields.is_empty(), "field {} not found", field_name); + Ok(cols) +} + /// Get all fast field reader or empty as default. /// /// Is guaranteed to return at least one column. diff --git a/src/aggregation/agg_result.rs b/src/aggregation/agg_result.rs index 6a83e41d7..a032ad121 100644 --- a/src/aggregation/agg_result.rs +++ b/src/aggregation/agg_result.rs @@ -8,7 +8,7 @@ use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; use super::bucket::GetDocCount; -use super::metric::{ExtendedStats, PercentilesMetricResult, SingleMetricResult, Stats}; +use super::metric::{ExtendedStats, PercentilesMetricResult, SingleMetricResult, Stats, TopHitsMetricResult}; use super::{AggregationError, Key}; use crate::TantivyError; @@ -92,8 +92,10 @@ pub enum MetricResult { ExtendedStats(Box), /// Sum metric result. Sum(SingleMetricResult), - /// Sum metric result. + /// Percentiles metric result. Percentiles(PercentilesMetricResult), + /// Top hits metric result + TopHits(TopHitsMetricResult), } impl MetricResult { @@ -109,6 +111,9 @@ impl MetricResult { MetricResult::Percentiles(_) => Err(TantivyError::AggregationError( AggregationError::InvalidRequest("percentiles can't be used to order".to_string()), )), + MetricResult::TopHits(_) => Err(TantivyError::AggregationError( + AggregationError::InvalidRequest("top_hits can't be used to order".to_string()), + )), } } } diff --git a/src/aggregation/agg_tests.rs b/src/aggregation/agg_tests.rs index 1fea7fe6f..126c2240e 100644 --- a/src/aggregation/agg_tests.rs +++ b/src/aggregation/agg_tests.rs @@ -4,6 +4,7 @@ use crate::aggregation::agg_req::{Aggregation, Aggregations}; use crate::aggregation::agg_result::AggregationResults; use crate::aggregation::buf_collector::DOC_BLOCK_SIZE; use crate::aggregation::collector::AggregationCollector; +use crate::aggregation::intermediate_agg_result::IntermediateAggregationResults; use crate::aggregation::segment_agg_result::AggregationLimits; use crate::aggregation::tests::{get_test_index_2_segments, get_test_index_from_values_and_terms}; use crate::aggregation::DistributedAggregationCollector; @@ -66,6 +67,22 @@ fn test_aggregation_flushing( } } }, + "top_hits_test":{ + "terms": { + "field": "string_id" + }, + "aggs": { + "bucketsL2": { + "top_hits": { + "size": 2, + "sort": [ + { "score": "asc" } + ], + "docvalue_fields": ["score"] + } + } + } + }, "histogram_test":{ "histogram": { "field": "score", @@ -108,6 +125,16 @@ fn test_aggregation_flushing( let searcher = reader.searcher(); let intermediate_agg_result = searcher.search(&AllQuery, &collector).unwrap(); + + // Test postcard roundtrip serialization + let intermediate_agg_result_bytes = postcard::to_allocvec(&intermediate_agg_result).expect( + "Postcard Serialization failed, flatten etc. is not supported in the intermediate \ + result", + ); + let intermediate_agg_result: IntermediateAggregationResults = + postcard::from_bytes(&intermediate_agg_result_bytes) + .expect("Post deserialization failed"); + intermediate_agg_result .into_final_result(agg_req, &Default::default()) .unwrap() @@ -587,6 +614,9 @@ fn test_aggregation_on_json_object() { let schema = schema_builder.build(); let index = Index::create_in_ram(schema); let mut index_writer: IndexWriter = index.writer_for_tests().unwrap(); + index_writer + .add_document(doc!(json => json!({"color": "red"}))) + .unwrap(); index_writer .add_document(doc!(json => json!({"color": "red"}))) .unwrap(); @@ -614,8 +644,8 @@ fn test_aggregation_on_json_object() { &serde_json::json!({ "jsonagg": { "buckets": [ + {"doc_count": 2, "key": "red"}, {"doc_count": 1, "key": "blue"}, - {"doc_count": 1, "key": "red"} ], "doc_count_error_upper_bound": 0, "sum_other_doc_count": 0 @@ -637,6 +667,9 @@ fn test_aggregation_on_nested_json_object() { index_writer .add_document(doc!(json => json!({"color.dot": "blue", "color": {"nested":"blue"} }))) .unwrap(); + index_writer + .add_document(doc!(json => json!({"color.dot": "blue", "color": {"nested":"blue"} }))) + .unwrap(); index_writer.commit().unwrap(); let reader = index.reader().unwrap(); let searcher = reader.searcher(); @@ -664,7 +697,7 @@ fn test_aggregation_on_nested_json_object() { &serde_json::json!({ "jsonagg1": { "buckets": [ - {"doc_count": 1, "key": "blue"}, + {"doc_count": 2, "key": "blue"}, {"doc_count": 1, "key": "red"} ], "doc_count_error_upper_bound": 0, @@ -672,7 +705,7 @@ fn test_aggregation_on_nested_json_object() { }, "jsonagg2": { "buckets": [ - {"doc_count": 1, "key": "blue"}, + {"doc_count": 2, "key": "blue"}, {"doc_count": 1, "key": "red"} ], "doc_count_error_upper_bound": 0, @@ -810,29 +843,38 @@ fn test_aggregation_on_json_object_mixed_types() { let mut index_writer: IndexWriter = index.writer_for_tests().unwrap(); // => Segment with all values numeric index_writer - .add_document(doc!(json => json!({"mixed_type": 10.0}))) + .add_document(doc!(json => json!({"mixed_type": 10.0, "mixed_price": 10.0}))) .unwrap(); index_writer.commit().unwrap(); // => Segment with all values text index_writer - .add_document(doc!(json => json!({"mixed_type": "blue"}))) + .add_document(doc!(json => json!({"mixed_type": "blue", "mixed_price": 5.0}))) + .unwrap(); + index_writer + .add_document(doc!(json => json!({"mixed_type": "blue", "mixed_price": 5.0}))) + .unwrap(); + index_writer + .add_document(doc!(json => json!({"mixed_type": "blue", "mixed_price": 5.0}))) .unwrap(); index_writer.commit().unwrap(); // => Segment with all boolen index_writer - .add_document(doc!(json => json!({"mixed_type": true}))) + .add_document(doc!(json => json!({"mixed_type": true, "mixed_price": "no_price"}))) .unwrap(); index_writer.commit().unwrap(); // => Segment with mixed values index_writer - .add_document(doc!(json => json!({"mixed_type": "red"}))) + .add_document(doc!(json => json!({"mixed_type": "red", "mixed_price": 1.0}))) .unwrap(); index_writer - .add_document(doc!(json => json!({"mixed_type": -20.5}))) + .add_document(doc!(json => json!({"mixed_type": "red", "mixed_price": 1.0}))) .unwrap(); index_writer - .add_document(doc!(json => json!({"mixed_type": true}))) + .add_document(doc!(json => json!({"mixed_type": -20.5, "mixed_price": -20.5}))) + .unwrap(); + index_writer + .add_document(doc!(json => json!({"mixed_type": true, "mixed_price": "no_price"}))) .unwrap(); index_writer.commit().unwrap(); @@ -846,7 +888,7 @@ fn test_aggregation_on_json_object_mixed_types() { "order": { "min_price": "desc" } }, "aggs": { - "min_price": { "min": { "field": "json.mixed_type" } } + "min_price": { "min": { "field": "json.mixed_price" } } } }, "rangeagg": { @@ -870,6 +912,7 @@ fn test_aggregation_on_json_object_mixed_types() { let aggregation_results = searcher.search(&AllQuery, &aggregation_collector).unwrap(); let aggregation_res_json = serde_json::to_value(aggregation_results).unwrap(); + use pretty_assertions::assert_eq; assert_eq!( &aggregation_res_json, &serde_json::json!({ @@ -884,10 +927,10 @@ fn test_aggregation_on_json_object_mixed_types() { "termagg": { "buckets": [ { "doc_count": 1, "key": 10.0, "min_price": { "value": 10.0 } }, + { "doc_count": 3, "key": "blue", "min_price": { "value": 5.0 } }, + { "doc_count": 2, "key": "red", "min_price": { "value": 1.0 } }, { "doc_count": 1, "key": -20.5, "min_price": { "value": -20.5 } }, - // TODO bool is also not yet handled in aggregation - { "doc_count": 1, "key": "blue", "min_price": { "value": null } }, - { "doc_count": 1, "key": "red", "min_price": { "value": null } }, + { "doc_count": 2, "key": 1.0, "key_as_string": "true", "min_price": { "value": null } }, ], "sum_other_doc_count": 0 } diff --git a/src/aggregation/bucket/histogram/date_histogram.rs b/src/aggregation/bucket/histogram/date_histogram.rs index d0502af73..e1ca3426b 100644 --- a/src/aggregation/bucket/histogram/date_histogram.rs +++ b/src/aggregation/bucket/histogram/date_histogram.rs @@ -1,7 +1,7 @@ use serde::{Deserialize, Serialize}; use super::{HistogramAggregation, HistogramBounds}; -use crate::aggregation::AggregationError; +use crate::aggregation::*; /// DateHistogramAggregation is similar to `HistogramAggregation`, but it can only be used with date /// type. @@ -307,6 +307,7 @@ pub mod tests { ) -> crate::Result { let mut schema_builder = Schema::builder(); schema_builder.add_date_field("date", FAST); + schema_builder.add_json_field("mixed", FAST); schema_builder.add_text_field("text", FAST | STRING); schema_builder.add_text_field("text2", FAST | STRING); let schema = schema_builder.build(); @@ -351,8 +352,10 @@ pub mod tests { let docs = vec![ vec![r#"{ "date": "2015-01-01T12:10:30Z", "text": "aaa" }"#], vec![r#"{ "date": "2015-01-01T11:11:30Z", "text": "bbb" }"#], + vec![r#"{ "date": "2015-01-01T11:11:30Z", "text": "bbb" }"#], vec![r#"{ "date": "2015-01-02T00:00:00Z", "text": "bbb" }"#], vec![r#"{ "date": "2015-01-06T00:00:00Z", "text": "ccc" }"#], + vec![r#"{ "date": "2015-01-06T00:00:00Z", "text": "ccc" }"#], ]; let index = get_test_index_from_docs(merge_segments, &docs).unwrap(); @@ -381,7 +384,7 @@ pub mod tests { { "key_as_string" : "2015-01-01T00:00:00Z", "key" : 1420070400000.0, - "doc_count" : 4 + "doc_count" : 6 } ] } @@ -419,15 +422,15 @@ pub mod tests { { "key_as_string" : "2015-01-01T00:00:00Z", "key" : 1420070400000.0, - "doc_count" : 4, + "doc_count" : 6, "texts": { "buckets": [ { - "doc_count": 2, + "doc_count": 3, "key": "bbb" }, { - "doc_count": 1, + "doc_count": 2, "key": "ccc" }, { @@ -466,7 +469,7 @@ pub mod tests { "sales_over_time": { "buckets": [ { - "doc_count": 2, + "doc_count": 3, "key": 1420070400000.0, "key_as_string": "2015-01-01T00:00:00Z" }, @@ -491,7 +494,7 @@ pub mod tests { "key_as_string": "2015-01-05T00:00:00Z" }, { - "doc_count": 1, + "doc_count": 2, "key": 1420502400000.0, "key_as_string": "2015-01-06T00:00:00Z" } @@ -532,7 +535,7 @@ pub mod tests { "key_as_string": "2014-12-31T00:00:00Z" }, { - "doc_count": 2, + "doc_count": 3, "key": 1420070400000.0, "key_as_string": "2015-01-01T00:00:00Z" }, @@ -557,7 +560,7 @@ pub mod tests { "key_as_string": "2015-01-05T00:00:00Z" }, { - "doc_count": 1, + "doc_count": 2, "key": 1420502400000.0, "key_as_string": "2015-01-06T00:00:00Z" }, diff --git a/src/aggregation/bucket/histogram/histogram.rs b/src/aggregation/bucket/histogram/histogram.rs index a3597995e..26853c4af 100644 --- a/src/aggregation/bucket/histogram/histogram.rs +++ b/src/aggregation/bucket/histogram/histogram.rs @@ -1,8 +1,5 @@ use std::cmp::Ordering; -use std::fmt::Display; -use columnar::ColumnType; -use itertools::Itertools; use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; use tantivy_bitpacker::minmax; @@ -18,9 +15,9 @@ use crate::aggregation::intermediate_agg_result::{ IntermediateHistogramBucketEntry, }; use crate::aggregation::segment_agg_result::{ - build_segment_agg_collector, AggregationLimits, SegmentAggregationCollector, + build_segment_agg_collector, SegmentAggregationCollector, }; -use crate::aggregation::{f64_from_fastfield_u64, format_date}; +use crate::aggregation::*; use crate::TantivyError; /// Histogram is a bucket aggregation, where buckets are created dynamically for given `interval`. @@ -73,6 +70,7 @@ pub struct HistogramAggregation { pub field: String, /// The interval to chunk your data range. Each bucket spans a value range of [0..interval). /// Must be a positive value. + #[serde(deserialize_with = "deserialize_f64")] pub interval: f64, /// Intervals implicitly defines an absolute grid of buckets `[interval * k, interval * (k + /// 1))`. @@ -85,6 +83,7 @@ pub struct HistogramAggregation { /// fall into the buckets with the key 0 and 10. /// With offset 5 and interval 10, they would both fall into the bucket with they key 5 and the /// range [5..15) + #[serde(default, deserialize_with = "deserialize_option_f64")] pub offset: Option, /// The minimum number of documents in a bucket to be returned. Defaults to 0. pub min_doc_count: Option, @@ -308,7 +307,10 @@ impl SegmentAggregationCollector for SegmentHistogramCollector { .column_block_accessor .fetch_block(docs, &bucket_agg_accessor.accessor); - for (doc, val) in bucket_agg_accessor.column_block_accessor.iter_docid_vals() { + for (doc, val) in bucket_agg_accessor + .column_block_accessor + .iter_docid_vals(docs, &bucket_agg_accessor.accessor) + { let val = self.f64_from_fastfield_u64(val); let bucket_pos = get_bucket_pos(val); @@ -595,11 +597,12 @@ mod tests { use serde_json::Value; use super::*; - use crate::aggregation::agg_req::Aggregations; + use crate::aggregation::agg_result::AggregationResults; use crate::aggregation::tests::{ exec_request, exec_request_with_query, exec_request_with_query_and_memory_limit, get_test_index_2_segments, get_test_index_from_values, get_test_index_with_num_docs, }; + use crate::query::AllQuery; #[test] fn histogram_test_crooked_values() -> crate::Result<()> { @@ -1351,6 +1354,35 @@ mod tests { }) ); + Ok(()) + } + #[test] + fn test_aggregation_histogram_empty_index() -> crate::Result<()> { + // test index without segments + let values = vec![]; + + let index = get_test_index_from_values(false, &values)?; + + let agg_req_1: Aggregations = serde_json::from_value(json!({ + "myhisto": { + "histogram": { + "field": "score", + "interval": 10.0 + }, + } + })) + .unwrap(); + + let collector = AggregationCollector::from_aggs(agg_req_1, Default::default()); + + let reader = index.reader()?; + let searcher = reader.searcher(); + let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap(); + + let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?; + // Make sure the result structure is correct + assert_eq!(res["myhisto"]["buckets"].as_array().unwrap().len(), 0); + Ok(()) } } diff --git a/src/aggregation/bucket/mod.rs b/src/aggregation/bucket/mod.rs index cd6d980cd..f1eaa975b 100644 --- a/src/aggregation/bucket/mod.rs +++ b/src/aggregation/bucket/mod.rs @@ -28,6 +28,7 @@ mod term_agg; mod term_missing_agg; use std::collections::HashMap; +use std::fmt; pub use histogram::*; pub use range::*; @@ -72,12 +73,12 @@ impl From<&str> for OrderTarget { } } -impl ToString for OrderTarget { - fn to_string(&self) -> String { +impl fmt::Display for OrderTarget { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - OrderTarget::Key => "_key".to_string(), - OrderTarget::Count => "_count".to_string(), - OrderTarget::SubAggregation(agg) => agg.to_string(), + OrderTarget::Key => f.write_str("_key"), + OrderTarget::Count => f.write_str("_count"), + OrderTarget::SubAggregation(agg) => agg.fmt(f), } } } diff --git a/src/aggregation/bucket/range.rs b/src/aggregation/bucket/range.rs index a50761e63..2e29d97ae 100644 --- a/src/aggregation/bucket/range.rs +++ b/src/aggregation/bucket/range.rs @@ -1,7 +1,6 @@ use std::fmt::Debug; use std::ops::Range; -use columnar::{ColumnType, MonotonicallyMappableToU64}; use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; @@ -14,9 +13,7 @@ use crate::aggregation::intermediate_agg_result::{ use crate::aggregation::segment_agg_result::{ build_segment_agg_collector, SegmentAggregationCollector, }; -use crate::aggregation::{ - f64_from_fastfield_u64, f64_to_fastfield_u64, format_date, Key, SerializedKey, -}; +use crate::aggregation::*; use crate::TantivyError; /// Provide user-defined buckets to aggregate on. @@ -72,11 +69,19 @@ pub struct RangeAggregationRange { pub key: Option, /// The from range value, which is inclusive in the range. /// `None` equals to an open ended interval. - #[serde(skip_serializing_if = "Option::is_none", default)] + #[serde( + skip_serializing_if = "Option::is_none", + default, + deserialize_with = "deserialize_option_f64" + )] pub from: Option, /// The to range value, which is not inclusive in the range. /// `None` equals to an open ended interval. - #[serde(skip_serializing_if = "Option::is_none", default)] + #[serde( + skip_serializing_if = "Option::is_none", + default, + deserialize_with = "deserialize_option_f64" + )] pub to: Option, } @@ -230,7 +235,10 @@ impl SegmentAggregationCollector for SegmentRangeCollector { .column_block_accessor .fetch_block(docs, &bucket_agg_accessor.accessor); - for (doc, val) in bucket_agg_accessor.column_block_accessor.iter_docid_vals() { + for (doc, val) in bucket_agg_accessor + .column_block_accessor + .iter_docid_vals(docs, &bucket_agg_accessor.accessor) + { let bucket_pos = self.get_bucket_pos(val); let bucket = &mut self.buckets[bucket_pos]; @@ -441,7 +449,6 @@ pub(crate) fn range_to_key(range: &Range, field_type: &ColumnType) -> crate #[cfg(test)] mod tests { - use columnar::MonotonicallyMappableToU64; use serde_json::Value; use super::*; @@ -450,7 +457,6 @@ mod tests { exec_request, exec_request_with_query, get_test_index_2_segments, get_test_index_with_num_docs, }; - use crate::aggregation::AggregationLimits; pub fn get_collector_from_ranges( ranges: Vec, diff --git a/src/aggregation/bucket/term_agg.rs b/src/aggregation/bucket/term_agg.rs index 1b29e361e..2488a53b7 100644 --- a/src/aggregation/bucket/term_agg.rs +++ b/src/aggregation/bucket/term_agg.rs @@ -1,6 +1,10 @@ use std::fmt::Debug; +use std::net::Ipv6Addr; -use columnar::{BytesColumn, ColumnType, MonotonicallyMappableToU64, StrColumn}; +use columnar::column_values::CompactSpaceU64Accessor; +use columnar::{ + BytesColumn, ColumnType, MonotonicallyMappableToU128, MonotonicallyMappableToU64, StrColumn, +}; use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; @@ -99,23 +103,14 @@ pub struct TermsAggregation { #[serde(skip_serializing_if = "Option::is_none", default)] pub size: Option, - /// Unused by tantivy. - /// - /// Since tantivy doesn't know shards, this parameter is merely there to be used by consumers - /// of tantivy. shard_size is the number of terms returned by each shard. - /// The default value in elasticsearch is size * 1.5 + 10. - /// - /// Should never be smaller than size. - #[serde(skip_serializing_if = "Option::is_none", default)] - #[serde(alias = "shard_size")] - pub split_size: Option, - - /// The get more accurate results, we fetch more than `size` from each segment. + /// To get more accurate results, we fetch more than `size` from each segment. /// /// Increasing this value is will increase the cost for more accuracy. /// /// Defaults to 10 * size. #[serde(skip_serializing_if = "Option::is_none", default)] + #[serde(alias = "shard_size")] + #[serde(alias = "split_size")] pub segment_size: Option, /// If you set the `show_term_doc_count_error` parameter to true, the terms aggregation will @@ -256,7 +251,7 @@ pub struct SegmentTermCollector { term_buckets: TermBuckets, req: TermsAggregationInternal, blueprint: Option>, - field_type: ColumnType, + column_type: ColumnType, accessor_idx: usize, } @@ -315,7 +310,10 @@ impl SegmentAggregationCollector for SegmentTermCollector { } // has subagg if let Some(blueprint) = self.blueprint.as_ref() { - for (doc, term_id) in bucket_agg_accessor.column_block_accessor.iter_docid_vals() { + for (doc, term_id) in bucket_agg_accessor + .column_block_accessor + .iter_docid_vals(docs, &bucket_agg_accessor.accessor) + { let sub_aggregations = self .term_buckets .sub_aggs @@ -355,7 +353,7 @@ impl SegmentTermCollector { field_type: ColumnType, accessor_idx: usize, ) -> crate::Result { - if field_type == ColumnType::Bytes || field_type == ColumnType::Bool { + if field_type == ColumnType::Bytes { return Err(TantivyError::InvalidArgument(format!( "terms aggregation is not supported for column type {:?}", field_type @@ -389,7 +387,7 @@ impl SegmentTermCollector { req: TermsAggregationInternal::from_req(req), term_buckets, blueprint, - field_type, + column_type: field_type, accessor_idx, }) } @@ -466,7 +464,7 @@ impl SegmentTermCollector { Ok(intermediate_entry) }; - if self.field_type == ColumnType::Str { + if self.column_type == ColumnType::Str { let term_dict = agg_with_accessor .str_dict_column .as_ref() @@ -531,28 +529,55 @@ impl SegmentTermCollector { }); } } - } else if self.field_type == ColumnType::DateTime { + } else if self.column_type == ColumnType::DateTime { for (val, doc_count) in entries { let intermediate_entry = into_intermediate_bucket_entry(val, doc_count)?; let val = i64::from_u64(val); let date = format_date(val)?; dict.insert(IntermediateKey::Str(date), intermediate_entry); } + } else if self.column_type == ColumnType::Bool { + for (val, doc_count) in entries { + let intermediate_entry = into_intermediate_bucket_entry(val, doc_count)?; + let val = bool::from_u64(val); + dict.insert(IntermediateKey::Bool(val), intermediate_entry); + } + } else if self.column_type == ColumnType::IpAddr { + let compact_space_accessor = agg_with_accessor + .accessor + .values + .clone() + .downcast_arc::() + .map_err(|_| { + TantivyError::AggregationError( + crate::aggregation::AggregationError::InternalError( + "Type mismatch: Could not downcast to CompactSpaceU64Accessor" + .to_string(), + ), + ) + })?; + + for (val, doc_count) in entries { + let intermediate_entry = into_intermediate_bucket_entry(val, doc_count)?; + let val: u128 = compact_space_accessor.compact_to_u128(val as u32); + let val = Ipv6Addr::from_u128(val); + dict.insert(IntermediateKey::IpAddr(val), intermediate_entry); + } } else { for (val, doc_count) in entries { let intermediate_entry = into_intermediate_bucket_entry(val, doc_count)?; - let val = f64_from_fastfield_u64(val, &self.field_type); + let val = f64_from_fastfield_u64(val, &self.column_type); dict.insert(IntermediateKey::F64(val), intermediate_entry); } }; - Ok(IntermediateBucketResult::Terms( - IntermediateTermBucketResult { + Ok(IntermediateBucketResult::Terms { + buckets: IntermediateTermBucketResult { entries: dict, sum_other_doc_count, doc_count_error_upper_bound: term_doc_count_before_cutoff, }, - )) + }) } } @@ -590,6 +615,9 @@ pub(crate) fn cut_off_buckets( #[cfg(test)] mod tests { + use std::net::IpAddr; + use std::str::FromStr; + use common::DateTime; use time::{Date, Month}; @@ -600,7 +628,7 @@ mod tests { }; use crate::aggregation::AggregationLimits; use crate::indexer::NoMergePolicy; - use crate::schema::{Schema, FAST, STRING}; + use crate::schema::{IntoIpv6Addr, Schema, FAST, STRING}; use crate::{Index, IndexWriter}; #[test] @@ -1182,9 +1210,9 @@ mod tests { assert_eq!(res["my_texts"]["buckets"][0]["key"], "terma"); assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 4); - assert_eq!(res["my_texts"]["buckets"][1]["key"], "termc"); + assert_eq!(res["my_texts"]["buckets"][1]["key"], "termb"); assert_eq!(res["my_texts"]["buckets"][1]["doc_count"], 0); - assert_eq!(res["my_texts"]["buckets"][2]["key"], "termb"); + assert_eq!(res["my_texts"]["buckets"][2]["key"], "termc"); assert_eq!(res["my_texts"]["buckets"][2]["doc_count"], 0); assert_eq!(res["my_texts"]["sum_other_doc_count"], 0); assert_eq!(res["my_texts"]["doc_count_error_upper_bound"], 0); @@ -1365,7 +1393,7 @@ mod tests { #[test] fn terms_aggregation_different_tokenizer_on_ff_test() -> crate::Result<()> { - let terms = vec!["Hello Hello", "Hallo Hallo"]; + let terms = vec!["Hello Hello", "Hallo Hallo", "Hallo Hallo"]; let index = get_test_index_from_terms(true, &[terms])?; @@ -1383,7 +1411,7 @@ mod tests { println!("{}", serde_json::to_string_pretty(&res).unwrap()); assert_eq!(res["my_texts"]["buckets"][0]["key"], "Hallo Hallo"); - assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 1); + assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 2); assert_eq!(res["my_texts"]["buckets"][1]["key"], "Hello Hello"); assert_eq!(res["my_texts"]["buckets"][1]["doc_count"], 1); @@ -1894,4 +1922,80 @@ mod tests { Ok(()) } + + #[test] + fn terms_aggregation_bool() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let field = schema_builder.add_bool_field("bool_field", FAST); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + { + let mut writer = index.writer_with_num_threads(1, 15_000_000)?; + writer.add_document(doc!(field=>true))?; + writer.add_document(doc!(field=>false))?; + writer.add_document(doc!(field=>true))?; + writer.commit()?; + } + + let agg_req: Aggregations = serde_json::from_value(json!({ + "my_bool": { + "terms": { + "field": "bool_field" + }, + } + })) + .unwrap(); + + let res = exec_request_with_query(agg_req, &index, None)?; + + assert_eq!(res["my_bool"]["buckets"][0]["key"], 1.0); + assert_eq!(res["my_bool"]["buckets"][0]["key_as_string"], "true"); + assert_eq!(res["my_bool"]["buckets"][0]["doc_count"], 2); + assert_eq!(res["my_bool"]["buckets"][1]["key"], 0.0); + assert_eq!(res["my_bool"]["buckets"][1]["key_as_string"], "false"); + assert_eq!(res["my_bool"]["buckets"][1]["doc_count"], 1); + assert_eq!(res["my_bool"]["buckets"][2]["key"], serde_json::Value::Null); + + Ok(()) + } + + #[test] + fn terms_aggregation_ip_addr() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let field = schema_builder.add_ip_addr_field("ip_field", FAST); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + { + let mut writer = index.writer_with_num_threads(1, 15_000_000)?; + // IpV6 loopback + writer.add_document(doc!(field=>IpAddr::from_str("::1").unwrap().into_ipv6_addr()))?; + writer.add_document(doc!(field=>IpAddr::from_str("::1").unwrap().into_ipv6_addr()))?; + // IpV4 + writer.add_document( + doc!(field=>IpAddr::from_str("127.0.0.1").unwrap().into_ipv6_addr()), + )?; + writer.commit()?; + } + + let agg_req: Aggregations = serde_json::from_value(json!({ + "my_bool": { + "terms": { + "field": "ip_field" + }, + } + })) + .unwrap(); + + let res = exec_request_with_query(agg_req, &index, None)?; + // print as json + // println!("{}", serde_json::to_string_pretty(&res).unwrap()); + + assert_eq!(res["my_bool"]["buckets"][0]["key"], "::1"); + assert_eq!(res["my_bool"]["buckets"][0]["doc_count"], 2); + assert_eq!(res["my_bool"]["buckets"][1]["key"], "127.0.0.1"); + assert_eq!(res["my_bool"]["buckets"][1]["doc_count"], 1); + assert_eq!(res["my_bool"]["buckets"][2]["key"], serde_json::Value::Null); + + Ok(()) + } } diff --git a/src/aggregation/bucket/term_missing_agg.rs b/src/aggregation/bucket/term_missing_agg.rs index 1d43a2e65..bb8b295b4 100644 --- a/src/aggregation/bucket/term_missing_agg.rs +++ b/src/aggregation/bucket/term_missing_agg.rs @@ -73,11 +73,13 @@ impl SegmentAggregationCollector for TermMissingAgg { entries.insert(missing.into(), missing_entry); - let bucket = IntermediateBucketResult::Terms(IntermediateTermBucketResult { - entries, - sum_other_doc_count: 0, - doc_count_error_upper_bound: 0, - }); + let bucket = IntermediateBucketResult::Terms { + buckets: IntermediateTermBucketResult { + entries, + sum_other_doc_count: 0, + doc_count_error_upper_bound: 0, + }, + }; results.push(name, IntermediateAggregationResult::Bucket(bucket))?; @@ -90,7 +92,10 @@ impl SegmentAggregationCollector for TermMissingAgg { agg_with_accessor: &mut AggregationsWithAccessor, ) -> crate::Result<()> { let agg = &mut agg_with_accessor.aggs.values[self.accessor_idx]; - let has_value = agg.accessors.iter().any(|acc| acc.index.has_value(doc)); + let has_value = agg + .accessors + .iter() + .any(|(acc, _)| acc.index.has_value(doc)); if !has_value { self.missing_count += 1; if let Some(sub_agg) = self.sub_agg.as_mut() { diff --git a/src/aggregation/collector.rs b/src/aggregation/collector.rs index b3a0ed917..d0e9ec5b8 100644 --- a/src/aggregation/collector.rs +++ b/src/aggregation/collector.rs @@ -8,7 +8,7 @@ use super::segment_agg_result::{ }; use crate::aggregation::agg_req_with_accessor::get_aggs_with_segment_accessor_and_validate; use crate::collector::{Collector, SegmentCollector}; -use crate::{DocId, SegmentReader, TantivyError}; +use crate::{DocId, SegmentOrdinal, SegmentReader, TantivyError}; /// The default max bucket count, before the aggregation fails. pub const DEFAULT_BUCKET_LIMIT: u32 = 65000; @@ -64,10 +64,15 @@ impl Collector for DistributedAggregationCollector { fn for_segment( &self, - _segment_local_id: crate::SegmentOrdinal, + segment_local_id: crate::SegmentOrdinal, reader: &crate::SegmentReader, ) -> crate::Result { - AggregationSegmentCollector::from_agg_req_and_reader(&self.agg, reader, &self.limits) + AggregationSegmentCollector::from_agg_req_and_reader( + &self.agg, + reader, + segment_local_id, + &self.limits, + ) } fn requires_scoring(&self) -> bool { @@ -89,10 +94,15 @@ impl Collector for AggregationCollector { fn for_segment( &self, - _segment_local_id: crate::SegmentOrdinal, + segment_local_id: crate::SegmentOrdinal, reader: &crate::SegmentReader, ) -> crate::Result { - AggregationSegmentCollector::from_agg_req_and_reader(&self.agg, reader, &self.limits) + AggregationSegmentCollector::from_agg_req_and_reader( + &self.agg, + reader, + segment_local_id, + &self.limits, + ) } fn requires_scoring(&self) -> bool { @@ -135,10 +145,11 @@ impl AggregationSegmentCollector { pub fn from_agg_req_and_reader( agg: &Aggregations, reader: &SegmentReader, + segment_ordinal: SegmentOrdinal, limits: &AggregationLimits, ) -> crate::Result { let mut aggs_with_accessor = - get_aggs_with_segment_accessor_and_validate(agg, reader, limits)?; + get_aggs_with_segment_accessor_and_validate(agg, reader, segment_ordinal, limits)?; let result = BufAggregationCollector::new(build_segment_agg_collector(&mut aggs_with_accessor)?); Ok(AggregationSegmentCollector { diff --git a/src/aggregation/intermediate_agg_result.rs b/src/aggregation/intermediate_agg_result.rs index ed913aebb..3f472854b 100644 --- a/src/aggregation/intermediate_agg_result.rs +++ b/src/aggregation/intermediate_agg_result.rs @@ -5,6 +5,7 @@ use std::cmp::Ordering; use std::collections::hash_map::Entry; use std::hash::Hash; +use std::net::Ipv6Addr; use columnar::ColumnType; use itertools::Itertools; @@ -19,7 +20,7 @@ use super::bucket::{ }; use super::metric::{ IntermediateAverage, IntermediateCount, IntermediateExtendedStats, IntermediateMax, - IntermediateMin, IntermediateStats, IntermediateSum, PercentilesCollector, + IntermediateMin, IntermediateStats, IntermediateSum, PercentilesCollector, TopHitsTopNComputer, }; use super::segment_agg_result::AggregationLimits; use super::{format_date, AggregationError, Key, SerializedKey}; @@ -41,6 +42,10 @@ pub struct IntermediateAggregationResults { /// This might seem redundant with `Key`, but the point is to have a different /// Serialize implementation. pub enum IntermediateKey { + /// Ip Addr key + IpAddr(Ipv6Addr), + /// Bool key + Bool(bool), /// String key Str(String), /// `f64` key @@ -58,7 +63,16 @@ impl From for Key { fn from(value: IntermediateKey) -> Self { match value { IntermediateKey::Str(s) => Self::Str(s), + IntermediateKey::IpAddr(s) => { + // Prefer to use the IPv4 representation if possible + if let Some(ip) = s.to_ipv4_mapped() { + Self::Str(ip.to_string()) + } else { + Self::Str(s.to_string()) + } + } IntermediateKey::F64(f) => Self::F64(f), + IntermediateKey::Bool(f) => Self::F64(f as u64 as f64), } } } @@ -71,6 +85,8 @@ impl std::hash::Hash for IntermediateKey { match self { IntermediateKey::Str(text) => text.hash(state), IntermediateKey::F64(val) => val.to_bits().hash(state), + IntermediateKey::Bool(val) => val.hash(state), + IntermediateKey::IpAddr(val) => val.hash(state), } } } @@ -166,9 +182,9 @@ impl IntermediateAggregationResults { pub(crate) fn empty_from_req(req: &Aggregation) -> IntermediateAggregationResult { use AggregationVariants::*; match req.agg { - Terms(_) => IntermediateAggregationResult::Bucket(IntermediateBucketResult::Terms( - Default::default(), - )), + Terms(_) => IntermediateAggregationResult::Bucket(IntermediateBucketResult::Terms { + buckets: Default::default(), + }), Range(_) => IntermediateAggregationResult::Bucket(IntermediateBucketResult::Range( Default::default(), )), @@ -208,6 +224,9 @@ pub(crate) fn empty_from_req(req: &Aggregation) -> IntermediateAggregationResult Percentiles(_) => IntermediateAggregationResult::Metric( IntermediateMetricResult::Percentiles(PercentilesCollector::default()), ), + TopHits(ref req) => IntermediateAggregationResult::Metric( + IntermediateMetricResult::TopHits(TopHitsTopNComputer::new(req.clone())), + ), } } @@ -270,6 +289,8 @@ pub enum IntermediateMetricResult { ExtendedStats(IntermediateExtendedStats), /// Intermediate sum result. Sum(IntermediateSum), + /// Intermediate top_hits result + TopHits(TopHitsTopNComputer), } impl IntermediateMetricResult { @@ -300,9 +321,13 @@ impl IntermediateMetricResult { percentiles .into_final_result(req.agg.as_percentile().expect("unexpected metric type")), ), + IntermediateMetricResult::TopHits(top_hits) => { + MetricResult::TopHits(top_hits.into_final_result()) + } } } + // TODO: this is our top-of-the-chain fruit merge mech fn merge_fruits(&mut self, other: IntermediateMetricResult) -> crate::Result<()> { match (self, other) { ( @@ -344,6 +369,9 @@ impl IntermediateMetricResult { ) => { left.merge_fruits(right)?; } + (IntermediateMetricResult::TopHits(left), IntermediateMetricResult::TopHits(right)) => { + left.merge_fruits(right)?; + } _ => { panic!("incompatible fruit types in tree or missing merge_fruits handler"); } @@ -365,11 +393,14 @@ pub enum IntermediateBucketResult { Histogram { /// The column_type of the underlying `Column` is DateTime is_date_agg: bool, - /// The buckets + /// The histogram buckets buckets: Vec, }, /// Term aggregation - Terms(IntermediateTermBucketResult), + Terms { + /// The term buckets + buckets: IntermediateTermBucketResult, + }, } impl IntermediateBucketResult { @@ -446,7 +477,7 @@ impl IntermediateBucketResult { }; Ok(BucketResult::Histogram { buckets }) } - IntermediateBucketResult::Terms(terms) => terms.into_final_result( + IntermediateBucketResult::Terms { buckets: terms } => terms.into_final_result( req.agg .as_term() .expect("unexpected aggregation, expected term aggregation"), @@ -459,8 +490,12 @@ impl IntermediateBucketResult { fn merge_fruits(&mut self, other: IntermediateBucketResult) -> crate::Result<()> { match (self, other) { ( - IntermediateBucketResult::Terms(term_res_left), - IntermediateBucketResult::Terms(term_res_right), + IntermediateBucketResult::Terms { + buckets: term_res_left, + }, + IntermediateBucketResult::Terms { + buckets: term_res_right, + }, ) => { merge_maps(&mut term_res_left.entries, term_res_right.entries)?; term_res_left.sum_other_doc_count += term_res_right.sum_other_doc_count; @@ -544,8 +579,15 @@ impl IntermediateTermBucketResult { .into_iter() .filter(|bucket| bucket.1.doc_count as u64 >= req.min_doc_count) .map(|(key, entry)| { + let key_as_string = match key { + IntermediateKey::Bool(key) => { + let val = if key { "true" } else { "false" }; + Some(val.to_string()) + } + _ => None, + }; Ok(BucketEntry { - key_as_string: None, + key_as_string, key: key.into(), doc_count: entry.doc_count as u64, sub_aggregation: entry diff --git a/src/aggregation/metric/average.rs b/src/aggregation/metric/average.rs index 8e9ad69ad..70d426c51 100644 --- a/src/aggregation/metric/average.rs +++ b/src/aggregation/metric/average.rs @@ -24,7 +24,7 @@ pub struct AverageAggregation { /// By default they will be ignored but it is also possible to treat them as if they had a /// value. Examples in JSON format: /// { "field": "my_numbers", "missing": "10.0" } - #[serde(default)] + #[serde(default, deserialize_with = "deserialize_option_f64")] pub missing: Option, } @@ -63,3 +63,71 @@ impl IntermediateAverage { self.stats.finalize().avg } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn deserialization_with_missing_test1() { + let json = r#"{ + "field": "score", + "missing": "10.0" + }"#; + let avg: AverageAggregation = serde_json::from_str(json).unwrap(); + assert_eq!(avg.field, "score"); + assert_eq!(avg.missing, Some(10.0)); + // no dot + let json = r#"{ + "field": "score", + "missing": "10" + }"#; + let avg: AverageAggregation = serde_json::from_str(json).unwrap(); + assert_eq!(avg.field, "score"); + assert_eq!(avg.missing, Some(10.0)); + + // from value + let avg: AverageAggregation = serde_json::from_value(json!({ + "field": "score_f64", + "missing": 10u64, + })) + .unwrap(); + assert_eq!(avg.missing, Some(10.0)); + // from value + let avg: AverageAggregation = serde_json::from_value(json!({ + "field": "score_f64", + "missing": 10u32, + })) + .unwrap(); + assert_eq!(avg.missing, Some(10.0)); + let avg: AverageAggregation = serde_json::from_value(json!({ + "field": "score_f64", + "missing": 10i8, + })) + .unwrap(); + assert_eq!(avg.missing, Some(10.0)); + } + + #[test] + fn deserialization_with_missing_test_fail() { + let json = r#"{ + "field": "score", + "missing": "a" + }"#; + let avg: Result = serde_json::from_str(json); + assert!(avg.is_err()); + assert!(avg + .unwrap_err() + .to_string() + .contains("Failed to parse f64 from string: \"a\"")); + + // Disallow NaN + let json = r#"{ + "field": "score", + "missing": "NaN" + }"#; + let avg: Result = serde_json::from_str(json); + assert!(avg.is_err()); + assert!(avg.unwrap_err().to_string().contains("NaN")); + } +} diff --git a/src/aggregation/metric/count.rs b/src/aggregation/metric/count.rs index afdcfe55d..1868716b1 100644 --- a/src/aggregation/metric/count.rs +++ b/src/aggregation/metric/count.rs @@ -3,7 +3,7 @@ use std::fmt::Debug; use serde::{Deserialize, Serialize}; use super::IntermediateStats; - +use crate::aggregation::*; /// A single-value metric aggregation that counts the number of values that are /// extracted from the aggregated documents. /// See [super::SingleMetricResult] for return value. @@ -24,7 +24,7 @@ pub struct CountAggregation { /// By default they will be ignored but it is also possible to treat them as if they had a /// value. Examples in JSON format: /// { "field": "my_numbers", "missing": "10.0" } - #[serde(default)] + #[serde(default, deserialize_with = "deserialize_option_f64")] pub missing: Option, } diff --git a/src/aggregation/metric/max.rs b/src/aggregation/metric/max.rs index b1be96dfd..d3fc722a2 100644 --- a/src/aggregation/metric/max.rs +++ b/src/aggregation/metric/max.rs @@ -3,7 +3,7 @@ use std::fmt::Debug; use serde::{Deserialize, Serialize}; use super::IntermediateStats; - +use crate::aggregation::*; /// A single-value metric aggregation that computes the maximum of numeric values that are /// extracted from the aggregated documents. /// See [super::SingleMetricResult] for return value. @@ -24,7 +24,7 @@ pub struct MaxAggregation { /// By default they will be ignored but it is also possible to treat them as if they had a /// value. Examples in JSON format: /// { "field": "my_numbers", "missing": "10.0" } - #[serde(default)] + #[serde(default, deserialize_with = "deserialize_option_f64")] pub missing: Option, } diff --git a/src/aggregation/metric/min.rs b/src/aggregation/metric/min.rs index 45ffdcb9d..7e042995c 100644 --- a/src/aggregation/metric/min.rs +++ b/src/aggregation/metric/min.rs @@ -3,6 +3,7 @@ use std::fmt::Debug; use serde::{Deserialize, Serialize}; use super::IntermediateStats; +use crate::aggregation::*; /// A single-value metric aggregation that computes the minimum of numeric values that are /// extracted from the aggregated documents. @@ -24,7 +25,7 @@ pub struct MinAggregation { /// By default they will be ignored but it is also possible to treat them as if they had a /// value. Examples in JSON format: /// { "field": "my_numbers", "missing": "10.0" } - #[serde(default)] + #[serde(default, deserialize_with = "deserialize_option_f64")] pub missing: Option, } diff --git a/src/aggregation/metric/mod.rs b/src/aggregation/metric/mod.rs index 998793d2d..6da583a59 100644 --- a/src/aggregation/metric/mod.rs +++ b/src/aggregation/metric/mod.rs @@ -23,6 +23,10 @@ mod min; mod percentiles; mod stats; mod sum; +mod top_hits; + +use std::collections::HashMap; + pub use average::*; pub use count::*; pub use max::*; @@ -32,6 +36,9 @@ use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; pub use stats::*; pub use sum::*; +pub use top_hits::*; + +use crate::schema::OwnedValue; /// Single-metric aggregations use this common result structure. /// @@ -81,6 +88,28 @@ pub struct PercentilesMetricResult { pub values: PercentileValues, } +/// The top_hits metric results entry +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct TopHitsVecEntry { + /// The sort values of the document, depending on the sort criteria in the request. + pub sort: Vec>, + + /// Search results, for queries that include field retrieval requests + /// (`docvalue_fields`). + #[serde(rename = "docvalue_fields")] + #[serde(skip_serializing_if = "HashMap::is_empty")] + pub doc_value_fields: HashMap, +} + +/// The top_hits metric aggregation results a list of top hits by sort criteria. +/// +/// The main reason for wrapping it in `hits` is to match elasticsearch output structure. +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct TopHitsMetricResult { + /// The result of the top_hits metric. + pub hits: Vec, +} + #[cfg(test)] mod tests { use crate::aggregation::agg_req::Aggregations; diff --git a/src/aggregation/metric/percentiles.rs b/src/aggregation/metric/percentiles.rs index 7b66b1273..4a6bac3f0 100644 --- a/src/aggregation/metric/percentiles.rs +++ b/src/aggregation/metric/percentiles.rs @@ -1,6 +1,5 @@ use std::fmt::Debug; -use columnar::ColumnType; use serde::{Deserialize, Serialize}; use super::*; @@ -11,7 +10,7 @@ use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult, }; use crate::aggregation::segment_agg_result::SegmentAggregationCollector; -use crate::aggregation::{f64_from_fastfield_u64, f64_to_fastfield_u64, AggregationError}; +use crate::aggregation::*; use crate::{DocId, TantivyError}; /// # Percentiles @@ -84,7 +83,11 @@ pub struct PercentilesAggregationReq { /// By default they will be ignored but it is also possible to treat them as if they had a /// value. Examples in JSON format: /// { "field": "my_numbers", "missing": "10.0" } - #[serde(skip_serializing_if = "Option::is_none", default)] + #[serde( + skip_serializing_if = "Option::is_none", + default, + deserialize_with = "deserialize_option_f64" + )] pub missing: Option, } fn default_percentiles() -> &'static [f64] { @@ -133,7 +136,6 @@ pub(crate) struct SegmentPercentilesCollector { field_type: ColumnType, pub(crate) percentiles: PercentilesCollector, pub(crate) accessor_idx: usize, - val_cache: Vec, missing: Option, } @@ -243,7 +245,6 @@ impl SegmentPercentilesCollector { field_type, percentiles: PercentilesCollector::new(), accessor_idx, - val_cache: Default::default(), missing, }) } diff --git a/src/aggregation/metric/stats.rs b/src/aggregation/metric/stats.rs index c2c7630ff..749ca69da 100644 --- a/src/aggregation/metric/stats.rs +++ b/src/aggregation/metric/stats.rs @@ -1,6 +1,3 @@ -use std::fmt::Debug; - -use columnar::ColumnType; use serde::{Deserialize, Serialize}; use super::*; @@ -11,7 +8,7 @@ use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateMetricResult, }; use crate::aggregation::segment_agg_result::SegmentAggregationCollector; -use crate::aggregation::{f64_from_fastfield_u64, f64_to_fastfield_u64}; +use crate::aggregation::*; use crate::{DocId, TantivyError}; /// A multi-value metric aggregation that computes a collection of statistics on numeric values that @@ -35,7 +32,7 @@ pub struct StatsAggregation { /// By default they will be ignored but it is also possible to treat them as if they had a /// value. Examples in JSON format: /// { "field": "my_numbers", "missing": "10.0" } - #[serde(default)] + #[serde(default, deserialize_with = "deserialize_option_f64")] pub missing: Option, } @@ -980,6 +977,30 @@ mod tests { }) ); + // From string + let agg_req: Aggregations = serde_json::from_value(json!({ + "my_stats": { + "stats": { + "field": "json.partially_empty", + "missing": "0.0" + }, + } + })) + .unwrap(); + + let res = exec_request_with_query(agg_req, &index, None)?; + + assert_eq!( + res["my_stats"], + json!({ + "avg": 2.5, + "count": 4, + "max": 10.0, + "min": 0.0, + "sum": 10.0 + }) + ); + Ok(()) } diff --git a/src/aggregation/metric/sum.rs b/src/aggregation/metric/sum.rs index a455ae010..a1e2497fa 100644 --- a/src/aggregation/metric/sum.rs +++ b/src/aggregation/metric/sum.rs @@ -3,7 +3,7 @@ use std::fmt::Debug; use serde::{Deserialize, Serialize}; use super::IntermediateStats; - +use crate::aggregation::*; /// A single-value metric aggregation that sums up numeric values that are /// extracted from the aggregated documents. /// See [super::SingleMetricResult] for return value. @@ -24,7 +24,7 @@ pub struct SumAggregation { /// By default they will be ignored but it is also possible to treat them as if they had a /// value. Examples in JSON format: /// { "field": "my_numbers", "missing": "10.0" } - #[serde(default)] + #[serde(default, deserialize_with = "deserialize_option_f64")] pub missing: Option, } diff --git a/src/aggregation/metric/top_hits.rs b/src/aggregation/metric/top_hits.rs new file mode 100644 index 000000000..ee316cf6e --- /dev/null +++ b/src/aggregation/metric/top_hits.rs @@ -0,0 +1,897 @@ +use std::collections::HashMap; +use std::net::Ipv6Addr; + +use columnar::{ColumnarReader, DynamicColumn}; +use common::json_path_writer::JSON_PATH_SEGMENT_SEP_STR; +use common::DateTime; +use regex::Regex; +use serde::ser::SerializeMap; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +use super::{TopHitsMetricResult, TopHitsVecEntry}; +use crate::aggregation::bucket::Order; +use crate::aggregation::intermediate_agg_result::{ + IntermediateAggregationResult, IntermediateMetricResult, +}; +use crate::aggregation::segment_agg_result::SegmentAggregationCollector; +use crate::aggregation::AggregationError; +use crate::collector::TopNComputer; +use crate::schema::OwnedValue; +use crate::{DocAddress, DocId, SegmentOrdinal}; + +/// # Top Hits +/// +/// The top hits aggregation is a useful tool to answer questions like: +/// - "What are the most recent posts by each author?" +/// - "What are the most popular items in each category?" +/// +/// It does so by keeping track of the most relevant document being aggregated, +/// in terms of a sort criterion that can consist of multiple fields and their +/// sort-orders (ascending or descending). +/// +/// `top_hits` should not be used as a top-level aggregation. It is intended to be +/// used as a sub-aggregation, inside a `terms` aggregation or a `filters` aggregation, +/// for example. +/// +/// Note that this aggregator does not return the actual document addresses, but +/// rather a list of the values of the fields that were requested to be retrieved. +/// These values can be specified in the `docvalue_fields` parameter, which can include +/// a list of fast fields to be retrieved. At the moment, only fast fields are supported +/// but it is possible that we support the `fields` parameter to retrieve any stored +/// field in the future. +/// +/// The following example demonstrates a request for the top_hits aggregation: +/// ```JSON +/// { +/// "aggs": { +/// "top_authors": { +/// "terms": { +/// "field": "author", +/// "size": 5 +/// } +/// }, +/// "aggs": { +/// "top_hits": { +/// "size": 2, +/// "from": 0 +/// "sort": [ +/// { "date": "desc" } +/// ] +/// "docvalue_fields": ["date", "title", "iden"] +/// } +/// } +/// } +/// ``` +/// +/// This request will return an object containing the top two documents, sorted +/// by the `date` field in descending order. You can also sort by multiple fields, which +/// helps to resolve ties. The aggregation object for each bucket will look like: +/// ```JSON +/// { +/// "hits": [ +/// { +/// "score": [], +/// "docvalue_fields": { +/// "date": "", +/// "title": "", +/// "iden": "<iden>" +/// } +/// }, +/// { +/// "score": [<time_u64>] +/// "docvalue_fields": { +/// "date": "<date_RFC3339>", +/// "title": "<title>", +/// "iden": "<iden>" +/// } +/// } +/// ] +/// } +/// ``` +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)] +pub struct TopHitsAggregation { + sort: Vec<KeyOrder>, + size: usize, + from: Option<usize>, + + #[serde(rename = "docvalue_fields")] + #[serde(default)] + doc_value_fields: Vec<String>, + + // Not supported + _source: Option<serde_json::Value>, + fields: Option<serde_json::Value>, + script_fields: Option<serde_json::Value>, + highlight: Option<serde_json::Value>, + explain: Option<serde_json::Value>, + version: Option<serde_json::Value>, +} + +#[derive(Debug, Clone, PartialEq, Default)] +struct KeyOrder { + field: String, + order: Order, +} + +impl Serialize for KeyOrder { + fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { + let KeyOrder { field, order } = self; + let mut map = serializer.serialize_map(Some(1))?; + map.serialize_entry(field, order)?; + map.end() + } +} + +impl<'de> Deserialize<'de> for KeyOrder { + fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> + where D: Deserializer<'de> { + let mut key_order = <HashMap<String, Order>>::deserialize(deserializer)?.into_iter(); + let (field, order) = key_order.next().ok_or(serde::de::Error::custom( + "Expected exactly one key-value pair in sort parameter of top_hits, found none", + ))?; + if key_order.next().is_some() { + return Err(serde::de::Error::custom(format!( + "Expected exactly one key-value pair in sort parameter of top_hits, found {:?}", + key_order + ))); + } + Ok(Self { field, order }) + } +} + +// Tranform a glob (`pattern*`, for example) into a regex::Regex (`^pattern.*$`) +fn globbed_string_to_regex(glob: &str) -> Result<Regex, crate::TantivyError> { + // Replace `*` glob with `.*` regex + let sanitized = format!("^{}$", regex::escape(glob).replace(r"\*", ".*")); + Regex::new(&sanitized.replace('*', ".*")).map_err(|e| { + crate::TantivyError::SchemaError(format!( + "Invalid regex '{}' in docvalue_fields: {}", + glob, e + )) + }) +} + +fn use_doc_value_fields_err(parameter: &str) -> crate::Result<()> { + Err(crate::TantivyError::AggregationError( + AggregationError::InvalidRequest(format!( + "The `{}` parameter is not supported, only `docvalue_fields` is supported in \ + `top_hits` aggregation", + parameter + )), + )) +} +fn unsupported_err(parameter: &str) -> crate::Result<()> { + Err(crate::TantivyError::AggregationError( + AggregationError::InvalidRequest(format!( + "The `{}` parameter is not supported in the `top_hits` aggregation", + parameter + )), + )) +} + +impl TopHitsAggregation { + /// Validate and resolve field retrieval parameters + pub fn validate_and_resolve_field_names( + &mut self, + reader: &ColumnarReader, + ) -> crate::Result<()> { + if self._source.is_some() { + use_doc_value_fields_err("_source")?; + } + if self.fields.is_some() { + use_doc_value_fields_err("fields")?; + } + if self.script_fields.is_some() { + use_doc_value_fields_err("script_fields")?; + } + if self.explain.is_some() { + unsupported_err("explain")?; + } + if self.highlight.is_some() { + unsupported_err("highlight")?; + } + if self.version.is_some() { + unsupported_err("version")?; + } + + self.doc_value_fields = self + .doc_value_fields + .iter() + .map(|field| { + if !field.contains('*') + && reader + .iter_columns()? + .any(|(name, _)| name.as_str() == field) + { + return Ok(vec![field.to_owned()]); + } + + let pattern = globbed_string_to_regex(field)?; + let fields = reader + .iter_columns()? + .map(|(name, _)| { + // normalize path from internal fast field repr + name.replace(JSON_PATH_SEGMENT_SEP_STR, ".") + }) + .filter(|name| pattern.is_match(name)) + .collect::<Vec<_>>(); + assert!( + !fields.is_empty(), + "No fields matched the glob '{}' in docvalue_fields", + field + ); + Ok(fields) + }) + .collect::<crate::Result<Vec<_>>>()? + .into_iter() + .flatten() + .collect(); + + Ok(()) + } + + /// Return fields accessed by the aggregator, in order. + pub fn field_names(&self) -> Vec<&str> { + self.sort + .iter() + .map(|KeyOrder { field, .. }| field.as_str()) + .collect() + } + + /// Return fields accessed by the aggregator's value retrieval. + pub fn value_field_names(&self) -> Vec<&str> { + self.doc_value_fields.iter().map(|s| s.as_str()).collect() + } + + fn get_document_field_data( + &self, + accessors: &HashMap<String, Vec<DynamicColumn>>, + doc_id: DocId, + ) -> HashMap<String, FastFieldValue> { + let doc_value_fields = self + .doc_value_fields + .iter() + .map(|field| { + let accessors = accessors + .get(field) + .unwrap_or_else(|| panic!("field '{}' not found in accessors", field)); + + let values: Vec<FastFieldValue> = accessors + .iter() + .flat_map(|accessor| match accessor { + DynamicColumn::U64(accessor) => accessor + .values_for_doc(doc_id) + .map(FastFieldValue::U64) + .collect::<Vec<_>>(), + DynamicColumn::I64(accessor) => accessor + .values_for_doc(doc_id) + .map(FastFieldValue::I64) + .collect::<Vec<_>>(), + DynamicColumn::F64(accessor) => accessor + .values_for_doc(doc_id) + .map(FastFieldValue::F64) + .collect::<Vec<_>>(), + DynamicColumn::Bytes(accessor) => accessor + .term_ords(doc_id) + .map(|term_ord| { + let mut buffer = vec![]; + assert!( + accessor + .ord_to_bytes(term_ord, &mut buffer) + .expect("could not read term dictionary"), + "term corresponding to term_ord does not exist" + ); + FastFieldValue::Bytes(buffer) + }) + .collect::<Vec<_>>(), + DynamicColumn::Str(accessor) => accessor + .term_ords(doc_id) + .map(|term_ord| { + let mut buffer = vec![]; + assert!( + accessor + .ord_to_bytes(term_ord, &mut buffer) + .expect("could not read term dictionary"), + "term corresponding to term_ord does not exist" + ); + FastFieldValue::Str(String::from_utf8(buffer).unwrap()) + }) + .collect::<Vec<_>>(), + DynamicColumn::Bool(accessor) => accessor + .values_for_doc(doc_id) + .map(FastFieldValue::Bool) + .collect::<Vec<_>>(), + DynamicColumn::IpAddr(accessor) => accessor + .values_for_doc(doc_id) + .map(FastFieldValue::IpAddr) + .collect::<Vec<_>>(), + DynamicColumn::DateTime(accessor) => accessor + .values_for_doc(doc_id) + .map(FastFieldValue::Date) + .collect::<Vec<_>>(), + }) + .collect(); + + (field.to_owned(), FastFieldValue::Array(values)) + }) + .collect(); + doc_value_fields + } +} + +/// A retrieved value from a fast field. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum FastFieldValue { + /// The str type is used for any text information. + Str(String), + /// Unsigned 64-bits Integer `u64` + U64(u64), + /// Signed 64-bits Integer `i64` + I64(i64), + /// 64-bits Float `f64` + F64(f64), + /// Bool value + Bool(bool), + /// Date/time with nanoseconds precision + Date(DateTime), + /// Arbitrarily sized byte array + Bytes(Vec<u8>), + /// IpV6 Address. Internally there is no IpV4, it needs to be converted to `Ipv6Addr`. + IpAddr(Ipv6Addr), + /// A list of values. + Array(Vec<Self>), +} + +impl From<FastFieldValue> for OwnedValue { + fn from(value: FastFieldValue) -> Self { + match value { + FastFieldValue::Str(s) => OwnedValue::Str(s), + FastFieldValue::U64(u) => OwnedValue::U64(u), + FastFieldValue::I64(i) => OwnedValue::I64(i), + FastFieldValue::F64(f) => OwnedValue::F64(f), + FastFieldValue::Bool(b) => OwnedValue::Bool(b), + FastFieldValue::Date(d) => OwnedValue::Date(d), + FastFieldValue::Bytes(b) => OwnedValue::Bytes(b), + FastFieldValue::IpAddr(ip) => OwnedValue::IpAddr(ip), + FastFieldValue::Array(a) => { + OwnedValue::Array(a.into_iter().map(OwnedValue::from).collect()) + } + } + } +} + +/// Holds a fast field value in its u64 representation, and the order in which it should be sorted. +#[derive(Clone, Serialize, Deserialize, Debug)] +struct DocValueAndOrder { + /// A fast field value in its u64 representation. + value: Option<u64>, + /// Sort order for the value + order: Order, +} + +impl Ord for DocValueAndOrder { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + let invert = |cmp: std::cmp::Ordering| match self.order { + Order::Asc => cmp, + Order::Desc => cmp.reverse(), + }; + + match (self.value, other.value) { + (Some(self_value), Some(other_value)) => invert(self_value.cmp(&other_value)), + (Some(_), None) => std::cmp::Ordering::Greater, + (None, Some(_)) => std::cmp::Ordering::Less, + (None, None) => std::cmp::Ordering::Equal, + } + } +} + +impl PartialOrd for DocValueAndOrder { + fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { + Some(self.cmp(other)) + } +} + +impl PartialEq for DocValueAndOrder { + fn eq(&self, other: &Self) -> bool { + self.value.cmp(&other.value) == std::cmp::Ordering::Equal + } +} + +impl Eq for DocValueAndOrder {} + +#[derive(Clone, Serialize, Deserialize, Debug)] +struct DocSortValuesAndFields { + sorts: Vec<DocValueAndOrder>, + + #[serde(rename = "docvalue_fields")] + #[serde(skip_serializing_if = "HashMap::is_empty")] + doc_value_fields: HashMap<String, FastFieldValue>, +} + +impl Ord for DocSortValuesAndFields { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + for (self_feature, other_feature) in self.sorts.iter().zip(other.sorts.iter()) { + let cmp = self_feature.cmp(other_feature); + if cmp != std::cmp::Ordering::Equal { + return cmp; + } + } + std::cmp::Ordering::Equal + } +} + +impl PartialOrd for DocSortValuesAndFields { + fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { + Some(self.cmp(other)) + } +} + +impl PartialEq for DocSortValuesAndFields { + fn eq(&self, other: &Self) -> bool { + self.cmp(other) == std::cmp::Ordering::Equal + } +} + +impl Eq for DocSortValuesAndFields {} + +/// The TopHitsCollector used for collecting over segments and merging results. +#[derive(Clone, Serialize, Deserialize, Debug)] +pub struct TopHitsTopNComputer { + req: TopHitsAggregation, + top_n: TopNComputer<DocSortValuesAndFields, DocAddress, false>, +} + +impl std::cmp::PartialEq for TopHitsTopNComputer { + fn eq(&self, _other: &Self) -> bool { + false + } +} + +impl TopHitsTopNComputer { + /// Create a new TopHitsCollector + pub fn new(req: TopHitsAggregation) -> Self { + Self { + top_n: TopNComputer::new(req.size + req.from.unwrap_or(0)), + req, + } + } + + fn collect(&mut self, features: DocSortValuesAndFields, doc: DocAddress) { + self.top_n.push(features, doc); + } + + pub(crate) fn merge_fruits(&mut self, other_fruit: Self) -> crate::Result<()> { + for doc in other_fruit.top_n.into_vec() { + self.collect(doc.feature, doc.doc); + } + Ok(()) + } + + /// Finalize by converting self into the final result form + pub fn into_final_result(self) -> TopHitsMetricResult { + let mut hits: Vec<TopHitsVecEntry> = self + .top_n + .into_sorted_vec() + .into_iter() + .map(|doc| TopHitsVecEntry { + sort: doc.feature.sorts.iter().map(|f| f.value).collect(), + doc_value_fields: doc + .feature + .doc_value_fields + .into_iter() + .map(|(k, v)| (k, v.into())) + .collect(), + }) + .collect(); + + // Remove the first `from` elements + // Truncating from end would be more efficient, but we need to truncate from the front + // because `into_sorted_vec` gives us a descending order because of the inverted + // `Ord` semantics of the heap elements. + hits.drain(..self.req.from.unwrap_or(0)); + TopHitsMetricResult { hits } + } +} + +#[derive(Clone, Debug)] +pub(crate) struct TopHitsSegmentCollector { + segment_ordinal: SegmentOrdinal, + accessor_idx: usize, + req: TopHitsAggregation, + top_n: TopNComputer<Vec<DocValueAndOrder>, DocAddress, false>, +} + +impl TopHitsSegmentCollector { + pub fn from_req( + req: &TopHitsAggregation, + accessor_idx: usize, + segment_ordinal: SegmentOrdinal, + ) -> Self { + Self { + req: req.clone(), + top_n: TopNComputer::new(req.size + req.from.unwrap_or(0)), + segment_ordinal, + accessor_idx, + } + } + fn into_top_hits_collector( + self, + value_accessors: &HashMap<String, Vec<DynamicColumn>>, + ) -> TopHitsTopNComputer { + let mut top_hits_computer = TopHitsTopNComputer::new(self.req.clone()); + let top_results = self.top_n.into_vec(); + + for res in top_results { + let doc_value_fields = self + .req + .get_document_field_data(value_accessors, res.doc.doc_id); + top_hits_computer.collect( + DocSortValuesAndFields { + sorts: res.feature, + doc_value_fields, + }, + res.doc, + ); + } + + top_hits_computer + } +} + +impl SegmentAggregationCollector for TopHitsSegmentCollector { + fn add_intermediate_aggregation_result( + self: Box<Self>, + agg_with_accessor: &crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor, + results: &mut crate::aggregation::intermediate_agg_result::IntermediateAggregationResults, + ) -> crate::Result<()> { + let name = agg_with_accessor.aggs.keys[self.accessor_idx].to_string(); + + let value_accessors = &agg_with_accessor.aggs.values[self.accessor_idx].value_accessors; + + let intermediate_result = + IntermediateMetricResult::TopHits(self.into_top_hits_collector(value_accessors)); + results.push( + name, + IntermediateAggregationResult::Metric(intermediate_result), + ) + } + + fn collect( + &mut self, + doc_id: crate::DocId, + agg_with_accessor: &mut crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor, + ) -> crate::Result<()> { + let accessors = &agg_with_accessor.aggs.values[self.accessor_idx].accessors; + let sorts: Vec<DocValueAndOrder> = self + .req + .sort + .iter() + .enumerate() + .map(|(idx, KeyOrder { order, .. })| { + let order = *order; + let value = accessors + .get(idx) + .expect("could not find field in accessors") + .0 + .values_for_doc(doc_id) + .next(); + DocValueAndOrder { value, order } + }) + .collect(); + + self.top_n.push( + sorts, + DocAddress { + segment_ord: self.segment_ordinal, + doc_id, + }, + ); + Ok(()) + } + + fn collect_block( + &mut self, + docs: &[crate::DocId], + agg_with_accessor: &mut crate::aggregation::agg_req_with_accessor::AggregationsWithAccessor, + ) -> crate::Result<()> { + // TODO: Consider getting fields with the column block accessor. + for doc in docs { + self.collect(*doc, agg_with_accessor)?; + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use common::DateTime; + use pretty_assertions::assert_eq; + use serde_json::Value; + use time::macros::datetime; + + use super::{DocSortValuesAndFields, DocValueAndOrder, Order}; + use crate::aggregation::agg_req::Aggregations; + use crate::aggregation::agg_result::AggregationResults; + use crate::aggregation::bucket::tests::get_test_index_from_docs; + use crate::aggregation::tests::get_test_index_from_values; + use crate::aggregation::AggregationCollector; + use crate::collector::ComparableDoc; + use crate::query::AllQuery; + use crate::schema::OwnedValue; + + fn invert_order(cmp_feature: DocValueAndOrder) -> DocValueAndOrder { + let DocValueAndOrder { value, order } = cmp_feature; + let order = match order { + Order::Asc => Order::Desc, + Order::Desc => Order::Asc, + }; + DocValueAndOrder { value, order } + } + + fn collector_with_capacity(capacity: usize) -> super::TopHitsTopNComputer { + super::TopHitsTopNComputer { + top_n: super::TopNComputer::new(capacity), + req: Default::default(), + } + } + + fn invert_order_features(mut cmp_features: DocSortValuesAndFields) -> DocSortValuesAndFields { + cmp_features.sorts = cmp_features + .sorts + .into_iter() + .map(invert_order) + .collect::<Vec<_>>(); + cmp_features + } + + #[test] + fn test_comparable_doc_feature() -> crate::Result<()> { + let small = DocValueAndOrder { + value: Some(1), + order: Order::Asc, + }; + let big = DocValueAndOrder { + value: Some(2), + order: Order::Asc, + }; + let none = DocValueAndOrder { + value: None, + order: Order::Asc, + }; + + assert!(small < big); + assert!(none < small); + assert!(none < big); + + let small = invert_order(small); + let big = invert_order(big); + let none = invert_order(none); + + assert!(small > big); + assert!(none < small); + assert!(none < big); + + Ok(()) + } + + #[test] + fn test_comparable_doc_features() -> crate::Result<()> { + let features_1 = DocSortValuesAndFields { + sorts: vec![DocValueAndOrder { + value: Some(1), + order: Order::Asc, + }], + doc_value_fields: Default::default(), + }; + + let features_2 = DocSortValuesAndFields { + sorts: vec![DocValueAndOrder { + value: Some(2), + order: Order::Asc, + }], + doc_value_fields: Default::default(), + }; + + assert!(features_1 < features_2); + + assert!(invert_order_features(features_1.clone()) > invert_order_features(features_2)); + + Ok(()) + } + + #[test] + fn test_aggregation_top_hits_empty_index() -> crate::Result<()> { + let values = vec![]; + + let index = get_test_index_from_values(false, &values)?; + + let d: Aggregations = serde_json::from_value(json!({ + "top_hits_req": { + "top_hits": { + "size": 2, + "sort": [ + { "date": "desc" } + ], + "from": 0, + } + } + })) + .unwrap(); + + let collector = AggregationCollector::from_aggs(d, Default::default()); + + let reader = index.reader()?; + let searcher = reader.searcher(); + let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap(); + + let res: Value = serde_json::from_str( + &serde_json::to_string(&agg_res).expect("JSON serialization failed"), + ) + .expect("JSON parsing failed"); + + assert_eq!( + res, + json!({ + "top_hits_req": { + "hits": [] + } + }) + ); + + Ok(()) + } + + #[test] + fn test_top_hits_collector_single_feature() -> crate::Result<()> { + let docs = vec![ + ComparableDoc::<_, _, false> { + doc: crate::DocAddress { + segment_ord: 0, + doc_id: 0, + }, + feature: DocSortValuesAndFields { + sorts: vec![DocValueAndOrder { + value: Some(1), + order: Order::Asc, + }], + doc_value_fields: Default::default(), + }, + }, + ComparableDoc { + doc: crate::DocAddress { + segment_ord: 0, + doc_id: 2, + }, + feature: DocSortValuesAndFields { + sorts: vec![DocValueAndOrder { + value: Some(3), + order: Order::Asc, + }], + doc_value_fields: Default::default(), + }, + }, + ComparableDoc { + doc: crate::DocAddress { + segment_ord: 0, + doc_id: 1, + }, + feature: DocSortValuesAndFields { + sorts: vec![DocValueAndOrder { + value: Some(5), + order: Order::Asc, + }], + doc_value_fields: Default::default(), + }, + }, + ]; + + let mut collector = collector_with_capacity(3); + for doc in docs.clone() { + collector.collect(doc.feature, doc.doc); + } + + let res = collector.into_final_result(); + + assert_eq!( + res, + super::TopHitsMetricResult { + hits: vec![ + super::TopHitsVecEntry { + sort: vec![docs[0].feature.sorts[0].value], + doc_value_fields: Default::default(), + }, + super::TopHitsVecEntry { + sort: vec![docs[1].feature.sorts[0].value], + doc_value_fields: Default::default(), + }, + super::TopHitsVecEntry { + sort: vec![docs[2].feature.sorts[0].value], + doc_value_fields: Default::default(), + }, + ] + } + ); + + Ok(()) + } + + fn test_aggregation_top_hits(merge_segments: bool) -> crate::Result<()> { + let docs = vec![ + vec![ + r#"{ "date": "2015-01-02T00:00:00Z", "text": "bbb", "text2": "bbb", "mixed": { "dyn_arr": [1, "2"] } }"#, + r#"{ "date": "2017-06-15T00:00:00Z", "text": "ccc", "text2": "ddd", "mixed": { "dyn_arr": [3, "4"] } }"#, + ], + vec![ + r#"{ "text": "aaa", "text2": "bbb", "date": "2018-01-02T00:00:00Z", "mixed": { "dyn_arr": ["9", 8] } }"#, + r#"{ "text": "aaa", "text2": "bbb", "date": "2016-01-02T00:00:00Z", "mixed": { "dyn_arr": ["7", 6] } }"#, + ], + ]; + + let index = get_test_index_from_docs(merge_segments, &docs)?; + + let d: Aggregations = serde_json::from_value(json!({ + "top_hits_req": { + "top_hits": { + "size": 2, + "sort": [ + { "date": "desc" } + ], + "from": 1, + "docvalue_fields": [ + "date", + "tex*", + "mixed.*", + ], + } + } + }))?; + + let collector = AggregationCollector::from_aggs(d, Default::default()); + let reader = index.reader()?; + let searcher = reader.searcher(); + + let agg_res = + serde_json::to_value(searcher.search(&AllQuery, &collector).unwrap()).unwrap(); + + let date_2017 = datetime!(2017-06-15 00:00:00 UTC); + let date_2016 = datetime!(2016-01-02 00:00:00 UTC); + + assert_eq!( + agg_res["top_hits_req"], + json!({ + "hits": [ + { + "sort": [common::i64_to_u64(date_2017.unix_timestamp_nanos() as i64)], + "docvalue_fields": { + "date": [ OwnedValue::Date(DateTime::from_utc(date_2017)) ], + "text": [ "ccc" ], + "text2": [ "ddd" ], + "mixed.dyn_arr": [ 3, "4" ], + } + }, + { + "sort": [common::i64_to_u64(date_2016.unix_timestamp_nanos() as i64)], + "docvalue_fields": { + "date": [ OwnedValue::Date(DateTime::from_utc(date_2016)) ], + "text": [ "aaa" ], + "text2": [ "bbb" ], + "mixed.dyn_arr": [ 6, "7" ], + } + } + ] + }), + ); + + Ok(()) + } + + #[test] + fn test_aggregation_top_hits_single_segment() -> crate::Result<()> { + test_aggregation_top_hits(true) + } + + #[test] + fn test_aggregation_top_hits_multi_segment() -> crate::Result<()> { + test_aggregation_top_hits(false) + } +} diff --git a/src/aggregation/mod.rs b/src/aggregation/mod.rs index 9b6482546..fbb2925dd 100644 --- a/src/aggregation/mod.rs +++ b/src/aggregation/mod.rs @@ -145,6 +145,8 @@ mod agg_tests; mod agg_bench; +use core::fmt; + pub use agg_limits::AggregationLimits; pub use collector::{ AggregationCollector, AggregationSegmentCollector, DistributedAggregationCollector, @@ -154,7 +156,106 @@ use columnar::{ColumnType, MonotonicallyMappableToU64}; pub(crate) use date::format_date; pub use error::AggregationError; use itertools::Itertools; -use serde::{Deserialize, Serialize}; +use serde::de::{self, Visitor}; +use serde::{Deserialize, Deserializer, Serialize}; + +fn parse_str_into_f64<E: de::Error>(value: &str) -> Result<f64, E> { + let parsed = value.parse::<f64>().map_err(|_err| { + de::Error::custom(format!("Failed to parse f64 from string: {:?}", value)) + })?; + + // Check if the parsed value is NaN or infinity + if parsed.is_nan() || parsed.is_infinite() { + Err(de::Error::custom(format!( + "Value is not a valid f64 (NaN or Infinity): {:?}", + value + ))) + } else { + Ok(parsed) + } +} + +/// deserialize Option<f64> from string or float +pub(crate) fn deserialize_option_f64<'de, D>(deserializer: D) -> Result<Option<f64>, D::Error> +where D: Deserializer<'de> { + struct StringOrFloatVisitor; + + impl<'de> Visitor<'de> for StringOrFloatVisitor { + type Value = Option<f64>; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a string or a float") + } + + fn visit_str<E>(self, value: &str) -> Result<Self::Value, E> + where E: de::Error { + parse_str_into_f64(value).map(Some) + } + + fn visit_f64<E>(self, value: f64) -> Result<Self::Value, E> + where E: de::Error { + Ok(Some(value)) + } + + fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E> + where E: de::Error { + Ok(Some(value as f64)) + } + + fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E> + where E: de::Error { + Ok(Some(value as f64)) + } + + fn visit_none<E>(self) -> Result<Self::Value, E> + where E: de::Error { + Ok(None) + } + + fn visit_unit<E>(self) -> Result<Self::Value, E> + where E: de::Error { + Ok(None) + } + } + + deserializer.deserialize_any(StringOrFloatVisitor) +} + +/// deserialize f64 from string or float +pub(crate) fn deserialize_f64<'de, D>(deserializer: D) -> Result<f64, D::Error> +where D: Deserializer<'de> { + struct StringOrFloatVisitor; + + impl<'de> Visitor<'de> for StringOrFloatVisitor { + type Value = f64; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a string or a float") + } + + fn visit_str<E>(self, value: &str) -> Result<Self::Value, E> + where E: de::Error { + parse_str_into_f64(value) + } + + fn visit_f64<E>(self, value: f64) -> Result<Self::Value, E> + where E: de::Error { + Ok(value) + } + + fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E> + where E: de::Error { + Ok(value as f64) + } + + fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E> + where E: de::Error { + Ok(value as f64) + } + } + + deserializer.deserialize_any(StringOrFloatVisitor) +} /// Represents an associative array `(key => values)` in a very efficient manner. #[derive(PartialEq, Serialize, Deserialize)] @@ -281,6 +382,7 @@ pub(crate) fn f64_from_fastfield_u64(val: u64, field_type: &ColumnType) -> f64 { ColumnType::U64 => val as f64, ColumnType::I64 | ColumnType::DateTime => i64::from_u64(val) as f64, ColumnType::F64 => f64::from_u64(val), + ColumnType::Bool => val as f64, _ => { panic!("unexpected type {field_type:?}. This should not happen") } @@ -301,6 +403,7 @@ pub(crate) fn f64_to_fastfield_u64(val: f64, field_type: &ColumnType) -> Option< ColumnType::U64 => Some(val as u64), ColumnType::I64 | ColumnType::DateTime => Some((val as i64).to_u64()), ColumnType::F64 => Some(val.to_u64()), + ColumnType::Bool => Some(val as u64), _ => None, } } @@ -314,7 +417,6 @@ mod tests { use time::OffsetDateTime; use super::agg_req::Aggregations; - use super::segment_agg_result::AggregationLimits; use super::*; use crate::indexer::NoMergePolicy; use crate::query::{AllQuery, TermQuery}; diff --git a/src/aggregation/segment_agg_result.rs b/src/aggregation/segment_agg_result.rs index f39c7586d..f119b5d56 100644 --- a/src/aggregation/segment_agg_result.rs +++ b/src/aggregation/segment_agg_result.rs @@ -16,6 +16,7 @@ use super::metric::{ SumAggregation, }; use crate::aggregation::bucket::TermMissingAgg; +use crate::aggregation::metric::TopHitsSegmentCollector; use crate::aggregation::metric::{IntermediateExtendedStats, IntermediateInnerStatsCollector}; pub(crate) trait SegmentAggregationCollector: CollectorClone + Debug { @@ -169,6 +170,11 @@ pub(crate) fn build_single_agg_segment_collector( accessor_idx, )?, )), + TopHits(top_hits_req) => Ok(Box::new(TopHitsSegmentCollector::from_req( + top_hits_req, + accessor_idx, + req.segment_ordinal, + ))), } } diff --git a/src/collector/facet_collector.rs b/src/collector/facet_collector.rs index ee66a88a2..16759f3b2 100644 --- a/src/collector/facet_collector.rs +++ b/src/collector/facet_collector.rs @@ -1,7 +1,7 @@ use std::cmp::Ordering; use std::collections::{btree_map, BTreeMap, BTreeSet, BinaryHeap}; +use std::io; use std::ops::Bound; -use std::{io, u64, usize}; use crate::collector::{Collector, SegmentCollector}; use crate::fastfield::FacetReader; @@ -410,6 +410,7 @@ impl SegmentCollector for FacetSegmentCollector { /// Intermediary result of the `FacetCollector` that stores /// the facet counts for all the segments. +#[derive(Default, Clone)] pub struct FacetCounts { facet_counts: BTreeMap<Facet, u64>, } @@ -493,7 +494,7 @@ mod tests { use super::{FacetCollector, FacetCounts}; use crate::collector::facet_collector::compress_mapping; use crate::collector::Count; - use crate::core::Index; + use crate::index::Index; use crate::query::{AllQuery, QueryParser, TermQuery}; use crate::schema::{Facet, FacetOptions, IndexRecordOption, Schema, TantivyDocument}; use crate::{IndexWriter, Term}; diff --git a/src/collector/histogram_collector.rs b/src/collector/histogram_collector.rs index d5ca1b44f..51105e7b1 100644 --- a/src/collector/histogram_collector.rs +++ b/src/collector/histogram_collector.rs @@ -160,7 +160,7 @@ mod tests { use super::{add_vecs, HistogramCollector, HistogramComputer}; use crate::schema::{Schema, FAST}; use crate::time::{Date, Month}; - use crate::{doc, query, DateTime, Index}; + use crate::{query, DateTime, Index}; #[test] fn test_add_histograms_simple() { diff --git a/src/collector/mod.rs b/src/collector/mod.rs index 4d9b43d65..b78e02072 100644 --- a/src/collector/mod.rs +++ b/src/collector/mod.rs @@ -97,6 +97,7 @@ pub use self::multi_collector::{FruitHandle, MultiCollector, MultiFruit}; mod top_collector; mod top_score_collector; +pub use self::top_collector::ComparableDoc; pub use self::top_score_collector::{TopDocs, TopNComputer}; mod custom_score_top_collector; @@ -273,6 +274,10 @@ pub trait SegmentCollector: 'static { fn collect(&mut self, doc: DocId, score: Score); /// The query pushes the scored document to the collector via this method. + /// This method is used when the collector does not require scoring. + /// + /// See [`COLLECT_BLOCK_BUFFER_LEN`](crate::COLLECT_BLOCK_BUFFER_LEN) for the + /// buffer size passed to the collector. fn collect_block(&mut self, docs: &[DocId]) { for doc in docs { self.collect(*doc, 0.0); diff --git a/src/collector/multi_collector.rs b/src/collector/multi_collector.rs index 4cbcadc24..da4d222a9 100644 --- a/src/collector/multi_collector.rs +++ b/src/collector/multi_collector.rs @@ -52,10 +52,16 @@ impl<TCollector: Collector> Collector for CollectorWrapper<TCollector> { impl SegmentCollector for Box<dyn BoxableSegmentCollector> { type Fruit = Box<dyn Fruit>; + #[inline] fn collect(&mut self, doc: u32, score: Score) { self.as_mut().collect(doc, score); } + #[inline] + fn collect_block(&mut self, docs: &[DocId]) { + self.as_mut().collect_block(docs); + } + fn harvest(self) -> Box<dyn Fruit> { BoxableSegmentCollector::harvest_from_box(self) } @@ -63,6 +69,11 @@ impl SegmentCollector for Box<dyn BoxableSegmentCollector> { pub trait BoxableSegmentCollector { fn collect(&mut self, doc: u32, score: Score); + fn collect_block(&mut self, docs: &[DocId]) { + for &doc in docs { + self.collect(doc, 0.0); + } + } fn harvest_from_box(self: Box<Self>) -> Box<dyn Fruit>; } @@ -71,9 +82,14 @@ pub struct SegmentCollectorWrapper<TSegmentCollector: SegmentCollector>(TSegment impl<TSegmentCollector: SegmentCollector> BoxableSegmentCollector for SegmentCollectorWrapper<TSegmentCollector> { + #[inline] fn collect(&mut self, doc: u32, score: Score) { self.0.collect(doc, score); } + #[inline] + fn collect_block(&mut self, docs: &[DocId]) { + self.0.collect_block(docs); + } fn harvest_from_box(self: Box<Self>) -> Box<dyn Fruit> { Box::new(self.0.harvest()) diff --git a/src/collector/tests.rs b/src/collector/tests.rs index 81924090a..7af7c6d8c 100644 --- a/src/collector/tests.rs +++ b/src/collector/tests.rs @@ -1,15 +1,11 @@ use columnar::{BytesColumn, Column}; use super::*; -use crate::collector::{Count, FilterCollector, TopDocs}; -use crate::core::SegmentReader; use crate::query::{AllQuery, QueryParser}; use crate::schema::{Schema, FAST, TEXT}; use crate::time::format_description::well_known::Rfc3339; use crate::time::OffsetDateTime; -use crate::{ - doc, DateTime, DocAddress, DocId, Index, Score, Searcher, SegmentOrdinal, TantivyDocument, -}; +use crate::{DateTime, DocAddress, Index, Searcher, TantivyDocument}; pub const TEST_COLLECTOR_WITH_SCORE: TestCollector = TestCollector { compute_score: true, diff --git a/src/collector/top_collector.rs b/src/collector/top_collector.rs index ddb78c7b1..5a07e4218 100644 --- a/src/collector/top_collector.rs +++ b/src/collector/top_collector.rs @@ -1,47 +1,58 @@ use std::cmp::Ordering; use std::marker::PhantomData; +use serde::{Deserialize, Serialize}; + use super::top_score_collector::TopNComputer; use crate::{DocAddress, DocId, SegmentOrdinal, SegmentReader}; /// Contains a feature (field, score, etc.) of a document along with the document address. /// -/// It has a custom implementation of `PartialOrd` that reverses the order. This is because the -/// default Rust heap is a max heap, whereas a min heap is needed. -/// -/// Additionally, it guarantees stable sorting: in case of a tie on the feature, the document +/// It guarantees stable sorting: in case of a tie on the feature, the document /// address is used. /// +/// The REVERSE_ORDER generic parameter controls whether the by-feature order +/// should be reversed, which is useful for achieving for example largest-first +/// semantics without having to wrap the feature in a `Reverse`. +/// /// WARNING: equality is not what you would expect here. /// Two elements are equal if their feature is equal, and regardless of whether `doc` /// is equal. This should be perfectly fine for this usage, but let's make sure this /// struct is never public. -pub(crate) struct ComparableDoc<T, D> { +#[derive(Clone, Default, Serialize, Deserialize)] +pub struct ComparableDoc<T, D, const REVERSE_ORDER: bool = false> { + /// The feature of the document. In practice, this is + /// is any type that implements `PartialOrd`. pub feature: T, + /// The document address. In practice, this is any + /// type that implements `PartialOrd`, and is guaranteed + /// to be unique for each document. pub doc: D, } -impl<T: std::fmt::Debug, D: std::fmt::Debug> std::fmt::Debug for ComparableDoc<T, D> { +impl<T: std::fmt::Debug, D: std::fmt::Debug, const R: bool> std::fmt::Debug + for ComparableDoc<T, D, R> +{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("ComparableDoc") + f.debug_struct(format!("ComparableDoc<_, _ {R}").as_str()) .field("feature", &self.feature) .field("doc", &self.doc) .finish() } } -impl<T: PartialOrd, D: PartialOrd> PartialOrd for ComparableDoc<T, D> { +impl<T: PartialOrd, D: PartialOrd, const R: bool> PartialOrd for ComparableDoc<T, D, R> { fn partial_cmp(&self, other: &Self) -> Option<Ordering> { Some(self.cmp(other)) } } -impl<T: PartialOrd, D: PartialOrd> Ord for ComparableDoc<T, D> { +impl<T: PartialOrd, D: PartialOrd, const R: bool> Ord for ComparableDoc<T, D, R> { #[inline] fn cmp(&self, other: &Self) -> Ordering { - // Reversed to make BinaryHeap work as a min-heap - let by_feature = other + let by_feature = self .feature - .partial_cmp(&self.feature) + .partial_cmp(&other.feature) + .map(|ord| if R { ord.reverse() } else { ord }) .unwrap_or(Ordering::Equal); let lazy_by_doc_address = || self.doc.partial_cmp(&other.doc).unwrap_or(Ordering::Equal); @@ -53,13 +64,13 @@ impl<T: PartialOrd, D: PartialOrd> Ord for ComparableDoc<T, D> { } } -impl<T: PartialOrd, D: PartialOrd> PartialEq for ComparableDoc<T, D> { +impl<T: PartialOrd, D: PartialOrd, const R: bool> PartialEq for ComparableDoc<T, D, R> { fn eq(&self, other: &Self) -> bool { self.cmp(other) == Ordering::Equal } } -impl<T: PartialOrd, D: PartialOrd> Eq for ComparableDoc<T, D> {} +impl<T: PartialOrd, D: PartialOrd, const R: bool> Eq for ComparableDoc<T, D, R> {} pub(crate) struct TopCollector<T> { pub limit: usize, @@ -99,10 +110,10 @@ where T: PartialOrd + Clone if self.limit == 0 { return Ok(Vec::new()); } - let mut top_collector = TopNComputer::new(self.limit + self.offset); + let mut top_collector: TopNComputer<_, _> = TopNComputer::new(self.limit + self.offset); for child_fruit in children { for (feature, doc) in child_fruit { - top_collector.push(ComparableDoc { feature, doc }); + top_collector.push(feature, doc); } } @@ -143,6 +154,8 @@ where T: PartialOrd + Clone /// The theoretical complexity for collecting the top `K` out of `n` documents /// is `O(n + K)`. pub(crate) struct TopSegmentCollector<T> { + /// We reverse the order of the feature in order to + /// have top-semantics instead of bottom semantics. topn_computer: TopNComputer<T, DocId>, segment_ord: u32, } @@ -180,7 +193,7 @@ impl<T: PartialOrd + Clone> TopSegmentCollector<T> { /// will compare the lowest scoring item with the given one and keep whichever is greater. #[inline] pub fn collect(&mut self, doc: DocId, feature: T) { - self.topn_computer.push(ComparableDoc { feature, doc }); + self.topn_computer.push(feature, doc); } } diff --git a/src/collector/top_score_collector.rs b/src/collector/top_score_collector.rs index 484882046..415625bc1 100644 --- a/src/collector/top_score_collector.rs +++ b/src/collector/top_score_collector.rs @@ -3,6 +3,8 @@ use std::marker::PhantomData; use std::sync::Arc; use columnar::ColumnValues; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; use super::Collector; use crate::collector::custom_score_top_collector::CustomScoreTopCollector; @@ -309,7 +311,7 @@ impl TopDocs { /// /// To comfortably work with `u64`s, `i64`s, `f64`s, or `date`s, please refer to /// the [.order_by_fast_field(...)](TopDocs::order_by_fast_field) method. - fn order_by_u64_field( + pub fn order_by_u64_field( self, field: impl ToString, order: Order, @@ -663,7 +665,7 @@ impl Collector for TopDocs { reader: &SegmentReader, ) -> crate::Result<<Self::Child as SegmentCollector>::Fruit> { let heap_len = self.0.limit + self.0.offset; - let mut top_n = TopNComputer::new(heap_len); + let mut top_n: TopNComputer<_, _> = TopNComputer::new(heap_len); if let Some(alive_bitset) = reader.alive_bitset() { let mut threshold = Score::MIN; @@ -672,21 +674,13 @@ impl Collector for TopDocs { if alive_bitset.is_deleted(doc) { return threshold; } - let doc = ComparableDoc { - feature: score, - doc, - }; - top_n.push(doc); + top_n.push(score, doc); threshold = top_n.threshold.unwrap_or(Score::MIN); threshold })?; } else { weight.for_each_pruning(Score::MIN, reader, &mut |doc, score| { - let doc = ComparableDoc { - feature: score, - doc, - }; - top_n.push(doc); + top_n.push(score, doc); top_n.threshold.unwrap_or(Score::MIN) })?; } @@ -725,17 +719,78 @@ impl SegmentCollector for TopScoreSegmentCollector { /// Fast TopN Computation /// +/// Capacity of the vec is 2 * top_n. +/// The buffer is truncated to the top_n elements when it reaches the capacity of the Vec. +/// That means capacity has special meaning and should be carried over when cloning or serializing. +/// /// For TopN == 0, it will be relative expensive. -pub struct TopNComputer<Score, DocId> { - buffer: Vec<ComparableDoc<Score, DocId>>, +#[derive(Serialize, Deserialize)] +#[serde(from = "TopNComputerDeser<Score, D, REVERSE_ORDER>")] +pub struct TopNComputer<Score, D, const REVERSE_ORDER: bool = true> { + /// The buffer reverses sort order to get top-semantics instead of bottom-semantics + buffer: Vec<ComparableDoc<Score, D, REVERSE_ORDER>>, top_n: usize, pub(crate) threshold: Option<Score>, } -impl<Score, DocId> TopNComputer<Score, DocId> +impl<Score: std::fmt::Debug, D, const REVERSE_ORDER: bool> std::fmt::Debug + for TopNComputer<Score, D, REVERSE_ORDER> +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("TopNComputer") + .field("buffer_len", &self.buffer.len()) + .field("top_n", &self.top_n) + .field("current_threshold", &self.threshold) + .finish() + } +} + +// Intermediate struct for TopNComputer for deserialization, to keep vec capacity +#[derive(Deserialize)] +struct TopNComputerDeser<Score, D, const REVERSE_ORDER: bool> { + buffer: Vec<ComparableDoc<Score, D, REVERSE_ORDER>>, + top_n: usize, + threshold: Option<Score>, +} + +// Custom clone to keep capacity +impl<Score: Clone, D: Clone, const REVERSE_ORDER: bool> Clone + for TopNComputer<Score, D, REVERSE_ORDER> +{ + fn clone(&self) -> Self { + let mut buffer_clone = Vec::with_capacity(self.buffer.capacity()); + buffer_clone.extend(self.buffer.iter().cloned()); + + TopNComputer { + buffer: buffer_clone, + top_n: self.top_n, + threshold: self.threshold.clone(), + } + } +} + +impl<Score, D, const R: bool> From<TopNComputerDeser<Score, D, R>> for TopNComputer<Score, D, R> { + fn from(mut value: TopNComputerDeser<Score, D, R>) -> Self { + let expected_cap = value.top_n.max(1) * 2; + let current_cap = value.buffer.capacity(); + if current_cap < expected_cap { + value.buffer.reserve_exact(expected_cap - current_cap); + } else { + value.buffer.shrink_to(expected_cap); + } + + TopNComputer { + buffer: value.buffer, + top_n: value.top_n, + threshold: value.threshold, + } + } +} + +impl<Score, D, const R: bool> TopNComputer<Score, D, R> where Score: PartialOrd + Clone, - DocId: Ord + Clone, + D: Serialize + DeserializeOwned + Ord + Clone, { /// Create a new `TopNComputer`. /// Internally it will allocate a buffer of size `2 * top_n`. @@ -748,10 +803,12 @@ where } } + /// Push a new document to the top n. + /// If the document is below the current threshold, it will be ignored. #[inline] - pub(crate) fn push(&mut self, doc: ComparableDoc<Score, DocId>) { + pub fn push(&mut self, feature: Score, doc: D) { if let Some(last_median) = self.threshold.clone() { - if doc.feature < last_median { + if feature < last_median { return; } } @@ -766,7 +823,7 @@ where let uninit = self.buffer.spare_capacity_mut(); // This cannot panic, because we truncate_median will at least remove one element, since // the min capacity is 2. - uninit[0].write(doc); + uninit[0].write(ComparableDoc { doc, feature }); // This is safe because it would panic in the line above unsafe { self.buffer.set_len(self.buffer.len() + 1); @@ -785,13 +842,24 @@ where median_score } - pub(crate) fn into_sorted_vec(mut self) -> Vec<ComparableDoc<Score, DocId>> { + /// Returns the top n elements in sorted order. + pub fn into_sorted_vec(mut self) -> Vec<ComparableDoc<Score, D, R>> { if self.buffer.len() > self.top_n { self.truncate_top_n(); } self.buffer.sort_unstable(); self.buffer } + + /// Returns the top n elements in stored order. + /// Useful if you do not need the elements in sorted order, + /// for example when merging the results of multiple segments. + pub fn into_vec(mut self) -> Vec<ComparableDoc<Score, D, R>> { + if self.buffer.len() > self.top_n { + self.truncate_top_n(); + } + self.buffer + } } #[cfg(test)] @@ -825,49 +893,44 @@ mod tests { crate::assert_nearly_equals!(result.0, expected.0); } } + #[test] + fn test_topn_computer_serde() { + let computer: TopNComputer<u32, u32> = TopNComputer::new(1); + + let computer_ser = serde_json::to_string(&computer).unwrap(); + let mut computer: TopNComputer<u32, u32> = serde_json::from_str(&computer_ser).unwrap(); + + computer.push(1u32, 5u32); + computer.push(1u32, 0u32); + computer.push(1u32, 7u32); + + assert_eq!( + computer.into_sorted_vec(), + &[ComparableDoc { + feature: 1u32, + doc: 0u32, + },] + ); + } #[test] fn test_empty_topn_computer() { let mut computer: TopNComputer<u32, u32> = TopNComputer::new(0); - computer.push(ComparableDoc { - feature: 1u32, - doc: 1u32, - }); - computer.push(ComparableDoc { - feature: 1u32, - doc: 2u32, - }); - computer.push(ComparableDoc { - feature: 1u32, - doc: 3u32, - }); + computer.push(1u32, 1u32); + computer.push(1u32, 2u32); + computer.push(1u32, 3u32); assert!(computer.into_sorted_vec().is_empty()); } #[test] fn test_topn_computer() { let mut computer: TopNComputer<u32, u32> = TopNComputer::new(2); - computer.push(ComparableDoc { - feature: 1u32, - doc: 1u32, - }); - computer.push(ComparableDoc { - feature: 2u32, - doc: 2u32, - }); - computer.push(ComparableDoc { - feature: 3u32, - doc: 3u32, - }); - computer.push(ComparableDoc { - feature: 2u32, - doc: 4u32, - }); - computer.push(ComparableDoc { - feature: 1u32, - doc: 5u32, - }); + computer.push(1u32, 1u32); + computer.push(2u32, 2u32); + computer.push(3u32, 3u32); + computer.push(2u32, 4u32); + computer.push(1u32, 5u32); assert_eq!( computer.into_sorted_vec(), &[ @@ -889,10 +952,7 @@ mod tests { let mut computer: TopNComputer<u32, u32> = TopNComputer::new(top_n); for _ in 0..1 + top_n * 2 { - computer.push(ComparableDoc { - feature: 1u32, - doc: 1u32, - }); + computer.push(1u32, 1u32); } let _vals = computer.into_sorted_vec(); } diff --git a/src/core/json_utils.rs b/src/core/json_utils.rs index 09059ddbf..d7ac29ad7 100644 --- a/src/core/json_utils.rs +++ b/src/core/json_utils.rs @@ -1,12 +1,10 @@ -use columnar::MonotonicallyMappableToU64; +use common::json_path_writer::JSON_PATH_SEGMENT_SEP; use common::{replace_in_place, JsonPathWriter}; use rustc_hash::FxHashMap; -use crate::fastfield::FastValue; use crate::postings::{IndexingContext, IndexingPosition, PostingsWriter}; use crate::schema::document::{ReferenceValue, ReferenceValueLeaf, Value}; -use crate::schema::term::JSON_PATH_SEGMENT_SEP; -use crate::schema::{Field, Type, DATE_TIME_PRECISION_INDEXED}; +use crate::schema::{Field, Type}; use crate::time::format_description::well_known::Rfc3339; use crate::time::{OffsetDateTime, UtcOffset}; use crate::tokenizer::TextAnalyzer; @@ -256,71 +254,45 @@ fn index_json_value<'a, V: Value<'a>>( } } -// Tries to infer a JSON type from a string. -pub fn convert_to_fast_value_and_get_term( - json_term_writer: &mut JsonTermWriter, +/// Tries to infer a JSON type from a string and append it to the term. +/// +/// The term must be json + JSON path. +pub(crate) fn convert_to_fast_value_and_append_to_json_term( + mut term: Term, phrase: &str, ) -> Option<Term> { + assert_eq!( + term.value() + .as_json_value_bytes() + .expect("expecting a Term with a json type and json path") + .as_serialized() + .len(), + 0, + "JSON value bytes should be empty" + ); if let Ok(dt) = OffsetDateTime::parse(phrase, &Rfc3339) { let dt_utc = dt.to_offset(UtcOffset::UTC); - return Some(set_fastvalue_and_get_term( - json_term_writer, - DateTime::from_utc(dt_utc), - )); + term.append_type_and_fast_value(DateTime::from_utc(dt_utc)); + return Some(term); } if let Ok(i64_val) = str::parse::<i64>(phrase) { - return Some(set_fastvalue_and_get_term(json_term_writer, i64_val)); + term.append_type_and_fast_value(i64_val); + return Some(term); } if let Ok(u64_val) = str::parse::<u64>(phrase) { - return Some(set_fastvalue_and_get_term(json_term_writer, u64_val)); + term.append_type_and_fast_value(u64_val); + return Some(term); } if let Ok(f64_val) = str::parse::<f64>(phrase) { - return Some(set_fastvalue_and_get_term(json_term_writer, f64_val)); + term.append_type_and_fast_value(f64_val); + return Some(term); } if let Ok(bool_val) = str::parse::<bool>(phrase) { - return Some(set_fastvalue_and_get_term(json_term_writer, bool_val)); + term.append_type_and_fast_value(bool_val); + return Some(term); } None } -// helper function to generate a Term from a json fastvalue -pub(crate) fn set_fastvalue_and_get_term<T: FastValue>( - json_term_writer: &mut JsonTermWriter, - value: T, -) -> Term { - json_term_writer.set_fast_value(value); - json_term_writer.term().clone() -} - -// helper function to generate a list of terms with their positions from a textual json value -pub(crate) fn set_string_and_get_terms( - json_term_writer: &mut JsonTermWriter, - value: &str, - text_analyzer: &mut TextAnalyzer, -) -> Vec<(usize, Term)> { - let mut positions_and_terms = Vec::<(usize, Term)>::new(); - json_term_writer.close_path_and_set_type(Type::Str); - let term_num_bytes = json_term_writer.term_buffer.len_bytes(); - let mut token_stream = text_analyzer.token_stream(value); - token_stream.process(&mut |token| { - json_term_writer - .term_buffer - .truncate_value_bytes(term_num_bytes); - json_term_writer - .term_buffer - .append_bytes(token.text.as_bytes()); - positions_and_terms.push((token.position, json_term_writer.term().clone())); - }); - positions_and_terms -} - -/// Writes a value of a JSON field to a `Term`. -/// The Term format is as follows: -/// `[JSON_TYPE][JSON_PATH][JSON_END_OF_PATH][VALUE_BYTES]` -pub struct JsonTermWriter<'a> { - term_buffer: &'a mut Term, - path_stack: Vec<usize>, - expand_dots_enabled: bool, -} /// Splits a json path supplied to the query parser in such a way that /// `.` can be escaped. @@ -377,158 +349,68 @@ pub(crate) fn encode_column_name( path.into() } -impl<'a> JsonTermWriter<'a> { - pub fn from_field_and_json_path( - field: Field, - json_path: &str, - expand_dots_enabled: bool, - term_buffer: &'a mut Term, - ) -> Self { - term_buffer.set_field_and_type(field, Type::Json); - let mut json_term_writer = Self::wrap(term_buffer, expand_dots_enabled); - for segment in split_json_path(json_path) { - json_term_writer.push_path_segment(&segment); - } - json_term_writer +pub fn term_from_json_paths<'a>( + json_field: Field, + paths: impl Iterator<Item = &'a str>, + expand_dots_enabled: bool, +) -> Term { + let mut json_path = JsonPathWriter::with_expand_dots(expand_dots_enabled); + for path in paths { + json_path.push(path); } + json_path.set_end(); + let mut term = Term::with_type_and_field(Type::Json, json_field); - pub fn wrap(term_buffer: &'a mut Term, expand_dots_enabled: bool) -> Self { - term_buffer.clear_with_type(Type::Json); - let mut path_stack = Vec::with_capacity(10); - path_stack.push(0); - Self { - term_buffer, - path_stack, - expand_dots_enabled, - } - } - - fn trim_to_end_of_path(&mut self) { - let end_of_path = *self.path_stack.last().unwrap(); - self.term_buffer.truncate_value_bytes(end_of_path); - } - - pub fn close_path_and_set_type(&mut self, typ: Type) { - self.trim_to_end_of_path(); - self.term_buffer.set_json_path_end(); - self.term_buffer.append_bytes(&[typ.to_code()]); - } - - // TODO: Remove this function and use JsonPathWriter instead. - pub fn push_path_segment(&mut self, segment: &str) { - // the path stack should never be empty. - self.trim_to_end_of_path(); - - if self.path_stack.len() > 1 { - self.term_buffer.set_json_path_separator(); - } - let appended_segment = self.term_buffer.append_bytes(segment.as_bytes()); - if self.expand_dots_enabled { - // We need to replace `.` by JSON_PATH_SEGMENT_SEP. - replace_in_place(b'.', JSON_PATH_SEGMENT_SEP, appended_segment); - } - self.term_buffer.add_json_path_separator(); - self.path_stack.push(self.term_buffer.len_bytes()); - } - - pub fn pop_path_segment(&mut self) { - self.path_stack.pop(); - assert!(!self.path_stack.is_empty()); - self.trim_to_end_of_path(); - } - - /// Returns the json path of the term being currently built. - #[cfg(test)] - pub(crate) fn path(&self) -> &[u8] { - let end_of_path = self.path_stack.last().cloned().unwrap_or(1); - &self.term().serialized_value_bytes()[..end_of_path - 1] - } - - pub(crate) fn set_fast_value<T: FastValue>(&mut self, val: T) { - self.close_path_and_set_type(T::to_type()); - let value = if T::to_type() == Type::Date { - DateTime::from_u64(val.to_u64()) - .truncate(DATE_TIME_PRECISION_INDEXED) - .to_u64() - } else { - val.to_u64() - }; - self.term_buffer - .append_bytes(value.to_be_bytes().as_slice()); - } - - pub fn set_str(&mut self, text: &str) { - self.close_path_and_set_type(Type::Str); - self.term_buffer.append_bytes(text.as_bytes()); - } - - pub fn term(&self) -> &Term { - self.term_buffer - } + term.append_bytes(json_path.as_str().as_bytes()); + term } #[cfg(test)] mod tests { - use super::{split_json_path, JsonTermWriter}; - use crate::schema::{Field, Type}; - use crate::Term; + use super::split_json_path; + use crate::json_utils::term_from_json_paths; + use crate::schema::Field; #[test] fn test_json_writer() { let field = Field::from_field_id(1); - let mut term = Term::with_type_and_field(Type::Json, field); - let mut json_writer = JsonTermWriter::wrap(&mut term, false); - json_writer.push_path_segment("attributes"); - json_writer.push_path_segment("color"); - json_writer.set_str("red"); + + let mut term = term_from_json_paths(field, ["attributes", "color"].into_iter(), false); + term.append_type_and_str("red"); assert_eq!( - format!("{:?}", json_writer.term()), + format!("{:?}", term), "Term(field=1, type=Json, path=attributes.color, type=Str, \"red\")" ); - json_writer.set_str("blue"); - assert_eq!( - format!("{:?}", json_writer.term()), - "Term(field=1, type=Json, path=attributes.color, type=Str, \"blue\")" + + let mut term = term_from_json_paths( + field, + ["attributes", "dimensions", "width"].into_iter(), + false, ); - json_writer.pop_path_segment(); - json_writer.push_path_segment("dimensions"); - json_writer.push_path_segment("width"); - json_writer.set_fast_value(400i64); + term.append_type_and_fast_value(400i64); assert_eq!( - format!("{:?}", json_writer.term()), + format!("{:?}", term), "Term(field=1, type=Json, path=attributes.dimensions.width, type=I64, 400)" ); - json_writer.pop_path_segment(); - json_writer.push_path_segment("height"); - json_writer.set_fast_value(300i64); - assert_eq!( - format!("{:?}", json_writer.term()), - "Term(field=1, type=Json, path=attributes.dimensions.height, type=I64, 300)" - ); } #[test] fn test_string_term() { let field = Field::from_field_id(1); - let mut term = Term::with_type_and_field(Type::Json, field); - let mut json_writer = JsonTermWriter::wrap(&mut term, false); - json_writer.push_path_segment("color"); - json_writer.set_str("red"); - assert_eq!( - json_writer.term().serialized_term(), - b"\x00\x00\x00\x01jcolor\x00sred" - ) + let mut term = term_from_json_paths(field, ["color"].into_iter(), false); + term.append_type_and_str("red"); + + assert_eq!(term.serialized_term(), b"\x00\x00\x00\x01jcolor\x00sred") } #[test] fn test_i64_term() { let field = Field::from_field_id(1); - let mut term = Term::with_type_and_field(Type::Json, field); - let mut json_writer = JsonTermWriter::wrap(&mut term, false); - json_writer.push_path_segment("color"); - json_writer.set_fast_value(-4i64); + let mut term = term_from_json_paths(field, ["color"].into_iter(), false); + term.append_type_and_fast_value(-4i64); + assert_eq!( - json_writer.term().serialized_term(), + term.serialized_term(), b"\x00\x00\x00\x01jcolor\x00i\x7f\xff\xff\xff\xff\xff\xff\xfc" ) } @@ -536,12 +418,11 @@ mod tests { #[test] fn test_u64_term() { let field = Field::from_field_id(1); - let mut term = Term::with_type_and_field(Type::Json, field); - let mut json_writer = JsonTermWriter::wrap(&mut term, false); - json_writer.push_path_segment("color"); - json_writer.set_fast_value(4u64); + let mut term = term_from_json_paths(field, ["color"].into_iter(), false); + term.append_type_and_fast_value(4u64); + assert_eq!( - json_writer.term().serialized_term(), + term.serialized_term(), b"\x00\x00\x00\x01jcolor\x00u\x00\x00\x00\x00\x00\x00\x00\x04" ) } @@ -549,12 +430,10 @@ mod tests { #[test] fn test_f64_term() { let field = Field::from_field_id(1); - let mut term = Term::with_type_and_field(Type::Json, field); - let mut json_writer = JsonTermWriter::wrap(&mut term, false); - json_writer.push_path_segment("color"); - json_writer.set_fast_value(4.0f64); + let mut term = term_from_json_paths(field, ["color"].into_iter(), false); + term.append_type_and_fast_value(4.0f64); assert_eq!( - json_writer.term().serialized_term(), + term.serialized_term(), b"\x00\x00\x00\x01jcolor\x00f\xc0\x10\x00\x00\x00\x00\x00\x00" ) } @@ -562,90 +441,14 @@ mod tests { #[test] fn test_bool_term() { let field = Field::from_field_id(1); - let mut term = Term::with_type_and_field(Type::Json, field); - let mut json_writer = JsonTermWriter::wrap(&mut term, false); - json_writer.push_path_segment("color"); - json_writer.set_fast_value(true); + let mut term = term_from_json_paths(field, ["color"].into_iter(), false); + term.append_type_and_fast_value(true); assert_eq!( - json_writer.term().serialized_term(), + term.serialized_term(), b"\x00\x00\x00\x01jcolor\x00o\x00\x00\x00\x00\x00\x00\x00\x01" ) } - #[test] - fn test_push_after_set_path_segment() { - let field = Field::from_field_id(1); - let mut term = Term::with_type_and_field(Type::Json, field); - let mut json_writer = JsonTermWriter::wrap(&mut term, false); - json_writer.push_path_segment("attribute"); - json_writer.set_str("something"); - json_writer.push_path_segment("color"); - json_writer.set_str("red"); - assert_eq!( - json_writer.term().serialized_term(), - b"\x00\x00\x00\x01jattribute\x01color\x00sred" - ) - } - - #[test] - fn test_pop_segment() { - let field = Field::from_field_id(1); - let mut term = Term::with_type_and_field(Type::Json, field); - let mut json_writer = JsonTermWriter::wrap(&mut term, false); - json_writer.push_path_segment("color"); - json_writer.push_path_segment("hue"); - json_writer.pop_path_segment(); - json_writer.set_str("red"); - assert_eq!( - json_writer.term().serialized_term(), - b"\x00\x00\x00\x01jcolor\x00sred" - ) - } - - #[test] - fn test_json_writer_path() { - let field = Field::from_field_id(1); - let mut term = Term::with_type_and_field(Type::Json, field); - let mut json_writer = JsonTermWriter::wrap(&mut term, false); - json_writer.push_path_segment("color"); - assert_eq!(json_writer.path(), b"color"); - json_writer.push_path_segment("hue"); - assert_eq!(json_writer.path(), b"color\x01hue"); - json_writer.set_str("pink"); - assert_eq!(json_writer.path(), b"color\x01hue"); - } - - #[test] - fn test_json_path_expand_dots_disabled() { - let field = Field::from_field_id(1); - let mut term = Term::with_type_and_field(Type::Json, field); - let mut json_writer = JsonTermWriter::wrap(&mut term, false); - json_writer.push_path_segment("color.hue"); - assert_eq!(json_writer.path(), b"color.hue"); - } - - #[test] - fn test_json_path_expand_dots_enabled() { - let field = Field::from_field_id(1); - let mut term = Term::with_type_and_field(Type::Json, field); - let mut json_writer = JsonTermWriter::wrap(&mut term, true); - json_writer.push_path_segment("color.hue"); - assert_eq!(json_writer.path(), b"color\x01hue"); - } - - #[test] - fn test_json_path_expand_dots_enabled_pop_segment() { - let field = Field::from_field_id(1); - let mut term = Term::with_type_and_field(Type::Json, field); - let mut json_writer = JsonTermWriter::wrap(&mut term, true); - json_writer.push_path_segment("hello"); - assert_eq!(json_writer.path(), b"hello"); - json_writer.push_path_segment("color.hue"); - assert_eq!(json_writer.path(), b"hello\x01color\x01hue"); - json_writer.pop_path_segment(); - assert_eq!(json_writer.path(), b"hello"); - } - #[test] fn test_split_json_path_simple() { let json_path = split_json_path("titi.toto"); diff --git a/src/core/mod.rs b/src/core/mod.rs index 6a98f6fe0..db4ab2896 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -1,32 +1,14 @@ mod executor; -pub mod index; -mod index_meta; -mod inverted_index_reader; #[doc(hidden)] pub mod json_utils; pub mod searcher; -mod segment; -mod segment_component; -mod segment_id; -mod segment_reader; -mod single_segment_index_writer; use std::path::Path; use once_cell::sync::Lazy; pub use self::executor::Executor; -pub use self::index::{Index, IndexBuilder}; -pub use self::index_meta::{ - IndexMeta, IndexSettings, IndexSortByField, Order, SegmentMeta, SegmentMetaInventory, -}; -pub use self::inverted_index_reader::InvertedIndexReader; pub use self::searcher::{Searcher, SearcherGeneration}; -pub use self::segment::Segment; -pub use self::segment_component::SegmentComponent; -pub use self::segment_id::SegmentId; -pub use self::segment_reader::{merge_field_meta_data, FieldMetadata, SegmentReader}; -pub use self::single_segment_index_writer::SingleSegmentIndexWriter; /// The meta file contains all the information about the list of segments and the schema /// of the index. diff --git a/src/core/searcher.rs b/src/core/searcher.rs index 3f989696c..8c5dad3da 100644 --- a/src/core/searcher.rs +++ b/src/core/searcher.rs @@ -3,7 +3,8 @@ use std::sync::Arc; use std::{fmt, io}; use crate::collector::Collector; -use crate::core::{Executor, SegmentReader}; +use crate::core::Executor; +use crate::index::SegmentReader; use crate::query::{Bm25StatisticsProvider, EnableScoring, Query}; use crate::schema::document::DocumentDeserialize; use crate::schema::{Schema, Term}; diff --git a/src/core/tests.rs b/src/core/tests.rs index e215c31f4..210b359da 100644 --- a/src/core/tests.rs +++ b/src/core/tests.rs @@ -1,9 +1,9 @@ use crate::collector::Count; use crate::directory::{RamDirectory, WatchCallback}; use crate::indexer::{LogMergePolicy, NoMergePolicy}; -use crate::json_utils::JsonTermWriter; +use crate::json_utils::term_from_json_paths; use crate::query::TermQuery; -use crate::schema::{Field, IndexRecordOption, Schema, Type, INDEXED, STRING, TEXT}; +use crate::schema::{Field, IndexRecordOption, Schema, INDEXED, STRING, TEXT}; use crate::tokenizer::TokenizerManager; use crate::{ Directory, DocSet, Index, IndexBuilder, IndexReader, IndexSettings, IndexWriter, Postings, @@ -137,7 +137,6 @@ mod mmap_specific { use tempfile::TempDir; use super::*; - use crate::Directory; #[test] fn test_index_on_commit_reload_policy_mmap() -> crate::Result<()> { @@ -417,16 +416,12 @@ fn test_non_text_json_term_freq() { let searcher = reader.searcher(); let segment_reader = searcher.segment_reader(0u32); let inv_idx = segment_reader.inverted_index(field).unwrap(); - let mut term = Term::with_type_and_field(Type::Json, field); - let mut json_term_writer = JsonTermWriter::wrap(&mut term, false); - json_term_writer.push_path_segment("tenant_id"); - json_term_writer.close_path_and_set_type(Type::U64); - json_term_writer.set_fast_value(75u64); + + let mut term = term_from_json_paths(field, ["tenant_id"].iter().cloned(), false); + term.append_type_and_fast_value(75u64); + let postings = inv_idx - .read_postings( - &json_term_writer.term(), - IndexRecordOption::WithFreqsAndPositions, - ) + .read_postings(&term, IndexRecordOption::WithFreqsAndPositions) .unwrap() .unwrap(); assert_eq!(postings.doc(), 0); @@ -455,16 +450,12 @@ fn test_non_text_json_term_freq_bitpacked() { let searcher = reader.searcher(); let segment_reader = searcher.segment_reader(0u32); let inv_idx = segment_reader.inverted_index(field).unwrap(); - let mut term = Term::with_type_and_field(Type::Json, field); - let mut json_term_writer = JsonTermWriter::wrap(&mut term, false); - json_term_writer.push_path_segment("tenant_id"); - json_term_writer.close_path_and_set_type(Type::U64); - json_term_writer.set_fast_value(75u64); + + let mut term = term_from_json_paths(field, ["tenant_id"].iter().cloned(), false); + term.append_type_and_fast_value(75u64); + let mut postings = inv_idx - .read_postings( - &json_term_writer.term(), - IndexRecordOption::WithFreqsAndPositions, - ) + .read_postings(&term, IndexRecordOption::WithFreqsAndPositions) .unwrap() .unwrap(); assert_eq!(postings.doc(), 0); diff --git a/src/directory/composite_file.rs b/src/directory/composite_file.rs index d33b67b95..11d8929f1 100644 --- a/src/directory/composite_file.rs +++ b/src/directory/composite_file.rs @@ -1,6 +1,5 @@ use std::collections::HashMap; use std::io::{self, Read, Write}; -use std::iter::ExactSizeIterator; use std::ops::Range; use common::{BinarySerializable, CountingWriter, HasLen, VInt}; diff --git a/src/directory/directory.rs b/src/directory/directory.rs index 570307aeb..19df314d9 100644 --- a/src/directory/directory.rs +++ b/src/directory/directory.rs @@ -1,5 +1,4 @@ use std::io::Write; -use std::marker::{Send, Sync}; use std::path::{Path, PathBuf}; use std::sync::Arc; use std::time::Duration; @@ -40,6 +39,7 @@ impl RetryPolicy { /// The `DirectoryLock` is an object that represents a file lock. /// /// It is associated with a lock file, that gets deleted on `Drop.` +#[allow(dead_code)] pub struct DirectoryLock(Box<dyn Send + Sync + 'static>); struct DirectoryLockGuard { diff --git a/src/directory/mmap_directory.rs b/src/directory/mmap_directory.rs index 781fe7e20..f953f4689 100644 --- a/src/directory/mmap_directory.rs +++ b/src/directory/mmap_directory.rs @@ -479,6 +479,7 @@ impl Directory for MmapDirectory { let file: File = OpenOptions::new() .write(true) .create(true) //< if the file does not exist yet, create it. + .truncate(false) .open(full_path) .map_err(LockError::wrap_io_error)?; if lock.is_blocking { @@ -673,7 +674,7 @@ mod tests { let num_segments = reader.searcher().segment_readers().len(); assert!(num_segments <= 4); let num_components_except_deletes_and_tempstore = - crate::core::SegmentComponent::iterator().len() - 2; + crate::index::SegmentComponent::iterator().len() - 2; let max_num_mmapped = num_components_except_deletes_and_tempstore * num_segments; assert_eventually(|| { let num_mmapped = mmap_directory.get_cache_info().mmapped.len(); diff --git a/src/directory/ram_directory.rs b/src/directory/ram_directory.rs index a7a29b15c..cfd447c22 100644 --- a/src/directory/ram_directory.rs +++ b/src/directory/ram_directory.rs @@ -85,7 +85,7 @@ impl InnerDirectory { self.fs .get(path) .ok_or_else(|| OpenReadError::FileDoesNotExist(PathBuf::from(path))) - .map(Clone::clone) + .cloned() } fn delete(&mut self, path: &Path) -> result::Result<(), DeleteError> { diff --git a/src/directory/tests.rs b/src/directory/tests.rs index 2ef1868c2..a2c8473ce 100644 --- a/src/directory/tests.rs +++ b/src/directory/tests.rs @@ -1,6 +1,6 @@ use std::io::Write; use std::mem; -use std::path::{Path, PathBuf}; +use std::path::Path; use std::sync::atomic::Ordering::SeqCst; use std::sync::atomic::{AtomicBool, AtomicUsize}; use std::sync::Arc; diff --git a/src/directory/watch_event_router.rs b/src/directory/watch_event_router.rs index d47fc2d22..28fd83c46 100644 --- a/src/directory/watch_event_router.rs +++ b/src/directory/watch_event_router.rs @@ -32,6 +32,7 @@ pub struct WatchCallbackList { /// file change is detected. #[must_use = "This `WatchHandle` controls the lifetime of the watch and should therefore be used."] #[derive(Clone)] +#[allow(dead_code)] pub struct WatchHandle(Arc<WatchCallback>); impl WatchHandle { diff --git a/src/docset.rs b/src/docset.rs index 7f0b10c70..d04024d45 100644 --- a/src/docset.rs +++ b/src/docset.rs @@ -9,7 +9,10 @@ use crate::DocId; /// to compare `[u32; 4]`. pub const TERMINATED: DocId = i32::MAX as u32; -pub const BUFFER_LEN: usize = 64; +/// The collect_block method on `SegmentCollector` uses a buffer of this size. +/// Passed results to `collect_block` will not exceed this size and will be +/// exactly this size as long as we can fill the buffer. +pub const COLLECT_BLOCK_BUFFER_LEN: usize = 64; /// Represents an iterable set of sorted doc ids. pub trait DocSet: Send { @@ -61,7 +64,7 @@ pub trait DocSet: Send { /// This method is only here for specific high-performance /// use case where batching. The normal way to /// go through the `DocId`'s is to call `.advance()`. - fn fill_buffer(&mut self, buffer: &mut [DocId; BUFFER_LEN]) -> usize { + fn fill_buffer(&mut self, buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN]) -> usize { if self.doc() == TERMINATED { return 0; } @@ -151,7 +154,7 @@ impl<TDocSet: DocSet + ?Sized> DocSet for Box<TDocSet> { unboxed.seek(target) } - fn fill_buffer(&mut self, buffer: &mut [DocId; BUFFER_LEN]) -> usize { + fn fill_buffer(&mut self, buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN]) -> usize { let unboxed: &mut TDocSet = self.borrow_mut(); unboxed.fill_buffer(buffer) } diff --git a/src/fastfield/mod.rs b/src/fastfield/mod.rs index 65e4c03bf..e0689650b 100644 --- a/src/fastfield/mod.rs +++ b/src/fastfield/mod.rs @@ -79,7 +79,7 @@ mod tests { use std::ops::{Range, RangeInclusive}; use std::path::Path; - use columnar::{Column, MonotonicallyMappableToU64, StrColumn}; + use columnar::StrColumn; use common::{ByteCount, HasLen, TerminatingWrite}; use once_cell::sync::Lazy; use rand::prelude::SliceRandom; @@ -131,7 +131,7 @@ mod tests { } let file = directory.open_read(path).unwrap(); - assert_eq!(file.len(), 93); + assert_eq!(file.len(), 80); let fast_field_readers = FastFieldReaders::open(file, SCHEMA.clone()).unwrap(); let column = fast_field_readers .u64("field") @@ -181,7 +181,7 @@ mod tests { write.terminate().unwrap(); } let file = directory.open_read(path).unwrap(); - assert_eq!(file.len(), 121); + assert_eq!(file.len(), 108); let fast_field_readers = FastFieldReaders::open(file, SCHEMA.clone()).unwrap(); let col = fast_field_readers .u64("field") @@ -214,7 +214,7 @@ mod tests { write.terminate().unwrap(); } let file = directory.open_read(path).unwrap(); - assert_eq!(file.len(), 94); + assert_eq!(file.len(), 81); let fast_field_readers = FastFieldReaders::open(file, SCHEMA.clone()).unwrap(); let fast_field_reader = fast_field_readers .u64("field") @@ -246,7 +246,7 @@ mod tests { write.terminate().unwrap(); } let file = directory.open_read(path).unwrap(); - assert_eq!(file.len(), 4489); + assert_eq!(file.len(), 4476); { let fast_field_readers = FastFieldReaders::open(file, SCHEMA.clone()).unwrap(); let col = fast_field_readers @@ -279,7 +279,7 @@ mod tests { write.terminate().unwrap(); } let file = directory.open_read(path).unwrap(); - assert_eq!(file.len(), 265); + assert_eq!(file.len(), 252); { let fast_field_readers = FastFieldReaders::open(file, schema).unwrap(); @@ -773,7 +773,7 @@ mod tests { write.terminate().unwrap(); } let file = directory.open_read(path).unwrap(); - assert_eq!(file.len(), 102); + assert_eq!(file.len(), 84); let fast_field_readers = FastFieldReaders::open(file, schema).unwrap(); let bool_col = fast_field_readers.bool("field_bool").unwrap(); assert_eq!(bool_col.first(0), Some(true)); @@ -805,7 +805,7 @@ mod tests { write.terminate().unwrap(); } let file = directory.open_read(path).unwrap(); - assert_eq!(file.len(), 114); + assert_eq!(file.len(), 96); let readers = FastFieldReaders::open(file, schema).unwrap(); let bool_col = readers.bool("field_bool").unwrap(); for i in 0..25 { @@ -830,7 +830,7 @@ mod tests { write.terminate().unwrap(); } let file = directory.open_read(path).unwrap(); - assert_eq!(file.len(), 104); + assert_eq!(file.len(), 86); let fastfield_readers = FastFieldReaders::open(file, schema).unwrap(); let col = fastfield_readers.bool("field_bool").unwrap(); assert_eq!(col.first(0), None); diff --git a/src/functional_test.rs b/src/functional_test.rs index f69136b2d..aac54fe87 100644 --- a/src/functional_test.rs +++ b/src/functional_test.rs @@ -1,3 +1,5 @@ +#![allow(deprecated)] // Remove with index sorting + use std::collections::HashSet; use rand::{thread_rng, Rng}; diff --git a/src/core/index.rs b/src/index/index.rs similarity index 95% rename from src/core/index.rs rename to src/index/index.rs index efdff505b..e02212cd5 100644 --- a/src/core/index.rs +++ b/src/index/index.rs @@ -6,24 +6,23 @@ use std::path::PathBuf; use std::sync::Arc; use super::segment::Segment; -use super::IndexSettings; -use crate::core::single_segment_index_writer::SingleSegmentIndexWriter; -use crate::core::{ - Executor, IndexMeta, SegmentId, SegmentMeta, SegmentMetaInventory, META_FILEPATH, -}; +use super::segment_reader::merge_field_meta_data; +use super::{FieldMetadata, IndexSettings}; +use crate::core::{Executor, META_FILEPATH}; use crate::directory::error::OpenReadError; #[cfg(feature = "mmap")] use crate::directory::MmapDirectory; use crate::directory::{Directory, ManagedDirectory, RamDirectory, INDEX_WRITER_LOCK}; use crate::error::{DataCorruption, TantivyError}; +use crate::index::{IndexMeta, SegmentId, SegmentMeta, SegmentMetaInventory}; use crate::indexer::index_writer::{MAX_NUM_THREAD, MEMORY_BUDGET_NUM_BYTES_MIN}; use crate::indexer::segment_updater::save_metas; -use crate::indexer::IndexWriter; +use crate::indexer::{IndexWriter, SingleSegmentIndexWriter}; use crate::reader::{IndexReader, IndexReaderBuilder}; use crate::schema::document::Document; -use crate::schema::{Field, FieldType, Schema}; +use crate::schema::{Field, FieldType, Schema, Type}; use crate::tokenizer::{TextAnalyzer, TokenizerManager}; -use crate::{merge_field_meta_data, FieldMetadata, SegmentReader}; +use crate::SegmentReader; fn load_metas( directory: &dyn Directory, @@ -84,7 +83,7 @@ fn save_new_metas( /// /// ``` /// use tantivy::schema::*; -/// use tantivy::{Index, IndexSettings, IndexSortByField, Order}; +/// use tantivy::{Index, IndexSettings}; /// /// let mut schema_builder = Schema::builder(); /// let id_field = schema_builder.add_text_field("id", STRING); @@ -97,10 +96,7 @@ fn save_new_metas( /// /// let schema = schema_builder.build(); /// let settings = IndexSettings{ -/// sort_by_field: Some(IndexSortByField{ -/// field: "number".to_string(), -/// order: Order::Asc -/// }), +/// docstore_blocksize: 100_000, /// ..Default::default() /// }; /// let index = Index::builder().schema(schema).settings(settings).create_in_ram(); @@ -252,6 +248,15 @@ impl IndexBuilder { sort_by_field.field ))); } + let supported_field_types = [Type::I64, Type::U64, Type::F64, Type::Date]; + let field_type = entry.field_type().value_type(); + if !supported_field_types.contains(&field_type) { + return Err(TantivyError::InvalidArgument(format!( + "Unsupported field type in sort_by_field: {:?}. Supported field types: \ + {:?} ", + field_type, supported_field_types, + ))); + } } Ok(()) } else { @@ -323,6 +328,15 @@ impl Index { Ok(()) } + /// Custom thread pool by a outer thread pool. + pub fn set_shared_multithread_executor( + &mut self, + shared_thread_pool: Arc<Executor>, + ) -> crate::Result<()> { + self.executor = shared_thread_pool.clone(); + Ok(()) + } + /// Replace the default single thread search executor pool /// by a thread pool with as many threads as there are CPUs on the system. pub fn set_default_multithread_executor(&mut self) -> crate::Result<()> { diff --git a/src/core/index_meta.rs b/src/index/index_meta.rs similarity index 98% rename from src/core/index_meta.rs rename to src/index/index_meta.rs index 0ed61e2a6..f478ac2fd 100644 --- a/src/core/index_meta.rs +++ b/src/index/index_meta.rs @@ -7,7 +7,7 @@ use std::sync::Arc; use serde::{Deserialize, Serialize}; use super::SegmentComponent; -use crate::core::SegmentId; +use crate::index::SegmentId; use crate::schema::Schema; use crate::store::Compressor; use crate::{Inventory, Opstamp, TrackedObject}; @@ -19,7 +19,7 @@ struct DeleteMeta { } #[derive(Clone, Default)] -pub struct SegmentMetaInventory { +pub(crate) struct SegmentMetaInventory { inventory: Inventory<InnerSegmentMeta>, } @@ -288,6 +288,10 @@ impl Default for IndexSettings { /// Presorting documents can greatly improve performance /// in some scenarios, by applying top n /// optimizations. +#[deprecated( + since = "0.22.0", + note = "We plan to remove index sorting in `0.23`. If you need index sorting, please comment on the related issue https://github.com/quickwit-oss/tantivy/issues/2352 and explain your use case." +)] #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] pub struct IndexSortByField { /// The field to sort the documents by @@ -408,7 +412,7 @@ impl fmt::Debug for IndexMeta { mod tests { use super::IndexMeta; - use crate::core::index_meta::UntrackedIndexMeta; + use crate::index::index_meta::UntrackedIndexMeta; use crate::schema::{Schema, TEXT}; use crate::store::Compressor; #[cfg(feature = "zstd-compression")] diff --git a/src/core/inverted_index_reader.rs b/src/index/inverted_index_reader.rs similarity index 96% rename from src/core/inverted_index_reader.rs rename to src/index/inverted_index_reader.rs index 059ec988c..c5df591bf 100644 --- a/src/core/inverted_index_reader.rs +++ b/src/index/inverted_index_reader.rs @@ -1,12 +1,13 @@ use std::io; +use common::json_path_writer::JSON_END_OF_PATH; use common::BinarySerializable; use fnv::FnvHashSet; use crate::directory::FileSlice; use crate::positions::PositionReader; use crate::postings::{BlockSegmentPostings, SegmentPostings, TermInfo}; -use crate::schema::{IndexRecordOption, Term, Type, JSON_END_OF_PATH}; +use crate::schema::{IndexRecordOption, Term, Type}; use crate::termdict::TermDictionary; /// The inverted index reader is in charge of accessing @@ -266,7 +267,9 @@ impl InvertedIndexReader { /// Warmup a block postings given a `Term`. /// This method is for an advanced usage only. - pub async fn warm_postings(&self, term: &Term, with_positions: bool) -> io::Result<()> { + /// + /// returns a boolean, whether the term was found in the dictionary + pub async fn warm_postings(&self, term: &Term, with_positions: bool) -> io::Result<bool> { let term_info_opt: Option<TermInfo> = self.get_term_info_async(term).await?; if let Some(term_info) = term_info_opt { let postings = self @@ -280,23 +283,27 @@ impl InvertedIndexReader { } else { postings.await?; } + Ok(true) + } else { + Ok(false) } - Ok(()) } /// Warmup a block postings given a range of `Term`s. /// This method is for an advanced usage only. + /// + /// returns a boolean, whether a term matching the range was found in the dictionary pub async fn warm_postings_range( &self, terms: impl std::ops::RangeBounds<Term>, limit: Option<u64>, with_positions: bool, - ) -> io::Result<()> { + ) -> io::Result<bool> { let mut term_info = self.get_term_range_async(terms, limit).await?; let Some(first_terminfo) = term_info.next() else { // no key matches, nothing more to load - return Ok(()); + return Ok(false); }; let last_terminfo = term_info.last().unwrap_or_else(|| first_terminfo.clone()); @@ -316,7 +323,7 @@ impl InvertedIndexReader { } else { postings.await?; } - Ok(()) + Ok(true) } /// Warmup the block postings for all terms. diff --git a/src/index/mod.rs b/src/index/mod.rs new file mode 100644 index 000000000..89e71d2c8 --- /dev/null +++ b/src/index/mod.rs @@ -0,0 +1,22 @@ +//! # Index Module +//! +//! The `index` module in Tantivy contains core components to read and write indexes. +//! +//! It contains `Index` and `Segment`, where a `Index` consists of one or more `Segment`s. + +mod index; +mod index_meta; +mod inverted_index_reader; +mod segment; +mod segment_component; +mod segment_id; +mod segment_reader; + +pub use self::index::{Index, IndexBuilder}; +pub(crate) use self::index_meta::SegmentMetaInventory; +pub use self::index_meta::{IndexMeta, IndexSettings, IndexSortByField, Order, SegmentMeta}; +pub use self::inverted_index_reader::InvertedIndexReader; +pub use self::segment::Segment; +pub use self::segment_component::SegmentComponent; +pub use self::segment_id::SegmentId; +pub use self::segment_reader::{FieldMetadata, SegmentReader}; diff --git a/src/core/segment.rs b/src/index/segment.rs similarity index 98% rename from src/core/segment.rs rename to src/index/segment.rs index 21cf2d691..4c9382cb0 100644 --- a/src/core/segment.rs +++ b/src/index/segment.rs @@ -2,9 +2,9 @@ use std::fmt; use std::path::PathBuf; use super::SegmentComponent; -use crate::core::{Index, SegmentId, SegmentMeta}; use crate::directory::error::{OpenReadError, OpenWriteError}; use crate::directory::{Directory, FileSlice, WritePtr}; +use crate::index::{Index, SegmentId, SegmentMeta}; use crate::schema::Schema; use crate::Opstamp; diff --git a/src/core/segment_component.rs b/src/index/segment_component.rs similarity index 100% rename from src/core/segment_component.rs rename to src/index/segment_component.rs diff --git a/src/core/segment_id.rs b/src/index/segment_id.rs similarity index 99% rename from src/core/segment_id.rs rename to src/index/segment_id.rs index 5e2cf1b32..e66aa95a9 100644 --- a/src/core/segment_id.rs +++ b/src/index/segment_id.rs @@ -1,4 +1,4 @@ -use std::cmp::{Ord, Ordering}; +use std::cmp::Ordering; use std::error::Error; use std::fmt; use std::str::FromStr; diff --git a/src/core/segment_reader.rs b/src/index/segment_reader.rs similarity index 99% rename from src/core/segment_reader.rs rename to src/index/segment_reader.rs index cae1b537d..c86ee5906 100644 --- a/src/core/segment_reader.rs +++ b/src/index/segment_reader.rs @@ -6,11 +6,11 @@ use std::{fmt, io}; use fnv::FnvHashMap; use itertools::Itertools; -use crate::core::{InvertedIndexReader, Segment, SegmentComponent, SegmentId}; use crate::directory::{CompositeFile, FileSlice}; use crate::error::DataCorruption; use crate::fastfield::{intersect_alive_bitsets, AliveBitSet, FacetReader, FastFieldReaders}; use crate::fieldnorm::{FieldNormReader, FieldNormReaders}; +use crate::index::{InvertedIndexReader, Segment, SegmentComponent, SegmentId}; use crate::json_utils::json_path_sep_to_dot; use crate::schema::{Field, IndexRecordOption, Schema, Type}; use crate::space_usage::SegmentSpaceUsage; @@ -406,7 +406,7 @@ impl SegmentReader { } /// Returns an iterator that will iterate over the alive document ids - pub fn doc_ids_alive(&self) -> Box<dyn Iterator<Item = DocId> + '_> { + pub fn doc_ids_alive(&self) -> Box<dyn Iterator<Item = DocId> + Send + '_> { if let Some(alive_bitset) = &self.alive_bitset_opt { Box::new(alive_bitset.iter_alive()) } else { @@ -515,9 +515,9 @@ impl fmt::Debug for SegmentReader { #[cfg(test)] mod test { use super::*; - use crate::core::Index; - use crate::schema::{Schema, SchemaBuilder, Term, STORED, TEXT}; - use crate::{DocId, FieldMetadata, IndexWriter}; + use crate::index::Index; + use crate::schema::{SchemaBuilder, Term, STORED, TEXT}; + use crate::IndexWriter; #[test] fn test_merge_field_meta_data_same() { diff --git a/src/indexer/doc_id_mapping.rs b/src/indexer/doc_id_mapping.rs index 0fad45eb1..63460eda3 100644 --- a/src/indexer/doc_id_mapping.rs +++ b/src/indexer/doc_id_mapping.rs @@ -158,9 +158,8 @@ mod tests_indexsorting { use crate::indexer::doc_id_mapping::DocIdMapping; use crate::indexer::NoMergePolicy; use crate::query::QueryParser; - use crate::schema::document::Value; - use crate::schema::{Schema, *}; - use crate::{DocAddress, Index, IndexSettings, IndexSortByField, Order}; + use crate::schema::*; + use crate::{DocAddress, Index, IndexBuilder, IndexSettings, IndexSortByField, Order}; fn create_test_index( index_settings: Option<IndexSettings>, @@ -558,4 +557,28 @@ mod tests_indexsorting { &[2000, 8000, 3000] ); } + + #[test] + fn test_text_sort() -> crate::Result<()> { + let mut schema_builder = SchemaBuilder::new(); + schema_builder.add_text_field("id", STRING | FAST | STORED); + schema_builder.add_text_field("name", TEXT | STORED); + + let resp = IndexBuilder::new() + .schema(schema_builder.build()) + .settings(IndexSettings { + sort_by_field: Some(IndexSortByField { + field: "id".to_string(), + order: Order::Asc, + }), + ..Default::default() + }) + .create_in_ram(); + assert!(resp + .unwrap_err() + .to_string() + .contains("Unsupported field type")); + + Ok(()) + } } diff --git a/src/indexer/flat_map_with_buffer.rs b/src/indexer/flat_map_with_buffer.rs index 88b509cdb..9f2a1924e 100644 --- a/src/indexer/flat_map_with_buffer.rs +++ b/src/indexer/flat_map_with_buffer.rs @@ -22,6 +22,7 @@ where } } +#[allow(dead_code)] pub trait FlatMapWithBufferIter: Iterator { /// Function similar to `flat_map`, but allows reusing a shared `Vec`. fn flat_map_with_buffer<F, T>(self, fill_buffer: F) -> FlatMapWithBuffer<T, F, Self> diff --git a/src/indexer/index_writer.rs b/src/indexer/index_writer.rs index 2323806d1..c13853bc0 100644 --- a/src/indexer/index_writer.rs +++ b/src/indexer/index_writer.rs @@ -9,10 +9,10 @@ use smallvec::smallvec; use super::operation::{AddOperation, UserOperation}; use super::segment_updater::SegmentUpdater; use super::{AddBatch, AddBatchReceiver, AddBatchSender, PreparedCommit}; -use crate::core::{Index, Segment, SegmentComponent, SegmentId, SegmentMeta, SegmentReader}; use crate::directory::{DirectoryLock, GarbageCollectionResult, TerminatingWrite}; use crate::error::TantivyError; use crate::fastfield::write_alive_bitset; +use crate::index::{Index, Segment, SegmentComponent, SegmentId, SegmentMeta, SegmentReader}; use crate::indexer::delete_queue::{DeleteCursor, DeleteQueue}; use crate::indexer::doc_opstamp_mapping::DocToOpstampMapping; use crate::indexer::index_writer_status::IndexWriterStatus; @@ -806,7 +806,6 @@ mod tests { use columnar::{Cardinality, Column, MonotonicallyMappableToU128}; use itertools::Itertools; use proptest::prop_oneof; - use proptest::strategy::Strategy; use super::super::operation::UserOperation; use crate::collector::TopDocs; @@ -1651,6 +1650,7 @@ mod tests { force_end_merge: bool, ) -> crate::Result<Index> { let mut schema_builder = schema::Schema::builder(); + let json_field = schema_builder.add_json_field("json", FAST | TEXT | STORED); let ip_field = schema_builder.add_ip_addr_field("ip", FAST | INDEXED | STORED); let ips_field = schema_builder .add_ip_addr_field("ips", IpAddrOptions::default().set_fast().set_indexed()); @@ -1729,7 +1729,9 @@ mod tests { id_field=>id, ))?; } else { + let json = json!({"date1": format!("2022-{id}-01T00:00:01Z"), "date2": format!("{id}-05-01T00:00:01Z"), "id": id, "ip": ip.to_string()}); index_writer.add_document(doc!(id_field=>id, + json_field=>json, bytes_field => id.to_le_bytes().as_slice(), id_opt_field => id, ip_field => ip, diff --git a/src/indexer/log_merge_policy.rs b/src/indexer/log_merge_policy.rs index b7ee34dcc..726deb578 100644 --- a/src/indexer/log_merge_policy.rs +++ b/src/indexer/log_merge_policy.rs @@ -3,7 +3,7 @@ use std::cmp; use itertools::Itertools; use super::merge_policy::{MergeCandidate, MergePolicy}; -use crate::core::SegmentMeta; +use crate::index::SegmentMeta; const DEFAULT_LEVEL_LOG_SIZE: f64 = 0.75; const DEFAULT_MIN_LAYER_SIZE: u32 = 10_000; @@ -144,10 +144,9 @@ mod tests { use once_cell::sync::Lazy; use super::*; - use crate::core::{SegmentId, SegmentMeta, SegmentMetaInventory}; - use crate::indexer::merge_policy::MergePolicy; - use crate::schema; + use crate::index::SegmentMetaInventory; use crate::schema::INDEXED; + use crate::{schema, SegmentId}; static INVENTORY: Lazy<SegmentMetaInventory> = Lazy::new(SegmentMetaInventory::default); diff --git a/src/indexer/merge_policy.rs b/src/indexer/merge_policy.rs index 1e4503f97..4215caaac 100644 --- a/src/indexer/merge_policy.rs +++ b/src/indexer/merge_policy.rs @@ -1,7 +1,7 @@ use std::fmt::Debug; use std::marker; -use crate::core::{SegmentId, SegmentMeta}; +use crate::index::{SegmentId, SegmentMeta}; /// Set of segment suggested for a merge. #[derive(Debug, Clone)] @@ -39,7 +39,6 @@ impl MergePolicy for NoMergePolicy { pub mod tests { use super::*; - use crate::core::{SegmentId, SegmentMeta}; /// `MergePolicy` useful for test purposes. /// diff --git a/src/indexer/merger.rs b/src/indexer/merger.rs index 87bc4c8c8..4cc455713 100644 --- a/src/indexer/merger.rs +++ b/src/indexer/merger.rs @@ -8,12 +8,12 @@ use common::ReadOnlyBitSet; use itertools::Itertools; use measure_time::debug_time; -use crate::core::{Segment, SegmentReader}; use crate::directory::WritePtr; use crate::docset::{DocSet, TERMINATED}; use crate::error::DataCorruption; use crate::fastfield::{AliveBitSet, FastFieldNotAvailableError}; use crate::fieldnorm::{FieldNormReader, FieldNormReaders, FieldNormsSerializer, FieldNormsWriter}; +use crate::index::{Segment, SegmentReader}; use crate::indexer::doc_id_mapping::{MappingType, SegmentDocIdMapping}; use crate::indexer::SegmentSerializer; use crate::postings::{InvertedIndexSerializer, Postings, SegmentPostings}; @@ -576,7 +576,7 @@ impl IndexMerger { // // Overall the reliable way to know if we have actual frequencies loaded or not // is to check whether the actual decoded array is empty or not. - if has_term_freq != !postings.block_cursor.freqs().is_empty() { + if has_term_freq == postings.block_cursor.freqs().is_empty() { return Err(DataCorruption::comment_only( "Term freqs are inconsistent across segments", ) @@ -605,6 +605,10 @@ impl IndexMerger { segment_postings.positions(&mut positions_buffer); segment_postings.term_freq() } else { + // The positions_buffer may contain positions from the previous term + // Existence of positions depend on the value type in JSON fields. + // https://github.com/quickwit-oss/tantivy/issues/2283 + positions_buffer.clear(); 0u32 }; @@ -790,7 +794,7 @@ mod tests { BytesFastFieldTestCollector, FastFieldTestCollector, TEST_COLLECTOR_WITH_SCORE, }; use crate::collector::{Count, FacetCollector}; - use crate::core::Index; + use crate::index::Index; use crate::query::{AllQuery, BooleanQuery, EnableScoring, Scorer, TermQuery}; use crate::schema::document::Value; use crate::schema::{ diff --git a/src/indexer/merger_sorted_index_test.rs b/src/indexer/merger_sorted_index_test.rs index d27eab4ed..3b256a634 100644 --- a/src/indexer/merger_sorted_index_test.rs +++ b/src/indexer/merger_sorted_index_test.rs @@ -1,8 +1,8 @@ #[cfg(test)] mod tests { use crate::collector::TopDocs; - use crate::core::Index; use crate::fastfield::AliveBitSet; + use crate::index::Index; use crate::query::QueryParser; use crate::schema::document::Value; use crate::schema::{ @@ -485,7 +485,7 @@ mod bench_sorted_index_merge { use test::{self, Bencher}; - use crate::core::Index; + use crate::index::Index; use crate::indexer::merger::IndexMerger; use crate::schema::{NumericOptions, Schema}; use crate::{IndexSettings, IndexSortByField, IndexWriter, Order}; diff --git a/src/indexer/mod.rs b/src/indexer/mod.rs index 204ce134b..692e2c108 100644 --- a/src/indexer/mod.rs +++ b/src/indexer/mod.rs @@ -25,6 +25,7 @@ mod segment_register; pub(crate) mod segment_serializer; pub(crate) mod segment_updater; pub(crate) mod segment_writer; +pub(crate) mod single_segment_index_writer; mod stamper; use crossbeam_channel as channel; @@ -34,13 +35,14 @@ pub use self::index_writer::IndexWriter; pub use self::log_merge_policy::LogMergePolicy; pub use self::merge_operation::MergeOperation; pub use self::merge_policy::{MergeCandidate, MergePolicy, NoMergePolicy}; +use self::operation::AddOperation; pub use self::operation::UserOperation; pub use self::prepared_commit::PreparedCommit; pub use self::segment_entry::SegmentEntry; pub(crate) use self::segment_serializer::SegmentSerializer; pub use self::segment_updater::{merge_filtered_segments, merge_indices}; pub use self::segment_writer::SegmentWriter; -use crate::indexer::operation::AddOperation; +pub use self::single_segment_index_writer::SingleSegmentIndexWriter; /// Alias for the default merge policy, which is the `LogMergePolicy`. pub type DefaultMergePolicy = LogMergePolicy; @@ -63,9 +65,10 @@ mod tests_mmap { use crate::aggregation::agg_result::AggregationResults; use crate::aggregation::AggregationCollector; use crate::collector::{Count, TopDocs}; + use crate::index::FieldMetadata; use crate::query::{AllQuery, QueryParser}; use crate::schema::{JsonObjectOptions, Schema, Type, FAST, INDEXED, STORED, TEXT}; - use crate::{FieldMetadata, Index, IndexWriter, Term}; + use crate::{Index, IndexWriter, Term}; #[test] fn test_advance_delete_bug() -> crate::Result<()> { @@ -141,6 +144,123 @@ mod tests_mmap { assert_eq!(num_docs, 256); } } + #[test] + fn test_json_field_null_byte() { + // Test when field name contains a zero byte, which has special meaning in tantivy. + // As a workaround, we convert the zero byte to the ASCII character '0'. + // https://github.com/quickwit-oss/tantivy/issues/2340 + // https://github.com/quickwit-oss/tantivy/issues/2193 + let field_name_in = "\u{0000}"; + let field_name_out = "0"; + test_json_field_name(field_name_in, field_name_out); + } + #[test] + fn test_json_field_1byte() { + // Test when field name contains a '1' byte, which has special meaning in tantivy. + // The 1 byte can be addressed as '1' byte or '.'. + let field_name_in = "\u{0001}"; + let field_name_out = "\u{0001}"; + test_json_field_name(field_name_in, field_name_out); + + // Test when field name contains a '1' byte, which has special meaning in tantivy. + let field_name_in = "\u{0001}"; + let field_name_out = "."; + test_json_field_name(field_name_in, field_name_out); + } + #[test] + fn test_json_field_dot() { + // Test when field name contains a '.' + let field_name_in = "."; + let field_name_out = "."; + test_json_field_name(field_name_in, field_name_out); + } + fn test_json_field_name(field_name_in: &str, field_name_out: &str) { + let mut schema_builder = Schema::builder(); + + let options = JsonObjectOptions::from(TEXT | FAST).set_expand_dots_enabled(); + let field = schema_builder.add_json_field("json", options); + let index = Index::create_in_ram(schema_builder.build()); + let mut index_writer = index.writer_for_tests().unwrap(); + index_writer + .add_document(doc!(field=>json!({format!("{field_name_in}"): "test1"}))) + .unwrap(); + index_writer + .add_document(doc!(field=>json!({format!("a{field_name_in}"): "test2"}))) + .unwrap(); + index_writer + .add_document(doc!(field=>json!({format!("a{field_name_in}a"): "test3"}))) + .unwrap(); + index_writer + .add_document( + doc!(field=>json!({format!("a{field_name_in}a{field_name_in}"): "test4"})), + ) + .unwrap(); + index_writer + .add_document( + doc!(field=>json!({format!("a{field_name_in}.ab{field_name_in}"): "test5"})), + ) + .unwrap(); + index_writer + .add_document( + doc!(field=>json!({format!("a{field_name_in}"): json!({format!("a{field_name_in}"): "test6"}) })), + ) + .unwrap(); + index_writer + .add_document(doc!(field=>json!({format!("{field_name_in}a" ): "test7"}))) + .unwrap(); + + index_writer.commit().unwrap(); + let reader = index.reader().unwrap(); + let searcher = reader.searcher(); + let parse_query = QueryParser::for_index(&index, Vec::new()); + let test_query = |query_str: &str| { + let query = parse_query.parse_query(query_str).unwrap(); + let num_docs = searcher.search(&query, &Count).unwrap(); + assert_eq!(num_docs, 1, "{}", query_str); + }; + test_query(format!("json.{field_name_out}:test1").as_str()); + test_query(format!("json.a{field_name_out}:test2").as_str()); + test_query(format!("json.a{field_name_out}a:test3").as_str()); + test_query(format!("json.a{field_name_out}a{field_name_out}:test4").as_str()); + test_query(format!("json.a{field_name_out}.ab{field_name_out}:test5").as_str()); + test_query(format!("json.a{field_name_out}.a{field_name_out}:test6").as_str()); + test_query(format!("json.{field_name_out}a:test7").as_str()); + + let test_agg = |field_name: &str, expected: &str| { + let agg_req_str = json!( + { + "termagg": { + "terms": { + "field": field_name, + } + } + }); + + let agg_req: Aggregations = serde_json::from_value(agg_req_str).unwrap(); + let collector = AggregationCollector::from_aggs(agg_req, Default::default()); + let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap(); + let res = serde_json::to_value(agg_res).unwrap(); + assert_eq!(res["termagg"]["buckets"][0]["doc_count"], 1); + assert_eq!(res["termagg"]["buckets"][0]["key"], expected); + }; + + test_agg(format!("json.{field_name_out}").as_str(), "test1"); + test_agg(format!("json.a{field_name_out}").as_str(), "test2"); + test_agg(format!("json.a{field_name_out}a").as_str(), "test3"); + test_agg( + format!("json.a{field_name_out}a{field_name_out}").as_str(), + "test4", + ); + test_agg( + format!("json.a{field_name_out}.ab{field_name_out}").as_str(), + "test5", + ); + test_agg( + format!("json.a{field_name_out}.a{field_name_out}").as_str(), + "test6", + ); + test_agg(format!("json.{field_name_out}a").as_str(), "test7"); + } #[test] fn test_json_field_expand_dots_enabled_dot_escape_not_required() { @@ -403,11 +523,10 @@ mod tests_mmap { let searcher = reader.searcher(); - let fields_and_vals = vec![ - // Only way to address or it gets shadowed by `json.shadow` field + let fields_and_vals = [ ("json.shadow\u{1}val".to_string(), "a"), // Succeeds //("json.shadow.val".to_string(), "a"), // Fails - ("json.shadow.val".to_string(), "b"), // Succeeds + ("json.shadow.val".to_string(), "b"), ]; let query_parser = QueryParser::for_index(&index, vec![]); diff --git a/src/indexer/segment_entry.rs b/src/indexer/segment_entry.rs index 0e5002338..56fcf09b2 100644 --- a/src/indexer/segment_entry.rs +++ b/src/indexer/segment_entry.rs @@ -2,7 +2,7 @@ use std::fmt; use common::BitSet; -use crate::core::{SegmentId, SegmentMeta}; +use crate::index::{SegmentId, SegmentMeta}; use crate::indexer::delete_queue::DeleteCursor; /// A segment entry describes the state of diff --git a/src/indexer/segment_manager.rs b/src/indexer/segment_manager.rs index a65621514..810e10170 100644 --- a/src/indexer/segment_manager.rs +++ b/src/indexer/segment_manager.rs @@ -3,8 +3,8 @@ use std::fmt::{self, Debug, Formatter}; use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; use super::segment_register::SegmentRegister; -use crate::core::{SegmentId, SegmentMeta}; use crate::error::TantivyError; +use crate::index::{SegmentId, SegmentMeta}; use crate::indexer::delete_queue::DeleteCursor; use crate::indexer::SegmentEntry; diff --git a/src/indexer/segment_register.rs b/src/indexer/segment_register.rs index 0068d598b..0e7046310 100644 --- a/src/indexer/segment_register.rs +++ b/src/indexer/segment_register.rs @@ -1,7 +1,7 @@ use std::collections::{HashMap, HashSet}; use std::fmt::{self, Debug, Display, Formatter}; -use crate::core::{SegmentId, SegmentMeta}; +use crate::index::{SegmentId, SegmentMeta}; use crate::indexer::delete_queue::DeleteCursor; use crate::indexer::segment_entry::SegmentEntry; @@ -103,7 +103,7 @@ impl SegmentRegister { #[cfg(test)] mod tests { use super::*; - use crate::core::{SegmentId, SegmentMetaInventory}; + use crate::index::SegmentMetaInventory; use crate::indexer::delete_queue::*; fn segment_ids(segment_register: &SegmentRegister) -> Vec<SegmentId> { diff --git a/src/indexer/segment_serializer.rs b/src/indexer/segment_serializer.rs index 6dcb442ff..cdb7a8ef5 100644 --- a/src/indexer/segment_serializer.rs +++ b/src/indexer/segment_serializer.rs @@ -1,8 +1,8 @@ use common::TerminatingWrite; -use crate::core::{Segment, SegmentComponent}; use crate::directory::WritePtr; use crate::fieldnorm::FieldNormsSerializer; +use crate::index::{Segment, SegmentComponent}; use crate::postings::InvertedIndexSerializer; use crate::store::StoreWriter; diff --git a/src/indexer/segment_updater.rs b/src/indexer/segment_updater.rs index e4d056d8c..12faba951 100644 --- a/src/indexer/segment_updater.rs +++ b/src/indexer/segment_updater.rs @@ -9,11 +9,10 @@ use std::sync::{Arc, RwLock}; use rayon::{ThreadPool, ThreadPoolBuilder}; use super::segment_manager::SegmentManager; -use crate::core::{ - Index, IndexMeta, IndexSettings, Segment, SegmentId, SegmentMeta, META_FILEPATH, -}; +use crate::core::META_FILEPATH; use crate::directory::{Directory, DirectoryClone, GarbageCollectionResult}; use crate::fastfield::AliveBitSet; +use crate::index::{Index, IndexMeta, IndexSettings, Segment, SegmentId, SegmentMeta}; use crate::indexer::delete_queue::DeleteCursor; use crate::indexer::index_writer::advance_deletes; use crate::indexer::merge_operation::MergeOperationInventory; diff --git a/src/indexer/segment_writer.rs b/src/indexer/segment_writer.rs index 1888f3b47..384a939e6 100644 --- a/src/indexer/segment_writer.rs +++ b/src/indexer/segment_writer.rs @@ -6,9 +6,9 @@ use tokenizer_api::BoxTokenStream; use super::doc_id_mapping::{get_doc_id_mapping_from_field, DocIdMapping}; use super::operation::AddOperation; use crate::core::json_utils::index_json_values; -use crate::core::Segment; use crate::fastfield::FastFieldsWriter; use crate::fieldnorm::{FieldNormReaders, FieldNormsWriter}; +use crate::index::Segment; use crate::indexer::segment_serializer::SegmentSerializer; use crate::postings::{ compute_table_memory_size, serialize_postings, IndexingContext, IndexingPosition, @@ -496,14 +496,14 @@ mod tests { use tempfile::TempDir; use crate::collector::{Count, TopDocs}; - use crate::core::json_utils::JsonTermWriter; use crate::directory::RamDirectory; + use crate::fastfield::FastValue; + use crate::json_utils::term_from_json_paths; use crate::postings::TermInfo; use crate::query::{PhraseQuery, QueryParser}; use crate::schema::document::Value; use crate::schema::{ - Document, IndexRecordOption, Schema, TextFieldIndexing, TextOptions, Type, STORED, STRING, - TEXT, + Document, IndexRecordOption, Schema, TextFieldIndexing, TextOptions, STORED, STRING, TEXT, }; use crate::store::{Compressor, StoreReader, StoreWriter}; use crate::time::format_description::well_known::Rfc3339; @@ -645,115 +645,117 @@ mod tests { let inv_idx = segment_reader.inverted_index(json_field).unwrap(); let term_dict = inv_idx.terms(); - let mut term = Term::with_type_and_field(Type::Json, json_field); let mut term_stream = term_dict.stream().unwrap(); - let mut json_term_writer = JsonTermWriter::wrap(&mut term, false); + let term_from_path = |paths: &[&str]| -> Term { + term_from_json_paths(json_field, paths.iter().cloned(), false) + }; - json_term_writer.push_path_segment("bool"); - json_term_writer.set_fast_value(true); + fn set_fast_val<T: FastValue>(val: T, mut term: Term) -> Term { + term.append_type_and_fast_value(val); + term + } + fn set_str(val: &str, mut term: Term) -> Term { + term.append_type_and_str(val); + term + } + + let term = term_from_path(&["bool"]); assert!(term_stream.advance()); assert_eq!( term_stream.key(), - json_term_writer.term().serialized_value_bytes() + set_fast_val(true, term).serialized_value_bytes() ); - json_term_writer.pop_path_segment(); - json_term_writer.push_path_segment("complexobject"); - json_term_writer.push_path_segment("field.with.dot"); - json_term_writer.set_fast_value(1i64); + let term = term_from_path(&["complexobject", "field.with.dot"]); + assert!(term_stream.advance()); assert_eq!( term_stream.key(), - json_term_writer.term().serialized_value_bytes() + set_fast_val(1i64, term).serialized_value_bytes() ); - json_term_writer.pop_path_segment(); - json_term_writer.pop_path_segment(); - json_term_writer.push_path_segment("date"); - json_term_writer.set_fast_value(DateTime::from_utc( - OffsetDateTime::parse("1985-04-12T23:20:50.52Z", &Rfc3339).unwrap(), - )); + // Date + let term = term_from_path(&["date"]); + assert!(term_stream.advance()); assert_eq!( term_stream.key(), - json_term_writer.term().serialized_value_bytes() + set_fast_val( + DateTime::from_utc( + OffsetDateTime::parse("1985-04-12T23:20:50.52Z", &Rfc3339).unwrap(), + ), + term + ) + .serialized_value_bytes() ); - json_term_writer.pop_path_segment(); - json_term_writer.push_path_segment("float"); - json_term_writer.set_fast_value(-0.2f64); + // Float + let term = term_from_path(&["float"]); assert!(term_stream.advance()); assert_eq!( term_stream.key(), - json_term_writer.term().serialized_value_bytes() + set_fast_val(-0.2f64, term).serialized_value_bytes() ); - json_term_writer.pop_path_segment(); - json_term_writer.push_path_segment("my_arr"); - json_term_writer.set_fast_value(2i64); + // Number In Array + let term = term_from_path(&["my_arr"]); assert!(term_stream.advance()); assert_eq!( term_stream.key(), - json_term_writer.term().serialized_value_bytes() + set_fast_val(2i64, term).serialized_value_bytes() ); - json_term_writer.set_fast_value(3i64); + let term = term_from_path(&["my_arr"]); assert!(term_stream.advance()); assert_eq!( term_stream.key(), - json_term_writer.term().serialized_value_bytes() + set_fast_val(3i64, term).serialized_value_bytes() ); - json_term_writer.set_fast_value(4i64); + let term = term_from_path(&["my_arr"]); assert!(term_stream.advance()); assert_eq!( term_stream.key(), - json_term_writer.term().serialized_value_bytes() + set_fast_val(4i64, term).serialized_value_bytes() ); - json_term_writer.push_path_segment("my_key"); - json_term_writer.set_str("tokens"); + // El in Array + let term = term_from_path(&["my_arr", "my_key"]); assert!(term_stream.advance()); assert_eq!( term_stream.key(), - json_term_writer.term().serialized_value_bytes() + set_str("tokens", term).serialized_value_bytes() ); - - json_term_writer.set_str("two"); + let term = term_from_path(&["my_arr", "my_key"]); assert!(term_stream.advance()); assert_eq!( term_stream.key(), - json_term_writer.term().serialized_value_bytes() + set_str("two", term).serialized_value_bytes() ); - json_term_writer.pop_path_segment(); - json_term_writer.pop_path_segment(); - json_term_writer.push_path_segment("signed"); - json_term_writer.set_fast_value(-2i64); + // Signed + let term = term_from_path(&["signed"]); assert!(term_stream.advance()); assert_eq!( term_stream.key(), - json_term_writer.term().serialized_value_bytes() + set_fast_val(-2i64, term).serialized_value_bytes() ); - json_term_writer.pop_path_segment(); - json_term_writer.push_path_segment("toto"); - json_term_writer.set_str("titi"); + let term = term_from_path(&["toto"]); assert!(term_stream.advance()); assert_eq!( term_stream.key(), - json_term_writer.term().serialized_value_bytes() + set_str("titi", term).serialized_value_bytes() ); - - json_term_writer.pop_path_segment(); - json_term_writer.push_path_segment("unsigned"); - json_term_writer.set_fast_value(1i64); + // Unsigned + let term = term_from_path(&["unsigned"]); assert!(term_stream.advance()); assert_eq!( term_stream.key(), - json_term_writer.term().serialized_value_bytes() + set_fast_val(1i64, term).serialized_value_bytes() ); + assert!(!term_stream.advance()); } @@ -774,14 +776,9 @@ mod tests { let searcher = reader.searcher(); let segment_reader = searcher.segment_reader(0u32); let inv_index = segment_reader.inverted_index(json_field).unwrap(); - let mut term = Term::with_type_and_field(Type::Json, json_field); - let mut json_term_writer = JsonTermWriter::wrap(&mut term, false); - json_term_writer.push_path_segment("mykey"); - json_term_writer.set_str("token"); - let term_info = inv_index - .get_term_info(json_term_writer.term()) - .unwrap() - .unwrap(); + let mut term = term_from_json_paths(json_field, ["mykey"].into_iter(), false); + term.append_type_and_str("token"); + let term_info = inv_index.get_term_info(&term).unwrap().unwrap(); assert_eq!( term_info, TermInfo { @@ -818,14 +815,9 @@ mod tests { let searcher = reader.searcher(); let segment_reader = searcher.segment_reader(0u32); let inv_index = segment_reader.inverted_index(json_field).unwrap(); - let mut term = Term::with_type_and_field(Type::Json, json_field); - let mut json_term_writer = JsonTermWriter::wrap(&mut term, false); - json_term_writer.push_path_segment("mykey"); - json_term_writer.set_str("two tokens"); - let term_info = inv_index - .get_term_info(json_term_writer.term()) - .unwrap() - .unwrap(); + let mut term = term_from_json_paths(json_field, ["mykey"].into_iter(), false); + term.append_type_and_str("two tokens"); + let term_info = inv_index.get_term_info(&term).unwrap().unwrap(); assert_eq!( term_info, TermInfo { @@ -863,22 +855,49 @@ mod tests { writer.commit().unwrap(); let reader = index.reader().unwrap(); let searcher = reader.searcher(); - let mut term = Term::with_type_and_field(Type::Json, json_field); - let mut json_term_writer = JsonTermWriter::wrap(&mut term, false); - json_term_writer.push_path_segment("mykey"); - json_term_writer.push_path_segment("field"); - json_term_writer.set_str("hello"); - let hello_term = json_term_writer.term().clone(); - json_term_writer.set_str("nothello"); - let nothello_term = json_term_writer.term().clone(); - json_term_writer.set_str("happy"); - let happy_term = json_term_writer.term().clone(); + + let term = term_from_json_paths(json_field, ["mykey", "field"].into_iter(), false); + + let mut hello_term = term.clone(); + hello_term.append_type_and_str("hello"); + + let mut nothello_term = term.clone(); + nothello_term.append_type_and_str("nothello"); + + let mut happy_term = term.clone(); + happy_term.append_type_and_str("happy"); + let phrase_query = PhraseQuery::new(vec![hello_term, happy_term.clone()]); assert_eq!(searcher.search(&phrase_query, &Count).unwrap(), 1); let phrase_query = PhraseQuery::new(vec![nothello_term, happy_term]); assert_eq!(searcher.search(&phrase_query, &Count).unwrap(), 0); } + #[test] + fn test_json_term_with_numeric_merge_panic_regression_bug_2283() { + // https://github.com/quickwit-oss/tantivy/issues/2283 + let mut schema_builder = Schema::builder(); + let json = schema_builder.add_json_field("json", TEXT); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut writer = index.writer_for_tests().unwrap(); + let doc = json!({"field": "a"}); + writer.add_document(doc!(json=>doc)).unwrap(); + writer.commit().unwrap(); + let doc = json!({"field": "a", "id": 1}); + writer.add_document(doc!(json=>doc.clone())).unwrap(); + writer.commit().unwrap(); + + // Force Merge + writer.wait_merging_threads().unwrap(); + let mut index_writer: IndexWriter = index.writer_for_tests().unwrap(); + let segment_ids = index + .searchable_segment_ids() + .expect("Searchable segments failed."); + index_writer.merge(&segment_ids).wait().unwrap(); + assert!(index_writer.wait_merging_threads().is_ok()); + } + #[test] fn test_bug_regression_1629_position_when_array_with_a_field_value_that_does_not_contain_any_token( ) { diff --git a/src/core/single_segment_index_writer.rs b/src/indexer/single_segment_index_writer.rs similarity index 100% rename from src/core/single_segment_index_writer.rs rename to src/indexer/single_segment_index_writer.rs diff --git a/src/lib.rs b/src/lib.rs index d06c5717a..b86a88414 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -178,6 +178,7 @@ pub use crate::future_result::FutureResult; pub type Result<T> = std::result::Result<T, TantivyError>; mod core; +#[allow(deprecated)] // Remove with index sorting pub mod indexer; #[allow(unused_doc_comments)] @@ -189,6 +190,8 @@ pub mod collector; pub mod directory; pub mod fastfield; pub mod fieldnorm; +#[allow(deprecated)] // Remove with index sorting +pub mod index; pub mod positions; pub mod postings; @@ -212,7 +215,7 @@ pub use common::{f64_to_u64, i64_to_u64, u64_to_f64, u64_to_i64, HasLen}; use once_cell::sync::Lazy; use serde::{Deserialize, Serialize}; -pub use self::docset::{DocSet, TERMINATED}; +pub use self::docset::{DocSet, COLLECT_BLOCK_BUFFER_LEN, TERMINATED}; #[deprecated( since = "0.22.0", note = "Will be removed in tantivy 0.23. Use export from snippet module instead" @@ -220,21 +223,20 @@ pub use self::docset::{DocSet, TERMINATED}; pub use self::snippet::{Snippet, SnippetGenerator}; #[doc(hidden)] pub use crate::core::json_utils; -pub use crate::core::{ - merge_field_meta_data, Executor, FieldMetadata, Index, IndexBuilder, IndexMeta, IndexSettings, - IndexSortByField, InvertedIndexReader, Order, Searcher, SearcherGeneration, Segment, - SegmentComponent, SegmentId, SegmentMeta, SegmentReader, SingleSegmentIndexWriter, -}; +pub use crate::core::{Executor, Searcher, SearcherGeneration}; pub use crate::directory::Directory; -pub use crate::indexer::IndexWriter; +#[allow(deprecated)] // Remove with index sorting +pub use crate::index::{ + Index, IndexBuilder, IndexMeta, IndexSettings, IndexSortByField, InvertedIndexReader, Order, + Segment, SegmentComponent, SegmentId, SegmentMeta, SegmentReader, +}; #[deprecated( since = "0.22.0", note = "Will be removed in tantivy 0.23. Use export from indexer module instead" )] -pub use crate::indexer::{merge_filtered_segments, merge_indices, PreparedCommit}; +pub use crate::indexer::PreparedCommit; +pub use crate::indexer::{IndexWriter, SingleSegmentIndexWriter}; pub use crate::postings::Postings; -#[allow(deprecated)] -pub use crate::schema::DatePrecision; pub use crate::schema::{DateOptions, DateTimePrecision, Document, TantivyDocument, Term}; /// Index format version. @@ -253,7 +255,7 @@ pub struct Version { impl fmt::Debug for Version { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.to_string()) + fmt::Display::fmt(self, f) } } @@ -264,9 +266,10 @@ static VERSION: Lazy<Version> = Lazy::new(|| Version { index_format_version: INDEX_FORMAT_VERSION, }); -impl ToString for Version { - fn to_string(&self) -> String { - format!( +impl fmt::Display for Version { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, "tantivy v{}.{}.{}, index_format v{}", self.major, self.minor, self.patch, self.index_format_version ) @@ -338,7 +341,7 @@ impl DocAddress { /// /// The id used for the segment is actually an ordinal /// in the list of `Segment`s held by a `Searcher`. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] pub struct DocAddress { /// The segment ordinal id that identifies the segment /// hosting the document in the `Searcher` it is called from. @@ -386,11 +389,10 @@ pub mod tests { use time::OffsetDateTime; use crate::collector::tests::TEST_COLLECTOR_WITH_SCORE; - use crate::core::SegmentReader; use crate::docset::{DocSet, TERMINATED}; + use crate::index::SegmentReader; use crate::merge_policy::NoMergePolicy; use crate::query::BooleanQuery; - use crate::schema::document::Value; use crate::schema::*; use crate::{DateTime, DocAddress, Index, IndexWriter, Postings, ReloadPolicy}; diff --git a/src/postings/block_segment_postings.rs b/src/postings/block_segment_postings.rs index 366809a95..d8d2735b7 100644 --- a/src/postings/block_segment_postings.rs +++ b/src/postings/block_segment_postings.rs @@ -383,8 +383,8 @@ mod tests { use common::HasLen; use super::BlockSegmentPostings; - use crate::core::Index; use crate::docset::{DocSet, TERMINATED}; + use crate::index::Index; use crate::postings::compression::COMPRESSION_BLOCK_SIZE; use crate::postings::postings::Postings; use crate::postings::SegmentPostings; diff --git a/src/postings/compression/mod.rs b/src/postings/compression/mod.rs index f8a8a3193..3928be51b 100644 --- a/src/postings/compression/mod.rs +++ b/src/postings/compression/mod.rs @@ -14,7 +14,6 @@ pub fn compressed_block_size(num_bits: u8) -> usize { pub struct BlockEncoder { bitpacker: BitPacker4x, pub output: [u8; COMPRESSED_BLOCK_MAX_SIZE], - pub output_len: usize, } impl Default for BlockEncoder { @@ -28,7 +27,6 @@ impl BlockEncoder { BlockEncoder { bitpacker: BitPacker4x::new(), output: [0u8; COMPRESSED_BLOCK_MAX_SIZE], - output_len: 0, } } diff --git a/src/postings/json_postings_writer.rs b/src/postings/json_postings_writer.rs index 9f0d8eb06..ed3d5c24f 100644 --- a/src/postings/json_postings_writer.rs +++ b/src/postings/json_postings_writer.rs @@ -1,5 +1,6 @@ use std::io; +use common::json_path_writer::JSON_END_OF_PATH; use stacker::Addr; use crate::indexer::doc_id_mapping::DocIdMapping; @@ -7,7 +8,7 @@ use crate::indexer::path_to_unordered_id::OrderedPathId; use crate::postings::postings_writer::SpecializedPostingsWriter; use crate::postings::recorder::{BufferLender, DocIdRecorder, Recorder}; use crate::postings::{FieldSerializer, IndexingContext, IndexingPosition, PostingsWriter}; -use crate::schema::{Field, Type, JSON_END_OF_PATH}; +use crate::schema::{Field, Type}; use crate::tokenizer::TokenStream; use crate::{DocId, Term}; @@ -67,10 +68,18 @@ impl<Rec: Recorder> PostingsWriter for JsonPostingsWriter<Rec> { ) -> io::Result<()> { let mut term_buffer = Term::with_capacity(48); let mut buffer_lender = BufferLender::default(); + term_buffer.clear_with_field_and_type(Type::Json, Field::from_field_id(0)); + let mut prev_term_id = u32::MAX; + let mut term_path_len = 0; // this will be set in the first iteration for (_field, path_id, term, addr) in term_addrs { - term_buffer.clear_with_field_and_type(Type::Json, Field::from_field_id(0)); - term_buffer.append_bytes(ordered_id_to_path[path_id.path_id() as usize].as_bytes()); - term_buffer.append_bytes(&[JSON_END_OF_PATH]); + if prev_term_id != path_id.path_id() { + term_buffer.truncate_value_bytes(0); + term_buffer.append_path(ordered_id_to_path[path_id.path_id() as usize].as_bytes()); + term_buffer.append_bytes(&[JSON_END_OF_PATH]); + term_path_len = term_buffer.len_bytes(); + prev_term_id = path_id.path_id(); + } + term_buffer.truncate_value_bytes(term_path_len); term_buffer.append_bytes(term); if let Some(json_value) = term_buffer.value().as_json_value_bytes() { let typ = json_value.typ(); diff --git a/src/postings/mod.rs b/src/postings/mod.rs index 32c4b7bd8..5fd90032d 100644 --- a/src/postings/mod.rs +++ b/src/postings/mod.rs @@ -42,9 +42,9 @@ pub mod tests { use std::mem; use super::{InvertedIndexSerializer, Postings}; - use crate::core::{Index, SegmentComponent, SegmentReader}; use crate::docset::{DocSet, TERMINATED}; use crate::fieldnorm::FieldNormReader; + use crate::index::{Index, SegmentComponent, SegmentReader}; use crate::indexer::operation::AddOperation; use crate::indexer::SegmentWriter; use crate::query::Scorer; diff --git a/src/postings/serializer.rs b/src/postings/serializer.rs index b9bf8f0d3..c0757f8fd 100644 --- a/src/postings/serializer.rs +++ b/src/postings/serializer.rs @@ -4,9 +4,9 @@ use std::io::{self, Write}; use common::{BinarySerializable, CountingWriter, VInt}; use super::TermInfo; -use crate::core::Segment; use crate::directory::{CompositeWrite, WritePtr}; use crate::fieldnorm::FieldNormReader; +use crate::index::Segment; use crate::positions::PositionSerializer; use crate::postings::compression::{BlockEncoder, VIntEncoder, COMPRESSION_BLOCK_SIZE}; use crate::postings::skip::SkipSerializer; diff --git a/src/postings/skip.rs b/src/postings/skip.rs index 1f5eb3577..fe5a8df88 100644 --- a/src/postings/skip.rs +++ b/src/postings/skip.rs @@ -1,5 +1,3 @@ -use std::convert::TryInto; - use crate::directory::OwnedBytes; use crate::postings::compression::{compressed_block_size, COMPRESSION_BLOCK_SIZE}; use crate::query::Bm25Weight; diff --git a/src/postings/term_info.rs b/src/postings/term_info.rs index 4f3045d7f..94e640304 100644 --- a/src/postings/term_info.rs +++ b/src/postings/term_info.rs @@ -1,5 +1,4 @@ use std::io; -use std::iter::ExactSizeIterator; use std::ops::Range; use common::{BinarySerializable, FixedSize}; diff --git a/src/query/all_query.rs b/src/query/all_query.rs index fb88bf90c..149041b04 100644 --- a/src/query/all_query.rs +++ b/src/query/all_query.rs @@ -1,5 +1,5 @@ -use crate::core::SegmentReader; -use crate::docset::{DocSet, BUFFER_LEN, TERMINATED}; +use crate::docset::{DocSet, COLLECT_BLOCK_BUFFER_LEN, TERMINATED}; +use crate::index::SegmentReader; use crate::query::boost_query::BoostScorer; use crate::query::explanation::does_not_match; use crate::query::{EnableScoring, Explanation, Query, Scorer, Weight}; @@ -54,7 +54,7 @@ impl DocSet for AllScorer { self.doc } - fn fill_buffer(&mut self, buffer: &mut [DocId; BUFFER_LEN]) -> usize { + fn fill_buffer(&mut self, buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN]) -> usize { if self.doc() == TERMINATED { return 0; } @@ -96,7 +96,7 @@ impl Scorer for AllScorer { #[cfg(test)] mod tests { use super::AllQuery; - use crate::docset::{DocSet, BUFFER_LEN, TERMINATED}; + use crate::docset::{DocSet, COLLECT_BLOCK_BUFFER_LEN, TERMINATED}; use crate::query::{AllScorer, EnableScoring, Query}; use crate::schema::{Schema, TEXT}; use crate::{Index, IndexWriter}; @@ -162,16 +162,16 @@ mod tests { pub fn test_fill_buffer() { let mut postings = AllScorer { doc: 0u32, - max_doc: BUFFER_LEN as u32 * 2 + 9, + max_doc: COLLECT_BLOCK_BUFFER_LEN as u32 * 2 + 9, }; - let mut buffer = [0u32; BUFFER_LEN]; - assert_eq!(postings.fill_buffer(&mut buffer), BUFFER_LEN); - for i in 0u32..BUFFER_LEN as u32 { + let mut buffer = [0u32; COLLECT_BLOCK_BUFFER_LEN]; + assert_eq!(postings.fill_buffer(&mut buffer), COLLECT_BLOCK_BUFFER_LEN); + for i in 0u32..COLLECT_BLOCK_BUFFER_LEN as u32 { assert_eq!(buffer[i as usize], i); } - assert_eq!(postings.fill_buffer(&mut buffer), BUFFER_LEN); - for i in 0u32..BUFFER_LEN as u32 { - assert_eq!(buffer[i as usize], i + BUFFER_LEN as u32); + assert_eq!(postings.fill_buffer(&mut buffer), COLLECT_BLOCK_BUFFER_LEN); + for i in 0u32..COLLECT_BLOCK_BUFFER_LEN as u32 { + assert_eq!(buffer[i as usize], i + COLLECT_BLOCK_BUFFER_LEN as u32); } assert_eq!(postings.fill_buffer(&mut buffer), 9); } diff --git a/src/query/automaton_weight.rs b/src/query/automaton_weight.rs index 76c4e8286..ef675864b 100644 --- a/src/query/automaton_weight.rs +++ b/src/query/automaton_weight.rs @@ -5,7 +5,7 @@ use common::BitSet; use tantivy_fst::Automaton; use super::phrase_prefix_query::prefix_end; -use crate::core::SegmentReader; +use crate::index::SegmentReader; use crate::query::{BitSetDocSet, ConstScorer, Explanation, Scorer, Weight}; use crate::schema::{Field, IndexRecordOption}; use crate::termdict::{TermDictionary, TermStreamer}; diff --git a/src/query/boolean_query/boolean_weight.rs b/src/query/boolean_query/boolean_weight.rs index 4979ab1f0..ece6217d2 100644 --- a/src/query/boolean_query/boolean_weight.rs +++ b/src/query/boolean_query/boolean_weight.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; -use crate::core::SegmentReader; -use crate::docset::BUFFER_LEN; +use crate::docset::COLLECT_BLOCK_BUFFER_LEN; +use crate::index::SegmentReader; use crate::postings::FreqReadingOption; use crate::query::explanation::does_not_match; use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner}; @@ -228,7 +228,7 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin callback: &mut dyn FnMut(&[DocId]), ) -> crate::Result<()> { let scorer = self.complex_scorer(reader, 1.0, || DoNothingCombiner)?; - let mut buffer = [0u32; BUFFER_LEN]; + let mut buffer = [0u32; COLLECT_BLOCK_BUFFER_LEN]; match scorer { SpecializedScorer::TermUnion(term_scorers) => { diff --git a/src/query/boost_query.rs b/src/query/boost_query.rs index e7c25114f..4d2352d4d 100644 --- a/src/query/boost_query.rs +++ b/src/query/boost_query.rs @@ -1,6 +1,6 @@ use std::fmt; -use crate::docset::BUFFER_LEN; +use crate::docset::COLLECT_BLOCK_BUFFER_LEN; use crate::fastfield::AliveBitSet; use crate::query::{EnableScoring, Explanation, Query, Scorer, Weight}; use crate::{DocId, DocSet, Score, SegmentReader, Term}; @@ -105,7 +105,7 @@ impl<S: Scorer> DocSet for BoostScorer<S> { self.underlying.seek(target) } - fn fill_buffer(&mut self, buffer: &mut [DocId; BUFFER_LEN]) -> usize { + fn fill_buffer(&mut self, buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN]) -> usize { self.underlying.fill_buffer(buffer) } diff --git a/src/query/const_score_query.rs b/src/query/const_score_query.rs index 80f81fdfc..8f27b8285 100644 --- a/src/query/const_score_query.rs +++ b/src/query/const_score_query.rs @@ -1,6 +1,6 @@ use std::fmt; -use crate::docset::BUFFER_LEN; +use crate::docset::COLLECT_BLOCK_BUFFER_LEN; use crate::query::{EnableScoring, Explanation, Query, Scorer, Weight}; use crate::{DocId, DocSet, Score, SegmentReader, TantivyError, Term}; @@ -119,7 +119,7 @@ impl<TDocSet: DocSet> DocSet for ConstScorer<TDocSet> { self.docset.seek(target) } - fn fill_buffer(&mut self, buffer: &mut [DocId; BUFFER_LEN]) -> usize { + fn fill_buffer(&mut self, buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN]) -> usize { self.docset.fill_buffer(buffer) } diff --git a/src/query/exist_query.rs b/src/query/exist_query.rs index 7de8ee513..f028ebaa9 100644 --- a/src/query/exist_query.rs +++ b/src/query/exist_query.rs @@ -3,8 +3,8 @@ use core::fmt::Debug; use columnar::{ColumnIndex, DynamicColumn}; use super::{ConstScorer, EmptyScorer}; -use crate::core::SegmentReader; use crate::docset::{DocSet, TERMINATED}; +use crate::index::SegmentReader; use crate::query::explanation::does_not_match; use crate::query::{EnableScoring, Explanation, Query, Scorer, Weight}; use crate::{DocId, Score, TantivyError}; @@ -149,7 +149,7 @@ mod tests { use crate::query::exist_query::ExistsQuery; use crate::query::{BooleanQuery, RangeQuery}; use crate::schema::{Facet, FacetOptions, Schema, FAST, INDEXED, STRING, TEXT}; - use crate::{doc, Index, Searcher}; + use crate::{Index, Searcher}; #[test] fn test_exists_query_simple() -> crate::Result<()> { diff --git a/src/query/fuzzy_query.rs b/src/query/fuzzy_query.rs index 7b7379acd..a2e3f2a6b 100644 --- a/src/query/fuzzy_query.rs +++ b/src/query/fuzzy_query.rs @@ -84,7 +84,7 @@ pub struct FuzzyTermQuery { distance: u8, /// Should a transposition cost 1 or 2? transposition_cost_one: bool, - /// + /// is a starts with query prefix: bool, } diff --git a/src/query/phrase_prefix_query/phrase_prefix_weight.rs b/src/query/phrase_prefix_query/phrase_prefix_weight.rs index ab34bf2c9..866c3c2c5 100644 --- a/src/query/phrase_prefix_query/phrase_prefix_weight.rs +++ b/src/query/phrase_prefix_query/phrase_prefix_weight.rs @@ -1,6 +1,6 @@ use super::{prefix_end, PhrasePrefixScorer}; -use crate::core::SegmentReader; use crate::fieldnorm::FieldNormReader; +use crate::index::SegmentReader; use crate::postings::SegmentPostings; use crate::query::bm25::Bm25Weight; use crate::query::explanation::does_not_match; @@ -157,8 +157,8 @@ impl Weight for PhrasePrefixWeight { #[cfg(test)] mod tests { - use crate::core::Index; use crate::docset::TERMINATED; + use crate::index::Index; use crate::query::{EnableScoring, PhrasePrefixQuery, Query}; use crate::schema::{Schema, TEXT}; use crate::{DocSet, IndexWriter, Term}; diff --git a/src/query/phrase_query/mod.rs b/src/query/phrase_query/mod.rs index 4809f46f8..7b8d3e007 100644 --- a/src/query/phrase_query/mod.rs +++ b/src/query/phrase_query/mod.rs @@ -14,7 +14,7 @@ pub mod tests { use super::*; use crate::collector::tests::{TEST_COLLECTOR_WITHOUT_SCORE, TEST_COLLECTOR_WITH_SCORE}; - use crate::core::Index; + use crate::index::Index; use crate::query::{EnableScoring, QueryParser, Weight}; use crate::schema::{Schema, Term, TEXT}; use crate::{assert_nearly_equals, DocAddress, DocId, IndexWriter, TERMINATED}; diff --git a/src/query/phrase_query/phrase_weight.rs b/src/query/phrase_query/phrase_weight.rs index 5b61eafb8..6e97bca7f 100644 --- a/src/query/phrase_query/phrase_weight.rs +++ b/src/query/phrase_query/phrase_weight.rs @@ -1,6 +1,6 @@ use super::PhraseScorer; -use crate::core::SegmentReader; use crate::fieldnorm::FieldNormReader; +use crate::index::SegmentReader; use crate::postings::SegmentPostings; use crate::query::bm25::Bm25Weight; use crate::query::explanation::does_not_match; diff --git a/src/query/query_parser/query_parser.rs b/src/query/query_parser/query_parser.rs index 6e7102b8f..aedd0c433 100644 --- a/src/query/query_parser/query_parser.rs +++ b/src/query/query_parser/query_parser.rs @@ -10,10 +10,10 @@ use query_grammar::{UserInputAst, UserInputBound, UserInputLeaf, UserInputLitera use rustc_hash::FxHashMap; use super::logical_ast::*; -use crate::core::json_utils::{ - convert_to_fast_value_and_get_term, set_string_and_get_terms, JsonTermWriter, +use crate::index::Index; +use crate::json_utils::{ + convert_to_fast_value_and_append_to_json_term, split_json_path, term_from_json_paths, }; -use crate::core::Index; use crate::query::range_query::{is_type_valid_for_fastfield_range_query, RangeQuery}; use crate::query::{ AllQuery, BooleanQuery, BoostQuery, EmptyQuery, FuzzyTermQuery, Occur, PhrasePrefixQuery, @@ -965,20 +965,33 @@ fn generate_literals_for_json_object( })?; let index_record_option = text_options.index_option(); let mut logical_literals = Vec::new(); - let mut term = Term::with_capacity(100); - let mut json_term_writer = JsonTermWriter::from_field_and_json_path( - field, - json_path, - json_options.is_expand_dots_enabled(), - &mut term, - ); - if let Some(term) = convert_to_fast_value_and_get_term(&mut json_term_writer, phrase) { + + let paths = split_json_path(json_path); + let get_term_with_path = || { + term_from_json_paths( + field, + paths.iter().map(|el| el.as_str()), + json_options.is_expand_dots_enabled(), + ) + }; + + // Try to convert the phrase to a fast value + if let Some(term) = convert_to_fast_value_and_append_to_json_term(get_term_with_path(), phrase) + { logical_literals.push(LogicalLiteral::Term(term)); } - let terms = set_string_and_get_terms(&mut json_term_writer, phrase, &mut text_analyzer); - drop(json_term_writer); - if terms.len() <= 1 { - for (_, term) in terms { + + // Try to tokenize the phrase and create Terms. + let mut positions_and_terms = Vec::<(usize, Term)>::new(); + let mut token_stream = text_analyzer.token_stream(phrase); + token_stream.process(&mut |token| { + let mut term = get_term_with_path(); + term.append_type_and_str(&token.text); + positions_and_terms.push((token.position, term.clone())); + }); + + if positions_and_terms.len() <= 1 { + for (_, term) in positions_and_terms { logical_literals.push(LogicalLiteral::Term(term)); } return Ok(logical_literals); @@ -989,7 +1002,7 @@ fn generate_literals_for_json_object( )); } logical_literals.push(LogicalLiteral::Phrase { - terms, + terms: positions_and_terms, slop: 0, prefix: false, }); diff --git a/src/query/range_query/range_query.rs b/src/query/range_query/range_query.rs index 86ad4ac8f..ac2327c7a 100644 --- a/src/query/range_query/range_query.rs +++ b/src/query/range_query/range_query.rs @@ -7,8 +7,8 @@ use common::{BinarySerializable, BitSet}; use super::map_bound; use super::range_query_u64_fastfield::FastFieldRangeWeight; -use crate::core::SegmentReader; use crate::error::TantivyError; +use crate::index::SegmentReader; use crate::query::explanation::does_not_match; use crate::query::range_query::range_query_ip_fastfield::IPFastFieldRangeWeight; use crate::query::range_query::{is_type_valid_for_fastfield_range_query, map_bound_res}; @@ -477,7 +477,7 @@ mod tests { use crate::schema::{ Field, IntoIpv6Addr, Schema, TantivyDocument, FAST, INDEXED, STORED, TEXT, }; - use crate::{doc, Index, IndexWriter}; + use crate::{Index, IndexWriter}; #[test] fn test_range_query_simple() -> crate::Result<()> { diff --git a/src/query/regex_query.rs b/src/query/regex_query.rs index 50527980a..815832d31 100644 --- a/src/query/regex_query.rs +++ b/src/query/regex_query.rs @@ -63,7 +63,7 @@ impl RegexQuery { /// Creates a new RegexQuery from a given pattern pub fn from_pattern(regex_pattern: &str, field: Field) -> crate::Result<Self> { let regex = Regex::new(regex_pattern) - .map_err(|_| TantivyError::InvalidArgument(regex_pattern.to_string()))?; + .map_err(|err| TantivyError::InvalidArgument(format!("RegexQueryError: {err}")))?; Ok(RegexQuery::from_regex(regex, field)) } @@ -176,4 +176,16 @@ mod test { verify_regex_query(matching_one, matching_zero, reader); Ok(()) } + + #[test] + pub fn test_pattern_error() { + let (_reader, field) = build_test_index().unwrap(); + + match RegexQuery::from_pattern(r"(foo", field) { + Err(crate::TantivyError::InvalidArgument(msg)) => { + assert!(msg.contains("error: unclosed group")) + } + res => panic!("unexpected result: {:?}", res), + } + } } diff --git a/src/query/term_query/term_query.rs b/src/query/term_query/term_query.rs index 832d07895..fed4ca481 100644 --- a/src/query/term_query/term_query.rs +++ b/src/query/term_query/term_query.rs @@ -139,7 +139,7 @@ mod tests { use crate::collector::{Count, TopDocs}; use crate::query::{Query, QueryParser, TermQuery}; use crate::schema::{IndexRecordOption, IntoIpv6Addr, Schema, INDEXED, STORED}; - use crate::{doc, Index, IndexWriter, Term}; + use crate::{Index, IndexWriter, Term}; #[test] fn search_ip_test() { diff --git a/src/query/term_query/term_weight.rs b/src/query/term_query/term_weight.rs index 69064644e..a70c8ce8f 100644 --- a/src/query/term_query/term_weight.rs +++ b/src/query/term_query/term_weight.rs @@ -1,7 +1,7 @@ use super::term_scorer::TermScorer; -use crate::core::SegmentReader; -use crate::docset::{DocSet, BUFFER_LEN}; +use crate::docset::{DocSet, COLLECT_BLOCK_BUFFER_LEN}; use crate::fieldnorm::FieldNormReader; +use crate::index::SegmentReader; use crate::postings::SegmentPostings; use crate::query::bm25::Bm25Weight; use crate::query::explanation::does_not_match; @@ -64,7 +64,7 @@ impl Weight for TermWeight { callback: &mut dyn FnMut(&[DocId]), ) -> crate::Result<()> { let mut scorer = self.specialized_scorer(reader, 1.0)?; - let mut buffer = [0u32; BUFFER_LEN]; + let mut buffer = [0u32; COLLECT_BLOCK_BUFFER_LEN]; for_each_docset_buffered(&mut scorer, &mut buffer, callback); Ok(()) } diff --git a/src/query/vec_docset.rs b/src/query/vec_docset.rs index e0a7b9f6b..5c87a71ce 100644 --- a/src/query/vec_docset.rs +++ b/src/query/vec_docset.rs @@ -53,8 +53,7 @@ impl HasLen for VecDocSet { pub mod tests { use super::*; - use crate::docset::{DocSet, BUFFER_LEN}; - use crate::DocId; + use crate::docset::COLLECT_BLOCK_BUFFER_LEN; #[test] pub fn test_vec_postings() { @@ -72,16 +71,16 @@ pub mod tests { #[test] pub fn test_fill_buffer() { - let doc_ids: Vec<DocId> = (1u32..=(BUFFER_LEN as u32 * 2 + 9)).collect(); + let doc_ids: Vec<DocId> = (1u32..=(COLLECT_BLOCK_BUFFER_LEN as u32 * 2 + 9)).collect(); let mut postings = VecDocSet::from(doc_ids); - let mut buffer = [0u32; BUFFER_LEN]; - assert_eq!(postings.fill_buffer(&mut buffer), BUFFER_LEN); - for i in 0u32..BUFFER_LEN as u32 { + let mut buffer = [0u32; COLLECT_BLOCK_BUFFER_LEN]; + assert_eq!(postings.fill_buffer(&mut buffer), COLLECT_BLOCK_BUFFER_LEN); + for i in 0u32..COLLECT_BLOCK_BUFFER_LEN as u32 { assert_eq!(buffer[i as usize], i + 1); } - assert_eq!(postings.fill_buffer(&mut buffer), BUFFER_LEN); - for i in 0u32..BUFFER_LEN as u32 { - assert_eq!(buffer[i as usize], i + 1 + BUFFER_LEN as u32); + assert_eq!(postings.fill_buffer(&mut buffer), COLLECT_BLOCK_BUFFER_LEN); + for i in 0u32..COLLECT_BLOCK_BUFFER_LEN as u32 { + assert_eq!(buffer[i as usize], i + 1 + COLLECT_BLOCK_BUFFER_LEN as u32); } assert_eq!(postings.fill_buffer(&mut buffer), 9); } diff --git a/src/query/weight.rs b/src/query/weight.rs index eea4d28a8..23ff55c04 100644 --- a/src/query/weight.rs +++ b/src/query/weight.rs @@ -1,6 +1,6 @@ use super::Scorer; -use crate::core::SegmentReader; -use crate::docset::BUFFER_LEN; +use crate::docset::COLLECT_BLOCK_BUFFER_LEN; +use crate::index::SegmentReader; use crate::query::Explanation; use crate::{DocId, DocSet, Score, TERMINATED}; @@ -22,7 +22,7 @@ pub(crate) fn for_each_scorer<TScorer: Scorer + ?Sized>( #[inline] pub(crate) fn for_each_docset_buffered<T: DocSet + ?Sized>( docset: &mut T, - buffer: &mut [DocId; BUFFER_LEN], + buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN], mut callback: impl FnMut(&[DocId]), ) { loop { @@ -105,7 +105,7 @@ pub trait Weight: Send + Sync + 'static { ) -> crate::Result<()> { let mut docset = self.scorer(reader, 1.0)?; - let mut buffer = [0u32; BUFFER_LEN]; + let mut buffer = [0u32; COLLECT_BLOCK_BUFFER_LEN]; for_each_docset_buffered(&mut docset, &mut buffer, callback); Ok(()) } diff --git a/src/reader/mod.rs b/src/reader/mod.rs index 7c57580a1..39e4c6e00 100644 --- a/src/reader/mod.rs +++ b/src/reader/mod.rs @@ -1,6 +1,5 @@ mod warming; -use std::convert::TryInto; use std::sync::atomic::AtomicU64; use std::sync::{atomic, Arc, Weak}; diff --git a/src/schema/date_time_options.rs b/src/schema/date_time_options.rs index c6c03d795..44465a409 100644 --- a/src/schema/date_time_options.rs +++ b/src/schema/date_time_options.rs @@ -1,7 +1,5 @@ use std::ops::BitOr; -#[allow(deprecated)] -pub use common::DatePrecision; pub use common::DateTimePrecision; use serde::{Deserialize, Serialize}; diff --git a/src/schema/document/de.rs b/src/schema/document/de.rs index 706f3b768..657ad1430 100644 --- a/src/schema/document/de.rs +++ b/src/schema/document/de.rs @@ -160,7 +160,7 @@ pub enum ValueType { /// A dynamic object value. Object, /// A JSON object value. Deprecated. - #[deprecated] + #[deprecated(note = "We keep this for backwards compatibility, use Object instead")] JSONObject, } @@ -819,7 +819,6 @@ mod tests { use crate::schema::document::existing_type_impls::JsonObjectIter; use crate::schema::document::se::BinaryValueSerializer; use crate::schema::document::{ReferenceValue, ReferenceValueLeaf}; - use crate::schema::OwnedValue; fn serialize_value<'a>(value: ReferenceValue<'a, &'a serde_json::Value>) -> Vec<u8> { let mut writer = Vec::new(); @@ -889,7 +888,7 @@ mod tests { #[test] fn test_array_serialize() { - let elements = vec![serde_json::Value::Null, serde_json::Value::Null]; + let elements = [serde_json::Value::Null, serde_json::Value::Null]; let result = serialize_value(ReferenceValue::Array(elements.iter())); let value = deserialize_value(result); assert_eq!( @@ -900,7 +899,7 @@ mod tests { ]), ); - let elements = vec![ + let elements = [ serde_json::Value::String("Hello, world".into()), serde_json::Value::String("Some demo".into()), ]; @@ -914,12 +913,12 @@ mod tests { ]), ); - let elements = vec![]; + let elements = []; let result = serialize_value(ReferenceValue::Array(elements.iter())); let value = deserialize_value(result); assert_eq!(value, crate::schema::OwnedValue::Array(vec![])); - let elements = vec![ + let elements = [ serde_json::Value::Null, serde_json::Value::String("Hello, world".into()), serde_json::Value::Number(12345.into()), diff --git a/src/schema/document/default_document.rs b/src/schema/document/default_document.rs index eda44ee8e..fcf374dfe 100644 --- a/src/schema/document/default_document.rs +++ b/src/schema/document/default_document.rs @@ -256,7 +256,6 @@ impl DocParsingError { #[cfg(test)] mod tests { - use crate::schema::document::default_document::TantivyDocument; use crate::schema::*; #[test] diff --git a/src/schema/document/owned_value.rs b/src/schema/document/owned_value.rs index 3369dd979..3dc7a1f67 100644 --- a/src/schema/document/owned_value.rs +++ b/src/schema/document/owned_value.rs @@ -443,9 +443,7 @@ impl<'a> Iterator for ObjectMapIter<'a> { mod tests { use super::*; use crate::schema::{BytesOptions, Schema}; - use crate::time::format_description::well_known::Rfc3339; - use crate::time::OffsetDateTime; - use crate::{DateTime, Document, TantivyDocument}; + use crate::{Document, TantivyDocument}; #[test] fn test_parse_bytes_doc() { diff --git a/src/schema/document/se.rs b/src/schema/document/se.rs index 10e0657e0..8acffb36b 100644 --- a/src/schema/document/se.rs +++ b/src/schema/document/se.rs @@ -453,7 +453,7 @@ mod tests { #[test] fn test_array_serialize() { - let elements = vec![serde_json::Value::Null, serde_json::Value::Null]; + let elements = [serde_json::Value::Null, serde_json::Value::Null]; let result = serialize_value(ReferenceValue::Array(elements.iter())); let expected = binary_repr!( collection type_codes::ARRAY_CODE, @@ -466,7 +466,7 @@ mod tests { "Expected serialized value to match the binary representation" ); - let elements = vec![ + let elements = [ serde_json::Value::String("Hello, world".into()), serde_json::Value::String("Some demo".into()), ]; @@ -482,7 +482,7 @@ mod tests { "Expected serialized value to match the binary representation" ); - let elements = vec![]; + let elements = []; let result = serialize_value(ReferenceValue::Array(elements.iter())); let expected = binary_repr!( collection type_codes::ARRAY_CODE, @@ -493,7 +493,7 @@ mod tests { "Expected serialized value to match the binary representation" ); - let elements = vec![ + let elements = [ serde_json::Value::Null, serde_json::Value::String("Hello, world".into()), serde_json::Value::Number(12345.into()), diff --git a/src/schema/field_entry.rs b/src/schema/field_entry.rs index 9fa643ca0..8d2f9b230 100644 --- a/src/schema/field_entry.rs +++ b/src/schema/field_entry.rs @@ -136,7 +136,6 @@ impl FieldEntry { #[cfg(test)] mod tests { - use serde_json; use super::*; use crate::schema::{Schema, TextFieldIndexing, TEXT}; diff --git a/src/schema/mod.rs b/src/schema/mod.rs index ced9a4b8c..b4c3b037e 100644 --- a/src/schema/mod.rs +++ b/src/schema/mod.rs @@ -130,8 +130,6 @@ mod text_options; use columnar::ColumnType; pub use self::bytes_options::BytesOptions; -#[allow(deprecated)] -pub use self::date_time_options::DatePrecision; pub use self::date_time_options::{DateOptions, DateTimePrecision, DATE_TIME_PRECISION_INDEXED}; pub use self::document::{DocParsingError, Document, OwnedValue, TantivyDocument, Value}; pub(crate) use self::facet::FACET_SEP_BYTE; @@ -146,11 +144,9 @@ pub use self::index_record_option::IndexRecordOption; pub use self::ip_options::{IntoIpv6Addr, IpAddrOptions}; pub use self::json_object_options::JsonObjectOptions; pub use self::named_field_document::NamedFieldDocument; -#[allow(deprecated)] -pub use self::numeric_options::IntOptions; pub use self::numeric_options::NumericOptions; pub use self::schema::{Schema, SchemaBuilder}; -pub use self::term::{Term, ValueBytes, JSON_END_OF_PATH}; +pub use self::term::{Term, ValueBytes}; pub use self::text_options::{TextFieldIndexing, TextOptions, STRING, TEXT}; /// Validator for a potential `field_name`. diff --git a/src/schema/numeric_options.rs b/src/schema/numeric_options.rs index 432e5563e..db36b523e 100644 --- a/src/schema/numeric_options.rs +++ b/src/schema/numeric_options.rs @@ -5,10 +5,6 @@ use serde::{Deserialize, Serialize}; use super::flags::CoerceFlag; use crate::schema::flags::{FastFlag, IndexedFlag, SchemaFlagList, StoredFlag}; -#[deprecated(since = "0.17.0", note = "Use NumericOptions instead.")] -/// Deprecated use [`NumericOptions`] instead. -pub type IntOptions = NumericOptions; - /// Define how an `u64`, `i64`, or `f64` field should be handled by tantivy. #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)] #[serde(from = "NumericOptionsDeser")] diff --git a/src/schema/schema.rs b/src/schema/schema.rs index 9fec25c05..d3215a37c 100644 --- a/src/schema/schema.rs +++ b/src/schema/schema.rs @@ -6,10 +6,8 @@ use serde::de::{SeqAccess, Visitor}; use serde::ser::SerializeSeq; use serde::{Deserialize, Deserializer, Serialize, Serializer}; -use super::ip_options::IpAddrOptions; use super::*; use crate::json_utils::split_json_path; -use crate::schema::bytes_options::BytesOptions; use crate::TantivyError; /// Tantivy has a very strict schema. @@ -421,9 +419,7 @@ mod tests { use matches::{assert_matches, matches}; use pretty_assertions::assert_eq; - use serde_json; - use crate::schema::document::Value; use crate::schema::field_type::ValueParsingError; use crate::schema::schema::DocParsingError::InvalidJson; use crate::schema::*; diff --git a/src/schema/term.rs b/src/schema/term.rs index db707e294..3ac5d0ac4 100644 --- a/src/schema/term.rs +++ b/src/schema/term.rs @@ -1,9 +1,9 @@ -use std::convert::TryInto; use std::hash::{Hash, Hasher}; use std::net::Ipv6Addr; use std::{fmt, str}; use columnar::{MonotonicallyMappableToU128, MonotonicallyMappableToU64}; +use common::json_path_writer::{JSON_END_OF_PATH, JSON_PATH_SEGMENT_SEP_STR}; use super::date_time_options::DATE_TIME_PRECISION_INDEXED; use super::Field; @@ -11,15 +11,6 @@ use crate::fastfield::FastValue; use crate::schema::{Facet, Type}; use crate::DateTime; -/// Separates the different segments of a json path. -pub const JSON_PATH_SEGMENT_SEP: u8 = 1u8; -pub const JSON_PATH_SEGMENT_SEP_STR: &str = - unsafe { std::str::from_utf8_unchecked(&[JSON_PATH_SEGMENT_SEP]) }; - -/// Separates the json path and the value in -/// a JSON term binary representation. -pub const JSON_END_OF_PATH: u8 = 0u8; - /// Term represents the value that the token can take. /// It's a serialized representation over different types. /// @@ -170,6 +161,10 @@ impl Term { self.set_bytes(val.to_u64().to_be_bytes().as_ref()); } + /// Append a type marker + fast value to a term. + /// This is used in JSON type to append a fast value after the path. + /// + /// It will not clear existing bytes. pub(crate) fn append_type_and_fast_value<T: FastValue>(&mut self, val: T) { self.0.push(T::to_type().to_code()); let value = if T::to_type() == Type::Date { @@ -182,6 +177,15 @@ impl Term { self.0.extend(value.to_be_bytes().as_ref()); } + /// Append a string type marker + string to a term. + /// This is used in JSON type to append a str after the path. + /// + /// It will not clear existing bytes. + pub(crate) fn append_type_and_str(&mut self, val: &str) { + self.0.push(Type::Str.to_code()); + self.0.extend(val.as_bytes().as_ref()); + } + /// Sets a `Ipv6Addr` value in the term. pub fn set_ip_addr(&mut self, val: Ipv6Addr) { self.set_bytes(val.to_u128().to_be_bytes().as_ref()); @@ -193,11 +197,6 @@ impl Term { self.0.extend(bytes); } - /// Set the texts only, keeping the field untouched. - pub fn set_text(&mut self, text: &str) { - self.set_bytes(text.as_bytes()); - } - /// Truncates the value bytes of the term. Value and field type stays the same. pub fn truncate_value_bytes(&mut self, len: usize) { self.0.truncate(len + TERM_METADATA_LENGTH); @@ -218,25 +217,21 @@ impl Term { &mut self.0[len_before..] } - /// Appends a JSON_PATH_SEGMENT_SEP to the term. - /// Only used for JSON type. + /// Appends json path bytes to the Term. + /// If the path contains 0 bytes, they are replaced by a "0" string. + /// The 0 byte is used to mark the end of the path. + /// + /// This function returns the segment that has just been added. #[inline] - pub fn add_json_path_separator(&mut self) { - self.0.push(JSON_PATH_SEGMENT_SEP); - } - /// Sets the current end to JSON_END_OF_PATH. - /// Only used for JSON type. - #[inline] - pub fn set_json_path_end(&mut self) { - let buffer_len = self.0.len(); - self.0[buffer_len - 1] = JSON_END_OF_PATH; - } - /// Sets the current end to JSON_PATH_SEGMENT_SEP. - /// Only used for JSON type. - #[inline] - pub fn set_json_path_separator(&mut self) { - let buffer_len = self.0.len(); - self.0[buffer_len - 1] = JSON_PATH_SEGMENT_SEP; + pub fn append_path(&mut self, bytes: &[u8]) -> &mut [u8] { + let len_before = self.0.len(); + if bytes.contains(&0u8) { + self.0 + .extend(bytes.iter().map(|&b| if b == 0 { b'0' } else { b })); + } else { + self.0.extend_from_slice(bytes); + } + &mut self.0[len_before..] } } diff --git a/src/snippet/mod.rs b/src/snippet/mod.rs index 6542df5a3..16edac043 100644 --- a/src/snippet/mod.rs +++ b/src/snippet/mod.rs @@ -743,11 +743,12 @@ Survey in 2016, 2017, and 2018."#; #[test] fn test_collapse_overlapped_ranges() { - assert_eq!(&collapse_overlapped_ranges(&[0..1, 2..3,]), &[0..1, 2..3]); - assert_eq!(&collapse_overlapped_ranges(&[0..1, 1..2,]), &[0..1, 1..2]); - assert_eq!(&collapse_overlapped_ranges(&[0..2, 1..2,]), &[0..2]); - assert_eq!(&collapse_overlapped_ranges(&[0..2, 1..3,]), &[0..3]); - assert_eq!(&collapse_overlapped_ranges(&[0..3, 1..2,]), &[0..3]); + #![allow(clippy::single_range_in_vec_init)] + assert_eq!(&collapse_overlapped_ranges(&[0..1, 2..3]), &[0..1, 2..3]); + assert_eq!(&collapse_overlapped_ranges(&[0..1, 1..2]), &[0..1, 1..2]); + assert_eq!(&collapse_overlapped_ranges(&[0..2, 1..2]), &[0..2]); + assert_eq!(&collapse_overlapped_ranges(&[0..2, 1..3]), &[0..3]); + assert_eq!(&collapse_overlapped_ranges(&[0..3, 1..2]), &[0..3]); } #[test] diff --git a/src/space_usage/mod.rs b/src/space_usage/mod.rs index 84fb074d0..466d67aae 100644 --- a/src/space_usage/mod.rs +++ b/src/space_usage/mod.rs @@ -290,7 +290,7 @@ impl FieldUsage { #[cfg(test)] mod test { - use crate::core::Index; + use crate::index::Index; use crate::schema::{Field, Schema, FAST, INDEXED, STORED, TEXT}; use crate::space_usage::PerFieldSpaceUsage; use crate::{IndexWriter, Term}; diff --git a/src/store/compression_lz4_block.rs b/src/store/compression_lz4_block.rs index 0464510b8..08ecd9e4f 100644 --- a/src/store/compression_lz4_block.rs +++ b/src/store/compression_lz4_block.rs @@ -1,4 +1,3 @@ -use core::convert::TryInto; use std::io::{self}; use std::mem; diff --git a/src/store/compressors.rs b/src/store/compressors.rs index 89205c99d..541855aa9 100644 --- a/src/store/compressors.rs +++ b/src/store/compressors.rs @@ -2,12 +2,6 @@ use std::io; use serde::{Deserialize, Deserializer, Serialize}; -pub trait StoreCompressor { - fn compress(&self, uncompressed: &[u8], compressed: &mut Vec<u8>) -> io::Result<()>; - fn decompress(&self, compressed: &[u8], decompressed: &mut Vec<u8>) -> io::Result<()>; - fn get_compressor_id() -> u8; -} - /// Compressor can be used on `IndexSettings` to choose /// the compressor used to compress the doc store. /// diff --git a/src/store/decompressors.rs b/src/store/decompressors.rs index 2c3173ae2..4d0319aca 100644 --- a/src/store/decompressors.rs +++ b/src/store/decompressors.rs @@ -4,12 +4,6 @@ use serde::{Deserialize, Serialize}; use super::Compressor; -pub trait StoreCompressor { - fn compress(&self, uncompressed: &[u8], compressed: &mut Vec<u8>) -> io::Result<()>; - fn decompress(&self, compressed: &[u8], decompressed: &mut Vec<u8>) -> io::Result<()>; - fn get_compressor_id() -> u8; -} - /// Decompressor is deserialized from the doc store footer, when opening an index. #[derive(Clone, Debug, Copy, PartialEq, Eq, Serialize, Deserialize)] pub enum Decompressor { @@ -86,7 +80,6 @@ impl Decompressor { #[cfg(test)] mod tests { use super::*; - use crate::store::Compressor; #[test] fn compressor_decompressor_id_test() { diff --git a/src/store/index/mod.rs b/src/store/index/mod.rs index 9e657b31b..13c252e92 100644 --- a/src/store/index/mod.rs +++ b/src/store/index/mod.rs @@ -41,7 +41,7 @@ mod tests { use std::io; - use proptest::strategy::{BoxedStrategy, Strategy}; + use proptest::prelude::*; use super::{SkipIndex, SkipIndexBuilder}; use crate::directory::OwnedBytes; @@ -227,8 +227,6 @@ mod tests { } } - use proptest::prelude::*; - proptest! { #![proptest_config(ProptestConfig::with_cases(20))] #[test] diff --git a/src/store/mod.rs b/src/store/mod.rs index 7fbd8c1e5..1cb1a1101 100644 --- a/src/store/mod.rs +++ b/src/store/mod.rs @@ -129,10 +129,7 @@ pub mod tests { ); } - for (_, doc) in store - .iter::<TantivyDocument>(Some(&alive_bitset)) - .enumerate() - { + for doc in store.iter::<TantivyDocument>(Some(&alive_bitset)) { let doc = doc?; let title_content = doc.get_first(field_title).unwrap().as_str().unwrap(); if !title_content.starts_with("Doc ") { diff --git a/src/store/reader.rs b/src/store/reader.rs index 16125a147..b7f243003 100644 --- a/src/store/reader.rs +++ b/src/store/reader.rs @@ -14,7 +14,7 @@ use super::Decompressor; use crate::directory::FileSlice; use crate::error::DataCorruption; use crate::fastfield::AliveBitSet; -use crate::schema::document::{BinaryDocumentDeserializer, Document, DocumentDeserialize}; +use crate::schema::document::{BinaryDocumentDeserializer, DocumentDeserialize}; use crate::space_usage::StoreSpaceUsage; use crate::store::index::Checkpoint; use crate::DocId; @@ -235,7 +235,7 @@ impl StoreReader { /// Iterator over all Documents in their order as they are stored in the doc store. /// Use this, if you want to extract all Documents from the doc store. /// The `alive_bitset` has to be forwarded from the `SegmentReader` or the results may be wrong. - pub fn iter<'a: 'b, 'b, D: Document + DocumentDeserialize>( + pub fn iter<'a: 'b, 'b, D: DocumentDeserialize>( &'b self, alive_bitset: Option<&'a AliveBitSet>, ) -> impl Iterator<Item = crate::Result<D>> + 'b { diff --git a/src/termdict/fst_termdict/term_info_store.rs b/src/termdict/fst_termdict/term_info_store.rs index 837136ff1..0ad3a9d35 100644 --- a/src/termdict/fst_termdict/term_info_store.rs +++ b/src/termdict/fst_termdict/term_info_store.rs @@ -288,7 +288,6 @@ impl TermInfoStoreWriter { #[cfg(test)] mod tests { - use common; use common::BinarySerializable; use tantivy_bitpacker::{compute_num_bits, BitPacker}; diff --git a/src/tokenizer/stemmer.rs b/src/tokenizer/stemmer.rs index 4c43b609a..f66dd2ecb 100644 --- a/src/tokenizer/stemmer.rs +++ b/src/tokenizer/stemmer.rs @@ -1,7 +1,7 @@ use std::borrow::Cow; use std::mem; -use rust_stemmers::{self, Algorithm}; +use rust_stemmers::Algorithm; use serde::{Deserialize, Serialize}; use super::{Token, TokenFilter, TokenStream, Tokenizer}; diff --git a/src/tokenizer/tokenized_string.rs b/src/tokenizer/tokenized_string.rs index 046a02c75..8fbf51f8c 100644 --- a/src/tokenizer/tokenized_string.rs +++ b/src/tokenizer/tokenized_string.rs @@ -95,7 +95,6 @@ impl TokenStream for PreTokenizedStream { mod tests { use super::*; - use crate::tokenizer::Token; #[test] fn test_tokenized_stream() { diff --git a/sstable/Cargo.toml b/sstable/Cargo.toml index 643d6b976..91d629229 100644 --- a/sstable/Cargo.toml +++ b/sstable/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tantivy-sstable" -version = "0.2.0" +version = "0.3.0" edition = "2021" license = "MIT" homepage = "https://github.com/quickwit-oss/tantivy" @@ -10,7 +10,8 @@ categories = ["database-implementations", "data-structures", "compression"] description = "sstables for tantivy" [dependencies] -common = {version= "0.6", path="../common", package="tantivy-common"} +common = {version= "0.7", path="../common", package="tantivy-common"} +tantivy-bitpacker = { version= "0.6", path="../bitpacker" } tantivy-fst = "0.5" # experimental gives us access to Decompressor::upper_bound zstd = { version = "0.13", features = ["experimental"] } diff --git a/sstable/README.md b/sstable/README.md index bec6d70f9..e5e3f5a49 100644 --- a/sstable/README.md +++ b/sstable/README.md @@ -89,33 +89,71 @@ Note: as the SSTable does not support redundant keys, there is no ambiguity betw ### SSTFooter ``` -+-------+-------+-----+-------------+---------+---------+ -| Block | Block | ... | IndexOffset | NumTerm | Version | -+-------+-------+-----+-------------+---------+---------+ -|----( # of blocks)---| ++-----+----------------+-------------+-------------+---------+---------+ +| Fst | BlockAddrStore | StoreOffset | IndexOffset | NumTerm | Version | ++-----+----------------+-------------+-------------+---------+---------+ ``` -- Block(SSTBlock): uses IndexValue for its Values format +- Fst(Fst): finite state transducer mapping keys to a block number +- BlockAddrStore(BlockAddrStore): store mapping a block number to its BlockAddr +- StoreOffset(u64): Offset to start of the BlockAddrStore. If zero, see the SingleBlockSStable section - IndexOffset(u64): Offset to the start of the SSTFooter - NumTerm(u64): number of terms in the sstable -- Version(u32): Currently equal to 2 +- Version(u32): Currently equal to 3 -### IndexValue -``` -+------------+----------+-------+-------+-----+ -| EntryCount | StartPos | Entry | Entry | ... | -+------------+----------+-------+-------+-----+ - |---( # of entries)---| -``` +### Fst -- EntryCount(VInt): number of entries -- StartPos(VInt): the start pos of the first (data) block referenced by this (index) block -- Entry (IndexEntry) +Fst is in the format of tantivy\_fst -### Entry -``` -+----------+--------------+ -| BlockLen | FirstOrdinal | -+----------+--------------+ -``` -- BlockLen(VInt): length of the block -- FirstOrdinal(VInt): ordinal of the first element in the given block +### BlockAddrStore + ++---------+-----------+-----------+-----+-----------+-----------+-----+ +| MetaLen | BlockMeta | BlockMeta | ... | BlockData | BlockData | ... | ++---------+-----------+-----------+-----+-----------+-----------+-----+ + |---------(N blocks)----------|---------(N blocks)----------| + +- MetaLen(u64): length of the BlockMeta section +- BlockMeta(BlockAddrBlockMetadata): metadata to seek through BlockData +- BlockData(CompactedBlockAddr): bitpacked per block metadata + +### BlockAddrBlockMetadata + ++--------+------------+--------------+------------+--------------+-------------------+-----------------+----------+ +| Offset | RangeStart | FirstOrdinal | RangeSlope | OrdinalSlope | FirstOrdinalNBits | RangeStartNBits | BlockLen | ++--------+------------+--------------+------------+--------------+-------------------+-----------------+----------+ + +- Offset(u64): offset of the corresponding BlockData in the datastream +- RangeStart(u64): the start position of the first block +- FirstOrdinal(u64): the first ordinal of the first block +- RangeSlope(u32): slope predicted for start range evolution (see computation in BlockData) +- OrdinalSlope(u64): slope predicted for first ordinal evolution (see computation in BlockData) +- FirstOrdinalNBits(u8): number of bits per ordinal in datastream (see computation in BlockData) +- RangeStartNBits(u8): number of bits per range start in datastream (see computation in BlockData) + +### BlockData + ++-----------------+-------------------+---------------+ +| RangeStartDelta | FirstOrdinalDelta | FinalRangeEnd | ++-----------------+-------------------+---------------+ +|------(BlockLen repetitions)---------| + +- RangeStartDelta(var): RangeStartNBits *bits* of little endian number. See below for decoding +- FirstOrdinalDelta(var): FirstOrdinalNBits *bits* of little endian number. See below for decoding +- FinalRangeEnd(var): RangeStartNBits *bits* of integer. See below for decoding + +converting a BlockData of index Index and a BlockAddrBlockMetadata to an actual block address is done as follow: +range\_prediction := RangeStart + Index * RangeSlop; +range\_derivation := RangeStartDelta - (1 << (RangeStartNBits-1)); +range\_start := range\_prediction + range\_derivation + +The same computation can be done for ordinal. + +Note that `range_derivation` can take negative value. `RangeStartDelta` is just its translation to a positive range. + + +## SingleBlockSStable + +The format used for the index is meant to be compact, however it has a constant cost of around 70 +bytes, which isn't negligible for a table containing very few keys. +To limit the impact of that constant cost, single block sstable omit the Fst and BlockAddrStore from +their index. Instead a block with first ordinal of 0, range start of 0 and range end of IndexOffset +is implicitly used for every operations. diff --git a/sstable/benches/ord_to_term.rs b/sstable/benches/ord_to_term.rs index 04c8835fb..9285af2e4 100644 --- a/sstable/benches/ord_to_term.rs +++ b/sstable/benches/ord_to_term.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use common::file_slice::FileSlice; use common::OwnedBytes; use criterion::{criterion_group, criterion_main, Criterion}; -use tantivy_sstable::{self, Dictionary, MonotonicU64SSTable}; +use tantivy_sstable::{Dictionary, MonotonicU64SSTable}; fn make_test_sstable(suffix: &str) -> FileSlice { let mut builder = Dictionary::<MonotonicU64SSTable>::builder(Vec::new()).unwrap(); @@ -40,6 +40,31 @@ pub fn criterion_benchmark(c: &mut Criterion) { assert!(dict.ord_to_term(19_000_000, &mut res).unwrap()); }) }); + c.bench_function("term_ord_suffix", |b| { + b.iter(|| { + assert_eq!( + dict.term_ord(b"prefix.00186A0.suffix").unwrap().unwrap(), + 100_000 + ); + assert_eq!( + dict.term_ord(b"prefix.121EAC0.suffix").unwrap().unwrap(), + 19_000_000 + ); + }) + }); + c.bench_function("open_and_term_ord_suffix", |b| { + b.iter(|| { + let dict = Dictionary::<MonotonicU64SSTable>::open(slice.clone()).unwrap(); + assert_eq!( + dict.term_ord(b"prefix.00186A0.suffix").unwrap().unwrap(), + 100_000 + ); + assert_eq!( + dict.term_ord(b"prefix.121EAC0.suffix").unwrap().unwrap(), + 19_000_000 + ); + }) + }); } { let slice = make_test_sstable(""); @@ -59,6 +84,25 @@ pub fn criterion_benchmark(c: &mut Criterion) { assert!(dict.ord_to_term(19_000_000, &mut res).unwrap()); }) }); + c.bench_function("term_ord", |b| { + b.iter(|| { + assert_eq!(dict.term_ord(b"prefix.00186A0").unwrap().unwrap(), 100_000); + assert_eq!( + dict.term_ord(b"prefix.121EAC0").unwrap().unwrap(), + 19_000_000 + ); + }) + }); + c.bench_function("open_and_term_ord", |b| { + b.iter(|| { + let dict = Dictionary::<MonotonicU64SSTable>::open(slice.clone()).unwrap(); + assert_eq!(dict.term_ord(b"prefix.00186A0").unwrap().unwrap(), 100_000); + assert_eq!( + dict.term_ord(b"prefix.121EAC0").unwrap().unwrap(), + 19_000_000 + ); + }) + }); } } diff --git a/sstable/benches/stream_bench.rs b/sstable/benches/stream_bench.rs index 2b29a5e99..d8df433e9 100644 --- a/sstable/benches/stream_bench.rs +++ b/sstable/benches/stream_bench.rs @@ -5,7 +5,7 @@ use common::file_slice::FileSlice; use criterion::{criterion_group, criterion_main, Criterion}; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; -use tantivy_sstable::{self, Dictionary, MonotonicU64SSTable}; +use tantivy_sstable::{Dictionary, MonotonicU64SSTable}; const CHARSET: &'static [u8] = b"abcdefghij"; diff --git a/sstable/src/dictionary.rs b/sstable/src/dictionary.rs index 0eb5822d9..b4821fe24 100644 --- a/sstable/src/dictionary.rs +++ b/sstable/src/dictionary.rs @@ -9,8 +9,11 @@ use common::{BinarySerializable, OwnedBytes}; use tantivy_fst::automaton::AlwaysMatch; use tantivy_fst::Automaton; +use crate::sstable_index_v3::SSTableIndexV3Empty; use crate::streamer::{Streamer, StreamerBuilder}; -use crate::{BlockAddr, DeltaReader, Reader, SSTable, SSTableIndex, TermOrdinal, VoidSSTable}; +use crate::{ + BlockAddr, DeltaReader, Reader, SSTable, SSTableIndex, SSTableIndexV3, TermOrdinal, VoidSSTable, +}; /// An SSTable is a sorted map that associates sorted `&[u8]` keys /// to any kind of typed values. @@ -180,24 +183,41 @@ impl<TSSTable: SSTable> Dictionary<TSSTable> { pub fn open(term_dictionary_file: FileSlice) -> io::Result<Self> { let (main_slice, footer_len_slice) = term_dictionary_file.split_from_end(20); let mut footer_len_bytes: OwnedBytes = footer_len_slice.read_bytes()?; - let index_offset = u64::deserialize(&mut footer_len_bytes)?; let num_terms = u64::deserialize(&mut footer_len_bytes)?; let version = u32::deserialize(&mut footer_len_bytes)?; - if version != crate::SSTABLE_VERSION { - return Err(io::Error::new( - io::ErrorKind::Other, - format!( - "Unsuported sstable version, expected {version}, found {}", - crate::SSTABLE_VERSION, - ), - )); - } - let (sstable_slice, index_slice) = main_slice.split(index_offset as usize); let sstable_index_bytes = index_slice.read_bytes()?; - let sstable_index = SSTableIndex::load(sstable_index_bytes) - .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "SSTable corruption"))?; + + let sstable_index = match version { + 2 => SSTableIndex::V2( + crate::sstable_index_v2::SSTableIndex::load(sstable_index_bytes).map_err(|_| { + io::Error::new(io::ErrorKind::InvalidData, "SSTable corruption") + })?, + ), + 3 => { + let (sstable_index_bytes, mut footerv3_len_bytes) = sstable_index_bytes.rsplit(8); + let store_offset = u64::deserialize(&mut footerv3_len_bytes)?; + if store_offset != 0 { + SSTableIndex::V3( + SSTableIndexV3::load(sstable_index_bytes, store_offset).map_err(|_| { + io::Error::new(io::ErrorKind::InvalidData, "SSTable corruption") + })?, + ) + } else { + // if store_offset is zero, there is no index, so we build a pseudo-index + // assuming a single block of sstable covering everything. + SSTableIndex::V3Empty(SSTableIndexV3Empty::load(index_offset as usize)) + } + } + _ => { + return Err(io::Error::new( + io::ErrorKind::Other, + format!("Unsuported sstable version, expected one of [2, 3], found {version}"), + )) + } + }; + Ok(Dictionary { sstable_slice, sstable_index, diff --git a/sstable/src/lib.rs b/sstable/src/lib.rs index 09014b794..87b8f29d9 100644 --- a/sstable/src/lib.rs +++ b/sstable/src/lib.rs @@ -1,6 +1,5 @@ use std::io::{self, Write}; use std::ops::Range; -use std::usize; use merge::ValueMerger; @@ -10,8 +9,9 @@ pub mod merge; mod streamer; pub mod value; -mod sstable_index; -pub use sstable_index::{BlockAddr, SSTableIndex, SSTableIndexBuilder}; +mod sstable_index_v3; +pub use sstable_index_v3::{BlockAddr, SSTableIndex, SSTableIndexBuilder, SSTableIndexV3}; +mod sstable_index_v2; pub(crate) mod vint; pub use dictionary::Dictionary; pub use streamer::{Streamer, StreamerBuilder}; @@ -28,7 +28,7 @@ use crate::value::{RangeValueReader, RangeValueWriter}; pub type TermOrdinal = u64; const DEFAULT_KEY_CAPACITY: usize = 50; -const SSTABLE_VERSION: u32 = 2; +const SSTABLE_VERSION: u32 = 3; /// Given two byte string returns the length of /// the longest common prefix. @@ -304,7 +304,8 @@ where let offset = wrt.written_bytes(); - self.index_builder.serialize(&mut wrt)?; + let fst_len: u64 = self.index_builder.serialize(&mut wrt)?; + wrt.write_all(&fst_len.to_le_bytes())?; wrt.write_all(&offset.to_le_bytes())?; wrt.write_all(&self.num_terms.to_le_bytes())?; @@ -385,13 +386,10 @@ mod test { 16, 17, 33, 18, 19, 17, 20, // data block 0, 0, 0, 0, // no more block // index - 8, 0, 0, 0, // size of index block - 0, // compression - 1, 0, 12, 0, 32, 17, 20, // index block - 0, 0, 0, 0, // no more index block + 0, 0, 0, 0, 0, 0, 0, 0, // fst lenght 16, 0, 0, 0, 0, 0, 0, 0, // index start offset 3, 0, 0, 0, 0, 0, 0, 0, // num term - 2, 0, 0, 0, // version + 3, 0, 0, 0, // version ] ); let buffer = OwnedBytes::new(buffer); diff --git a/sstable/src/sstable_index.rs b/sstable/src/sstable_index.rs deleted file mode 100644 index 1ce85305c..000000000 --- a/sstable/src/sstable_index.rs +++ /dev/null @@ -1,266 +0,0 @@ -use std::io::{self, Write}; -use std::ops::Range; - -use common::OwnedBytes; - -use crate::{common_prefix_len, SSTable, SSTableDataCorruption, TermOrdinal}; - -#[derive(Default, Debug, Clone)] -pub struct SSTableIndex { - blocks: Vec<BlockMeta>, -} - -impl SSTableIndex { - /// Load an index from its binary representation - pub fn load(data: OwnedBytes) -> Result<SSTableIndex, SSTableDataCorruption> { - let mut reader = IndexSSTable::reader(data); - let mut blocks = Vec::new(); - - while reader.advance().map_err(|_| SSTableDataCorruption)? { - blocks.push(BlockMeta { - last_key_or_greater: reader.key().to_vec(), - block_addr: reader.value().clone(), - }); - } - - Ok(SSTableIndex { blocks }) - } - - /// Get the [`BlockAddr`] of the requested block. - pub(crate) fn get_block(&self, block_id: usize) -> Option<BlockAddr> { - self.blocks - .get(block_id) - .map(|block_meta| block_meta.block_addr.clone()) - } - - /// Get the block id of the block that would contain `key`. - /// - /// Returns None if `key` is lexicographically after the last key recorded. - pub(crate) fn locate_with_key(&self, key: &[u8]) -> Option<usize> { - let pos = self - .blocks - .binary_search_by_key(&key, |block| &block.last_key_or_greater); - match pos { - Ok(pos) => Some(pos), - Err(pos) => { - if pos < self.blocks.len() { - Some(pos) - } else { - // after end of last block: no block matches - None - } - } - } - } - - /// Get the [`BlockAddr`] of the block that would contain `key`. - /// - /// Returns None if `key` is lexicographically after the last key recorded. - pub fn get_block_with_key(&self, key: &[u8]) -> Option<BlockAddr> { - self.locate_with_key(key).and_then(|id| self.get_block(id)) - } - - pub(crate) fn locate_with_ord(&self, ord: TermOrdinal) -> usize { - let pos = self - .blocks - .binary_search_by_key(&ord, |block| block.block_addr.first_ordinal); - - match pos { - Ok(pos) => pos, - // Err(0) can't happen as the sstable starts with ordinal zero - Err(pos) => pos - 1, - } - } - - /// Get the [`BlockAddr`] of the block containing the `ord`-th term. - pub(crate) fn get_block_with_ord(&self, ord: TermOrdinal) -> BlockAddr { - // locate_with_ord always returns an index within range - self.get_block(self.locate_with_ord(ord)).unwrap() - } -} - -#[derive(Clone, Eq, PartialEq, Debug)] -pub struct BlockAddr { - pub byte_range: Range<usize>, - pub first_ordinal: u64, -} - -#[derive(Debug, Clone)] -pub(crate) struct BlockMeta { - /// Any byte string that is lexicographically greater or equal to - /// the last key in the block, - /// and yet strictly smaller than the first key in the next block. - pub last_key_or_greater: Vec<u8>, - pub block_addr: BlockAddr, -} - -#[derive(Default)] -pub struct SSTableIndexBuilder { - index: SSTableIndex, -} - -/// Given that left < right, -/// mutates `left into a shorter byte string left'` that -/// matches `left <= left' < right`. -fn find_shorter_str_in_between(left: &mut Vec<u8>, right: &[u8]) { - assert!(&left[..] < right); - let common_len = common_prefix_len(left, right); - if left.len() == common_len { - return; - } - // It is possible to do one character shorter in some case, - // but it is not worth the extra complexity - for pos in (common_len + 1)..left.len() { - if left[pos] != u8::MAX { - left[pos] += 1; - left.truncate(pos + 1); - return; - } - } -} - -impl SSTableIndexBuilder { - /// In order to make the index as light as possible, we - /// try to find a shorter alternative to the last key of the last block - /// that is still smaller than the next key. - pub(crate) fn shorten_last_block_key_given_next_key(&mut self, next_key: &[u8]) { - if let Some(last_block) = self.index.blocks.last_mut() { - find_shorter_str_in_between(&mut last_block.last_key_or_greater, next_key); - } - } - - pub fn add_block(&mut self, last_key: &[u8], byte_range: Range<usize>, first_ordinal: u64) { - self.index.blocks.push(BlockMeta { - last_key_or_greater: last_key.to_vec(), - block_addr: BlockAddr { - byte_range, - first_ordinal, - }, - }) - } - - pub fn serialize<W: std::io::Write>(&self, wrt: W) -> io::Result<()> { - // we can't use a plain writer as it would generate an index - let mut sstable_writer = IndexSSTable::delta_writer(wrt); - - // in tests, set a smaller block size to stress-test - #[cfg(test)] - sstable_writer.set_block_len(16); - - let mut previous_key = Vec::with_capacity(crate::DEFAULT_KEY_CAPACITY); - for block in self.index.blocks.iter() { - let keep_len = common_prefix_len(&previous_key, &block.last_key_or_greater); - - sstable_writer.write_suffix(keep_len, &block.last_key_or_greater[keep_len..]); - sstable_writer.write_value(&block.block_addr); - sstable_writer.flush_block_if_required()?; - - previous_key.clear(); - previous_key.extend_from_slice(&block.last_key_or_greater); - } - sstable_writer.flush_block()?; - sstable_writer.finish().write_all(&0u32.to_le_bytes())?; - Ok(()) - } -} - -/// SSTable representing an index -/// -/// `last_key_or_greater` is used as the key, the value contains the -/// length and first ordinal of each block. The start offset is implicitly -/// obtained from lengths. -struct IndexSSTable; - -impl SSTable for IndexSSTable { - type Value = BlockAddr; - - type ValueReader = crate::value::index::IndexValueReader; - - type ValueWriter = crate::value::index::IndexValueWriter; -} - -#[cfg(test)] -mod tests { - use common::OwnedBytes; - - use super::{BlockAddr, SSTableIndex, SSTableIndexBuilder}; - use crate::SSTableDataCorruption; - - #[test] - fn test_sstable_index() { - let mut sstable_builder = SSTableIndexBuilder::default(); - sstable_builder.add_block(b"aaa", 10..20, 0u64); - sstable_builder.add_block(b"bbbbbbb", 20..30, 5u64); - sstable_builder.add_block(b"ccc", 30..40, 10u64); - sstable_builder.add_block(b"dddd", 40..50, 15u64); - let mut buffer: Vec<u8> = Vec::new(); - sstable_builder.serialize(&mut buffer).unwrap(); - let buffer = OwnedBytes::new(buffer); - let sstable_index = SSTableIndex::load(buffer).unwrap(); - assert_eq!( - sstable_index.get_block_with_key(b"bbbde"), - Some(BlockAddr { - first_ordinal: 10u64, - byte_range: 30..40 - }) - ); - - assert_eq!(sstable_index.locate_with_key(b"aa").unwrap(), 0); - assert_eq!(sstable_index.locate_with_key(b"aaa").unwrap(), 0); - assert_eq!(sstable_index.locate_with_key(b"aab").unwrap(), 1); - assert_eq!(sstable_index.locate_with_key(b"ccc").unwrap(), 2); - assert!(sstable_index.locate_with_key(b"e").is_none()); - - assert_eq!(sstable_index.locate_with_ord(0), 0); - assert_eq!(sstable_index.locate_with_ord(1), 0); - assert_eq!(sstable_index.locate_with_ord(4), 0); - assert_eq!(sstable_index.locate_with_ord(5), 1); - assert_eq!(sstable_index.locate_with_ord(100), 3); - } - - #[test] - fn test_sstable_with_corrupted_data() { - let mut sstable_builder = SSTableIndexBuilder::default(); - sstable_builder.add_block(b"aaa", 10..20, 0u64); - sstable_builder.add_block(b"bbbbbbb", 20..30, 5u64); - sstable_builder.add_block(b"ccc", 30..40, 10u64); - sstable_builder.add_block(b"dddd", 40..50, 15u64); - let mut buffer: Vec<u8> = Vec::new(); - sstable_builder.serialize(&mut buffer).unwrap(); - buffer[2] = 9u8; - let buffer = OwnedBytes::new(buffer); - let data_corruption_err = SSTableIndex::load(buffer).err().unwrap(); - assert!(matches!(data_corruption_err, SSTableDataCorruption)); - } - - #[track_caller] - fn test_find_shorter_str_in_between_aux(left: &[u8], right: &[u8]) { - let mut left_buf = left.to_vec(); - super::find_shorter_str_in_between(&mut left_buf, right); - assert!(left_buf.len() <= left.len()); - assert!(left <= &left_buf); - assert!(&left_buf[..] < right); - } - - #[test] - fn test_find_shorter_str_in_between() { - test_find_shorter_str_in_between_aux(b"", b"hello"); - test_find_shorter_str_in_between_aux(b"abc", b"abcd"); - test_find_shorter_str_in_between_aux(b"abcd", b"abd"); - test_find_shorter_str_in_between_aux(&[0, 0, 0], &[1]); - test_find_shorter_str_in_between_aux(&[0, 0, 0], &[0, 0, 1]); - test_find_shorter_str_in_between_aux(&[0, 0, 255, 255, 255, 0u8], &[0, 1]); - } - - use proptest::prelude::*; - - proptest! { - #![proptest_config(ProptestConfig::with_cases(100))] - #[test] - fn test_proptest_find_shorter_str(left in any::<Vec<u8>>(), right in any::<Vec<u8>>()) { - if left < right { - test_find_shorter_str_in_between_aux(&left, &right); - } - } - } -} diff --git a/sstable/src/sstable_index_v2.rs b/sstable/src/sstable_index_v2.rs new file mode 100644 index 000000000..d7c97c13a --- /dev/null +++ b/sstable/src/sstable_index_v2.rs @@ -0,0 +1,101 @@ +use common::OwnedBytes; + +use crate::{BlockAddr, SSTable, SSTableDataCorruption, TermOrdinal}; + +#[derive(Default, Debug, Clone)] +pub struct SSTableIndex { + blocks: Vec<BlockMeta>, +} + +impl SSTableIndex { + /// Load an index from its binary representation + pub fn load(data: OwnedBytes) -> Result<SSTableIndex, SSTableDataCorruption> { + let mut reader = IndexSSTable::reader(data); + let mut blocks = Vec::new(); + + while reader.advance().map_err(|_| SSTableDataCorruption)? { + blocks.push(BlockMeta { + last_key_or_greater: reader.key().to_vec(), + block_addr: reader.value().clone(), + }); + } + + Ok(SSTableIndex { blocks }) + } + + /// Get the [`BlockAddr`] of the requested block. + pub(crate) fn get_block(&self, block_id: usize) -> Option<BlockAddr> { + self.blocks + .get(block_id) + .map(|block_meta| block_meta.block_addr.clone()) + } + + /// Get the block id of the block that would contain `key`. + /// + /// Returns None if `key` is lexicographically after the last key recorded. + pub(crate) fn locate_with_key(&self, key: &[u8]) -> Option<usize> { + let pos = self + .blocks + .binary_search_by_key(&key, |block| &block.last_key_or_greater); + match pos { + Ok(pos) => Some(pos), + Err(pos) => { + if pos < self.blocks.len() { + Some(pos) + } else { + // after end of last block: no block matches + None + } + } + } + } + + /// Get the [`BlockAddr`] of the block that would contain `key`. + /// + /// Returns None if `key` is lexicographically after the last key recorded. + pub fn get_block_with_key(&self, key: &[u8]) -> Option<BlockAddr> { + self.locate_with_key(key).and_then(|id| self.get_block(id)) + } + + pub(crate) fn locate_with_ord(&self, ord: TermOrdinal) -> usize { + let pos = self + .blocks + .binary_search_by_key(&ord, |block| block.block_addr.first_ordinal); + + match pos { + Ok(pos) => pos, + // Err(0) can't happen as the sstable starts with ordinal zero + Err(pos) => pos - 1, + } + } + + /// Get the [`BlockAddr`] of the block containing the `ord`-th term. + pub(crate) fn get_block_with_ord(&self, ord: TermOrdinal) -> BlockAddr { + // locate_with_ord always returns an index within range + self.get_block(self.locate_with_ord(ord)).unwrap() + } +} + +#[derive(Debug, Clone)] +pub(crate) struct BlockMeta { + /// Any byte string that is lexicographically greater or equal to + /// the last key in the block, + /// and yet strictly smaller than the first key in the next block. + pub last_key_or_greater: Vec<u8>, + pub block_addr: BlockAddr, +} + +/// SSTable representing an index +/// +/// `last_key_or_greater` is used as the key, the value contains the +/// length and first ordinal of each block. The start offset is implicitly +/// obtained from lengths. +struct IndexSSTable; + +impl SSTable for IndexSSTable { + type Value = BlockAddr; + + type ValueReader = crate::value::index::IndexValueReader; + + type ValueWriter = crate::value::index::IndexValueWriter; +} diff --git a/sstable/src/sstable_index_v3.rs b/sstable/src/sstable_index_v3.rs new file mode 100644 index 000000000..e2a6fdfc7 --- /dev/null +++ b/sstable/src/sstable_index_v3.rs @@ -0,0 +1,826 @@ +use std::io::{self, Read, Write}; +use std::ops::Range; +use std::sync::Arc; + +use common::{BinarySerializable, FixedSize, OwnedBytes}; +use tantivy_bitpacker::{compute_num_bits, BitPacker}; +use tantivy_fst::raw::Fst; +use tantivy_fst::{IntoStreamer, Map, MapBuilder, Streamer}; + +use crate::{common_prefix_len, SSTableDataCorruption, TermOrdinal}; + +#[derive(Debug, Clone)] +pub enum SSTableIndex { + V2(crate::sstable_index_v2::SSTableIndex), + V3(SSTableIndexV3), + V3Empty(SSTableIndexV3Empty), +} + +impl SSTableIndex { + /// Get the [`BlockAddr`] of the requested block. + pub(crate) fn get_block(&self, block_id: u64) -> Option<BlockAddr> { + match self { + SSTableIndex::V2(v2_index) => v2_index.get_block(block_id as usize), + SSTableIndex::V3(v3_index) => v3_index.get_block(block_id), + SSTableIndex::V3Empty(v3_empty) => v3_empty.get_block(block_id), + } + } + + /// Get the block id of the block that would contain `key`. + /// + /// Returns None if `key` is lexicographically after the last key recorded. + pub(crate) fn locate_with_key(&self, key: &[u8]) -> Option<u64> { + match self { + SSTableIndex::V2(v2_index) => v2_index.locate_with_key(key).map(|i| i as u64), + SSTableIndex::V3(v3_index) => v3_index.locate_with_key(key), + SSTableIndex::V3Empty(v3_empty) => v3_empty.locate_with_key(key), + } + } + + /// Get the [`BlockAddr`] of the block that would contain `key`. + /// + /// Returns None if `key` is lexicographically after the last key recorded. + pub fn get_block_with_key(&self, key: &[u8]) -> Option<BlockAddr> { + match self { + SSTableIndex::V2(v2_index) => v2_index.get_block_with_key(key), + SSTableIndex::V3(v3_index) => v3_index.get_block_with_key(key), + SSTableIndex::V3Empty(v3_empty) => v3_empty.get_block_with_key(key), + } + } + + pub(crate) fn locate_with_ord(&self, ord: TermOrdinal) -> u64 { + match self { + SSTableIndex::V2(v2_index) => v2_index.locate_with_ord(ord) as u64, + SSTableIndex::V3(v3_index) => v3_index.locate_with_ord(ord), + SSTableIndex::V3Empty(v3_empty) => v3_empty.locate_with_ord(ord), + } + } + + /// Get the [`BlockAddr`] of the block containing the `ord`-th term. + pub(crate) fn get_block_with_ord(&self, ord: TermOrdinal) -> BlockAddr { + match self { + SSTableIndex::V2(v2_index) => v2_index.get_block_with_ord(ord), + SSTableIndex::V3(v3_index) => v3_index.get_block_with_ord(ord), + SSTableIndex::V3Empty(v3_empty) => v3_empty.get_block_with_ord(ord), + } + } +} + +#[derive(Debug, Clone)] +pub struct SSTableIndexV3 { + fst_index: Arc<Map<OwnedBytes>>, + block_addr_store: BlockAddrStore, +} + +impl SSTableIndexV3 { + /// Load an index from its binary representation + pub fn load( + data: OwnedBytes, + fst_length: u64, + ) -> Result<SSTableIndexV3, SSTableDataCorruption> { + let (fst_slice, block_addr_store_slice) = data.split(fst_length as usize); + let fst_index = Fst::new(fst_slice) + .map_err(|_| SSTableDataCorruption)? + .into(); + let block_addr_store = + BlockAddrStore::open(block_addr_store_slice).map_err(|_| SSTableDataCorruption)?; + + Ok(SSTableIndexV3 { + fst_index: Arc::new(fst_index), + block_addr_store, + }) + } + + /// Get the [`BlockAddr`] of the requested block. + pub(crate) fn get_block(&self, block_id: u64) -> Option<BlockAddr> { + self.block_addr_store.get(block_id) + } + + /// Get the block id of the block that would contain `key`. + /// + /// Returns None if `key` is lexicographically after the last key recorded. + pub(crate) fn locate_with_key(&self, key: &[u8]) -> Option<u64> { + self.fst_index + .range() + .ge(key) + .into_stream() + .next() + .map(|(_key, id)| id) + } + + /// Get the [`BlockAddr`] of the block that would contain `key`. + /// + /// Returns None if `key` is lexicographically after the last key recorded. + pub fn get_block_with_key(&self, key: &[u8]) -> Option<BlockAddr> { + self.locate_with_key(key).and_then(|id| self.get_block(id)) + } + + pub(crate) fn locate_with_ord(&self, ord: TermOrdinal) -> u64 { + self.block_addr_store.binary_search_ord(ord).0 + } + + /// Get the [`BlockAddr`] of the block containing the `ord`-th term. + pub(crate) fn get_block_with_ord(&self, ord: TermOrdinal) -> BlockAddr { + self.block_addr_store.binary_search_ord(ord).1 + } +} + +#[derive(Debug, Clone)] +pub struct SSTableIndexV3Empty { + block_addr: BlockAddr, +} + +impl SSTableIndexV3Empty { + pub fn load(index_start_pos: usize) -> SSTableIndexV3Empty { + SSTableIndexV3Empty { + block_addr: BlockAddr { + first_ordinal: 0, + byte_range: 0..index_start_pos, + }, + } + } + + /// Get the [`BlockAddr`] of the requested block. + pub(crate) fn get_block(&self, _block_id: u64) -> Option<BlockAddr> { + Some(self.block_addr.clone()) + } + + /// Get the block id of the block that would contain `key`. + /// + /// Returns None if `key` is lexicographically after the last key recorded. + pub(crate) fn locate_with_key(&self, _key: &[u8]) -> Option<u64> { + Some(0) + } + + /// Get the [`BlockAddr`] of the block that would contain `key`. + /// + /// Returns None if `key` is lexicographically after the last key recorded. + pub fn get_block_with_key(&self, _key: &[u8]) -> Option<BlockAddr> { + Some(self.block_addr.clone()) + } + + pub(crate) fn locate_with_ord(&self, _ord: TermOrdinal) -> u64 { + 0 + } + + /// Get the [`BlockAddr`] of the block containing the `ord`-th term. + pub(crate) fn get_block_with_ord(&self, _ord: TermOrdinal) -> BlockAddr { + self.block_addr.clone() + } +} +#[derive(Clone, Eq, PartialEq, Debug)] +pub struct BlockAddr { + pub first_ordinal: u64, + pub byte_range: Range<usize>, +} + +impl BlockAddr { + fn to_block_start(&self) -> BlockStartAddr { + BlockStartAddr { + first_ordinal: self.first_ordinal, + byte_range_start: self.byte_range.start, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct BlockStartAddr { + first_ordinal: u64, + byte_range_start: usize, +} + +impl BlockStartAddr { + fn to_block_addr(&self, byte_range_end: usize) -> BlockAddr { + BlockAddr { + first_ordinal: self.first_ordinal, + byte_range: self.byte_range_start..byte_range_end, + } + } +} + +#[derive(Debug, Clone)] +pub(crate) struct BlockMeta { + /// Any byte string that is lexicographically greater or equal to + /// the last key in the block, + /// and yet strictly smaller than the first key in the next block. + pub last_key_or_greater: Vec<u8>, + pub block_addr: BlockAddr, +} + +impl BinarySerializable for BlockStartAddr { + fn serialize<W: Write + ?Sized>(&self, writer: &mut W) -> io::Result<()> { + let start = self.byte_range_start as u64; + start.serialize(writer)?; + self.first_ordinal.serialize(writer) + } + + fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> { + let byte_range_start = u64::deserialize(reader)? as usize; + let first_ordinal = u64::deserialize(reader)?; + Ok(BlockStartAddr { + first_ordinal, + byte_range_start, + }) + } + + // Provided method + fn num_bytes(&self) -> u64 { + BlockStartAddr::SIZE_IN_BYTES as u64 + } +} + +impl FixedSize for BlockStartAddr { + const SIZE_IN_BYTES: usize = 2 * u64::SIZE_IN_BYTES; +} + +/// Given that left < right, +/// mutates `left into a shorter byte string left'` that +/// matches `left <= left' < right`. +fn find_shorter_str_in_between(left: &mut Vec<u8>, right: &[u8]) { + assert!(&left[..] < right); + let common_len = common_prefix_len(left, right); + if left.len() == common_len { + return; + } + // It is possible to do one character shorter in some case, + // but it is not worth the extra complexity + for pos in (common_len + 1)..left.len() { + if left[pos] != u8::MAX { + left[pos] += 1; + left.truncate(pos + 1); + return; + } + } +} + +#[derive(Default)] +pub struct SSTableIndexBuilder { + blocks: Vec<BlockMeta>, +} + +impl SSTableIndexBuilder { + /// In order to make the index as light as possible, we + /// try to find a shorter alternative to the last key of the last block + /// that is still smaller than the next key. + pub(crate) fn shorten_last_block_key_given_next_key(&mut self, next_key: &[u8]) { + if let Some(last_block) = self.blocks.last_mut() { + find_shorter_str_in_between(&mut last_block.last_key_or_greater, next_key); + } + } + + pub fn add_block(&mut self, last_key: &[u8], byte_range: Range<usize>, first_ordinal: u64) { + self.blocks.push(BlockMeta { + last_key_or_greater: last_key.to_vec(), + block_addr: BlockAddr { + byte_range, + first_ordinal, + }, + }) + } + + pub fn serialize<W: std::io::Write>(&self, wrt: W) -> io::Result<u64> { + if self.blocks.len() <= 1 { + return Ok(0); + } + let counting_writer = common::CountingWriter::wrap(wrt); + let mut map_builder = MapBuilder::new(counting_writer).map_err(fst_error_to_io_error)?; + for (i, block) in self.blocks.iter().enumerate() { + map_builder + .insert(&block.last_key_or_greater, i as u64) + .map_err(fst_error_to_io_error)?; + } + let counting_writer = map_builder.into_inner().map_err(fst_error_to_io_error)?; + let written_bytes = counting_writer.written_bytes(); + let mut wrt = counting_writer.finish(); + + let mut block_store_writer = BlockAddrStoreWriter::new(); + for block in &self.blocks { + block_store_writer.write_block_meta(block.block_addr.clone())?; + } + block_store_writer.serialize(&mut wrt)?; + + Ok(written_bytes) + } +} + +fn fst_error_to_io_error(error: tantivy_fst::Error) -> io::Error { + match error { + tantivy_fst::Error::Fst(fst_error) => io::Error::new(io::ErrorKind::Other, fst_error), + tantivy_fst::Error::Io(ioerror) => ioerror, + } +} + +const STORE_BLOCK_LEN: usize = 128; + +#[derive(Debug)] +struct BlockAddrBlockMetadata { + offset: u64, + ref_block_addr: BlockStartAddr, + range_start_slope: u32, + first_ordinal_slope: u32, + range_start_nbits: u8, + first_ordinal_nbits: u8, + block_len: u16, + // these fields are computed on deserialization, and not stored + range_shift: i64, + ordinal_shift: i64, +} + +impl BlockAddrBlockMetadata { + fn num_bits(&self) -> u8 { + self.first_ordinal_nbits + self.range_start_nbits + } + + fn deserialize_block_addr(&self, data: &[u8], inner_offset: usize) -> Option<BlockAddr> { + if inner_offset == 0 { + let range_end = self.ref_block_addr.byte_range_start + + extract_bits(data, 0, self.range_start_nbits) as usize + + self.range_start_slope as usize + - self.range_shift as usize; + return Some(self.ref_block_addr.to_block_addr(range_end)); + } + let inner_offset = inner_offset - 1; + if inner_offset >= self.block_len as usize { + return None; + } + let num_bits = self.num_bits() as usize; + + let range_start_addr = num_bits * inner_offset; + let ordinal_addr = range_start_addr + self.range_start_nbits as usize; + let range_end_addr = range_start_addr + num_bits; + + if (range_end_addr + self.range_start_nbits as usize + 7) / 8 > data.len() { + return None; + } + + let range_start = self.ref_block_addr.byte_range_start + + extract_bits(data, range_start_addr, self.range_start_nbits) as usize + + self.range_start_slope as usize * (inner_offset + 1) + - self.range_shift as usize; + let first_ordinal = self.ref_block_addr.first_ordinal + + extract_bits(data, ordinal_addr, self.first_ordinal_nbits) + + self.first_ordinal_slope as u64 * (inner_offset + 1) as u64 + - self.ordinal_shift as u64; + let range_end = self.ref_block_addr.byte_range_start + + extract_bits(data, range_end_addr, self.range_start_nbits) as usize + + self.range_start_slope as usize * (inner_offset + 2) + - self.range_shift as usize; + + Some(BlockAddr { + first_ordinal, + byte_range: range_start..range_end, + }) + } + + fn bisect_for_ord(&self, data: &[u8], target_ord: TermOrdinal) -> (u64, BlockAddr) { + let inner_target_ord = target_ord - self.ref_block_addr.first_ordinal; + let num_bits = self.num_bits() as usize; + let range_start_nbits = self.range_start_nbits as usize; + let get_ord = |index| { + extract_bits( + data, + num_bits * index as usize + range_start_nbits, + self.first_ordinal_nbits, + ) + self.first_ordinal_slope as u64 * (index + 1) + - self.ordinal_shift as u64 + }; + + let inner_offset = match binary_search(self.block_len as u64, |index| { + get_ord(index).cmp(&inner_target_ord) + }) { + Ok(inner_offset) => inner_offset + 1, + Err(inner_offset) => inner_offset, + }; + // we can unwrap because inner_offset <= self.block_len + ( + inner_offset, + self.deserialize_block_addr(data, inner_offset as usize) + .unwrap(), + ) + } +} + +// TODO move this function to tantivy_common? +#[inline(always)] +fn extract_bits(data: &[u8], addr_bits: usize, num_bits: u8) -> u64 { + assert!(num_bits <= 56); + let addr_byte = addr_bits / 8; + let bit_shift = (addr_bits % 8) as u64; + let val_unshifted_unmasked: u64 = if data.len() >= addr_byte + 8 { + let b = data[addr_byte..addr_byte + 8].try_into().unwrap(); + u64::from_le_bytes(b) + } else { + // the buffer is not large enough. + // Let's copy the few remaining bytes to a 8 byte buffer + // padded with 0s. + let mut buf = [0u8; 8]; + let data_to_copy = &data[addr_byte..]; + let nbytes = data_to_copy.len(); + buf[..nbytes].copy_from_slice(data_to_copy); + u64::from_le_bytes(buf) + }; + let val_shifted_unmasked = val_unshifted_unmasked >> bit_shift; + let mask = (1u64 << u64::from(num_bits)) - 1; + val_shifted_unmasked & mask +} + +impl BinarySerializable for BlockAddrBlockMetadata { + fn serialize<W: Write + ?Sized>(&self, write: &mut W) -> io::Result<()> { + self.offset.serialize(write)?; + self.ref_block_addr.serialize(write)?; + self.range_start_slope.serialize(write)?; + self.first_ordinal_slope.serialize(write)?; + write.write_all(&[self.first_ordinal_nbits, self.range_start_nbits])?; + self.block_len.serialize(write)?; + self.num_bits(); + Ok(()) + } + + fn deserialize<R: Read>(reader: &mut R) -> io::Result<Self> { + let offset = u64::deserialize(reader)?; + let ref_block_addr = BlockStartAddr::deserialize(reader)?; + let range_start_slope = u32::deserialize(reader)?; + let first_ordinal_slope = u32::deserialize(reader)?; + let mut buffer = [0u8; 2]; + reader.read_exact(&mut buffer)?; + let first_ordinal_nbits = buffer[0]; + let range_start_nbits = buffer[1]; + let block_len = u16::deserialize(reader)?; + Ok(BlockAddrBlockMetadata { + offset, + ref_block_addr, + range_start_slope, + first_ordinal_slope, + range_start_nbits, + first_ordinal_nbits, + block_len, + range_shift: 1 << (range_start_nbits - 1), + ordinal_shift: 1 << (first_ordinal_nbits - 1), + }) + } +} + +impl FixedSize for BlockAddrBlockMetadata { + const SIZE_IN_BYTES: usize = u64::SIZE_IN_BYTES + + BlockStartAddr::SIZE_IN_BYTES + + 2 * u32::SIZE_IN_BYTES + + 2 * u8::SIZE_IN_BYTES + + u16::SIZE_IN_BYTES; +} + +#[derive(Debug, Clone)] +struct BlockAddrStore { + block_meta_bytes: OwnedBytes, + addr_bytes: OwnedBytes, +} + +impl BlockAddrStore { + fn open(term_info_store_file: OwnedBytes) -> io::Result<BlockAddrStore> { + let (mut len_slice, main_slice) = term_info_store_file.split(8); + let len = u64::deserialize(&mut len_slice)? as usize; + let (block_meta_bytes, addr_bytes) = main_slice.split(len); + Ok(BlockAddrStore { + block_meta_bytes, + addr_bytes, + }) + } + + fn get_block_meta(&self, store_block_id: usize) -> Option<BlockAddrBlockMetadata> { + let mut block_data: &[u8] = self + .block_meta_bytes + .get(store_block_id * BlockAddrBlockMetadata::SIZE_IN_BYTES..)?; + BlockAddrBlockMetadata::deserialize(&mut block_data).ok() + } + + fn get(&self, block_id: u64) -> Option<BlockAddr> { + let store_block_id = (block_id as usize) / STORE_BLOCK_LEN; + let inner_offset = (block_id as usize) % STORE_BLOCK_LEN; + let block_addr_block_data = self.get_block_meta(store_block_id)?; + block_addr_block_data.deserialize_block_addr( + &self.addr_bytes[block_addr_block_data.offset as usize..], + inner_offset, + ) + } + + fn binary_search_ord(&self, ord: TermOrdinal) -> (u64, BlockAddr) { + let max_block = + (self.block_meta_bytes.len() / BlockAddrBlockMetadata::SIZE_IN_BYTES) as u64; + let get_first_ordinal = |block_id| { + // we can unwrap because block_id < max_block + self.get(block_id * STORE_BLOCK_LEN as u64) + .unwrap() + .first_ordinal + }; + let store_block_id = + binary_search(max_block, |block_id| get_first_ordinal(block_id).cmp(&ord)); + let store_block_id = match store_block_id { + Ok(store_block_id) => { + let block_id = store_block_id * STORE_BLOCK_LEN as u64; + // we can unwrap because store_block_id < max_block + return (block_id, self.get(block_id).unwrap()); + } + Err(store_block_id) => store_block_id - 1, + }; + + // we can unwrap because store_block_id < max_block + let block_addr_block_data = self.get_block_meta(store_block_id as usize).unwrap(); + let (inner_offset, block_addr) = block_addr_block_data.bisect_for_ord( + &self.addr_bytes[block_addr_block_data.offset as usize..], + ord, + ); + ( + store_block_id * STORE_BLOCK_LEN as u64 + inner_offset, + block_addr, + ) + } +} + +fn binary_search(max: u64, cmp_fn: impl Fn(u64) -> std::cmp::Ordering) -> Result<u64, u64> { + use std::cmp::Ordering::*; + let mut size = max; + let mut left = 0; + let mut right = size; + while left < right { + let mid = left + size / 2; + + let cmp = cmp_fn(mid); + + if cmp == Less { + left = mid + 1; + } else if cmp == Greater { + right = mid; + } else { + return Ok(mid); + } + + size = right - left; + } + Err(left) +} + +struct BlockAddrStoreWriter { + buffer_block_metas: Vec<u8>, + buffer_addrs: Vec<u8>, + block_addrs: Vec<BlockAddr>, +} + +impl BlockAddrStoreWriter { + fn new() -> Self { + BlockAddrStoreWriter { + buffer_block_metas: Vec::new(), + buffer_addrs: Vec::new(), + block_addrs: Vec::with_capacity(STORE_BLOCK_LEN), + } + } + + fn flush_block(&mut self) -> io::Result<()> { + if self.block_addrs.is_empty() { + return Ok(()); + } + let ref_block_addr = self.block_addrs[0].clone(); + + for block_addr in &mut self.block_addrs { + block_addr.byte_range.start -= ref_block_addr.byte_range.start; + block_addr.first_ordinal -= ref_block_addr.first_ordinal; + } + + // we are only called if block_addrs is not empty + let mut last_block_addr = self.block_addrs.last().unwrap().clone(); + last_block_addr.byte_range.end -= ref_block_addr.byte_range.start; + + // we skip(1), so we never give an index of 0 to find_best_slope + let (range_start_slope, range_start_nbits) = find_best_slope( + self.block_addrs + .iter() + .map(|block| block.byte_range.start as u64) + .chain(std::iter::once(last_block_addr.byte_range.end as u64)) + .enumerate() + .skip(1), + ); + + // we skip(1), so we never give an index of 0 to find_best_slope + let (first_ordinal_slope, first_ordinal_nbits) = find_best_slope( + self.block_addrs + .iter() + .map(|block| block.first_ordinal) + .enumerate() + .skip(1), + ); + + let range_shift = 1 << (range_start_nbits - 1); + let ordinal_shift = 1 << (first_ordinal_nbits - 1); + + let block_addr_block_meta = BlockAddrBlockMetadata { + offset: self.buffer_addrs.len() as u64, + ref_block_addr: ref_block_addr.to_block_start(), + range_start_slope, + first_ordinal_slope, + range_start_nbits, + first_ordinal_nbits, + block_len: self.block_addrs.len() as u16 - 1, + range_shift, + ordinal_shift, + }; + block_addr_block_meta.serialize(&mut self.buffer_block_metas)?; + + let mut bit_packer = BitPacker::new(); + + for (i, block_addr) in self.block_addrs.iter().enumerate().skip(1) { + let range_pred = (range_start_slope as usize * i) as i64; + bit_packer.write( + (block_addr.byte_range.start as i64 - range_pred + range_shift) as u64, + range_start_nbits, + &mut self.buffer_addrs, + )?; + let first_ordinal_pred = (first_ordinal_slope as u64 * i as u64) as i64; + bit_packer.write( + (block_addr.first_ordinal as i64 - first_ordinal_pred + ordinal_shift) as u64, + first_ordinal_nbits, + &mut self.buffer_addrs, + )?; + } + + let range_pred = (range_start_slope as usize * self.block_addrs.len()) as i64; + bit_packer.write( + (last_block_addr.byte_range.end as i64 - range_pred + range_shift) as u64, + range_start_nbits, + &mut self.buffer_addrs, + )?; + bit_packer.flush(&mut self.buffer_addrs)?; + + self.block_addrs.clear(); + Ok(()) + } + + fn write_block_meta(&mut self, block_addr: BlockAddr) -> io::Result<()> { + self.block_addrs.push(block_addr); + if self.block_addrs.len() >= STORE_BLOCK_LEN { + self.flush_block()?; + } + Ok(()) + } + + fn serialize<W: std::io::Write>(&mut self, wrt: &mut W) -> io::Result<()> { + self.flush_block()?; + let len = self.buffer_block_metas.len() as u64; + len.serialize(wrt)?; + wrt.write_all(&self.buffer_block_metas)?; + wrt.write_all(&self.buffer_addrs)?; + Ok(()) + } +} + +/// Given an iterator over (index, value), returns the slope, and number of bits needed to +/// represente the error to a prediction made by this slope. +/// +/// The iterator may be empty, but all indexes in it must be non-zero. +fn find_best_slope(elements: impl Iterator<Item = (usize, u64)> + Clone) -> (u32, u8) { + let slope_iterator = elements.clone(); + let derivation_iterator = elements; + + let mut min_slope_idx = 1; + let mut min_slope_val = 0; + let mut min_slope = u32::MAX; + let mut max_slope_idx = 1; + let mut max_slope_val = 0; + let mut max_slope = 0; + for (index, value) in slope_iterator { + let slope = (value / index as u64) as u32; + if slope <= min_slope { + min_slope = slope; + min_slope_idx = index; + min_slope_val = value; + } + if slope >= max_slope { + max_slope = slope; + max_slope_idx = index; + max_slope_val = value; + } + } + + // above is an heuristic giving the "highest" and "lowest" point. It's imperfect in that in that + // a point that appear earlier might have a high slope derivation, but a smaller absolute + // derivation than a latter point. + // The actual best values can be obtained by using the symplex method, but the improvement is + // likely minimal, and computation is way more complexe. + // + // Assuming these point are the furthest up and down, we find the slope that would cause the + // same positive derivation for the highest as negative derivation for the lowest. + // A is the optimal slope. B is the derivation to the guess + // + // 0 = min_slope_val - min_slope_idx * A - B + // 0 = max_slope_val - max_slope_idx * A + B + // + // 0 = min_slope_val + max_slope_val - (min_slope_idx + max_slope_idx) * A + // (min_slope_val + max_slope_val) / (min_slope_idx + max_slope_idx) = A + // + // we actually add some correcting factor to have proper rounding, not truncation. + + let denominator = (min_slope_idx + max_slope_idx) as u64; + let final_slope = ((min_slope_val + max_slope_val + denominator / 2) / denominator) as u32; + + // we don't solve for B because our choice of point is suboptimal, so it's actually a lower + // bound and we need to iterate to find the actual worst value. + + let max_derivation: u64 = derivation_iterator + .map(|(index, value)| (value as i64 - final_slope as i64 * index as i64).unsigned_abs()) + .max() + .unwrap_or(0); + + (final_slope, compute_num_bits(max_derivation) + 1) +} + +#[cfg(test)] +mod tests { + use common::OwnedBytes; + + use super::{BlockAddr, SSTableIndexBuilder, SSTableIndexV3}; + use crate::SSTableDataCorruption; + + #[test] + fn test_sstable_index() { + let mut sstable_builder = SSTableIndexBuilder::default(); + sstable_builder.add_block(b"aaa", 10..20, 0u64); + sstable_builder.add_block(b"bbbbbbb", 20..30, 5u64); + sstable_builder.add_block(b"ccc", 30..40, 10u64); + sstable_builder.add_block(b"dddd", 40..50, 15u64); + let mut buffer: Vec<u8> = Vec::new(); + let fst_len = sstable_builder.serialize(&mut buffer).unwrap(); + let buffer = OwnedBytes::new(buffer); + let sstable_index = SSTableIndexV3::load(buffer, fst_len).unwrap(); + assert_eq!( + sstable_index.get_block_with_key(b"bbbde"), + Some(BlockAddr { + first_ordinal: 10u64, + byte_range: 30..40 + }) + ); + + assert_eq!(sstable_index.locate_with_key(b"aa").unwrap(), 0); + assert_eq!(sstable_index.locate_with_key(b"aaa").unwrap(), 0); + assert_eq!(sstable_index.locate_with_key(b"aab").unwrap(), 1); + assert_eq!(sstable_index.locate_with_key(b"ccc").unwrap(), 2); + assert!(sstable_index.locate_with_key(b"e").is_none()); + + assert_eq!(sstable_index.locate_with_ord(0), 0); + assert_eq!(sstable_index.locate_with_ord(1), 0); + assert_eq!(sstable_index.locate_with_ord(4), 0); + assert_eq!(sstable_index.locate_with_ord(5), 1); + assert_eq!(sstable_index.locate_with_ord(100), 3); + } + + #[test] + fn test_sstable_with_corrupted_data() { + let mut sstable_builder = SSTableIndexBuilder::default(); + sstable_builder.add_block(b"aaa", 10..20, 0u64); + sstable_builder.add_block(b"bbbbbbb", 20..30, 5u64); + sstable_builder.add_block(b"ccc", 30..40, 10u64); + sstable_builder.add_block(b"dddd", 40..50, 15u64); + let mut buffer: Vec<u8> = Vec::new(); + let fst_len = sstable_builder.serialize(&mut buffer).unwrap(); + buffer[2] = 9u8; + let buffer = OwnedBytes::new(buffer); + let data_corruption_err = SSTableIndexV3::load(buffer, fst_len).err().unwrap(); + assert!(matches!(data_corruption_err, SSTableDataCorruption)); + } + + #[track_caller] + fn test_find_shorter_str_in_between_aux(left: &[u8], right: &[u8]) { + let mut left_buf = left.to_vec(); + super::find_shorter_str_in_between(&mut left_buf, right); + assert!(left_buf.len() <= left.len()); + assert!(left <= &left_buf); + assert!(&left_buf[..] < right); + } + + #[test] + fn test_find_shorter_str_in_between() { + test_find_shorter_str_in_between_aux(b"", b"hello"); + test_find_shorter_str_in_between_aux(b"abc", b"abcd"); + test_find_shorter_str_in_between_aux(b"abcd", b"abd"); + test_find_shorter_str_in_between_aux(&[0, 0, 0], &[1]); + test_find_shorter_str_in_between_aux(&[0, 0, 0], &[0, 0, 1]); + test_find_shorter_str_in_between_aux(&[0, 0, 255, 255, 255, 0u8], &[0, 1]); + } + + use proptest::prelude::*; + + proptest! { + #![proptest_config(ProptestConfig::with_cases(100))] + #[test] + fn test_proptest_find_shorter_str(left in any::<Vec<u8>>(), right in any::<Vec<u8>>()) { + if left < right { + test_find_shorter_str_in_between_aux(&left, &right); + } + } + } + + #[test] + fn test_find_best_slop() { + assert_eq!(super::find_best_slope(std::iter::empty()), (0, 1)); + assert_eq!( + super::find_best_slope(std::iter::once((1, 12345))), + (12345, 1) + ); + } +} diff --git a/stacker/Cargo.toml b/stacker/Cargo.toml index e0080a9a4..1702940cd 100644 --- a/stacker/Cargo.toml +++ b/stacker/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tantivy-stacker" -version = "0.2.0" +version = "0.3.0" edition = "2021" license = "MIT" homepage = "https://github.com/quickwit-oss/tantivy" @@ -9,8 +9,9 @@ description = "term hashmap used for indexing" [dependencies] murmurhash32 = "0.3" -common = { version = "0.6", path = "../common/", package = "tantivy-common" } -ahash = { version = "0.8.3", default-features = false, optional = true } +common = { version = "0.7", path = "../common/", package = "tantivy-common" } +ahash = { version = "0.8.11", default-features = false, optional = true } +rand_distr = "0.4.3" [[bench]] harness = false diff --git a/stacker/fuzz_test/Cargo.toml b/stacker/fuzz_test/Cargo.toml new file mode 100644 index 000000000..02478c95b --- /dev/null +++ b/stacker/fuzz_test/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "fuzz_test" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +ahash = "0.8.7" +rand = "0.8.5" +rand_distr = "0.4.3" +tantivy-stacker = { version = "0.2.0", path = ".." } + +[workspace] + diff --git a/stacker/fuzz_test/src/main.rs b/stacker/fuzz_test/src/main.rs new file mode 100644 index 000000000..2367ddc33 --- /dev/null +++ b/stacker/fuzz_test/src/main.rs @@ -0,0 +1,45 @@ +use ahash::AHashMap; +use rand::{rngs::StdRng, Rng, SeedableRng}; +use rand_distr::Exp; +use tantivy_stacker::ArenaHashMap; + +fn main() { + for _ in 0..1_000_000 { + let seed: u64 = rand::random(); + test_with_seed(seed); + } +} + +fn test_with_seed(seed: u64) { + let mut hash_map = AHashMap::new(); + let mut arena_hashmap = ArenaHashMap::default(); + let mut rng = StdRng::seed_from_u64(seed); + let key_count = rng.gen_range(1_000..=1_000_000); + let exp = Exp::new(0.05).unwrap(); + + for _ in 0..key_count { + let key_length = rng.sample::<f32, _>(exp).min(u16::MAX as f32).max(1.0) as usize; + + let key: Vec<u8> = (0..key_length).map(|_| rng.gen()).collect(); + + arena_hashmap.mutate_or_create(&key, |current_count| { + let count: u64 = current_count.unwrap_or(0); + count + 1 + }); + hash_map.entry(key).and_modify(|e| *e += 1).or_insert(1); + } + + println!( + "Seed: {} \t {:.2}MB", + seed, + arena_hashmap.memory_arena.len() as f32 / 1024.0 / 1024.0 + ); + // Check the contents of the ArenaHashMap + for (key, addr) in arena_hashmap.iter() { + let count: u64 = arena_hashmap.read(addr); + let count_expected = hash_map + .get(key) + .unwrap_or_else(|| panic!("NOT FOUND: Key: {:?}, Count: {}", key, count)); + assert_eq!(count, *count_expected); + } +} diff --git a/stacker/src/expull.rs b/stacker/src/expull.rs index cbda3b8e9..28a14f214 100644 --- a/stacker/src/expull.rs +++ b/stacker/src/expull.rs @@ -151,7 +151,6 @@ impl ExpUnrolledLinkedList { mod tests { use common::{read_u32_vint, write_u32_vint}; - use super::super::MemoryArena; use super::*; #[test] diff --git a/stacker/src/memory_arena.rs b/stacker/src/memory_arena.rs index 0d5de72f0..5c5bf44cf 100644 --- a/stacker/src/memory_arena.rs +++ b/stacker/src/memory_arena.rs @@ -113,6 +113,15 @@ impl MemoryArena { self.pages.len() * PAGE_SIZE } + /// Returns the number of bytes allocated in the arena. + pub fn len(&self) -> usize { + self.pages.len().saturating_sub(1) * PAGE_SIZE + self.pages.last().unwrap().len + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + #[inline] pub fn write_at<Item: Copy + 'static>(&mut self, addr: Addr, val: Item) { let dest = self.slice_mut(addr, std::mem::size_of::<Item>()); @@ -189,6 +198,11 @@ struct Page { impl Page { fn new(page_id: usize) -> Page { + // We use 32-bits addresses. + // - 20 bits for the in-page addressing + // - 12 bits for the page id. + // This limits us to 2^12 - 1=4095 for the page id. + assert!(page_id < 4096); Page { page_id, len: 0, @@ -238,6 +252,7 @@ impl Page { mod tests { use super::MemoryArena; + use crate::memory_arena::PAGE_SIZE; #[test] fn test_arena_allocate_slice() { @@ -255,6 +270,31 @@ mod tests { assert_eq!(arena.slice(addr_b, b.len()), b); } + #[test] + fn test_arena_allocate_end_of_page() { + let mut arena = MemoryArena::default(); + + // A big block + let len_a = PAGE_SIZE - 2; + let addr_a = arena.allocate_space(len_a); + *arena.slice_mut(addr_a, len_a).last_mut().unwrap() = 1; + + // Single bytes + let addr_b = arena.allocate_space(1); + arena.slice_mut(addr_b, 1)[0] = 2; + + let addr_c = arena.allocate_space(1); + arena.slice_mut(addr_c, 1)[0] = 3; + + let addr_d = arena.allocate_space(1); + arena.slice_mut(addr_d, 1)[0] = 4; + + assert_eq!(arena.slice(addr_a, len_a)[len_a - 1], 1); + assert_eq!(arena.slice(addr_b, 1)[0], 2); + assert_eq!(arena.slice(addr_c, 1)[0], 3); + assert_eq!(arena.slice(addr_d, 1)[0], 4); + } + #[derive(Clone, Copy, Debug, Eq, PartialEq)] struct MyTest { pub a: usize, diff --git a/stacker/src/shared_arena_hashmap.rs b/stacker/src/shared_arena_hashmap.rs index 0dbae3dfd..0e50315d4 100644 --- a/stacker/src/shared_arena_hashmap.rs +++ b/stacker/src/shared_arena_hashmap.rs @@ -295,6 +295,8 @@ impl SharedArenaHashMap { /// will be in charge of returning a default value. /// If the key already as an associated value, then it will be passed /// `Some(previous_value)`. + /// + /// The key will be truncated to u16::MAX bytes. #[inline] pub fn mutate_or_create<V>( &mut self, @@ -308,6 +310,8 @@ impl SharedArenaHashMap { if self.is_saturated() { self.resize(); } + // Limit the key size to u16::MAX + let key = &key[..std::cmp::min(key.len(), u16::MAX as usize)]; let hash = self.get_hash(key); let mut probe = self.probe(hash); let mut bucket = probe.next_probe(); @@ -379,6 +383,36 @@ mod tests { } assert_eq!(vanilla_hash_map.len(), 2); } + + #[test] + fn test_long_key_truncation() { + // Keys longer than u16::MAX are truncated. + let mut memory_arena = MemoryArena::default(); + let mut hash_map: SharedArenaHashMap = SharedArenaHashMap::default(); + let key1 = (0..u16::MAX as usize).map(|i| i as u8).collect::<Vec<_>>(); + hash_map.mutate_or_create(&key1, &mut memory_arena, |opt_val: Option<u32>| { + assert_eq!(opt_val, None); + 4u32 + }); + // Due to truncation, this key is the same as key1 + let key2 = (0..u16::MAX as usize + 1) + .map(|i| i as u8) + .collect::<Vec<_>>(); + hash_map.mutate_or_create(&key2, &mut memory_arena, |opt_val: Option<u32>| { + assert_eq!(opt_val, Some(4)); + 3u32 + }); + let mut vanilla_hash_map = HashMap::new(); + let iter_values = hash_map.iter(&memory_arena); + for (key, addr) in iter_values { + let val: u32 = memory_arena.read(addr); + vanilla_hash_map.insert(key.to_owned(), val); + assert_eq!(key.len(), key1[..].len()); + assert_eq!(key, &key1[..]) + } + assert_eq!(vanilla_hash_map.len(), 1); // Both map to the same key + } + #[test] fn test_empty_hashmap() { let memory_arena = MemoryArena::default(); diff --git a/tokenizer-api/Cargo.toml b/tokenizer-api/Cargo.toml index e8e47589f..0ebcbb89a 100644 --- a/tokenizer-api/Cargo.toml +++ b/tokenizer-api/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tantivy-tokenizer-api" -version = "0.2.0" +version = "0.3.0" license = "MIT" edition = "2021" description = "Tokenizer API of tantivy"