From 5dae6e6bbc5ea9638e5e53161627e951639870e6 Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Sun, 18 Feb 2018 10:28:43 +0900 Subject: [PATCH] Downcast `TermScorer` for intersection when all legs are TermScorers --- Cargo.toml | 1 + src/lib.rs | 3 +++ src/query/boolean_query/boolean_weight.rs | 25 +++++++++++++++++++++-- src/query/scorer.rs | 5 ++++- 4 files changed, 31 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 2936fa760..ab12a7ae7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,6 +41,7 @@ error-chain = "0.8" owning_ref = "0.3" stable_deref_trait = "1.0.0" rust-stemmers = "0.1.0" +downcast = "0.9" [target.'cfg(windows)'.dependencies] winapi = "0.2" diff --git a/src/lib.rs b/src/lib.rs index 78466f126..239ed5b52 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -167,6 +167,9 @@ extern crate test; extern crate tinysegmenter; +#[macro_use] +extern crate downcast; + #[cfg(test)] mod functional_test; diff --git a/src/query/boolean_query/boolean_weight.rs b/src/query/boolean_query/boolean_weight.rs index 5f71d024f..6db34a170 100644 --- a/src/query/boolean_query/boolean_weight.rs +++ b/src/query/boolean_query/boolean_weight.rs @@ -4,6 +4,9 @@ use postings::{Intersection, Union}; use std::collections::HashMap; use query::EmptyScorer; use query::Scorer; +use downcast::Downcast; +use query::term_query::TermScorer; +use std::borrow::Borrow; use query::Exclude; use query::Occur; use query::RequiredOptionalScorer; @@ -59,8 +62,26 @@ impl BooleanWeight { if scorers.len() == 1 { scorers.into_iter().next().unwrap() } else { - let scorer: Box = box Intersection::from(scorers); - scorer + if scorers + .iter() + .all(|scorer| { + let scorer_ref:&Scorer = scorer.borrow(); + Downcast::::is_type(scorer_ref) + }) { + let scorers: Vec = scorers.into_iter() + .map(|scorer| { + *Downcast::::downcast(scorer) + .expect("downcasting should not have failed, we\ + checked in advance that the type were correct.") + + }) + .collect(); + let scorer: Box = box Intersection::from(scorers); + scorer + } else { + let scorer: Box = box Intersection::from(scorers); + scorer + } } }); diff --git a/src/query/scorer.rs b/src/query/scorer.rs index cbe04db38..82f0c3a27 100644 --- a/src/query/scorer.rs +++ b/src/query/scorer.rs @@ -5,11 +5,12 @@ use collector::Collector; use postings::SkipResult; use common::BitSet; use std::ops::DerefMut; +use downcast; /// Scored set of documents matching a query within a specific segment. /// /// See [`Query`](./trait.Query.html). -pub trait Scorer: DocSet + 'static { +pub trait Scorer: downcast::Any + DocSet + 'static { /// Returns the score. /// /// This method will perform a bit of computation and is not cached. @@ -24,6 +25,8 @@ pub trait Scorer: DocSet + 'static { } } +downcast!(Scorer); + impl<'a> Scorer for Box { fn score(&mut self) -> Score { self.deref_mut().score()