feat: impl vector index query (#7564)

* feat: impl vector index query

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>

* feat: remove VectorSearchRule and merge it into scan hint rule

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>

* refactor: vector search hint

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>

* test: join and subquery

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>

* fix: clippy when feature disabled

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>

* fix: push hint only when column is non-nullable or an explicit IS NOT NULL filter exists

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>

* fix: transformed = true

Co-authored-by: Yingwen <realevenyag@gmail.com>
Signed-off-by: Dennis Zhuang <killme2008@gmail.com>

* chore: remove adpater vector hint

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>

* chore: revert transformed

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>

---------

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>
Co-authored-by: Yingwen <realevenyag@gmail.com>
This commit is contained in:
dennis zhuang
2026-01-28 11:40:56 +08:00
committed by GitHub
parent 124478f577
commit 238bc4fa2c
13 changed files with 1208 additions and 16 deletions

View File

@@ -73,6 +73,8 @@ jobs:
save-if: ${{ github.ref == 'refs/heads/main' }}
- name: Run cargo check
run: cargo check --locked --workspace --all-targets
- name: Run cargo check (all features)
run: cargo check --locked --workspace --all-targets --all-features
toml:
if: ${{ github.repository == 'GreptimeTeam/greptimedb' }}

View File

@@ -18,7 +18,7 @@ default = [
]
enterprise = ["common-meta/enterprise", "frontend/enterprise", "meta-srv/enterprise"]
tokio-console = ["common-telemetry/tokio-console"]
vector_index = ["mito2/vector_index"]
vector_index = ["mito2/vector_index", "query/vector_index"]
[lints]
workspace = true

View File

@@ -13,7 +13,7 @@
// limitations under the License.
mod convert;
mod distance;
pub mod distance;
mod elem_avg;
mod elem_product;
mod elem_sum;
@@ -33,6 +33,7 @@ use std::borrow::Cow;
use datafusion_common::{DataFusionError, Result, ScalarValue, utils};
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
use datatypes::arrow::array::new_empty_array;
use crate::function_registry::FunctionRegistry;
use crate::scalars::vector::impl_conv::as_veclit;
@@ -128,6 +129,11 @@ where
}
let len = ensure_same_length(&[arg0, arg1])?;
if len == 0 {
return Ok(ColumnarValue::Array(new_empty_array(
args.return_field.data_type(),
)));
}
let mut results = Vec::with_capacity(len);
for i in 0..len {
let v0 = try_get_scalar_value!(arg0, i);
@@ -155,6 +161,11 @@ where
}
let len = ensure_same_length(&[arg0, arg1])?;
if len == 0 {
return Ok(ColumnarValue::Array(new_empty_array(
args.return_field.data_type(),
)));
}
let mut results = Vec::with_capacity(len);
match (arg0, arg1) {
@@ -210,6 +221,11 @@ where
};
let len = arg0.len();
if len == 0 {
return Ok(ColumnarValue::Array(new_empty_array(
args.return_field.data_type(),
)));
}
let mut results = Vec::with_capacity(len);
for i in 0..len {
let v = ScalarValue::try_from_array(arg0, i)?;

View File

@@ -16,6 +16,10 @@ mod cos;
mod dot;
mod l2sq;
pub const VEC_COS_DISTANCE: &str = "vec_cos_distance";
pub const VEC_L2SQ_DISTANCE: &str = "vec_l2sq_distance";
pub const VEC_DOT_PRODUCT: &str = "vec_dot_product";
use std::borrow::Cow;
use std::fmt::Display;
@@ -109,9 +113,9 @@ macro_rules! define_distance_function {
};
}
define_distance_function!(CosDistanceFunction, "vec_cos_distance", cos::cos);
define_distance_function!(L2SqDistanceFunction, "vec_l2sq_distance", l2sq::l2sq);
define_distance_function!(DotProductFunction, "vec_dot_product", dot::dot);
define_distance_function!(CosDistanceFunction, VEC_COS_DISTANCE, cos::cos);
define_distance_function!(L2SqDistanceFunction, VEC_L2SQ_DISTANCE, l2sq::l2sq);
define_distance_function!(DotProductFunction, VEC_DOT_PRODUCT, dot::dot);
#[cfg(test)]
mod tests {

View File

@@ -165,6 +165,8 @@ pub(crate) struct ScanMetricsSet {
rows_bloom_filtered: usize,
/// Number of rows filtered by vector index.
rows_vector_filtered: usize,
/// Number of rows selected by vector index.
rows_vector_selected: usize,
/// Number of rows filtered by precise filter.
rows_precise_filtered: usize,
/// Number of index result cache hits for fulltext index.
@@ -291,6 +293,7 @@ impl fmt::Debug for ScanMetricsSet {
rows_inverted_filtered,
rows_bloom_filtered,
rows_vector_filtered,
rows_vector_selected,
rows_precise_filtered,
fulltext_index_cache_hit,
fulltext_index_cache_miss,
@@ -384,6 +387,9 @@ impl fmt::Debug for ScanMetricsSet {
if *rows_vector_filtered > 0 {
write!(f, ", \"rows_vector_filtered\":{rows_vector_filtered}")?;
}
if *rows_vector_selected > 0 {
write!(f, ", \"rows_vector_selected\":{rows_vector_selected}")?;
}
if *rows_precise_filtered > 0 {
write!(f, ", \"rows_precise_filtered\":{rows_precise_filtered}")?;
}
@@ -600,6 +606,7 @@ impl ScanMetricsSet {
rows_inverted_filtered,
rows_bloom_filtered,
rows_vector_filtered,
rows_vector_selected,
rows_precise_filtered,
fulltext_index_cache_hit,
fulltext_index_cache_miss,
@@ -636,6 +643,7 @@ impl ScanMetricsSet {
self.rows_inverted_filtered += *rows_inverted_filtered;
self.rows_bloom_filtered += *rows_bloom_filtered;
self.rows_vector_filtered += *rows_vector_filtered;
self.rows_vector_selected += *rows_vector_selected;
self.rows_precise_filtered += *rows_precise_filtered;
self.fulltext_index_cache_hit += *fulltext_index_cache_hit;

View File

@@ -979,6 +979,7 @@ impl ParquetReaderBuilder {
return;
}
};
metrics.rows_vector_selected += selection.row_count();
apply_selection_and_update_metrics(output, &selection, metrics, INDEX_TYPE_VECTOR);
}
@@ -1229,6 +1230,8 @@ pub(crate) struct ReaderFilterMetrics {
pub(crate) rows_bloom_filtered: usize,
/// Number of rows filtered by vector index.
pub(crate) rows_vector_filtered: usize,
/// Number of rows selected by vector index.
pub(crate) rows_vector_selected: usize,
/// Number of rows filtered by precise filter.
pub(crate) rows_precise_filtered: usize,
@@ -1268,6 +1271,7 @@ impl ReaderFilterMetrics {
self.rows_inverted_filtered += other.rows_inverted_filtered;
self.rows_bloom_filtered += other.rows_bloom_filtered;
self.rows_vector_filtered += other.rows_vector_filtered;
self.rows_vector_selected += other.rows_vector_selected;
self.rows_precise_filtered += other.rows_precise_filtered;
self.fulltext_index_cache_hit += other.fulltext_index_cache_hit;

View File

@@ -9,6 +9,7 @@ workspace = true
[features]
enterprise = []
vector_index = []
[dependencies]
ahash.workspace = true

View File

@@ -162,6 +162,188 @@ impl Stream for EmptyStream {
}
}
#[cfg(feature = "vector_index")]
mod vector_search_tests {
use std::sync::Arc;
use common_function::function::Function;
use common_function::scalars::udf::create_udf;
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::{Expr, LogicalPlanBuilder, Signature, Volatility, col, lit};
use datatypes::schema::{ColumnSchema, SchemaBuilder};
use store_api::storage::ConcreteDataType;
use table::metadata::{FilterPushDownType, TableInfoBuilder, TableMeta, TableType};
use table::table::adapter::DfTableProviderAdapter;
use table::{Table, TableRef};
use super::*;
use crate::dist_plan::MergeScanLogicalPlan;
struct TestVectorFunction {
name: &'static str,
signature: Signature,
}
impl TestVectorFunction {
fn new(name: &'static str) -> Self {
Self {
name,
signature: Signature::any(2, Volatility::Immutable),
}
}
}
impl std::fmt::Display for TestVectorFunction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.name)
}
}
impl Function for TestVectorFunction {
fn name(&self) -> &str {
self.name
}
fn return_type(
&self,
_input_types: &[datatypes::arrow::datatypes::DataType],
) -> datafusion_common::Result<datatypes::arrow::datatypes::DataType> {
Ok(datatypes::arrow::datatypes::DataType::Float32)
}
fn signature(&self) -> &Signature {
&self.signature
}
fn invoke_with_args(
&self,
_args: datafusion_expr::ScalarFunctionArgs,
) -> datafusion_common::Result<datafusion_expr::ColumnarValue> {
Err(datafusion_common::DataFusionError::Execution(
"test udf should not be invoked".to_string(),
))
}
}
fn build_vector_table(table_id: TableId) -> TableRef {
let schema = {
let columns = vec![
ColumnSchema::new("k0", ConcreteDataType::string_datatype(), true),
ColumnSchema::new(
"ts",
ConcreteDataType::timestamp_millisecond_datatype(),
false,
)
.with_time_index(true),
ColumnSchema::new("v", ConcreteDataType::vector_datatype(2), false),
];
Arc::new(
SchemaBuilder::try_from_columns(columns)
.unwrap()
.build()
.unwrap(),
)
};
let table_meta = TableMeta {
schema: schema.clone(),
primary_key_indices: vec![0],
value_indices: vec![2],
engine: "test_engine".to_string(),
next_column_id: 3,
options: Default::default(),
created_on: Default::default(),
updated_on: Default::default(),
partition_key_indices: vec![0],
column_ids: vec![0, 1, 2],
};
let table_info = TableInfoBuilder::default()
.table_id(table_id)
.name("t".to_string())
.catalog_name(DEFAULT_CATALOG_NAME)
.schema_name(DEFAULT_SCHEMA_NAME)
.table_version(0)
.table_type(TableType::Base)
.meta(table_meta)
.build()
.unwrap();
let data_source = Arc::new(TestDataSource::new(schema));
Arc::new(Table::new(
Arc::new(table_info),
FilterPushDownType::Unsupported,
data_source,
))
}
fn vector_distance_expr() -> Expr {
let udf = create_udf(Arc::new(TestVectorFunction::new("vec_l2sq_distance")));
Expr::ScalarFunction(ScalarFunction::new_udf(
Arc::new(udf),
vec![
col("v"),
lit(ScalarValue::Utf8(Some("[1.0, 2.0]".to_string()))),
],
))
}
#[test]
fn vector_search_rewrite_keeps_sort_in_child_plan() {
init_default_ut_logging();
let table = build_vector_table(0);
let table_source = Arc::new(DefaultTableSource::new(Arc::new(
DfTableProviderAdapter::new(table),
)));
let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
.unwrap()
.sort(vec![vector_distance_expr().sort(true, false)])
.unwrap()
.limit(0, Some(5))
.unwrap()
.build()
.unwrap();
let config = ConfigOptions::default();
let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
let plan_str = result.to_string();
assert!(plan_str.contains("MergeSort: vec_l2sq_distance"));
assert!(plan_str.contains("Sort: vec_l2sq_distance"));
assert!(plan_str.contains(MergeScanLogicalPlan::name()));
}
#[test]
fn vector_search_rewrite_with_filter_keeps_sort_in_child_plan() {
init_default_ut_logging();
let table = build_vector_table(0);
let table_source = Arc::new(DefaultTableSource::new(Arc::new(
DfTableProviderAdapter::new(table),
)));
let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
.unwrap()
.filter(col("k0").eq(lit("hello")))
.unwrap()
.sort(vec![vector_distance_expr().sort(true, false)])
.unwrap()
.limit(0, Some(5))
.unwrap()
.build()
.unwrap();
let config = ConfigOptions::default();
let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
let plan_str = result.to_string();
assert!(plan_str.contains("MergeSort: vec_l2sq_distance"));
assert!(plan_str.contains("Sort: vec_l2sq_distance"));
assert!(plan_str.contains("Filter: t.k0 = Utf8(\"hello\")"));
assert!(plan_str.contains(MergeScanLogicalPlan::name()));
}
}
fn try_encode_decode_substrait(plan: &LogicalPlan, state: SessionState) {
let sub_plan_bytes = substrait::DFLogicalSubstraitConvertor
.encode(plan, crate::query_engine::DefaultSerializer)

View File

@@ -16,9 +16,17 @@ use std::collections::HashSet;
use std::sync::Arc;
use common_function::aggrs::aggr_wrapper::{StateMergeHelper, is_all_aggr_exprs_steppable};
#[cfg(feature = "vector_index")]
use common_function::scalars::vector::distance::{
VEC_COS_DISTANCE, VEC_DOT_PRODUCT, VEC_L2SQ_DISTANCE,
};
use common_telemetry::debug;
use datafusion::error::Result as DfResult;
#[cfg(feature = "vector_index")]
use datafusion_common::DataFusionError;
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
#[cfg(feature = "vector_index")]
use datafusion_expr::Sort;
use datafusion_expr::{Expr, LogicalPlan, UserDefinedLogicalNode};
use promql::extension_plan::{
EmptyMetric, InstantManipulate, RangeManipulate, SeriesDivide, SeriesNormalize,
@@ -29,6 +37,37 @@ use crate::dist_plan::MergeScanLogicalPlan;
use crate::dist_plan::analyzer::AliasMapping;
use crate::dist_plan::merge_sort::{MergeSortLogicalPlan, merge_sort_transformer};
#[cfg(feature = "vector_index")]
fn is_vector_sort(sort: &Sort) -> bool {
if sort.expr.len() != 1 {
return false;
}
let sort_expr = &sort.expr[0].expr;
let Expr::ScalarFunction(func) = sort_expr else {
return false;
};
matches!(
func.name().to_lowercase().as_str(),
VEC_L2SQ_DISTANCE | VEC_COS_DISTANCE | VEC_DOT_PRODUCT
)
}
#[cfg(feature = "vector_index")]
fn vector_sort_transformer(plan: &LogicalPlan) -> DfResult<TransformerAction> {
let LogicalPlan::Sort(sort) = plan else {
return Err(DataFusionError::Internal(format!(
"vector_sort_transformer expects Sort, got {plan}"
)));
};
Ok(TransformerAction {
extra_parent_plans: vec![
MergeSortLogicalPlan::new(sort.input.clone(), sort.expr.clone(), sort.fetch)
.into_logical_plan(),
],
new_child_plan: Some(LogicalPlan::Sort(sort.clone())),
})
}
pub struct StepTransformAction {
extra_parent_plans: Vec<LogicalPlan>,
new_child_plan: Option<LogicalPlan>,
@@ -150,14 +189,19 @@ impl Categorizer {
// commutativity is needed under this situation.
Commutativity::ConditionalCommutative(None)
}
LogicalPlan::Sort(_) => {
LogicalPlan::Sort(_sort) => {
if partition_cols.is_empty() {
return Ok(Commutativity::Commutative);
}
// sort plan needs to consider column priority
// Change Sort to MergeSort which assumes the input streams are already sorted hence can be more efficient
// We should ensure the number of partition is not smaller than the number of region at present. Otherwise this would result in incorrect output.
// Change Sort to MergeSort which assumes the input streams are already sorted hence can be more efficient.
#[cfg(feature = "vector_index")]
if is_vector_sort(_sort) {
return Ok(Commutativity::TransformedCommutative {
transformer: Some(Arc::new(vector_sort_transformer)),
});
}
Commutativity::ConditionalCommutative(Some(Arc::new(merge_sort_transformer)))
}
LogicalPlan::Join(_) => Commutativity::NonCommutative,

View File

@@ -35,7 +35,9 @@ use session::context::{QueryContext, QueryContextRef};
use snafu::ResultExt;
use store_api::metadata::RegionMetadataRef;
use store_api::region_engine::RegionEngineRef;
use store_api::storage::{RegionId, ScanRequest, TimeSeriesDistribution, TimeSeriesRowSelector};
use store_api::storage::{
RegionId, ScanRequest, TimeSeriesDistribution, TimeSeriesRowSelector, VectorSearchRequest,
};
use table::TableRef;
use table::metadata::{TableId, TableInfoRef};
use table::table::scan::RegionScanExec;
@@ -256,6 +258,14 @@ impl DummyTableProvider {
self.scan_request.lock().unwrap().series_row_selector = Some(selector);
}
pub fn with_vector_search_hint(&self, hint: VectorSearchRequest) {
self.scan_request.lock().unwrap().vector_search = Some(hint);
}
pub fn get_vector_search_hint(&self) -> Option<VectorSearchRequest> {
self.scan_request.lock().unwrap().vector_search.clone()
}
pub fn with_sequence(&self, sequence: u64) {
self.scan_request.lock().unwrap().memtable_max_sequence = Some(sequence);
}

View File

@@ -28,6 +28,10 @@ use store_api::metric_engine_consts::DATA_SCHEMA_TSID_COLUMN_NAME;
use store_api::storage::{TimeSeriesDistribution, TimeSeriesRowSelector};
use crate::dummy_catalog::DummyTableProvider;
#[cfg(feature = "vector_index")]
mod vector_search;
#[cfg(feature = "vector_index")]
use vector_search::VectorSearchState;
/// This rule will traverse the plan to collect necessary hints for leaf
/// table scan node and set them in [`ScanRequest`]. Hints include:
@@ -59,13 +63,16 @@ impl ScanHintRule {
let _ = plan.visit(&mut visitor)?;
if visitor.need_rewrite() {
plan.transform_down(&|plan| Self::set_hints(plan, &visitor))
plan.transform_down(&mut |plan| Self::set_hints(plan, &mut visitor))
} else {
Ok(Transformed::no(plan))
}
}
fn set_hints(plan: LogicalPlan, visitor: &ScanHintVisitor) -> Result<Transformed<LogicalPlan>> {
fn set_hints(
plan: LogicalPlan,
visitor: &mut ScanHintVisitor,
) -> Result<Transformed<LogicalPlan>> {
match &plan {
LogicalPlan::TableScan(table_scan) => {
let mut transformed = false;
@@ -94,6 +101,13 @@ impl ScanHintRule {
);
}
#[cfg(feature = "vector_index")]
if let Some(vector_request) = visitor
.vector_search
.take_vector_request_from_dummy(adapter, &table_scan.table_name)
{
adapter.with_vector_search_hint(vector_request);
}
transformed = true;
}
}
@@ -221,15 +235,29 @@ struct ScanHintVisitor {
/// This field stores saved `group_by` columns when all aggregate functions are `last_value`
/// and the `order_by` column which should be time index.
ts_row_selector: Option<(HashSet<Column>, Column)>,
#[cfg(feature = "vector_index")]
vector_search: VectorSearchState,
}
impl TreeNodeVisitor<'_> for ScanHintVisitor {
type Node = LogicalPlan;
fn f_down(&mut self, node: &Self::Node) -> Result<TreeNodeRecursion> {
#[cfg(feature = "vector_index")]
if let LogicalPlan::Limit(limit) = node {
// Track LIMIT so vector hint k can be derived within the same input chain.
self.vector_search.on_limit_enter(limit);
}
// Get order requirement from sort plan
if let LogicalPlan::Sort(sort) = node {
self.order_expr = Some(sort.expr.clone());
#[cfg(feature = "vector_index")]
{
// Capture vector ORDER BY and TopK hints from sort nodes.
self.vector_search.on_sort_enter(sort);
}
}
// Get time series row selector from aggr plan
@@ -294,12 +322,17 @@ impl TreeNodeVisitor<'_> for ScanHintVisitor {
}
}
if self.ts_row_selector.is_some()
&& (matches!(node, LogicalPlan::Subquery(_)) || node.inputs().len() > 1)
{
// Avoid carrying vector hints across branching inputs (join/subquery) to prevent
// pruning results before global ordering is applied.
let is_branching = matches!(node, LogicalPlan::Subquery(_)) || node.inputs().len() > 1;
if is_branching && self.ts_row_selector.is_some() {
// clean previous time series selector hint when encounter subqueries or join
self.ts_row_selector = None;
}
#[cfg(feature = "vector_index")]
if is_branching {
self.vector_search.on_branching_enter();
}
if let LogicalPlan::Filter(filter) = node
&& let Some(group_by_exprs) = &self.ts_row_selector
@@ -312,13 +345,56 @@ impl TreeNodeVisitor<'_> for ScanHintVisitor {
}
}
#[cfg(feature = "vector_index")]
if let LogicalPlan::Filter(filter) = node {
self.vector_search.on_filter_enter(&filter.predicate);
}
#[cfg(feature = "vector_index")]
if let LogicalPlan::TableScan(table_scan) = node {
// Record vector hints at leaf scans after scope checks.
self.vector_search.on_table_scan(table_scan);
}
Ok(TreeNodeRecursion::Continue)
}
fn f_up(&mut self, _node: &Self::Node) -> Result<TreeNodeRecursion> {
#[cfg(feature = "vector_index")]
match _node {
LogicalPlan::Limit(_) => {
self.vector_search.on_limit_exit();
}
LogicalPlan::Sort(_) => {
self.vector_search.on_sort_exit();
}
LogicalPlan::Filter(_) => {
self.vector_search.on_filter_exit();
}
LogicalPlan::Subquery(_) => {
self.vector_search.on_branching_exit();
}
_ if _node.inputs().len() > 1 => {
self.vector_search.on_branching_exit();
}
_ => {}
}
Ok(TreeNodeRecursion::Continue)
}
}
impl ScanHintVisitor {
fn need_rewrite(&self) -> bool {
self.order_expr.is_some() || self.ts_row_selector.is_some()
let base = self.order_expr.is_some() || self.ts_row_selector.is_some();
#[cfg(feature = "vector_index")]
{
base || self.vector_search.need_rewrite()
}
#[cfg(not(feature = "vector_index"))]
{
base
}
}
}

View File

@@ -0,0 +1,837 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::{HashMap, HashSet, VecDeque};
use common_function::scalars::vector::distance::{
VEC_COS_DISTANCE, VEC_DOT_PRODUCT, VEC_L2SQ_DISTANCE,
};
use common_telemetry::debug;
use datafusion_common::ScalarValue;
use datafusion_expr::logical_plan::FetchType;
use datafusion_expr::utils::split_conjunction;
use datafusion_expr::{Expr, SortExpr};
use datafusion_sql::TableReference;
use datatypes::types::parse_string_to_vector_type_value;
use store_api::storage::{VectorDistanceMetric, VectorSearchRequest};
use crate::dummy_catalog::DummyTableProvider;
/// Tracks vector search hints while traversing the logical plan.
///
/// Vector search requests are emitted only when:
/// - ORDER BY uses a supported vector distance function and its direction matches the metric.
/// - A LIMIT (or Sort.fetch) is present to derive k.
/// - The hint stays within a single input chain (not across join/subquery branches).
/// - The target column is non-nullable, or an explicit IS NOT NULL filter exists.
#[derive(Default)]
pub(crate) struct VectorSearchState {
current_distance: Option<VectorDistanceInfo>,
current_limit: Option<VectorLimitInfo>,
distance_stack: Vec<Option<VectorDistanceInfo>>,
limit_stack: Vec<Option<VectorLimitInfo>>,
non_null_columns: HashSet<ColumnKey>,
non_null_stack: Vec<HashSet<ColumnKey>>,
vector_hints: HashMap<TableReference, VecDeque<VectorHintEntry>>,
}
#[derive(Clone)]
struct VectorDistanceInfo {
table_reference: Option<TableReference>,
column_name: String,
query_vector: Vec<f32>,
metric: VectorDistanceMetric,
}
#[derive(Clone)]
struct VectorLimitInfo {
fetch: usize,
skip: usize,
}
impl VectorLimitInfo {
fn k(&self) -> Option<usize> {
self.fetch.checked_add(self.skip)
}
}
#[derive(Clone)]
struct VectorHintEntry {
distance: VectorDistanceInfo,
limit: VectorLimitInfo,
non_null_constraint: bool,
}
#[derive(Clone, Hash, PartialEq, Eq)]
struct ColumnKey {
table_reference: Option<TableReference>,
column_name: String,
}
impl VectorSearchState {
pub(crate) fn need_rewrite(&self) -> bool {
!self.vector_hints.is_empty()
}
pub(crate) fn on_branching_enter(&mut self) {
// Clear per-branch state so hints are only derived within a single input chain.
self.distance_stack.push(self.current_distance.take());
self.limit_stack.push(self.current_limit.take());
self.non_null_stack
.push(std::mem::take(&mut self.non_null_columns));
}
pub(crate) fn on_branching_exit(&mut self) {
// Restore the prior chain state after leaving the branch.
if let Some(previous) = self.limit_stack.pop() {
self.current_limit = previous;
}
if let Some(previous) = self.distance_stack.pop() {
self.current_distance = previous;
}
if let Some(previous) = self.non_null_stack.pop() {
self.non_null_columns = previous;
}
}
pub(crate) fn on_limit_enter(&mut self, limit: &datafusion_expr::logical_plan::Limit) {
self.limit_stack.push(self.current_limit.take());
self.current_limit = Self::extract_limit_info(limit);
}
pub(crate) fn on_limit_exit(&mut self) {
if let Some(previous) = self.limit_stack.pop() {
self.current_limit = previous;
}
}
pub(crate) fn on_sort_enter(&mut self, sort: &datafusion_expr::logical_plan::Sort) {
// Distance is scoped to the nearest sort, while limit may be inherited from parents.
self.distance_stack.push(self.current_distance.take());
self.limit_stack.push(self.current_limit.clone());
let distance = Self::extract_distance_from_sort(sort);
self.current_distance = distance.clone();
// Sort.fetch is a TopK limit, so we can infer k when no LIMIT is present.
if self.current_limit.is_none() {
self.current_limit = distance
.as_ref()
.and_then(|_| Self::extract_limit_from_sort(sort));
}
}
pub(crate) fn on_sort_exit(&mut self) {
if let Some(previous) = self.limit_stack.pop() {
self.current_limit = previous;
}
if let Some(previous) = self.distance_stack.pop() {
self.current_distance = previous;
}
}
pub(crate) fn on_table_scan(&mut self, table_scan: &datafusion_expr::logical_plan::TableScan) {
self.record_vector_hint(table_scan);
}
pub(crate) fn on_filter_enter(&mut self, predicate: &Expr) {
self.non_null_stack.push(self.non_null_columns.clone());
for expr in split_conjunction(predicate) {
if let Expr::IsNotNull(inner) = expr
&& let Expr::Column(col) = inner.as_ref()
{
self.non_null_columns.insert(ColumnKey {
table_reference: col.relation.clone(),
column_name: col.name.clone(),
});
}
}
// TODO: detect non-null constraints from more complex predicates (casts/functions).
}
pub(crate) fn on_filter_exit(&mut self) {
if let Some(previous) = self.non_null_stack.pop() {
self.non_null_columns = previous;
}
}
pub(crate) fn take_vector_request_from_dummy(
&mut self,
provider: &DummyTableProvider,
table_name: &TableReference,
) -> Option<VectorSearchRequest> {
let hint = self.take_vector_hint(table_name)?;
self.build_vector_request_from_dummy(provider, table_name, &hint)
}
fn build_vector_request_from_dummy(
&self,
provider: &DummyTableProvider,
table_name: &TableReference,
hint: &VectorHintEntry,
) -> Option<VectorSearchRequest> {
let info = &hint.distance;
let k = hint.limit.k()?;
if let Some(ref hint_table) = info.table_reference
&& table_name != hint_table
{
return None;
}
let metadata = provider.region_metadata();
let column = metadata.column_by_name(&info.column_name)?;
if column.column_schema.is_nullable() && !hint.non_null_constraint {
debug!(
"Skip vector hint: column '{}' is nullable without IS NOT NULL filter",
info.column_name
);
return None;
}
Some(VectorSearchRequest {
column_id: column.column_id,
query_vector: info.query_vector.clone(),
k,
metric: info.metric,
})
}
fn extract_distance_info(expr: &Expr) -> Option<VectorDistanceInfo> {
let Expr::ScalarFunction(func) = expr else {
return None;
};
let func_name = func.name().to_lowercase();
let metric = match func_name.as_str() {
VEC_L2SQ_DISTANCE => VectorDistanceMetric::L2sq,
VEC_COS_DISTANCE => VectorDistanceMetric::Cosine,
VEC_DOT_PRODUCT => VectorDistanceMetric::InnerProduct,
_ => return None,
};
if func.args.len() != 2 {
return None;
}
let (table_reference, column_name) = match &func.args[0] {
Expr::Column(col) => (col.relation.clone(), col.name.clone()),
_ => return None,
};
let query_vector = Self::extract_query_vector(&func.args[1])?;
Some(VectorDistanceInfo {
table_reference,
column_name,
query_vector,
metric,
})
}
fn extract_distance_from_sort(
sort: &datafusion_expr::logical_plan::Sort,
) -> Option<VectorDistanceInfo> {
if sort.expr.len() != 1 {
debug!(
"Skip vector hint: Sort has {} expressions, expected 1",
sort.expr.len()
);
return None;
}
let sort_expr: &SortExpr = &sort.expr[0];
let info = Self::extract_distance_info(&sort_expr.expr)?;
let expected_asc = info.metric != VectorDistanceMetric::InnerProduct;
if sort_expr.asc == expected_asc {
Some(info)
} else {
None
}
}
fn extract_limit_info(limit: &datafusion_expr::logical_plan::Limit) -> Option<VectorLimitInfo> {
let fetch = match limit.get_fetch_type().ok()? {
FetchType::Literal(fetch) => fetch?,
FetchType::UnsupportedExpr => return None,
};
let skip = match limit.get_skip_type().ok()? {
datafusion_expr::logical_plan::SkipType::Literal(skip) => skip,
datafusion_expr::logical_plan::SkipType::UnsupportedExpr => return None,
};
Some(VectorLimitInfo { fetch, skip })
}
fn extract_limit_from_sort(
sort: &datafusion_expr::logical_plan::Sort,
) -> Option<VectorLimitInfo> {
let fetch = sort.fetch?;
Some(VectorLimitInfo { fetch, skip: 0 })
}
fn record_vector_hint(&mut self, table_scan: &datafusion_expr::logical_plan::TableScan) {
let Some(limit) = self.current_limit.as_ref() else {
return;
};
let Some(distance) = self.current_distance.as_ref() else {
return;
};
// Only emit hints when distance+limit are present and the table matches the sort target.
if let Some(ref hint_table) = distance.table_reference
&& hint_table != &table_scan.table_name
{
return;
}
let non_null_constraint =
self.is_column_non_null(&distance.table_reference, &distance.column_name);
self.vector_hints
.entry(table_scan.table_name.clone())
.or_default()
.push_back(VectorHintEntry {
distance: distance.clone(),
limit: limit.clone(),
non_null_constraint,
});
}
fn take_vector_hint(&mut self, table_name: &TableReference) -> Option<VectorHintEntry> {
let hints = self.vector_hints.get_mut(table_name)?;
let hint = hints.pop_front();
if hints.is_empty() {
self.vector_hints.remove(table_name);
}
hint
}
fn is_column_non_null(
&self,
table_reference: &Option<TableReference>,
column_name: &str,
) -> bool {
self.non_null_columns.contains(&ColumnKey {
table_reference: table_reference.clone(),
column_name: column_name.to_string(),
})
}
fn extract_query_vector(expr: &Expr) -> Option<Vec<f32>> {
match expr {
Expr::Literal(scalar, _) => match scalar {
ScalarValue::Utf8(Some(s)) | ScalarValue::LargeUtf8(Some(s)) => {
Self::parse_json_vector(s)
}
ScalarValue::Binary(Some(bytes)) | ScalarValue::LargeBinary(Some(bytes)) => {
Self::parse_binary_vector(bytes)
}
_ => None,
},
_ => None,
}
}
fn parse_json_vector(s: &str) -> Option<Vec<f32>> {
let trimmed = s.trim();
if !trimmed.starts_with('[') || !trimmed.ends_with(']') {
return None;
}
parse_string_to_vector_type_value(trimmed, None)
.ok()
.and_then(|bytes| if bytes.is_empty() { None } else { Some(bytes) })
.and_then(|bytes| Self::parse_binary_vector(&bytes))
}
fn parse_binary_vector(bytes: &[u8]) -> Option<Vec<f32>> {
if bytes.is_empty() || !bytes.len().is_multiple_of(4) {
return None;
}
Some(
bytes
.chunks(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect(),
)
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use api::v1::SemanticType;
use common_function::function::Function;
use common_function::scalars::udf::create_udf;
use common_function::scalars::vector::distance::{VEC_DOT_PRODUCT, VEC_L2SQ_DISTANCE};
use datafusion::datasource::DefaultTableSource;
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{Column, DataFusionError, Result, ScalarValue};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::logical_plan::JoinType;
use datafusion_expr::{
Expr, LogicalPlan, LogicalPlanBuilder, Signature, Subquery, Volatility, col, lit,
};
use datafusion_optimizer::{OptimizerContext, OptimizerRule};
use datatypes::schema::ColumnSchema;
use store_api::metadata::{ColumnMetadata, RegionMetadataBuilder};
use store_api::storage::{ConcreteDataType, VectorDistanceMetric};
use super::VectorSearchState;
use crate::dummy_catalog::DummyTableProvider;
use crate::optimizer::scan_hint::ScanHintRule;
use crate::optimizer::test_util::MetaRegionEngine;
struct TestVectorFunction {
name: &'static str,
signature: Signature,
}
impl TestVectorFunction {
fn new(name: &'static str) -> Self {
Self {
name,
signature: Signature::any(2, Volatility::Immutable),
}
}
}
impl std::fmt::Display for TestVectorFunction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.name)
}
}
impl Function for TestVectorFunction {
fn name(&self) -> &str {
self.name
}
fn return_type(
&self,
_input_types: &[datatypes::arrow::datatypes::DataType],
) -> Result<datatypes::arrow::datatypes::DataType> {
Ok(datatypes::arrow::datatypes::DataType::Float32)
}
fn signature(&self) -> &Signature {
&self.signature
}
fn invoke_with_args(
&self,
_args: datafusion_expr::ScalarFunctionArgs,
) -> Result<ColumnarValue> {
Err(DataFusionError::Execution(
"test udf should not be invoked".to_string(),
))
}
}
fn vec_distance_expr(function_name: &'static str) -> Expr {
let udf = create_udf(Arc::new(TestVectorFunction::new(function_name)));
Expr::ScalarFunction(ScalarFunction::new_udf(
Arc::new(udf),
vec![
col("v"),
lit(ScalarValue::Utf8(Some("[1.0, 2.0]".to_string()))),
],
))
}
fn build_dummy_provider(column_id: u32) -> Arc<DummyTableProvider> {
build_dummy_provider_with_nullable(column_id, false)
}
fn build_dummy_provider_with_nullable(
column_id: u32,
nullable_vector: bool,
) -> Arc<DummyTableProvider> {
let mut builder = RegionMetadataBuilder::new(0.into());
builder
.push_column_metadata(ColumnMetadata {
column_schema: ColumnSchema::new("k0", ConcreteDataType::string_datatype(), true),
semantic_type: SemanticType::Tag,
column_id: 1,
})
.push_column_metadata(ColumnMetadata {
column_schema: ColumnSchema::new(
"ts",
ConcreteDataType::timestamp_millisecond_datatype(),
false,
),
semantic_type: SemanticType::Timestamp,
column_id: 2,
})
.push_column_metadata(ColumnMetadata {
column_schema: ColumnSchema::new(
"v",
ConcreteDataType::vector_datatype(2),
nullable_vector,
),
semantic_type: SemanticType::Field,
column_id,
})
.primary_key(vec![1]);
let metadata = Arc::new(builder.build().unwrap());
let engine = Arc::new(MetaRegionEngine::with_metadata(metadata.clone()));
Arc::new(DummyTableProvider::new(0.into(), engine, metadata))
}
#[test]
fn test_parse_json_vector() {
assert_eq!(
VectorSearchState::parse_json_vector("[1.0, 2.0, 3.0]"),
Some(vec![1.0, 2.0, 3.0])
);
assert_eq!(
VectorSearchState::parse_json_vector("[1.5, -2.3, 0.0]"),
Some(vec![1.5, -2.3, 0.0])
);
assert_eq!(VectorSearchState::parse_json_vector("invalid"), None);
assert_eq!(VectorSearchState::parse_json_vector("[]"), None);
assert_eq!(VectorSearchState::parse_json_vector(""), None);
assert_eq!(VectorSearchState::parse_json_vector("["), None);
assert_eq!(VectorSearchState::parse_json_vector("[1.0, abc]"), None);
}
#[test]
fn test_parse_binary_vector() {
let v1: f32 = 1.0;
let v2: f32 = 2.0;
let mut bytes = Vec::new();
bytes.extend_from_slice(&v1.to_le_bytes());
bytes.extend_from_slice(&v2.to_le_bytes());
let result = VectorSearchState::parse_binary_vector(&bytes);
assert_eq!(result, Some(vec![1.0, 2.0]));
let result = VectorSearchState::parse_binary_vector(&[1, 2, 3]);
assert_eq!(result, None);
}
#[test]
fn test_dummy_provider_vector_hint() {
let dummy_provider = build_dummy_provider(10);
let table_source = Arc::new(DefaultTableSource::new(dummy_provider.clone()));
let expr = vec_distance_expr(VEC_L2SQ_DISTANCE);
let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
.unwrap()
.sort(vec![expr.sort(true, false)])
.unwrap()
.limit(0, Some(5))
.unwrap()
.build()
.unwrap();
let context = OptimizerContext::default();
let _ = ScanHintRule.rewrite(plan, &context).unwrap();
let hint = dummy_provider.get_vector_search_hint().unwrap();
assert_eq!(hint.column_id, 10);
assert_eq!(hint.k, 5);
assert_eq!(hint.metric, VectorDistanceMetric::L2sq);
assert_eq!(hint.query_vector, vec![1.0, 2.0]);
}
#[test]
fn test_limit_offset_for_vector_hint() {
let dummy_provider = build_dummy_provider(10);
let table_source = Arc::new(DefaultTableSource::new(dummy_provider.clone()));
let expr = vec_distance_expr(VEC_L2SQ_DISTANCE);
let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
.unwrap()
.sort(vec![expr.sort(true, false)])
.unwrap()
.limit(5, Some(10))
.unwrap()
.build()
.unwrap();
let context = OptimizerContext::default();
let _ = ScanHintRule.rewrite(plan, &context).unwrap();
let hint = dummy_provider.get_vector_search_hint().unwrap();
assert_eq!(hint.k, 15);
}
#[test]
fn test_inner_product_sort_direction() {
let dummy_provider = build_dummy_provider(10);
let table_source = Arc::new(DefaultTableSource::new(dummy_provider.clone()));
let expr = vec_distance_expr(VEC_DOT_PRODUCT);
let plan = LogicalPlanBuilder::scan_with_filters("t", table_source.clone(), None, vec![])
.unwrap()
.sort(vec![expr.clone().sort(true, false)])
.unwrap()
.limit(0, Some(3))
.unwrap()
.build()
.unwrap();
let context = OptimizerContext::default();
let _ = ScanHintRule.rewrite(plan, &context).unwrap();
assert!(dummy_provider.get_vector_search_hint().is_none());
let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
.unwrap()
.sort(vec![expr.sort(false, false)])
.unwrap()
.limit(0, Some(3))
.unwrap()
.build()
.unwrap();
let _ = ScanHintRule.rewrite(plan, &context).unwrap();
let hint = dummy_provider.get_vector_search_hint().unwrap();
assert_eq!(hint.metric, VectorDistanceMetric::InnerProduct);
}
#[test]
fn test_no_limit_clause() {
let dummy_provider = build_dummy_provider(10);
let table_source = Arc::new(DefaultTableSource::new(dummy_provider.clone()));
let expr = vec_distance_expr(VEC_L2SQ_DISTANCE);
let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
.unwrap()
.sort(vec![expr.sort(true, false)])
.unwrap()
.build()
.unwrap();
let context = OptimizerContext::default();
let _ = ScanHintRule.rewrite(plan, &context).unwrap();
assert!(dummy_provider.get_vector_search_hint().is_none());
}
#[test]
fn test_nullable_vector_requires_is_not_null_filter() {
let dummy_provider = build_dummy_provider_with_nullable(10, true);
let table_source = Arc::new(DefaultTableSource::new(dummy_provider.clone()));
let expr = vec_distance_expr(VEC_L2SQ_DISTANCE);
let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
.unwrap()
.sort(vec![expr.sort(true, false)])
.unwrap()
.limit(0, Some(5))
.unwrap()
.build()
.unwrap();
let context = OptimizerContext::default();
let _ = ScanHintRule.rewrite(plan, &context).unwrap();
assert!(dummy_provider.get_vector_search_hint().is_none());
}
#[test]
fn test_nullable_vector_with_is_not_null_filter() {
let dummy_provider = build_dummy_provider_with_nullable(10, true);
let table_source = Arc::new(DefaultTableSource::new(dummy_provider.clone()));
let expr = vec_distance_expr(VEC_L2SQ_DISTANCE);
let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
.unwrap()
.filter(col("v").is_not_null())
.unwrap()
.sort(vec![expr.sort(true, false)])
.unwrap()
.limit(0, Some(5))
.unwrap()
.build()
.unwrap();
let context = OptimizerContext::default();
let _ = ScanHintRule.rewrite(plan, &context).unwrap();
let hint = dummy_provider.get_vector_search_hint().unwrap();
assert_eq!(hint.column_id, 10);
assert_eq!(hint.k, 5);
}
#[test]
fn test_sort_fetch_for_vector_hint() {
let dummy_provider = build_dummy_provider(10);
let table_source = Arc::new(DefaultTableSource::new(dummy_provider.clone()));
let expr = vec_distance_expr(VEC_L2SQ_DISTANCE);
let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
.unwrap()
.sort_with_limit(vec![expr.sort(true, false)], Some(4))
.unwrap()
.build()
.unwrap();
let context = OptimizerContext::default();
let _ = ScanHintRule.rewrite(plan, &context).unwrap();
let hint = dummy_provider.get_vector_search_hint().unwrap();
assert_eq!(hint.k, 4);
}
#[test]
fn test_limit_scoped_to_sort_branch() {
let t1_provider = build_dummy_provider(10);
let t2_provider = build_dummy_provider(20);
let t1_source = Arc::new(DefaultTableSource::new(t1_provider.clone()));
let t2_source = Arc::new(DefaultTableSource::new(t2_provider.clone()));
let expr = vec_distance_expr(VEC_L2SQ_DISTANCE);
let left = LogicalPlanBuilder::scan_with_filters("t1", t1_source, None, vec![])
.unwrap()
.limit(0, Some(5))
.unwrap()
.build()
.unwrap();
let right = LogicalPlanBuilder::scan_with_filters("t2", t2_source, None, vec![])
.unwrap()
.sort(vec![expr.sort(true, false)])
.unwrap()
.build()
.unwrap();
let join_keys: (Vec<Column>, Vec<Column>) = (vec![], vec![]);
let plan = LogicalPlanBuilder::from(left)
.join(right, JoinType::Inner, join_keys, None)
.unwrap()
.build()
.unwrap();
let context = OptimizerContext::default();
let _ = ScanHintRule.rewrite(plan, &context).unwrap();
assert!(t1_provider.get_vector_search_hint().is_none());
assert!(t2_provider.get_vector_search_hint().is_none());
}
fn vec_distance_expr_qualified(
function_name: &'static str,
table_name: &str,
column_name: &str,
) -> Expr {
use datafusion_common::Column;
let udf = create_udf(Arc::new(TestVectorFunction::new(function_name)));
let qualified_col = Expr::Column(Column::new(Some(table_name.to_string()), column_name));
Expr::ScalarFunction(ScalarFunction::new_udf(
Arc::new(udf),
vec![
qualified_col,
lit(ScalarValue::Utf8(Some("[1.0, 2.0]".to_string()))),
],
))
}
#[test]
fn test_no_vector_hint_above_join() {
let t1_provider = build_dummy_provider(10);
let t2_provider = build_dummy_provider(20);
let t1_source = Arc::new(DefaultTableSource::new(t1_provider.clone()));
let t2_source = Arc::new(DefaultTableSource::new(t2_provider.clone()));
let left = LogicalPlanBuilder::scan_with_filters("t1", t1_source, None, vec![])
.unwrap()
.build()
.unwrap();
let right = LogicalPlanBuilder::scan_with_filters("t2", t2_source, None, vec![])
.unwrap()
.build()
.unwrap();
let join_keys: (Vec<Column>, Vec<Column>) = (vec![], vec![]);
let join_plan = LogicalPlanBuilder::from(left)
.join(right, JoinType::Inner, join_keys, None)
.unwrap()
.build()
.unwrap();
let expr = vec_distance_expr_qualified(VEC_L2SQ_DISTANCE, "t1", "v");
let plan = LogicalPlanBuilder::from(join_plan)
.sort(vec![expr.sort(true, false)])
.unwrap()
.limit(0, Some(5))
.unwrap()
.build()
.unwrap();
let context = OptimizerContext::default();
let _ = ScanHintRule.rewrite(plan, &context).unwrap();
assert!(t1_provider.get_vector_search_hint().is_none());
assert!(t2_provider.get_vector_search_hint().is_none());
}
#[test]
fn test_no_vector_hint_above_subquery() {
let provider = build_dummy_provider(10);
let table_source = Arc::new(DefaultTableSource::new(provider.clone()));
let scan_plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
.unwrap()
.build()
.unwrap();
let subquery = LogicalPlan::Subquery(Subquery {
subquery: Arc::new(scan_plan),
outer_ref_columns: vec![],
spans: Default::default(),
});
let expr = vec_distance_expr(VEC_L2SQ_DISTANCE);
let plan = LogicalPlanBuilder::from(subquery)
.sort(vec![expr.sort(true, false)])
.unwrap()
.limit(0, Some(5))
.unwrap()
.build()
.unwrap();
let context = OptimizerContext::default();
let _ = ScanHintRule.rewrite(plan, &context).unwrap();
assert!(provider.get_vector_search_hint().is_none());
}
#[test]
fn test_qualified_column_scopes_hint_to_correct_table() {
let t1_provider = build_dummy_provider(10);
let t2_provider = build_dummy_provider(20);
let t1_source = Arc::new(DefaultTableSource::new(t1_provider.clone()));
let t2_source = Arc::new(DefaultTableSource::new(t2_provider.clone()));
let expr = vec_distance_expr_qualified(VEC_L2SQ_DISTANCE, "t2", "v");
let t1_plan = LogicalPlanBuilder::scan_with_filters("t1", t1_source, None, vec![])
.unwrap()
.sort(vec![expr.clone().sort(true, false)])
.unwrap()
.limit(0, Some(5))
.unwrap()
.build()
.unwrap();
let t2_plan = LogicalPlanBuilder::scan_with_filters("t2", t2_source, None, vec![])
.unwrap()
.sort(vec![expr.sort(true, false)])
.unwrap()
.limit(0, Some(5))
.unwrap()
.build()
.unwrap();
let context = OptimizerContext::default();
let _ = ScanHintRule.rewrite(t1_plan, &context).unwrap();
assert!(t1_provider.get_vector_search_hint().is_none());
let _ = ScanHintRule.rewrite(t2_plan, &context).unwrap();
let hint = t2_provider.get_vector_search_hint().unwrap();
assert_eq!(hint.column_id, 20);
assert_eq!(hint.k, 5);
}
}

View File

@@ -26,7 +26,7 @@ use datafusion_expr::TableProviderFilterPushDown as DfTableProviderFilterPushDow
use datafusion_expr::expr::Expr;
use datafusion_physical_expr::PhysicalSortExpr;
use datafusion_physical_expr::expressions::Column;
use store_api::storage::ScanRequest;
use store_api::storage::{ScanRequest, VectorSearchRequest};
use crate::table::{TableRef, TableType};
@@ -52,6 +52,14 @@ impl DfTableProviderAdapter {
self.scan_req.lock().unwrap().output_ordering = Some(order_opts.to_vec());
}
pub fn with_vector_search_hint(&self, hint: VectorSearchRequest) {
self.scan_req.lock().unwrap().vector_search = Some(hint);
}
pub fn get_vector_search_hint(&self) -> Option<VectorSearchRequest> {
self.scan_req.lock().unwrap().vector_search.clone()
}
#[cfg(feature = "testing")]
pub fn get_scan_req(&self) -> ScanRequest {
self.scan_req.lock().unwrap().clone()