This commit is contained in:
Paul Masurel
2019-01-29 11:45:30 +01:00
parent 6a547b0b5f
commit 4c93b096eb
6 changed files with 34 additions and 35 deletions

View File

@@ -36,7 +36,8 @@ impl<TCollector: Collector> Collector for CollectorWrapper<TCollector> {
let typed_fruit: Vec<TCollector::Fruit> = children
.into_iter()
.map(|untyped_fruit| {
untyped_fruit.downcast::<TCollector::Fruit>()
untyped_fruit
.downcast::<TCollector::Fruit>()
.map(|boxed_but_typed| *boxed_but_typed)
.map_err(|_| {
TantivyError::InvalidArgument("Failed to cast child fruit.".to_string())
@@ -87,7 +88,10 @@ pub struct FruitHandle<TFruit: Fruit> {
impl<TFruit: Fruit> FruitHandle<TFruit> {
pub fn extract(self, fruits: &mut MultiFruit) -> TFruit {
let boxed_fruit = fruits.sub_fruits[self.pos].take().expect("");
*boxed_fruit.downcast::<TFruit>().map_err(|_| ()).expect("Failed to downcast collector fruit.")
*boxed_fruit
.downcast::<TFruit>()
.map_err(|_| ())
.expect("Failed to downcast collector fruit.")
}
}

View File

@@ -1,4 +1,3 @@
use common::BitSet;
use common::HasLen;
use common::{BinarySerializable, VInt};
@@ -135,18 +134,18 @@ fn binary_search(arr: &[u32], mut base: usize, mut len: usize, target: u32) -> u
let mid = base + half;
let pivot = *unsafe { arr.get_unchecked(mid) };
base = if pivot > target { base } else { mid };
// Unfortunately, rustc does not compiles this to a conditional mov.
// since rustc 1.25.
//
// See https://github.com/rust-lang/rust/issues/53823 for detail
//
// unsafe {
// let pivot: u32 = *arr.get_unchecked(mid);
// asm!("cmpl $2, $1\ncmovge $3, $0"
// : "+r"(base)
// : "r"(target), "r"(pivot), "r"(mid))
// ;
// }
// Unfortunately, rustc does not compiles this to a conditional mov.
// since rustc 1.25.
//
// See https://github.com/rust-lang/rust/issues/53823 for detail
//
// unsafe {
// let pivot: u32 = *arr.get_unchecked(mid);
// asm!("cmpl $2, $1\ncmovge $3, $0"
// : "+r"(base)
// : "r"(target), "r"(pivot), "r"(mid))
// ;
// }
len -= half;
}
base + ((*unsafe { arr.get_unchecked(base) } < target) as usize)
@@ -155,12 +154,12 @@ fn binary_search(arr: &[u32], mut base: usize, mut len: usize, target: u32) -> u
fn exponential_search(arr: &[u32], target: u32) -> (usize, usize) {
let end = arr.len();
let mut begin = 0;
for &pivot in &[1,3,7,15,31,63] {
for &pivot in &[1, 3, 7, 15, 31, 63] {
if pivot >= end {
break;
}
if arr[pivot] > target {
return (begin, pivot);
return (begin, pivot);
}
begin = pivot;
}
@@ -642,6 +641,7 @@ impl<'b> Streamer<'b> for BlockSegmentPostings {
#[cfg(test)]
mod tests {
use super::binary_search;
use super::exponential_search;
use super::search_within_block;
use super::BlockSegmentPostings;
@@ -657,12 +657,11 @@ mod tests {
use schema::INT_INDEXED;
use DocId;
use SkipResult;
use super::binary_search;
#[test]
fn test_binary_search() {
let len: usize = 50;
let arr: Vec<u32> = (0..len).map(|el| 1u32 + (el as u32)*2).collect();
let arr: Vec<u32> = (0..len).map(|el| 1u32 + (el as u32) * 2).collect();
for target in 1..*arr.last().unwrap() {
let res = binary_search(&arr[..], 0, len, target);
if res > 0 {
@@ -702,10 +701,10 @@ mod tests {
#[test]
fn test_exponentiel_search() {
assert_eq!(exponential_search(&[1, 2],0), (0, 1));
assert_eq!(exponential_search(&[1, 2], 1 ), (0, 1));
assert_eq!(exponential_search(&[1, 2], 0), (0, 1));
assert_eq!(exponential_search(&[1, 2], 1), (0, 1));
assert_eq!(
exponential_search(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], 7 ),
exponential_search(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], 7),
(3, 7)
);
}

View File

@@ -1,5 +1,4 @@
use core::SegmentReader;
use downcast_rs::Downcast;
use query::intersect_scorers;
use query::score_combiner::{DoNothingCombiner, ScoreCombiner, SumWithCoordsCombiner};
use query::term_query::TermScorer;
@@ -23,13 +22,11 @@ where
}
{
let is_all_term_queries = scorers.iter().all(|scorer| {
scorer.is::<TermScorer>()
});
let is_all_term_queries = scorers.iter().all(|scorer| scorer.is::<TermScorer>());
if is_all_term_queries {
let scorers: Vec<TermScorer> = scorers
.into_iter()
.map(|scorer| *(scorer.downcast::<TermScorer>().map_err(|_| ()).unwrap() ))
.map(|scorer| *(scorer.downcast::<TermScorer>().map_err(|_| ()).unwrap()))
.collect();
let scorer: Box<Scorer> = Box::new(Union::<TermScorer, TScoreCombiner>::from(scorers));
return scorer;

View File

@@ -103,7 +103,8 @@ mod tests {
let query = query_parser.parse_query("+a b").unwrap();
let weight = query.weight(&searcher, true).unwrap();
let scorer = weight.scorer(searcher.segment_reader(0u32)).unwrap();
assert!(scorer.is::<RequiredOptionalScorer<Box<Scorer>, Box<Scorer>, SumWithCoordsCombiner>>());
assert!(scorer
.is::<RequiredOptionalScorer<Box<Scorer>, Box<Scorer>, SumWithCoordsCombiner>>());
}
{
let query = query_parser.parse_query("+a b").unwrap();
@@ -111,8 +112,7 @@ mod tests {
let scorer = weight.scorer(searcher.segment_reader(0u32)).unwrap();
assert!(scorer.is::<TermScorer>());
}
}
}
#[test]
pub fn test_boolean_query() {

View File

@@ -1,9 +1,9 @@
use docset::{DocSet, SkipResult};
use query::term_query::TermScorer;
use query::EmptyScorer;
use query::Scorer;
use DocId;
use Score;
use query::term_query::TermScorer;
/// Returns the intersection scorer.
///
@@ -24,9 +24,9 @@ pub fn intersect_scorers(mut scorers: Vec<Box<Scorer>>) -> Box<Scorer> {
(Some(single_docset), None) => single_docset,
(Some(left), Some(right)) => {
{
let all_term_scorers = [&left, &right].iter().all(|&scorer| {
scorer.is::<TermScorer>()
});
let all_term_scorers = [&left, &right]
.iter()
.all(|&scorer| scorer.is::<TermScorer>());
if all_term_scorers {
let left = *(left.downcast::<TermScorer>().map_err(|_| ()).unwrap());
let right = *(right.downcast::<TermScorer>().map_err(|_| ()).unwrap());

View File

@@ -25,7 +25,6 @@ pub trait Scorer: downcast_rs::Downcast + DocSet + 'static {
impl_downcast!(Scorer);
impl Scorer for Box<Scorer> {
fn score(&mut self) -> Score {
self.deref_mut().score()