diff --git a/src/collector/top_score_collector.rs b/src/collector/top_score_collector.rs index 081b62145..a1067aba4 100644 --- a/src/collector/top_score_collector.rs +++ b/src/collector/top_score_collector.rs @@ -2,11 +2,13 @@ use std::fmt; use std::marker::PhantomData; use std::sync::Arc; -use columnar::ColumnValues; +use columnar::{ColumnValues, StrColumn}; use serde::{Deserialize, Serialize}; use super::Collector; -use crate::collector::custom_score_top_collector::CustomScoreTopCollector; +use crate::collector::custom_score_top_collector::{ + CustomScoreTopCollector, CustomScoreTopSegmentCollector, +}; use crate::collector::top_collector::{ComparableDoc, TopCollector, TopSegmentCollector}; use crate::collector::tweak_score_top_collector::TweakedScoreTopCollector; use crate::collector::{ @@ -14,6 +16,7 @@ use crate::collector::{ }; use crate::fastfield::{FastFieldNotAvailableError, FastValue}; use crate::query::Weight; +use crate::termdict::TermOrdinal; use crate::{DocAddress, DocId, Order, Score, SegmentOrdinal, SegmentReader, TantivyError}; struct FastFieldConvertCollector< @@ -83,6 +86,163 @@ where } } +struct StringConvertCollector { + pub collector: CustomScoreTopCollector, + pub field: String, + order: Order, + limit: usize, + offset: usize, +} + +impl Collector for StringConvertCollector { + type Fruit = Vec<(String, DocAddress)>; + + type Child = StringConvertSegmentCollector; + + fn for_segment( + &self, + segment_local_id: crate::SegmentOrdinal, + segment: &SegmentReader, + ) -> crate::Result { + let schema = segment.schema(); + let field = schema.get_field(&self.field)?; + let field_entry = schema.get_field_entry(field); + if !field_entry.is_fast() { + return Err(TantivyError::SchemaError(format!( + "Field {:?} is not a fast field.", + field_entry.name() + ))); + } + let requested_type = crate::schema::Type::Str; + let schema_type = field_entry.field_type().value_type(); + if schema_type != requested_type { + return Err(TantivyError::SchemaError(format!( + "Field {:?} is of type {schema_type:?}!={requested_type:?}", + field_entry.name() + ))); + } + let ff = segment + .fast_fields() + .str(&self.field)? + .expect("ff should be a str field"); + Ok(StringConvertSegmentCollector { + collector: self.collector.for_segment(segment_local_id, segment)?, + ff, + order: self.order.clone(), + }) + } + + fn requires_scoring(&self) -> bool { + self.collector.requires_scoring() + } + + fn merge_fruits( + &self, + child_fruits: Vec<::Fruit>, + ) -> crate::Result { + if self.limit == 0 { + return Ok(Vec::new()); + } + if self.order.is_desc() { + let mut top_collector: TopNComputer<_, _, true> = + TopNComputer::new(self.limit + self.offset); + for child_fruit in child_fruits { + for (feature, doc) in child_fruit { + top_collector.push(feature, doc); + } + } + Ok(top_collector + .into_sorted_vec() + .into_iter() + .skip(self.offset) + .map(|cdoc| (cdoc.feature, cdoc.doc)) + .collect()) + } else { + let mut top_collector: TopNComputer<_, _, false> = + TopNComputer::new(self.limit + self.offset); + for child_fruit in child_fruits { + for (feature, doc) in child_fruit { + top_collector.push(feature, doc); + } + } + + Ok(top_collector + .into_sorted_vec() + .into_iter() + .skip(self.offset) + .map(|cdoc| (cdoc.feature, cdoc.doc)) + .collect()) + } + } +} + +struct StringConvertSegmentCollector { + pub collector: CustomScoreTopSegmentCollector, + ff: StrColumn, + order: Order, +} + +impl SegmentCollector for StringConvertSegmentCollector { + type Fruit = Vec<(String, DocAddress)>; + + fn collect(&mut self, doc: DocId, score: Score) { + self.collector.collect(doc, score); + } + + fn harvest(self) -> Vec<(String, DocAddress)> { + let top_ordinals: Vec<(TermOrdinal, DocAddress)> = self.collector.harvest(); + + // Collect terms. + let mut terms: Vec = Vec::with_capacity(top_ordinals.len()); + let result = if self.order.is_asc() { + self.ff.dictionary().sorted_ords_to_term_cb( + top_ordinals.iter().map(|(term_ord, _)| u64::MAX - term_ord), + |term| { + terms.push( + std::str::from_utf8(term) + .expect("Failed to decode term as unicode") + .to_owned(), + ); + Ok(()) + }, + ) + } else { + self.ff.dictionary().sorted_ords_to_term_cb( + top_ordinals.iter().rev().map(|(term_ord, _)| *term_ord), + |term| { + terms.push( + std::str::from_utf8(term) + .expect("Failed to decode term as unicode") + .to_owned(), + ); + Ok(()) + }, + ) + }; + + assert!( + result.expect("Failed to read terms from term dictionary"), + "Not all terms were matched in segment." + ); + + // Zip them back with their docs. + if self.order.is_asc() { + terms + .into_iter() + .zip(top_ordinals) + .map(|(term, (_, doc))| (term, doc)) + .collect() + } else { + terms + .into_iter() + .rev() + .zip(top_ordinals) + .map(|(term, (_, doc))| (term, doc)) + .collect() + } + } +} + /// The `TopDocs` collector keeps track of the top `K` documents /// sorted by their score. /// @@ -410,6 +570,30 @@ impl TopDocs { } } + /// Like `order_by_fast_field`, but for a `String` fast field. + pub fn order_by_string_fast_field( + self, + fast_field: impl ToString, + order: Order, + ) -> impl Collector> { + let limit = self.0.limit; + let offset = self.0.offset; + let u64_collector = CustomScoreTopCollector::new( + ScorerByField { + field: fast_field.to_string(), + order: order.clone(), + }, + self.0.into_tscore(), + ); + StringConvertCollector { + collector: u64_collector, + field: fast_field.to_string(), + order, + limit, + offset, + } + } + /// Ranks the documents using a custom score. /// /// This method offers a convenient way to tweak or replace @@ -1214,6 +1398,94 @@ mod tests { Ok(()) } + #[test] + fn test_top_field_collector_string() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + let city = schema_builder.add_text_field("city", TEXT | FAST); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + let mut index_writer = index.writer_for_tests()?; + index_writer.add_document(doc!( + city => "austin", + ))?; + index_writer.add_document(doc!( + city => "greenville", + ))?; + index_writer.add_document(doc!( + city => "tokyo", + ))?; + index_writer.commit()?; + + fn query( + index: &Index, + order: Order, + limit: usize, + offset: usize, + ) -> crate::Result> { + let searcher = index.reader()?.searcher(); + let top_collector = TopDocs::with_limit(limit) + .and_offset(offset) + .order_by_string_fast_field("city", order); + searcher.search(&AllQuery, &top_collector) + } + + assert_eq!( + &query(&index, Order::Desc, 3, 0)?, + &[ + ("tokyo".to_owned(), DocAddress::new(0, 2)), + ("greenville".to_owned(), DocAddress::new(0, 1)), + ("austin".to_owned(), DocAddress::new(0, 0)), + ] + ); + + assert_eq!( + &query(&index, Order::Desc, 2, 0)?, + &[ + ("tokyo".to_owned(), DocAddress::new(0, 2)), + ("greenville".to_owned(), DocAddress::new(0, 1)), + ] + ); + + assert_eq!(&query(&index, Order::Desc, 3, 3)?, &[]); + + assert_eq!( + &query(&index, Order::Desc, 2, 1)?, + &[ + ("greenville".to_owned(), DocAddress::new(0, 1)), + ("austin".to_owned(), DocAddress::new(0, 0)), + ] + ); + + assert_eq!( + &query(&index, Order::Asc, 3, 0)?, + &[ + ("austin".to_owned(), DocAddress::new(0, 0)), + ("greenville".to_owned(), DocAddress::new(0, 1)), + ("tokyo".to_owned(), DocAddress::new(0, 2)), + ] + ); + + assert_eq!( + &query(&index, Order::Asc, 2, 1)?, + &[ + ("greenville".to_owned(), DocAddress::new(0, 1)), + ("tokyo".to_owned(), DocAddress::new(0, 2)), + ] + ); + + assert_eq!( + &query(&index, Order::Asc, 2, 0)?, + &[ + ("austin".to_owned(), DocAddress::new(0, 0)), + ("greenville".to_owned(), DocAddress::new(0, 1)), + ] + ); + + assert_eq!(&query(&index, Order::Asc, 3, 3)?, &[]); + + Ok(()) + } + #[test] #[should_panic] fn test_field_does_not_exist() {