Compare commits

..

8 Commits

Author SHA1 Message Date
Pascal Seitz
0bdec77410 pub method on Term 2026-02-24 13:31:51 +01:00
Pascal Seitz
1a1c29c785 allow Searcher to be constructed without index 2026-02-17 17:56:33 +01:00
Pascal Seitz
8a16afa2f1 add to_json->Value method 2026-02-17 13:21:23 +01:00
Pascal Seitz
e841cebba4 convert StoreReader to trait
this will remove the DocumentDeserialize (maybe added later in a
different form)
2026-02-16 17:33:49 +01:00
Pascal Seitz
05f255b757 add async methods for quickwit 2026-02-16 10:32:32 +01:00
Pascal Seitz
e6318e1591 add comments, remove fieldnorms 2026-02-12 14:13:33 +01:00
Pascal Seitz
70bb97231b remove fieldnorms_readers 2026-02-11 19:40:45 +01:00
Paul Masurel
6038455761 First stab at tantivy's codec
Convert SegmentReader, InvertedIndexReader and postinglists to traits.
Add special functions to pushdown certain performance methods to keep
them strictly typed.

We rely on a ObjectSafeCodec contraption to avoid the proliferation of generics.
That object's point is to make sure we can build TermScorer with a concrete
codec specific type before reboxing it. (same thing for PhraseScorer).

fix performance regression: fix incorrect scorer cast for buffered union
bock wand
2026-02-11 15:11:29 +01:00
179 changed files with 3039 additions and 12047 deletions

View File

@@ -1,87 +0,0 @@
---
name: update-changelog
description: Update CHANGELOG.md with merged PRs since the last changelog update, categorized by type
---
# Update Changelog
This skill updates CHANGELOG.md with merged PRs that aren't already listed.
## Step 1: Determine the changelog scope
Read `CHANGELOG.md` to identify the current unreleased version section at the top (e.g., `Tantivy 0.26 (Unreleased)`).
Collect all PR numbers already mentioned in the unreleased section by extracting `#NNNN` references.
## Step 2: Find merged PRs not yet in the changelog
Use `gh` to list recently merged PRs from the upstream repo:
```bash
gh pr list --repo quickwit-oss/tantivy --state merged --limit 100 --json number,title,author,labels,mergedAt
```
Filter out any PRs whose number already appears in the unreleased section of the changelog.
## Step 3: Consolidate related PRs
Before categorizing, group PRs that belong to the same logical change. This is critical for producing a clean changelog. Use PR descriptions, titles, cross-references, and the files touched to identify relationships.
**Merge follow-up PRs into the original:**
- If a PR is a bugfix, refinement, or follow-up to another PR in the same unreleased cycle, combine them into a single changelog entry with multiple `[#N](url)` links.
- Also consolidate PRs that touch the same feature area even if not explicitly linked — e.g., a PR fixing an edge case in a new API should be folded into the entry for the PR that introduced that API.
**Filter out bugfixes on unreleased features:**
- If a bugfix PR fixes something introduced by another PR in the **same unreleased version**, it must NOT appear as a separate Bugfixes entry. Instead, silently fold it into the original feature/improvement entry. The changelog should describe the final shipped state, not the development history.
- To detect this: check if the bugfix PR references or reverts changes from another PR in the same release cycle, or if it touches code that was newly added (not present in the previous release).
## Step 4: Review the actual code diff
**Do not rely on PR titles or descriptions alone.** For every candidate PR, run `gh pr diff <number> --repo quickwit-oss/tantivy` and read the actual changes. PR titles are often misleading — the diff is the source of truth.
**What to look for in the diff:**
- Does it change observable behavior, public API surface, or performance characteristics?
- Is the change something a user of the library would notice or need to know about?
- Could the change break existing code (API changes, removed features)?
**Skip PRs where the diff reveals the change is not meaningful enough for the changelog** — e.g., cosmetic renames, trivial visibility tweaks, test-only changes, etc.
## Step 5: Categorize each PR group
For each PR (or consolidated group) that survived the diff review, determine its category:
- **Bugfixes** — fixes to behavior that existed in the **previous release**. NOT fixes to features introduced in this release cycle.
- **Features/Improvements** — new features, API additions, new options, improvements that change user-facing behavior or add new capabilities.
- **Performance** — optimizations, speed improvements, memory reductions. **If a PR adds new API whose primary purpose is enabling a performance optimization, categorize it as Performance, not Features.** The deciding question is: does a user benefit from this because of new functionality, or because things got faster/leaner? For example, a new trait method that exists solely to enable cheaper intersection ordering is Performance, not a Feature.
If a PR doesn't clearly fit any category (e.g., CI-only changes, internal refactors with no user-facing impact, dependency bumps with no behavior change), skip it — not everything belongs in the changelog.
When unclear, use your best judgment or ask the user.
## Step 6: Format entries
Each entry must follow this exact format:
```
- Description [#NUMBER](https://github.com/quickwit-oss/tantivy/pull/NUMBER)(@author)
```
Rules:
- The description should be concise and describe the user-facing change (not the implementation). Describe the final shipped state, not the incremental development steps.
- Use sub-categories with bold headers when multiple entries relate to the same area (e.g., `- **Aggregation**` with indented entries beneath). Follow the existing grouping style in the changelog.
- Author is the GitHub username from the PR, prefixed with `@`. For consolidated entries, include all contributing authors.
- For consolidated PRs, list all PR links in a single entry: `[#100](url) [#110](url)` (see existing entries for examples).
## Step 7: Present changes to the user
Show the user the proposed changelog entries grouped by category **before** editing the file. Ask for confirmation or adjustments.
## Step 8: Update CHANGELOG.md
Insert the new entries into the appropriate sections of the unreleased version block. If a section doesn't exist yet, create it following the order: Bugfixes, Features/Improvements, Performance.
Append new entries at the end of each section (before the next section header or version header).
## Step 9: Verify
Read back the updated unreleased section and display it to the user for final review.

View File

@@ -6,8 +6,6 @@ updates:
interval: daily
time: "20:00"
open-pull-requests-limit: 10
cooldown:
default-days: 2
- package-ecosystem: "github-actions"
directory: "/"
@@ -15,5 +13,3 @@ updates:
interval: daily
time: "20:00"
open-pull-requests-limit: 10
cooldown:
default-days: 2

View File

@@ -4,9 +4,6 @@ on:
push:
branches: [main]
permissions:
contents: read
# Ensures that we cancel running jobs for the same PR / same workflow.
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
@@ -15,20 +12,16 @@ concurrency:
jobs:
coverage:
runs-on: ubuntu-latest
permissions:
contents: read
steps:
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
- uses: actions/checkout@v4
- name: Install Rust
run: rustup toolchain install nightly-2025-12-01 --profile minimal --component llvm-tools-preview
- uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # v2.9.1
- uses: taiki-e/install-action@e4b3a0453201addddc06d3a72db90326aad87084 # cargo-llvm-cov
- uses: Swatinem/rust-cache@v2
- uses: taiki-e/install-action@cargo-llvm-cov
- name: Generate code coverage
run: cargo +nightly-2025-12-01 llvm-cov --all-features --workspace --doctests --lcov --output-path lcov.info
- name: Upload coverage to Codecov
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
uses: codecov/codecov-action@v3
continue-on-error: true
with:
token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos

View File

@@ -8,9 +8,6 @@ env:
CARGO_TERM_COLOR: always
NUM_FUNCTIONAL_TEST_ITERATIONS: 20000
permissions:
contents: read
# Ensures that we cancel running jobs for the same PR / same workflow.
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
@@ -21,13 +18,10 @@ jobs:
runs-on: ubuntu-latest
permissions:
contents: read
steps:
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
- uses: actions/checkout@v4
- name: Install stable
uses: actions-rs/toolchain@16499b5e05bf2e26879000db0c1d13f7e13fa3af # v1.0.7
uses: actions-rs/toolchain@v1
with:
toolchain: stable
profile: minimal

View File

@@ -1,49 +0,0 @@
name: OpenSSF Scorecard
on:
schedule:
- cron: '0 0 * * 0'
push:
branches:
- main
permissions:
contents: read
jobs:
analysis:
name: Scorecards analysis
runs-on: ubuntu-latest
permissions:
# Needed to upload the results to code-scanning dashboard.
security-events: write
# Needed to publish results
id-token: write
steps:
- name: 'Checkout code'
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
with:
persist-credentials: false
- name: 'Run analysis'
uses: ossf/scorecard-action@4eaacf0543bb3f2c246792bd56e8cdeffafb205a # v2.4.3
with:
results_file: results.sarif
results_format: sarif
repo_token: ${{ secrets.GITHUB_TOKEN }}
publish_results: true
# Upload the results as artifacts.
- name: 'Upload artifact'
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
with:
name: SARIF file
path: results.sarif
retention-days: 5
# Upload the results to GitHub's code scanning dashboard.
- name: 'Upload to code-scanning'
uses: github/codeql-action/upload-sarif@87557b9c84dde89fdd9b10e88954ac2f4248e463 # v4.36.1
with:
sarif_file: results.sarif

View File

@@ -9,9 +9,6 @@ on:
env:
CARGO_TERM_COLOR: always
permissions:
contents: read
# Ensures that we cancel running jobs for the same PR / same workflow.
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
@@ -22,27 +19,23 @@ jobs:
runs-on: ubuntu-latest
permissions:
contents: read
checks: write
steps:
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
- uses: actions/checkout@v4
- name: Install nightly
uses: actions-rs/toolchain@16499b5e05bf2e26879000db0c1d13f7e13fa3af # v1.0.7
uses: actions-rs/toolchain@v1
with:
toolchain: nightly
profile: minimal
components: rustfmt
- name: Install stable
uses: actions-rs/toolchain@16499b5e05bf2e26879000db0c1d13f7e13fa3af # v1.0.7
uses: actions-rs/toolchain@v1
with:
toolchain: stable
profile: minimal
components: clippy
- uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # v2.9.1
- uses: Swatinem/rust-cache@v2
- name: Check Formatting
run: cargo +nightly fmt --all -- --check
@@ -54,7 +47,7 @@ jobs:
- name: Check Bench Compilation
run: cargo +nightly bench --no-run --profile=dev --all-features
- uses: actions-rs/clippy-check@b5b5f21f4797c02da247df37026fcd0a5024aa4d # v1.0.7
- uses: actions-rs/clippy-check@v1
with:
toolchain: stable
token: ${{ secrets.GITHUB_TOKEN }}
@@ -64,9 +57,6 @@ jobs:
runs-on: ubuntu-latest
permissions:
contents: read
strategy:
matrix:
features:
@@ -77,17 +67,17 @@ jobs:
name: test-${{ matrix.features.label}}
steps:
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
- uses: actions/checkout@v4
- name: Install stable
uses: actions-rs/toolchain@16499b5e05bf2e26879000db0c1d13f7e13fa3af # v1.0.7
uses: actions-rs/toolchain@v1
with:
toolchain: stable
profile: minimal
override: true
- uses: taiki-e/install-action@56cc9adf3a3e2c23eafb56e8acaf9d0373cb845a # nextest
- uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # v2.9.1
- uses: taiki-e/install-action@nextest
- uses: Swatinem/rust-cache@v2
- name: Run tests
run: |

View File

@@ -1,58 +1,3 @@
Tantivy 0.26.1
================================
## Performance
- Fix quadratic runtime in nested term and composite aggregations: memory accounting scanned all parent buckets on every collect instead of just the current parent (@PSeitz @fulmicoton)
Tantivy 0.26 (Unreleased)
================================
## Bugfixes
- Align float query coercion during search with the columnar coercion rules [#2692](https://github.com/quickwit-oss/tantivy/pull/2692)(@fulmicoton)
- Fix lenient elastic range queries with trailing closing parentheses [#2816](https://github.com/quickwit-oss/tantivy/pull/2816)(@evance-br)
- Fix intersection `seek()` advancing below current doc id [#2812](https://github.com/quickwit-oss/tantivy/pull/2812)(@fulmicoton)
- Fix phrase query prefixed with `*` [#2751](https://github.com/quickwit-oss/tantivy/pull/2751)(@Darkheir)
- Fix `vint` buffer overflow during index creation [#2778](https://github.com/quickwit-oss/tantivy/pull/2778)(@rebasedming)
- Fix integer overflow in `ExpUnrolledLinkedList` for large datasets [#2735](https://github.com/quickwit-oss/tantivy/pull/2735)(@mdashti)
- Fix integer overflow in segment sorting and merge policy truncation [#2846](https://github.com/quickwit-oss/tantivy/pull/2846)(@anaslimem)
- Fix merging of intermediate aggregation results [#2719](https://github.com/quickwit-oss/tantivy/pull/2719)(@PSeitz)
- Fix deduplicate doc counts in term aggregation for multi-valued fields [#2854](https://github.com/quickwit-oss/tantivy/pull/2854)(@nuri-yoo)
## Features/Improvements
- **Aggregation**
- Add filter aggregation [#2711](https://github.com/quickwit-oss/tantivy/pull/2711)(@mdashti)
- Add include/exclude filtering for term aggregations [#2717](https://github.com/quickwit-oss/tantivy/pull/2717)(@PSeitz)
- Add public accessors for intermediate aggregation results [#2829](https://github.com/quickwit-oss/tantivy/pull/2829)(@congx4)
- Replace HyperLogLog++ with Apache DataSketches HLL for cardinality aggregation [#2837](https://github.com/quickwit-oss/tantivy/pull/2837) [#2842](https://github.com/quickwit-oss/tantivy/pull/2842)(@congx4)
- Add composite aggregation [#2856](https://github.com/quickwit-oss/tantivy/pull/2856)(@fulmicoton)
- **Fast Fields**
- Add fast field fallback for `TermQuery` when the field is not indexed [#2693](https://github.com/quickwit-oss/tantivy/pull/2693)(@PSeitz-dd)
- Add fast field support for `Bytes` values [#2830](https://github.com/quickwit-oss/tantivy/pull/2830)(@mdashti)
- **Query Parser**
- Add support for regexes in the query grammar [#2677](https://github.com/quickwit-oss/tantivy/pull/2677) [#2818](https://github.com/quickwit-oss/tantivy/pull/2818)(@Darkheir)
- Deduplicate queries in query parser [#2698](https://github.com/quickwit-oss/tantivy/pull/2698)(@PSeitz-dd)
- Add erased `SortKeyComputer` for sorting on column types unknown until runtime [#2770](https://github.com/quickwit-oss/tantivy/pull/2770) [#2790](https://github.com/quickwit-oss/tantivy/pull/2790)(@stuhood @PSeitz)
- Add natural-order-with-none-highest support in `TopDocs::order_by` [#2780](https://github.com/quickwit-oss/tantivy/pull/2780)(@stuhood)
- Move stemming behing `stemmer` feature flag [#2791](https://github.com/quickwit-oss/tantivy/pull/2791)(@fulmicoton)
- Make `DeleteMeta`, `AddOperation`, `advance_deletes`, `with_max_doc`, `serializer` module, and `delete_queue` public [#2762](https://github.com/quickwit-oss/tantivy/pull/2762) [#2765](https://github.com/quickwit-oss/tantivy/pull/2765) [#2766](https://github.com/quickwit-oss/tantivy/pull/2766) [#2835](https://github.com/quickwit-oss/tantivy/pull/2835)(@philippemnoel @PSeitz)
- Make `Language` hashable [#2763](https://github.com/quickwit-oss/tantivy/pull/2763)(@philippemnoel)
- Improve `space_usage` reporting for JSON fields and columnar data [#2761](https://github.com/quickwit-oss/tantivy/pull/2761)(@PSeitz-dd)
- Split `Term` into `Term` and `IndexingTerm` [#2744](https://github.com/quickwit-oss/tantivy/pull/2744) [#2750](https://github.com/quickwit-oss/tantivy/pull/2750)(@PSeitz-dd @PSeitz)
## Performance
- **Aggregation**
- Large speed up and memory reduction for nested high cardinality aggregations by using one collector per request instead of one per bucket, and adding `PagedTermMap` for faster medium cardinality term aggregations [#2715](https://github.com/quickwit-oss/tantivy/pull/2715) [#2759](https://github.com/quickwit-oss/tantivy/pull/2759)(@PSeitz @PSeitz-dd)
- Optimize low-cardinality term aggregations by using a `Vec` instead of a `HashMap` [#2740](https://github.com/quickwit-oss/tantivy/pull/2740)(@fulmicoton-dd)
- Optimize `ExistsQuery` for a high number of dynamic columns [#2694](https://github.com/quickwit-oss/tantivy/pull/2694)(@PSeitz-dd)
- Add lazy scorers to stop score evaluation early when a doc won't reach the top-K threshold [#2726](https://github.com/quickwit-oss/tantivy/pull/2726) [#2777](https://github.com/quickwit-oss/tantivy/pull/2777)(@fulmicoton @stuhood)
- Add `DocSet::cost()` and use it to order scorers in intersections [#2707](https://github.com/quickwit-oss/tantivy/pull/2707)(@PSeitz)
- Add `collect_block` support for collector wrappers [#2727](https://github.com/quickwit-oss/tantivy/pull/2727)(@stuhood)
- Optimize saturated posting lists by replacing them with `AllScorer` in boolean queries [#2745](https://github.com/quickwit-oss/tantivy/pull/2745) [#2760](https://github.com/quickwit-oss/tantivy/pull/2760) [#2774](https://github.com/quickwit-oss/tantivy/pull/2774)(@fulmicoton @mdashti @trinity-1686a)
- Add `seek_danger` on `DocSet` for more efficient intersections [#2538](https://github.com/quickwit-oss/tantivy/pull/2538) [#2810](https://github.com/quickwit-oss/tantivy/pull/2810)(@PSeitz @stuhood @fulmicoton)
- Skip column traversal in `RangeDocSet` when query range does not overlap with column bounds [#2783](https://github.com/quickwit-oss/tantivy/pull/2783)(@ChangRui-Ryan)
- Speed up exclude queries by supporting multiple excluded `DocSet`s without intermediate union [#2825](https://github.com/quickwit-oss/tantivy/pull/2825)(@PSeitz)
- Improve union performance for non-score unions with `fill_buffer` and optimized `TinySet` [#2863](https://github.com/quickwit-oss/tantivy/pull/2863)(@PSeitz)
Tantivy 0.25
================================

View File

@@ -11,7 +11,7 @@ repository = "https://github.com/quickwit-oss/tantivy"
readme = "README.md"
keywords = ["search", "information", "retrieval"]
edition = "2021"
rust-version = "1.86"
rust-version = "1.85"
exclude = ["benches/*.json", "benches/*.txt"]
[dependencies]
@@ -27,7 +27,7 @@ regex = { version = "1.5.5", default-features = false, features = [
aho-corasick = "1.0"
tantivy-fst = "0.5"
memmap2 = { version = "0.9.0", optional = true }
lz4_flex = { version = "0.13", default-features = false, optional = true }
lz4_flex = { version = "0.12", default-features = false, optional = true }
zstd = { version = "0.13", optional = true, default-features = false }
tempfile = { version = "3.12.0", optional = true }
log = "0.4.16"
@@ -47,7 +47,7 @@ rustc-hash = "2.0.0"
thiserror = "2.0.1"
htmlescape = "0.3.1"
fail = { version = "0.5.0", optional = true }
time = { version = "0.3.47", features = ["serde-well-known"] }
time = { version = "0.3.35", features = ["serde-well-known"] }
smallvec = "1.8.0"
rayon = "1.5.2"
lru = "0.16.3"
@@ -57,15 +57,15 @@ measure_time = "0.9.0"
arc-swap = "1.5.0"
bon = "3.3.1"
columnar = { version = "0.7", path = "./columnar", package = "tantivy-columnar" }
sstable = { version = "0.7", path = "./sstable", package = "tantivy-sstable", optional = true }
stacker = { version = "0.7", path = "./stacker", package = "tantivy-stacker" }
query-grammar = { version = "0.26.0", path = "./query-grammar", package = "tantivy-query-grammar" }
tantivy-bitpacker = { version = "0.10", path = "./bitpacker" }
common = { version = "0.11", path = "./common/", package = "tantivy-common" }
tokenizer-api = { version = "0.7", path = "./tokenizer-api", package = "tantivy-tokenizer-api" }
sketches-ddsketch = { version = "0.4", features = ["use_serde"] }
datasketches = { version = "0.3.0", features = ["hll"] }
columnar = { version = "0.6", path = "./columnar", package = "tantivy-columnar" }
sstable = { version = "0.6", path = "./sstable", package = "tantivy-sstable", optional = true }
stacker = { version = "0.6", path = "./stacker", package = "tantivy-stacker" }
query-grammar = { version = "0.25.0", path = "./query-grammar", package = "tantivy-query-grammar" }
tantivy-bitpacker = { version = "0.9", path = "./bitpacker" }
common = { version = "0.10", path = "./common/", package = "tantivy-common" }
tokenizer-api = { version = "0.6", path = "./tokenizer-api", package = "tantivy-tokenizer-api" }
sketches-ddsketch = { version = "0.3.0", features = ["use_serde"] }
hyperloglogplus = { version = "0.4.1", features = ["const-loop"] }
futures-util = { version = "0.3.28", optional = true }
futures-channel = { version = "0.3.28", optional = true }
fnv = "1.0.7"
@@ -75,7 +75,7 @@ typetag = "0.2.21"
winapi = "0.3.9"
[dev-dependencies]
binggan = "0.17.0"
binggan = "0.14.2"
rand = "0.9"
maplit = "1.0.2"
matches = "0.1.9"
@@ -86,13 +86,13 @@ futures = "0.3.21"
paste = "1.0.11"
more-asserts = "0.3.1"
rand_distr = "0.5"
time = { version = "0.3.47", features = ["serde-well-known", "macros"] }
time = { version = "0.3.10", features = ["serde-well-known", "macros"] }
postcard = { version = "1.0.4", features = [
"use-std",
], default-features = false }
[target.'cfg(not(windows))'.dev-dependencies]
criterion = { version = "0.8", default-features = false }
criterion = { version = "0.5", default-features = false }
[dev-dependencies.fail]
version = "0.5.0"
@@ -202,10 +202,3 @@ harness = false
name = "regex_all_terms"
harness = false
[[bench]]
name = "query_parser_nested"
harness = false
[[bench]]
name = "intersection_bench"
harness = false

View File

@@ -1,7 +1,6 @@
[![Docs](https://docs.rs/tantivy/badge.svg)](https://docs.rs/crate/tantivy/)
[![Build Status](https://github.com/quickwit-oss/tantivy/actions/workflows/test.yml/badge.svg)](https://github.com/quickwit-oss/tantivy/actions/workflows/test.yml)
[![codecov](https://codecov.io/gh/quickwit-oss/tantivy/branch/main/graph/badge.svg)](https://codecov.io/gh/quickwit-oss/tantivy)
[![OpenSSF Scorecard](https://api.scorecard.dev/projects/github.com/quickwit-oss/tantivy/badge)](https://scorecard.dev/viewer/?uri=github.com/quickwit-oss/tantivy)
[![Join the chat at https://discord.gg/MT27AG5EVE](https://shields.io/discord/908281611840282624?label=chat%20on%20discord)](https://discord.gg/MT27AG5EVE)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
[![Crates.io](https://img.shields.io/crates/v/tantivy.svg)](https://crates.io/crates/tantivy)

View File

@@ -10,7 +10,7 @@ use tantivy::aggregation::agg_req::Aggregations;
use tantivy::aggregation::AggregationCollector;
use tantivy::query::{AllQuery, TermQuery};
use tantivy::schema::{IndexRecordOption, Schema, TextFieldIndexing, FAST, STRING};
use tantivy::{doc, DateTime, Index, Term};
use tantivy::{doc, Index, Term};
#[global_allocator]
pub static GLOBAL: &PeakMemAlloc<std::alloc::System> = &INSTRUMENTED_SYSTEM;
@@ -63,8 +63,6 @@ fn bench_agg(mut group: InputGroup<Index>) {
register!(group, terms_all_unique_with_avg_sub_agg);
register!(group, terms_many_with_avg_sub_agg);
register!(group, terms_status_with_avg_sub_agg);
register!(group, terms_status_with_terms_zipf_1000_sub_agg);
register!(group, terms_zipf_1000_with_terms_status_sub_agg);
register!(group, terms_status_with_histogram);
register!(group, terms_zipf_1000);
register!(group, terms_zipf_1000_with_histogram);
@@ -72,19 +70,8 @@ fn bench_agg(mut group: InputGroup<Index>) {
register!(group, terms_many_json_mixed_type_with_avg_sub_agg);
register!(group, composite_term_many_page_1000);
register!(group, composite_term_many_page_1000_with_avg_sub_agg);
register!(group, composite_term_few);
register!(group, composite_histogram);
register!(group, composite_histogram_calendar);
register!(group, cardinality_agg);
register!(group, cardinality_agg_high_card);
register!(group, cardinality_agg_low_card);
register!(group, terms_status_with_cardinality_agg);
register!(group, terms_100_buckets_with_cardinality_agg);
register!(group, terms_many_with_single_term_order_by_card);
register!(group, terms_many_with_single_term_2_order_by_card);
register!(group, range_agg);
register!(group, range_agg_with_avg_sub_agg);
@@ -172,52 +159,10 @@ fn cardinality_agg(index: &Index) {
});
execute_agg(index, agg_req);
}
// Full-scan cardinality on a near-1M-cardinality string field.
// Hits the dense (PagedBitset) path: every doc has a unique term,
// so the bucket promotes from FxHashSet shortly into the scan.
fn cardinality_agg_high_card(index: &Index) {
let agg_req = json!({
"cardinality": {
"cardinality": {
"field": "text_all_unique_terms"
},
}
});
execute_agg(index, agg_req);
}
// Full-scan cardinality on a tiny-cardinality string field (7 distinct
// values). Stays on the FxHashSet path — the promotion threshold is
// never crossed. Validates no regression on the sparse path.
fn cardinality_agg_low_card(index: &Index) {
let agg_req = json!({
"cardinality": {
"cardinality": {
"field": "text_few_terms_status"
},
}
});
execute_agg(index, agg_req);
}
fn terms_status_with_cardinality_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms_status" },
"aggs": {
"cardinality": {
"cardinality": {
"field": "text_few_terms_status"
},
}
}
},
});
execute_agg(index, agg_req);
}
fn terms_100_buckets_with_cardinality_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_1000_terms_zipf", "size": 100 },
"aggs": {
"cardinality": {
"cardinality": {
@@ -230,58 +175,6 @@ fn terms_100_buckets_with_cardinality_agg(index: &Index) {
execute_agg(index, agg_req);
}
fn terms_many_with_single_term_order_by_card(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_many_terms" },
"aggs": {
"nested_terms": {
"terms": {
"field": "single_term",
"order": { "cardinality": "desc" }
},
"aggs": {
"cardinality": {
"cardinality": { "field": "text_few_terms" }
}
}
}
}
},
});
execute_agg(index, agg_req);
}
// Two-level terms ordered by cardinality at each level: a high-card outer terms
// (text_many_terms) ordered by a cardinality sub-agg, with a nested low-card terms
// (text_few_terms_status) also ordered by a cardinality sub-agg, plus an avg.
fn terms_many_with_single_term_2_order_by_card(index: &Index) {
let agg_req = json!({
"by_ip": {
"terms": {
"field": "text_many_terms",
"order": { "card_few_terms": "desc" }
},
"aggs": {
"card_few_terms": {
"cardinality": { "field": "text_few_terms" }
},
"nested_terms": {
"terms": {
"field": " single_term",
"order": { "distinct_path2": "desc" }
},
"aggs": {
"avg_botscore": { "avg": { "field": "score" } },
"distinct_path2": { "cardinality": { "field": "text_few_terms" } }
}
}
}
}
});
execute_agg(index, agg_req);
}
fn terms_7(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_few_terms_status" } },
@@ -354,30 +247,6 @@ fn terms_all_unique_with_avg_sub_agg(index: &Index) {
});
execute_agg(index, agg_req);
}
fn terms_status_with_terms_zipf_1000_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_few_terms_status" },
"aggs": {
"nested_terms": { "terms": { "field": "text_1000_terms_zipf" } }
}
}
});
execute_agg(index, agg_req);
}
fn terms_zipf_1000_with_terms_status_sub_agg(index: &Index) {
let agg_req = json!({
"my_texts": {
"terms": { "field": "text_1000_terms_zipf" },
"aggs": {
"nested_terms": { "terms": { "field": "text_few_terms_status" } }
}
}
});
execute_agg(index, agg_req);
}
fn terms_status_with_histogram(index: &Index) {
let agg_req = json!({
"my_texts": {
@@ -445,75 +314,6 @@ fn terms_many_json_mixed_type_with_avg_sub_agg(index: &Index) {
execute_agg(index, agg_req);
}
fn composite_term_few(index: &Index) {
let agg_req = json!({
"my_ctf": {
"composite": {
"sources": [
{ "text_few_terms": { "terms": { "field": "text_few_terms" } } }
],
"size": 1000
}
},
});
execute_agg(index, agg_req);
}
fn composite_term_many_page_1000(index: &Index) {
let agg_req = json!({
"my_ctmp1000": {
"composite": {
"sources": [
{ "text_many_terms": { "terms": { "field": "text_many_terms" } } }
],
"size": 1000
}
},
});
execute_agg(index, agg_req);
}
fn composite_term_many_page_1000_with_avg_sub_agg(index: &Index) {
let agg_req = json!({
"my_ctmp1000wasa": {
"composite": {
"sources": [
{ "text_many_terms": { "terms": { "field": "text_many_terms" } } }
],
"size": 1000,
},
"aggs": {
"average_f64": { "avg": { "field": "score_f64" } }
}
},
});
execute_agg(index, agg_req);
}
fn composite_histogram(index: &Index) {
let agg_req = json!({
"my_ch": {
"composite": {
"sources": [
{ "f64_histogram": { "histogram": { "field": "score_f64", "interval": 1 } } }
],
"size": 1000
}
},
});
execute_agg(index, agg_req);
}
fn composite_histogram_calendar(index: &Index) {
let agg_req = json!({
"my_chc": {
"composite": {
"sources": [
{ "time_histogram": { "date_histogram": { "field": "timestamp", "calendar_interval": "month" } } }
],
"size": 1000
}
},
});
execute_agg(index, agg_req);
}
fn execute_agg(index: &Index, agg_req: serde_json::Value) {
let agg_req: Aggregations = serde_json::from_value(agg_req).unwrap();
let collector = get_collector(agg_req);
@@ -691,13 +491,11 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
TextFieldIndexing::default().set_index_option(IndexRecordOption::WithFreqs),
)
.set_stored();
let text_field = schema_builder.add_text_field("text", text_fieldtype.clone());
let single_term = schema_builder.add_text_field("single_term", FAST);
let text_field = schema_builder.add_text_field("text", text_fieldtype);
let json_field = schema_builder.add_json_field("json", FAST);
let text_field_all_unique_terms =
schema_builder.add_text_field("text_all_unique_terms", STRING | FAST);
let text_field_many_terms = schema_builder.add_text_field("text_many_terms", STRING | FAST);
let text_field_few_terms = schema_builder.add_text_field("text_few_terms", STRING | FAST);
let text_field_few_terms_status =
schema_builder.add_text_field("text_few_terms_status", STRING | FAST);
let text_field_1000_terms_zipf =
@@ -706,7 +504,6 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
let score_field = schema_builder.add_u64_field("score", score_fieldtype.clone());
let score_field_f64 = schema_builder.add_f64_field("score_f64", score_fieldtype.clone());
let score_field_i64 = schema_builder.add_i64_field("score_i64", score_fieldtype);
let date_field = schema_builder.add_date_field("timestamp", FAST);
// use tmp dir
let index = if reuse_index {
Index::create_in_dir("agg_bench", schema_builder.build())?
@@ -726,7 +523,6 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
let log_level_distribution =
WeightedIndex::new(status_field_data.iter().map(|item| item.1)).unwrap();
let few_terms_data = ["INFO", "ERROR", "WARN", "DEBUG"];
let lg_norm = rand_distr::LogNormal::new(2.996f64, 0.979f64).unwrap();
let many_terms_data = (0..150_000)
@@ -756,16 +552,12 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
index_writer.add_document(doc!(
json_field => json!({"mixed_type": 10.0}),
json_field => json!({"mixed_type": 10.0}),
single_term => "single_term",
single_term => "single_term",
text_field => "cool",
text_field => "cool",
text_field_all_unique_terms => "cool",
text_field_all_unique_terms => "coolo",
text_field_many_terms => "cool",
text_field_many_terms => "cool",
text_field_few_terms => "cool",
text_field_few_terms => "cool",
text_field_few_terms_status => log_level_sample_a,
text_field_few_terms_status => log_level_sample_b,
text_field_1000_terms_zipf => term_1000_a.as_str(),
@@ -792,18 +584,15 @@ fn get_test_index_bench(cardinality: Cardinality) -> tantivy::Result<Index> {
json!({"mixed_type": many_terms_data.choose(&mut rng).unwrap().to_string()})
};
index_writer.add_document(doc!(
single_term => "single_term",
text_field => "cool",
json_field => json,
text_field_all_unique_terms => format!("unique_term_{}", rng.random::<u64>()),
text_field_many_terms => many_terms_data.choose(&mut rng).unwrap().to_string(),
text_field_few_terms => few_terms_data.choose(&mut rng).unwrap().to_string(),
text_field_few_terms_status => status_field_data[log_level_distribution.sample(&mut rng)].0,
text_field_1000_terms_zipf => terms_1000[zipf_1000.sample(&mut rng) as usize - 1].as_str(),
score_field => val as u64,
score_field_f64 => lg_norm.sample(&mut rng),
score_field_i64 => val as i64,
date_field => DateTime::from_timestamp_millis((val * 1_000_000.) as i64),
))?;
if cardinality == Cardinality::OptionalSparse {
for _ in 0..20 {

View File

@@ -22,7 +22,7 @@ use rand::rngs::StdRng;
use rand::SeedableRng;
use tantivy::collector::sort_key::SortByStaticFastValue;
use tantivy::collector::{Collector, Count, TopDocs};
use tantivy::query::QueryParser;
use tantivy::query::{Query, QueryParser};
use tantivy::schema::{Schema, FAST, TEXT};
use tantivy::{doc, Index, Order, ReloadPolicy, Searcher};
@@ -38,7 +38,7 @@ struct BenchIndex {
/// return two BenchIndex views:
/// - single_field: QueryParser defaults to only "body"
/// - multi_field: QueryParser defaults to ["title", "body"]
fn build_index(num_docs: usize, terms: &[(&str, f32)]) -> (BenchIndex, BenchIndex) {
fn build_shared_indices(num_docs: usize, p_a: f32, p_b: f32, p_c: f32) -> (BenchIndex, BenchIndex) {
// Unified schema (two text fields)
let mut schema_builder = Schema::builder();
let f_title = schema_builder.add_text_field("title", TEXT);
@@ -55,17 +55,32 @@ fn build_index(num_docs: usize, terms: &[(&str, f32)]) -> (BenchIndex, BenchInde
{
let mut writer = index.writer_with_num_threads(1, 500_000_000).unwrap();
for _ in 0..num_docs {
let has_a = rng.random_bool(p_a as f64);
let has_b = rng.random_bool(p_b as f64);
let has_c = rng.random_bool(p_c as f64);
let score = rng.random_range(0u64..100u64);
let score2 = rng.random_range(0u64..100_000u64);
let mut title_tokens: Vec<&str> = Vec::new();
let mut body_tokens: Vec<&str> = Vec::new();
for &(tok, prob) in terms {
if rng.random_bool(prob as f64) {
if rng.random_bool(0.1) {
title_tokens.push(tok);
} else {
body_tokens.push(tok);
}
if has_a {
if rng.random_bool(0.1) {
title_tokens.push("a");
} else {
body_tokens.push("a");
}
}
if has_b {
if rng.random_bool(0.1) {
title_tokens.push("b");
} else {
body_tokens.push("b");
}
}
if has_c {
if rng.random_bool(0.1) {
title_tokens.push("c");
} else {
body_tokens.push("c");
}
}
if title_tokens.is_empty() && body_tokens.is_empty() {
@@ -95,97 +110,59 @@ fn build_index(num_docs: usize, terms: &[(&str, f32)]) -> (BenchIndex, BenchInde
let qp_single = QueryParser::for_index(&index, vec![f_body]);
let qp_multi = QueryParser::for_index(&index, vec![f_title, f_body]);
let only_title = BenchIndex {
let single_view = BenchIndex {
index: index.clone(),
searcher: searcher.clone(),
query_parser: qp_single,
};
let title_and_body = BenchIndex {
let multi_view = BenchIndex {
index,
searcher,
query_parser: qp_multi,
};
(only_title, title_and_body)
}
fn format_pct(p: f32) -> String {
let pct = (p as f64) * 100.0;
let rounded = (pct * 1_000_000.0).round() / 1_000_000.0;
if rounded.fract() <= 0.001 {
format!("{}%", rounded as u64)
} else {
format!("{}%", rounded)
}
}
fn query_label(query_str: &str, term_pcts: &[(&str, String)]) -> String {
let mut label = query_str.to_string();
for (term, pct) in term_pcts {
label = label.replace(term, pct);
}
label.replace(' ', "_")
(single_view, multi_view)
}
fn main() {
// terms with varying selectivity, ordered from rarest to most common.
// With 1M docs, we expect:
// a: 0.01% (100), b: 1% (10k), c: 5% (50k), d: 15% (150k), e: 30% (300k)
let num_docs = 1_000_000;
let terms: &[(&str, f32)] = &[
("a", 0.0001),
("b", 0.01),
("c", 0.05),
("d", 0.15),
("e", 0.30),
// Prepare corpora with varying selectivity. Build one index per corpus
// and derive two views (single-field vs multi-field) from it.
let scenarios = vec![
(
"N=1M, p(a)=5%, p(b)=1%, p(c)=15%".to_string(),
1_000_000,
0.05,
0.01,
0.15,
),
(
"N=1M, p(a)=1%, p(b)=1%, p(c)=15%".to_string(),
1_000_000,
0.01,
0.01,
0.15,
),
];
let queries: &[(&str, &[&str])] = &[
(
"only_union",
&["c OR b", "c OR b OR d", "c OR e", "e OR a"] as &[&str],
),
(
"only_intersection",
&["+c +b", "+c +b +d", "+c +e", "+e +a"] as &[&str],
),
(
"union_intersection",
&["+c +(b OR d)", "+e +(c OR a)", "+(c OR b) +(d OR e)"] as &[&str],
),
];
let queries = &["a", "+a +b", "+a +b +c", "a OR b", "a OR b OR c"];
let mut runner = BenchRunner::new();
let (only_title, title_and_body) = build_index(num_docs, terms);
let term_pcts: Vec<(&str, String)> = terms
.iter()
.map(|&(term, p)| (term, format_pct(p)))
.collect();
for (label, n, pa, pb, pc) in scenarios {
let (single_view, multi_view) = build_shared_indices(n, pa, pb, pc);
for (view_name, bench_index) in [
("single_field", only_title),
("multi_field", title_and_body),
] {
for (category_name, category_queries) in queries {
for query_str in *category_queries {
let mut group = runner.new_group();
let query_label = query_label(query_str, &term_pcts);
group.set_name(format!("{}_{}_{}", view_name, category_name, query_label));
for (view_name, bench_index) in [("single_field", single_view), ("multi_field", multi_view)]
{
// Single-field group: default field is body only
let mut group = runner.new_group();
group.set_name(format!("{}{}", view_name, label));
for query_str in queries {
add_bench_task(&mut group, &bench_index, query_str, Count, "count");
add_bench_task(
&mut group,
&bench_index,
query_str,
TopDocs::with_limit(10).order_by_score(),
"top10_inv_idx",
"top10",
);
add_bench_task(
&mut group,
&bench_index,
query_str,
(Count, TopDocs::with_limit(10).order_by_score()),
"count+top10",
);
add_bench_task(
&mut group,
&bench_index,
@@ -203,47 +180,39 @@ fn main() {
)),
"top10_by_2ff",
);
group.run();
}
group.run();
}
}
}
trait FruitCount {
fn count(&self) -> usize;
}
impl FruitCount for usize {
fn count(&self) -> usize {
*self
}
}
impl<T> FruitCount for Vec<T> {
fn count(&self) -> usize {
self.len()
}
}
impl<A: FruitCount, B> FruitCount for (A, B) {
fn count(&self) -> usize {
self.0.count()
}
}
fn add_bench_task<C: Collector + 'static>(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query_str: &str,
collector: C,
collector_name: &str,
) where
C::Fruit: FruitCount,
{
) {
let task_name = format!("{}_{}", query_str.replace(" ", "_"), collector_name);
let query = bench_index.query_parser.parse_query(query_str).unwrap();
let searcher = bench_index.searcher.clone();
bench_group.register(collector_name.to_string(), move |_| {
black_box(searcher.search(&query, &collector).unwrap().count())
});
let search_task = SearchTask {
searcher: bench_index.searcher.clone(),
collector,
query,
};
bench_group.register(task_name, move |_| black_box(search_task.run()));
}
struct SearchTask<C: Collector> {
searcher: Searcher,
collector: C,
query: Box<dyn Query>,
}
impl<C: Collector> SearchTask<C> {
#[inline(never)]
pub fn run(&self) -> usize {
self.searcher.search(&self.query, &self.collector).unwrap();
1
}
}

View File

@@ -1,149 +0,0 @@
// Benchmarks top-K intersection of term scorers (block_wand_intersection).
//
// What's measured:
// - Conjunctive queries (+a +b, +a +b +c) with top-10 by score
// - Varying doc-frequency balance between terms (balanced, skewed, very skewed)
// - Realistic term frequencies (geometric distribution, mostly low)
// - 1M-doc single segment
//
// Run with: cargo bench --bench intersection_bench
use binggan::{black_box, BenchRunner};
use rand::prelude::*;
use rand::rngs::StdRng;
use rand::SeedableRng;
use tantivy::collector::TopDocs;
use tantivy::query::QueryParser;
use tantivy::schema::{Schema, TEXT};
use tantivy::{doc, Index, ReloadPolicy, Searcher};
const NUM_DOCS: usize = 1_000_000;
struct BenchIndex {
searcher: Searcher,
query_parser: QueryParser,
}
/// Generate term frequency from a geometric-like distribution.
/// Most values are 1, a few are 2-3, rarely higher.
/// p controls the decay: higher p → more weight on tf=1.
fn random_term_freq(rng: &mut StdRng, p: f64) -> u32 {
let mut tf = 1u32;
while tf < 10 && rng.random_bool(1.0 - p) {
tf += 1;
}
tf
}
/// Build an index with three terms (a, b, c) with given doc-frequency probabilities.
/// Each term occurrence has a realistic term frequency (geometric distribution).
/// Field length is padded with filler tokens to create varied fieldnorms.
fn build_index(p_a: f64, p_b: f64, p_c: f64) -> BenchIndex {
let mut schema_builder = Schema::builder();
let body = schema_builder.add_text_field("body", TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut rng = StdRng::from_seed([42u8; 32]);
{
let mut writer = index.writer_with_num_threads(1, 500_000_000).unwrap();
for _ in 0..NUM_DOCS {
let mut tokens: Vec<String> = Vec::new();
if rng.random_bool(p_a) {
let tf = random_term_freq(&mut rng, 0.7);
for _ in 0..tf {
tokens.push("aaa".to_string());
}
}
if rng.random_bool(p_b) {
let tf = random_term_freq(&mut rng, 0.7);
for _ in 0..tf {
tokens.push("bbb".to_string());
}
}
if rng.random_bool(p_c) {
let tf = random_term_freq(&mut rng, 0.7);
for _ in 0..tf {
tokens.push("ccc".to_string());
}
}
// Pad with filler to create varied field lengths (5-30 tokens).
let filler_count = rng.random_range(5u32..30u32);
for _ in 0..filler_count {
tokens.push("filler".to_string());
}
let text = tokens.join(" ");
writer.add_document(doc!(body => text)).unwrap();
}
writer.commit().unwrap();
}
let reader = index
.reader_builder()
.reload_policy(ReloadPolicy::Manual)
.try_into()
.unwrap();
let searcher = reader.searcher();
let query_parser = QueryParser::for_index(&index, vec![body]);
BenchIndex {
searcher,
query_parser,
}
}
fn main() {
// Scenarios: (label, p_a, p_b, p_c)
//
// "balanced": all terms ~10% → intersection ~1% of docs
// "skewed": one common (50%), one rare (2%) → intersection ~1%
// "very_skewed": one very common (80%), one very rare (0.5%) → intersection ~0.4%
// "three_balanced": three terms ~20% each → intersection ~0.8%
// "three_skewed": 50% / 10% / 2% → intersection ~0.1%
let scenarios: Vec<(&str, f64, f64, f64)> = vec![
("balanced_10%_10%", 0.10, 0.10, 0.0),
("skewed_50%_2%", 0.50, 0.02, 0.0),
("very_skewed_80%_0.5%", 0.80, 0.005, 0.0),
("three_balanced_20%_20%_20%", 0.20, 0.20, 0.20),
("three_skewed_50%_10%_2%", 0.50, 0.10, 0.02),
];
let mut runner = BenchRunner::new();
for (label, p_a, p_b, p_c) in &scenarios {
let bench_index = build_index(*p_a, *p_b, *p_c);
let mut group = runner.new_group();
group.set_name(format!("intersection — {label}"));
// Two-term intersection
if *p_a > 0.0 && *p_b > 0.0 {
let query_str = "+aaa +bbb";
let query = bench_index.query_parser.parse_query(query_str).unwrap();
let searcher = bench_index.searcher.clone();
group.register(format!("{query_str} top10"), move |_| {
let collector = TopDocs::with_limit(10).order_by_score();
black_box(searcher.search(&query, &collector).unwrap());
1usize
});
}
// Three-term intersection
if *p_c > 0.0 {
let query_str = "+aaa +bbb +ccc";
let query = bench_index.query_parser.parse_query(query_str).unwrap();
let searcher = bench_index.searcher.clone();
group.register(format!("{query_str} top10"), move |_| {
let collector = TopDocs::with_limit(10).order_by_score();
black_box(searcher.search(&query, &collector).unwrap());
1usize
});
}
group.run();
}
}

View File

@@ -1,35 +0,0 @@
// Benchmark for the query grammar parsing deeply nested queries.
//
// Regression guard for https://github.com/quickwit-oss/tantivy/issues/2498:
// at depth 20/21 the old parser took 0.87 s / 1.72 s respectively because
// `ast()` retried `occur_leaf` on backtrack, giving O(2^n) time. With the
// fix parsing is linear and completes in microseconds.
//
// Run with: `cargo bench --bench query_parser_nested`.
use binggan::{black_box, BenchRunner};
use tantivy::query_grammar::parse_query;
fn nested_query(depth: usize, leading_plus: bool) -> String {
let leading = "(".repeat(depth);
let trailing = ")".repeat(depth);
let prefix = if leading_plus { "+" } else { "" };
format!("{prefix}{leading}title:test{trailing}")
}
fn main() {
let mut runner = BenchRunner::new();
for depth in [20, 21] {
for leading_plus in [false, true] {
let query = nested_query(depth, leading_plus);
let label = format!(
"parse_nested_depth_{depth}_{}",
if leading_plus { "plus" } else { "plain" },
);
runner.bench_function(&label, move |_| {
black_box(parse_query(black_box(&query)).unwrap());
});
}
}
}

View File

@@ -17,7 +17,6 @@ use rand::rngs::StdRng;
use rand::SeedableRng;
use tantivy::collector::{Count, DocSetCollector};
use tantivy::query::RangeQuery;
use tantivy::schema::document::TantivyDocument;
use tantivy::schema::{Schema, Value, FAST, STORED, STRING};
use tantivy::{doc, Index, ReloadPolicy, Searcher, Term};
@@ -406,7 +405,7 @@ impl FetchAllStringsFromDocTask {
for doc_address in docs {
// Get the document from the doc store (row store access)
if let Ok(doc) = self.searcher.doc::<TantivyDocument>(doc_address) {
if let Ok(doc) = self.searcher.doc(doc_address) {
// Extract string values from the stored field
if let Some(field_value) = doc.get_first(str_stored_field) {
if let Some(text) = field_value.as_value().as_str() {

View File

@@ -1,6 +1,6 @@
[package]
name = "tantivy-bitpacker"
version = "0.10.0"
version = "0.9.0"
edition = "2024"
authors = ["Paul Masurel <paul.masurel@gmail.com>"]
license = "MIT"
@@ -18,10 +18,5 @@ homepage = "https://github.com/quickwit-oss/tantivy"
bitpacking = { version = "0.9.2", default-features = false, features = ["bitpacker1x"] }
[dev-dependencies]
binggan = "0.17.0"
rand = "0.9"
proptest = "1"
[[bench]]
name = "bench"
harness = false

View File

@@ -1,110 +1,65 @@
use std::cell::RefCell;
#![feature(test)]
use binggan::{BenchRunner, black_box};
use rand::rng;
use rand::seq::IteratorRandom;
use tantivy_bitpacker::{BitPacker, BitUnpacker, BlockedBitpacker};
extern crate test;
fn create_bitpacked_data(bit_width: u8, num_els: u32) -> Vec<u8> {
let mut bitpacker = BitPacker::new();
let mut buffer = Vec::new();
for _ in 0..num_els {
bitpacker.write(0u64, bit_width, &mut buffer).unwrap();
bitpacker.flush(&mut buffer).unwrap();
}
buffer
}
#[cfg(test)]
mod tests {
use rand::rng;
use rand::seq::IteratorRandom;
use tantivy_bitpacker::{BitPacker, BitUnpacker, BlockedBitpacker};
use test::Bencher;
const N: usize = 100_000;
const MAX_VAL: u64 = 1_000;
const BIT_WIDTH: u8 = 10; // 2^10 = 1024 > MAX_VAL
fn create_packed_data() -> (BitUnpacker, Vec<u8>) {
let mut bitpacker = BitPacker::new();
let mut data = Vec::new();
for i in 0..N as u64 {
let val = i * MAX_VAL / N as u64;
bitpacker.write(val, BIT_WIDTH, &mut data).unwrap();
}
bitpacker.close(&mut data).unwrap();
(BitUnpacker::new(BIT_WIDTH), data)
}
fn bench_bitpacking() {
let mut runner = BenchRunner::new();
let bit_width = 3;
let num_els = 1_000_000u32;
let bit_unpacker = BitUnpacker::new(bit_width);
let data = create_bitpacked_data(bit_width, num_els);
let idxs: Vec<u32> = (0..num_els).choose_multiple(&mut rng(), 100_000);
runner.bench_function("bitpacking_read", move |_| {
let mut out = 0u64;
for &idx in &idxs {
out = out.wrapping_add(bit_unpacker.get(idx, &data[..]));
#[inline(never)]
fn create_bitpacked_data(bit_width: u8, num_els: u32) -> Vec<u8> {
let mut bitpacker = BitPacker::new();
let mut buffer = Vec::new();
for _ in 0..num_els {
// the values do not matter.
bitpacker.write(0u64, bit_width, &mut buffer).unwrap();
bitpacker.flush(&mut buffer).unwrap();
}
black_box(out);
});
}
fn bench_blocked_bitpacker() {
let mut runner = BenchRunner::new();
let mut blocked_bitpacker = BlockedBitpacker::new();
for val in 0..=21500 {
blocked_bitpacker.add(val * val);
buffer
}
runner.bench_function("blockedbitp_read", move |_| {
let mut out = 0u64;
for val in 0..=21500 {
out = out.wrapping_add(blocked_bitpacker.get(val));
}
black_box(out);
});
runner.bench_function("blockedbitp_create", |_| {
#[bench]
fn bench_bitpacking_read(b: &mut Bencher) {
let bit_width = 3;
let num_els = 1_000_000u32;
let bit_unpacker = BitUnpacker::new(bit_width);
let data = create_bitpacked_data(bit_width, num_els);
let idxs: Vec<u32> = (0..num_els).choose_multiple(&mut rng(), 100_000);
b.iter(|| {
let mut out = 0u64;
for &idx in &idxs {
out = out.wrapping_add(bit_unpacker.get(idx, &data[..]));
}
out
});
}
#[bench]
fn bench_blockedbitp_read(b: &mut Bencher) {
let mut blocked_bitpacker = BlockedBitpacker::new();
for val in 0..=21500 {
blocked_bitpacker.add(val * val);
}
black_box(blocked_bitpacker);
});
}
fn bench_filter_vec() {
let mut runner = BenchRunner::new();
let (unpacker, data) = create_packed_data();
let positions = RefCell::new(Vec::with_capacity(N));
runner.bench_function("filter_vec_dense", move |_| {
unpacker.get_ids_for_value_range(
250..=750,
0..N as u32,
&data,
&mut positions.borrow_mut(),
);
black_box(positions.borrow().len());
});
let (unpacker, data) = create_packed_data();
let positions = RefCell::new(Vec::with_capacity(N));
runner.bench_function("filter_vec_sparse", move |_| {
unpacker.get_ids_for_value_range(0..=50, 0..N as u32, &data, &mut positions.borrow_mut());
black_box(positions.borrow().len());
});
let (unpacker, data) = create_packed_data();
let positions = RefCell::new(Vec::with_capacity(N));
runner.bench_function("filter_vec_full", move |_| {
unpacker.get_ids_for_value_range(
0..=MAX_VAL,
0..N as u32,
&data,
&mut positions.borrow_mut(),
);
black_box(positions.borrow().len());
});
}
fn main() {
bench_bitpacking();
bench_blocked_bitpacker();
bench_filter_vec();
b.iter(|| {
let mut out = 0u64;
for val in 0..=21500 {
out = out.wrapping_add(blocked_bitpacker.get(val));
}
out
});
}
#[bench]
fn bench_blockedbitp_create(b: &mut Bencher) {
b.iter(|| {
let mut blocked_bitpacker = BlockedBitpacker::new();
for val in 0..=21500 {
blocked_bitpacker.add(val * val);
}
blocked_bitpacker
});
}
}

View File

@@ -1,17 +1,8 @@
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple")))]
use std::arch::is_aarch64_feature_detected;
use std::ops::RangeInclusive;
#[cfg(target_arch = "x86_64")]
mod avx2;
#[cfg(target_arch = "aarch64")]
mod neon;
// SVE intrinsics are not exposed on aarch64-apple-darwin.
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple")))]
mod sve;
mod scalar;
#[derive(Clone, Copy, Eq, PartialEq, Debug)]
@@ -19,10 +10,6 @@ mod scalar;
enum FilterImplPerInstructionSet {
#[cfg(target_arch = "x86_64")]
AVX2 = 0u8,
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple")))]
SVE = 3u8,
#[cfg(target_arch = "aarch64")]
Neon = 2u8,
Scalar = 1u8,
}
@@ -32,57 +19,29 @@ impl FilterImplPerInstructionSet {
match *self {
#[cfg(target_arch = "x86_64")]
FilterImplPerInstructionSet::AVX2 => is_x86_feature_detected!("avx2"),
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple")))]
FilterImplPerInstructionSet::SVE => is_aarch64_feature_detected!("sve"),
// TIL Neon is required on aarch 64.
#[cfg(target_arch = "aarch64")]
FilterImplPerInstructionSet::Neon => true,
FilterImplPerInstructionSet::Scalar => true,
}
}
}
// List of available implementations in preferred order.
// List of available implementation in preferred order.
#[cfg(target_arch = "x86_64")]
const IMPLS: [FilterImplPerInstructionSet; 2] = [
FilterImplPerInstructionSet::AVX2,
FilterImplPerInstructionSet::Scalar,
];
// Non-Apple aarch64: try SVE, NEON, Scalar.
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple")))]
const IMPLS: [FilterImplPerInstructionSet; 3] = [
FilterImplPerInstructionSet::SVE,
FilterImplPerInstructionSet::Neon,
FilterImplPerInstructionSet::Scalar,
];
// Apple aarch64 (M-series): SVE not available; use NEON or Scalar.
#[cfg(all(target_arch = "aarch64", target_vendor = "apple"))]
const IMPLS: [FilterImplPerInstructionSet; 2] = [
FilterImplPerInstructionSet::Neon,
FilterImplPerInstructionSet::Scalar,
];
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
#[cfg(not(target_arch = "x86_64"))]
const IMPLS: [FilterImplPerInstructionSet; 1] = [FilterImplPerInstructionSet::Scalar];
impl FilterImplPerInstructionSet {
#[inline]
#[allow(unused_variables)]
#[allow(unused_variables)] // on non-x86_64, code is unused.
fn from(code: u8) -> FilterImplPerInstructionSet {
#[cfg(target_arch = "x86_64")]
if code == FilterImplPerInstructionSet::AVX2 as u8 {
return FilterImplPerInstructionSet::AVX2;
}
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple")))]
if code == FilterImplPerInstructionSet::SVE as u8 {
return FilterImplPerInstructionSet::SVE;
}
#[cfg(target_arch = "aarch64")]
if code == FilterImplPerInstructionSet::Neon as u8 {
return FilterImplPerInstructionSet::Neon;
}
FilterImplPerInstructionSet::Scalar
}
@@ -91,13 +50,6 @@ impl FilterImplPerInstructionSet {
match self {
#[cfg(target_arch = "x86_64")]
FilterImplPerInstructionSet::AVX2 => avx2::filter_vec_in_place(range, offset, output),
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple")))]
// SAFETY: SVE availability was verified by is_available() before selecting this impl.
FilterImplPerInstructionSet::SVE => unsafe {
sve::filter_vec_in_place(range, offset, output)
},
#[cfg(target_arch = "aarch64")]
FilterImplPerInstructionSet::Neon => neon::filter_vec_in_place(range, offset, output),
FilterImplPerInstructionSet::Scalar => {
scalar::filter_vec_in_place(range, offset, output)
}
@@ -105,12 +57,6 @@ impl FilterImplPerInstructionSet {
}
}
fn available_impls() -> impl Iterator<Item = FilterImplPerInstructionSet> {
IMPLS
.into_iter()
.filter(FilterImplPerInstructionSet::is_available)
}
#[inline]
fn get_best_available_instruction_set() -> FilterImplPerInstructionSet {
use std::sync::atomic::{AtomicU8, Ordering};
@@ -118,7 +64,10 @@ fn get_best_available_instruction_set() -> FilterImplPerInstructionSet {
let instruction_set_byte: u8 = INSTRUCTION_SET_BYTE.load(Ordering::Relaxed);
if instruction_set_byte == u8::MAX {
// Let's initialize the instruction set and cache it.
let instruction_set = available_impls().next().unwrap();
let instruction_set = IMPLS
.into_iter()
.find(FilterImplPerInstructionSet::is_available)
.unwrap();
INSTRUCTION_SET_BYTE.store(instruction_set as u8, Ordering::Relaxed);
return instruction_set;
}
@@ -131,12 +80,12 @@ pub fn filter_vec_in_place(range: RangeInclusive<u32>, offset: u32, output: &mut
#[cfg(test)]
mod tests {
use proptest::strategy::Strategy;
use super::*;
#[test]
fn test_get_best_available_instruction_set() {
// This does not test much unfortunately.
// We just make sure the function returns without crashing and returns the same result.
let instruction_set = get_best_available_instruction_set();
assert_eq!(get_best_available_instruction_set(), instruction_set);
}
@@ -153,31 +102,6 @@ mod tests {
}
}
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple")))]
#[test]
fn test_instruction_set_to_code_from_code() {
for instruction_set in [
FilterImplPerInstructionSet::SVE,
FilterImplPerInstructionSet::Neon,
FilterImplPerInstructionSet::Scalar,
] {
let code = instruction_set as u8;
assert_eq!(instruction_set, FilterImplPerInstructionSet::from(code));
}
}
#[cfg(all(target_arch = "aarch64", target_vendor = "apple"))]
#[test]
fn test_instruction_set_to_code_from_code() {
for instruction_set in [
FilterImplPerInstructionSet::Neon,
FilterImplPerInstructionSet::Scalar,
] {
let code = instruction_set as u8;
assert_eq!(instruction_set, FilterImplPerInstructionSet::from(code));
}
}
fn test_filter_impl_empty_aux(filter_impl: FilterImplPerInstructionSet) {
let mut output = vec![];
filter_impl.filter_vec_in_place(0..=u32::MAX, 0, &mut output);
@@ -202,20 +126,11 @@ mod tests {
assert_eq!(&output, &[1, 3, 4, 5, 6, 7, 8]);
}
fn test_filter_impl_empty_range_aux(filter_impl: FilterImplPerInstructionSet) {
// start > end: RangeInclusive::contains always returns false; output must be empty.
// The SVE path's wrapping_sub would otherwise produce a huge range_width.
let mut output = vec![3, 2, 1, 5, 11, 2, 5, 10, 2];
filter_impl.filter_vec_in_place(10..=5, 0, &mut output);
assert_eq!(&output, &[]);
}
fn test_filter_impl_test_suite(filter_impl: FilterImplPerInstructionSet) {
test_filter_impl_empty_aux(filter_impl);
test_filter_impl_simple_aux(filter_impl);
test_filter_impl_simple_aux_shifted(filter_impl);
test_filter_impl_simple_outside_i32_range(filter_impl);
test_filter_impl_empty_range_aux(filter_impl);
}
#[test]
@@ -226,60 +141,25 @@ mod tests {
}
}
#[test]
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple")))]
fn test_filter_implementation_sve() {
if FilterImplPerInstructionSet::SVE.is_available() {
test_filter_impl_test_suite(FilterImplPerInstructionSet::SVE);
}
}
#[test]
#[cfg(target_arch = "aarch64")]
fn test_filter_implementation_neon() {
test_filter_impl_test_suite(FilterImplPerInstructionSet::Neon);
}
#[test]
fn test_filter_implementation_scalar() {
test_filter_impl_test_suite(FilterImplPerInstructionSet::Scalar);
}
fn max_val_strategy() -> impl proptest::strategy::Strategy<Value = u32> {
proptest::prop_oneof![
0u32..10u32,
255u32..258u32,
proptest::prelude::Just(1u32 << 25),
proptest::prelude::Just(u32::MAX - 1),
proptest::prelude::Just(u32::MAX),
]
}
fn vals_strategy() -> impl proptest::strategy::Strategy<Value = Vec<u32>> {
proptest::prop_oneof![
proptest::collection::vec(proptest::prelude::any::<u32>(), 0..300),
max_val_strategy()
.prop_flat_map(|max_val| { proptest::collection::vec(0..=max_val, 0..300) })
]
}
#[cfg(target_arch = "x86_64")]
proptest::proptest! {
#[test]
fn test_filter_compare_scalar_and_impls_impl_proptest(
start in 0u32..400u32,
end in 0u32..400u32,
fn test_filter_compare_scalar_and_avx2_impl_proptest(
start in proptest::prelude::any::<u32>(),
end in proptest::prelude::any::<u32>(),
offset in 0u32..2u32,
vals in vals_strategy()) {
for implementation in available_impls() {
if implementation == FilterImplPerInstructionSet::Scalar {
continue;
}
let mut impl_output = vals.clone();
let mut scalar_output = vals.clone();
implementation.filter_vec_in_place(start..=end, offset, &mut impl_output);
FilterImplPerInstructionSet::Scalar.filter_vec_in_place(start..=end, offset, &mut scalar_output);
assert_eq!(&impl_output, &scalar_output);
}
mut vals in proptest::collection::vec(0..u32::MAX, 0..30)) {
if FilterImplPerInstructionSet::AVX2.is_available() {
let mut vals_clone = vals.clone();
FilterImplPerInstructionSet::AVX2.filter_vec_in_place(start..=end, offset, &mut vals);
FilterImplPerInstructionSet::Scalar.filter_vec_in_place(start..=end, offset, &mut vals_clone);
assert_eq!(&vals, &vals_clone);
}
}
}
}

View File

@@ -1,118 +0,0 @@
use std::arch::aarch64::*;
use std::ops::RangeInclusive;
const NUM_LANES: usize = 4;
// Compacts matching lanes to the front using a byte-level shuffle.
// `mask` is a 4-bit value: bit k=1 means lane k should appear in the output.
#[inline]
#[target_feature(enable = "neon")]
unsafe fn compact(data: uint32x4_t, mask: u8) -> uint32x4_t {
unsafe {
// SAFETY: mask is always in [0, 15] by construction (max sum of [1,2,4,8]).
// BYTE_SHUFFLE_TABLE has 16 entries, so this is always in bounds.
let shuffle = BYTE_SHUFFLE_TABLE.get_unchecked(mask as usize);
let shuffle_vec = vld1q_u8(shuffle.as_ptr());
vreinterpretq_u32_u8(vqtbl1q_u8(vreinterpretq_u8_u32(data), shuffle_vec))
}
}
// Safe (not unsafe) because NEON is mandatory on aarch64: no runtime feature check needed.
#[inline(never)]
pub fn filter_vec_in_place(range: RangeInclusive<u32>, offset: u32, output: &mut Vec<u32>) {
let num_words = output.len() / NUM_LANES;
let mut output_len = unsafe {
filter_vec_neon_aux(
output.as_ptr(),
range.clone(),
output.as_mut_ptr(),
offset,
num_words,
)
};
let remainder_start = num_words * NUM_LANES;
for i in remainder_start..output.len() {
let val = output[i];
output[output_len] = offset + i as u32;
output_len += if range.contains(&val) { 1 } else { 0 };
}
output.truncate(output_len);
}
#[target_feature(enable = "neon")]
unsafe fn filter_vec_neon_aux(
input: *const u32,
range: RangeInclusive<u32>,
output: *mut u32,
offset: u32,
num_words: usize,
) -> usize {
unsafe {
let mut input = input;
let mut output_tail = output;
let range_start_simd = vdupq_n_u32(*range.start());
let range_end_simd = vdupq_n_u32(*range.end());
let mut ids = vld1q_u32([offset, offset + 1, offset + 2, offset + 3].as_ptr());
let shift = vdupq_n_u32(NUM_LANES as u32);
let bit_weights = vld1q_u32([1u32, 2, 4, 8].as_ptr());
for _ in 0..num_words {
let word = vld1q_u32(input);
// Unsigned compares: CMHS (compare higher or same) tests `word >= start`
// and `end >= word`. ANDing both gives the inside-range mask directly,
// which is cheaper than computing `outside` and then negating.
let ge_start = vcgeq_u32(word, range_start_simd);
let le_end = vcleq_u32(word, range_end_simd);
// inside[k] = 0xFFFFFFFF if val[k] is in range, 0 otherwise.
let inside = vandq_u32(ge_start, le_end);
// Build the 4-bit mask: AND bit_weights with the inside lane mask, so each
// inside lane contributes its bit_weight (1, 2, 4, or 8). Summing yields the
// 4-bit mask in one addv.
let inside_bits = vandq_u32(bit_weights, inside);
let mask = vaddvq_u32(inside_bits) as u8;
// mask is mathematically bounded: max value is 1+2+4+8=15 (all lanes match)
debug_assert!(mask <= 15, "mask must fit in 4 bits: {}", mask);
// Count of matching lanes = popcount(mask). Derives the count directly from
// the mask instead of running a parallel SIMD reduction over `outside`.
let added_len = mask.count_ones() as usize;
// Safe because mask is guaranteed to be in [0, 15]
let filtered_ids = compact(ids, mask);
vst1q_u32(output_tail, filtered_ids);
output_tail = output_tail.add(added_len);
ids = vaddq_u32(ids, shift);
input = input.add(NUM_LANES);
}
output_tail.offset_from(output) as usize
}
}
// Byte shuffle patterns to compact matching lanes to the front of the vector.
// Index is a 4-bit mask: bit k=1 means lane k (bytes 4k..4k+3) is in-range.
// The j-th set bit determines which input lane goes to output position j.
const BYTE_SHUFFLE_TABLE: [[u8; 16]; 16] = [
[
16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16,
], // 0b0000: none
[0, 1, 2, 3, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16], // 0b0001: lane 0
[4, 5, 6, 7, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16], // 0b0010: lane 1
[0, 1, 2, 3, 4, 5, 6, 7, 16, 16, 16, 16, 16, 16, 16, 16], // 0b0011: lanes 0,1
[8, 9, 10, 11, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16], // 0b0100: lane 2
[0, 1, 2, 3, 8, 9, 10, 11, 16, 16, 16, 16, 16, 16, 16, 16], // 0b0101: lanes 0,2
[4, 5, 6, 7, 8, 9, 10, 11, 16, 16, 16, 16, 16, 16, 16, 16], // 0b0110: lanes 1,2
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16, 16, 16, 16], // 0b0111: lanes 0,1,2
[
12, 13, 14, 15, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16,
], // 0b1000: lane 3
[0, 1, 2, 3, 12, 13, 14, 15, 16, 16, 16, 16, 16, 16, 16, 16], // 0b1001: lanes 0,3
[4, 5, 6, 7, 12, 13, 14, 15, 16, 16, 16, 16, 16, 16, 16, 16], // 0b1010: lanes 1,3
[0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 16, 16, 16, 16], // 0b1011: lanes 0,1,3
[8, 9, 10, 11, 12, 13, 14, 15, 16, 16, 16, 16, 16, 16, 16, 16], // 0b1100: lanes 2,3
[0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 16, 16, 16, 16], // 0b1101: lanes 0,2,3
[4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 16, 16, 16], // 0b1110: lanes 1,2,3
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], // 0b1111: all lanes
];

View File

@@ -1,260 +0,0 @@
use std::ops::RangeInclusive;
// SVE vector length (in u32 lanes) is not a compile-time constant; query at runtime.
// Safe to call only when SVE is confirmed available via is_aarch64_feature_detected!("sve").
#[target_feature(enable = "sve")]
unsafe fn num_lanes() -> usize {
let vl: usize;
unsafe {
core::arch::asm!(
"cntw {vl}",
vl = out(reg) vl,
options(nostack, nomem, preserves_flags),
);
}
vl
}
// SAFETY: caller must ensure SVE is available (checked via is_aarch64_feature_detected!("sve")).
// Unlike NEON, SVE is optional on aarch64 and not guaranteed by the target architecture.
pub unsafe fn filter_vec_in_place(range: RangeInclusive<u32>, offset: u32, output: &mut Vec<u32>) {
if range.start() > range.end() {
output.clear();
return;
}
let vl = unsafe { num_lanes() };
let num_words = output.len() / vl;
let range_start = *range.start();
// Unsigned subtraction trick: val ∈ [lo, hi] ↔ (val - lo) ≤ᵤ (hi - lo).
// Values below lo wrap around to large u32, so the single unsigned ≤ excludes them.
let range_width = range.end().wrapping_sub(range_start);
let mut output_len = unsafe {
filter_vec_sve_aux(
output.as_ptr(),
range_start,
range_width,
output.as_mut_ptr(),
offset,
num_words,
vl,
)
};
let remainder_start = num_words * vl;
for i in remainder_start..output.len() {
let val = output[i];
output[output_len] = offset + i as u32;
output_len += if range.contains(&val) { 1 } else { 0 };
}
output.truncate(output_len);
}
// Register allocation for the asm! blocks:
// z0 ids_a (index vector for first half of each pair, advances by step2 each iter)
// z1 range_width broadcast
// z2 range_start broadcast
// z3 step2 broadcast (2 * vl)
// z4 ids_b (index vector for second half, = ids_a + step, advances by step2)
// z5 scratch: loaded word_a, then compacted_a
// z6 scratch: loaded word_b, then compacted_b
// p0 all-true predicate (ptrue p0.s)
// p1 in-range mask for word_a
// p2 in-range mask for word_b
#[target_feature(enable = "sve")]
unsafe fn filter_vec_sve_aux(
input: *const u32,
range_start: u32,
range_width: u32,
output: *mut u32,
offset: u32,
num_words: usize,
vl: usize,
) -> usize {
let num_pairs = num_words / 2;
let mut input_ptr = input;
let mut output_tail = output;
if num_pairs > 0 {
unsafe {
// We rely on asm! because the SVE intrinsics are not available in stable Rust.
// The code that follows was generated by Rustc nightly based on the intrinsics version
// at the bottom of this file.
core::arch::asm!(
// --- Setup ---
// All-true predicate for 32-bit lanes.
"ptrue p0.s",
// ids_a = [offset, offset+1, offset+2, ...]
"index z0.s, {offset:w}, #1",
// Broadcast scalars into SVE vectors.
"mov z1.s, {range_width:w}",
"mov z2.s, {range_start:w}",
// vl_gpr = number of 32-bit lanes (cntw).
"cntw {vl_gpr}",
// step2_bytes will first hold 2*vl (for the step2 vector), then 2*VL in bytes.
"lsl {step2_bytes}, {vl_gpr}, #1",
// z4 = step = [vl, vl, ...]; will become ids_b after the add below.
"mov z4.s, {vl_gpr:w}",
// z3 = step2 = [2*vl, 2*vl, ...], used to advance both id vectors each iter.
"mov z3.s, {step2_bytes:w}",
// Repurpose step2_bytes to hold the byte stride for advancing the input pointer
// by two full SVE vectors per iteration.
"rdvl {step2_bytes}, #2",
// ids_b = ids_a + step = [offset+vl, offset+vl+1, ...]
"add z4.s, z0.s, z4.s",
// --- Main loop: process two SVE vectors (ids_a and ids_b) per iteration ---
"0:",
// Load two consecutive SVE vectors from input.
"ld1w {{z5.s}}, p0/z, [{input}]",
"ld1w {{z6.s}}, p0/z, [{input}, #1, mul vl]",
// Advance input pointer by 2 * VL bytes.
"add {input}, {input}, {step2_bytes}",
// Unsigned shift: subtract range_start so in-range check becomes a single cmpu ≤.
"sub z5.s, z5.s, z2.s",
"sub z6.s, z6.s, z2.s",
// in_range: shifted value ≤ range_width (unsigned, so values below lo also fail).
"cmphs p1.s, p0/z, z1.s, z5.s",
"cmphs p2.s, p0/z, z1.s, z6.s",
// Count matching lanes; both cntp calls have independent inputs for OOO parallelism.
"cntp {cnt_a}, p0, p1.s",
"compact z5.s, p1, z0.s",
"compact z6.s, p2, z4.s",
"cntp {cnt_b}, p0, p2.s",
// Advance id vectors for the next iteration.
"add z0.s, z0.s, z3.s",
"add z4.s, z4.s, z3.s",
// Store compacted ids. Only the first cnt_a / cnt_b slots are valid; the rest
// will be overwritten by subsequent iterations before the final truncate.
"str z5, [{out}]",
"st1w {{z6.s}}, p0, [{out}, {cnt_a}, lsl #2]",
"add {out}, {out}, {cnt_a}, lsl #2",
"add {out}, {out}, {cnt_b}, lsl #2",
"subs {pairs}, {pairs}, #1",
"b.ne 0b",
// --- Operands ---
input = inout(reg) input_ptr,
out = inout(reg) output_tail,
pairs = inout(reg) num_pairs => _,
offset = in(reg) offset,
range_start = in(reg) range_start,
range_width = in(reg) range_width,
vl_gpr = out(reg) _,
step2_bytes = out(reg) _,
cnt_a = out(reg) _,
cnt_b = out(reg) _,
out("p0") _, out("p1") _, out("p2") _,
out("v0") _, out("v1") _, out("v2") _, out("v3") _,
out("v4") _, out("v5") _, out("v6") _,
options(nostack),
);
}
}
// Handle an odd trailing vector.
if num_words % 2 == 1 {
// ids_a for the odd word starts at offset + num_pairs * 2 * vl.
// input_ptr was advanced by the main loop and now points at the odd word.
let odd_offset =
offset.wrapping_add((num_pairs as u32).wrapping_mul(2).wrapping_mul(vl as u32));
unsafe {
core::arch::asm!(
"ptrue p0.s",
"index z0.s, {odd_offset:w}, #1",
"mov z1.s, {range_width:w}",
"mov z2.s, {range_start:w}",
"ld1w {{z3.s}}, p0/z, [{input}]",
"sub z3.s, z3.s, z2.s",
"cmphs p1.s, p0/z, z1.s, z3.s",
"cntp {cnt}, p0, p1.s",
"compact z0.s, p1, z0.s",
"str z0, [{out}]",
"add {out}, {out}, {cnt}, lsl #2",
odd_offset = in(reg) odd_offset,
range_width = in(reg) range_width,
range_start = in(reg) range_start,
input = in(reg) input_ptr,
out = inout(reg) output_tail,
cnt = out(reg) _,
out("p0") _, out("p1") _,
out("v0") _, out("v1") _, out("v2") _, out("v3") _,
options(nostack),
);
}
}
unsafe { output_tail.offset_from(output) as usize }
}
// SVE implements with intrinsics.
//
// #[target_feature(enable = "sve")]
// unsafe fn filter_vec_sve_aux(
// input: *const u32,
// range_start: u32,
// range_width: u32,
// output: *mut u32,
// offset: u32,
// num_words: usize,
// vl: usize,
// ) -> usize {
// unsafe {
// let all_true = svptrue_b32();
// let range_start_simd = svdup_n_u32(range_start);
// let range_width_simd = svdup_n_u32(range_width);
// // ids_a covers [offset .. offset+vl), ids_b covers the next vl ids.
// // Keeping them separate breaks the loop-carried dependency through ids so
// // both compact/cntp chains are fully independent within each unrolled body.
// let mut ids_a = svindex_u32(offset, 1);
// let step = svdup_n_u32(vl as u32);
// let step2 = svdup_n_u32(2 * vl as u32);
// let mut ids_b = svadd_u32_x(all_true, ids_a, step);
// let mut input = input;
// let mut output_tail = output;
// // Unrolled ×2: both cntp calls have independent inputs and execute in parallel.
// // The two output_tail updates are sequential but together cost 4+1+1=6 cy per
// // pair vs 5+5=10 cy for two scalar iterations, breaking the cntp latency chain.
// let num_pairs = num_words / 2;
// for _ in 0..num_pairs {
// let word_a = svld1_u32(all_true, input);
// let word_b = svld1_u32(all_true, input.add(vl));
// let shifted_a = svsub_u32_x(all_true, word_a, range_start_simd);
// let shifted_b = svsub_u32_x(all_true, word_b, range_start_simd);
// let in_range_a = svcmple_u32(all_true, shifted_a, range_width_simd);
// let in_range_b = svcmple_u32(all_true, shifted_b, range_width_simd);
// let compacted_a = svcompact_u32(in_range_a, ids_a);
// let compacted_b = svcompact_u32(in_range_b, ids_b);
// // cntp_a and cntp_b have independent inputs: OOO engine issues them in parallel.
// let added_len_a = svcntp_b32(all_true, in_range_a) as usize;
// let added_len_b = svcntp_b32(all_true, in_range_b) as usize;
// // Write the full vector — only the first added_len slots are valid.
// // Subsequent iterations overwrite the trailing zeros before truncate.
// svst1_u32(all_true, output_tail, compacted_a);
// output_tail = output_tail.add(added_len_a);
// svst1_u32(all_true, output_tail, compacted_b);
// output_tail = output_tail.add(added_len_b);
// ids_a = svadd_u32_x(all_true, ids_a, step2);
// ids_b = svadd_u32_x(all_true, ids_b, step2);
// input = input.add(2 * vl);
// }
// // Handle an odd trailing word.
// if num_words % 2 == 1 {
// let word = svld1_u32(all_true, input);
// let shifted = svsub_u32_x(all_true, word, range_start_simd);
// let in_range = svcmple_u32(all_true, shifted, range_width_simd);
// let added_len = svcntp_b32(all_true, in_range) as usize;
// let compacted_ids = svcompact_u32(in_range, ids_a);
// svst1_u32(all_true, output_tail, compacted_ids);
// output_tail = output_tail.add(added_len);
// }
// output_tail.offset_from(output) as usize
// }
// }

View File

@@ -1,6 +1,6 @@
[package]
name = "tantivy-columnar"
version = "0.7.0"
version = "0.6.0"
edition = "2024"
license = "MIT"
homepage = "https://github.com/quickwit-oss/tantivy"
@@ -12,10 +12,10 @@ categories = ["database-implementations", "data-structures", "compression"]
itertools = "0.14.0"
fastdivide = "0.4.0"
stacker = { version= "0.7", path = "../stacker", package="tantivy-stacker"}
sstable = { version= "0.7", path = "../sstable", package = "tantivy-sstable" }
common = { version= "0.11", path = "../common", package = "tantivy-common" }
tantivy-bitpacker = { version= "0.10", path = "../bitpacker/" }
stacker = { version= "0.6", path = "../stacker", package="tantivy-stacker"}
sstable = { version= "0.6", path = "../sstable", package = "tantivy-sstable" }
common = { version= "0.10", path = "../common", package = "tantivy-common" }
tantivy-bitpacker = { version= "0.9", path = "../bitpacker/" }
serde = "1.0.152"
downcast-rs = "2.0.1"
@@ -23,7 +23,7 @@ downcast-rs = "2.0.1"
proptest = "1"
more-asserts = "0.3.1"
rand = "0.9"
binggan = "0.17.0"
binggan = "0.14.0"
[[bench]]
name = "bench_merge"

View File

@@ -33,14 +33,14 @@ impl<T: PartialOrd + Copy + std::fmt::Debug + Send + Sync + 'static + Default>
&mut self,
docs: &[u32],
accessor: &Column<T>,
missing_opt: Option<T>,
missing: Option<T>,
) {
self.fetch_block(docs, accessor);
// no missing values
if accessor.index.get_cardinality().is_full() {
return;
}
let Some(missing) = missing_opt else {
let Some(missing) = missing else {
return;
};
@@ -58,78 +58,6 @@ impl<T: PartialOrd + Copy + std::fmt::Debug + Send + Sync + 'static + Default>
}
}
/// Like `fetch_block_with_missing`, but deduplicates (doc_id, value) pairs
/// so that each unique value per document is returned only once.
///
/// This is necessary for correct document counting in aggregations,
/// where multi-valued fields can produce duplicate entries that inflate counts.
#[inline]
pub fn fetch_block_with_missing_unique_per_doc(
&mut self,
docs: &[u32],
accessor: &Column<T>,
missing: Option<T>,
) where
T: Ord,
{
self.fetch_block_with_missing(docs, accessor, missing);
if accessor.index.get_cardinality().is_multivalue() {
self.dedup_docid_val_pairs();
}
}
/// Removes duplicate (doc_id, value) pairs from the caches.
///
/// After `fetch_block`, entries are sorted by doc_id, but values within
/// the same doc may not be sorted (e.g. `(0,1), (0,2), (0,1)`).
/// We group consecutive entries by doc_id, sort values within each group
/// if it has more than 2 elements, then deduplicate adjacent pairs.
///
/// Skips entirely if no doc_id appears more than once in the block.
fn dedup_docid_val_pairs(&mut self)
where T: Ord {
if self.docid_cache.len() <= 1 {
return;
}
// Quick check: if no consecutive doc_ids are equal, no dedup needed.
let has_multivalue = self.docid_cache.windows(2).any(|w| w[0] == w[1]);
if !has_multivalue {
return;
}
// Sort values within each doc_id group so duplicates become adjacent.
let mut start = 0;
while start < self.docid_cache.len() {
let doc = self.docid_cache[start];
let mut end = start + 1;
while end < self.docid_cache.len() && self.docid_cache[end] == doc {
end += 1;
}
if end - start > 2 {
self.val_cache[start..end].sort();
}
start = end;
}
// Now duplicates are adjacent — deduplicate in place.
let mut write = 0;
for read in 1..self.docid_cache.len() {
if self.docid_cache[read] != self.docid_cache[write]
|| self.val_cache[read] != self.val_cache[write]
{
write += 1;
if write != read {
self.docid_cache[write] = self.docid_cache[read];
self.val_cache[write] = self.val_cache[read];
}
}
}
let new_len = write + 1;
self.docid_cache.truncate(new_len);
self.val_cache.truncate(new_len);
}
#[inline]
pub fn iter_vals(&self) -> impl Iterator<Item = T> + '_ {
self.val_cache.iter().cloned()
@@ -191,7 +119,6 @@ where F: FnMut(u32) {
}
#[cfg(test)]
#[allow(clippy::field_reassign_with_default)]
mod tests {
use super::*;
@@ -236,56 +163,4 @@ mod tests {
assert_eq!(missing_docs, vec![1, 2, 3, 4, 5]);
}
#[test]
fn test_dedup_docid_val_pairs_consecutive() {
let mut accessor = ColumnBlockAccessor::<u64>::default();
accessor.docid_cache = vec![0, 0, 2, 3];
accessor.val_cache = vec![10, 10, 10, 10];
accessor.dedup_docid_val_pairs();
assert_eq!(accessor.docid_cache, vec![0, 2, 3]);
assert_eq!(accessor.val_cache, vec![10, 10, 10]);
}
#[test]
fn test_dedup_docid_val_pairs_non_consecutive() {
// (0,1), (0,2), (0,1) — duplicate value not adjacent
let mut accessor = ColumnBlockAccessor::<u64>::default();
accessor.docid_cache = vec![0, 0, 0];
accessor.val_cache = vec![1, 2, 1];
accessor.dedup_docid_val_pairs();
assert_eq!(accessor.docid_cache, vec![0, 0]);
assert_eq!(accessor.val_cache, vec![1, 2]);
}
#[test]
fn test_dedup_docid_val_pairs_multi_doc() {
// doc 0: values [3, 1, 3], doc 1: values [5, 5]
let mut accessor = ColumnBlockAccessor::<u64>::default();
accessor.docid_cache = vec![0, 0, 0, 1, 1];
accessor.val_cache = vec![3, 1, 3, 5, 5];
accessor.dedup_docid_val_pairs();
assert_eq!(accessor.docid_cache, vec![0, 0, 1]);
assert_eq!(accessor.val_cache, vec![1, 3, 5]);
}
#[test]
fn test_dedup_docid_val_pairs_no_duplicates() {
let mut accessor = ColumnBlockAccessor::<u64>::default();
accessor.docid_cache = vec![0, 0, 1];
accessor.val_cache = vec![1, 2, 3];
accessor.dedup_docid_val_pairs();
assert_eq!(accessor.docid_cache, vec![0, 0, 1]);
assert_eq!(accessor.val_cache, vec![1, 2, 3]);
}
#[test]
fn test_dedup_docid_val_pairs_single_element() {
let mut accessor = ColumnBlockAccessor::<u64>::default();
accessor.docid_cache = vec![0];
accessor.val_cache = vec![1];
accessor.dedup_docid_val_pairs();
assert_eq!(accessor.docid_cache, vec![0]);
assert_eq!(accessor.val_cache, vec![1]);
}
}

View File

@@ -31,7 +31,7 @@ pub use u64_based::{
serialize_and_load_u64_based_column_values, serialize_u64_based_column_values,
};
pub use u128_based::{
CompactHit, CompactSpaceU64Accessor, open_u128_as_compact_u64, open_u128_mapped,
CompactSpaceU64Accessor, open_u128_as_compact_u64, open_u128_mapped,
serialize_column_values_u128,
};
pub use vec_column::VecColumn;

View File

@@ -292,19 +292,6 @@ impl BinarySerializable for IPCodecParams {
}
}
/// Represents the result of looking up a u128 value in the compact space.
///
/// If a value is outside the compact space, the next compact value is returned.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CompactHit {
/// The value exists in the compact space
Exact(u32),
/// The value does not exist in the compact space, but the next higher value does
Next(u32),
/// The value is greater than the maximum compact value
AfterLast,
}
/// Exposes the compact space compressed values as u64.
///
/// This allows faster access to the values, as u64 is faster to work with than u128.
@@ -322,11 +309,6 @@ impl CompactSpaceU64Accessor {
pub fn compact_to_u128(&self, compact: u32) -> u128 {
self.0.compact_to_u128(compact)
}
/// Finds the next compact space value for a given u128 value.
pub fn u128_to_next_compact(&self, value: u128) -> CompactHit {
self.0.u128_to_next_compact(value)
}
}
impl ColumnValues<u64> for CompactSpaceU64Accessor {
@@ -459,21 +441,6 @@ impl CompactSpaceDecompressor {
self.params.compact_space.u128_to_compact(value)
}
/// Finds the next compact space value for a given u128 value.
pub fn u128_to_next_compact(&self, value: u128) -> CompactHit {
match self.u128_to_compact(value) {
Ok(compact) => CompactHit::Exact(compact),
Err(pos) => {
if pos >= self.params.compact_space.ranges_mapping.len() {
CompactHit::AfterLast
} else {
let next_range = &self.params.compact_space.ranges_mapping[pos];
CompactHit::Next(next_range.compact_start)
}
}
}
}
fn compact_to_u128(&self, compact: u32) -> u128 {
self.params.compact_space.compact_to_u128(compact)
}
@@ -856,41 +823,6 @@ mod tests {
let _data = test_aux_vals(vals);
}
#[test]
fn test_u128_to_next_compact() {
let vals = &[100u128, 200u128, 1_000_000_000u128, 1_000_000_100u128];
let mut data = test_aux_vals(vals);
let _header = U128Header::deserialize(&mut data);
let decomp = CompactSpaceDecompressor::open(data).unwrap();
// Test value that's already in a range
let compact_100 = decomp.u128_to_compact(100).unwrap();
assert_eq!(
decomp.u128_to_next_compact(100),
CompactHit::Exact(compact_100)
);
// Test value between two ranges
let compact_million = decomp.u128_to_compact(1_000_000_000).unwrap();
assert_eq!(
decomp.u128_to_next_compact(250),
CompactHit::Next(compact_million)
);
// Test value before the first range
assert_eq!(
decomp.u128_to_next_compact(50),
CompactHit::Next(compact_100)
);
// Test value after the last range
assert_eq!(
decomp.u128_to_next_compact(10_000_000_000),
CompactHit::AfterLast
);
}
use proptest::prelude::*;
fn num_strategy() -> impl Strategy<Value = u128> {

View File

@@ -7,7 +7,7 @@ mod compact_space;
use common::{BinarySerializable, OwnedBytes, VInt};
pub use compact_space::{
CompactHit, CompactSpaceCompressor, CompactSpaceDecompressor, CompactSpaceU64Accessor,
CompactSpaceCompressor, CompactSpaceDecompressor, CompactSpaceU64Accessor,
};
use crate::column_values::monotonic_map_column;

View File

@@ -59,7 +59,7 @@ pub struct RowAddr {
pub row_id: RowId,
}
pub use sstable::{Dictionary, TermOrdHit};
pub use sstable::Dictionary;
pub type Streamer<'a> = sstable::Streamer<'a, VoidSSTable>;
pub use common::DateTime;

View File

@@ -1,6 +1,6 @@
[package]
name = "tantivy-common"
version = "0.11.0"
version = "0.10.0"
authors = ["Paul Masurel <paul@quickwit.io>", "Pascal Seitz <pascal@quickwit.io>"]
license = "MIT"
edition = "2024"
@@ -15,10 +15,11 @@ repository = "https://github.com/quickwit-oss/tantivy"
byteorder = "1.4.3"
ownedbytes = { version= "0.9", path="../ownedbytes" }
async-trait = "0.1"
time = { version = "0.3.47", features = ["serde-well-known"] }
time = { version = "0.3.10", features = ["serde-well-known"] }
serde = { version = "1.0.136", features = ["derive"] }
[dev-dependencies]
binggan = "0.17.0"
binggan = "0.14.0"
proptest = "1.0.0"
rand = "0.9"

View File

@@ -47,9 +47,6 @@ impl TinySet {
TinySet(val)
}
/// An empty `TinySet` constant.
pub const EMPTY: TinySet = TinySet(0u64);
/// Returns an empty `TinySet`.
#[inline]
pub fn empty() -> TinySet {
@@ -156,22 +153,7 @@ impl TinySet {
None
} else {
let lowest = self.0.trailing_zeros();
// Kernighan's trick: `n &= n - 1` clears the lowest set bit
// without depending on `lowest`. This lets the CPU execute
// `trailing_zeros` and the bit-clear in parallel instead of
// serializing them.
//
// The previous form `self.0 ^= 1 << lowest` needs the result of
// `trailing_zeros` before it can shift, creating a dependency chain:
// ARM64: rbit → clz → lsl → eor
// x86: tzcnt → btc
//
// With Kernighan's trick the clear path is independent of the count:
// ARM64: sub → and (trailing_zeros runs in parallel)
// x86: blsr (tzcnt runs in parallel)
//
// https://godbolt.org/z/fnfrP1T5f
self.0 &= self.0 - 1;
self.0 ^= TinySet::singleton(lowest).0;
Some(lowest)
}
}

View File

@@ -121,7 +121,7 @@ pub struct FileSlice {
impl fmt::Debug for FileSlice {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "FileSlice({:?}, {:?})", self.data, self.range)
write!(f, "FileSlice({:?}, {:?})", &self.data, self.range)
}
}

View File

@@ -70,7 +70,7 @@ impl Collector for StatsCollector {
fn for_segment(
&self,
_segment_local_id: u32,
segment_reader: &SegmentReader,
segment_reader: &dyn SegmentReader,
) -> tantivy::Result<StatsSegmentCollector> {
let fast_field_reader = segment_reader.fast_fields().u64(&self.field)?;
Ok(StatsSegmentCollector {

View File

@@ -60,7 +60,7 @@ fn main() -> tantivy::Result<()> {
let count_docs = searcher.search(&*query, &TopDocs::with_limit(4).order_by_score())?;
assert_eq!(count_docs.len(), 1);
for (_score, doc_address) in count_docs {
let retrieved_doc = searcher.doc::<TantivyDocument>(doc_address)?;
let retrieved_doc = searcher.doc(doc_address)?;
assert!(retrieved_doc
.get_first(occurred_at)
.unwrap()

View File

@@ -65,7 +65,7 @@ fn main() -> tantivy::Result<()> {
);
let top_docs_by_custom_score =
// Call TopDocs with a custom tweak score
TopDocs::with_limit(2).tweak_score(move |segment_reader: &SegmentReader| {
TopDocs::with_limit(2).tweak_score(move |segment_reader: &dyn SegmentReader| {
let ingredient_reader = segment_reader.facet_reader("ingredient").unwrap();
let facet_dict = ingredient_reader.facet_dict();
@@ -91,7 +91,7 @@ fn main() -> tantivy::Result<()> {
.iter()
.map(|(_, doc_id)| {
searcher
.doc::<TantivyDocument>(*doc_id)
.doc(*doc_id)
.unwrap()
.get_first(title)
.and_then(|v| v.as_str().map(|el| el.to_string()))

View File

@@ -67,7 +67,7 @@ fn main() -> Result<()> {
let mut titles = top_docs
.into_iter()
.map(|(_score, doc_address)| {
let doc = searcher.doc::<TantivyDocument>(doc_address)?;
let doc = searcher.doc(doc_address)?;
let title = doc
.get_first(title)
.and_then(|v| v.as_str())

View File

@@ -55,7 +55,7 @@ fn main() -> tantivy::Result<()> {
let snippet_generator = SnippetGenerator::create(&searcher, &*query, body)?;
for (score, doc_address) in top_docs {
let doc = searcher.doc::<TantivyDocument>(doc_address)?;
let doc = searcher.doc(doc_address)?;
let snippet = snippet_generator.snippet_from_doc(&doc);
println!("Document score {score}:");
println!("title: {}", doc.get_first(title).unwrap().as_str().unwrap());

View File

@@ -43,7 +43,7 @@ impl DynamicPriceColumn {
}
}
pub fn price_for_segment(&self, segment_reader: &SegmentReader) -> Option<Arc<Vec<Price>>> {
pub fn price_for_segment(&self, segment_reader: &dyn SegmentReader) -> Option<Arc<Vec<Price>>> {
let segment_key = (segment_reader.segment_id(), segment_reader.delete_opstamp());
self.price_cache.read().unwrap().get(&segment_key).cloned()
}
@@ -157,7 +157,7 @@ fn main() -> tantivy::Result<()> {
let query = query_parser.parse_query("cooking")?;
let searcher = reader.searcher();
let score_by_price = move |segment_reader: &SegmentReader| {
let score_by_price = move |segment_reader: &dyn SegmentReader| {
let price = price_dynamic_column
.price_for_segment(segment_reader)
.unwrap();

View File

@@ -1,6 +1,6 @@
[package]
name = "tantivy-query-grammar"
version = "0.26.0"
version = "0.25.0"
authors = ["Paul Masurel <paul.masurel@gmail.com>"]
license = "MIT"
categories = ["database-implementations", "data-structures"]

View File

@@ -1045,43 +1045,18 @@ fn operand_leaf(inp: &str) -> IResult<&str, (Option<BinaryOperand>, Option<Occur
}
fn ast(inp: &str) -> IResult<&str, UserInputAst> {
// Parse `occur_leaf` once, then conditionally extend into a boolean
// expression. The previous implementation used `alt((boolean_expr,
// single_leaf))` which, when the input was a single leaf with no
// following operand, would parse `occur_leaf` once for `boolean_expr`,
// fail at `multispace1`, backtrack, then re-parse `occur_leaf` for
// `single_leaf`. With recursively-nested groups like `(+(+(+a)))`, that
// doubling at every level produced O(2^n) parse time. Parsing once and
// peeking ahead for the operand keeps it O(n).
delimited(
multispace0,
|inp| {
let (rest, first) = occur_leaf(inp)?;
// Only fall back on `Err::Error` (recoverable), mirroring
// `alt`'s behaviour. `Err::Failure` and `Err::Incomplete`
// must propagate so cut points and streaming needs are not
// accidentally swallowed if they are ever introduced in the
// operand parsers.
match preceded(multispace1, many1(operand_leaf))(rest) {
Ok((rest, more)) => {
let combined = aggregate_binary_expressions(first, more)
.map_err(|_| nom::Err::Error(Error::new(inp, ErrorKind::MapRes)))?;
Ok((rest, combined))
}
Err(nom::Err::Error(_)) => {
let (occur, ast) = first;
let single = if occur == Some(Occur::MustNot) {
ast.unary(Occur::MustNot)
} else {
ast
};
Ok((rest, single))
}
Err(e) => Err(e),
}
},
multispace0,
)(inp)
let boolean_expr = map_res(
separated_pair(occur_leaf, multispace1, many1(operand_leaf)),
|(left, right)| aggregate_binary_expressions(left, right),
);
let single_leaf = map(occur_leaf, |(occur, ast)| {
if occur == Some(Occur::MustNot) {
ast.unary(Occur::MustNot)
} else {
ast
}
});
delimited(multispace0, alt((boolean_expr, single_leaf)), multispace0)(inp)
}
fn ast_infallible(inp: &str) -> JResult<&str, UserInputAst> {
@@ -1916,23 +1891,4 @@ mod test {
r#"(+"field":'happy tax payer' +"other_field":1)"#,
);
}
// Regression test for https://github.com/quickwit-oss/tantivy/issues/2498:
// deeply nested parenthesized queries used to take O(2^n) time because the
// top-level `ast()` parser tried `boolean_expr` first and re-parsed the
// inner `occur_leaf` when it backtracked to `single_leaf`. Depth 60 would
// take ~10^18 operations under the regression; with the fix it parses
// instantly. We use `test_parse_query_to_ast_helper` so this test would
// never finish if the regression returned.
#[test]
fn test_parse_deeply_nested_query() {
let depth = 60;
let leading: String = "(".repeat(depth);
let trailing: String = ")".repeat(depth);
let query = format!("{leading}title:test{trailing}");
test_parse_query_to_ast_helper(&query, r#""title":test"#);
let query_with_plus = format!("+{leading}title:test{trailing}");
test_parse_query_to_ast_helper(&query_with_plus, r#""title":test"#);
}
}

View File

@@ -57,7 +57,7 @@ pub(crate) fn get_numeric_or_date_column_types() -> &'static [ColumnType] {
/// Get fast field reader or empty as default.
pub(crate) fn get_ff_reader(
reader: &SegmentReader,
reader: &dyn SegmentReader,
field_name: &str,
allowed_column_types: Option<&[ColumnType]>,
) -> crate::Result<(columnar::Column<u64>, ColumnType)> {
@@ -74,7 +74,7 @@ pub(crate) fn get_ff_reader(
}
pub(crate) fn get_dynamic_columns(
reader: &SegmentReader,
reader: &dyn SegmentReader,
field_name: &str,
) -> crate::Result<Vec<columnar::DynamicColumn>> {
let ff_fields = reader.fast_fields().dynamic_column_handles(field_name)?;
@@ -90,7 +90,7 @@ pub(crate) fn get_dynamic_columns(
///
/// Is guaranteed to return at least one column.
pub(crate) fn get_all_ff_reader_or_empty(
reader: &SegmentReader,
reader: &dyn SegmentReader,
field_name: &str,
allowed_column_types: Option<&[ColumnType]>,
fallback_type: ColumnType,

View File

@@ -10,18 +10,17 @@ use crate::aggregation::accessor_helpers::{
};
use crate::aggregation::agg_req::{Aggregation, AggregationVariants, Aggregations};
use crate::aggregation::bucket::{
build_segment_filter_collector, build_segment_range_collector, CompositeAggReqData,
CompositeAggregation, CompositeSourceAccessors, FilterAggReqData, HistogramAggReqData,
HistogramBounds, IncludeExcludeParam, MissingTermAggReqData, RangeAggReqData,
SegmentHistogramCollector, TermMissingAgg, TermsAggReqData, TermsAggregation,
build_segment_filter_collector, build_segment_range_collector, FilterAggReqData,
HistogramAggReqData, HistogramBounds, IncludeExcludeParam, MissingTermAggReqData,
RangeAggReqData, SegmentHistogramCollector, TermMissingAgg, TermsAggReqData, TermsAggregation,
TermsAggregationInternal,
};
use crate::aggregation::metric::{
build_segment_stats_collector, AverageAggregation, CardinalityAggReqData,
CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation, MaxAggregation,
MetricAggReqData, MinAggregation, SegmentCardinalityCollector, SegmentExtendedStatsCollector,
SegmentPercentilesCollector, StatsAggregation, StatsType, SumAggregation, TermOrdSet,
TopHitsAggReqData, TopHitsSegmentCollector, BITSET_MAX_TERM_ORD,
SegmentPercentilesCollector, StatsAggregation, StatsType, SumAggregation, TopHitsAggReqData,
TopHitsSegmentCollector,
};
use crate::aggregation::segment_agg_result::{
GenericSegmentAggregationResultsCollector, SegmentAggregationCollector,
@@ -74,12 +73,6 @@ impl AggregationsSegmentCtx {
self.per_request.filter_req_data.push(Some(Box::new(data)));
self.per_request.filter_req_data.len() - 1
}
pub(crate) fn push_composite_req_data(&mut self, data: CompositeAggReqData) -> usize {
self.per_request
.composite_req_data
.push(Some(Box::new(data)));
self.per_request.composite_req_data.len() - 1
}
#[inline]
pub(crate) fn get_term_req_data(&self, idx: usize) -> &TermsAggReqData {
@@ -115,12 +108,6 @@ impl AggregationsSegmentCtx {
.as_deref()
.expect("range_req_data slot is empty (taken)")
}
#[inline]
pub(crate) fn get_composite_req_data(&self, idx: usize) -> &CompositeAggReqData {
self.per_request.composite_req_data[idx]
.as_deref()
.expect("composite_req_data slot is empty (taken)")
}
// ---------- mutable getters ----------
@@ -194,25 +181,6 @@ impl AggregationsSegmentCtx {
debug_assert!(self.per_request.filter_req_data[idx].is_none());
self.per_request.filter_req_data[idx] = Some(value);
}
/// Move out the Composite request at `idx`.
#[inline]
pub(crate) fn take_composite_req_data(&mut self, idx: usize) -> Box<CompositeAggReqData> {
self.per_request.composite_req_data[idx]
.take()
.expect("composite_req_data slot is empty (taken)")
}
/// Put back a Composite request into an empty slot at `idx`.
#[inline]
pub(crate) fn put_back_composite_req_data(
&mut self,
idx: usize,
value: Box<CompositeAggReqData>,
) {
debug_assert!(self.per_request.composite_req_data[idx].is_none());
self.per_request.composite_req_data[idx] = Some(value);
}
}
/// Each type of aggregation has its own request data struct. This struct holds
@@ -240,8 +208,6 @@ pub struct PerRequestAggSegCtx {
pub top_hits_req_data: Vec<TopHitsAggReqData>,
/// MissingTermAggReqData contains the request data for a missing term aggregation.
pub missing_term_req_data: Vec<MissingTermAggReqData>,
/// CompositeAggReqData contains the request data for a composite aggregation.
pub composite_req_data: Vec<Option<Box<CompositeAggReqData>>>,
/// Request tree used to build collectors.
pub agg_tree: Vec<AggRefNode>,
@@ -289,11 +255,6 @@ impl PerRequestAggSegCtx {
.iter()
.map(|t| t.get_memory_consumption())
.sum::<usize>()
+ self
.composite_req_data
.iter()
.map(|b| b.as_ref().map(|d| d.get_memory_consumption()).unwrap_or(0))
.sum::<usize>()
+ self.agg_tree.len() * std::mem::size_of::<AggRefNode>()
}
@@ -330,11 +291,6 @@ impl PerRequestAggSegCtx {
.expect("filter_req_data slot is empty (taken)")
.name
.as_str(),
AggKind::Composite => self.composite_req_data[idx]
.as_deref()
.expect("composite_req_data slot is empty (taken)")
.name
.as_str(),
}
}
@@ -413,38 +369,12 @@ pub(crate) fn build_segment_agg_collector(
}
AggKind::Cardinality => {
let req_data = &mut req.get_cardinality_req_data_mut(node.idx_in_req_data);
// For str columns, choose the per-bucket entries representation
// based on the segment's column.max_value():
// * small (< BITSET_MAX_TERM_ORD): `BitSet`, pre-allocated, no promotion machinery.
// * large: `TermOrdSet` (sparse FxHashSet that promotes to a paged bitset).
// For non-str columns the `entries` field is unused (values go
// straight into the HLL sketch); we still pick `TermOrdSet`
// because its empty Sparse(FxHashSet) costs nothing.
let is_str = req_data.column_type == ColumnType::Str;
let max_term_ord_inclusive = if is_str {
req_data.accessor.max_value()
} else {
0
};
let collector: Box<dyn SegmentAggregationCollector> =
if is_str && max_term_ord_inclusive < BITSET_MAX_TERM_ORD {
Box::new(SegmentCardinalityCollector::<BitSet>::from_req(
req_data.column_type,
node.idx_in_req_data,
req_data.accessor.clone(),
req_data.missing_value_for_accessor,
max_term_ord_inclusive,
))
} else {
Box::new(SegmentCardinalityCollector::<TermOrdSet>::from_req(
req_data.column_type,
node.idx_in_req_data,
req_data.accessor.clone(),
req_data.missing_value_for_accessor,
max_term_ord_inclusive,
))
};
Ok(collector)
Ok(Box::new(SegmentCardinalityCollector::from_req(
req_data.column_type,
node.idx_in_req_data,
req_data.accessor.clone(),
req_data.missing_value_for_accessor,
)))
}
AggKind::StatsKind(stats_type) => {
let req_data = &mut req.per_request.stats_metric_req_data[node.idx_in_req_data];
@@ -487,11 +417,6 @@ pub(crate) fn build_segment_agg_collector(
)?)),
AggKind::Range => Ok(build_segment_range_collector(req, node)?),
AggKind::Filter => build_segment_filter_collector(req, node),
AggKind::Composite => Ok(Box::new(
crate::aggregation::bucket::SegmentCompositeCollector::from_req_and_validate(
req, node,
)?,
)),
}
}
@@ -522,7 +447,6 @@ pub enum AggKind {
DateHistogram,
Range,
Filter,
Composite,
}
impl AggKind {
@@ -538,7 +462,6 @@ impl AggKind {
AggKind::DateHistogram => "DateHistogram",
AggKind::Range => "Range",
AggKind::Filter => "Filter",
AggKind::Composite => "Composite",
}
}
}
@@ -546,7 +469,7 @@ impl AggKind {
/// Build AggregationsData by walking the request tree.
pub(crate) fn build_aggregations_data_from_req(
aggs: &Aggregations,
reader: &SegmentReader,
reader: &dyn SegmentReader,
segment_ordinal: SegmentOrdinal,
context: AggContextParams,
) -> crate::Result<AggregationsSegmentCtx> {
@@ -566,7 +489,7 @@ pub(crate) fn build_aggregations_data_from_req(
fn build_nodes(
agg_name: &str,
req: &Aggregation,
reader: &SegmentReader,
reader: &dyn SegmentReader,
segment_ordinal: SegmentOrdinal,
data: &mut AggregationsSegmentCtx,
is_top_level: bool,
@@ -786,14 +709,6 @@ fn build_nodes(
children,
}])
}
AggregationVariants::Composite(composite_req) => Ok(vec![build_composite_node(
agg_name,
reader,
segment_ordinal,
data,
&req.sub_aggregation,
composite_req,
)?]),
AggregationVariants::Filter(filter_req) => {
// Build the query and evaluator upfront
let schema = reader.schema();
@@ -813,7 +728,7 @@ fn build_nodes(
let idx_in_req_data = data.push_filter_req_data(FilterAggReqData {
name: agg_name.to_string(),
req: filter_req.clone(),
segment_reader: reader.clone(),
segment_reader: reader.clone_arc(),
evaluator,
matching_docs_buffer,
is_top_level,
@@ -828,38 +743,9 @@ fn build_nodes(
}
}
fn build_composite_node(
agg_name: &str,
reader: &SegmentReader,
_segment_ordinal: SegmentOrdinal,
data: &mut AggregationsSegmentCtx,
sub_aggs: &Aggregations,
req: &CompositeAggregation,
) -> crate::Result<AggRefNode> {
let mut composite_accessors = Vec::with_capacity(req.sources.len());
for source in &req.sources {
let source_after_key_opt = req.after.get(source.name()).map(|k| &k.0);
let source_accessor =
CompositeSourceAccessors::build_for_source(reader, source, source_after_key_opt)?;
composite_accessors.push(source_accessor);
}
let agg = CompositeAggReqData {
name: agg_name.to_string(),
req: req.clone(),
composite_accessors,
};
let idx = data.push_composite_req_data(agg);
let children = build_children(sub_aggs, reader, _segment_ordinal, data)?;
Ok(AggRefNode {
kind: AggKind::Composite,
idx_in_req_data: idx,
children,
})
}
fn build_children(
aggs: &Aggregations,
reader: &SegmentReader,
reader: &dyn SegmentReader,
segment_ordinal: SegmentOrdinal,
data: &mut AggregationsSegmentCtx,
) -> crate::Result<Vec<AggRefNode>> {
@@ -878,7 +764,7 @@ fn build_children(
}
fn get_term_agg_accessors(
reader: &SegmentReader,
reader: &dyn SegmentReader,
field_name: &str,
missing: &Option<Key>,
) -> crate::Result<Vec<(Column<u64>, ColumnType)>> {
@@ -931,7 +817,7 @@ fn build_terms_or_cardinality_nodes(
agg_name: &str,
field_name: &str,
missing: &Option<Key>,
reader: &SegmentReader,
reader: &dyn SegmentReader,
segment_ordinal: SegmentOrdinal,
data: &mut AggregationsSegmentCtx,
sub_aggs: &Aggregations,
@@ -1011,12 +897,8 @@ fn build_terms_or_cardinality_nodes(
let str_col = str_dict_column
.as_ref()
.expect("str_dict_column must exist for string column");
allowed_term_ids = build_allowed_term_ids_for_str(
str_col,
&req.include,
&req.exclude,
missing.is_some(),
)?;
allowed_term_ids =
build_allowed_term_ids_for_str(str_col, &req.include, &req.exclude)?;
};
let idx_in_req_data = data.push_term_req_data(TermsAggReqData {
accessor,
@@ -1032,20 +914,10 @@ fn build_terms_or_cardinality_nodes(
(idx_in_req_data, AggKind::Terms)
}
TermsOrCardinalityRequest::Cardinality(ref req) => {
// `str_dict_column` is computed once per field; for JSON paths
// with mixed types it's `Some` even on the numeric req_data.
// Cardinality only consults it for the str column path, so
// gate by column_type to avoid driving non-str collectors
// through the coupon-cache path.
let str_dict_column_for_req = if column_type == ColumnType::Str {
str_dict_column.clone()
} else {
None
};
let idx_in_req_data = data.push_cardinality_req_data(CardinalityAggReqData {
accessor,
column_type,
str_dict_column: str_dict_column_for_req,
str_dict_column: str_dict_column.clone(),
missing_value_for_accessor,
name: agg_name.to_string(),
req: req.clone(),
@@ -1065,21 +937,16 @@ fn build_terms_or_cardinality_nodes(
/// Builds a single BitSet of allowed term ordinals for a string dictionary column according to
/// include/exclude parameters.
///
/// When `reserve_missing_sentinel` is true, the bitset will have 1 additional slot for the missing
/// term ordinal
fn build_allowed_term_ids_for_str(
str_col: &StrColumn,
include: &Option<IncludeExcludeParam>,
exclude: &Option<IncludeExcludeParam>,
reserve_missing_sentinel: bool,
) -> crate::Result<Option<BitSet>> {
let mut allowed: Option<BitSet> = None;
let missing_sentinel_adjustment = if reserve_missing_sentinel { 1 } else { 0 };
let allowed_capacity = str_col.dictionary().num_terms() as u32 + missing_sentinel_adjustment;
let num_terms = str_col.dictionary().num_terms() as u32;
if let Some(include) = include {
// add matches
allowed = Some(BitSet::with_max_value(allowed_capacity));
allowed = Some(BitSet::with_max_value(num_terms));
let allowed = allowed.as_mut().unwrap();
for_each_matching_term_ord(str_col, include, |ord| allowed.insert(ord))?;
};
@@ -1087,7 +954,7 @@ fn build_allowed_term_ids_for_str(
if let Some(exclude) = exclude {
if allowed.is_none() {
// Start with all terms allowed
allowed = Some(BitSet::with_max_value_and_full(allowed_capacity));
allowed = Some(BitSet::with_max_value_and_full(num_terms));
}
let allowed = allowed.as_mut().unwrap();
for_each_matching_term_ord(str_col, exclude, |ord| allowed.remove(ord))?;

View File

@@ -32,8 +32,8 @@ use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use super::bucket::{
CompositeAggregation, DateHistogramAggregationReq, FilterAggregation, HistogramAggregation,
RangeAggregation, TermsAggregation,
DateHistogramAggregationReq, FilterAggregation, HistogramAggregation, RangeAggregation,
TermsAggregation,
};
use super::metric::{
AverageAggregation, CardinalityAggregationReq, CountAggregation, ExtendedStatsAggregation,
@@ -115,71 +115,6 @@ pub fn get_fast_field_names(aggs: &Aggregations) -> HashSet<String> {
fast_field_names
}
/// Validates that all fields referenced in the aggregation request exist in the schema
/// and are configured as fast fields.
///
/// This is a convenience function for upfront validation before executing aggregations.
/// Returns an error if any field doesn't exist or is not a fast field.
///
/// Validation is intentionally opt-in rather than baked into aggregation execution: the
/// default lenient behavior (returning empty results for missing fields) supports
/// schema evolution and federated queries where the same request runs against segments
/// or indices with different schemas.
///
/// # Example
/// ```
/// use tantivy::aggregation::agg_req::{Aggregations, validate_aggregation_fields_exist};
/// use tantivy::schema::{Schema, FAST};
/// use tantivy::Index;
///
/// # fn main() -> tantivy::Result<()> {
/// // Create a simple index
/// let mut schema_builder = Schema::builder();
/// schema_builder.add_f64_field("price", FAST);
/// let schema = schema_builder.build();
/// let index = Index::create_in_ram(schema);
///
/// // Parse aggregation request
/// let agg_req: Aggregations = serde_json::from_str(r#"{
/// "avg_price": { "avg": { "field": "price" } }
/// }"#)?;
///
/// let reader = index.reader()?;
/// let searcher = reader.searcher();
///
/// // Validate fields before executing
/// for segment_reader in searcher.segment_readers() {
/// validate_aggregation_fields_exist(&agg_req, segment_reader)?;
/// }
/// # Ok(())
/// # }
/// ```
pub fn validate_aggregation_fields_exist(
aggs: &Aggregations,
reader: &crate::SegmentReader,
) -> crate::Result<()> {
let field_names = get_fast_field_names(aggs);
let schema = reader.schema();
for field_name in field_names {
// Check if the field is either directly in the schema or could be part of a json field
// present in the schema, and verify it's a fast field.
if let Some((field, _path)) = schema.find_field(&field_name) {
let field_type = schema.get_field_entry(field).field_type();
if !field_type.is_fast() {
return Err(crate::TantivyError::SchemaError(format!(
"Field '{}' is not a fast field. Aggregations require fast fields.",
field_name
)));
}
} else {
return Err(crate::TantivyError::FieldNotFound(field_name));
}
}
Ok(())
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
/// All aggregation types.
pub enum AggregationVariants {
@@ -199,9 +134,6 @@ pub enum AggregationVariants {
/// Filter documents into a single bucket.
#[serde(rename = "filter")]
Filter(FilterAggregation),
/// Multi-dimensional, paginable bucket aggregation.
#[serde(rename = "composite")]
Composite(CompositeAggregation),
// Metric aggregation types
/// Computes the average of the extracted values.
@@ -248,11 +180,6 @@ impl AggregationVariants {
AggregationVariants::Histogram(histogram) => vec![histogram.field.as_str()],
AggregationVariants::DateHistogram(histogram) => vec![histogram.field.as_str()],
AggregationVariants::Filter(filter) => filter.get_fast_field_names(),
AggregationVariants::Composite(composite) => composite
.sources
.iter()
.map(|source| source.field())
.collect(),
AggregationVariants::Average(avg) => vec![avg.field_name()],
AggregationVariants::Count(count) => vec![count.field_name()],
AggregationVariants::Max(max) => vec![max.field_name()],
@@ -287,12 +214,6 @@ impl AggregationVariants {
_ => None,
}
}
pub(crate) fn as_composite(&self) -> Option<&CompositeAggregation> {
match &self {
AggregationVariants::Composite(composite) => Some(composite),
_ => None,
}
}
pub(crate) fn as_percentile(&self) -> Option<&PercentilesAggregationReq> {
match &self {
AggregationVariants::Percentiles(percentile_req) => Some(percentile_req),

View File

@@ -9,12 +9,10 @@ use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use super::bucket::GetDocCount;
use super::intermediate_agg_result::CompositeIntermediateKey;
use super::metric::{
ExtendedStats, PercentilesMetricResult, SingleMetricResult, Stats, TopHitsMetricResult,
};
use super::{AggregationError, Key};
use crate::aggregation::bucket::AfterKey;
use crate::TantivyError;
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
@@ -160,14 +158,6 @@ pub enum BucketResult {
},
/// This is the filter result - a single bucket with sub-aggregations
Filter(FilterBucketResult),
/// This is the composite result
Composite {
/// The buckets
buckets: Vec<CompositeBucketEntry>,
/// The key to start after when paginating
#[serde(skip_serializing_if = "FxHashMap::is_empty")]
after_key: FxHashMap<String, AfterKey>,
},
}
impl BucketResult {
@@ -189,9 +179,6 @@ impl BucketResult {
// Only count sub-aggregation buckets
filter_result.sub_aggregations.get_bucket_count()
}
BucketResult::Composite { buckets, .. } => {
buckets.iter().map(|bucket| bucket.get_bucket_count()).sum()
}
}
}
}
@@ -208,8 +195,7 @@ pub enum BucketEntries<T> {
}
impl<T> BucketEntries<T> {
/// Iterate over all bucket entries.
pub fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = &'a T> + 'a> {
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = &'a T> + 'a> {
match self {
BucketEntries::Vec(vec) => Box::new(vec.iter()),
BucketEntries::HashMap(map) => Box::new(map.values()),
@@ -351,87 +337,3 @@ pub struct FilterBucketResult {
#[serde(flatten)]
pub sub_aggregations: AggregationResults,
}
/// Note the type information loss compared to `CompositeIntermediateKey`.
/// Pagination is performed using `AfterKey`, which encodes type information.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(untagged)]
pub enum CompositeKey {
/// Boolean key
Bool(bool),
/// String key
Str(String),
/// `i64` key
I64(i64),
/// `u64` key
U64(u64),
/// `f64` key
F64(f64),
/// Null key
Null,
}
impl Eq for CompositeKey {}
impl std::hash::Hash for CompositeKey {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
core::mem::discriminant(self).hash(state);
match self {
Self::Bool(val) => val.hash(state),
Self::Str(text) => text.hash(state),
Self::F64(val) => val.to_bits().hash(state),
Self::U64(val) => val.hash(state),
Self::I64(val) => val.hash(state),
Self::Null => {}
}
}
}
impl PartialEq for CompositeKey {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::Bool(l), Self::Bool(r)) => l == r,
(Self::Str(l), Self::Str(r)) => l == r,
(Self::F64(l), Self::F64(r)) => l.to_bits() == r.to_bits(),
(Self::I64(l), Self::I64(r)) => l == r,
(Self::U64(l), Self::U64(r)) => l == r,
(Self::Null, Self::Null) => true,
_ => false,
}
}
}
impl From<CompositeIntermediateKey> for CompositeKey {
fn from(value: CompositeIntermediateKey) -> Self {
match value {
CompositeIntermediateKey::Str(s) => Self::Str(s),
CompositeIntermediateKey::IpAddr(s) => {
if let Some(ip) = s.to_ipv4_mapped() {
Self::Str(ip.to_string())
} else {
Self::Str(s.to_string())
}
}
CompositeIntermediateKey::F64(f) => Self::F64(f),
CompositeIntermediateKey::Bool(f) => Self::Bool(f),
CompositeIntermediateKey::U64(f) => Self::U64(f),
CompositeIntermediateKey::I64(f) => Self::I64(f),
CompositeIntermediateKey::DateTime(f) => Self::I64(f / 1_000_000), // ns to ms
CompositeIntermediateKey::Null => Self::Null,
}
}
}
/// Composite bucket entry with a multi-dimensional key.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct CompositeBucketEntry {
/// The identifier of the bucket.
pub key: FxHashMap<String, CompositeKey>,
/// Number of documents in the bucket.
pub doc_count: u64,
#[serde(flatten)]
/// Sub-aggregations in this bucket.
pub sub_aggregation: AggregationResults,
}
impl CompositeBucketEntry {
pub(crate) fn get_bucket_count(&self) -> u64 {
1 + self.sub_aggregation.get_bucket_count()
}
}

View File

@@ -1436,46 +1436,3 @@ fn test_aggregation_on_json_object_mixed_numerical_segments() {
)
);
}
#[test]
fn test_aggregation_field_validation_helper() {
// Test the standalone validation helper function for field validation
let index = get_test_index_2_segments(false).unwrap();
let reader = index.reader().unwrap();
let searcher = reader.searcher();
let segment_reader = searcher.segment_reader(0);
// Test with invalid field
let agg_req: Aggregations = serde_json::from_str(
r#"{
"avg_test": {
"avg": { "field": "nonexistent_field" }
}
}"#,
)
.unwrap();
let result =
crate::aggregation::agg_req::validate_aggregation_fields_exist(&agg_req, segment_reader);
assert!(result.is_err());
match result {
Err(crate::TantivyError::FieldNotFound(field_name)) => {
assert_eq!(field_name, "nonexistent_field");
}
_ => panic!("Expected FieldNotFound error, got: {:?}", result),
}
// Test with valid field
let agg_req: Aggregations = serde_json::from_str(
r#"{
"avg_test": {
"avg": { "field": "score" }
}
}"#,
)
.unwrap();
let result =
crate::aggregation::agg_req::validate_aggregation_fields_exist(&agg_req, segment_reader);
assert!(result.is_ok());
}

View File

@@ -1,518 +0,0 @@
use std::net::Ipv6Addr;
use columnar::column_values::{CompactHit, CompactSpaceU64Accessor};
use columnar::{Column, ColumnType, MonotonicallyMappableToU64, StrColumn, TermOrdHit};
use crate::aggregation::accessor_helpers::get_numeric_or_date_column_types;
use crate::aggregation::bucket::composite::numeric_types::num_proj;
use crate::aggregation::bucket::composite::numeric_types::num_proj::ProjectedNumber;
use crate::aggregation::bucket::composite::ToTypePaginationOrder;
use crate::aggregation::bucket::{
parse_into_milliseconds, CalendarInterval, CompositeAggregation, CompositeAggregationSource,
MissingOrder, Order,
};
use crate::aggregation::intermediate_agg_result::CompositeIntermediateKey;
use crate::{SegmentReader, TantivyError};
/// Contains all information required by the SegmentCompositeCollector to perform the
/// composite aggregation on a segment.
pub struct CompositeAggReqData {
/// The name of the aggregation.
pub name: String,
/// The normalized term aggregation request.
pub req: CompositeAggregation,
/// Accessors for each source, each source can have multiple accessors (columns).
pub composite_accessors: Vec<CompositeSourceAccessors>,
}
impl CompositeAggReqData {
/// Estimate the memory consumption of this struct in bytes.
pub fn get_memory_consumption(&self) -> usize {
std::mem::size_of::<Self>()
+ self.composite_accessors.len() * std::mem::size_of::<CompositeSourceAccessors>()
}
}
/// Accessors for a single column in a composite source.
pub struct CompositeAccessor {
/// The fast field column
pub column: Column<u64>,
/// The column type
pub column_type: ColumnType,
/// Term dictionary if the column type is Str
///
/// Only used by term sources
pub str_dict_column: Option<StrColumn>,
/// Parsed date interval for date histogram sources
pub date_histogram_interval: PrecomputedDateInterval,
}
/// Accessors to all the columns that belong to the field of a composite source.
pub struct CompositeSourceAccessors {
/// The accessors for this source
pub accessors: Vec<CompositeAccessor>,
/// The key after which to start collecting results. Applies to the first
/// column of the source.
pub after_key: PrecomputedAfterKey,
/// The column index the after_key applies to. The after_key only applies to
/// one column. Columns before should be skipped. Columns after should be
/// kept without comparison to the after_key.
pub after_key_accessor_idx: usize,
/// Whether to skip missing values because of the after_key. Skipping only
/// applies if the value for previous columns were exactly equal to the
/// corresponding after keys (is_on_after_key).
pub skip_missing: bool,
/// The after key was set to null to indicate that the last collected key
/// was a missing value.
pub is_after_key_explicit_missing: bool,
}
impl CompositeSourceAccessors {
/// Creates a new set of accessors for the composite source.
///
/// Precomputes some values to make collection faster.
pub fn build_for_source(
reader: &SegmentReader,
source: &CompositeAggregationSource,
// First option is None when no after key was set in the query, the
// second option is None when the after key was set but its value for
// this source was set to `null`
source_after_key_opt: Option<&CompositeIntermediateKey>,
) -> crate::Result<Self> {
let is_after_key_explicit_missing = source_after_key_opt
.map(|after_key| matches!(after_key, CompositeIntermediateKey::Null))
.unwrap_or(false);
let mut skip_missing = false;
if let Some(CompositeIntermediateKey::Null) = source_after_key_opt {
if !source.missing_bucket() {
return Err(TantivyError::InvalidArgument(
"the 'after' key for a source cannot be null when 'missing_bucket' is false"
.to_string(),
));
}
} else if source_after_key_opt.is_some() {
// if missing buckets come first and we have a non null after key, we skip missing
if MissingOrder::First == source.missing_order() {
skip_missing = true;
}
if MissingOrder::Default == source.missing_order() && Order::Asc == source.order() {
skip_missing = true;
}
};
match source {
CompositeAggregationSource::Terms(source) => {
let allowed_column_types = [
ColumnType::I64,
ColumnType::U64,
ColumnType::F64,
ColumnType::Str,
ColumnType::DateTime,
ColumnType::Bool,
ColumnType::IpAddr,
// ColumnType::Bytes Unsupported
];
let mut columns_and_types = reader
.fast_fields()
.u64_lenient_for_type_all(Some(&allowed_column_types), &source.field)?;
// Sort columns by their pagination order and determine which to skip
columns_and_types.sort_by_key(|(_, col_type): &(Column, ColumnType)| {
col_type.column_pagination_order()
});
if source.order == Order::Desc {
columns_and_types.reverse();
}
let after_key_accessor_idx = find_first_column_to_collect(
&columns_and_types,
source_after_key_opt,
source.missing_order,
source.order,
)?;
let source_collectors: Vec<CompositeAccessor> = columns_and_types
.into_iter()
.map(|(column, column_type)| {
Ok(CompositeAccessor {
column,
column_type,
str_dict_column: reader.fast_fields().str(&source.field)?,
date_histogram_interval: PrecomputedDateInterval::NotApplicable,
})
})
.collect::<crate::Result<_>>()?;
let after_key = if let Some(first_col) =
source_collectors.get(after_key_accessor_idx)
{
match source_after_key_opt {
Some(after_key) => PrecomputedAfterKey::precompute(
first_col,
after_key,
&source.field,
source.missing_order,
source.order,
)?,
None => {
precompute_missing_after_key(false, source.missing_order, source.order)
}
}
} else {
// if no columns, we don't care about the after_key
PrecomputedAfterKey::Next(0)
};
Ok(CompositeSourceAccessors {
accessors: source_collectors,
is_after_key_explicit_missing,
skip_missing,
after_key,
after_key_accessor_idx,
})
}
CompositeAggregationSource::Histogram(source) => {
let column_and_types: Vec<(Column, ColumnType)> =
reader.fast_fields().u64_lenient_for_type_all(
Some(get_numeric_or_date_column_types()),
&source.field,
)?;
let source_collectors: Vec<CompositeAccessor> = column_and_types
.into_iter()
.map(|(column, column_type)| {
Ok(CompositeAccessor {
column,
column_type,
str_dict_column: None,
date_histogram_interval: PrecomputedDateInterval::NotApplicable,
})
})
.collect::<crate::Result<_>>()?;
let after_key = match source_after_key_opt {
Some(CompositeIntermediateKey::F64(key)) => {
let normalized_key = *key / source.interval;
num_proj::f64_to_i64(normalized_key).into()
}
Some(CompositeIntermediateKey::Null) => {
precompute_missing_after_key(true, source.missing_order, source.order)
}
None => precompute_missing_after_key(true, source.missing_order, source.order),
_ => {
return Err(crate::TantivyError::InvalidArgument(
"After key type invalid for interval composite source".to_string(),
));
}
};
Ok(CompositeSourceAccessors {
accessors: source_collectors,
is_after_key_explicit_missing,
skip_missing,
after_key,
after_key_accessor_idx: 0,
})
}
CompositeAggregationSource::DateHistogram(source) => {
let column_and_types = reader
.fast_fields()
.u64_lenient_for_type_all(Some(&[ColumnType::DateTime]), &source.field)?;
let date_histogram_interval =
PrecomputedDateInterval::from_date_histogram_source_intervals(
&source.fixed_interval,
source.calendar_interval,
)?;
let source_collectors: Vec<CompositeAccessor> = column_and_types
.into_iter()
.map(|(column, column_type)| {
Ok(CompositeAccessor {
column,
column_type,
str_dict_column: None,
date_histogram_interval,
})
})
.collect::<crate::Result<_>>()?;
let after_key = match source_after_key_opt {
Some(CompositeIntermediateKey::DateTime(key)) => {
PrecomputedAfterKey::Exact(key.to_u64())
}
Some(CompositeIntermediateKey::Null) => {
precompute_missing_after_key(true, source.missing_order, source.order)
}
None => precompute_missing_after_key(true, source.missing_order, source.order),
_ => {
return Err(crate::TantivyError::InvalidArgument(
"After key type invalid for interval composite source".to_string(),
));
}
};
Ok(CompositeSourceAccessors {
accessors: source_collectors,
is_after_key_explicit_missing,
skip_missing,
after_key,
after_key_accessor_idx: 0,
})
}
}
}
}
/// Finds the index of the first column we should start collecting from to
/// resume the pagination from the after_key.
fn find_first_column_to_collect<T>(
sorted_columns: &[(T, ColumnType)],
after_key_opt: Option<&CompositeIntermediateKey>,
missing_order: MissingOrder,
order: Order,
) -> crate::Result<usize> {
let after_key = match after_key_opt {
None => return Ok(0), // No pagination, start from beginning
Some(key) => key,
};
// Handle null after_key (we were on a missing value last time)
if matches!(after_key, CompositeIntermediateKey::Null) {
return match (missing_order, order) {
// Missing values come first, so all columns remain
(MissingOrder::First, _) | (MissingOrder::Default, Order::Asc) => Ok(0),
// Missing values come last, so all columns are done
(MissingOrder::Last, _) | (MissingOrder::Default, Order::Desc) => {
Ok(sorted_columns.len())
}
};
}
// Find the first column whose type order matches or follows the after_key's
// type in the pagination sequence
let after_key_column_order = after_key.column_pagination_order();
for (idx, (_, col_type)) in sorted_columns.iter().enumerate() {
let col_order = col_type.column_pagination_order();
let is_first_to_collect = match order {
Order::Asc => col_order >= after_key_column_order,
Order::Desc => col_order <= after_key_column_order,
};
if is_first_to_collect {
return Ok(idx);
}
}
// All columns are before the after_key, nothing left to collect
Ok(sorted_columns.len())
}
fn precompute_missing_after_key(
is_after_key_explicit_missing: bool,
missing_order: MissingOrder,
order: Order,
) -> PrecomputedAfterKey {
let after_last = PrecomputedAfterKey::AfterLast;
let before_first = PrecomputedAfterKey::Next(0);
match (is_after_key_explicit_missing, missing_order, order) {
(true, MissingOrder::First, Order::Asc) => before_first,
(true, MissingOrder::First, Order::Desc) => after_last,
(true, MissingOrder::Last, Order::Asc) => after_last,
(true, MissingOrder::Last, Order::Desc) => before_first,
(true, MissingOrder::Default, Order::Asc) => before_first,
(true, MissingOrder::Default, Order::Desc) => after_last,
(false, _, Order::Asc) => before_first,
(false, _, Order::Desc) => after_last,
}
}
/// A parsed representation of the date interval for date histogram sources
#[derive(Clone, Copy, Debug)]
pub enum PrecomputedDateInterval {
/// This is not a date histogram source
NotApplicable,
/// Source was configured with a fixed interval
FixedNanoseconds(i64),
/// Source was configured with a calendar interval
Calendar(CalendarInterval),
}
impl PrecomputedDateInterval {
/// Validates the date histogram source interval fields and parses a date interval from them.
pub fn from_date_histogram_source_intervals(
fixed_interval: &Option<String>,
calendar_interval: Option<CalendarInterval>,
) -> crate::Result<Self> {
match (fixed_interval, calendar_interval) {
(Some(_), Some(_)) | (None, None) => Err(TantivyError::InvalidArgument(
"date histogram source must one and only one of fixed_interval or \
calendar_interval set"
.to_string(),
)),
(Some(fixed_interval), None) => {
let fixed_interval_ms = parse_into_milliseconds(fixed_interval)?;
Ok(PrecomputedDateInterval::FixedNanoseconds(
fixed_interval_ms * 1_000_000,
))
}
(None, Some(calendar_interval)) => {
Ok(PrecomputedDateInterval::Calendar(calendar_interval))
}
}
}
}
/// The after key projected to the u64 column space
///
/// Some column types (term, IP) might not have an exact representation of the
/// specified after key
#[derive(Debug)]
pub enum PrecomputedAfterKey {
/// The after key could be exactly represented in the column space.
Exact(u64),
/// The after key could not be exactly represented exactly represented, so
/// this is the next closest one.
Next(u64),
/// The after key could not be represented in the column space, it is
/// greater than all value
AfterLast,
}
impl From<CompactHit> for PrecomputedAfterKey {
fn from(hit: CompactHit) -> Self {
match hit {
CompactHit::Exact(ord) => PrecomputedAfterKey::Exact(ord as u64),
CompactHit::Next(ord) => PrecomputedAfterKey::Next(ord as u64),
CompactHit::AfterLast => PrecomputedAfterKey::AfterLast,
}
}
}
impl From<TermOrdHit> for PrecomputedAfterKey {
fn from(hit: TermOrdHit) -> Self {
match hit {
TermOrdHit::Exact(ord) => PrecomputedAfterKey::Exact(ord),
// TermOrdHit represents AfterLast as Next(u64::MAX), we keep it as is
TermOrdHit::Next(ord) => PrecomputedAfterKey::Next(ord),
}
}
}
impl<T: MonotonicallyMappableToU64> From<ProjectedNumber<T>> for PrecomputedAfterKey {
fn from(num: ProjectedNumber<T>) -> Self {
match num {
ProjectedNumber::Exact(number) => PrecomputedAfterKey::Exact(number.to_u64()),
ProjectedNumber::Next(number) => PrecomputedAfterKey::Next(number.to_u64()),
ProjectedNumber::AfterLast => PrecomputedAfterKey::AfterLast,
}
}
}
// /!\ These operators only makes sense if both values are in the same column space
impl PrecomputedAfterKey {
pub fn equals(&self, column_value: u64) -> bool {
match self {
PrecomputedAfterKey::Exact(v) => *v == column_value,
PrecomputedAfterKey::Next(_) => false,
PrecomputedAfterKey::AfterLast => false,
}
}
pub fn gt(&self, column_value: u64) -> bool {
match self {
PrecomputedAfterKey::Exact(v) => *v > column_value,
PrecomputedAfterKey::Next(v) => *v > column_value,
PrecomputedAfterKey::AfterLast => true,
}
}
pub fn lt(&self, column_value: u64) -> bool {
match self {
PrecomputedAfterKey::Exact(v) => *v < column_value,
// a value equal to the next is greater than the after key
PrecomputedAfterKey::Next(v) => *v <= column_value,
PrecomputedAfterKey::AfterLast => false,
}
}
fn precompute_ip_addr(column: &Column<u64>, key: &Ipv6Addr) -> crate::Result<Self> {
let compact_space_accessor = column
.values
.clone()
.downcast_arc::<CompactSpaceU64Accessor>()
.map_err(|_| {
TantivyError::AggregationError(crate::aggregation::AggregationError::InternalError(
"type mismatch: could not downcast to CompactSpaceU64Accessor".to_string(),
))
})?;
let ip_u128 = key.to_bits();
let ip_next_compact = compact_space_accessor.u128_to_next_compact(ip_u128);
Ok(ip_next_compact.into())
}
fn precompute_term_ord(
str_dict_column: &Option<StrColumn>,
key: &str,
field: &str,
) -> crate::Result<Self> {
let dict = str_dict_column
.as_ref()
.expect("dictionary missing for str accessor")
.dictionary();
let next_ord = dict.term_ord_or_next(key).map_err(|_| {
TantivyError::InvalidArgument(format!(
"failed to lookup after_key '{}' for field '{}'",
key, field
))
})?;
Ok(next_ord.into())
}
/// Projects the after key into the column space of the given accessor.
///
/// The computed after key will not take care of skipping entire columns
/// when the after key type is ordered after the accessor's type, that
/// should be performed earlier.
pub fn precompute(
composite_accessor: &CompositeAccessor,
source_after_key: &CompositeIntermediateKey,
field: &str,
missing_order: MissingOrder,
order: Order,
) -> crate::Result<Self> {
use CompositeIntermediateKey as CIKey;
let precomputed_key = match (composite_accessor.column_type, source_after_key) {
(ColumnType::Bytes, _) => panic!("unsupported"),
// null after key
(_, CIKey::Null) => precompute_missing_after_key(false, missing_order, order),
// numerical
(ColumnType::I64, CIKey::I64(k)) => PrecomputedAfterKey::Exact(k.to_u64()),
(ColumnType::I64, CIKey::U64(k)) => num_proj::u64_to_i64(*k).into(),
(ColumnType::I64, CIKey::F64(k)) => num_proj::f64_to_i64(*k).into(),
(ColumnType::U64, CIKey::I64(k)) => num_proj::i64_to_u64(*k).into(),
(ColumnType::U64, CIKey::U64(k)) => PrecomputedAfterKey::Exact(*k),
(ColumnType::U64, CIKey::F64(k)) => num_proj::f64_to_u64(*k).into(),
(ColumnType::F64, CIKey::I64(k)) => num_proj::i64_to_f64(*k).into(),
(ColumnType::F64, CIKey::U64(k)) => num_proj::u64_to_f64(*k).into(),
(ColumnType::F64, CIKey::F64(k)) => PrecomputedAfterKey::Exact(k.to_u64()),
// boolean
(ColumnType::Bool, CIKey::Bool(key)) => PrecomputedAfterKey::Exact(key.to_u64()),
// string
(ColumnType::Str, CIKey::Str(key)) => PrecomputedAfterKey::precompute_term_ord(
&composite_accessor.str_dict_column,
key,
field,
)?,
// date time
(ColumnType::DateTime, CIKey::DateTime(key)) => {
PrecomputedAfterKey::Exact(key.to_u64())
}
// ip address
(ColumnType::IpAddr, CIKey::IpAddr(key)) => {
PrecomputedAfterKey::precompute_ip_addr(&composite_accessor.column, key)?
}
// assume the column's type is ordered after the after_key's type
_ => PrecomputedAfterKey::keep_all(order),
};
Ok(precomputed_key)
}
fn keep_all(order: Order) -> Self {
match order {
Order::Asc => PrecomputedAfterKey::Next(0),
Order::Desc => PrecomputedAfterKey::Next(u64::MAX),
}
}
}

View File

@@ -1,136 +0,0 @@
use time::convert::{Day, Nanosecond};
use time::{Time, UtcDateTime};
const NS_IN_DAY: i64 = Nanosecond::per_t::<i128>(Day) as i64;
/// Computes the timestamp in nanoseconds corresponding to the beginning of the
/// year (January 1st at midnight UTC).
pub(super) fn try_year_bucket(timestamp_ns: i64) -> crate::Result<i64> {
year_bucket_using_time_crate(timestamp_ns).map_err(|e| {
crate::TantivyError::InvalidArgument(format!(
"Failed to compute year bucket for timestamp {}: {e}",
timestamp_ns
))
})
}
/// Computes the timestamp in nanoseconds corresponding to the beginning of the
/// month (1st at midnight UTC).
pub(super) fn try_month_bucket(timestamp_ns: i64) -> crate::Result<i64> {
month_bucket_using_time_crate(timestamp_ns).map_err(|e| {
crate::TantivyError::InvalidArgument(format!(
"Failed to compute month bucket for timestamp {}: {e}",
timestamp_ns
))
})
}
/// Computes the timestamp in nanoseconds corresponding to the beginning of the
/// week (Monday at midnight UTC).
pub(super) fn week_bucket(timestamp_ns: i64) -> i64 {
// 1970-01-01 was a Thursday (weekday = 4)
let days_since_epoch = timestamp_ns.div_euclid(NS_IN_DAY);
// Find the weekday: 0=Monday, ..., 6=Sunday
let weekday = (days_since_epoch + 3).rem_euclid(7);
let monday_days_since_epoch = days_since_epoch - weekday;
monday_days_since_epoch * NS_IN_DAY
}
fn year_bucket_using_time_crate(timestamp_ns: i64) -> Result<i64, time::Error> {
let timestamp_ns = UtcDateTime::from_unix_timestamp_nanos(timestamp_ns as i128)?
.replace_ordinal(1)?
.replace_time(Time::MIDNIGHT)
.unix_timestamp_nanos();
Ok(timestamp_ns as i64)
}
fn month_bucket_using_time_crate(timestamp_ns: i64) -> Result<i64, time::Error> {
let timestamp_ns = UtcDateTime::from_unix_timestamp_nanos(timestamp_ns as i128)?
.replace_day(1)?
.replace_time(Time::MIDNIGHT)
.unix_timestamp_nanos();
Ok(timestamp_ns as i64)
}
#[cfg(test)]
mod tests {
use time::format_description::well_known::Iso8601;
use time::UtcDateTime;
use super::*;
fn ts_ns(iso: &str) -> i64 {
UtcDateTime::parse(iso, &Iso8601::DEFAULT)
.unwrap()
.unix_timestamp_nanos() as i64
}
#[test]
fn test_year_bucket() {
let ts = ts_ns("1970-01-01T00:00:00Z");
let res = try_year_bucket(ts).unwrap();
assert_eq!(res, ts_ns("1970-01-01T00:00:00Z"));
let ts = ts_ns("1970-06-01T10:00:01.010Z");
let res = try_year_bucket(ts).unwrap();
assert_eq!(res, ts_ns("1970-01-01T00:00:00Z"));
let ts = ts_ns("2008-12-31T23:59:59.999999999Z"); // leap year
let res = try_year_bucket(ts).unwrap();
assert_eq!(res, ts_ns("2008-01-01T00:00:00Z"));
let ts = ts_ns("2008-01-01T00:00:00Z"); // leap year
let res = try_year_bucket(ts).unwrap();
assert_eq!(res, ts_ns("2008-01-01T00:00:00Z"));
let ts = ts_ns("2010-12-31T23:59:59.999999999Z");
let res = try_year_bucket(ts).unwrap();
assert_eq!(res, ts_ns("2010-01-01T00:00:00Z"));
let ts = ts_ns("1972-06-01T00:10:00Z");
let res = try_year_bucket(ts).unwrap();
assert_eq!(res, ts_ns("1972-01-01T00:00:00Z"));
}
#[test]
fn test_month_bucket() {
let ts = ts_ns("1970-01-15T00:00:00Z");
let res = try_month_bucket(ts).unwrap();
assert_eq!(res, ts_ns("1970-01-01T00:00:00Z"));
let ts = ts_ns("1970-02-01T00:00:00Z");
let res = try_month_bucket(ts).unwrap();
assert_eq!(res, ts_ns("1970-02-01T00:00:00Z"));
let ts = ts_ns("2000-01-31T23:59:59.999999999Z");
let res = try_month_bucket(ts).unwrap();
assert_eq!(res, ts_ns("2000-01-01T00:00:00Z"));
}
#[test]
fn test_week_bucket() {
let ts = ts_ns("1970-01-05T00:00:00Z"); // Monday
let res = week_bucket(ts);
assert_eq!(res, ts_ns("1970-01-05T00:00:00Z"));
let ts = ts_ns("1970-01-05T23:59:59Z"); // Monday
let res = week_bucket(ts);
assert_eq!(res, ts_ns("1970-01-05T00:00:00Z"));
let ts = ts_ns("1970-01-07T01:13:00Z"); // Wednesday
let res = week_bucket(ts);
assert_eq!(res, ts_ns("1970-01-05T00:00:00Z"));
let ts = ts_ns("1970-01-11T23:59:59.999999999Z"); // Sunday
let res = week_bucket(ts);
assert_eq!(res, ts_ns("1970-01-05T00:00:00Z"));
let ts = ts_ns("2025-10-16T10:41:59.010Z"); // Thursday
let res = week_bucket(ts);
assert_eq!(res, ts_ns("2025-10-13T00:00:00Z"));
let ts = ts_ns("1970-01-01T00:00:00Z"); // Thursday
let res = week_bucket(ts);
assert_eq!(res, ts_ns("1969-12-29T00:00:00Z")); // Negative
}
}

View File

@@ -1,660 +0,0 @@
use std::fmt::Debug;
use std::mem;
use std::net::Ipv6Addr;
use columnar::column_values::CompactSpaceU64Accessor;
use columnar::{
Column, ColumnType, Dictionary, MonotonicallyMappableToU128, MonotonicallyMappableToU64,
NumericalValue, StrColumn,
};
use rustc_hash::FxHashMap;
use smallvec::SmallVec;
use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::bucket::composite::accessors::{
CompositeAccessor, CompositeAggReqData, PrecomputedDateInterval,
};
use crate::aggregation::bucket::composite::calendar_interval;
use crate::aggregation::bucket::composite::map::{DynArrayHeapMap, MAX_DYN_ARRAY_SIZE};
use crate::aggregation::bucket::{
CalendarInterval, CompositeAggregationSource, MissingOrder, Order,
};
use crate::aggregation::buffered_sub_aggs::{BufferedSubAggs, HighCardSubAggBuffer};
use crate::aggregation::intermediate_agg_result::{
CompositeIntermediateKey, IntermediateAggregationResult, IntermediateAggregationResults,
IntermediateBucketResult, IntermediateCompositeBucketEntry, IntermediateCompositeBucketResult,
};
use crate::aggregation::segment_agg_result::{BucketIdProvider, SegmentAggregationCollector};
use crate::aggregation::BucketId;
use crate::TantivyError;
#[derive(Clone, Debug)]
struct CompositeBucketCollector {
count: u32,
bucket_id: BucketId,
}
/// Compact sortable representation of a single source value within a composite key.
///
/// The struct encodes both the column identity and the fast field value in a way
/// that preserves the desired sort order via the derived `Ord` implementation
/// (fields are compared top-to-bottom: `sort_key` first, then `encoded_value`).
///
/// ## `sort_key` encoding
/// - `0` — missing value, sorted first
/// - `1..=254` — present value; the original accessor index is `sort_key - 1`
/// - `u8::MAX` (255) — missing value, sorted last
///
/// ## `encoded_value` encoding
/// - `0` when the field is missing
/// - The raw u64 fast-field representation when order is ascending
/// - Bitwise NOT of the raw u64 when order is descending
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Default, Hash)]
struct InternalValueRepr {
/// Column index biased by +1 (so 0 and u8::MAX are reserved for missing sentinels).
sort_key: u8,
/// Fast field value, possibly bit-flipped for descending order.
encoded_value: u64,
}
impl InternalValueRepr {
#[inline]
fn new_term(raw: u64, accessor_idx: u8, order: Order) -> Self {
let encoded_value = match order {
Order::Asc => raw,
Order::Desc => !raw,
};
InternalValueRepr {
sort_key: accessor_idx + 1,
encoded_value,
}
}
/// For histogram sources the column index is irrelevant (always 1).
#[inline]
fn new_histogram(raw: u64, order: Order) -> Self {
let encoded_value = match order {
Order::Asc => raw,
Order::Desc => !raw,
};
InternalValueRepr {
sort_key: 1,
encoded_value,
}
}
#[inline]
fn new_missing(order: Order, missing_order: MissingOrder) -> Self {
let sort_key = match (missing_order, order) {
(MissingOrder::First, _) | (MissingOrder::Default, Order::Asc) => 0,
(MissingOrder::Last, _) | (MissingOrder::Default, Order::Desc) => u8::MAX,
};
InternalValueRepr {
sort_key,
encoded_value: 0,
}
}
/// Decode back to `(accessor_idx, raw_value)`.
/// Returns `None` when the value represents a missing field.
#[inline]
fn decode(self, order: Order) -> Option<(u8, u64)> {
if self.sort_key == 0 || self.sort_key == u8::MAX {
return None;
}
let raw = match order {
Order::Asc => self.encoded_value,
Order::Desc => !self.encoded_value,
};
Some((self.sort_key - 1, raw))
}
}
/// The collector puts values from the fast field into the correct buckets and
/// does a conversion to the correct datatype.
#[derive(Debug)]
pub struct SegmentCompositeCollector {
/// One DynArrayHeapMap per parent bucket.
parent_buckets: Vec<DynArrayHeapMap<InternalValueRepr, CompositeBucketCollector>>,
accessor_idx: usize,
sub_agg: Option<BufferedSubAggs<HighCardSubAggBuffer>>,
bucket_id_provider: BucketIdProvider,
/// Number of sources, needed when creating new DynArrayHeapMaps.
num_sources: usize,
}
impl SegmentAggregationCollector for SegmentCompositeCollector {
fn add_intermediate_aggregation_result(
&mut self,
agg_data: &AggregationsSegmentCtx,
results: &mut IntermediateAggregationResults,
parent_bucket_id: BucketId,
) -> crate::Result<()> {
let name = agg_data
.get_composite_req_data(self.accessor_idx)
.name
.clone();
let buckets = self.add_intermediate_bucket_result(agg_data, parent_bucket_id)?;
results.push(
name,
IntermediateAggregationResult::Bucket(IntermediateBucketResult::Composite { buckets }),
)?;
Ok(())
}
fn collect(
&mut self,
parent_bucket_id: BucketId,
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let mem_pre = self.get_memory_consumption(parent_bucket_id);
let composite_agg_data = agg_data.take_composite_req_data(self.accessor_idx);
for doc in docs {
let mut visitor = CompositeKeyVisitor {
doc_id: *doc,
composite_agg_data: &composite_agg_data,
buckets: &mut self.parent_buckets[parent_bucket_id as usize],
sub_agg: &mut self.sub_agg,
bucket_id_provider: &mut self.bucket_id_provider,
sub_level_values: SmallVec::new(),
};
visitor.visit(0, true)?;
}
agg_data.put_back_composite_req_data(self.accessor_idx, composite_agg_data);
if let Some(sub_agg) = &mut self.sub_agg {
sub_agg.check_flush_local(agg_data)?;
}
let mem_delta = self.get_memory_consumption(parent_bucket_id) - mem_pre;
if mem_delta > 0 {
agg_data.context.limits.add_memory_consumed(mem_delta)?;
}
Ok(())
}
fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
if let Some(sub_agg) = &mut self.sub_agg {
sub_agg.flush(agg_data)?;
}
Ok(())
}
fn prepare_max_bucket(
&mut self,
max_bucket: BucketId,
_agg_data: &AggregationsSegmentCtx,
) -> crate::Result<()> {
let required_len = max_bucket as usize + 1;
while self.parent_buckets.len() < required_len {
let map = DynArrayHeapMap::try_new(self.num_sources)?;
self.parent_buckets.push(map);
}
Ok(())
}
fn compute_metric_value(
&self,
_bucket_id: BucketId,
_sub_agg_name: &str,
_sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
// Composite is a multi-bucket agg with no single value to extract.
None
}
}
impl SegmentCompositeCollector {
fn get_memory_consumption(&self, parent_bucket_id: BucketId) -> u64 {
self.parent_buckets[parent_bucket_id as usize].memory_consumption()
}
pub(crate) fn from_req_and_validate(
req_data: &mut AggregationsSegmentCtx,
node: &AggRefNode,
) -> crate::Result<Self> {
validate_req(req_data, node.idx_in_req_data)?;
let has_sub_aggregations = !node.children.is_empty();
let sub_agg = if has_sub_aggregations {
let sub_agg_collector = build_segment_agg_collectors(req_data, &node.children)?;
Some(BufferedSubAggs::new(sub_agg_collector))
} else {
None
};
let composite_req_data = req_data.get_composite_req_data(node.idx_in_req_data);
let num_sources = composite_req_data.req.sources.len();
Ok(SegmentCompositeCollector {
parent_buckets: vec![DynArrayHeapMap::try_new(num_sources)?],
accessor_idx: node.idx_in_req_data,
sub_agg,
bucket_id_provider: BucketIdProvider::default(),
num_sources,
})
}
#[inline]
fn add_intermediate_bucket_result(
&mut self,
agg_data: &AggregationsSegmentCtx,
parent_bucket_id: BucketId,
) -> crate::Result<IntermediateCompositeBucketResult> {
let empty_map = DynArrayHeapMap::try_new(self.num_sources)?;
let heap_map = mem::replace(
&mut self.parent_buckets[parent_bucket_id as usize],
empty_map,
);
let mut dict: FxHashMap<Vec<CompositeIntermediateKey>, IntermediateCompositeBucketEntry> =
Default::default();
dict.reserve(heap_map.size());
let composite_data = agg_data.get_composite_req_data(self.accessor_idx);
for (key_internal_repr, agg) in heap_map.into_iter() {
let key = resolve_key(&key_internal_repr, composite_data)?;
let mut sub_aggregation_res = IntermediateAggregationResults::default();
if let Some(sub_agg) = &mut self.sub_agg {
sub_agg
.get_sub_agg_collector()
.add_intermediate_aggregation_result(
agg_data,
&mut sub_aggregation_res,
agg.bucket_id,
)?;
}
dict.insert(
key,
IntermediateCompositeBucketEntry {
doc_count: agg.count,
sub_aggregation: sub_aggregation_res,
},
);
}
Ok(IntermediateCompositeBucketResult {
entries: dict,
target_size: composite_data.req.size,
orders: composite_data
.req
.sources
.iter()
.map(|source| match source {
CompositeAggregationSource::Terms(t) => (t.order, t.missing_order),
CompositeAggregationSource::Histogram(h) => (h.order, h.missing_order),
CompositeAggregationSource::DateHistogram(d) => (d.order, d.missing_order),
})
.collect(),
})
}
}
fn validate_req(req_data: &mut AggregationsSegmentCtx, accessor_idx: usize) -> crate::Result<()> {
let composite_data = req_data.get_composite_req_data(accessor_idx);
let req = &composite_data.req;
if req.sources.is_empty() {
return Err(TantivyError::InvalidArgument(
"composite aggregation must have at least one source".to_string(),
));
}
if req.size == 0 {
return Err(TantivyError::InvalidArgument(
"composite aggregation 'size' must be > 0".to_string(),
));
}
if composite_data.composite_accessors.len() > MAX_DYN_ARRAY_SIZE {
return Err(TantivyError::InvalidArgument(format!(
"composite aggregation source supports maximum {MAX_DYN_ARRAY_SIZE} sources",
)));
}
let column_types_for_sources = composite_data.composite_accessors.iter().map(|item| {
item.accessors
.iter()
.map(|a| a.column_type)
.collect::<Vec<_>>()
});
for column_types in column_types_for_sources {
if column_types.contains(&ColumnType::Bytes) {
return Err(TantivyError::InvalidArgument(
"composite aggregation does not support 'bytes' field type".to_string(),
));
}
}
Ok(())
}
fn collect_bucket_with_limit(
doc_id: crate::DocId,
limit_num_buckets: usize,
buckets: &mut DynArrayHeapMap<InternalValueRepr, CompositeBucketCollector>,
key: &[InternalValueRepr],
sub_agg: &mut Option<BufferedSubAggs<HighCardSubAggBuffer>>,
bucket_id_provider: &mut BucketIdProvider,
) {
let mut record_in_bucket = |bucket: &mut CompositeBucketCollector| {
bucket.count += 1;
if let Some(sub_agg) = sub_agg {
sub_agg.push(bucket.bucket_id, doc_id);
}
};
// We still have room for buckets, just insert
if buckets.size() < limit_num_buckets {
let bucket = buckets.get_or_insert_with(key, || CompositeBucketCollector {
count: 0,
bucket_id: bucket_id_provider.next_bucket_id(),
});
record_in_bucket(bucket);
return;
}
// Map is full, but we can still update the bucket if it already exists
if let Some(bucket) = buckets.get_mut(key) {
record_in_bucket(bucket);
return;
}
// Check if the item qualifies to enter the top-k, and evict the highest if it does
if let Some(highest_key) = buckets.peek_highest() {
if key < highest_key {
buckets.evict_highest();
let bucket = buckets.get_or_insert_with(key, || CompositeBucketCollector {
count: 0,
bucket_id: bucket_id_provider.next_bucket_id(),
});
record_in_bucket(bucket);
}
}
}
/// Converts the composite key from its internal column space representation
/// (segment specific) into its intermediate form.
fn resolve_key(
internal_key: &[InternalValueRepr],
agg_data: &CompositeAggReqData,
) -> crate::Result<Vec<CompositeIntermediateKey>> {
internal_key
.iter()
.enumerate()
.map(|(idx, val)| {
resolve_internal_value_repr(
*val,
&agg_data.req.sources[idx],
&agg_data.composite_accessors[idx].accessors,
)
})
.collect()
}
fn resolve_internal_value_repr(
internal_value_repr: InternalValueRepr,
source: &CompositeAggregationSource,
composite_accessors: &[CompositeAccessor],
) -> crate::Result<CompositeIntermediateKey> {
let decoded_value_opt = match source {
CompositeAggregationSource::Terms(source) => internal_value_repr.decode(source.order),
CompositeAggregationSource::Histogram(source) => internal_value_repr.decode(source.order),
CompositeAggregationSource::DateHistogram(source) => {
internal_value_repr.decode(source.order)
}
};
let Some((decoded_accessor_idx, val)) = decoded_value_opt else {
return Ok(CompositeIntermediateKey::Null);
};
let key = match source {
CompositeAggregationSource::Terms(_) => {
let CompositeAccessor {
column_type,
str_dict_column,
column,
..
} = &composite_accessors[decoded_accessor_idx as usize];
resolve_term(val, column_type, str_dict_column, column)?
}
CompositeAggregationSource::Histogram(source) => {
CompositeIntermediateKey::F64(i64::from_u64(val) as f64 * source.interval)
}
CompositeAggregationSource::DateHistogram(_) => {
CompositeIntermediateKey::DateTime(i64::from_u64(val))
}
};
Ok(key)
}
fn resolve_term(
val: u64,
column_type: &ColumnType,
str_dict_column: &Option<StrColumn>,
column: &Column,
) -> crate::Result<CompositeIntermediateKey> {
let key = if *column_type == ColumnType::Str {
let fallback_dict = Dictionary::empty();
let term_dict = str_dict_column
.as_ref()
.map(|el| el.dictionary())
.unwrap_or_else(|| &fallback_dict);
let mut buffer = Vec::new();
term_dict.ord_to_term(val, &mut buffer)?;
CompositeIntermediateKey::Str(
String::from_utf8(buffer.to_vec()).expect("could not convert to String"),
)
} else if *column_type == ColumnType::DateTime {
let val = i64::from_u64(val);
CompositeIntermediateKey::DateTime(val)
} else if *column_type == ColumnType::Bool {
let val = bool::from_u64(val);
CompositeIntermediateKey::Bool(val)
} else if *column_type == ColumnType::IpAddr {
let compact_space_accessor = column
.values
.clone()
.downcast_arc::<CompactSpaceU64Accessor>()
.map_err(|_| {
TantivyError::AggregationError(crate::aggregation::AggregationError::InternalError(
"Type mismatch: Could not downcast to CompactSpaceU64Accessor".to_string(),
))
})?;
let val: u128 = compact_space_accessor.compact_to_u128(val as u32);
let val = Ipv6Addr::from_u128(val);
CompositeIntermediateKey::IpAddr(val)
} else if *column_type == ColumnType::U64 {
CompositeIntermediateKey::U64(val)
} else if *column_type == ColumnType::I64 {
CompositeIntermediateKey::I64(i64::from_u64(val))
} else {
let val = f64::from_u64(val);
let val: NumericalValue = val.into();
match val.normalize() {
NumericalValue::U64(val) => CompositeIntermediateKey::U64(val),
NumericalValue::I64(val) => CompositeIntermediateKey::I64(val),
NumericalValue::F64(val) => CompositeIntermediateKey::F64(val),
}
};
Ok(key)
}
/// Browse through the cardinal product obtained by the different values of the doc composite key
/// sources.
///
/// For each of those tuple-key, that are after the limit key, we call collect_bucket_with_limit.
struct CompositeKeyVisitor<'a> {
doc_id: crate::DocId,
composite_agg_data: &'a CompositeAggReqData,
buckets: &'a mut DynArrayHeapMap<InternalValueRepr, CompositeBucketCollector>,
sub_agg: &'a mut Option<BufferedSubAggs<HighCardSubAggBuffer>>,
bucket_id_provider: &'a mut BucketIdProvider,
sub_level_values: SmallVec<[InternalValueRepr; MAX_DYN_ARRAY_SIZE]>,
}
impl CompositeKeyVisitor<'_> {
/// Depth-first walk of the accessors to build the composite key combinations
/// and update the buckets.
///
/// `source_idx` is the current source index in the recursion.
/// `is_on_after_key` tracks whether we still need to consider the after_key
/// for pruning at this level and below.
fn visit(&mut self, source_idx: usize, is_on_after_key: bool) -> crate::Result<()> {
if source_idx == self.composite_agg_data.req.sources.len() {
if !is_on_after_key {
collect_bucket_with_limit(
self.doc_id,
self.composite_agg_data.req.size as usize,
self.buckets,
&self.sub_level_values,
self.sub_agg,
self.bucket_id_provider,
);
}
return Ok(());
}
let current_level_accessors = &self.composite_agg_data.composite_accessors[source_idx];
let current_level_source = &self.composite_agg_data.req.sources[source_idx];
let mut missing = true;
for (accessor_idx, accessor) in current_level_accessors.accessors.iter().enumerate() {
let values = accessor.column.values_for_doc(self.doc_id);
for value in values {
missing = false;
match current_level_source {
CompositeAggregationSource::Terms(_) => {
let preceeds_after_key_type =
accessor_idx < current_level_accessors.after_key_accessor_idx;
if is_on_after_key && preceeds_after_key_type {
break;
}
let matches_after_key_type =
accessor_idx == current_level_accessors.after_key_accessor_idx;
if matches_after_key_type && is_on_after_key {
let should_skip = match current_level_source.order() {
Order::Asc => current_level_accessors.after_key.gt(value),
Order::Desc => current_level_accessors.after_key.lt(value),
};
if should_skip {
continue;
}
}
self.sub_level_values.push(InternalValueRepr::new_term(
value,
accessor_idx as u8,
current_level_source.order(),
));
let still_on_after_key = matches_after_key_type
&& current_level_accessors.after_key.equals(value);
self.visit(source_idx + 1, is_on_after_key && still_on_after_key)?;
self.sub_level_values.pop();
}
CompositeAggregationSource::Histogram(source) => {
let float_value = match accessor.column_type {
ColumnType::U64 => value as f64,
ColumnType::I64 => i64::from_u64(value) as f64,
ColumnType::DateTime => i64::from_u64(value) as f64 / 1_000_000.,
ColumnType::F64 => f64::from_u64(value),
_ => {
panic!(
"unexpected type {:?}. This should not happen",
accessor.column_type
)
}
};
let bucket_index = (float_value / source.interval).floor() as i64;
let bucket_value = i64::to_u64(bucket_index);
if is_on_after_key {
let should_skip = match current_level_source.order() {
Order::Asc => current_level_accessors.after_key.gt(bucket_value),
Order::Desc => current_level_accessors.after_key.lt(bucket_value),
};
if should_skip {
continue;
}
}
self.sub_level_values.push(InternalValueRepr::new_histogram(
bucket_value,
current_level_source.order(),
));
let still_on_after_key =
current_level_accessors.after_key.equals(bucket_value);
self.visit(source_idx + 1, is_on_after_key && still_on_after_key)?;
self.sub_level_values.pop();
}
CompositeAggregationSource::DateHistogram(_) => {
let value_ns = match accessor.column_type {
ColumnType::DateTime => i64::from_u64(value),
_ => {
panic!(
"unexpected type {:?}. This should not happen",
accessor.column_type
)
}
};
let bucket_index = match accessor.date_histogram_interval {
PrecomputedDateInterval::FixedNanoseconds(fixed_interval_ns) => {
(value_ns / fixed_interval_ns) * fixed_interval_ns
}
PrecomputedDateInterval::Calendar(CalendarInterval::Year) => {
calendar_interval::try_year_bucket(value_ns)?
}
PrecomputedDateInterval::Calendar(CalendarInterval::Month) => {
calendar_interval::try_month_bucket(value_ns)?
}
PrecomputedDateInterval::Calendar(CalendarInterval::Week) => {
calendar_interval::week_bucket(value_ns)
}
PrecomputedDateInterval::NotApplicable => {
panic!("interval not precomputed for date histogram source")
}
};
let bucket_value = i64::to_u64(bucket_index);
if is_on_after_key {
let should_skip = match current_level_source.order() {
Order::Asc => current_level_accessors.after_key.gt(bucket_value),
Order::Desc => current_level_accessors.after_key.lt(bucket_value),
};
if should_skip {
continue;
}
}
self.sub_level_values.push(InternalValueRepr::new_histogram(
bucket_value,
current_level_source.order(),
));
let still_on_after_key =
current_level_accessors.after_key.equals(bucket_value);
self.visit(source_idx + 1, is_on_after_key && still_on_after_key)?;
self.sub_level_values.pop();
}
};
}
}
if missing && current_level_source.missing_bucket() {
if is_on_after_key && current_level_accessors.skip_missing {
return Ok(());
}
self.sub_level_values.push(InternalValueRepr::new_missing(
current_level_source.order(),
current_level_source.missing_order(),
));
self.visit(
source_idx + 1,
is_on_after_key && current_level_accessors.is_after_key_explicit_missing,
)?;
self.sub_level_values.pop();
}
Ok(())
}
}

View File

@@ -1,329 +0,0 @@
use std::collections::BinaryHeap;
use std::fmt::Debug;
use std::hash::Hash;
use rustc_hash::FxHashMap;
use smallvec::SmallVec;
use crate::TantivyError;
/// Map backed by a hash map for fast access and a binary heap to track the
/// highest key. The key is an array of fixed size S.
#[derive(Clone, Debug)]
struct ArrayHeapMap<K: Ord, V, const S: usize> {
pub(crate) buckets: FxHashMap<[K; S], V>,
pub(crate) heap: BinaryHeap<[K; S]>,
}
impl<K: Ord, V, const S: usize> Default for ArrayHeapMap<K, V, S> {
fn default() -> Self {
ArrayHeapMap {
buckets: FxHashMap::default(),
heap: BinaryHeap::default(),
}
}
}
impl<K: Eq + Hash + Clone + Ord, V, const S: usize> ArrayHeapMap<K, V, S> {
/// Panics if the length of `key` is not S.
fn get_or_insert_with<F: FnOnce() -> V>(&mut self, key: &[K], f: F) -> &mut V {
let key_array: &[K; S] = key.try_into().expect("Key length mismatch");
self.buckets.entry(key_array.clone()).or_insert_with(|| {
self.heap.push(key_array.clone());
f()
})
}
/// Panics if the length of `key` is not S.
fn get_mut(&mut self, key: &[K]) -> Option<&mut V> {
let key_array: &[K; S] = key.try_into().expect("Key length mismatch");
self.buckets.get_mut(key_array)
}
fn peek_highest(&self) -> Option<&[K]> {
self.heap.peek().map(|k_array| k_array.as_slice())
}
fn evict_highest(&mut self) {
if let Some(highest) = self.heap.pop() {
self.buckets.remove(&highest);
}
}
fn memory_consumption(&self) -> u64 {
let key_size = std::mem::size_of::<[K; S]>();
let map_size = (key_size + std::mem::size_of::<V>()) * self.buckets.capacity();
let heap_size = key_size * self.heap.capacity();
(map_size + heap_size) as u64
}
}
impl<K: Copy + Ord + Clone + 'static, V: 'static, const S: usize> ArrayHeapMap<K, V, S> {
fn into_iter(self) -> Box<dyn Iterator<Item = (SmallVec<[K; MAX_DYN_ARRAY_SIZE]>, V)>> {
Box::new(
self.buckets
.into_iter()
.map(|(k, v)| (SmallVec::from_slice(&k), v)),
)
}
}
pub(super) const MAX_DYN_ARRAY_SIZE: usize = 16;
const MAX_DYN_ARRAY_SIZE_PLUS_ONE: usize = MAX_DYN_ARRAY_SIZE + 1;
/// A map optimized for memory footprint, fast access and efficient eviction of
/// the highest key.
///
/// Keys are inlined arrays of size 1 to [MAX_DYN_ARRAY_SIZE] but for a given
/// instance the key size is fixed. This allows to avoid heap allocations for the
/// keys.
#[derive(Clone, Debug)]
pub(super) struct DynArrayHeapMap<K: Ord, V>(DynArrayHeapMapInner<K, V>);
/// Wrapper around ArrayHeapMap to dynamically dispatch on the array size.
#[derive(Clone, Debug)]
enum DynArrayHeapMapInner<K: Ord, V> {
Dim1(ArrayHeapMap<K, V, 1>),
Dim2(ArrayHeapMap<K, V, 2>),
Dim3(ArrayHeapMap<K, V, 3>),
Dim4(ArrayHeapMap<K, V, 4>),
Dim5(ArrayHeapMap<K, V, 5>),
Dim6(ArrayHeapMap<K, V, 6>),
Dim7(ArrayHeapMap<K, V, 7>),
Dim8(ArrayHeapMap<K, V, 8>),
Dim9(ArrayHeapMap<K, V, 9>),
Dim10(ArrayHeapMap<K, V, 10>),
Dim11(ArrayHeapMap<K, V, 11>),
Dim12(ArrayHeapMap<K, V, 12>),
Dim13(ArrayHeapMap<K, V, 13>),
Dim14(ArrayHeapMap<K, V, 14>),
Dim15(ArrayHeapMap<K, V, 15>),
Dim16(ArrayHeapMap<K, V, 16>),
}
impl<K: Ord, V> DynArrayHeapMap<K, V> {
/// Creates a new heap map with dynamic array keys of size `key_dimension`.
pub(super) fn try_new(key_dimension: usize) -> crate::Result<Self> {
let inner = match key_dimension {
0 => {
return Err(TantivyError::InvalidArgument(
"DynArrayHeapMap dimension must be at least 1".to_string(),
))
}
1 => DynArrayHeapMapInner::Dim1(ArrayHeapMap::default()),
2 => DynArrayHeapMapInner::Dim2(ArrayHeapMap::default()),
3 => DynArrayHeapMapInner::Dim3(ArrayHeapMap::default()),
4 => DynArrayHeapMapInner::Dim4(ArrayHeapMap::default()),
5 => DynArrayHeapMapInner::Dim5(ArrayHeapMap::default()),
6 => DynArrayHeapMapInner::Dim6(ArrayHeapMap::default()),
7 => DynArrayHeapMapInner::Dim7(ArrayHeapMap::default()),
8 => DynArrayHeapMapInner::Dim8(ArrayHeapMap::default()),
9 => DynArrayHeapMapInner::Dim9(ArrayHeapMap::default()),
10 => DynArrayHeapMapInner::Dim10(ArrayHeapMap::default()),
11 => DynArrayHeapMapInner::Dim11(ArrayHeapMap::default()),
12 => DynArrayHeapMapInner::Dim12(ArrayHeapMap::default()),
13 => DynArrayHeapMapInner::Dim13(ArrayHeapMap::default()),
14 => DynArrayHeapMapInner::Dim14(ArrayHeapMap::default()),
15 => DynArrayHeapMapInner::Dim15(ArrayHeapMap::default()),
16 => DynArrayHeapMapInner::Dim16(ArrayHeapMap::default()),
MAX_DYN_ARRAY_SIZE_PLUS_ONE.. => {
return Err(TantivyError::InvalidArgument(format!(
"DynArrayHeapMap supports maximum {MAX_DYN_ARRAY_SIZE} dimensions, got \
{key_dimension}",
)))
}
};
Ok(DynArrayHeapMap(inner))
}
/// Number of elements in the map. This is not the dimension of the keys.
pub(super) fn size(&self) -> usize {
match &self.0 {
DynArrayHeapMapInner::Dim1(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim2(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim3(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim4(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim5(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim6(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim7(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim8(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim9(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim10(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim11(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim12(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim13(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim14(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim15(map) => map.buckets.len(),
DynArrayHeapMapInner::Dim16(map) => map.buckets.len(),
}
}
}
impl<K: Ord + Hash + Clone, V> DynArrayHeapMap<K, V> {
/// Get a mutable reference to the value corresponding to `key` or inserts a new
/// value created by calling `f`.
///
/// Panics if the length of `key` does not match the key dimension of the map.
pub(super) fn get_or_insert_with<F: FnOnce() -> V>(&mut self, key: &[K], f: F) -> &mut V {
match &mut self.0 {
DynArrayHeapMapInner::Dim1(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim2(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim3(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim4(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim5(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim6(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim7(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim8(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim9(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim10(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim11(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim12(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim13(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim14(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim15(map) => map.get_or_insert_with(key, f),
DynArrayHeapMapInner::Dim16(map) => map.get_or_insert_with(key, f),
}
}
/// Returns a mutable reference to the value corresponding to `key`.
///
/// Panics if the length of `key` does not match the key dimension of the map.
pub fn get_mut(&mut self, key: &[K]) -> Option<&mut V> {
match &mut self.0 {
DynArrayHeapMapInner::Dim1(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim2(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim3(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim4(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim5(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim6(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim7(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim8(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim9(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim10(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim11(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim12(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim13(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim14(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim15(map) => map.get_mut(key),
DynArrayHeapMapInner::Dim16(map) => map.get_mut(key),
}
}
/// Returns a reference to the highest key in the map.
pub(super) fn peek_highest(&self) -> Option<&[K]> {
match &self.0 {
DynArrayHeapMapInner::Dim1(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim2(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim3(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim4(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim5(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim6(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim7(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim8(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim9(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim10(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim11(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim12(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim13(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim14(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim15(map) => map.peek_highest(),
DynArrayHeapMapInner::Dim16(map) => map.peek_highest(),
}
}
/// Removes the entry with the highest key from the map.
pub(super) fn evict_highest(&mut self) {
match &mut self.0 {
DynArrayHeapMapInner::Dim1(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim2(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim3(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim4(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim5(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim6(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim7(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim8(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim9(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim10(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim11(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim12(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim13(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim14(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim15(map) => map.evict_highest(),
DynArrayHeapMapInner::Dim16(map) => map.evict_highest(),
}
}
pub(crate) fn memory_consumption(&self) -> u64 {
match &self.0 {
DynArrayHeapMapInner::Dim1(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim2(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim3(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim4(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim5(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim6(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim7(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim8(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim9(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim10(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim11(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim12(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim13(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim14(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim15(map) => map.memory_consumption(),
DynArrayHeapMapInner::Dim16(map) => map.memory_consumption(),
}
}
}
impl<K: Ord + Clone + Copy + 'static, V: 'static> DynArrayHeapMap<K, V> {
/// Turns this map into an iterator over key-value pairs.
pub fn into_iter(self) -> impl Iterator<Item = (SmallVec<[K; MAX_DYN_ARRAY_SIZE]>, V)> {
match self.0 {
DynArrayHeapMapInner::Dim1(map) => map.into_iter(),
DynArrayHeapMapInner::Dim2(map) => map.into_iter(),
DynArrayHeapMapInner::Dim3(map) => map.into_iter(),
DynArrayHeapMapInner::Dim4(map) => map.into_iter(),
DynArrayHeapMapInner::Dim5(map) => map.into_iter(),
DynArrayHeapMapInner::Dim6(map) => map.into_iter(),
DynArrayHeapMapInner::Dim7(map) => map.into_iter(),
DynArrayHeapMapInner::Dim8(map) => map.into_iter(),
DynArrayHeapMapInner::Dim9(map) => map.into_iter(),
DynArrayHeapMapInner::Dim10(map) => map.into_iter(),
DynArrayHeapMapInner::Dim11(map) => map.into_iter(),
DynArrayHeapMapInner::Dim12(map) => map.into_iter(),
DynArrayHeapMapInner::Dim13(map) => map.into_iter(),
DynArrayHeapMapInner::Dim14(map) => map.into_iter(),
DynArrayHeapMapInner::Dim15(map) => map.into_iter(),
DynArrayHeapMapInner::Dim16(map) => map.into_iter(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dyn_array_heap_map() {
let mut map = DynArrayHeapMap::<u32, &str>::try_new(2).unwrap();
// insert
let key1 = [1u32, 2u32];
let key2 = [2u32, 1u32];
map.get_or_insert_with(&key1, || "a");
map.get_or_insert_with(&key2, || "b");
assert_eq!(map.size(), 2);
// evict highest
assert_eq!(map.peek_highest(), Some(&key2[..]));
map.evict_highest();
assert_eq!(map.size(), 1);
assert_eq!(map.peek_highest(), Some(&key1[..]));
// into_iter
let mut iter = map.into_iter();
let (k, v) = iter.next().unwrap();
assert_eq!(k.as_slice(), &key1);
assert_eq!(v, "a");
assert_eq!(iter.next(), None);
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,460 +0,0 @@
/// This module helps comparing numerical values of different types (i64, u64
/// and f64).
pub(super) mod num_cmp {
use std::cmp::Ordering;
use crate::TantivyError;
pub fn cmp_i64_f64(left_i: i64, right_f: f64) -> crate::Result<Ordering> {
if right_f.is_nan() {
return Err(TantivyError::InvalidArgument(
"NaN comparison is not supported".to_string(),
));
}
// If right_f is < i64::MIN then left_i > right_f (i64::MIN=-2^63 can be
// exactly represented as f64)
if right_f < i64::MIN as f64 {
return Ok(Ordering::Greater);
}
// If right_f is >= i64::MAX then left_i < right_f (i64::MAX=2^63-1 cannot
// be exactly represented as f64)
if right_f >= i64::MAX as f64 {
return Ok(Ordering::Less);
}
// Now right_f is in (i64::MIN, i64::MAX), so `right_f as i64` is
// well-defined (truncation toward 0)
let right_as_i = right_f as i64;
let result = match left_i.cmp(&right_as_i) {
Ordering::Less => Ordering::Less,
Ordering::Greater => Ordering::Greater,
Ordering::Equal => {
// they have the same integer part, compare the fraction
let rem = right_f - (right_as_i as f64);
if rem == 0.0 {
Ordering::Equal
} else if right_f > 0.0 {
Ordering::Less
} else {
Ordering::Greater
}
}
};
Ok(result)
}
pub fn cmp_u64_f64(left_u: u64, right_f: f64) -> crate::Result<Ordering> {
if right_f.is_nan() {
return Err(TantivyError::InvalidArgument(
"NaN comparison is not supported".to_string(),
));
}
// Negative floats are always less than any u64 >= 0
if right_f < 0.0 {
return Ok(Ordering::Greater);
}
// If right_f is >= u64::MAX then left_u < right_f (u64::MAX=2^64-1 cannot be exactly)
let max_as_f = u64::MAX as f64;
if right_f > max_as_f {
return Ok(Ordering::Less);
}
// Now right_f is in (0, u64::MAX), so `right_f as u64` is well-defined
// (truncation toward 0)
let right_as_u = right_f as u64;
let result = match left_u.cmp(&right_as_u) {
Ordering::Less => Ordering::Less,
Ordering::Greater => Ordering::Greater,
Ordering::Equal => {
// they have the same integer part, compare the fraction
let rem = right_f - (right_as_u as f64);
if rem == 0.0 {
Ordering::Equal
} else {
Ordering::Less
}
}
};
Ok(result)
}
pub fn cmp_i64_u64(left_i: i64, right_u: u64) -> Ordering {
if left_i < 0 {
Ordering::Less
} else {
let left_as_u = left_i as u64;
left_as_u.cmp(&right_u)
}
}
}
/// This module helps projecting numerical values to other numerical types.
/// When the target value space cannot exactly represent the source value, the
/// next representable value is returned (or AfterLast if the source value is
/// larger than the largest representable value).
///
/// All functions in this module assume that f64 values are not NaN.
pub(super) mod num_proj {
#[derive(Debug, PartialEq)]
pub enum ProjectedNumber<T> {
Exact(T),
Next(T),
AfterLast,
}
pub fn i64_to_u64(value: i64) -> ProjectedNumber<u64> {
if value < 0 {
ProjectedNumber::Next(0)
} else {
ProjectedNumber::Exact(value as u64)
}
}
pub fn u64_to_i64(value: u64) -> ProjectedNumber<i64> {
if value > i64::MAX as u64 {
ProjectedNumber::AfterLast
} else {
ProjectedNumber::Exact(value as i64)
}
}
pub fn f64_to_u64(value: f64) -> ProjectedNumber<u64> {
if value < 0.0 {
ProjectedNumber::Next(0)
} else if value > u64::MAX as f64 {
ProjectedNumber::AfterLast
} else if value.fract() == 0.0 {
ProjectedNumber::Exact(value as u64)
} else {
// casting f64 to u64 truncates toward zero
ProjectedNumber::Next(value as u64 + 1)
}
}
pub fn f64_to_i64(value: f64) -> ProjectedNumber<i64> {
if value < (i64::MIN as f64) {
ProjectedNumber::Next(i64::MIN)
} else if value >= (i64::MAX as f64) {
ProjectedNumber::AfterLast
} else if value.fract() == 0.0 {
ProjectedNumber::Exact(value as i64)
} else if value > 0.0 {
// casting f64 to i64 truncates toward zero
ProjectedNumber::Next(value as i64 + 1)
} else {
ProjectedNumber::Next(value as i64)
}
}
pub fn i64_to_f64(value: i64) -> ProjectedNumber<f64> {
let value_f = value as f64;
let k_roundtrip = value_f as i64;
if k_roundtrip == value {
// between -2^53 and 2^53 all i64 are exactly represented as f64
ProjectedNumber::Exact(value_f)
} else {
// for very large/small i64 values, it is approximated to the closest f64
if k_roundtrip > value {
ProjectedNumber::Next(value_f)
} else {
ProjectedNumber::Next(value_f.next_up())
}
}
}
pub fn u64_to_f64(value: u64) -> ProjectedNumber<f64> {
let value_f = value as f64;
let k_roundtrip = value_f as u64;
if k_roundtrip == value {
// between 0 and 2^53 all u64 are exactly represented as f64
ProjectedNumber::Exact(value_f)
} else if k_roundtrip > value {
ProjectedNumber::Next(value_f)
} else {
ProjectedNumber::Next(value_f.next_up())
}
}
}
#[cfg(test)]
mod num_cmp_tests {
use std::cmp::Ordering;
use super::num_cmp::*;
#[test]
fn test_cmp_u64_f64() {
// Basic comparisons
assert_eq!(cmp_u64_f64(5, 5.0).unwrap(), Ordering::Equal);
assert_eq!(cmp_u64_f64(5, 6.0).unwrap(), Ordering::Less);
assert_eq!(cmp_u64_f64(6, 5.0).unwrap(), Ordering::Greater);
assert_eq!(cmp_u64_f64(0, 0.0).unwrap(), Ordering::Equal);
assert_eq!(cmp_u64_f64(0, 0.1).unwrap(), Ordering::Less);
// Negative float values should always be less than any u64
assert_eq!(cmp_u64_f64(0, -0.1).unwrap(), Ordering::Greater);
assert_eq!(cmp_u64_f64(5, -5.0).unwrap(), Ordering::Greater);
assert_eq!(cmp_u64_f64(u64::MAX, -1e20).unwrap(), Ordering::Greater);
// Tests with extreme values
assert_eq!(cmp_u64_f64(u64::MAX, 1e20).unwrap(), Ordering::Less);
// Precision edge cases: large u64 that loses precision when converted to f64
// => 2^54, exactly represented as f64
let large_f64 = 18_014_398_509_481_984.0;
let large_u64 = 18_014_398_509_481_984;
// prove that large_u64 is exactly represented as f64
assert_eq!(large_u64 as f64, large_f64);
assert_eq!(cmp_u64_f64(large_u64, large_f64).unwrap(), Ordering::Equal);
// => (2^54 + 1) cannot be exactly represented in f64
let large_u64_plus_1 = 18_014_398_509_481_985;
// prove that it is represented as f64 by large_f64
assert_eq!(large_u64_plus_1 as f64, large_f64);
assert_eq!(
cmp_u64_f64(large_u64_plus_1, large_f64).unwrap(),
Ordering::Greater
);
// => (2^54 - 1) cannot be exactly represented in f64
let large_u64_minus_1 = 18_014_398_509_481_983;
// prove that it is also represented as f64 by large_f64
assert_eq!(large_u64_minus_1 as f64, large_f64);
assert_eq!(
cmp_u64_f64(large_u64_minus_1, large_f64).unwrap(),
Ordering::Less
);
// NaN comparison results in an error
assert!(cmp_u64_f64(0, f64::NAN).is_err());
}
#[test]
fn test_cmp_i64_f64() {
// Basic comparisons
assert_eq!(cmp_i64_f64(5, 5.0).unwrap(), Ordering::Equal);
assert_eq!(cmp_i64_f64(5, 6.0).unwrap(), Ordering::Less);
assert_eq!(cmp_i64_f64(6, 5.0).unwrap(), Ordering::Greater);
assert_eq!(cmp_i64_f64(-5, -5.0).unwrap(), Ordering::Equal);
assert_eq!(cmp_i64_f64(-5, -4.0).unwrap(), Ordering::Less);
assert_eq!(cmp_i64_f64(-4, -5.0).unwrap(), Ordering::Greater);
assert_eq!(cmp_i64_f64(-5, 5.0).unwrap(), Ordering::Less);
assert_eq!(cmp_i64_f64(5, -5.0).unwrap(), Ordering::Greater);
assert_eq!(cmp_i64_f64(0, -0.1).unwrap(), Ordering::Greater);
assert_eq!(cmp_i64_f64(0, 0.1).unwrap(), Ordering::Less);
assert_eq!(cmp_i64_f64(-1, -0.5).unwrap(), Ordering::Less);
assert_eq!(cmp_i64_f64(-1, 0.0).unwrap(), Ordering::Less);
assert_eq!(cmp_i64_f64(0, 0.0).unwrap(), Ordering::Equal);
// Tests with extreme values
assert_eq!(cmp_i64_f64(i64::MAX, 1e20).unwrap(), Ordering::Less);
assert_eq!(cmp_i64_f64(i64::MIN, -1e20).unwrap(), Ordering::Greater);
// Precision edge cases: large i64 that loses precision when converted to f64
// => 2^54, exactly represented as f64
let large_f64 = 18_014_398_509_481_984.0;
let large_i64 = 18_014_398_509_481_984;
// prove that large_i64 is exactly represented as f64
assert_eq!(large_i64 as f64, large_f64);
assert_eq!(cmp_i64_f64(large_i64, large_f64).unwrap(), Ordering::Equal);
// => (1_i64 << 54) + 1 cannot be exactly represented in f64
let large_i64_plus_1 = 18_014_398_509_481_985;
// prove that it is represented as f64 by large_f64
assert_eq!(large_i64_plus_1 as f64, large_f64);
assert_eq!(
cmp_i64_f64(large_i64_plus_1, large_f64).unwrap(),
Ordering::Greater
);
// => (1_i64 << 54) - 1 cannot be exactly represented in f64
let large_i64_minus_1 = 18_014_398_509_481_983;
// prove that it is also represented as f64 by large_f64
assert_eq!(large_i64_minus_1 as f64, large_f64);
assert_eq!(
cmp_i64_f64(large_i64_minus_1, large_f64).unwrap(),
Ordering::Less
);
// Same precision edge case but with negative values
// => -2^54, exactly represented as f64
let large_neg_f64 = -18_014_398_509_481_984.0;
let large_neg_i64 = -18_014_398_509_481_984;
// prove that large_neg_i64 is exactly represented as f64
assert_eq!(large_neg_i64 as f64, large_neg_f64);
assert_eq!(
cmp_i64_f64(large_neg_i64, large_neg_f64).unwrap(),
Ordering::Equal
);
// => (-2^54 + 1) cannot be exactly represented in f64
let large_neg_i64_plus_1 = -18_014_398_509_481_985;
// prove that it is represented as f64 by large_neg_f64
assert_eq!(large_neg_i64_plus_1 as f64, large_neg_f64);
assert_eq!(
cmp_i64_f64(large_neg_i64_plus_1, large_neg_f64).unwrap(),
Ordering::Less
);
// => (-2^54 - 1) cannot be exactly represented in f64
let large_neg_i64_minus_1 = -18_014_398_509_481_983;
// prove that it is also represented as f64 by large_neg_f64
assert_eq!(large_neg_i64_minus_1 as f64, large_neg_f64);
assert_eq!(
cmp_i64_f64(large_neg_i64_minus_1, large_neg_f64).unwrap(),
Ordering::Greater
);
// NaN comparison results in an error
assert!(cmp_i64_f64(0, f64::NAN).is_err());
}
#[test]
fn test_cmp_i64_u64() {
// Test with negative i64 values (should always be less than any u64)
assert_eq!(cmp_i64_u64(-1, 0), Ordering::Less);
assert_eq!(cmp_i64_u64(i64::MIN, 0), Ordering::Less);
assert_eq!(cmp_i64_u64(i64::MIN, u64::MAX), Ordering::Less);
// Test with positive i64 values
assert_eq!(cmp_i64_u64(0, 0), Ordering::Equal);
assert_eq!(cmp_i64_u64(1, 0), Ordering::Greater);
assert_eq!(cmp_i64_u64(1, 1), Ordering::Equal);
assert_eq!(cmp_i64_u64(0, 1), Ordering::Less);
assert_eq!(cmp_i64_u64(5, 10), Ordering::Less);
assert_eq!(cmp_i64_u64(10, 5), Ordering::Greater);
// Test with values near i64::MAX and u64 conversion
assert_eq!(cmp_i64_u64(i64::MAX, i64::MAX as u64), Ordering::Equal);
assert_eq!(cmp_i64_u64(i64::MAX, (i64::MAX as u64) + 1), Ordering::Less);
assert_eq!(cmp_i64_u64(i64::MAX, u64::MAX), Ordering::Less);
}
}
#[cfg(test)]
mod num_proj_tests {
use super::num_proj::{self, ProjectedNumber};
#[test]
fn test_i64_to_u64() {
assert_eq!(num_proj::i64_to_u64(-1), ProjectedNumber::Next(0));
assert_eq!(num_proj::i64_to_u64(i64::MIN), ProjectedNumber::Next(0));
assert_eq!(num_proj::i64_to_u64(0), ProjectedNumber::Exact(0));
assert_eq!(num_proj::i64_to_u64(42), ProjectedNumber::Exact(42));
assert_eq!(
num_proj::i64_to_u64(i64::MAX),
ProjectedNumber::Exact(i64::MAX as u64)
);
}
#[test]
fn test_u64_to_i64() {
assert_eq!(num_proj::u64_to_i64(0), ProjectedNumber::Exact(0));
assert_eq!(num_proj::u64_to_i64(42), ProjectedNumber::Exact(42));
assert_eq!(
num_proj::u64_to_i64(i64::MAX as u64),
ProjectedNumber::Exact(i64::MAX)
);
assert_eq!(
num_proj::u64_to_i64((i64::MAX as u64) + 1),
ProjectedNumber::AfterLast
);
assert_eq!(num_proj::u64_to_i64(u64::MAX), ProjectedNumber::AfterLast);
}
#[test]
fn test_f64_to_u64() {
assert_eq!(num_proj::f64_to_u64(-1e25), ProjectedNumber::Next(0));
assert_eq!(num_proj::f64_to_u64(-0.1), ProjectedNumber::Next(0));
assert_eq!(num_proj::f64_to_u64(1e20), ProjectedNumber::AfterLast);
assert_eq!(
num_proj::f64_to_u64(f64::INFINITY),
ProjectedNumber::AfterLast
);
assert_eq!(num_proj::f64_to_u64(0.0), ProjectedNumber::Exact(0));
assert_eq!(num_proj::f64_to_u64(42.0), ProjectedNumber::Exact(42));
assert_eq!(num_proj::f64_to_u64(0.5), ProjectedNumber::Next(1));
assert_eq!(num_proj::f64_to_u64(42.1), ProjectedNumber::Next(43));
}
#[test]
fn test_f64_to_i64() {
assert_eq!(num_proj::f64_to_i64(-1e20), ProjectedNumber::Next(i64::MIN));
assert_eq!(
num_proj::f64_to_i64(f64::NEG_INFINITY),
ProjectedNumber::Next(i64::MIN)
);
assert_eq!(num_proj::f64_to_i64(1e20), ProjectedNumber::AfterLast);
assert_eq!(
num_proj::f64_to_i64(f64::INFINITY),
ProjectedNumber::AfterLast
);
assert_eq!(num_proj::f64_to_i64(0.0), ProjectedNumber::Exact(0));
assert_eq!(num_proj::f64_to_i64(42.0), ProjectedNumber::Exact(42));
assert_eq!(num_proj::f64_to_i64(-42.0), ProjectedNumber::Exact(-42));
assert_eq!(num_proj::f64_to_i64(0.5), ProjectedNumber::Next(1));
assert_eq!(num_proj::f64_to_i64(42.1), ProjectedNumber::Next(43));
assert_eq!(num_proj::f64_to_i64(-0.5), ProjectedNumber::Next(0));
assert_eq!(num_proj::f64_to_i64(-42.1), ProjectedNumber::Next(-42));
}
#[test]
fn test_i64_to_f64() {
assert_eq!(num_proj::i64_to_f64(0), ProjectedNumber::Exact(0.0));
assert_eq!(num_proj::i64_to_f64(42), ProjectedNumber::Exact(42.0));
assert_eq!(num_proj::i64_to_f64(-42), ProjectedNumber::Exact(-42.0));
let max_exact = 9_007_199_254_740_992; // 2^53
assert_eq!(
num_proj::i64_to_f64(max_exact),
ProjectedNumber::Exact(max_exact as f64)
);
// Test values that cannot be exactly represented as f64 (integers above 2^53)
let large_i64 = 9_007_199_254_740_993; // 2^53 + 1
let closest_f64 = 9_007_199_254_740_992.0;
assert_eq!(large_i64 as f64, closest_f64);
if let ProjectedNumber::Next(val) = num_proj::i64_to_f64(large_i64) {
// Verify that the returned float is different from the direct cast
assert!(val > closest_f64);
assert!(val - closest_f64 < 2. * f64::EPSILON * closest_f64);
} else {
panic!("Expected ProjectedNumber::Next for large_i64");
}
// Test with very large negative value
let large_neg_i64 = -9_007_199_254_740_993; // -(2^53 + 1)
let closest_neg_f64 = -9_007_199_254_740_992.0;
assert_eq!(large_neg_i64 as f64, closest_neg_f64);
if let ProjectedNumber::Next(val) = num_proj::i64_to_f64(large_neg_i64) {
// Verify that the returned float is the closest representable f64
assert_eq!(val, closest_neg_f64);
} else {
panic!("Expected ProjectedNumber::Next for large_neg_i64");
}
}
#[test]
fn test_u64_to_f64() {
assert_eq!(num_proj::u64_to_f64(0), ProjectedNumber::Exact(0.0));
assert_eq!(num_proj::u64_to_f64(42), ProjectedNumber::Exact(42.0));
// Test the largest u64 value that can be exactly represented as f64 (2^53)
let max_exact = 9_007_199_254_740_992; // 2^53
assert_eq!(
num_proj::u64_to_f64(max_exact),
ProjectedNumber::Exact(max_exact as f64)
);
// Test values that cannot be exactly represented as f64 (integers above 2^53)
let large_u64 = 9_007_199_254_740_993; // 2^53 + 1
let closest_f64 = 9_007_199_254_740_992.0;
assert_eq!(large_u64 as f64, closest_f64);
if let ProjectedNumber::Next(val) = num_proj::u64_to_f64(large_u64) {
// Verify that the returned float is different from the direct cast
assert!(val > closest_f64);
assert!(val - closest_f64 < 2. * f64::EPSILON * closest_f64);
} else {
panic!("Expected ProjectedNumber::Next for large_u64");
}
}
}

View File

@@ -1,4 +1,5 @@
use std::fmt::Debug;
use std::sync::Arc;
use common::BitSet;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
@@ -6,8 +7,8 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer};
use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::buffered_sub_aggs::{
BufferedSubAggs, HighCardSubAggBuffer, LowCardSubAggBuffer, SubAggBuffer,
use crate::aggregation::cached_sub_aggs::{
CachedSubAggs, HighCardSubAggCache, LowCardSubAggCache, SubAggCache,
};
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
@@ -402,7 +403,7 @@ pub struct FilterAggReqData {
/// The filter aggregation
pub req: FilterAggregation,
/// The segment reader
pub segment_reader: SegmentReader,
pub segment_reader: Arc<dyn SegmentReader>,
/// Document evaluator for the filter query (precomputed BitSet)
/// This is built once when the request data is created
pub evaluator: DocumentQueryEvaluator,
@@ -416,7 +417,7 @@ impl FilterAggReqData {
pub(crate) fn get_memory_consumption(&self) -> usize {
// Estimate: name + segment reader reference + bitset + buffer capacity
self.name.len()
+ std::mem::size_of::<SegmentReader>()
+ std::mem::size_of::<Arc<dyn SegmentReader>>()
+ self.evaluator.bitset.len() / 8 // BitSet memory (bits to bytes)
+ self.matching_docs_buffer.capacity() * std::mem::size_of::<DocId>()
+ std::mem::size_of::<bool>()
@@ -438,7 +439,7 @@ impl DocumentQueryEvaluator {
pub(crate) fn new(
query: Box<dyn Query>,
schema: Schema,
segment_reader: &SegmentReader,
segment_reader: &dyn SegmentReader,
) -> crate::Result<Self> {
let max_doc = segment_reader.max_doc();
@@ -503,17 +504,17 @@ struct DocCount {
}
/// Segment collector for filter aggregation
pub struct SegmentFilterCollector<B: SubAggBuffer> {
pub struct SegmentFilterCollector<C: SubAggCache> {
/// Document counts per parent bucket
parent_buckets: Vec<DocCount>,
/// Sub-aggregation collectors
sub_aggregations: Option<BufferedSubAggs<B>>,
sub_aggregations: Option<CachedSubAggs<C>>,
bucket_id_provider: BucketIdProvider,
/// Accessor index for this filter aggregation (to access FilterAggReqData)
accessor_idx: usize,
}
impl<B: SubAggBuffer> SegmentFilterCollector<B> {
impl<C: SubAggCache> SegmentFilterCollector<C> {
/// Create a new filter segment collector following the new agg_data pattern
pub(crate) fn from_req_and_validate(
req: &mut AggregationsSegmentCtx,
@@ -525,7 +526,7 @@ impl<B: SubAggBuffer> SegmentFilterCollector<B> {
} else {
None
};
let sub_agg_collector = sub_agg_collector.map(BufferedSubAggs::new);
let sub_agg_collector = sub_agg_collector.map(CachedSubAggs::new);
Ok(SegmentFilterCollector {
parent_buckets: Vec::new(),
@@ -547,16 +548,16 @@ pub(crate) fn build_segment_filter_collector(
if is_top_level {
Ok(Box::new(
SegmentFilterCollector::<LowCardSubAggBuffer>::from_req_and_validate(req, node)?,
SegmentFilterCollector::<LowCardSubAggCache>::from_req_and_validate(req, node)?,
))
} else {
Ok(Box::new(
SegmentFilterCollector::<HighCardSubAggBuffer>::from_req_and_validate(req, node)?,
SegmentFilterCollector::<HighCardSubAggCache>::from_req_and_validate(req, node)?,
))
}
}
impl<B: SubAggBuffer> Debug for SegmentFilterCollector<B> {
impl<C: SubAggCache> Debug for SegmentFilterCollector<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SegmentFilterCollector")
.field("buckets", &self.parent_buckets)
@@ -566,7 +567,7 @@ impl<B: SubAggBuffer> Debug for SegmentFilterCollector<B> {
}
}
impl<B: SubAggBuffer> SegmentAggregationCollector for SegmentFilterCollector<B> {
impl<C: SubAggCache> SegmentAggregationCollector for SegmentFilterCollector<C> {
fn add_intermediate_aggregation_result(
&mut self,
agg_data: &AggregationsSegmentCtx,
@@ -674,17 +675,6 @@ impl<B: SubAggBuffer> SegmentAggregationCollector for SegmentFilterCollector<B>
}
Ok(())
}
fn compute_metric_value(
&self,
_bucket_id: BucketId,
_sub_agg_name: &str,
_sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
// TODO: forward into the inner `sub_agg` for nested order paths (`filter.metric`).
None
}
}
/// Intermediate result for filter aggregation

View File

@@ -207,7 +207,7 @@ fn parse_offset_into_milliseconds(input: &str) -> Result<i64, AggregationError>
}
}
pub(crate) fn parse_into_milliseconds(input: &str) -> Result<i64, AggregationError> {
fn parse_into_milliseconds(input: &str) -> Result<i64, AggregationError> {
let split_boundary = input
.as_bytes()
.iter()

View File

@@ -10,7 +10,7 @@ use crate::aggregation::agg_data::{
};
use crate::aggregation::agg_req::Aggregations;
use crate::aggregation::agg_result::BucketEntry;
use crate::aggregation::buffered_sub_aggs::{BufferedSubAggs, HighCardBufferedSubAggs};
use crate::aggregation::cached_sub_aggs::{CachedSubAggs, HighCardCachedSubAggs};
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
IntermediateHistogramBucketEntry,
@@ -258,7 +258,7 @@ pub(crate) struct SegmentHistogramBucketEntry {
impl SegmentHistogramBucketEntry {
pub(crate) fn into_intermediate_bucket_entry(
self,
sub_aggregation: &mut Option<HighCardBufferedSubAggs>,
sub_aggregation: &mut Option<HighCardCachedSubAggs>,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<IntermediateHistogramBucketEntry> {
let mut sub_aggregation_res = IntermediateAggregationResults::default();
@@ -283,11 +283,6 @@ impl SegmentHistogramBucketEntry {
struct HistogramBuckets {
pub buckets: FxHashMap<i64, SegmentHistogramBucketEntry>,
}
impl HistogramBuckets {
fn memory_consumption(&self) -> u64 {
self.buckets.capacity() as u64 * std::mem::size_of::<SegmentHistogramBucketEntry>() as u64
}
}
/// The collector puts values from the fast field into the correct buckets and does a conversion to
/// the correct datatype.
@@ -296,7 +291,7 @@ pub struct SegmentHistogramCollector {
/// The buckets containing the aggregation data.
/// One Histogram bucket per parent bucket id.
parent_buckets: Vec<HistogramBuckets>,
sub_agg: Option<HighCardBufferedSubAggs>,
sub_agg: Option<HighCardCachedSubAggs>,
accessor_idx: usize,
bucket_id_provider: BucketIdProvider,
}
@@ -329,7 +324,7 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let req = agg_data.take_histogram_req_data(self.accessor_idx);
let mem_pre = self.get_memory_consumption(parent_bucket_id);
let mem_pre = self.get_memory_consumption();
let buckets = &mut self.parent_buckets[parent_bucket_id as usize].buckets;
let bounds = req.bounds;
@@ -363,9 +358,12 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
}
agg_data.put_back_histogram_req_data(self.accessor_idx, req);
let mem_delta = self.get_memory_consumption(parent_bucket_id) - mem_pre;
let mem_delta = self.get_memory_consumption() - mem_pre;
if mem_delta > 0 {
agg_data.context.limits.add_memory_consumed(mem_delta)?;
agg_data
.context
.limits
.add_memory_consumed(mem_delta as u64)?;
}
if let Some(sub_agg) = &mut self.sub_agg {
@@ -394,24 +392,14 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
}
Ok(())
}
fn compute_metric_value(
&self,
_bucket_id: BucketId,
_sub_agg_name: &str,
_sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
// Histogram is a multi-bucket agg with no single value to extract.
None
}
}
impl SegmentHistogramCollector {
fn get_memory_consumption(&self, parent_bucket_id: BucketId) -> u64 {
self.parent_buckets[parent_bucket_id as usize].memory_consumption()
fn get_memory_consumption(&self) -> usize {
let self_mem = std::mem::size_of::<Self>();
let buckets_mem = self.parent_buckets.len() * std::mem::size_of::<HistogramBuckets>();
self_mem + buckets_mem
}
/// Converts the collector result into a intermediate bucket result.
fn add_intermediate_bucket_result(
&mut self,
@@ -456,7 +444,7 @@ impl SegmentHistogramCollector {
max: f64::MAX,
});
req_data.offset = req_data.req.offset.unwrap_or(0.0);
let sub_agg = sub_agg.map(BufferedSubAggs::new);
let sub_agg = sub_agg.map(CachedSubAggs::new);
Ok(Self {
parent_buckets: Default::default(),

View File

@@ -22,7 +22,6 @@
//! - [Range](RangeAggregation)
//! - [Terms](TermsAggregation)
mod composite;
mod filter;
mod histogram;
mod range;
@@ -32,7 +31,6 @@ mod term_missing_agg;
use std::collections::HashMap;
use std::fmt;
pub use composite::*;
pub use filter::*;
pub use histogram::*;
pub use range::*;

View File

@@ -9,9 +9,8 @@ use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::agg_limits::AggregationLimitsGuard;
use crate::aggregation::buffered_sub_aggs::{
BufferedSubAggs, HighCardSubAggBuffer, LowCardBufferedSubAggs, LowCardSubAggBuffer,
SubAggBuffer,
use crate::aggregation::cached_sub_aggs::{
CachedSubAggs, HighCardSubAggCache, LowCardCachedSubAggs, LowCardSubAggCache, SubAggCache,
};
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
@@ -156,13 +155,13 @@ pub(crate) struct SegmentRangeAndBucketEntry {
/// The collector puts values from the fast field into the correct buckets and does a conversion to
/// the correct datatype.
pub struct SegmentRangeCollector<B: SubAggBuffer> {
pub struct SegmentRangeCollector<C: SubAggCache> {
/// The buckets containing the aggregation data.
/// One for each ParentBucketId
parent_buckets: Vec<Vec<SegmentRangeAndBucketEntry>>,
column_type: ColumnType,
pub(crate) accessor_idx: usize,
sub_agg: Option<BufferedSubAggs<B>>,
sub_agg: Option<CachedSubAggs<C>>,
/// Here things get a bit weird. We need to assign unique bucket ids across all
/// parent buckets. So we keep track of the next available bucket id here.
/// This allows a kind of flattening of the bucket ids across all parent buckets.
@@ -179,7 +178,7 @@ pub struct SegmentRangeCollector<B: SubAggBuffer> {
limits: AggregationLimitsGuard,
}
impl<B: SubAggBuffer> Debug for SegmentRangeCollector<B> {
impl<C: SubAggCache> Debug for SegmentRangeCollector<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SegmentRangeCollector")
.field("parent_buckets_len", &self.parent_buckets.len())
@@ -230,7 +229,7 @@ impl SegmentRangeBucketEntry {
}
}
impl<B: SubAggBuffer> SegmentAggregationCollector for SegmentRangeCollector<B> {
impl<C: SubAggCache> SegmentAggregationCollector for SegmentRangeCollector<C> {
fn add_intermediate_aggregation_result(
&mut self,
agg_data: &AggregationsSegmentCtx,
@@ -328,17 +327,6 @@ impl<B: SubAggBuffer> SegmentAggregationCollector for SegmentRangeCollector<B> {
Ok(())
}
fn compute_metric_value(
&self,
_bucket_id: BucketId,
_sub_agg_name: &str,
_sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
// Range is a multi-bucket agg with no single value to extract.
None
}
}
/// Build a concrete `SegmentRangeCollector` with either a Vec- or HashMap-backed
/// bucket storage, depending on the column type and aggregation level.
@@ -362,8 +350,8 @@ pub(crate) fn build_segment_range_collector(
};
if is_low_card {
Ok(Box::new(SegmentRangeCollector::<LowCardSubAggBuffer> {
sub_agg: sub_agg.map(LowCardBufferedSubAggs::new),
Ok(Box::new(SegmentRangeCollector::<LowCardSubAggCache> {
sub_agg: sub_agg.map(LowCardCachedSubAggs::new),
column_type: field_type,
accessor_idx,
parent_buckets: Vec::new(),
@@ -371,8 +359,8 @@ pub(crate) fn build_segment_range_collector(
limits: agg_data.context.limits.clone(),
}))
} else {
Ok(Box::new(SegmentRangeCollector::<HighCardSubAggBuffer> {
sub_agg: sub_agg.map(BufferedSubAggs::new),
Ok(Box::new(SegmentRangeCollector::<HighCardSubAggCache> {
sub_agg: sub_agg.map(CachedSubAggs::new),
column_type: field_type,
accessor_idx,
parent_buckets: Vec::new(),
@@ -382,7 +370,7 @@ pub(crate) fn build_segment_range_collector(
}
}
impl<B: SubAggBuffer> SegmentRangeCollector<B> {
impl<C: SubAggCache> SegmentRangeCollector<C> {
pub(crate) fn create_new_buckets(
&mut self,
agg_data: &AggregationsSegmentCtx,
@@ -566,7 +554,7 @@ mod tests {
pub fn get_collector_from_ranges(
ranges: Vec<RangeAggregationRange>,
field_type: ColumnType,
) -> SegmentRangeCollector<HighCardSubAggBuffer> {
) -> SegmentRangeCollector<HighCardSubAggCache> {
let req = RangeAggregation {
field: "dummy".to_string(),
ranges,

View File

@@ -1,4 +1,5 @@
use std::fmt::Debug;
use std::io;
use std::net::Ipv6Addr;
use columnar::column_values::CompactSpaceU64Accessor;
@@ -16,9 +17,8 @@ use crate::aggregation::agg_data::{
};
use crate::aggregation::agg_limits::MemoryConsumption;
use crate::aggregation::agg_req::Aggregations;
use crate::aggregation::buffered_sub_aggs::{
BufferedSubAggs, HighCardSubAggBuffer, LowCardBufferedSubAggs, LowCardSubAggBuffer,
SubAggBuffer,
use crate::aggregation::cached_sub_aggs::{
CachedSubAggs, HighCardSubAggCache, LowCardCachedSubAggs, LowCardSubAggCache, SubAggCache,
};
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
@@ -352,15 +352,19 @@ pub(crate) fn build_segment_term_collector(
)));
}
// Validate that the referenced sub-aggregation exists when ordering by one.
if let OrderTarget::SubAggregation(sub_agg_name) = &terms_req_data.req.order.target {
let (agg_name, _agg_property) = get_agg_name_and_property(sub_agg_name);
node.get_sub_agg(agg_name, &req_data.per_request)
.ok_or_else(|| {
TantivyError::InvalidArgument(format!(
"could not find aggregation with name {agg_name} in metric sub_aggregations"
))
})?;
// Validate sub aggregation exists when ordering by sub-aggregation.
{
if let OrderTarget::SubAggregation(sub_agg_name) = &terms_req_data.req.order.target {
let (agg_name, _agg_property) = get_agg_name_and_property(sub_agg_name);
node.get_sub_agg(agg_name, &req_data.per_request)
.ok_or_else(|| {
TantivyError::InvalidArgument(format!(
"could not find aggregation with name {agg_name} in metric \
sub_aggregations"
))
})?;
}
}
// Build sub-aggregation blueprint if there are children.
@@ -387,7 +391,7 @@ pub(crate) fn build_segment_term_collector(
// Decide which bucket storage is best suited for this aggregation.
if is_top_level && max_term_id < MAX_NUM_TERMS_FOR_VEC && !has_sub_aggregations {
let term_buckets = VecTermBucketsNoAgg::new(max_term_id + 1, &mut bucket_id_provider);
let collector: SegmentTermCollector<_, HighCardSubAggBuffer> = SegmentTermCollector {
let collector: SegmentTermCollector<_, HighCardSubAggCache> = SegmentTermCollector {
parent_buckets: vec![term_buckets],
sub_agg: None,
bucket_id_provider,
@@ -397,8 +401,8 @@ pub(crate) fn build_segment_term_collector(
Ok(Box::new(collector))
} else if is_top_level && max_term_id < MAX_NUM_TERMS_FOR_VEC {
let term_buckets = VecTermBuckets::new(max_term_id + 1, &mut bucket_id_provider);
let sub_agg = sub_agg_collector.map(LowCardBufferedSubAggs::new);
let collector: SegmentTermCollector<_, LowCardSubAggBuffer> = SegmentTermCollector {
let sub_agg = sub_agg_collector.map(LowCardCachedSubAggs::new);
let collector: SegmentTermCollector<_, LowCardSubAggCache> = SegmentTermCollector {
parent_buckets: vec![term_buckets],
sub_agg,
bucket_id_provider,
@@ -410,8 +414,8 @@ pub(crate) fn build_segment_term_collector(
let term_buckets: PagedTermMap =
PagedTermMap::new(max_term_id + 1, &mut bucket_id_provider);
// Build sub-aggregation blueprint (flat pairs)
let sub_agg = sub_agg_collector.map(BufferedSubAggs::new);
let collector: SegmentTermCollector<PagedTermMap, HighCardSubAggBuffer> =
let sub_agg = sub_agg_collector.map(CachedSubAggs::new);
let collector: SegmentTermCollector<PagedTermMap, HighCardSubAggCache> =
SegmentTermCollector {
parent_buckets: vec![term_buckets],
sub_agg,
@@ -423,8 +427,8 @@ pub(crate) fn build_segment_term_collector(
} else {
let term_buckets: HashMapTermBuckets = HashMapTermBuckets::default();
// Build sub-aggregation blueprint (flat pairs)
let sub_agg = sub_agg_collector.map(BufferedSubAggs::new);
let collector: SegmentTermCollector<HashMapTermBuckets, HighCardSubAggBuffer> =
let sub_agg = sub_agg_collector.map(CachedSubAggs::new);
let collector: SegmentTermCollector<HashMapTermBuckets, HighCardSubAggCache> =
SegmentTermCollector {
parent_buckets: vec![term_buckets],
sub_agg,
@@ -754,10 +758,10 @@ impl TermAggregationMap for VecTermBuckets {
/// The collector puts values from the fast field into the correct buckets and does a conversion to
/// the correct datatype.
#[derive(Debug)]
struct SegmentTermCollector<TermMap: TermAggregationMap, B: SubAggBuffer> {
struct SegmentTermCollector<TermMap: TermAggregationMap, C: SubAggCache> {
/// The buckets containing the aggregation data.
parent_buckets: Vec<TermMap>,
sub_agg: Option<BufferedSubAggs<B>>,
sub_agg: Option<CachedSubAggs<C>>,
bucket_id_provider: BucketIdProvider,
max_term_id: u64,
terms_req_data: TermsAggReqData,
@@ -768,8 +772,8 @@ pub(crate) fn get_agg_name_and_property(name: &str) -> (&str, &str) {
(agg_name, agg_property)
}
impl<TermMap: TermAggregationMap, B: SubAggBuffer> SegmentAggregationCollector
for SegmentTermCollector<TermMap, B>
impl<TermMap: TermAggregationMap, C: SubAggCache> SegmentAggregationCollector
for SegmentTermCollector<TermMap, C>
{
fn add_intermediate_aggregation_result(
&mut self,
@@ -786,14 +790,8 @@ impl<TermMap: TermAggregationMap, B: SubAggBuffer> SegmentAggregationCollector
let term_req = &self.terms_req_data;
let name = term_req.name.clone();
let bucket = Self::into_intermediate_bucket_result(
term_req,
self.sub_agg
.as_mut()
.map(BufferedSubAggs::get_sub_agg_collector),
bucket,
agg_data,
)?;
let bucket =
Self::into_intermediate_bucket_result(term_req, &mut self.sub_agg, bucket, agg_data)?;
results.push(name, IntermediateAggregationResult::Bucket(bucket))?;
Ok(())
}
@@ -805,17 +803,15 @@ impl<TermMap: TermAggregationMap, B: SubAggBuffer> SegmentAggregationCollector
docs: &[crate::DocId],
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
let mem_pre = self.get_memory_consumption(parent_bucket_id);
let mem_pre = self.get_memory_consumption();
let req_data = &mut self.terms_req_data;
agg_data
.column_block_accessor
.fetch_block_with_missing_unique_per_doc(
docs,
&req_data.accessor,
req_data.missing_value_for_accessor,
);
agg_data.column_block_accessor.fetch_block_with_missing(
docs,
&req_data.accessor,
req_data.missing_value_for_accessor,
);
if let Some(sub_agg) = &mut self.sub_agg {
let term_buckets = &mut self.parent_buckets[parent_bucket_id as usize];
@@ -849,7 +845,7 @@ impl<TermMap: TermAggregationMap, B: SubAggBuffer> SegmentAggregationCollector
}
}
let mem_delta = self.get_memory_consumption(parent_bucket_id) - mem_pre;
let mem_delta = self.get_memory_consumption() - mem_pre;
if mem_delta > 0 {
agg_data
.context
@@ -883,17 +879,6 @@ impl<TermMap: TermAggregationMap, B: SubAggBuffer> SegmentAggregationCollector
}
Ok(())
}
fn compute_metric_value(
&self,
_bucket_id: BucketId,
_sub_agg_name: &str,
_sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
// Terms is a multi-bucket agg with no single value to extract.
None
}
}
/// Missing value are represented as a sentinel value in the column.
@@ -920,53 +905,30 @@ fn extract_missing_value<T>(
Some((key, bucket))
}
fn reborrow_opt_collector<'a>(
opt: &'a mut Option<&mut dyn SegmentAggregationCollector>,
) -> Option<&'a mut dyn SegmentAggregationCollector> {
match opt {
Some(inner) => Some(*inner),
None => None,
}
}
fn into_intermediate_bucket_entry(
bucket: Bucket,
sub_agg_collector: Option<&mut dyn SegmentAggregationCollector>,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<IntermediateTermBucketEntry> {
let mut sub_aggregation_res = IntermediateAggregationResults::default();
if let Some(sub_agg_collector) = sub_agg_collector {
sub_agg_collector.add_intermediate_aggregation_result(
agg_data,
&mut sub_aggregation_res,
bucket.bucket_id,
)?;
}
Ok(IntermediateTermBucketEntry {
doc_count: bucket.count,
sub_aggregation: sub_aggregation_res,
})
}
impl<TermMap, B> SegmentTermCollector<TermMap, B>
impl<TermMap, C> SegmentTermCollector<TermMap, C>
where
TermMap: TermAggregationMap,
B: SubAggBuffer,
C: SubAggCache,
{
#[inline]
fn get_memory_consumption(&self, parent_bucket_id: BucketId) -> usize {
self.parent_buckets[parent_bucket_id as usize].get_memory_consumption()
fn get_memory_consumption(&self) -> usize {
self.parent_buckets
.iter()
.map(|b| b.get_memory_consumption())
.sum()
}
#[inline]
pub(crate) fn into_intermediate_bucket_result(
term_req: &TermsAggReqData,
mut sub_agg_collector: Option<&mut dyn SegmentAggregationCollector>,
sub_agg: &mut Option<CachedSubAggs<C>>,
term_buckets: TermMap,
agg_data: &AggregationsSegmentCtx,
) -> crate::Result<IntermediateBucketResult> {
let mut entries: Vec<(u64, Bucket)> = term_buckets.into_vec();
let order_by_sub_aggregation =
matches!(term_req.req.order.target, OrderTarget::SubAggregation(_));
match &term_req.req.order.target {
OrderTarget::Key => {
// We rely on the fact, that term ordinals match the order of the strings
@@ -978,37 +940,10 @@ where
entries.sort_unstable_by_key(|bucket| bucket.0);
}
}
OrderTarget::SubAggregation(sub_agg_path) => {
// Peek segment-level metric values, sort, then fall through to
// `cut_off_buckets`. Like Elasticsearch, we always cut off when ordering
// by a sub-agg: top-K results are approximate and may differ from the
// global ordering, especially for non-monotonic metrics like avg/min.
let coll = sub_agg_collector.as_deref().ok_or_else(|| {
TantivyError::InvalidArgument(format!(
"Could not find sub-aggregation collector for path {sub_agg_path}"
))
})?;
let (agg_name, agg_prop) = get_agg_name_and_property(sub_agg_path);
// Fetch values up-front; otherwise sort would re-compute per comparison
let mut keyed: Vec<(f64, (u64, Bucket))> = entries
.into_iter()
.map(|bucket| {
let metric_value = coll
.compute_metric_value(bucket.1.bucket_id, agg_name, agg_prop, agg_data)
.unwrap_or(0.0);
(metric_value, bucket)
})
.collect();
if term_req.req.order.order == Order::Desc {
keyed.sort_unstable_by(|a, b| {
b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)
});
} else {
keyed.sort_unstable_by(|a, b| {
a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
});
}
entries = keyed.into_iter().map(|(_, e)| e).collect();
OrderTarget::SubAggregation(_name) => {
// don't sort and cut off since it's hard to make assumptions on the quality of the
// results when cutting off du to unknown nature of the sub_aggregation (possible
// to check).
}
OrderTarget::Count => {
if term_req.req.order.order == Order::Desc {
@@ -1019,12 +954,40 @@ where
}
}
let (term_doc_count_before_cutoff, sum_other_doc_count) =
cut_off_buckets(&mut entries, term_req.req.segment_size as usize);
let (term_doc_count_before_cutoff, sum_other_doc_count) = if order_by_sub_aggregation {
(0, 0)
} else {
cut_off_buckets(&mut entries, term_req.req.segment_size as usize)
};
let mut dict: FxHashMap<IntermediateKey, IntermediateTermBucketEntry> = Default::default();
dict.reserve(entries.len());
let into_intermediate_bucket_entry =
|bucket: Bucket,
sub_agg: &mut Option<CachedSubAggs<C>>|
-> crate::Result<IntermediateTermBucketEntry> {
if let Some(sub_agg) = sub_agg {
let mut sub_aggregation_res = IntermediateAggregationResults::default();
sub_agg
.get_sub_agg_collector()
.add_intermediate_aggregation_result(
agg_data,
&mut sub_aggregation_res,
bucket.bucket_id,
)?;
Ok(IntermediateTermBucketEntry {
doc_count: bucket.count,
sub_aggregation: sub_aggregation_res,
})
} else {
Ok(IntermediateTermBucketEntry {
doc_count: bucket.count,
sub_aggregation: Default::default(),
})
}
};
if term_req.column_type == ColumnType::Str {
let fallback_dict = Dictionary::empty();
let term_dict = term_req
@@ -1035,11 +998,7 @@ where
if let Some((intermediate_key, bucket)) = extract_missing_value(&mut entries, term_req)
{
let intermediate_entry = into_intermediate_bucket_entry(
bucket,
reborrow_opt_collector(&mut sub_agg_collector),
agg_data,
)?;
let intermediate_entry = into_intermediate_bucket_entry(bucket, sub_agg)?;
dict.insert(intermediate_key, intermediate_entry);
}
@@ -1047,28 +1006,19 @@ where
entries.sort_unstable_by_key(|bucket| bucket.0);
let (term_ids, buckets): (Vec<u64>, Vec<Bucket>) = entries.into_iter().unzip();
let mut buckets_it = buckets.into_iter();
let intermediate_entries: Vec<IntermediateTermBucketEntry> = buckets
.into_iter()
.map(|bucket| {
into_intermediate_bucket_entry(
bucket,
reborrow_opt_collector(&mut sub_agg_collector),
agg_data,
)
})
.collect::<crate::Result<_>>()?;
let mut intermediate_entry_it = intermediate_entries.into_iter();
term_dict.sorted_ords_to_term_cb(&term_ids[..], |term| {
let intermediate_entry = intermediate_entry_it.next().unwrap();
term_dict.sorted_ords_to_term_cb(term_ids.into_iter(), |term| {
let bucket = buckets_it.next().unwrap();
let intermediate_entry =
into_intermediate_bucket_entry(bucket, sub_agg).map_err(io::Error::other)?;
dict.insert(
IntermediateKey::Str(
String::from_utf8(term.to_vec()).expect("could not convert to String"),
),
intermediate_entry,
);
Ok(())
})?;
if term_req.req.min_doc_count == 0 {
@@ -1103,22 +1053,14 @@ where
}
} else if term_req.column_type == ColumnType::DateTime {
for (val, doc_count) in entries {
let intermediate_entry = into_intermediate_bucket_entry(
doc_count,
reborrow_opt_collector(&mut sub_agg_collector),
agg_data,
)?;
let intermediate_entry = into_intermediate_bucket_entry(doc_count, sub_agg)?;
let val = i64::from_u64(val);
let date = format_date(val)?;
dict.insert(IntermediateKey::Str(date), intermediate_entry);
}
} else if term_req.column_type == ColumnType::Bool {
for (val, doc_count) in entries {
let intermediate_entry = into_intermediate_bucket_entry(
doc_count,
reborrow_opt_collector(&mut sub_agg_collector),
agg_data,
)?;
let intermediate_entry = into_intermediate_bucket_entry(doc_count, sub_agg)?;
let val = bool::from_u64(val);
dict.insert(IntermediateKey::Bool(val), intermediate_entry);
}
@@ -1138,22 +1080,14 @@ where
})?;
for (val, doc_count) in entries {
let intermediate_entry = into_intermediate_bucket_entry(
doc_count,
reborrow_opt_collector(&mut sub_agg_collector),
agg_data,
)?;
let intermediate_entry = into_intermediate_bucket_entry(doc_count, sub_agg)?;
let val: u128 = compact_space_accessor.compact_to_u128(val as u32);
let val = Ipv6Addr::from_u128(val);
dict.insert(IntermediateKey::IpAddr(val), intermediate_entry);
}
} else {
for (val, doc_count) in entries {
let intermediate_entry = into_intermediate_bucket_entry(
doc_count,
reborrow_opt_collector(&mut sub_agg_collector),
agg_data,
)?;
let intermediate_entry = into_intermediate_bucket_entry(doc_count, sub_agg)?;
if term_req.column_type == ColumnType::U64 {
dict.insert(IntermediateKey::U64(val), intermediate_entry);
} else if term_req.column_type == ColumnType::I64 {
@@ -1187,13 +1121,13 @@ where
}
}
impl<TermMap: TermAggregationMap, B: SubAggBuffer> SegmentTermCollector<TermMap, B> {
impl<TermMap: TermAggregationMap, C: SubAggCache> SegmentTermCollector<TermMap, C> {
#[inline]
fn collect_terms_with_docs(
iter: impl Iterator<Item = (crate::DocId, u64)>,
term_buckets: &mut TermMap,
bucket_id_provider: &mut BucketIdProvider,
sub_agg: &mut BufferedSubAggs<B>,
sub_agg: &mut CachedSubAggs<C>,
) {
for (doc, term_id) in iter {
let bucket_id = term_buckets.term_entry(term_id, bucket_id_provider);
@@ -1266,7 +1200,7 @@ mod tests {
use crate::aggregation::{AggregationLimitsGuard, DistributedAggregationCollector};
use crate::indexer::NoMergePolicy;
use crate::query::AllQuery;
use crate::schema::{IntoIpv6Addr, Schema, FAST, INDEXED, STRING, TEXT};
use crate::schema::{IntoIpv6Addr, Schema, FAST, STRING};
use crate::{Index, IndexWriter};
#[test]
@@ -1795,263 +1729,6 @@ mod tests {
Ok(())
}
#[test]
fn terms_aggregation_order_by_cardinality_desc_single_segment() -> crate::Result<()> {
terms_aggregation_order_by_cardinality_desc(true)
}
#[test]
fn terms_aggregation_order_by_cardinality_desc_multi_segment() -> crate::Result<()> {
terms_aggregation_order_by_cardinality_desc(false)
}
fn terms_aggregation_order_by_cardinality_desc(merge_segments: bool) -> crate::Result<()> {
// Distinct score values per bucket key: A→5, B→1, C→3.
// Order by cardinality desc must yield A, C, B.
let segment_and_terms = vec![vec![
(1.0, "A".to_string()),
(2.0, "A".to_string()),
(3.0, "A".to_string()),
(4.0, "A".to_string()),
(5.0, "A".to_string()),
(1.0, "B".to_string()),
(1.0, "B".to_string()),
(1.0, "B".to_string()),
(1.0, "C".to_string()),
(2.0, "C".to_string()),
(3.0, "C".to_string()),
]];
let index = get_test_index_from_values_and_terms(merge_segments, &segment_and_terms)?;
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"order": { "card": "desc" }
},
"aggs": {
"card": { "cardinality": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "A");
assert_eq!(res["my_texts"]["buckets"][0]["card"]["value"], 5.0);
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"][1]["card"]["value"], 3.0);
assert_eq!(res["my_texts"]["buckets"][2]["key"], "B");
assert_eq!(res["my_texts"]["buckets"][2]["card"]["value"], 1.0);
// Asc engages the segment-cutoff path too (monotonic-safe: discarded buckets had
// local card >= cutoff, so merged card >= cutoff and they cannot be globally smallest).
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"order": { "card": "asc" }
},
"aggs": {
"card": { "cardinality": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "B");
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"][2]["key"], "A");
// size=2 with desc engages the segment cutoff: must keep top-2 by cardinality (A, C),
// and `sum_other_doc_count` reflects the dropped B (3 docs).
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"size": 2,
"order": { "card": "desc" }
},
"aggs": {
"card": { "cardinality": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "A");
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"].as_array().unwrap().len(), 2);
// size=2 with asc engages the segment cutoff: must keep bottom-2 by cardinality (B, C).
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"size": 2,
"order": { "card": "asc" }
},
"aggs": {
"card": { "cardinality": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "B");
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"].as_array().unwrap().len(), 2);
Ok(())
}
#[test]
fn terms_aggregation_order_by_sum_single_segment() -> crate::Result<()> {
terms_aggregation_order_by_sum(true)
}
#[test]
fn terms_aggregation_order_by_sum_multi_segment() -> crate::Result<()> {
terms_aggregation_order_by_sum(false)
}
fn terms_aggregation_order_by_sum(merge_segments: bool) -> crate::Result<()> {
// Per-bucket sums on the U64 `score` column (non-negative => sum is monotonic):
// A → 1+2+3+4+5 = 15, B → 1+1+1 = 3, C → 1+2+3 = 6.
let segment_and_terms = vec![
vec![
(1.0, "A".to_string()),
(2.0, "A".to_string()),
(3.0, "A".to_string()),
(1.0, "B".to_string()),
(1.0, "C".to_string()),
],
vec![
(4.0, "A".to_string()),
(5.0, "A".to_string()),
(1.0, "B".to_string()),
(1.0, "B".to_string()),
(2.0, "C".to_string()),
(3.0, "C".to_string()),
],
];
let index = get_test_index_from_values_and_terms(merge_segments, &segment_and_terms)?;
// Desc on a Sum metric engages the fast path (column is U64).
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"order": { "total": "desc" }
},
"aggs": {
"total": { "sum": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "A");
assert_eq!(res["my_texts"]["buckets"][0]["total"]["value"], 15.0);
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"][1]["total"]["value"], 6.0);
assert_eq!(res["my_texts"]["buckets"][2]["key"], "B");
assert_eq!(res["my_texts"]["buckets"][2]["total"]["value"], 3.0);
// Asc engages the fast path too — discarded buckets had local sum >= cutoff,
// and merged sum >= local (non-negative addends), so they cannot be globally smallest.
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"order": { "total": "asc" }
},
"aggs": {
"total": { "sum": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "B");
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"][2]["key"], "A");
// size=2 desc with cutoff: top-2 by sum (A, C).
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"size": 2,
"order": { "total": "desc" }
},
"aggs": {
"total": { "sum": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "A");
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"].as_array().unwrap().len(), 2);
// Stats sub-property: ordering by `mystats.sum` on a U64 column also engages.
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"order": { "mystats.sum": "desc" }
},
"aggs": {
"mystats": { "stats": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "A");
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"][2]["key"], "B");
// Sum on a signed column (I64) takes the same cutoff path. Results may be
// approximate near the boundary on adversarial data, but for this dataset the
// top-K is unambiguous.
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"order": { "total": "desc" }
},
"aggs": {
"total": { "sum": { "field": "score_i64" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "A");
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"][2]["key"], "B");
// Order by extended_stats sub-property exercises compute_metric_value on the
// ExtendedStats collector. A→max=5, B→max=1, C→max=3, so desc by max → A, C, B.
let agg_req: Aggregations = serde_json::from_value(json!({
"my_texts": {
"terms": {
"field": "string_id",
"order": { "ext.max": "desc" }
},
"aggs": {
"ext": { "extended_stats": { "field": "score" } }
}
}
}))
.unwrap();
let res = exec_request(agg_req, &index)?;
assert_eq!(res["my_texts"]["buckets"][0]["key"], "A");
assert_eq!(res["my_texts"]["buckets"][1]["key"], "C");
assert_eq!(res["my_texts"]["buckets"][2]["key"], "B");
Ok(())
}
#[test]
fn terms_aggregation_test_order_key_single_segment() -> crate::Result<()> {
terms_aggregation_test_order_key_merge_segment(true)
@@ -2670,7 +2347,7 @@ mod tests {
// text field
assert_eq!(res["my_texts"]["buckets"][0]["key"], "Hello Hello");
assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 4);
assert_eq!(res["my_texts"]["buckets"][0]["doc_count"], 5);
assert_eq!(res["my_texts"]["buckets"][1]["key"], "Empty");
assert_eq!(res["my_texts"]["buckets"][1]["doc_count"], 2);
assert_eq!(
@@ -2679,7 +2356,7 @@ mod tests {
);
// text field with number as missing fallback
assert_eq!(res["my_texts2"]["buckets"][0]["key"], "Hello Hello");
assert_eq!(res["my_texts2"]["buckets"][0]["doc_count"], 4);
assert_eq!(res["my_texts2"]["buckets"][0]["doc_count"], 5);
assert_eq!(res["my_texts2"]["buckets"][1]["key"], 1337.0);
assert_eq!(res["my_texts2"]["buckets"][1]["doc_count"], 2);
assert_eq!(
@@ -2693,7 +2370,7 @@ mod tests {
assert_eq!(res["my_ids"]["buckets"][0]["key"], 1337.0);
assert_eq!(res["my_ids"]["buckets"][0]["doc_count"], 4);
assert_eq!(res["my_ids"]["buckets"][1]["key"], 1.0);
assert_eq!(res["my_ids"]["buckets"][1]["doc_count"], 2);
assert_eq!(res["my_ids"]["buckets"][1]["doc_count"], 3);
assert_eq!(res["my_ids"]["buckets"][2]["key"], serde_json::Value::Null);
Ok(())
@@ -3217,101 +2894,4 @@ mod tests {
Ok(())
}
fn prep_index_with_n_unique_terms_plus_one_null(n: u64) -> crate::Result<Index> {
let mut schema_builder = Schema::builder();
let id_field = schema_builder.add_u64_field("id", INDEXED);
let title_field = schema_builder.add_text_field("title", TEXT | FAST);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema.clone());
// set to one thread to guarantee all docs end up in the same segment
let mut writer = index.writer_with_num_threads(1, 50_000_000)?;
writer.add_document(doc!(
id_field => 0u64,
))?;
for i in 1u64..=n {
let title = format!("foo{i}");
writer.add_document(doc!(
id_field => i,
title_field => title,
))?;
}
writer.commit()?;
Ok(index)
}
#[test]
fn null_bitset_bounds_check_regression() -> crate::Result<()> {
// include cases
for i in 0..=4 {
let index = prep_index_with_n_unique_terms_plus_one_null(i * 64)?;
let normal_req: Aggregations = serde_json::from_value(json!({
"my_bool": {
"terms": {
"field": "title",
"missing": "__NULL__",
"size": 1000,
}
}
}))?;
let include_req: Aggregations = serde_json::from_value(json!({
"my_bool": {
"terms": {
"field": "title",
"include": "foo(.*)",
"missing": "__NULL__",
"size": 1000,
}
}
}))?;
let exclude_req: Aggregations = serde_json::from_value(json!({
"my_bool": {
"terms": {
"field": "title",
"exclude": "foo(.*)",
"missing": "__NULL__",
"size": 1000,
}
}
}))?;
let normal_res = exec_request(normal_req, &index)?;
let normal_buckets = normal_res["my_bool"]["buckets"].as_array().unwrap();
assert_eq!(
normal_buckets.len(),
(i * 64) as usize + 1,
"The normal request should return all 'foo' buckets, plus the missing term bucket",
);
let include_res = exec_request(include_req, &index)?;
eprintln!("include_res: {include_res:?}");
let include_buckets = include_res["my_bool"]["buckets"].as_array().unwrap();
assert_eq!(
include_buckets.len(),
(i * 64) as usize,
"The include request should return all 'foo' buckets, and not the missing term \
bucket",
);
assert!(include_buckets
.iter()
.all(|b| b["key"].as_str().unwrap().starts_with("foo")));
let exclude_res = exec_request(exclude_req, &index)?;
let exclude_buckets = exclude_res["my_bool"]["buckets"].as_array().unwrap();
if i != 0 {
// TODO: Remove this if after fixing exclude + missing bug
assert_eq!(
exclude_buckets.len(),
1,
"The exclude request should exclude all 'foo' buckets, and only the missing \
term bucket",
);
assert_eq!(exclude_buckets[0]["key"], "__NULL__");
}
}
Ok(())
}
}

View File

@@ -5,7 +5,7 @@ use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
};
use crate::aggregation::bucket::term_agg::TermsAggregation;
use crate::aggregation::buffered_sub_aggs::{BufferedSubAggs, HighCardBufferedSubAggs};
use crate::aggregation::cached_sub_aggs::{CachedSubAggs, HighCardCachedSubAggs};
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult,
IntermediateKey, IntermediateTermBucketEntry, IntermediateTermBucketResult,
@@ -47,7 +47,7 @@ struct MissingCount {
#[derive(Default, Debug)]
pub struct TermMissingAgg {
accessor_idx: usize,
sub_agg: Option<HighCardBufferedSubAggs>,
sub_agg: Option<HighCardCachedSubAggs>,
/// Idx = parent bucket id, Value = missing count for that bucket
missing_count_per_bucket: Vec<MissingCount>,
bucket_id_provider: BucketIdProvider,
@@ -66,7 +66,7 @@ impl TermMissingAgg {
None
};
let sub_agg = sub_agg.map(BufferedSubAggs::new);
let sub_agg = sub_agg.map(CachedSubAggs::new);
let bucket_id_provider = BucketIdProvider::default();
Ok(Self {
@@ -177,17 +177,6 @@ impl SegmentAggregationCollector for TermMissingAgg {
}
Ok(())
}
fn compute_metric_value(
&self,
_bucket_id: BucketId,
_sub_agg_name: &str,
_sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
// TODO: forward to `sub_agg` for nested order paths (`missing_agg>metric`).
None
}
}
#[cfg(test)]

View File

@@ -6,7 +6,7 @@ use crate::aggregation::bucket::MAX_NUM_TERMS_FOR_VEC;
use crate::aggregation::BucketId;
use crate::DocId;
/// A buffer for sub-aggregations, storing doc ids per bucket id.
/// A cache for sub-aggregations, storing doc ids per bucket id.
/// Depending on the cardinality of the parent aggregation, we use different
/// storage strategies.
///
@@ -24,21 +24,21 @@ use crate::DocId;
/// aggregations.
/// What this datastructure does in general is to group docs by bucket id.
#[derive(Debug)]
pub(crate) struct BufferedSubAggs<B: SubAggBuffer> {
buffer: B,
pub(crate) struct CachedSubAggs<C: SubAggCache> {
cache: C,
sub_agg_collector: Box<dyn SegmentAggregationCollector>,
num_docs: usize,
}
pub type LowCardBufferedSubAggs = BufferedSubAggs<LowCardSubAggBuffer>;
pub type HighCardBufferedSubAggs = BufferedSubAggs<HighCardSubAggBuffer>;
pub type LowCardCachedSubAggs = CachedSubAggs<LowCardSubAggCache>;
pub type HighCardCachedSubAggs = CachedSubAggs<HighCardSubAggCache>;
const FLUSH_THRESHOLD: usize = 2048;
/// A trait for buffering sub-aggregation doc ids per bucket id.
/// A trait for caching sub-aggregation doc ids per bucket id.
/// Different implementations can be used depending on the cardinality
/// of the parent aggregation.
pub trait SubAggBuffer: Debug {
pub trait SubAggCache: Debug {
fn new() -> Self;
fn push(&mut self, bucket_id: BucketId, doc_id: DocId);
fn flush_local(
@@ -49,22 +49,22 @@ pub trait SubAggBuffer: Debug {
) -> crate::Result<()>;
}
impl<Backend: SubAggBuffer + Debug> BufferedSubAggs<Backend> {
impl<Backend: SubAggCache + Debug> CachedSubAggs<Backend> {
pub fn new(sub_agg: Box<dyn SegmentAggregationCollector>) -> Self {
Self {
buffer: Backend::new(),
cache: Backend::new(),
sub_agg_collector: sub_agg,
num_docs: 0,
}
}
pub fn get_sub_agg_collector(&mut self) -> &mut dyn SegmentAggregationCollector {
&mut *self.sub_agg_collector
pub fn get_sub_agg_collector(&mut self) -> &mut Box<dyn SegmentAggregationCollector> {
&mut self.sub_agg_collector
}
#[inline]
pub fn push(&mut self, bucket_id: BucketId, doc_id: DocId) {
self.buffer.push(bucket_id, doc_id);
self.cache.push(bucket_id, doc_id);
self.num_docs += 1;
}
@@ -75,7 +75,7 @@ impl<Backend: SubAggBuffer + Debug> BufferedSubAggs<Backend> {
agg_data: &mut AggregationsSegmentCtx,
) -> crate::Result<()> {
if self.num_docs >= FLUSH_THRESHOLD {
self.buffer
self.cache
.flush_local(&mut self.sub_agg_collector, agg_data, false)?;
self.num_docs = 0;
}
@@ -85,7 +85,7 @@ impl<Backend: SubAggBuffer + Debug> BufferedSubAggs<Backend> {
/// Note: this _does_ flush the sub aggregations.
pub fn flush(&mut self, agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
if self.num_docs != 0 {
self.buffer
self.cache
.flush_local(&mut self.sub_agg_collector, agg_data, true)?;
self.num_docs = 0;
}
@@ -94,11 +94,11 @@ impl<Backend: SubAggBuffer + Debug> BufferedSubAggs<Backend> {
}
}
/// Number of partitions for high cardinality sub-aggregation buffer.
/// Number of partitions for high cardinality sub-aggregation cache.
const NUM_PARTITIONS: usize = 16;
#[derive(Debug)]
pub(crate) struct HighCardSubAggBuffer {
pub(crate) struct HighCardSubAggCache {
/// This weird partitioning is used to do some cheap grouping on the bucket ids.
/// bucket ids are dense, e.g. when we don't detect the cardinality as low cardinality,
/// but there are just 16 bucket ids, each bucket id will go to its own partition.
@@ -108,7 +108,7 @@ pub(crate) struct HighCardSubAggBuffer {
partitions: Box<[PartitionEntry; NUM_PARTITIONS]>,
}
impl HighCardSubAggBuffer {
impl HighCardSubAggCache {
#[inline]
fn clear(&mut self) {
for partition in self.partitions.iter_mut() {
@@ -131,7 +131,7 @@ impl PartitionEntry {
}
}
impl SubAggBuffer for HighCardSubAggBuffer {
impl SubAggCache for HighCardSubAggCache {
fn new() -> Self {
Self {
partitions: Box::new(core::array::from_fn(|_| PartitionEntry::default())),
@@ -173,14 +173,14 @@ impl SubAggBuffer for HighCardSubAggBuffer {
}
#[derive(Debug)]
pub(crate) struct LowCardSubAggBuffer {
/// Buffer doc ids per bucket for sub-aggregations.
pub(crate) struct LowCardSubAggCache {
/// Cache doc ids per bucket for sub-aggregations.
///
/// The outer Vec is indexed by BucketId.
per_bucket_docs: Vec<Vec<DocId>>,
}
impl LowCardSubAggBuffer {
impl LowCardSubAggCache {
#[inline]
fn clear(&mut self) {
for v in &mut self.per_bucket_docs {
@@ -189,7 +189,7 @@ impl LowCardSubAggBuffer {
}
}
impl SubAggBuffer for LowCardSubAggBuffer {
impl SubAggCache for LowCardSubAggCache {
fn new() -> Self {
Self {
per_bucket_docs: Vec::new(),

View File

@@ -1,6 +1,6 @@
use super::agg_req::Aggregations;
use super::agg_result::AggregationResults;
use super::buffered_sub_aggs::LowCardBufferedSubAggs;
use super::cached_sub_aggs::LowCardCachedSubAggs;
use super::intermediate_agg_result::IntermediateAggregationResults;
use super::AggContextParams;
// group buffering strategy is chosen explicitly by callers; no need to hash-group on the fly.
@@ -66,7 +66,7 @@ impl Collector for DistributedAggregationCollector {
fn for_segment(
&self,
segment_local_id: crate::SegmentOrdinal,
reader: &crate::SegmentReader,
reader: &dyn SegmentReader,
) -> crate::Result<Self::Child> {
AggregationSegmentCollector::from_agg_req_and_reader(
&self.agg,
@@ -96,7 +96,7 @@ impl Collector for AggregationCollector {
fn for_segment(
&self,
segment_local_id: crate::SegmentOrdinal,
reader: &crate::SegmentReader,
reader: &dyn SegmentReader,
) -> crate::Result<Self::Child> {
AggregationSegmentCollector::from_agg_req_and_reader(
&self.agg,
@@ -136,7 +136,7 @@ fn merge_fruits(
/// `AggregationSegmentCollector` does the aggregation collection on a segment.
pub struct AggregationSegmentCollector {
aggs_with_accessor: AggregationsSegmentCtx,
agg_collector: LowCardBufferedSubAggs,
agg_collector: LowCardCachedSubAggs,
error: Option<TantivyError>,
}
@@ -145,14 +145,14 @@ impl AggregationSegmentCollector {
/// reader. Also includes validation, e.g. checking field types and existence.
pub fn from_agg_req_and_reader(
agg: &Aggregations,
reader: &SegmentReader,
reader: &dyn SegmentReader,
segment_ordinal: SegmentOrdinal,
context: &AggContextParams,
) -> crate::Result<Self> {
let mut agg_data =
build_aggregations_data_from_req(agg, reader, segment_ordinal, context.clone())?;
let mut result =
LowCardBufferedSubAggs::new(build_segment_agg_collectors_root(&mut agg_data)?);
LowCardCachedSubAggs::new(build_segment_agg_collectors_root(&mut agg_data)?);
result
.get_sub_agg_collector()
.prepare_max_bucket(0, &agg_data)?; // prepare for bucket zero

View File

@@ -15,9 +15,8 @@ use serde::{Deserialize, Serialize};
use super::agg_req::{Aggregation, AggregationVariants, Aggregations};
use super::agg_result::{AggregationResult, BucketResult, MetricResult, RangeBucketEntry};
use super::bucket::{
composite_intermediate_key_ordering, cut_off_buckets, get_agg_name_and_property,
intermediate_histogram_buckets_to_final_buckets, CompositeAggregation, GetDocCount,
MissingOrder, Order, OrderTarget, RangeAggregation, TermsAggregation,
cut_off_buckets, get_agg_name_and_property, intermediate_histogram_buckets_to_final_buckets,
GetDocCount, Order, OrderTarget, RangeAggregation, TermsAggregation,
};
use super::metric::{
IntermediateAverage, IntermediateCount, IntermediateExtendedStats, IntermediateMax,
@@ -26,7 +25,7 @@ use super::metric::{
use super::segment_agg_result::AggregationLimitsGuard;
use super::{format_date, AggregationError, Key, SerializedKey};
use crate::aggregation::agg_result::{
AggregationResults, BucketEntries, BucketEntry, CompositeBucketEntry, FilterBucketResult,
AggregationResults, BucketEntries, BucketEntry, FilterBucketResult,
};
use crate::aggregation::bucket::TermsAggregationInternal;
use crate::aggregation::metric::CardinalityCollector;
@@ -281,11 +280,6 @@ pub(crate) fn empty_from_req(req: &Aggregation) -> IntermediateAggregationResult
doc_count: 0,
sub_aggregations: IntermediateAggregationResults::default(),
}),
Composite(_) => {
IntermediateAggregationResult::Bucket(IntermediateBucketResult::Composite {
buckets: IntermediateCompositeBucketResult::default(),
})
}
}
}
@@ -479,11 +473,6 @@ pub enum IntermediateBucketResult {
/// Sub-aggregation results
sub_aggregations: IntermediateAggregationResults,
},
/// Composite aggregation
Composite {
/// The composite buckets
buckets: IntermediateCompositeBucketResult,
},
}
impl IntermediateBucketResult {
@@ -579,13 +568,6 @@ impl IntermediateBucketResult {
sub_aggregations: final_sub_aggregations,
}))
}
IntermediateBucketResult::Composite { buckets } => {
let composite_req = req
.agg
.as_composite()
.expect("unexpected aggregation, expected composite aggregation");
buckets.into_final_result(composite_req, req.sub_aggregation(), limits)
}
}
}
@@ -652,16 +634,6 @@ impl IntermediateBucketResult {
*doc_count_left += doc_count_right;
sub_aggs_left.merge_fruits(sub_aggs_right)?;
}
(
IntermediateBucketResult::Composite {
buckets: composite_left,
},
IntermediateBucketResult::Composite {
buckets: composite_right,
},
) => {
composite_left.merge_fruits(composite_right)?;
}
(IntermediateBucketResult::Range(_), _) => {
panic!("try merge on different types")
}
@@ -674,9 +646,6 @@ impl IntermediateBucketResult {
(IntermediateBucketResult::Filter { .. }, _) => {
panic!("try merge on different types")
}
(IntermediateBucketResult::Composite { .. }, _) => {
panic!("try merge on different types")
}
}
Ok(())
}
@@ -945,172 +914,6 @@ impl MergeFruits for IntermediateHistogramBucketEntry {
}
}
/// Entry for the composite bucket.
pub type IntermediateCompositeBucketEntry = IntermediateTermBucketEntry;
/// The fully typed key for composite aggregation
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum CompositeIntermediateKey {
/// Bool key
Bool(bool),
/// String key
Str(String),
/// Float key
F64(f64),
/// Signed integer key
I64(i64),
/// Unsigned integer key
U64(u64),
/// DateTime key, nanoseconds since epoch
DateTime(i64),
/// IP Address key
IpAddr(Ipv6Addr),
/// Missing value key
Null,
}
impl Eq for CompositeIntermediateKey {}
impl std::hash::Hash for CompositeIntermediateKey {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
core::mem::discriminant(self).hash(state);
match self {
CompositeIntermediateKey::Bool(val) => val.hash(state),
CompositeIntermediateKey::Str(text) => text.hash(state),
CompositeIntermediateKey::F64(val) => val.to_bits().hash(state),
CompositeIntermediateKey::U64(val) => val.hash(state),
CompositeIntermediateKey::I64(val) => val.hash(state),
CompositeIntermediateKey::DateTime(val) => val.hash(state),
CompositeIntermediateKey::IpAddr(val) => val.hash(state),
CompositeIntermediateKey::Null => {}
}
}
}
/// Composite aggregation page.
#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct IntermediateCompositeBucketResult {
pub(crate) entries: FxHashMap<Vec<CompositeIntermediateKey>, IntermediateCompositeBucketEntry>,
pub(crate) target_size: u32,
pub(crate) orders: Vec<(Order, MissingOrder)>,
}
impl IntermediateCompositeBucketResult {
pub(crate) fn into_final_result(
self,
req: &CompositeAggregation,
sub_aggregation_req: &Aggregations,
limits: &mut AggregationLimitsGuard,
) -> crate::Result<BucketResult> {
let trimmed_entry_vec =
trim_composite_buckets(self.entries, &self.orders, self.target_size)?;
let after_key = trimmed_entry_vec
.last()
.map(|bucket| {
let (intermediate_key, _entry) = bucket;
intermediate_key
.iter()
.enumerate()
.map(|(idx, intermediate_key)| {
let source = &req.sources[idx];
(source.name().to_string(), intermediate_key.clone().into())
})
.collect()
})
.unwrap_or_default();
let buckets = trimmed_entry_vec
.into_iter()
.map(|(intermediate_key, entry)| {
let key = intermediate_key
.into_iter()
.enumerate()
.map(|(idx, intermediate_key)| {
let source = &req.sources[idx];
(source.name().to_string(), intermediate_key.into())
})
.collect();
Ok(CompositeBucketEntry {
key,
doc_count: entry.doc_count as u64,
sub_aggregation: entry
.sub_aggregation
.into_final_result_internal(sub_aggregation_req, limits)?,
})
})
.collect::<crate::Result<Vec<_>>>()?;
Ok(BucketResult::Composite { after_key, buckets })
}
fn merge_fruits(&mut self, other: IntermediateCompositeBucketResult) -> crate::Result<()> {
merge_maps(&mut self.entries, other.entries)?;
if self.entries.len() as u32 > 2 * self.target_size {
self.trim()?;
}
Ok(())
}
/// Trim the composite buckets to the target size, according to the ordering.
pub(crate) fn trim(&mut self) -> crate::Result<()> {
if self.entries.len() as u32 <= self.target_size {
return Ok(());
}
let sorted_entries = trim_composite_buckets(
std::mem::take(&mut self.entries),
&self.orders,
self.target_size,
)?;
self.entries = sorted_entries.into_iter().collect();
Ok(())
}
}
fn trim_composite_buckets(
entries: FxHashMap<Vec<CompositeIntermediateKey>, IntermediateCompositeBucketEntry>,
orders: &[(Order, MissingOrder)],
target_size: u32,
) -> crate::Result<
Vec<(
Vec<CompositeIntermediateKey>,
IntermediateCompositeBucketEntry,
)>,
> {
let mut entries: Vec<_> = entries.into_iter().collect();
let mut sort_error: Option<TantivyError> = None;
entries.sort_by(|(left_key, _), (right_key, _)| {
if sort_error.is_some() {
return Ordering::Equal;
}
for idx in 0..orders.len() {
match composite_intermediate_key_ordering(
&left_key[idx],
&right_key[idx],
orders[idx].0,
orders[idx].1,
) {
Ok(ordering) if ordering != Ordering::Equal => return ordering,
Ok(_) => continue,
Err(err) => {
sort_error = Some(err);
break;
}
}
}
Ordering::Equal
});
if let Some(err) = sort_error {
return Err(err);
}
entries.truncate(target_size as usize);
Ok(entries)
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;

File diff suppressed because it is too large Load Diff

View File

@@ -399,26 +399,6 @@ impl SegmentAggregationCollector for SegmentExtendedStatsCollector {
}
Ok(())
}
fn compute_metric_value(
&self,
bucket_id: BucketId,
sub_agg_name: &str,
sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
if self.name != sub_agg_name {
return None;
}
let extended = self.buckets.get(bucket_id as usize)?;
// Finalize is a pure read of accumulators — calling it here for the cutoff sort
// doesn't disturb the eventual intermediate result.
extended
.finalize()
.get_value(sub_agg_property)
.ok()
.flatten()
}
}
#[cfg(test)]

View File

@@ -107,9 +107,10 @@ pub enum PercentileValues {
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
/// The entry when requesting percentiles with keyed: false
pub struct PercentileValuesVecEntry {
/// The percentile key (e.g. 1.0, 5.0, 25.0).
/// Percentile
pub key: f64,
/// The percentile value. `NaN` when there are no values.
/// Value at the percentile
pub value: f64,
}

View File

@@ -222,12 +222,6 @@ impl PercentilesCollector {
self.sketch.add(val);
}
/// Encode the underlying DDSketch to Java-compatible binary format
/// for cross-language serialization with Java consumers.
pub fn to_sketch_bytes(&self) -> Vec<u8> {
self.sketch.to_java_bytes()
}
pub(crate) fn merge_fruits(&mut self, right: PercentilesCollector) -> crate::Result<()> {
self.sketch.merge(&right.sketch).map_err(|err| {
TantivyError::AggregationError(AggregationError::InternalError(format!(
@@ -312,26 +306,6 @@ impl SegmentAggregationCollector for SegmentPercentilesCollector {
}
Ok(())
}
fn compute_metric_value(
&self,
bucket_id: BucketId,
sub_agg_name: &str,
sub_agg_property: &str,
agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
if agg_data.get_metric_req_data(self.accessor_idx).name != sub_agg_name {
return None;
}
let percentile: f64 = sub_agg_property.parse().ok()?;
if !(0.0..=100.0).contains(&percentile) {
return None;
}
let bucket = self.buckets.get(bucket_id as usize)?;
// DDSketch.quantile is a pure read; calling it here for the cutoff sort does
// not affect the intermediate state used for the final result.
bucket.sketch.quantile(percentile / 100.0).ok().flatten()
}
}
#[cfg(test)]
@@ -351,7 +325,7 @@ mod tests {
use crate::aggregation::AggregationCollector;
use crate::query::AllQuery;
use crate::schema::{Schema, FAST};
use crate::{assert_nearly_equals, Index};
use crate::Index;
#[test]
fn test_aggregation_percentiles_empty_index() -> crate::Result<()> {
@@ -634,16 +608,12 @@ mod tests {
let res = exec_request_with_query(agg_req, &index, None)?;
assert_eq!(res["range_with_stats"]["buckets"][0]["doc_count"], 3);
assert_nearly_equals!(
res["range_with_stats"]["buckets"][0]["percentiles"]["values"]["1.0"]
.as_f64()
.unwrap(),
assert_eq!(
res["range_with_stats"]["buckets"][0]["percentiles"]["values"]["1.0"],
5.0028295751107414
);
assert_nearly_equals!(
res["range_with_stats"]["buckets"][0]["percentiles"]["values"]["99.0"]
.as_f64()
.unwrap(),
assert_eq!(
res["range_with_stats"]["buckets"][0]["percentiles"]["values"]["99.0"],
10.07469668951144
);
@@ -689,14 +659,8 @@ mod tests {
let res = exec_request_with_query(agg_req, &index, None)?;
assert_nearly_equals!(
res["percentiles"]["values"]["1.0"].as_f64().unwrap(),
5.0028295751107414
);
assert_nearly_equals!(
res["percentiles"]["values"]["99.0"].as_f64().unwrap(),
10.07469668951144
);
assert_eq!(res["percentiles"]["values"]["1.0"], 5.0028295751107414);
assert_eq!(res["percentiles"]["values"]["99.0"], 10.07469668951144);
Ok(())
}

View File

@@ -321,40 +321,6 @@ impl<const COLUMN_TYPE_ID: u8> SegmentAggregationCollector
}
Ok(())
}
fn compute_metric_value(
&self,
bucket_id: BucketId,
sub_agg_name: &str,
sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
if self.name != sub_agg_name {
return None;
}
let stats = self.buckets.get(bucket_id as usize)?;
// The property depends on what we're collecting:
// - StatsType::Stats exposes count/sum/min/max/avg via dotted property.
// - Single-value kinds (Sum/Count/Min/Max/Average) expect an empty property and return
// the value they were configured to collect.
let prop = match self.collecting_for {
StatsType::Stats if !sub_agg_property.is_empty() => sub_agg_property,
StatsType::Sum if sub_agg_property.is_empty() => "sum",
StatsType::Count if sub_agg_property.is_empty() => "count",
StatsType::Max if sub_agg_property.is_empty() => "max",
StatsType::Min if sub_agg_property.is_empty() => "min",
StatsType::Average if sub_agg_property.is_empty() => "avg",
_ => return None,
};
match prop {
"count" => Some(stats.count as f64),
"sum" => Some(stats.sum),
"min" if stats.count > 0 => Some(stats.min),
"max" if stats.count > 0 => Some(stats.max),
"avg" if stats.count > 0 => Some(stats.sum / stats.count as f64),
_ => None,
}
}
}
#[inline]

View File

@@ -644,17 +644,6 @@ impl SegmentAggregationCollector for TopHitsSegmentCollector {
);
Ok(())
}
fn compute_metric_value(
&self,
_bucket_id: BucketId,
_sub_agg_name: &str,
_sub_agg_property: &str,
_agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
// top_hits is not a numeric metric and cannot be used as an order target.
None
}
}
#[cfg(test)]

View File

@@ -133,7 +133,7 @@ mod agg_limits;
pub mod agg_req;
pub mod agg_result;
pub mod bucket;
pub(crate) mod buffered_sub_aggs;
pub(crate) mod cached_sub_aggs;
mod collector;
mod date;
mod error;

View File

@@ -76,31 +76,6 @@ pub trait SegmentAggregationCollector: Debug {
fn flush(&mut self, _agg_data: &mut AggregationsSegmentCtx) -> crate::Result<()> {
Ok(())
}
/// Compute the segment-level metric value of the named direct-child metric for `bucket_id`.
///
/// Used by parent term aggs that order by a sub-aggregation: the parent sorts on
/// this value and cuts off at segment time, matching the approximation tradeoff
/// Elasticsearch makes for any sub-agg ordering.
///
/// `sub_agg_property` is the dotted suffix (e.g. `"sum"` in `mystats.sum`); empty when
/// the metric is a single-value kind such as cardinality.
///
/// Returns `None` only on name mismatch, unknown property, or empty bucket. Implementations
/// may finalize their per-bucket state (e.g. compute a percentile from a sketch); calls
/// must be idempotent so the final intermediate result is unaffected.
///
/// No default impl on purpose: every collector must decide explicitly whether it
/// produces a metric value, forwards into children (single-bucket aggs), or rejects
/// the lookup. A silent `None` default would let a parent term agg's cutoff sort all
/// buckets to the same key and drop arbitrary winners.
fn compute_metric_value(
&self,
bucket_id: BucketId,
sub_agg_name: &str,
sub_agg_property: &str,
agg_data: &AggregationsSegmentCtx,
) -> Option<f64>;
}
#[derive(Default)]
@@ -162,21 +137,4 @@ impl SegmentAggregationCollector for GenericSegmentAggregationResultsCollector {
}
Ok(())
}
fn compute_metric_value(
&self,
bucket_id: BucketId,
sub_agg_name: &str,
sub_agg_property: &str,
agg_data: &AggregationsSegmentCtx,
) -> Option<f64> {
for agg in &self.aggs {
if let Some(value) =
agg.compute_metric_value(bucket_id, sub_agg_name, sub_agg_property, agg_data)
{
return Some(value);
}
}
None
}
}

View File

@@ -1,34 +1,29 @@
/// Codec specific to postings data.
pub mod postings;
/// Codec specific to positions data.
pub mod positions;
/// Standard tantivy codec. This is the codec you use by default.
pub mod standard;
use std::io;
use std::sync::Arc;
pub use standard::StandardCodec;
use crate::codec::positions::PositionsCodec;
use crate::codec::postings::PostingsCodec;
use crate::fieldnorm::FieldNormReader;
use crate::postings::{Postings, TermInfo};
use crate::directory::Directory;
use crate::fastfield::AliveBitSet;
use crate::query::score_combiner::DoNothingCombiner;
use crate::query::term_query::TermScorer;
use crate::query::{box_scorer, Bm25Weight, BufferedUnionScorer, Scorer, SumCombiner};
use crate::schema::IndexRecordOption;
use crate::{DocId, InvertedIndexReader, Score};
use crate::query::{box_scorer, BufferedUnionScorer, Scorer, SumCombiner};
use crate::schema::Schema;
use crate::{DocId, Score, SegmentMeta, SegmentReader, TantivySegmentReader};
/// Codecs describes how data is layed out on disk.
///
/// For the moment, only postings codec can be custom.
pub trait Codec: Clone + std::fmt::Debug + Send + Sync + 'static {
/// The specific postings codec used by this codec.
/// The specific postings type used by this codec.
type PostingsCodec: PostingsCodec;
/// The specific positions codec used by this codec.
type PositionsCodec: PositionsCodec;
/// ID of the codec. It should be unique to your codec.
/// Make it human-readable, descriptive, short and unique.
const ID: &'static str;
@@ -42,60 +37,45 @@ pub trait Codec: Clone + std::fmt::Debug + Send + Sync + 'static {
/// Returns the postings codec.
fn postings_codec(&self) -> &Self::PostingsCodec;
/// Returns the positions codec.
fn positions_codec(&self) -> &Self::PositionsCodec;
/// Loads postings using the codec's concrete postings type.
fn load_postings_typed(
&self,
reader: &dyn crate::index::InvertedIndexReader,
term_info: &crate::postings::TermInfo,
option: crate::schema::IndexRecordOption,
) -> std::io::Result<<Self::PostingsCodec as crate::codec::postings::PostingsCodec>::Postings>
{
let postings_data = reader.read_raw_postings_data(term_info, option)?;
self.postings_codec()
.load_postings(term_info.doc_freq, postings_data)
}
/// Opens a segment reader using this codec.
///
/// Override this if your codec uses a custom segment reader implementation.
fn open_segment_reader(
&self,
directory: &dyn Directory,
segment_meta: &SegmentMeta,
schema: Schema,
custom_bitset: Option<AliveBitSet>,
) -> crate::Result<Arc<dyn SegmentReader>> {
let codec: Arc<dyn ObjectSafeCodec> = Arc::new(self.clone());
let reader = TantivySegmentReader::open_with_custom_alive_set_from_directory(
directory,
segment_meta,
schema,
codec,
custom_bitset,
)?;
Ok(Arc::new(reader))
}
}
/// Object-safe codec is a Codec that can be used in a trait object.
///
/// The point of it is to offer a way to use a codec without a proliferation of generics.
pub trait ObjectSafeCodec: 'static + Send + Sync {
/// Loads a type-erased Postings object for the given term.
///
/// If the schema used to build the index did not provide enough
/// information to match the requested `option`, a Postings is still
/// returned in a best-effort manner.
fn load_postings_type_erased(
&self,
term_info: &TermInfo,
option: IndexRecordOption,
inverted_index_reader: &InvertedIndexReader,
) -> io::Result<Box<dyn Postings>>;
/// Loads a type-erased TermScorer object for the given term.
///
/// If the schema used to build the index did not provide enough
/// information to match the requested `option`, a TermScorer is still
/// returned in a best-effort manner.
///
/// The point of this contraption is that the return TermScorer is backed,
/// not by Box<dyn Postings> but by the codec's concrete Postings type.
fn load_term_scorer_type_erased(
&self,
term_info: &TermInfo,
option: IndexRecordOption,
inverted_index_reader: &InvertedIndexReader,
fieldnorm_reader: FieldNormReader,
similarity_weight: Bm25Weight,
) -> io::Result<Box<dyn Scorer>>;
/// Loads a type-erased PhraseScorer object for the given term.
///
/// If the schema used to build the index did not provide enough
/// information to match the requested `option`, a TermScorer is still
/// returned in a best-effort manner.
///
/// The point of this contraption is that the return PhraseScorer is backed,
/// not by Box<dyn Postings> but by the codec's concrete Postings type.
fn new_phrase_scorer_type_erased(
&self,
term_infos: &[(usize, TermInfo)],
similarity_weight: Option<Bm25Weight>,
fieldnorm_reader: FieldNormReader,
slop: u32,
inverted_index_reader: &InvertedIndexReader,
) -> io::Result<Box<dyn Scorer>>;
/// Performs a for_each_pruning operation on the given scorer.
///
/// The function will go through matching documents and call the callback
@@ -124,53 +104,6 @@ pub trait ObjectSafeCodec: 'static + Send + Sync {
}
impl<TCodec: Codec> ObjectSafeCodec for TCodec {
fn load_postings_type_erased(
&self,
term_info: &TermInfo,
option: IndexRecordOption,
inverted_index_reader: &InvertedIndexReader,
) -> io::Result<Box<dyn Postings>> {
let postings = inverted_index_reader
.read_postings_from_terminfo_specialized(term_info, option, self)?;
Ok(Box::new(postings))
}
fn load_term_scorer_type_erased(
&self,
term_info: &TermInfo,
option: IndexRecordOption,
inverted_index_reader: &InvertedIndexReader,
fieldnorm_reader: FieldNormReader,
similarity_weight: Bm25Weight,
) -> io::Result<Box<dyn Scorer>> {
let scorer = inverted_index_reader.new_term_scorer_specialized(
term_info,
option,
fieldnorm_reader,
similarity_weight,
self,
)?;
Ok(box_scorer(scorer))
}
fn new_phrase_scorer_type_erased(
&self,
term_infos: &[(usize, TermInfo)],
similarity_weight: Option<Bm25Weight>,
fieldnorm_reader: FieldNormReader,
slop: u32,
inverted_index_reader: &InvertedIndexReader,
) -> io::Result<Box<dyn Scorer>> {
let scorer = inverted_index_reader.new_phrase_scorer_type_specialized(
term_infos,
similarity_weight,
fieldnorm_reader,
slop,
self,
)?;
Ok(box_scorer(scorer))
}
fn build_union_scorer_with_sum_combiner(
&self,
scorers: Vec<Box<dyn Scorer>>,

View File

@@ -1,49 +0,0 @@
use std::io;
use common::OwnedBytes;
/// Codec for the positions file.
pub trait PositionsCodec: Send + Sync + 'static {
/// The serializer type created by this codec.
type Serializer<W: io::Write>: PositionsSerializer<W>;
/// The reader type created by this codec.
type Reader: PositionsReader;
/// Creates a new positions serializer writing into `writer`.
fn new_serializer<W: io::Write>(&self, writer: W) -> Self::Serializer<W>;
/// Opens a positions reader from the given raw byte slice.
fn open_reader(&self, data: OwnedBytes) -> io::Result<Self::Reader>;
}
/// Serializes delta-encoded positions for all terms in a field.
///
/// A single serializer is reused across all terms. Clients must call
/// `close_term` after each term, then `close` once when the field is done.
pub trait PositionsSerializer<W: io::Write> {
/// Returns the total number of bytes written since this serializer was created.
fn written_bytes(&self) -> u64;
/// Appends delta-encoded positions for the current document.
fn write_positions_delta(&mut self, positions_delta: &[u32]);
/// Finalizes and flushes positions data for the current term.
fn close_term(&mut self) -> io::Result<()>;
/// Flushes the underlying writer. Must be called once after all terms are done.
fn close(self) -> io::Result<()>;
}
/// Reads delta-encoded positions from a byte slice.
pub trait PositionsReader: Send + 'static {
/// Fills `output` with delta-encoded positions starting at `offset`.
///
/// Hidden contract: offset values should be non-decreasing for best performance;
/// passing a lower offset resets internal state and incurs extra work.
fn read(&mut self, offset: u64, output: &mut [u32]);
/// Returns a heap-allocated clone of this reader.
///
/// Needed to clone `SegmentPostings`, which owns a boxed reader.
fn clone_box(&self) -> Box<dyn PositionsReader>;
}

View File

@@ -51,7 +51,7 @@ fn block_max_was_too_low_advance_one_scorer<TPostings: PostingsWithBlockMax>(
scorers: &mut [TermScorerWithMaxScore<TPostings>],
pivot_len: usize,
) {
debug_assert!(scorers.iter().map(|scorer| scorer.doc()).is_sorted());
debug_assert!(is_sorted(scorers.iter().map(|scorer| scorer.doc())));
let mut scorer_to_seek = pivot_len - 1;
let mut global_max_score = scorers[scorer_to_seek].max_score;
let mut doc_to_seek_after = scorers[scorer_to_seek].last_doc_in_block();
@@ -77,7 +77,7 @@ fn block_max_was_too_low_advance_one_scorer<TPostings: PostingsWithBlockMax>(
scorers[scorer_to_seek].seek(doc_to_seek_after);
restore_ordering(scorers, scorer_to_seek);
debug_assert!(scorers.iter().map(|scorer| scorer.doc()).is_sorted());
debug_assert!(is_sorted(scorers.iter().map(|scorer| scorer.doc())));
}
// Given a list of term_scorers and a `ord` and assuming that `term_scorers[ord]` is sorted
@@ -94,7 +94,7 @@ fn restore_ordering<TPostings: PostingsWithBlockMax>(
}
term_scorers.swap(i, i - 1);
}
debug_assert!(term_scorers.iter().map(|scorer| scorer.doc()).is_sorted());
debug_assert!(is_sorted(term_scorers.iter().map(|scorer| scorer.doc())));
}
// Attempts to advance all term_scorers between `&term_scorers[0..before_len]` to the pivot.
@@ -158,21 +158,17 @@ pub fn block_wand<TPostings: PostingsWithBlockMax>(
mut threshold: Score,
callback: &mut dyn FnMut(u32, Score) -> Score,
) {
scorers.retain(|scorer| scorer.doc() < TERMINATED);
if scorers.len() == 1 {
let scorer = scorers.pop().unwrap();
return block_wand_single_scorer(scorer, threshold, callback);
}
let mut scorers: Vec<TermScorerWithMaxScore<TPostings>> = scorers
.iter_mut()
.map(TermScorerWithMaxScore::from)
.collect();
// At this point we need to ensure that the scorers are sorted!
scorers.sort_by_key(|scorer| scorer.doc());
// At this point we need to ensure that the scorers are sorted!
debug_assert!(is_sorted(scorers.iter().map(|scorer| scorer.doc())));
while let Some((before_pivot_len, pivot_len, pivot_doc)) =
find_pivot_doc(&scorers[..], threshold)
{
debug_assert!(scorers.iter().map(|scorer| scorer.doc()).is_sorted());
debug_assert!(is_sorted(scorers.iter().map(|scorer| scorer.doc())));
debug_assert_ne!(pivot_doc, TERMINATED);
debug_assert!(before_pivot_len < pivot_len);
@@ -298,6 +294,18 @@ impl<TPostings: PostingsWithBlockMax> DerefMut for TermScorerWithMaxScore<'_, TP
}
}
fn is_sorted<I: Iterator<Item = DocId>>(mut it: I) -> bool {
if let Some(first) = it.next() {
let mut prev = first;
for doc in it {
if doc < prev {
return false;
}
prev = doc;
}
}
true
}
#[cfg(test)]
mod tests {
use std::cmp::Ordering;

View File

@@ -1,51 +1,25 @@
use std::io;
/// Block-max WAND algorithm.
pub mod block_wand;
use std::io;
use common::OwnedBytes;
use crate::codec::positions::PositionsReader;
use crate::fieldnorm::FieldNormReader;
use crate::postings::Postings;
use crate::query::{Bm25Weight, Scorer};
use crate::schema::IndexRecordOption;
use crate::{DocId, Score};
/// Postings codec.
/// Postings codec (read path).
pub trait PostingsCodec: Send + Sync + 'static {
/// Serializer type for the postings codec.
type PostingsSerializer: PostingsSerializer;
/// Postings type for the postings codec.
type Postings: Postings + Clone;
/// Creates a new postings serializer.
fn new_serializer(
&self,
avg_fieldnorm: Score,
mode: IndexRecordOption,
fieldnorm_reader: Option<FieldNormReader>,
) -> Self::PostingsSerializer;
/// Loads postings
///
/// Record option is the option that was passed at indexing time.
/// Requested option is the option that is requested.
///
/// For instance, we may have term_freq in the posting list
/// but we can skip decompressing as we read the posting list.
///
/// If record option does not support the requested option,
/// this method does NOT return an error and will in fact restrict
/// requested_option to what is available.
///
/// `position_reader` is `Some` iff `requested_option` includes positions.
/// It is already opened by the caller via the codec's `PositionsCodec`.
/// Load postings from raw bytes and metadata.
fn load_postings(
&self,
doc_freq: u32,
postings_data: OwnedBytes,
record_option: IndexRecordOption,
requested_option: IndexRecordOption,
position_reader: Option<Box<dyn PositionsReader>>,
postings_data: RawPostingsData,
) -> io::Result<Self::Postings>;
/// If your codec supports different ways to accelerate `for_each_pruning` that's
@@ -67,55 +41,17 @@ pub trait PostingsCodec: Send + Sync + 'static {
}
}
/// A postings serializer is a listener that is in charge of serializing postings
///
/// IO is done only once per postings, once all of the data has been received.
/// A serializer will therefore contain internal buffers.
///
/// A serializer is created once and recycled for all postings.
///
/// Clients should use PostingsSerializer as follows.
/// ```text
/// // First postings list
/// serializer.new_term(2, true);
/// serializer.write_doc(2, 1);
/// serializer.write_doc(6, 2);
/// serializer.close_term(3, &mut wrt)?;
/// // Second postings list
/// serializer.new_term(1, true);
/// serializer.write_doc(3, 1);
/// serializer.close_term(1, &mut wrt)?;
/// ```
pub trait PostingsSerializer {
/// The term_doc_freq here is the number of documents
/// in the postings lists.
///
/// It can be used to compute the idf that will be used for the
/// blockmax parameters.
///
/// If not available (e.g. if we do not collect `term_frequencies`
/// blockwand is disabled), the term_doc_freq passed will be set 0.
fn new_term(&mut self, term_doc_freq: u32, record_term_freq: bool);
/// Codec-specific per-term payload.
///
/// It is supplied right after `new_term` and before any `write_doc`, so the
/// codec can let it influence how the postings list is encoded.
///
/// Hidden contract: `new_term` MUST reset any per-term payload state to its
/// default. This method is only called for terms that actually have a
/// payload registered, so a codec cannot rely on it being called for every
/// term.
///
/// The default implementation ignores the payload.
fn set_term_payload(&mut self, _payload: &dyn std::any::Any) {}
/// Records a new document id for the current term.
/// The serializer may ignore it.
fn write_doc(&mut self, doc_id: DocId, term_freq: u32);
/// Closes the current term and writes the postings list associated.
fn close_term(&mut self, doc_freq: u32, wrt: &mut impl io::Write) -> io::Result<()>;
/// Raw postings bytes and metadata read from storage.
#[derive(Debug, Clone)]
pub struct RawPostingsData {
/// Raw postings bytes for the term.
pub postings_data: OwnedBytes,
/// Raw positions bytes for the term, if positions are available.
pub positions_data: Option<OwnedBytes>,
/// Record option of the indexed field.
pub record_option: IndexRecordOption,
/// Effective record option after downgrading to the indexed field capability.
pub effective_option: IndexRecordOption,
}
/// A light complement interface to Postings to allow block-max wand acceleration.

View File

@@ -1,22 +1,17 @@
use serde::{Deserialize, Serialize};
use crate::codec::standard::positions::StandardPositionsCodec;
use crate::codec::standard::postings::StandardPostingsCodec;
use crate::codec::Codec;
/// Tantivy's default postings codec.
pub mod postings;
/// Tantivy's default positions codec.
pub mod positions;
/// Tantivy's default codec.
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct StandardCodec;
impl Codec for StandardCodec {
type PostingsCodec = StandardPostingsCodec;
type PositionsCodec = StandardPositionsCodec;
const ID: &'static str = "tantivy-default";
@@ -37,8 +32,4 @@ impl Codec for StandardCodec {
fn postings_codec(&self) -> &Self::PostingsCodec {
&StandardPostingsCodec
}
fn positions_codec(&self) -> &Self::PositionsCodec {
&StandardPositionsCodec
}
}

View File

@@ -1,50 +0,0 @@
use std::io;
use common::OwnedBytes;
use crate::codec::positions::{PositionsCodec, PositionsReader, PositionsSerializer};
use crate::positions::{PositionReader, PositionSerializer};
/// The default positions codec for tantivy.
pub struct StandardPositionsCodec;
impl PositionsCodec for StandardPositionsCodec {
type Serializer<W: io::Write> = PositionSerializer<W>;
type Reader = PositionReader;
fn new_serializer<W: io::Write>(&self, writer: W) -> Self::Serializer<W> {
PositionSerializer::new(writer)
}
fn open_reader(&self, data: OwnedBytes) -> io::Result<Self::Reader> {
PositionReader::open(data)
}
}
impl<W: io::Write> PositionsSerializer<W> for PositionSerializer<W> {
fn written_bytes(&self) -> u64 {
PositionSerializer::written_bytes(self)
}
fn write_positions_delta(&mut self, positions_delta: &[u32]) {
PositionSerializer::write_positions_delta(self, positions_delta);
}
fn close_term(&mut self) -> io::Result<()> {
PositionSerializer::close_term(self)
}
fn close(self) -> io::Result<()> {
PositionSerializer::close(self)
}
}
impl PositionsReader for PositionReader {
fn read(&mut self, offset: u64, output: &mut [u32]) {
PositionReader::read(self, offset, output);
}
fn clone_box(&self) -> Box<dyn PositionsReader> {
Box::new(self.clone())
}
}

View File

@@ -1,50 +0,0 @@
use crate::postings::compression::COMPRESSION_BLOCK_SIZE;
use crate::DocId;
pub struct Block {
doc_ids: [DocId; COMPRESSION_BLOCK_SIZE],
term_freqs: [u32; COMPRESSION_BLOCK_SIZE],
len: usize,
}
impl Block {
pub fn new() -> Self {
Block {
doc_ids: [0u32; COMPRESSION_BLOCK_SIZE],
term_freqs: [0u32; COMPRESSION_BLOCK_SIZE],
len: 0,
}
}
pub fn doc_ids(&self) -> &[DocId] {
&self.doc_ids[..self.len]
}
pub fn term_freqs(&self) -> &[u32] {
&self.term_freqs[..self.len]
}
pub fn clear(&mut self) {
self.len = 0;
}
pub fn append_doc(&mut self, doc: DocId, term_freq: u32) {
let len = self.len;
self.doc_ids[len] = doc;
self.term_freqs[len] = term_freq;
self.len = len + 1;
}
pub fn is_full(&self) -> bool {
self.len == COMPRESSION_BLOCK_SIZE
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn last_doc(&self) -> DocId {
assert_eq!(self.len, COMPRESSION_BLOCK_SIZE);
self.doc_ids[COMPRESSION_BLOCK_SIZE - 1]
}
}

View File

@@ -2,10 +2,10 @@ use std::io;
use common::{OwnedBytes, VInt};
use crate::codec::standard::postings::skip::{BlockInfo, SkipReader};
use crate::codec::standard::postings::FreqReadingOption;
use crate::fieldnorm::FieldNormReader;
use crate::postings::compression::{BlockDecoder, VIntDecoder as _, COMPRESSION_BLOCK_SIZE};
use crate::postings::skip::{BlockInfo, SkipReader};
use crate::query::Bm25Weight;
use crate::schema::IndexRecordOption;
use crate::{DocId, Score, TERMINATED};
@@ -337,18 +337,17 @@ mod tests {
use common::OwnedBytes;
use super::BlockSegmentPostings;
use crate::codec::postings::PostingsSerializer;
use crate::codec::standard::postings::segment_postings::SegmentPostings;
use crate::codec::standard::postings::StandardPostingsSerializer;
use crate::docset::{DocSet, TERMINATED};
use crate::postings::compression::COMPRESSION_BLOCK_SIZE;
use crate::postings::serializer::PostingsSerializer;
use crate::schema::IndexRecordOption;
#[cfg(test)]
fn build_block_postings(docs: &[u32]) -> BlockSegmentPostings {
let doc_freq = docs.len() as u32;
let mut postings_serializer =
StandardPostingsSerializer::new(1.0f32, IndexRecordOption::Basic, None);
PostingsSerializer::new(1.0f32, IndexRecordOption::Basic, None);
postings_serializer.new_term(docs.len() as u32, false);
for doc in docs {
postings_serializer.write_doc(*doc, 1u32);

View File

@@ -1,24 +1,20 @@
use std::io;
use crate::codec::positions::PositionsReader;
use common::BitSet;
use crate::codec::postings::block_wand::{block_wand, block_wand_single_scorer};
use crate::codec::postings::PostingsCodec;
use crate::codec::postings::{PostingsCodec, RawPostingsData};
use crate::codec::standard::postings::block_segment_postings::BlockSegmentPostings;
pub use crate::codec::standard::postings::segment_postings::SegmentPostings;
use crate::fieldnorm::FieldNormReader;
use crate::positions::PositionReader;
use crate::query::term_query::TermScorer;
use crate::query::{BufferedUnionScorer, Scorer, SumCombiner};
use crate::schema::IndexRecordOption;
use crate::{DocSet as _, Score, TERMINATED};
mod block;
mod block_segment_postings;
mod segment_postings;
mod skip;
mod standard_postings_serializer;
pub use segment_postings::SegmentPostings as StandardPostings;
pub use standard_postings_serializer::StandardPostingsSerializer;
/// The default postings codec for tantivy.
pub struct StandardPostingsCodec;
@@ -32,34 +28,14 @@ pub(crate) enum FreqReadingOption {
}
impl PostingsCodec for StandardPostingsCodec {
type PostingsSerializer = StandardPostingsSerializer;
type Postings = SegmentPostings;
fn new_serializer(
&self,
avg_fieldnorm: Score,
mode: IndexRecordOption,
fieldnorm_reader: Option<FieldNormReader>,
) -> Self::PostingsSerializer {
StandardPostingsSerializer::new(avg_fieldnorm, mode, fieldnorm_reader)
}
fn load_postings(
&self,
doc_freq: u32,
postings_data: common::OwnedBytes,
record_option: IndexRecordOption,
requested_option: IndexRecordOption,
position_reader: Option<Box<dyn PositionsReader>>,
postings_data: RawPostingsData,
) -> io::Result<Self::Postings> {
// Rationalize record_option/requested_option.
let requested_option = requested_option.downgrade(record_option);
let block_segment_postings =
BlockSegmentPostings::open(doc_freq, postings_data, record_option, requested_option)?;
Ok(SegmentPostings::from_block_postings(
block_segment_postings,
position_reader,
))
load_postings_from_raw_data(doc_freq, postings_data)
}
fn try_accelerated_for_each_pruning(
@@ -75,14 +51,7 @@ impl PostingsCodec for StandardPostingsCodec {
Err(scorer) => scorer,
};
let mut union_scorer =
scorer.downcast::<BufferedUnionScorer<Box<dyn Scorer>, SumCombiner>>()?;
if !union_scorer
.scorers()
.iter()
.all(|scorer| scorer.is::<TermScorer<Self::Postings>>())
{
return Err(union_scorer);
}
scorer.downcast::<BufferedUnionScorer<TermScorer<Self::Postings>, SumCombiner>>()?;
let doc = union_scorer.doc();
if doc == TERMINATED {
return Ok(());
@@ -91,31 +60,69 @@ impl PostingsCodec for StandardPostingsCodec {
if score > threshold {
threshold = callback(doc, score);
}
let boxed_scorers: Vec<Box<dyn Scorer>> = union_scorer.into_scorers();
let scorers: Vec<TermScorer<Self::Postings>> = boxed_scorers
.into_iter()
.map(|scorer| {
*scorer.downcast::<TermScorer<Self::Postings>>().ok().expect(
"Downcast failed despite the fact we already checked the type was correct",
)
})
.collect();
let scorers: Vec<TermScorer<Self::Postings>> = union_scorer.into_scorers();
block_wand(scorers, threshold, callback);
Ok(())
}
}
pub(crate) fn load_postings_from_raw_data(
doc_freq: u32,
postings_data: RawPostingsData,
) -> io::Result<SegmentPostings> {
let RawPostingsData {
postings_data,
positions_data: positions_data_opt,
record_option,
effective_option,
} = postings_data;
let requested_option = effective_option;
let block_segment_postings =
BlockSegmentPostings::open(doc_freq, postings_data, record_option, requested_option)?;
let position_reader = positions_data_opt.map(PositionReader::open).transpose()?;
Ok(SegmentPostings::from_block_postings(
block_segment_postings,
position_reader,
))
}
pub(crate) fn fill_bitset_from_raw_data(
doc_freq: u32,
postings_data: RawPostingsData,
doc_bitset: &mut BitSet,
) -> io::Result<()> {
let RawPostingsData {
postings_data,
record_option,
effective_option,
..
} = postings_data;
let mut block_postings =
BlockSegmentPostings::open(doc_freq, postings_data, record_option, effective_option)?;
loop {
let docs = block_postings.docs();
if docs.is_empty() {
break;
}
for &doc in docs {
doc_bitset.insert(doc);
}
block_postings.advance();
}
Ok(())
}
#[cfg(test)]
mod tests {
use common::OwnedBytes;
use super::*;
use crate::codec::postings::PostingsSerializer as _;
use crate::postings::serializer::PostingsSerializer;
use crate::postings::Postings as _;
use crate::schema::IndexRecordOption;
fn test_segment_postings_tf_aux(num_docs: u32, include_term_freq: bool) -> SegmentPostings {
let mut postings_serializer =
StandardPostingsCodec.new_serializer(1.0f32, IndexRecordOption::WithFreqs, None);
PostingsSerializer::new(1.0f32, IndexRecordOption::WithFreqs, None);
let mut buffer = Vec::new();
postings_serializer.new_term(num_docs, include_term_freq);
for i in 0..num_docs {
@@ -124,15 +131,16 @@ mod tests {
postings_serializer
.close_term(num_docs, &mut buffer)
.unwrap();
StandardPostingsCodec
.load_postings(
num_docs,
OwnedBytes::new(buffer),
IndexRecordOption::WithFreqs,
IndexRecordOption::WithFreqs,
None,
)
.unwrap()
load_postings_from_raw_data(
num_docs,
RawPostingsData {
postings_data: OwnedBytes::new(buffer),
positions_data: None,
record_option: IndexRecordOption::WithFreqs,
effective_option: IndexRecordOption::WithFreqs,
},
)
.unwrap()
}
#[test]

View File

@@ -1,10 +1,10 @@
use common::BitSet;
use super::BlockSegmentPostings;
use crate::codec::positions::PositionsReader;
use crate::codec::postings::PostingsWithBlockMax;
use crate::docset::DocSet;
use crate::fieldnorm::FieldNormReader;
use crate::positions::PositionReader;
use crate::postings::compression::COMPRESSION_BLOCK_SIZE;
use crate::postings::{DocFreq, Postings};
use crate::query::Bm25Weight;
@@ -15,20 +15,11 @@ use crate::{DocId, Score};
///
/// As we iterate through the `SegmentPostings`, the frequencies are optionally decoded.
/// Positions on the other hand, are optionally entirely decoded upfront.
#[derive(Clone)]
pub struct SegmentPostings {
pub(crate) block_cursor: BlockSegmentPostings,
cur: usize,
position_reader: Option<Box<dyn PositionsReader>>,
}
impl Clone for SegmentPostings {
fn clone(&self) -> Self {
SegmentPostings {
block_cursor: self.block_cursor.clone(),
cur: self.cur,
position_reader: self.position_reader.as_ref().map(|r| r.clone_box()),
}
}
position_reader: Option<PositionReader>,
}
impl SegmentPostings {
@@ -56,14 +47,10 @@ impl SegmentPostings {
use crate::schema::IndexRecordOption;
let mut buffer = Vec::new();
{
use crate::codec::postings::PostingsSerializer;
use crate::postings::serializer::PostingsSerializer;
let mut postings_serializer =
crate::codec::standard::postings::StandardPostingsSerializer::new(
0.0,
IndexRecordOption::Basic,
None,
);
PostingsSerializer::new(0.0, IndexRecordOption::Basic, None);
postings_serializer.new_term(docs.len() as u32, false);
for &doc in docs {
postings_serializer.write_doc(doc, 1u32);
@@ -90,9 +77,8 @@ impl SegmentPostings {
) -> SegmentPostings {
use common::OwnedBytes;
use crate::codec::postings::PostingsSerializer as _;
use crate::codec::standard::postings::StandardPostingsSerializer;
use crate::fieldnorm::FieldNormReader;
use crate::postings::serializer::PostingsSerializer;
use crate::schema::IndexRecordOption;
use crate::Score;
let mut buffer: Vec<u8> = Vec::new();
@@ -109,7 +95,7 @@ impl SegmentPostings {
total_num_tokens as Score / fieldnorms.len() as Score
})
.unwrap_or(0.0);
let mut postings_serializer = StandardPostingsSerializer::new(
let mut postings_serializer = PostingsSerializer::new(
average_field_norm,
IndexRecordOption::WithFreqs,
fieldnorm_reader,
@@ -138,7 +124,7 @@ impl SegmentPostings {
/// * `freq_handler` - the freq handler is in charge of decoding frequencies and/or positions
pub(crate) fn from_block_postings(
segment_block_postings: BlockSegmentPostings,
position_reader: Option<Box<dyn PositionsReader>>,
position_reader: Option<PositionReader>,
) -> SegmentPostings {
SegmentPostings {
block_cursor: segment_block_postings,
@@ -278,6 +264,7 @@ impl Postings for SegmentPostings {
}
impl PostingsWithBlockMax for SegmentPostings {
#[inline]
fn seek_block_max(
&mut self,
target_doc: crate::DocId,
@@ -289,6 +276,7 @@ impl PostingsWithBlockMax for SegmentPostings {
.block_max_score(fieldnorm_reader, similarity_weight)
}
#[inline]
fn last_doc_in_block(&self) -> crate::DocId {
self.block_cursor.skip_reader().last_doc_in_block()
}

View File

@@ -1,184 +0,0 @@
use std::cmp::Ordering;
use std::io::{self, Write as _};
use common::{BinarySerializable as _, VInt};
use crate::codec::postings::PostingsSerializer;
use crate::codec::standard::postings::block::Block;
use crate::codec::standard::postings::skip::SkipSerializer;
use crate::fieldnorm::FieldNormReader;
use crate::postings::compression::{BlockEncoder, VIntEncoder as _, COMPRESSION_BLOCK_SIZE};
use crate::query::Bm25Weight;
use crate::schema::IndexRecordOption;
use crate::{DocId, Score};
/// Serializer object for tantivy's default postings format.
pub struct StandardPostingsSerializer {
last_doc_id_encoded: u32,
block_encoder: BlockEncoder,
block: Box<Block>,
postings_write: Vec<u8>,
skip_write: SkipSerializer,
mode: IndexRecordOption,
fieldnorm_reader: Option<FieldNormReader>,
bm25_weight: Option<Bm25Weight>,
avg_fieldnorm: Score, /* Average number of term in the field for that segment.
* this value is used to compute the block wand information. */
term_has_freq: bool,
}
impl StandardPostingsSerializer {
pub(crate) fn new(
avg_fieldnorm: Score,
mode: IndexRecordOption,
fieldnorm_reader: Option<FieldNormReader>,
) -> StandardPostingsSerializer {
Self {
last_doc_id_encoded: 0,
block_encoder: BlockEncoder::new(),
block: Box::new(Block::new()),
postings_write: Vec::new(),
skip_write: SkipSerializer::new(),
mode,
fieldnorm_reader,
bm25_weight: None,
avg_fieldnorm,
term_has_freq: false,
}
}
}
impl PostingsSerializer for StandardPostingsSerializer {
fn new_term(&mut self, term_doc_freq: u32, record_term_freq: bool) {
self.clear();
self.term_has_freq = self.mode.has_freq() && record_term_freq;
if !self.term_has_freq {
return;
}
let num_docs_in_segment: u64 =
if let Some(fieldnorm_reader) = self.fieldnorm_reader.as_ref() {
fieldnorm_reader.num_docs() as u64
} else {
return;
};
if num_docs_in_segment == 0 {
return;
}
self.bm25_weight = Some(Bm25Weight::for_one_term_without_explain(
term_doc_freq as u64,
num_docs_in_segment,
self.avg_fieldnorm,
));
}
fn write_doc(&mut self, doc_id: DocId, term_freq: u32) {
self.block.append_doc(doc_id, term_freq);
if self.block.is_full() {
self.write_block();
}
}
fn close_term(&mut self, doc_freq: u32, output_write: &mut impl io::Write) -> io::Result<()> {
if !self.block.is_empty() {
// we have doc ids waiting to be written
// this happens when the number of doc ids is
// not a perfect multiple of our block size.
//
// In that case, the remaining part is encoded
// using variable int encoding.
{
let block_encoded = self
.block_encoder
.compress_vint_sorted(self.block.doc_ids(), self.last_doc_id_encoded);
self.postings_write.write_all(block_encoded)?;
}
// ... Idem for term frequencies
if self.term_has_freq {
let block_encoded = self
.block_encoder
.compress_vint_unsorted(self.block.term_freqs());
self.postings_write.write_all(block_encoded)?;
}
self.block.clear();
}
if doc_freq >= COMPRESSION_BLOCK_SIZE as u32 {
let skip_data = self.skip_write.data();
VInt(skip_data.len() as u64).serialize(output_write)?;
output_write.write_all(skip_data)?;
}
output_write.write_all(&self.postings_write[..])?;
self.skip_write.clear();
self.postings_write.clear();
self.bm25_weight = None;
Ok(())
}
}
impl StandardPostingsSerializer {
fn clear(&mut self) {
self.bm25_weight = None;
self.block.clear();
self.last_doc_id_encoded = 0;
}
fn write_block(&mut self) {
{
// encode the doc ids
let (num_bits, block_encoded): (u8, &[u8]) = self
.block_encoder
.compress_block_sorted(self.block.doc_ids(), self.last_doc_id_encoded);
self.last_doc_id_encoded = self.block.last_doc();
self.skip_write
.write_doc(self.last_doc_id_encoded, num_bits);
// last el block 0, offset block 1,
self.postings_write.extend(block_encoded);
}
if self.term_has_freq {
let (num_bits, block_encoded): (u8, &[u8]) = self
.block_encoder
.compress_block_unsorted(self.block.term_freqs(), true);
self.postings_write.extend(block_encoded);
self.skip_write.write_term_freq(num_bits);
if self.mode.has_positions() {
// We serialize the sum of term freqs within the skip information
// in order to navigate through positions.
let sum_freq = self.block.term_freqs().iter().cloned().sum();
self.skip_write.write_total_term_freq(sum_freq);
}
let mut blockwand_params = (0u8, 0u32);
if let Some(bm25_weight) = self.bm25_weight.as_ref() {
if let Some(fieldnorm_reader) = self.fieldnorm_reader.as_ref() {
let docs = self.block.doc_ids().iter().cloned();
let term_freqs = self.block.term_freqs().iter().cloned();
let fieldnorms = docs.map(|doc| fieldnorm_reader.fieldnorm_id(doc));
blockwand_params = fieldnorms
.zip(term_freqs)
.max_by(
|(left_fieldnorm_id, left_term_freq),
(right_fieldnorm_id, right_term_freq)| {
let left_score =
bm25_weight.tf_factor(*left_fieldnorm_id, *left_term_freq);
let right_score =
bm25_weight.tf_factor(*right_fieldnorm_id, *right_term_freq);
left_score
.partial_cmp(&right_score)
.unwrap_or(Ordering::Equal)
},
)
.unwrap();
}
}
let (fieldnorm_id, term_freq) = blockwand_params;
self.skip_write.write_blockwand_max(fieldnorm_id, term_freq);
}
self.block.clear();
}
}

View File

@@ -1,6 +1,5 @@
use super::Collector;
use crate::collector::SegmentCollector;
use crate::query::Weight;
use crate::{DocId, Score, SegmentOrdinal, SegmentReader};
/// `CountCollector` collector only counts how many
@@ -44,7 +43,7 @@ impl Collector for Count {
fn for_segment(
&self,
_: SegmentOrdinal,
_: &SegmentReader,
_: &dyn SegmentReader,
) -> crate::Result<SegmentCountCollector> {
Ok(SegmentCountCollector::default())
}
@@ -56,15 +55,6 @@ impl Collector for Count {
fn merge_fruits(&self, segment_counts: Vec<usize>) -> crate::Result<usize> {
Ok(segment_counts.into_iter().sum())
}
fn collect_segment(
&self,
weight: &dyn Weight,
_segment_ord: u32,
reader: &SegmentReader,
) -> crate::Result<usize> {
Ok(weight.count(reader)? as usize)
}
}
#[derive(Default)]

View File

@@ -1,7 +1,7 @@
use std::collections::HashSet;
use super::{Collector, SegmentCollector};
use crate::{DocAddress, DocId, Score};
use crate::{DocAddress, DocId, Score, SegmentReader};
/// Collectors that returns the set of DocAddress that matches the query.
///
@@ -15,7 +15,7 @@ impl Collector for DocSetCollector {
fn for_segment(
&self,
segment_local_id: crate::SegmentOrdinal,
_segment: &crate::SegmentReader,
_segment: &dyn SegmentReader,
) -> crate::Result<Self::Child> {
Ok(DocSetChildCollector {
segment_local_id,

View File

@@ -265,7 +265,7 @@ impl Collector for FacetCollector {
fn for_segment(
&self,
_: SegmentOrdinal,
reader: &SegmentReader,
reader: &dyn SegmentReader,
) -> crate::Result<FacetSegmentCollector> {
let facet_reader = reader.facet_reader(&self.field_name)?;
let facet_dict = facet_reader.facet_dict();
@@ -389,13 +389,6 @@ impl SegmentCollector for FacetSegmentCollector {
}
let mut facet = vec![];
let (facet_ord, facet_depth) = self.unique_facet_ords[collapsed_facet_ord];
// u64::MAX is used as a sentinel for unmapped ordinals (e.g. when a
// document has the exact registered facet, not a child of it).
// Passing it to ord_to_term would resolve to the last dictionary
// entry and produce a spurious facet from an unrelated branch.
if facet_ord == u64::MAX {
continue;
}
// TODO handle errors.
if facet_dict.ord_to_term(facet_ord, &mut facet).is_ok() {
if let Some((end_collapsed_facet, _)) = facet
@@ -821,63 +814,6 @@ mod tests {
assert!(!super::is_child_facet(&b"foo\0bar"[..], &b"foo"[..]));
assert!(!super::is_child_facet(&b"foo"[..], &b"foobar\0baz"[..]));
}
// Regression test for https://github.com/quickwit-oss/tantivy/issues/2494
// When a document has the exact registered facet path (not just a child),
// harvest() must not turn the unmapped sentinel into a spurious root entry.
#[test]
fn test_facet_collector_wrong_root() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let facet_field = schema_builder.add_facet_field("facet", FacetOptions::default());
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let mut index_writer: IndexWriter = index.writer_for_tests()?;
let facets: Vec<&str> = vec![
"/science-fiction/asimov",
"/science-fiction/clarke",
"/science-fiction/dick",
"/science-fiction/herbert",
"/science-fiction/orwell",
// This exact match on the registered facet is the bug trigger:
// its ordinal maps to the sentinel (u64::MAX, 0) in the collapse
// mapping, which without the fix resolves to an unrelated term.
"/fantasy/epic-fantasy",
"/fantasy/epic-fantasy/tolkien",
"/fantasy/epic-fantasy/martin",
];
for facet_str in &facets {
index_writer.add_document(doc!(
facet_field => Facet::from(*facet_str)
))?;
}
index_writer.commit()?;
let reader = index.reader()?;
let searcher = reader.searcher();
let term = Term::from_facet(facet_field, &Facet::from("/fantasy/epic-fantasy"));
let query = TermQuery::new(term, IndexRecordOption::Basic);
let mut facet_collector = FacetCollector::for_field("facet");
facet_collector.add_facet("/fantasy/epic-fantasy");
let counts: FacetCounts = searcher.search(&query, &facet_collector)?;
let result: Vec<(String, u64)> = counts
.get("/")
.map(|(facet, count)| (facet.to_string(), count))
.collect();
// Only children of /fantasy/epic-fantasy should appear, not /science-fiction
assert_eq!(
result,
vec![
("/fantasy/epic-fantasy/martin".to_string(), 1),
("/fantasy/epic-fantasy/tolkien".to_string(), 1),
]
);
Ok(())
}
}
#[cfg(all(test, feature = "unstable"))]

View File

@@ -113,7 +113,7 @@ where
fn for_segment(
&self,
segment_local_id: u32,
segment_reader: &SegmentReader,
segment_reader: &dyn SegmentReader,
) -> crate::Result<Self::Child> {
let column_opt = segment_reader.fast_fields().column_opt(&self.field)?;
@@ -287,7 +287,7 @@ where
fn for_segment(
&self,
segment_local_id: u32,
segment_reader: &SegmentReader,
segment_reader: &dyn SegmentReader,
) -> crate::Result<Self::Child> {
let column_opt = segment_reader.fast_fields().bytes(&self.field)?;

View File

@@ -6,7 +6,7 @@ use fastdivide::DividerU64;
use crate::collector::{Collector, SegmentCollector};
use crate::fastfield::{FastFieldNotAvailableError, FastValue};
use crate::schema::Type;
use crate::{DocId, Score};
use crate::{DocId, Score, SegmentReader};
/// Histogram builds an histogram of the values of a fastfield for the
/// collected DocSet.
@@ -110,7 +110,7 @@ impl Collector for HistogramCollector {
fn for_segment(
&self,
_segment_local_id: crate::SegmentOrdinal,
segment: &crate::SegmentReader,
segment: &dyn SegmentReader,
) -> crate::Result<Self::Child> {
let column_opt = segment.fast_fields().u64_lenient(&self.field)?;
let (column, _column_type) = column_opt.ok_or_else(|| FastFieldNotAvailableError {

View File

@@ -156,7 +156,7 @@ pub trait Collector: Sync + Send {
fn for_segment(
&self,
segment_local_id: SegmentOrdinal,
segment: &SegmentReader,
segment: &dyn SegmentReader,
) -> crate::Result<Self::Child>;
/// Returns true iff the collector requires to compute scores for documents.
@@ -174,7 +174,7 @@ pub trait Collector: Sync + Send {
&self,
weight: &dyn Weight,
segment_ord: u32,
reader: &SegmentReader,
reader: &dyn SegmentReader,
) -> crate::Result<<Self::Child as SegmentCollector>::Fruit> {
let with_scoring = self.requires_scoring();
let mut segment_collector = self.for_segment(segment_ord, reader)?;
@@ -186,7 +186,7 @@ pub trait Collector: Sync + Send {
pub(crate) fn default_collect_segment_impl<TSegmentCollector: SegmentCollector>(
segment_collector: &mut TSegmentCollector,
weight: &dyn Weight,
reader: &SegmentReader,
reader: &dyn SegmentReader,
with_scoring: bool,
) -> crate::Result<()> {
match (reader.alive_bitset(), with_scoring) {
@@ -255,7 +255,7 @@ impl<TCollector: Collector> Collector for Option<TCollector> {
fn for_segment(
&self,
segment_local_id: SegmentOrdinal,
segment: &SegmentReader,
segment: &dyn SegmentReader,
) -> crate::Result<Self::Child> {
Ok(if let Some(inner) = self {
let inner_segment_collector = inner.for_segment(segment_local_id, segment)?;
@@ -336,7 +336,7 @@ where
fn for_segment(
&self,
segment_local_id: u32,
segment: &SegmentReader,
segment: &dyn SegmentReader,
) -> crate::Result<Self::Child> {
let left = self.0.for_segment(segment_local_id, segment)?;
let right = self.1.for_segment(segment_local_id, segment)?;
@@ -407,7 +407,7 @@ where
fn for_segment(
&self,
segment_local_id: u32,
segment: &SegmentReader,
segment: &dyn SegmentReader,
) -> crate::Result<Self::Child> {
let one = self.0.for_segment(segment_local_id, segment)?;
let two = self.1.for_segment(segment_local_id, segment)?;
@@ -487,7 +487,7 @@ where
fn for_segment(
&self,
segment_local_id: u32,
segment: &SegmentReader,
segment: &dyn SegmentReader,
) -> crate::Result<Self::Child> {
let one = self.0.for_segment(segment_local_id, segment)?;
let two = self.1.for_segment(segment_local_id, segment)?;

View File

@@ -24,7 +24,7 @@ impl<TCollector: Collector> Collector for CollectorWrapper<TCollector> {
fn for_segment(
&self,
segment_local_id: u32,
reader: &SegmentReader,
reader: &dyn SegmentReader,
) -> crate::Result<Box<dyn BoxableSegmentCollector>> {
let child = self.0.for_segment(segment_local_id, reader)?;
Ok(Box::new(SegmentCollectorWrapper(child)))
@@ -209,7 +209,7 @@ impl Collector for MultiCollector<'_> {
fn for_segment(
&self,
segment_local_id: SegmentOrdinal,
segment: &SegmentReader,
segment: &dyn SegmentReader,
) -> crate::Result<MultiCollectorChild> {
let children = self
.collector_wrappers

View File

@@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize};
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::schema::{OwnedValue, Schema};
use crate::{DocId, Order, Score};
use crate::{DocId, Order, Score, SegmentReader};
fn compare_owned_value<const NULLS_FIRST: bool>(lhs: &OwnedValue, rhs: &OwnedValue) -> Ordering {
match (lhs, rhs) {
@@ -430,7 +430,7 @@ where
fn segment_sort_key_computer(
&self,
segment_reader: &crate::SegmentReader,
segment_reader: &dyn SegmentReader,
) -> crate::Result<Self::Child> {
let child = self.0.segment_sort_key_computer(segment_reader)?;
Ok(SegmentSortKeyComputerWithComparator {
@@ -468,7 +468,7 @@ where
fn segment_sort_key_computer(
&self,
segment_reader: &crate::SegmentReader,
segment_reader: &dyn SegmentReader,
) -> crate::Result<Self::Child> {
let child = self.0.segment_sort_key_computer(segment_reader)?;
Ok(SegmentSortKeyComputerWithComparator {

View File

@@ -32,7 +32,7 @@ impl SortKeyComputer for SortByBytes {
fn segment_sort_key_computer(
&self,
segment_reader: &crate::SegmentReader,
segment_reader: &dyn crate::SegmentReader,
) -> crate::Result<Self::Child> {
let bytes_column_opt = segment_reader.fast_fields().bytes(&self.column_name)?;
Ok(ByBytesColumnSegmentSortKeyComputer { bytes_column_opt })

View File

@@ -6,7 +6,7 @@ use crate::collector::sort_key::{
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::fastfield::FastFieldNotAvailableError;
use crate::schema::OwnedValue;
use crate::{DateTime, DocId, Score};
use crate::{DateTime, DocId, Score, SegmentReader};
/// Sort by the boxed / OwnedValue representation of either a fast field, or of the score.
///
@@ -86,7 +86,7 @@ impl SortKeyComputer for SortByErasedType {
fn segment_sort_key_computer(
&self,
segment_reader: &crate::SegmentReader,
segment_reader: &dyn SegmentReader,
) -> crate::Result<Self::Child> {
let inner: Box<dyn ErasedSegmentSortKeyComputer> = match self {
Self::Field(column_name) => {

View File

@@ -1,9 +1,6 @@
use std::cmp::{Ordering, Reverse};
use std::collections::BinaryHeap;
use crate::collector::sort_key::NaturalComparator;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::{DocAddress, DocId, Score};
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer, TopNComputer};
use crate::{DocAddress, DocId, Score, SegmentReader};
/// Sort by similarity score.
#[derive(Clone, Debug, Copy)]
@@ -22,27 +19,25 @@ impl SortKeyComputer for SortBySimilarityScore {
fn segment_sort_key_computer(
&self,
_segment_reader: &crate::SegmentReader,
_segment_reader: &dyn SegmentReader,
) -> crate::Result<Self::Child> {
Ok(SortBySimilarityScore)
}
// Sorting by score is special in that it allows for the Block-Wand optimization.
//
// We use a BinaryHeap (TopNHeap) instead of TopNComputer here so that the
// threshold is always the exact K-th best score. TopNComputer only updates its
// threshold every K docs (at truncation), giving Block-WAND a stale bound.
fn collect_segment_top_k(
&self,
k: usize,
weight: &dyn crate::query::Weight,
reader: &crate::SegmentReader,
reader: &dyn SegmentReader,
segment_ord: u32,
) -> crate::Result<Vec<(Self::SortKey, DocAddress)>> {
let mut top_n = TopNHeap::new(k);
let mut top_n: TopNComputer<Score, DocId, Self::Comparator> =
TopNComputer::new_with_comparator(k, self.comparator());
if let Some(alive_bitset) = reader.alive_bitset() {
let mut threshold = Score::MIN;
top_n.threshold = Some(threshold);
weight.for_each_pruning(Score::MIN, reader, &mut |doc, score| {
if alive_bitset.is_deleted(doc) {
return threshold;
@@ -61,7 +56,7 @@ impl SortKeyComputer for SortBySimilarityScore {
Ok(top_n
.into_vec()
.into_iter()
.map(|(score, doc)| (score, DocAddress::new(segment_ord, doc)))
.map(|cid| (cid.sort_key, DocAddress::new(segment_ord, cid.doc)))
.collect())
}
}
@@ -80,204 +75,3 @@ impl SegmentSortKeyComputer for SortBySimilarityScore {
score
}
}
/// Min-heap entry: higher score = greater, lower doc wins ties.
struct ScoreHeapEntry {
score: Score,
doc: DocId,
}
impl Eq for ScoreHeapEntry {}
impl PartialEq for ScoreHeapEntry {
fn eq(&self, other: &Self) -> bool {
self.cmp(other) == Ordering::Equal
}
}
impl PartialOrd for ScoreHeapEntry {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for ScoreHeapEntry {
fn cmp(&self, other: &Self) -> Ordering {
self.score
.partial_cmp(&other.score)
.unwrap_or(Ordering::Equal)
.then_with(|| other.doc.cmp(&self.doc))
}
}
/// Heap-based top-K for score collection. O(log K) per insert, but the threshold
/// is always tight, so Block-WAND prunes better than with [`TopNComputer`]'s
/// buffer/median approach.
///
/// Like [`TopNComputer`], items must arrive in ascending doc order, and equal
/// scores are rejected (strict `>`) so that lower doc IDs win ties.
///
/// [`TopNComputer`]: crate::collector::TopNComputer
struct TopNHeap {
heap: BinaryHeap<Reverse<ScoreHeapEntry>>,
top_n: usize,
threshold: Option<Score>,
}
impl TopNHeap {
fn new(top_n: usize) -> Self {
TopNHeap {
heap: BinaryHeap::with_capacity(top_n),
top_n,
threshold: None,
}
}
#[inline]
fn push(&mut self, score: Score, doc: DocId) {
if self.heap.len() < self.top_n {
self.heap.push(Reverse(ScoreHeapEntry { score, doc }));
if self.heap.len() == self.top_n {
self.threshold = self.heap.peek().map(|Reverse(entry)| entry.score);
}
} else if let Some(threshold) = self.threshold {
if score > threshold {
// peek_mut + assign is a single sift-down, vs pop + push = two sifts.
if let Some(mut min) = self.heap.peek_mut() {
*min = Reverse(ScoreHeapEntry { score, doc });
}
self.threshold = self.heap.peek().map(|Reverse(entry)| entry.score);
}
}
}
fn into_vec(self) -> Vec<(Score, DocId)> {
self.heap
.into_vec()
.into_iter()
.map(|Reverse(entry)| (entry.score, entry.doc))
.collect()
}
}
#[cfg(test)]
mod tests {
use proptest::prelude::*;
use super::*;
use crate::collector::sort_key::NaturalComparator;
use crate::collector::TopNComputer;
#[test]
fn test_top_n_heap_zero_capacity() {
let mut heap = TopNHeap::new(0);
heap.push(1.0, 0);
heap.push(2.0, 1);
assert!(heap.into_vec().is_empty());
}
#[test]
fn test_top_n_heap_basic() {
let mut heap = TopNHeap::new(2);
heap.push(1.0, 0);
heap.push(3.0, 1);
heap.push(2.0, 2);
let mut results = heap.into_vec();
results.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap().then_with(|| a.1.cmp(&b.1)));
assert_eq!(results, vec![(3.0, 1), (2.0, 2)]);
}
#[test]
fn test_top_n_heap_threshold_always_accurate() {
let mut heap = TopNHeap::new(2);
assert_eq!(heap.threshold, None);
heap.push(1.0, 0);
assert_eq!(heap.threshold, None);
heap.push(3.0, 1);
assert_eq!(heap.threshold, Some(1.0));
heap.push(2.0, 2); // evicts 1.0
assert_eq!(heap.threshold, Some(2.0));
heap.push(4.0, 3); // evicts 2.0
assert_eq!(heap.threshold, Some(3.0));
}
#[test]
fn test_top_n_heap_tiebreaking_lower_doc_wins() {
let mut heap = TopNHeap::new(2);
heap.push(5.0, 0);
heap.push(5.0, 1);
heap.push(5.0, 2); // rejected: not strictly > threshold
let mut results = heap.into_vec();
results.sort_by_key(|&(_, doc)| doc);
assert_eq!(results, vec![(5.0, 0), (5.0, 1)]);
}
#[test]
fn test_top_n_heap_single_element() {
let mut heap = TopNHeap::new(1);
heap.push(1.0, 0);
assert_eq!(heap.threshold, Some(1.0));
heap.push(0.5, 1); // rejected
heap.push(2.0, 2); // accepted
assert_eq!(heap.threshold, Some(2.0));
let results = heap.into_vec();
assert_eq!(results, vec![(2.0, 2)]);
}
#[test]
fn test_top_n_heap_under_capacity() {
let mut heap = TopNHeap::new(5);
heap.push(3.0, 0);
heap.push(1.0, 1);
heap.push(2.0, 2);
// Only 3 elements, capacity is 5 — all should be kept
assert_eq!(heap.threshold, None);
let mut results = heap.into_vec();
results.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap().then_with(|| a.1.cmp(&b.1)));
assert_eq!(results, vec![(3.0, 0), (2.0, 2), (1.0, 1)]);
}
proptest! {
#[test]
fn test_top_n_heap_matches_top_n_computer(
limit in 0..20_usize,
mut docs in proptest::collection::vec((0..1000_u32, 0..1000_u32), 0..200_usize),
) {
// Both require ascending doc order.
docs.sort_by_key(|(_, doc_id)| *doc_id);
docs.dedup_by_key(|(_, doc_id)| *doc_id);
let mut heap = TopNHeap::new(limit);
let mut computer: TopNComputer<Score, DocId, NaturalComparator> =
TopNComputer::new_with_comparator(limit, NaturalComparator);
for &(score_u32, doc) in &docs {
let score = score_u32 as Score;
heap.push(score, doc);
computer.push(score, doc);
}
let mut heap_results = heap.into_vec();
heap_results.sort_by(|a, b| {
b.0.partial_cmp(&a.0).unwrap().then_with(|| a.1.cmp(&b.1))
});
let computer_results: Vec<(Score, DocId)> = computer
.into_sorted_vec()
.into_iter()
.map(|cd| (cd.sort_key, cd.doc))
.collect();
prop_assert_eq!(heap_results, computer_results);
}
}
}

View File

@@ -52,7 +52,7 @@ impl<T: FastValue> SortKeyComputer for SortByStaticFastValue<T> {
if schema_type != T::to_type() {
return Err(crate::TantivyError::SchemaError(format!(
"Field `{}` is of type {schema_type:?}, not of the type {:?}.",
self.field,
&self.field,
T::to_type()
)));
}
@@ -61,7 +61,7 @@ impl<T: FastValue> SortKeyComputer for SortByStaticFastValue<T> {
fn segment_sort_key_computer(
&self,
segment_reader: &SegmentReader,
segment_reader: &dyn SegmentReader,
) -> crate::Result<Self::Child> {
let sort_column_opt = segment_reader.fast_fields().u64_lenient(&self.field)?;
let (sort_column, _sort_column_type) =

View File

@@ -3,7 +3,7 @@ use columnar::StrColumn;
use crate::collector::sort_key::NaturalComparator;
use crate::collector::{SegmentSortKeyComputer, SortKeyComputer};
use crate::termdict::TermOrdinal;
use crate::{DocId, Score};
use crate::{DocId, Score, SegmentReader};
/// Sort by the first value of a string column.
///
@@ -35,7 +35,7 @@ impl SortKeyComputer for SortByString {
fn segment_sort_key_computer(
&self,
segment_reader: &crate::SegmentReader,
segment_reader: &dyn SegmentReader,
) -> crate::Result<Self::Child> {
let str_column_opt = segment_reader.fast_fields().str(&self.column_name)?;
Ok(ByStringColumnSegmentSortKeyComputer { str_column_opt })

View File

@@ -119,7 +119,7 @@ pub trait SortKeyComputer: Sync {
&self,
k: usize,
weight: &dyn crate::query::Weight,
reader: &crate::SegmentReader,
reader: &dyn SegmentReader,
segment_ord: u32,
) -> crate::Result<Vec<(Self::SortKey, DocAddress)>> {
let with_scoring = self.requires_scoring();
@@ -135,7 +135,7 @@ pub trait SortKeyComputer: Sync {
}
/// Builds a child sort key computer for a specific segment.
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child>;
fn segment_sort_key_computer(&self, segment_reader: &dyn SegmentReader) -> Result<Self::Child>;
}
impl<HeadSortKeyComputer, TailSortKeyComputer> SortKeyComputer
@@ -156,7 +156,7 @@ where
(self.0.comparator(), self.1.comparator())
}
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
fn segment_sort_key_computer(&self, segment_reader: &dyn SegmentReader) -> Result<Self::Child> {
Ok((
self.0.segment_sort_key_computer(segment_reader)?,
self.1.segment_sort_key_computer(segment_reader)?,
@@ -357,7 +357,7 @@ where
)
}
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
fn segment_sort_key_computer(&self, segment_reader: &dyn SegmentReader) -> Result<Self::Child> {
let sort_key_computer1 = self.0.segment_sort_key_computer(segment_reader)?;
let sort_key_computer2 = self.1.segment_sort_key_computer(segment_reader)?;
let sort_key_computer3 = self.2.segment_sort_key_computer(segment_reader)?;
@@ -420,7 +420,7 @@ where
SortKeyComputer4::Comparator,
);
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
fn segment_sort_key_computer(&self, segment_reader: &dyn SegmentReader) -> Result<Self::Child> {
let sort_key_computer1 = self.0.segment_sort_key_computer(segment_reader)?;
let sort_key_computer2 = self.1.segment_sort_key_computer(segment_reader)?;
let sort_key_computer3 = self.2.segment_sort_key_computer(segment_reader)?;
@@ -454,7 +454,7 @@ where
impl<F, SegmentF, TSortKey> SortKeyComputer for F
where
F: 'static + Send + Sync + Fn(&SegmentReader) -> SegmentF,
F: 'static + Send + Sync + Fn(&dyn SegmentReader) -> SegmentF,
SegmentF: 'static + FnMut(DocId) -> TSortKey,
TSortKey: 'static + PartialOrd + Clone + Send + Sync + std::fmt::Debug,
{
@@ -462,7 +462,7 @@ where
type Child = SegmentF;
type Comparator = NaturalComparator;
fn segment_sort_key_computer(&self, segment_reader: &SegmentReader) -> Result<Self::Child> {
fn segment_sort_key_computer(&self, segment_reader: &dyn SegmentReader) -> Result<Self::Child> {
Ok((self)(segment_reader))
}
}
@@ -509,10 +509,10 @@ mod tests {
#[test]
fn test_lazy_score_computer() {
let score_computer_primary = |_segment_reader: &SegmentReader| |_doc: DocId| 200u32;
let score_computer_primary = |_segment_reader: &dyn SegmentReader| |_doc: DocId| 200u32;
let call_count = Arc::new(AtomicUsize::new(0));
let call_count_clone = call_count.clone();
let score_computer_secondary = move |_segment_reader: &SegmentReader| {
let score_computer_secondary = move |_segment_reader: &dyn SegmentReader| {
let call_count_new_clone = call_count_clone.clone();
move |_doc: DocId| {
call_count_new_clone.fetch_add(1, AtomicOrdering::SeqCst);
@@ -572,10 +572,10 @@ mod tests {
#[test]
fn test_lazy_score_computer_dynamic_ordering() {
let score_computer_primary = |_segment_reader: &SegmentReader| |_doc: DocId| 200u32;
let score_computer_primary = |_segment_reader: &dyn SegmentReader| |_doc: DocId| 200u32;
let call_count = Arc::new(AtomicUsize::new(0));
let call_count_clone = call_count.clone();
let score_computer_secondary = move |_segment_reader: &SegmentReader| {
let score_computer_secondary = move |_segment_reader: &dyn SegmentReader| {
let call_count_new_clone = call_count_clone.clone();
move |_doc: DocId| {
call_count_new_clone.fetch_add(1, AtomicOrdering::SeqCst);

View File

@@ -32,7 +32,11 @@ where TSortKeyComputer: SortKeyComputer + Send + Sync + 'static
self.sort_key_computer.check_schema(schema)
}
fn for_segment(&self, segment_ord: u32, segment_reader: &SegmentReader) -> Result<Self::Child> {
fn for_segment(
&self,
segment_ord: u32,
segment_reader: &dyn SegmentReader,
) -> Result<Self::Child> {
let segment_sort_key_computer = self
.sort_key_computer
.segment_sort_key_computer(segment_reader)?;
@@ -63,7 +67,7 @@ where TSortKeyComputer: SortKeyComputer + Send + Sync + 'static
&self,
weight: &dyn Weight,
segment_ord: u32,
reader: &SegmentReader,
reader: &dyn SegmentReader,
) -> crate::Result<Vec<(TSortKeyComputer::SortKey, DocAddress)>> {
let k = self.doc_range.end;
let docs = self

View File

@@ -5,7 +5,7 @@ use crate::query::{AllQuery, QueryParser};
use crate::schema::{Schema, FAST, TEXT};
use crate::time::format_description::well_known::Rfc3339;
use crate::time::OffsetDateTime;
use crate::{DateTime, DocAddress, Index, Searcher, TantivyDocument};
use crate::{DateTime, DocAddress, Index, Searcher, SegmentReader, TantivyDocument};
pub const TEST_COLLECTOR_WITH_SCORE: TestCollector = TestCollector {
compute_score: true,
@@ -109,7 +109,7 @@ impl Collector for TestCollector {
fn for_segment(
&self,
segment_id: SegmentOrdinal,
_reader: &SegmentReader,
_reader: &dyn SegmentReader,
) -> crate::Result<TestSegmentCollector> {
Ok(TestSegmentCollector {
segment_id,
@@ -180,7 +180,7 @@ impl Collector for FastFieldTestCollector {
fn for_segment(
&self,
_: SegmentOrdinal,
segment_reader: &SegmentReader,
segment_reader: &dyn SegmentReader,
) -> crate::Result<FastFieldSegmentCollector> {
let reader = segment_reader
.fast_fields()
@@ -243,7 +243,7 @@ impl Collector for BytesFastFieldTestCollector {
fn for_segment(
&self,
_segment_local_id: u32,
segment_reader: &SegmentReader,
segment_reader: &dyn SegmentReader,
) -> crate::Result<BytesFastFieldSegmentCollector> {
let column_opt = segment_reader.fast_fields().bytes(&self.field)?;
Ok(BytesFastFieldSegmentCollector {

View File

@@ -393,7 +393,7 @@ impl TopDocs {
/// // This is where we build our collector with our custom score.
/// let top_docs_by_custom_score = TopDocs
/// ::with_limit(10)
/// .tweak_score(move |segment_reader: &SegmentReader| {
/// .tweak_score(move |segment_reader: &dyn SegmentReader| {
/// // The argument is a function that returns our scoring
/// // function.
/// //
@@ -442,7 +442,7 @@ pub struct TweakScoreFn<F>(F);
impl<F, TTweakScoreSortKeyFn, TSortKey> SortKeyComputer for TweakScoreFn<F>
where
F: 'static + Send + Sync + Fn(&SegmentReader) -> TTweakScoreSortKeyFn,
F: 'static + Send + Sync + Fn(&dyn SegmentReader) -> TTweakScoreSortKeyFn,
TTweakScoreSortKeyFn: 'static + Fn(DocId, Score) -> TSortKey,
TweakScoreSegmentSortKeyComputer<TTweakScoreSortKeyFn>:
SegmentSortKeyComputer<SortKey = TSortKey, SegmentSortKey = TSortKey>,
@@ -458,7 +458,7 @@ where
fn segment_sort_key_computer(
&self,
segment_reader: &SegmentReader,
segment_reader: &dyn SegmentReader,
) -> crate::Result<Self::Child> {
Ok({
TweakScoreSegmentSortKeyComputer {
@@ -513,9 +513,7 @@ pub struct TopNComputer<Score, D, C> {
/// The buffer reverses sort order to get top-semantics instead of bottom-semantics
buffer: Vec<ComparableDoc<Score, D>>,
top_n: usize,
/// The current threshold for pruning. Documents with scores at or below
/// this value are skipped by `push()`. Updated when the buffer is truncated.
pub threshold: Option<Score>,
pub(crate) threshold: Option<Score>,
comparator: C,
}
@@ -1527,7 +1525,7 @@ mod tests {
let text_query = query_parser.parse_query("droopy tax")?;
let collector = TopDocs::with_limit(2)
.and_offset(1)
.order_by(move |_segment_reader: &SegmentReader| move |doc: DocId| doc);
.order_by(move |_segment_reader: &dyn SegmentReader| move |doc: DocId| doc);
let score_docs: Vec<(u32, DocAddress)> =
index.reader()?.searcher().search(&text_query, &collector)?;
assert_eq!(
@@ -1545,7 +1543,7 @@ mod tests {
let text_query = query_parser.parse_query("droopy tax").unwrap();
let collector = TopDocs::with_limit(2)
.and_offset(1)
.order_by(move |_segment_reader: &SegmentReader| move |doc: DocId| doc);
.order_by(move |_segment_reader: &dyn SegmentReader| move |doc: DocId| doc);
let score_docs: Vec<(u32, DocAddress)> = index
.reader()
.unwrap()

View File

@@ -52,8 +52,7 @@ use crate::{DateTime, DocId, Term};
/// We can therefore afford working with a map that is not imperfect. It is fine if several
/// path map to the same index position as long as the probability is relatively low.
#[derive(Default)]
#[doc(hidden)]
pub struct IndexingPositionsPerPath {
pub(crate) struct IndexingPositionsPerPath {
positions_per_path: FxHashMap<u32, IndexingPosition>,
}
@@ -105,8 +104,7 @@ fn index_json_object<'a, V: Value<'a>>(
}
#[expect(clippy::too_many_arguments)]
#[doc(hidden)]
pub fn index_json_value<'a, V: Value<'a>>(
pub(crate) fn index_json_value<'a, V: Value<'a>>(
doc: DocId,
json_value: V,
text_analyzer: &mut TextAnalyzer,

View File

@@ -8,7 +8,7 @@ use std::path::Path;
use once_cell::sync::Lazy;
pub use self::executor::Executor;
pub use self::searcher::{Searcher, SearcherGeneration};
pub use self::searcher::{Searcher, SearcherContext, SearcherGeneration};
/// The meta file contains all the information about the list of segments and the schema
/// of the index.

View File

@@ -4,13 +4,13 @@ use std::{fmt, io};
use crate::collector::Collector;
use crate::core::Executor;
use crate::index::{SegmentId, SegmentReader};
use crate::index::{Index, SegmentId, SegmentReader};
use crate::query::{Bm25StatisticsProvider, EnableScoring, Query};
use crate::schema::document::DocumentDeserialize;
use crate::schema::{Schema, Term};
use crate::schema::{Field, FieldType, Schema, TantivyDocument, Term};
use crate::space_usage::SearcherSpaceUsage;
use crate::store::{CacheStats, StoreReader};
use crate::{DocAddress, Index, Opstamp, TrackedObject};
use crate::store::{CacheStats, StoreReader, DOCSTORE_CACHE_CAPACITY};
use crate::tokenizer::{TextAnalyzer, TokenizerManager};
use crate::{DocAddress, Inventory, Opstamp, TantivyError, TrackedObject};
/// Identifies the searcher generation accessed by a [`Searcher`].
///
@@ -36,7 +36,7 @@ pub struct SearcherGeneration {
impl SearcherGeneration {
pub(crate) fn from_segment_readers(
segment_readers: &[SegmentReader],
segment_readers: &[Arc<dyn SegmentReader>],
generation_id: u64,
) -> Self {
let mut segment_id_to_del_opstamp = BTreeMap::new();
@@ -61,6 +61,103 @@ impl SearcherGeneration {
}
}
/// Search-time context required by a [`Searcher`].
#[derive(Clone)]
pub struct SearcherContext {
schema: Schema,
executor: Executor,
tokenizers: TokenizerManager,
fast_field_tokenizers: TokenizerManager,
}
impl SearcherContext {
/// Creates a context from explicit search-time components.
pub fn new(
schema: Schema,
executor: Executor,
tokenizers: TokenizerManager,
fast_field_tokenizers: TokenizerManager,
) -> SearcherContext {
SearcherContext {
schema,
executor,
tokenizers,
fast_field_tokenizers,
}
}
/// Creates a context from an index.
pub fn from_index<C: crate::codec::Codec>(index: &Index<C>) -> SearcherContext {
SearcherContext::new(
index.schema(),
index.search_executor().clone(),
index.tokenizers().clone(),
index.fast_field_tokenizer().clone(),
)
}
/// Access the schema associated with this context.
pub fn schema(&self) -> &Schema {
&self.schema
}
/// Access the executor associated with this context.
pub fn search_executor(&self) -> &Executor {
&self.executor
}
/// Access the tokenizer manager associated with this context.
pub fn tokenizers(&self) -> &TokenizerManager {
&self.tokenizers
}
/// Access the fast field tokenizer manager associated with this context.
pub fn fast_field_tokenizer(&self) -> &TokenizerManager {
&self.fast_field_tokenizers
}
/// Get the tokenizer associated with a specific field.
pub fn tokenizer_for_field(&self, field: Field) -> crate::Result<TextAnalyzer> {
let field_entry = self.schema.get_field_entry(field);
let field_type = field_entry.field_type();
let indexing_options_opt = match field_type {
FieldType::JsonObject(options) => options.get_text_indexing_options(),
FieldType::Str(options) => options.get_indexing_options(),
_ => {
return Err(TantivyError::SchemaError(format!(
"{:?} is not a text field.",
field_entry.name()
)))
}
};
let indexing_options = indexing_options_opt.ok_or_else(|| {
TantivyError::InvalidArgument(format!(
"No indexing options set for field {field_entry:?}"
))
})?;
self.tokenizers
.get(indexing_options.tokenizer())
.ok_or_else(|| {
TantivyError::InvalidArgument(format!(
"No Tokenizer found for field {field_entry:?}"
))
})
}
}
impl<C: crate::codec::Codec> From<&Index<C>> for SearcherContext {
fn from(index: &Index<C>) -> Self {
SearcherContext::from_index(index)
}
}
impl<C: crate::codec::Codec> From<Index<C>> for SearcherContext {
fn from(index: Index<C>) -> Self {
SearcherContext::from(&index)
}
}
/// Holds a list of `SegmentReader`s ready for search.
///
/// It guarantees that the `Segment` will not be removed before
@@ -71,9 +168,66 @@ pub struct Searcher {
}
impl Searcher {
/// Returns the `Index` associated with the `Searcher`
pub fn index(&self) -> &Index {
&self.inner.index
/// Creates a `Searcher` from an arbitrary list of segment readers.
///
/// This is useful when segment readers are not opened from
/// `IndexReader` / `meta.json` (e.g. external segment sources).
/// The generated [`SearcherGeneration`] uses `generation_id = 0`.
pub fn from_segment_readers<Ctx: Into<SearcherContext>>(
context: Ctx,
segment_readers: Vec<Arc<dyn SegmentReader>>,
) -> crate::Result<Searcher> {
Self::from_segment_readers_with_generation_id(context, segment_readers, 0)
}
/// Same as [`Searcher::from_segment_readers`] but allows setting
/// a custom generation id.
pub fn from_segment_readers_with_generation_id<Ctx: Into<SearcherContext>>(
context: Ctx,
segment_readers: Vec<Arc<dyn SegmentReader>>,
generation_id: u64,
) -> crate::Result<Searcher> {
let context = context.into();
let generation = SearcherGeneration::from_segment_readers(&segment_readers, generation_id);
let tracked_generation = Inventory::default().track(generation);
let inner = SearcherInner::new(
context,
segment_readers,
tracked_generation,
DOCSTORE_CACHE_CAPACITY,
)?;
Ok(Arc::new(inner).into())
}
/// Returns the search context associated with the `Searcher`.
pub fn context(&self) -> &SearcherContext {
&self.inner.context
}
/// Deprecated alias for [`Searcher::context`].
#[deprecated(note = "use Searcher::context()")]
pub fn index(&self) -> &SearcherContext {
self.context()
}
/// Access the search executor associated with this searcher.
pub fn search_executor(&self) -> &Executor {
self.context().search_executor()
}
/// Access the tokenizer manager associated with this searcher.
pub fn tokenizers(&self) -> &TokenizerManager {
self.context().tokenizers()
}
/// Access the fast field tokenizer manager associated with this searcher.
pub fn fast_field_tokenizer(&self) -> &TokenizerManager {
self.context().fast_field_tokenizer()
}
/// Get the tokenizer associated with a specific field.
pub fn tokenizer_for_field(&self, field: Field) -> crate::Result<TextAnalyzer> {
self.context().tokenizer_for_field(field)
}
/// [`SearcherGeneration`] which identifies the version of the snapshot held by this `Searcher`.
@@ -85,7 +239,7 @@ impl Searcher {
///
/// The searcher uses the segment ordinal to route the
/// request to the right `Segment`.
pub fn doc<D: DocumentDeserialize>(&self, doc_address: DocAddress) -> crate::Result<D> {
pub fn doc(&self, doc_address: DocAddress) -> crate::Result<TantivyDocument> {
let store_reader = &self.inner.store_readers[doc_address.segment_ord as usize];
store_reader.get(doc_address.doc_id)
}
@@ -105,18 +259,15 @@ impl Searcher {
/// Fetches a document in an asynchronous manner.
#[cfg(feature = "quickwit")]
pub async fn doc_async<D: DocumentDeserialize>(
&self,
doc_address: DocAddress,
) -> crate::Result<D> {
let executor = self.inner.index.search_executor();
pub async fn doc_async(&self, doc_address: DocAddress) -> crate::Result<TantivyDocument> {
let executor = self.search_executor();
let store_reader = &self.inner.store_readers[doc_address.segment_ord as usize];
store_reader.get_async(doc_address.doc_id, executor).await
}
/// Access the schema associated with the index of this searcher.
pub fn schema(&self) -> &Schema {
&self.inner.schema
self.context().schema()
}
/// Returns the overall number of documents in the index.
@@ -154,13 +305,13 @@ impl Searcher {
}
/// Return the list of segment readers
pub fn segment_readers(&self) -> &[SegmentReader] {
pub fn segment_readers(&self) -> &[Arc<dyn SegmentReader>] {
&self.inner.segment_readers
}
/// Returns the segment_reader associated with the given segment_ord
pub fn segment_reader(&self, segment_ord: u32) -> &SegmentReader {
&self.inner.segment_readers[segment_ord as usize]
pub fn segment_reader(&self, segment_ord: u32) -> &dyn SegmentReader {
self.inner.segment_readers[segment_ord as usize].as_ref()
}
/// Runs a query on the segment readers wrapped by the searcher.
@@ -201,7 +352,7 @@ impl Searcher {
} else {
EnableScoring::disabled_from_searcher(self)
};
let executor = self.inner.index.search_executor();
let executor = self.search_executor();
self.search_with_executor(query, collector, executor, enabled_scoring)
}
@@ -229,7 +380,11 @@ impl Searcher {
let segment_readers = self.segment_readers();
let fruits = executor.map(
|(segment_ord, segment_reader)| {
collector.collect_segment(weight.as_ref(), segment_ord as u32, segment_reader)
collector.collect_segment(
weight.as_ref(),
segment_ord as u32,
segment_reader.as_ref(),
)
},
segment_readers.iter().enumerate(),
)?;
@@ -257,19 +412,17 @@ impl From<Arc<SearcherInner>> for Searcher {
/// It guarantees that the `Segment` will not be removed before
/// the destruction of the `Searcher`.
pub(crate) struct SearcherInner {
schema: Schema,
index: Index,
segment_readers: Vec<SegmentReader>,
store_readers: Vec<StoreReader>,
context: SearcherContext,
segment_readers: Vec<Arc<dyn SegmentReader>>,
store_readers: Vec<Box<dyn StoreReader>>,
generation: TrackedObject<SearcherGeneration>,
}
impl SearcherInner {
/// Creates a new `Searcher`
pub(crate) fn new(
schema: Schema,
index: Index,
segment_readers: Vec<SegmentReader>,
context: SearcherContext,
segment_readers: Vec<Arc<dyn SegmentReader>>,
generation: TrackedObject<SearcherGeneration>,
doc_store_cache_num_blocks: usize,
) -> io::Result<SearcherInner> {
@@ -281,14 +434,13 @@ impl SearcherInner {
generation.segments(),
"Set of segments referenced by this Searcher and its SearcherGeneration must match"
);
let store_readers: Vec<StoreReader> = segment_readers
let store_readers: Vec<Box<dyn StoreReader>> = segment_readers
.iter()
.map(|segment_reader| segment_reader.get_store_reader(doc_store_cache_num_blocks))
.collect::<io::Result<Vec<_>>>()?;
Ok(SearcherInner {
schema,
index,
context,
segment_readers,
store_readers,
generation,
@@ -301,7 +453,7 @@ impl fmt::Debug for Searcher {
let segment_ids = self
.segment_readers()
.iter()
.map(SegmentReader::segment_id)
.map(|segment_reader| segment_reader.segment_id())
.collect::<Vec<_>>();
write!(f, "Searcher({segment_ids:?})")
}

View File

@@ -7,8 +7,8 @@ use crate::query::TermQuery;
use crate::schema::{Field, IndexRecordOption, Schema, INDEXED, STRING, TEXT};
use crate::tokenizer::TokenizerManager;
use crate::{
Directory, DocSet, Index, IndexBuilder, IndexReader, IndexSettings, IndexWriter, ReloadPolicy,
TantivyDocument, Term,
Directory, DocSet, Executor, Index, IndexBuilder, IndexReader, IndexSettings, IndexWriter,
ReloadPolicy, Searcher, SearcherContext, TantivyDocument, Term,
};
#[test]
@@ -300,6 +300,40 @@ fn test_single_segment_index_writer() -> crate::Result<()> {
Ok(())
}
#[test]
fn test_searcher_from_external_segment_readers() -> crate::Result<()> {
let mut schema_builder = Schema::builder();
let text_field = schema_builder.add_text_field("text", TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema.clone());
let mut writer: IndexWriter = index.writer_for_tests()?;
writer.add_document(doc!(text_field => "hello"))?;
writer.add_document(doc!(text_field => "hello"))?;
writer.commit()?;
let reader = index.reader()?;
let searcher = reader.searcher();
let segment_readers = searcher.segment_readers().to_vec();
let context = SearcherContext::new(
schema,
Executor::single_thread(),
TokenizerManager::default(),
TokenizerManager::default(),
);
let custom_searcher =
Searcher::from_segment_readers_with_generation_id(context, segment_readers, 42)?;
let term_query = TermQuery::new(
Term::from_field_text(text_field, "hello"),
IndexRecordOption::Basic,
);
let count = custom_searcher.search(&term_query, &Count)?;
assert_eq!(count, 2);
assert_eq!(custom_searcher.generation().generation_id(), 42);
assert_eq!(custom_searcher.segment_readers().len(), 1);
Ok(())
}
#[test]
fn test_merging_segment_update_docfreq() {
let mut schema_builder = Schema::builder();

View File

@@ -167,7 +167,9 @@ impl CompositeFile {
.map(|byte_range| self.data.slice(byte_range.clone()))
}
/// Returns the space usage per field in this composite file.
/// Returns per-field byte usage for all slices stored in this composite file.
///
/// The provided `schema` is used to resolve field ids into field names.
pub fn space_usage(&self, schema: &Schema) -> PerFieldSpaceUsage {
let mut fields = Vec::new();
for (&field_addr, byte_range) in &self.offsets_index {

View File

@@ -1,6 +1,7 @@
use std::borrow::BorrowMut;
use std::ops::{Deref as _, DerefMut as _};
use common::{BitSet, TinySet};
use common::BitSet;
use crate::fastfield::AliveBitSet;
use crate::DocId;
@@ -16,12 +17,6 @@ pub const TERMINATED: DocId = i32::MAX as u32;
/// exactly this size as long as we can fill the buffer.
pub const COLLECT_BLOCK_BUFFER_LEN: usize = 64;
/// Number of `TinySet` (64-bit) buckets in a block used by [`DocSet::fill_bitset_block`].
pub const BLOCK_NUM_TINYBITSETS: usize = 16;
/// Number of doc IDs covered by one block: `BLOCK_NUM_TINYBITSETS * 64 = 1024`.
pub const BLOCK_WINDOW: u32 = BLOCK_NUM_TINYBITSETS as u32 * 64;
/// Represents an iterable set of sorted doc ids.
pub trait DocSet: Send {
/// Goes to the next element.
@@ -181,31 +176,6 @@ pub trait DocSet: Send {
self.size_hint() as u64
}
/// Fills a bitmask representing which documents in `[min_doc, min_doc + BLOCK_WINDOW)` are
/// present in this docset.
///
/// The window is divided into `BLOCK_NUM_TINYBITSETS` buckets of 64 docs each.
/// Returns the next doc `>= min_doc + BLOCK_WINDOW`, or `TERMINATED` if exhausted.
fn fill_bitset_block(
&mut self,
min_doc: DocId,
mask: &mut [TinySet; BLOCK_NUM_TINYBITSETS],
) -> DocId {
self.seek(min_doc);
let horizon = min_doc + BLOCK_WINDOW;
loop {
let doc = self.doc();
if doc >= horizon {
return doc;
}
let delta = doc - min_doc;
mask[(delta / 64) as usize].insert_mut(delta % 64);
if self.advance() == TERMINATED {
return TERMINATED;
}
}
}
/// Returns the number documents matching.
/// Calling this method consumes the `DocSet`.
fn count(&mut self, alive_bitset: &AliveBitSet) -> u32 {
@@ -260,18 +230,6 @@ impl DocSet for &mut dyn DocSet {
(**self).seek_danger(target)
}
fn fill_buffer(&mut self, buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN]) -> usize {
(**self).fill_buffer(buffer)
}
fn fill_bitset_block(
&mut self,
min_doc: DocId,
mask: &mut [TinySet; BLOCK_NUM_TINYBITSETS],
) -> DocId {
(**self).fill_bitset_block(min_doc, mask)
}
fn doc(&self) -> u32 {
(**self).doc()
}
@@ -308,8 +266,10 @@ impl<TDocSet: DocSet + ?Sized> DocSet for Box<TDocSet> {
self.deref_mut().seek(target)
}
#[inline]
fn seek_danger(&mut self, target: DocId) -> SeekDangerResult {
self.deref_mut().seek_danger(target)
let unboxed: &mut TDocSet = self.borrow_mut();
unboxed.seek_danger(target)
}
#[inline]
@@ -317,15 +277,6 @@ impl<TDocSet: DocSet + ?Sized> DocSet for Box<TDocSet> {
self.deref_mut().fill_buffer(buffer)
}
fn fill_bitset_block(
&mut self,
min_doc: DocId,
mask: &mut [TinySet; BLOCK_NUM_TINYBITSETS],
) -> DocId {
let unboxed: &mut TDocSet = &mut **self;
unboxed.fill_bitset_block(min_doc, mask)
}
#[inline]
fn doc(&self) -> DocId {
self.deref().doc()

View File

@@ -84,9 +84,7 @@ mod tests {
let mut facet = Facet::default();
facet_reader.facet_from_ord(0, &mut facet).unwrap();
assert_eq!(facet.to_path_string(), "/a/b");
let doc = searcher
.doc::<TantivyDocument>(DocAddress::new(0u32, 0u32))
.unwrap();
let doc = searcher.doc(DocAddress::new(0u32, 0u32)).unwrap();
let value = doc
.get_first(facet_field)
.and_then(|v| v.as_value().as_facet());
@@ -145,7 +143,7 @@ mod tests {
let mut facet_ords = Vec::new();
facet_ords.extend(facet_reader.facet_ords(0u32));
assert_eq!(&facet_ords, &[0u64]);
let doc = searcher.doc::<TantivyDocument>(DocAddress::new(0u32, 0u32))?;
let doc = searcher.doc(DocAddress::new(0u32, 0u32))?;
let value: Option<Facet> = doc
.get_first(facet_field)
.and_then(|v| v.as_facet())

Some files were not shown because too many files have changed in this diff Show More