From c0452766e66e3f146d8c2a5469b71854342d63ee Mon Sep 17 00:00:00 2001 From: Pascal Seitz Date: Mon, 30 Dec 2024 14:28:13 +0100 Subject: [PATCH] add mixed AND OR test, fix buffered_union --- src/docset.rs | 14 +++--- src/query/mod.rs | 74 ++++++++++++++++++++++++++++++- src/query/union/buffered_union.rs | 33 ++++++++++---- 3 files changed, 105 insertions(+), 16 deletions(-) diff --git a/src/docset.rs b/src/docset.rs index 6fe3ec9c5..01ea1125a 100644 --- a/src/docset.rs +++ b/src/docset.rs @@ -53,14 +53,16 @@ pub trait DocSet: Send { /// Seeks to the target if possible and returns true if the target is in the DocSet. /// - /// DocSets that already have an efficient `seek` method don't need to implement `seek_exact`. - /// All wrapper DocSets should forward `seek_exact` to the underlying DocSet. + /// DocSets that already have an efficient `seek` method don't need to implement + /// `seek_into_the_danger_zone`. All wrapper DocSets should forward + /// `seek_into_the_danger_zone` to the underlying DocSet. /// /// ## API Behaviour - /// If `seek_exact` is returning true, a call to `doc()` has to return target. - /// If `seek_exact` is returning false, a call to `doc()` may return any doc between - /// the last doc that matched and target or a doc that is a valid next hit after target. - /// The DocSet is considered to be in an invalid state until `seek_exact` returns true again. + /// If `seek_into_the_danger_zone` is returning true, a call to `doc()` has to return target. + /// If `seek_into_the_danger_zone` is returning false, a call to `doc()` may return any doc + /// between the last doc that matched and target or a doc that is a valid next hit after + /// target. The DocSet is considered to be in an invalid state until + /// `seek_into_the_danger_zone` returns true again. /// /// `target` needs to be equal or larger than `doc` when in a valid state. /// diff --git a/src/query/mod.rs b/src/query/mod.rs index d609a0402..478f26efb 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -70,9 +70,81 @@ pub use self::weight::Weight; #[cfg(test)] mod tests { + use crate::collector::TopDocs; + use crate::query::phrase_query::tests::create_index; use crate::query::QueryParser; use crate::schema::{Schema, TEXT}; - use crate::{Index, Term}; + use crate::{DocAddress, Index, Term}; + + #[test] + pub fn test_mixed_intersection_and_union() -> crate::Result<()> { + let index = create_index(&["a b", "a c", "a b c", "b"])?; + let schema = index.schema(); + let text_field = schema.get_field("text").unwrap(); + let searcher = index.reader()?.searcher(); + + let do_search = |term: &str| { + let query = QueryParser::for_index(&index, vec![text_field]) + .parse_query(term) + .unwrap(); + let top_docs: Vec<(f32, DocAddress)> = + searcher.search(&query, &TopDocs::with_limit(10)).unwrap(); + + top_docs.iter().map(|el| el.1.doc_id).collect::>() + }; + + assert_eq!(do_search("a AND b"), vec![0, 2]); + assert_eq!(do_search("(a OR b) AND C"), vec![2, 1]); + // The intersection code has special code for more than 2 intersections + // left, right + others + // The will place the union in the "others" insersection to that seek_into_the_danger_zone + // is called + assert_eq!( + do_search("(a OR b) AND (c OR a) AND (b OR c)"), + vec![2, 1, 0] + ); + + Ok(()) + } + + #[test] + pub fn test_mixed_intersection_and_union_with_skip() -> crate::Result<()> { + // Test 4096 skip in BufferedUnionScorer + let mut data: Vec<&str> = Vec::new(); + data.push("a b"); + let zz_data = vec!["z z"; 5000]; + data.extend_from_slice(&zz_data); + data.extend_from_slice(&["a c"]); + data.extend_from_slice(&zz_data); + data.extend_from_slice(&["a b c", "b"]); + let index = create_index(&data)?; + let schema = index.schema(); + let text_field = schema.get_field("text").unwrap(); + let searcher = index.reader()?.searcher(); + + let do_search = |term: &str| { + let query = QueryParser::for_index(&index, vec![text_field]) + .parse_query(term) + .unwrap(); + let top_docs: Vec<(f32, DocAddress)> = + searcher.search(&query, &TopDocs::with_limit(10)).unwrap(); + + top_docs.iter().map(|el| el.1.doc_id).collect::>() + }; + + assert_eq!(do_search("a AND b"), vec![0, 10002]); + assert_eq!(do_search("(a OR b) AND C"), vec![10002, 5001]); + // The intersection code has special code for more than 2 intersections + // left, right + others + // The will place the union in the "others" insersection to that seek_into_the_danger_zone + // is called + assert_eq!( + do_search("(a OR b) AND (c OR a) AND (b OR c)"), + vec![10002, 5001, 0] + ); + + Ok(()) + } #[test] fn test_query_terms() { diff --git a/src/query/union/buffered_union.rs b/src/query/union/buffered_union.rs index 50dde16a7..b9a0bdedc 100644 --- a/src/query/union/buffered_union.rs +++ b/src/query/union/buffered_union.rs @@ -15,7 +15,7 @@ const HORIZON: u32 = 64u32 * 64u32; // This function is similar except that it does is not unstable, and // it does not keep the original vector ordering. // -// Also, it does not "yield" any elements. +// Elements are dropped and not yielded. fn unordered_drain_filter(v: &mut Vec, mut predicate: P) where P: FnMut(&mut T) -> bool { let mut i = 0; @@ -141,6 +141,11 @@ impl BufferedUnionScorer bool { + let gap = target - self.offset; + gap < HORIZON + } } impl DocSet for BufferedUnionScorer @@ -216,15 +221,25 @@ where } fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool { - let is_hit = self - .docsets - .iter_mut() - .all(|docset| docset.seek_into_the_danger_zone(target)); - // The API requires the DocSet to be in a valid state when `seek_exact` returns true. - if is_hit { - self.seek(target); + if self.is_in_horizon(target) { + // Our value is within the buffered horizon and the docset may already have been + // processed and removed, so we need to use seek, which uses the regular advance. + self.seek(target) == target + } else { + // The docsets are not in the buffered range, so we can use seek_into_the_danger_zone + // of the underlying docsets + let is_hit = self + .docsets + .iter_mut() + .any(|docset| docset.seek_into_the_danger_zone(target)); + + // The API requires the DocSet to be in a valid state when `seek_into_the_danger_zone` + // returns true. + if is_hit { + self.seek(target); + } + is_hit } - is_hit } fn doc(&self) -> DocId {