diff --git a/.github/workflows/develop.yml b/.github/workflows/develop.yml index 416c813e8d..1b35761318 100644 --- a/.github/workflows/develop.yml +++ b/.github/workflows/develop.yml @@ -73,6 +73,8 @@ jobs: save-if: ${{ github.ref == 'refs/heads/main' }} - name: Run cargo check run: cargo check --locked --workspace --all-targets + - name: Run cargo check (all features) + run: cargo check --locked --workspace --all-targets --all-features toml: if: ${{ github.repository == 'GreptimeTeam/greptimedb' }} diff --git a/src/cmd/Cargo.toml b/src/cmd/Cargo.toml index a70f164997..a7ae19e337 100644 --- a/src/cmd/Cargo.toml +++ b/src/cmd/Cargo.toml @@ -18,7 +18,7 @@ default = [ ] enterprise = ["common-meta/enterprise", "frontend/enterprise", "meta-srv/enterprise"] tokio-console = ["common-telemetry/tokio-console"] -vector_index = ["mito2/vector_index"] +vector_index = ["mito2/vector_index", "query/vector_index"] [lints] workspace = true diff --git a/src/common/function/src/scalars/vector.rs b/src/common/function/src/scalars/vector.rs index f265cfe53a..968231aa0a 100644 --- a/src/common/function/src/scalars/vector.rs +++ b/src/common/function/src/scalars/vector.rs @@ -13,7 +13,7 @@ // limitations under the License. mod convert; -mod distance; +pub mod distance; mod elem_avg; mod elem_product; mod elem_sum; @@ -33,6 +33,7 @@ use std::borrow::Cow; use datafusion_common::{DataFusionError, Result, ScalarValue, utils}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datatypes::arrow::array::new_empty_array; use crate::function_registry::FunctionRegistry; use crate::scalars::vector::impl_conv::as_veclit; @@ -128,6 +129,11 @@ where } let len = ensure_same_length(&[arg0, arg1])?; + if len == 0 { + return Ok(ColumnarValue::Array(new_empty_array( + args.return_field.data_type(), + ))); + } let mut results = Vec::with_capacity(len); for i in 0..len { let v0 = try_get_scalar_value!(arg0, i); @@ -155,6 +161,11 @@ where } let len = ensure_same_length(&[arg0, arg1])?; + if len == 0 { + return Ok(ColumnarValue::Array(new_empty_array( + args.return_field.data_type(), + ))); + } let mut results = Vec::with_capacity(len); match (arg0, arg1) { @@ -210,6 +221,11 @@ where }; let len = arg0.len(); + if len == 0 { + return Ok(ColumnarValue::Array(new_empty_array( + args.return_field.data_type(), + ))); + } let mut results = Vec::with_capacity(len); for i in 0..len { let v = ScalarValue::try_from_array(arg0, i)?; diff --git a/src/common/function/src/scalars/vector/distance.rs b/src/common/function/src/scalars/vector/distance.rs index ab063fe7b2..bfe0524a62 100644 --- a/src/common/function/src/scalars/vector/distance.rs +++ b/src/common/function/src/scalars/vector/distance.rs @@ -16,6 +16,10 @@ mod cos; mod dot; mod l2sq; +pub const VEC_COS_DISTANCE: &str = "vec_cos_distance"; +pub const VEC_L2SQ_DISTANCE: &str = "vec_l2sq_distance"; +pub const VEC_DOT_PRODUCT: &str = "vec_dot_product"; + use std::borrow::Cow; use std::fmt::Display; @@ -109,9 +113,9 @@ macro_rules! define_distance_function { }; } -define_distance_function!(CosDistanceFunction, "vec_cos_distance", cos::cos); -define_distance_function!(L2SqDistanceFunction, "vec_l2sq_distance", l2sq::l2sq); -define_distance_function!(DotProductFunction, "vec_dot_product", dot::dot); +define_distance_function!(CosDistanceFunction, VEC_COS_DISTANCE, cos::cos); +define_distance_function!(L2SqDistanceFunction, VEC_L2SQ_DISTANCE, l2sq::l2sq); +define_distance_function!(DotProductFunction, VEC_DOT_PRODUCT, dot::dot); #[cfg(test)] mod tests { diff --git a/src/mito2/src/read/scan_util.rs b/src/mito2/src/read/scan_util.rs index 06f5c1392c..0022ff7549 100644 --- a/src/mito2/src/read/scan_util.rs +++ b/src/mito2/src/read/scan_util.rs @@ -165,6 +165,8 @@ pub(crate) struct ScanMetricsSet { rows_bloom_filtered: usize, /// Number of rows filtered by vector index. rows_vector_filtered: usize, + /// Number of rows selected by vector index. + rows_vector_selected: usize, /// Number of rows filtered by precise filter. rows_precise_filtered: usize, /// Number of index result cache hits for fulltext index. @@ -291,6 +293,7 @@ impl fmt::Debug for ScanMetricsSet { rows_inverted_filtered, rows_bloom_filtered, rows_vector_filtered, + rows_vector_selected, rows_precise_filtered, fulltext_index_cache_hit, fulltext_index_cache_miss, @@ -384,6 +387,9 @@ impl fmt::Debug for ScanMetricsSet { if *rows_vector_filtered > 0 { write!(f, ", \"rows_vector_filtered\":{rows_vector_filtered}")?; } + if *rows_vector_selected > 0 { + write!(f, ", \"rows_vector_selected\":{rows_vector_selected}")?; + } if *rows_precise_filtered > 0 { write!(f, ", \"rows_precise_filtered\":{rows_precise_filtered}")?; } @@ -600,6 +606,7 @@ impl ScanMetricsSet { rows_inverted_filtered, rows_bloom_filtered, rows_vector_filtered, + rows_vector_selected, rows_precise_filtered, fulltext_index_cache_hit, fulltext_index_cache_miss, @@ -636,6 +643,7 @@ impl ScanMetricsSet { self.rows_inverted_filtered += *rows_inverted_filtered; self.rows_bloom_filtered += *rows_bloom_filtered; self.rows_vector_filtered += *rows_vector_filtered; + self.rows_vector_selected += *rows_vector_selected; self.rows_precise_filtered += *rows_precise_filtered; self.fulltext_index_cache_hit += *fulltext_index_cache_hit; diff --git a/src/mito2/src/sst/parquet/reader.rs b/src/mito2/src/sst/parquet/reader.rs index b726d3294f..c221ddffc1 100644 --- a/src/mito2/src/sst/parquet/reader.rs +++ b/src/mito2/src/sst/parquet/reader.rs @@ -979,6 +979,7 @@ impl ParquetReaderBuilder { return; } }; + metrics.rows_vector_selected += selection.row_count(); apply_selection_and_update_metrics(output, &selection, metrics, INDEX_TYPE_VECTOR); } @@ -1229,6 +1230,8 @@ pub(crate) struct ReaderFilterMetrics { pub(crate) rows_bloom_filtered: usize, /// Number of rows filtered by vector index. pub(crate) rows_vector_filtered: usize, + /// Number of rows selected by vector index. + pub(crate) rows_vector_selected: usize, /// Number of rows filtered by precise filter. pub(crate) rows_precise_filtered: usize, @@ -1268,6 +1271,7 @@ impl ReaderFilterMetrics { self.rows_inverted_filtered += other.rows_inverted_filtered; self.rows_bloom_filtered += other.rows_bloom_filtered; self.rows_vector_filtered += other.rows_vector_filtered; + self.rows_vector_selected += other.rows_vector_selected; self.rows_precise_filtered += other.rows_precise_filtered; self.fulltext_index_cache_hit += other.fulltext_index_cache_hit; diff --git a/src/query/Cargo.toml b/src/query/Cargo.toml index 0d19450b4f..523b49a3b1 100644 --- a/src/query/Cargo.toml +++ b/src/query/Cargo.toml @@ -9,6 +9,7 @@ workspace = true [features] enterprise = [] +vector_index = [] [dependencies] ahash.workspace = true diff --git a/src/query/src/dist_plan/analyzer/test.rs b/src/query/src/dist_plan/analyzer/test.rs index 7d4578f3dd..47848c9ceb 100644 --- a/src/query/src/dist_plan/analyzer/test.rs +++ b/src/query/src/dist_plan/analyzer/test.rs @@ -162,6 +162,188 @@ impl Stream for EmptyStream { } } +#[cfg(feature = "vector_index")] +mod vector_search_tests { + use std::sync::Arc; + + use common_function::function::Function; + use common_function::scalars::udf::create_udf; + use datafusion_expr::expr::ScalarFunction; + use datafusion_expr::{Expr, LogicalPlanBuilder, Signature, Volatility, col, lit}; + use datatypes::schema::{ColumnSchema, SchemaBuilder}; + use store_api::storage::ConcreteDataType; + use table::metadata::{FilterPushDownType, TableInfoBuilder, TableMeta, TableType}; + use table::table::adapter::DfTableProviderAdapter; + use table::{Table, TableRef}; + + use super::*; + use crate::dist_plan::MergeScanLogicalPlan; + + struct TestVectorFunction { + name: &'static str, + signature: Signature, + } + + impl TestVectorFunction { + fn new(name: &'static str) -> Self { + Self { + name, + signature: Signature::any(2, Volatility::Immutable), + } + } + } + + impl std::fmt::Display for TestVectorFunction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.name) + } + } + + impl Function for TestVectorFunction { + fn name(&self) -> &str { + self.name + } + + fn return_type( + &self, + _input_types: &[datatypes::arrow::datatypes::DataType], + ) -> datafusion_common::Result { + Ok(datatypes::arrow::datatypes::DataType::Float32) + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn invoke_with_args( + &self, + _args: datafusion_expr::ScalarFunctionArgs, + ) -> datafusion_common::Result { + Err(datafusion_common::DataFusionError::Execution( + "test udf should not be invoked".to_string(), + )) + } + } + + fn build_vector_table(table_id: TableId) -> TableRef { + let schema = { + let columns = vec![ + ColumnSchema::new("k0", ConcreteDataType::string_datatype(), true), + ColumnSchema::new( + "ts", + ConcreteDataType::timestamp_millisecond_datatype(), + false, + ) + .with_time_index(true), + ColumnSchema::new("v", ConcreteDataType::vector_datatype(2), false), + ]; + Arc::new( + SchemaBuilder::try_from_columns(columns) + .unwrap() + .build() + .unwrap(), + ) + }; + + let table_meta = TableMeta { + schema: schema.clone(), + primary_key_indices: vec![0], + value_indices: vec![2], + engine: "test_engine".to_string(), + next_column_id: 3, + options: Default::default(), + created_on: Default::default(), + updated_on: Default::default(), + partition_key_indices: vec![0], + column_ids: vec![0, 1, 2], + }; + + let table_info = TableInfoBuilder::default() + .table_id(table_id) + .name("t".to_string()) + .catalog_name(DEFAULT_CATALOG_NAME) + .schema_name(DEFAULT_SCHEMA_NAME) + .table_version(0) + .table_type(TableType::Base) + .meta(table_meta) + .build() + .unwrap(); + + let data_source = Arc::new(TestDataSource::new(schema)); + Arc::new(Table::new( + Arc::new(table_info), + FilterPushDownType::Unsupported, + data_source, + )) + } + + fn vector_distance_expr() -> Expr { + let udf = create_udf(Arc::new(TestVectorFunction::new("vec_l2sq_distance"))); + Expr::ScalarFunction(ScalarFunction::new_udf( + Arc::new(udf), + vec![ + col("v"), + lit(ScalarValue::Utf8(Some("[1.0, 2.0]".to_string()))), + ], + )) + } + + #[test] + fn vector_search_rewrite_keeps_sort_in_child_plan() { + init_default_ut_logging(); + let table = build_vector_table(0); + let table_source = Arc::new(DefaultTableSource::new(Arc::new( + DfTableProviderAdapter::new(table), + ))); + + let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![]) + .unwrap() + .sort(vec![vector_distance_expr().sort(true, false)]) + .unwrap() + .limit(0, Some(5)) + .unwrap() + .build() + .unwrap(); + + let config = ConfigOptions::default(); + let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap(); + + let plan_str = result.to_string(); + assert!(plan_str.contains("MergeSort: vec_l2sq_distance")); + assert!(plan_str.contains("Sort: vec_l2sq_distance")); + assert!(plan_str.contains(MergeScanLogicalPlan::name())); + } + + #[test] + fn vector_search_rewrite_with_filter_keeps_sort_in_child_plan() { + init_default_ut_logging(); + let table = build_vector_table(0); + let table_source = Arc::new(DefaultTableSource::new(Arc::new( + DfTableProviderAdapter::new(table), + ))); + + let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![]) + .unwrap() + .filter(col("k0").eq(lit("hello"))) + .unwrap() + .sort(vec![vector_distance_expr().sort(true, false)]) + .unwrap() + .limit(0, Some(5)) + .unwrap() + .build() + .unwrap(); + + let config = ConfigOptions::default(); + let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap(); + + let plan_str = result.to_string(); + assert!(plan_str.contains("MergeSort: vec_l2sq_distance")); + assert!(plan_str.contains("Sort: vec_l2sq_distance")); + assert!(plan_str.contains("Filter: t.k0 = Utf8(\"hello\")")); + assert!(plan_str.contains(MergeScanLogicalPlan::name())); + } +} + fn try_encode_decode_substrait(plan: &LogicalPlan, state: SessionState) { let sub_plan_bytes = substrait::DFLogicalSubstraitConvertor .encode(plan, crate::query_engine::DefaultSerializer) diff --git a/src/query/src/dist_plan/commutativity.rs b/src/query/src/dist_plan/commutativity.rs index 6ecad902f7..4aa0076bf1 100644 --- a/src/query/src/dist_plan/commutativity.rs +++ b/src/query/src/dist_plan/commutativity.rs @@ -16,9 +16,17 @@ use std::collections::HashSet; use std::sync::Arc; use common_function::aggrs::aggr_wrapper::{StateMergeHelper, is_all_aggr_exprs_steppable}; +#[cfg(feature = "vector_index")] +use common_function::scalars::vector::distance::{ + VEC_COS_DISTANCE, VEC_DOT_PRODUCT, VEC_L2SQ_DISTANCE, +}; use common_telemetry::debug; use datafusion::error::Result as DfResult; +#[cfg(feature = "vector_index")] +use datafusion_common::DataFusionError; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; +#[cfg(feature = "vector_index")] +use datafusion_expr::Sort; use datafusion_expr::{Expr, LogicalPlan, UserDefinedLogicalNode}; use promql::extension_plan::{ EmptyMetric, InstantManipulate, RangeManipulate, SeriesDivide, SeriesNormalize, @@ -29,6 +37,37 @@ use crate::dist_plan::MergeScanLogicalPlan; use crate::dist_plan::analyzer::AliasMapping; use crate::dist_plan::merge_sort::{MergeSortLogicalPlan, merge_sort_transformer}; +#[cfg(feature = "vector_index")] +fn is_vector_sort(sort: &Sort) -> bool { + if sort.expr.len() != 1 { + return false; + } + let sort_expr = &sort.expr[0].expr; + let Expr::ScalarFunction(func) = sort_expr else { + return false; + }; + matches!( + func.name().to_lowercase().as_str(), + VEC_L2SQ_DISTANCE | VEC_COS_DISTANCE | VEC_DOT_PRODUCT + ) +} + +#[cfg(feature = "vector_index")] +fn vector_sort_transformer(plan: &LogicalPlan) -> DfResult { + let LogicalPlan::Sort(sort) = plan else { + return Err(DataFusionError::Internal(format!( + "vector_sort_transformer expects Sort, got {plan}" + ))); + }; + Ok(TransformerAction { + extra_parent_plans: vec![ + MergeSortLogicalPlan::new(sort.input.clone(), sort.expr.clone(), sort.fetch) + .into_logical_plan(), + ], + new_child_plan: Some(LogicalPlan::Sort(sort.clone())), + }) +} + pub struct StepTransformAction { extra_parent_plans: Vec, new_child_plan: Option, @@ -150,14 +189,19 @@ impl Categorizer { // commutativity is needed under this situation. Commutativity::ConditionalCommutative(None) } - LogicalPlan::Sort(_) => { + LogicalPlan::Sort(_sort) => { if partition_cols.is_empty() { return Ok(Commutativity::Commutative); } // sort plan needs to consider column priority - // Change Sort to MergeSort which assumes the input streams are already sorted hence can be more efficient - // We should ensure the number of partition is not smaller than the number of region at present. Otherwise this would result in incorrect output. + // Change Sort to MergeSort which assumes the input streams are already sorted hence can be more efficient. + #[cfg(feature = "vector_index")] + if is_vector_sort(_sort) { + return Ok(Commutativity::TransformedCommutative { + transformer: Some(Arc::new(vector_sort_transformer)), + }); + } Commutativity::ConditionalCommutative(Some(Arc::new(merge_sort_transformer))) } LogicalPlan::Join(_) => Commutativity::NonCommutative, diff --git a/src/query/src/dummy_catalog.rs b/src/query/src/dummy_catalog.rs index 907b5e8c99..239cf7cea8 100644 --- a/src/query/src/dummy_catalog.rs +++ b/src/query/src/dummy_catalog.rs @@ -35,7 +35,9 @@ use session::context::{QueryContext, QueryContextRef}; use snafu::ResultExt; use store_api::metadata::RegionMetadataRef; use store_api::region_engine::RegionEngineRef; -use store_api::storage::{RegionId, ScanRequest, TimeSeriesDistribution, TimeSeriesRowSelector}; +use store_api::storage::{ + RegionId, ScanRequest, TimeSeriesDistribution, TimeSeriesRowSelector, VectorSearchRequest, +}; use table::TableRef; use table::metadata::{TableId, TableInfoRef}; use table::table::scan::RegionScanExec; @@ -256,6 +258,14 @@ impl DummyTableProvider { self.scan_request.lock().unwrap().series_row_selector = Some(selector); } + pub fn with_vector_search_hint(&self, hint: VectorSearchRequest) { + self.scan_request.lock().unwrap().vector_search = Some(hint); + } + + pub fn get_vector_search_hint(&self) -> Option { + self.scan_request.lock().unwrap().vector_search.clone() + } + pub fn with_sequence(&self, sequence: u64) { self.scan_request.lock().unwrap().memtable_max_sequence = Some(sequence); } diff --git a/src/query/src/optimizer/scan_hint.rs b/src/query/src/optimizer/scan_hint.rs index 1d128d5a68..c06c6b7812 100644 --- a/src/query/src/optimizer/scan_hint.rs +++ b/src/query/src/optimizer/scan_hint.rs @@ -28,6 +28,10 @@ use store_api::metric_engine_consts::DATA_SCHEMA_TSID_COLUMN_NAME; use store_api::storage::{TimeSeriesDistribution, TimeSeriesRowSelector}; use crate::dummy_catalog::DummyTableProvider; +#[cfg(feature = "vector_index")] +mod vector_search; +#[cfg(feature = "vector_index")] +use vector_search::VectorSearchState; /// This rule will traverse the plan to collect necessary hints for leaf /// table scan node and set them in [`ScanRequest`]. Hints include: @@ -59,13 +63,16 @@ impl ScanHintRule { let _ = plan.visit(&mut visitor)?; if visitor.need_rewrite() { - plan.transform_down(&|plan| Self::set_hints(plan, &visitor)) + plan.transform_down(&mut |plan| Self::set_hints(plan, &mut visitor)) } else { Ok(Transformed::no(plan)) } } - fn set_hints(plan: LogicalPlan, visitor: &ScanHintVisitor) -> Result> { + fn set_hints( + plan: LogicalPlan, + visitor: &mut ScanHintVisitor, + ) -> Result> { match &plan { LogicalPlan::TableScan(table_scan) => { let mut transformed = false; @@ -94,6 +101,13 @@ impl ScanHintRule { ); } + #[cfg(feature = "vector_index")] + if let Some(vector_request) = visitor + .vector_search + .take_vector_request_from_dummy(adapter, &table_scan.table_name) + { + adapter.with_vector_search_hint(vector_request); + } transformed = true; } } @@ -221,15 +235,29 @@ struct ScanHintVisitor { /// This field stores saved `group_by` columns when all aggregate functions are `last_value` /// and the `order_by` column which should be time index. ts_row_selector: Option<(HashSet, Column)>, + #[cfg(feature = "vector_index")] + vector_search: VectorSearchState, } impl TreeNodeVisitor<'_> for ScanHintVisitor { type Node = LogicalPlan; fn f_down(&mut self, node: &Self::Node) -> Result { + #[cfg(feature = "vector_index")] + if let LogicalPlan::Limit(limit) = node { + // Track LIMIT so vector hint k can be derived within the same input chain. + self.vector_search.on_limit_enter(limit); + } + // Get order requirement from sort plan if let LogicalPlan::Sort(sort) = node { self.order_expr = Some(sort.expr.clone()); + + #[cfg(feature = "vector_index")] + { + // Capture vector ORDER BY and TopK hints from sort nodes. + self.vector_search.on_sort_enter(sort); + } } // Get time series row selector from aggr plan @@ -294,12 +322,17 @@ impl TreeNodeVisitor<'_> for ScanHintVisitor { } } - if self.ts_row_selector.is_some() - && (matches!(node, LogicalPlan::Subquery(_)) || node.inputs().len() > 1) - { + // Avoid carrying vector hints across branching inputs (join/subquery) to prevent + // pruning results before global ordering is applied. + let is_branching = matches!(node, LogicalPlan::Subquery(_)) || node.inputs().len() > 1; + if is_branching && self.ts_row_selector.is_some() { // clean previous time series selector hint when encounter subqueries or join self.ts_row_selector = None; } + #[cfg(feature = "vector_index")] + if is_branching { + self.vector_search.on_branching_enter(); + } if let LogicalPlan::Filter(filter) = node && let Some(group_by_exprs) = &self.ts_row_selector @@ -312,13 +345,56 @@ impl TreeNodeVisitor<'_> for ScanHintVisitor { } } + #[cfg(feature = "vector_index")] + if let LogicalPlan::Filter(filter) = node { + self.vector_search.on_filter_enter(&filter.predicate); + } + + #[cfg(feature = "vector_index")] + if let LogicalPlan::TableScan(table_scan) = node { + // Record vector hints at leaf scans after scope checks. + self.vector_search.on_table_scan(table_scan); + } + + Ok(TreeNodeRecursion::Continue) + } + + fn f_up(&mut self, _node: &Self::Node) -> Result { + #[cfg(feature = "vector_index")] + match _node { + LogicalPlan::Limit(_) => { + self.vector_search.on_limit_exit(); + } + LogicalPlan::Sort(_) => { + self.vector_search.on_sort_exit(); + } + LogicalPlan::Filter(_) => { + self.vector_search.on_filter_exit(); + } + LogicalPlan::Subquery(_) => { + self.vector_search.on_branching_exit(); + } + _ if _node.inputs().len() > 1 => { + self.vector_search.on_branching_exit(); + } + _ => {} + } + Ok(TreeNodeRecursion::Continue) } } impl ScanHintVisitor { fn need_rewrite(&self) -> bool { - self.order_expr.is_some() || self.ts_row_selector.is_some() + let base = self.order_expr.is_some() || self.ts_row_selector.is_some(); + #[cfg(feature = "vector_index")] + { + base || self.vector_search.need_rewrite() + } + #[cfg(not(feature = "vector_index"))] + { + base + } } } diff --git a/src/query/src/optimizer/scan_hint/vector_search.rs b/src/query/src/optimizer/scan_hint/vector_search.rs new file mode 100644 index 0000000000..44f76066a1 --- /dev/null +++ b/src/query/src/optimizer/scan_hint/vector_search.rs @@ -0,0 +1,837 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::{HashMap, HashSet, VecDeque}; + +use common_function::scalars::vector::distance::{ + VEC_COS_DISTANCE, VEC_DOT_PRODUCT, VEC_L2SQ_DISTANCE, +}; +use common_telemetry::debug; +use datafusion_common::ScalarValue; +use datafusion_expr::logical_plan::FetchType; +use datafusion_expr::utils::split_conjunction; +use datafusion_expr::{Expr, SortExpr}; +use datafusion_sql::TableReference; +use datatypes::types::parse_string_to_vector_type_value; +use store_api::storage::{VectorDistanceMetric, VectorSearchRequest}; + +use crate::dummy_catalog::DummyTableProvider; + +/// Tracks vector search hints while traversing the logical plan. +/// +/// Vector search requests are emitted only when: +/// - ORDER BY uses a supported vector distance function and its direction matches the metric. +/// - A LIMIT (or Sort.fetch) is present to derive k. +/// - The hint stays within a single input chain (not across join/subquery branches). +/// - The target column is non-nullable, or an explicit IS NOT NULL filter exists. +#[derive(Default)] +pub(crate) struct VectorSearchState { + current_distance: Option, + current_limit: Option, + distance_stack: Vec>, + limit_stack: Vec>, + non_null_columns: HashSet, + non_null_stack: Vec>, + vector_hints: HashMap>, +} + +#[derive(Clone)] +struct VectorDistanceInfo { + table_reference: Option, + column_name: String, + query_vector: Vec, + metric: VectorDistanceMetric, +} + +#[derive(Clone)] +struct VectorLimitInfo { + fetch: usize, + skip: usize, +} + +impl VectorLimitInfo { + fn k(&self) -> Option { + self.fetch.checked_add(self.skip) + } +} + +#[derive(Clone)] +struct VectorHintEntry { + distance: VectorDistanceInfo, + limit: VectorLimitInfo, + non_null_constraint: bool, +} + +#[derive(Clone, Hash, PartialEq, Eq)] +struct ColumnKey { + table_reference: Option, + column_name: String, +} + +impl VectorSearchState { + pub(crate) fn need_rewrite(&self) -> bool { + !self.vector_hints.is_empty() + } + + pub(crate) fn on_branching_enter(&mut self) { + // Clear per-branch state so hints are only derived within a single input chain. + self.distance_stack.push(self.current_distance.take()); + self.limit_stack.push(self.current_limit.take()); + self.non_null_stack + .push(std::mem::take(&mut self.non_null_columns)); + } + + pub(crate) fn on_branching_exit(&mut self) { + // Restore the prior chain state after leaving the branch. + if let Some(previous) = self.limit_stack.pop() { + self.current_limit = previous; + } + if let Some(previous) = self.distance_stack.pop() { + self.current_distance = previous; + } + if let Some(previous) = self.non_null_stack.pop() { + self.non_null_columns = previous; + } + } + + pub(crate) fn on_limit_enter(&mut self, limit: &datafusion_expr::logical_plan::Limit) { + self.limit_stack.push(self.current_limit.take()); + self.current_limit = Self::extract_limit_info(limit); + } + + pub(crate) fn on_limit_exit(&mut self) { + if let Some(previous) = self.limit_stack.pop() { + self.current_limit = previous; + } + } + + pub(crate) fn on_sort_enter(&mut self, sort: &datafusion_expr::logical_plan::Sort) { + // Distance is scoped to the nearest sort, while limit may be inherited from parents. + self.distance_stack.push(self.current_distance.take()); + self.limit_stack.push(self.current_limit.clone()); + let distance = Self::extract_distance_from_sort(sort); + self.current_distance = distance.clone(); + // Sort.fetch is a TopK limit, so we can infer k when no LIMIT is present. + if self.current_limit.is_none() { + self.current_limit = distance + .as_ref() + .and_then(|_| Self::extract_limit_from_sort(sort)); + } + } + + pub(crate) fn on_sort_exit(&mut self) { + if let Some(previous) = self.limit_stack.pop() { + self.current_limit = previous; + } + if let Some(previous) = self.distance_stack.pop() { + self.current_distance = previous; + } + } + + pub(crate) fn on_table_scan(&mut self, table_scan: &datafusion_expr::logical_plan::TableScan) { + self.record_vector_hint(table_scan); + } + + pub(crate) fn on_filter_enter(&mut self, predicate: &Expr) { + self.non_null_stack.push(self.non_null_columns.clone()); + for expr in split_conjunction(predicate) { + if let Expr::IsNotNull(inner) = expr + && let Expr::Column(col) = inner.as_ref() + { + self.non_null_columns.insert(ColumnKey { + table_reference: col.relation.clone(), + column_name: col.name.clone(), + }); + } + } + // TODO: detect non-null constraints from more complex predicates (casts/functions). + } + + pub(crate) fn on_filter_exit(&mut self) { + if let Some(previous) = self.non_null_stack.pop() { + self.non_null_columns = previous; + } + } + + pub(crate) fn take_vector_request_from_dummy( + &mut self, + provider: &DummyTableProvider, + table_name: &TableReference, + ) -> Option { + let hint = self.take_vector_hint(table_name)?; + self.build_vector_request_from_dummy(provider, table_name, &hint) + } + + fn build_vector_request_from_dummy( + &self, + provider: &DummyTableProvider, + table_name: &TableReference, + hint: &VectorHintEntry, + ) -> Option { + let info = &hint.distance; + let k = hint.limit.k()?; + + if let Some(ref hint_table) = info.table_reference + && table_name != hint_table + { + return None; + } + + let metadata = provider.region_metadata(); + let column = metadata.column_by_name(&info.column_name)?; + if column.column_schema.is_nullable() && !hint.non_null_constraint { + debug!( + "Skip vector hint: column '{}' is nullable without IS NOT NULL filter", + info.column_name + ); + return None; + } + + Some(VectorSearchRequest { + column_id: column.column_id, + query_vector: info.query_vector.clone(), + k, + metric: info.metric, + }) + } + + fn extract_distance_info(expr: &Expr) -> Option { + let Expr::ScalarFunction(func) = expr else { + return None; + }; + + let func_name = func.name().to_lowercase(); + let metric = match func_name.as_str() { + VEC_L2SQ_DISTANCE => VectorDistanceMetric::L2sq, + VEC_COS_DISTANCE => VectorDistanceMetric::Cosine, + VEC_DOT_PRODUCT => VectorDistanceMetric::InnerProduct, + _ => return None, + }; + + if func.args.len() != 2 { + return None; + } + + let (table_reference, column_name) = match &func.args[0] { + Expr::Column(col) => (col.relation.clone(), col.name.clone()), + _ => return None, + }; + + let query_vector = Self::extract_query_vector(&func.args[1])?; + + Some(VectorDistanceInfo { + table_reference, + column_name, + query_vector, + metric, + }) + } + + fn extract_distance_from_sort( + sort: &datafusion_expr::logical_plan::Sort, + ) -> Option { + if sort.expr.len() != 1 { + debug!( + "Skip vector hint: Sort has {} expressions, expected 1", + sort.expr.len() + ); + return None; + } + let sort_expr: &SortExpr = &sort.expr[0]; + let info = Self::extract_distance_info(&sort_expr.expr)?; + let expected_asc = info.metric != VectorDistanceMetric::InnerProduct; + if sort_expr.asc == expected_asc { + Some(info) + } else { + None + } + } + + fn extract_limit_info(limit: &datafusion_expr::logical_plan::Limit) -> Option { + let fetch = match limit.get_fetch_type().ok()? { + FetchType::Literal(fetch) => fetch?, + FetchType::UnsupportedExpr => return None, + }; + let skip = match limit.get_skip_type().ok()? { + datafusion_expr::logical_plan::SkipType::Literal(skip) => skip, + datafusion_expr::logical_plan::SkipType::UnsupportedExpr => return None, + }; + Some(VectorLimitInfo { fetch, skip }) + } + + fn extract_limit_from_sort( + sort: &datafusion_expr::logical_plan::Sort, + ) -> Option { + let fetch = sort.fetch?; + Some(VectorLimitInfo { fetch, skip: 0 }) + } + + fn record_vector_hint(&mut self, table_scan: &datafusion_expr::logical_plan::TableScan) { + let Some(limit) = self.current_limit.as_ref() else { + return; + }; + let Some(distance) = self.current_distance.as_ref() else { + return; + }; + // Only emit hints when distance+limit are present and the table matches the sort target. + if let Some(ref hint_table) = distance.table_reference + && hint_table != &table_scan.table_name + { + return; + } + + let non_null_constraint = + self.is_column_non_null(&distance.table_reference, &distance.column_name); + self.vector_hints + .entry(table_scan.table_name.clone()) + .or_default() + .push_back(VectorHintEntry { + distance: distance.clone(), + limit: limit.clone(), + non_null_constraint, + }); + } + + fn take_vector_hint(&mut self, table_name: &TableReference) -> Option { + let hints = self.vector_hints.get_mut(table_name)?; + let hint = hints.pop_front(); + if hints.is_empty() { + self.vector_hints.remove(table_name); + } + hint + } + + fn is_column_non_null( + &self, + table_reference: &Option, + column_name: &str, + ) -> bool { + self.non_null_columns.contains(&ColumnKey { + table_reference: table_reference.clone(), + column_name: column_name.to_string(), + }) + } + + fn extract_query_vector(expr: &Expr) -> Option> { + match expr { + Expr::Literal(scalar, _) => match scalar { + ScalarValue::Utf8(Some(s)) | ScalarValue::LargeUtf8(Some(s)) => { + Self::parse_json_vector(s) + } + ScalarValue::Binary(Some(bytes)) | ScalarValue::LargeBinary(Some(bytes)) => { + Self::parse_binary_vector(bytes) + } + _ => None, + }, + _ => None, + } + } + + fn parse_json_vector(s: &str) -> Option> { + let trimmed = s.trim(); + if !trimmed.starts_with('[') || !trimmed.ends_with(']') { + return None; + } + parse_string_to_vector_type_value(trimmed, None) + .ok() + .and_then(|bytes| if bytes.is_empty() { None } else { Some(bytes) }) + .and_then(|bytes| Self::parse_binary_vector(&bytes)) + } + + fn parse_binary_vector(bytes: &[u8]) -> Option> { + if bytes.is_empty() || !bytes.len().is_multiple_of(4) { + return None; + } + Some( + bytes + .chunks(4) + .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])) + .collect(), + ) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use api::v1::SemanticType; + use common_function::function::Function; + use common_function::scalars::udf::create_udf; + use common_function::scalars::vector::distance::{VEC_DOT_PRODUCT, VEC_L2SQ_DISTANCE}; + use datafusion::datasource::DefaultTableSource; + use datafusion::logical_expr::ColumnarValue; + use datafusion_common::{Column, DataFusionError, Result, ScalarValue}; + use datafusion_expr::expr::ScalarFunction; + use datafusion_expr::logical_plan::JoinType; + use datafusion_expr::{ + Expr, LogicalPlan, LogicalPlanBuilder, Signature, Subquery, Volatility, col, lit, + }; + use datafusion_optimizer::{OptimizerContext, OptimizerRule}; + use datatypes::schema::ColumnSchema; + use store_api::metadata::{ColumnMetadata, RegionMetadataBuilder}; + use store_api::storage::{ConcreteDataType, VectorDistanceMetric}; + + use super::VectorSearchState; + use crate::dummy_catalog::DummyTableProvider; + use crate::optimizer::scan_hint::ScanHintRule; + use crate::optimizer::test_util::MetaRegionEngine; + + struct TestVectorFunction { + name: &'static str, + signature: Signature, + } + + impl TestVectorFunction { + fn new(name: &'static str) -> Self { + Self { + name, + signature: Signature::any(2, Volatility::Immutable), + } + } + } + + impl std::fmt::Display for TestVectorFunction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.name) + } + } + + impl Function for TestVectorFunction { + fn name(&self) -> &str { + self.name + } + + fn return_type( + &self, + _input_types: &[datatypes::arrow::datatypes::DataType], + ) -> Result { + Ok(datatypes::arrow::datatypes::DataType::Float32) + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn invoke_with_args( + &self, + _args: datafusion_expr::ScalarFunctionArgs, + ) -> Result { + Err(DataFusionError::Execution( + "test udf should not be invoked".to_string(), + )) + } + } + + fn vec_distance_expr(function_name: &'static str) -> Expr { + let udf = create_udf(Arc::new(TestVectorFunction::new(function_name))); + Expr::ScalarFunction(ScalarFunction::new_udf( + Arc::new(udf), + vec![ + col("v"), + lit(ScalarValue::Utf8(Some("[1.0, 2.0]".to_string()))), + ], + )) + } + + fn build_dummy_provider(column_id: u32) -> Arc { + build_dummy_provider_with_nullable(column_id, false) + } + + fn build_dummy_provider_with_nullable( + column_id: u32, + nullable_vector: bool, + ) -> Arc { + let mut builder = RegionMetadataBuilder::new(0.into()); + builder + .push_column_metadata(ColumnMetadata { + column_schema: ColumnSchema::new("k0", ConcreteDataType::string_datatype(), true), + semantic_type: SemanticType::Tag, + column_id: 1, + }) + .push_column_metadata(ColumnMetadata { + column_schema: ColumnSchema::new( + "ts", + ConcreteDataType::timestamp_millisecond_datatype(), + false, + ), + semantic_type: SemanticType::Timestamp, + column_id: 2, + }) + .push_column_metadata(ColumnMetadata { + column_schema: ColumnSchema::new( + "v", + ConcreteDataType::vector_datatype(2), + nullable_vector, + ), + semantic_type: SemanticType::Field, + column_id, + }) + .primary_key(vec![1]); + let metadata = Arc::new(builder.build().unwrap()); + let engine = Arc::new(MetaRegionEngine::with_metadata(metadata.clone())); + Arc::new(DummyTableProvider::new(0.into(), engine, metadata)) + } + + #[test] + fn test_parse_json_vector() { + assert_eq!( + VectorSearchState::parse_json_vector("[1.0, 2.0, 3.0]"), + Some(vec![1.0, 2.0, 3.0]) + ); + assert_eq!( + VectorSearchState::parse_json_vector("[1.5, -2.3, 0.0]"), + Some(vec![1.5, -2.3, 0.0]) + ); + assert_eq!(VectorSearchState::parse_json_vector("invalid"), None); + assert_eq!(VectorSearchState::parse_json_vector("[]"), None); + assert_eq!(VectorSearchState::parse_json_vector(""), None); + assert_eq!(VectorSearchState::parse_json_vector("["), None); + assert_eq!(VectorSearchState::parse_json_vector("[1.0, abc]"), None); + } + + #[test] + fn test_parse_binary_vector() { + let v1: f32 = 1.0; + let v2: f32 = 2.0; + let mut bytes = Vec::new(); + bytes.extend_from_slice(&v1.to_le_bytes()); + bytes.extend_from_slice(&v2.to_le_bytes()); + + let result = VectorSearchState::parse_binary_vector(&bytes); + assert_eq!(result, Some(vec![1.0, 2.0])); + + let result = VectorSearchState::parse_binary_vector(&[1, 2, 3]); + assert_eq!(result, None); + } + + #[test] + fn test_dummy_provider_vector_hint() { + let dummy_provider = build_dummy_provider(10); + let table_source = Arc::new(DefaultTableSource::new(dummy_provider.clone())); + let expr = vec_distance_expr(VEC_L2SQ_DISTANCE); + let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![]) + .unwrap() + .sort(vec![expr.sort(true, false)]) + .unwrap() + .limit(0, Some(5)) + .unwrap() + .build() + .unwrap(); + + let context = OptimizerContext::default(); + let _ = ScanHintRule.rewrite(plan, &context).unwrap(); + + let hint = dummy_provider.get_vector_search_hint().unwrap(); + assert_eq!(hint.column_id, 10); + assert_eq!(hint.k, 5); + assert_eq!(hint.metric, VectorDistanceMetric::L2sq); + assert_eq!(hint.query_vector, vec![1.0, 2.0]); + } + + #[test] + fn test_limit_offset_for_vector_hint() { + let dummy_provider = build_dummy_provider(10); + let table_source = Arc::new(DefaultTableSource::new(dummy_provider.clone())); + let expr = vec_distance_expr(VEC_L2SQ_DISTANCE); + let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![]) + .unwrap() + .sort(vec![expr.sort(true, false)]) + .unwrap() + .limit(5, Some(10)) + .unwrap() + .build() + .unwrap(); + + let context = OptimizerContext::default(); + let _ = ScanHintRule.rewrite(plan, &context).unwrap(); + + let hint = dummy_provider.get_vector_search_hint().unwrap(); + assert_eq!(hint.k, 15); + } + + #[test] + fn test_inner_product_sort_direction() { + let dummy_provider = build_dummy_provider(10); + let table_source = Arc::new(DefaultTableSource::new(dummy_provider.clone())); + let expr = vec_distance_expr(VEC_DOT_PRODUCT); + let plan = LogicalPlanBuilder::scan_with_filters("t", table_source.clone(), None, vec![]) + .unwrap() + .sort(vec![expr.clone().sort(true, false)]) + .unwrap() + .limit(0, Some(3)) + .unwrap() + .build() + .unwrap(); + + let context = OptimizerContext::default(); + let _ = ScanHintRule.rewrite(plan, &context).unwrap(); + assert!(dummy_provider.get_vector_search_hint().is_none()); + + let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![]) + .unwrap() + .sort(vec![expr.sort(false, false)]) + .unwrap() + .limit(0, Some(3)) + .unwrap() + .build() + .unwrap(); + let _ = ScanHintRule.rewrite(plan, &context).unwrap(); + let hint = dummy_provider.get_vector_search_hint().unwrap(); + assert_eq!(hint.metric, VectorDistanceMetric::InnerProduct); + } + + #[test] + fn test_no_limit_clause() { + let dummy_provider = build_dummy_provider(10); + let table_source = Arc::new(DefaultTableSource::new(dummy_provider.clone())); + let expr = vec_distance_expr(VEC_L2SQ_DISTANCE); + let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![]) + .unwrap() + .sort(vec![expr.sort(true, false)]) + .unwrap() + .build() + .unwrap(); + + let context = OptimizerContext::default(); + let _ = ScanHintRule.rewrite(plan, &context).unwrap(); + + assert!(dummy_provider.get_vector_search_hint().is_none()); + } + + #[test] + fn test_nullable_vector_requires_is_not_null_filter() { + let dummy_provider = build_dummy_provider_with_nullable(10, true); + let table_source = Arc::new(DefaultTableSource::new(dummy_provider.clone())); + let expr = vec_distance_expr(VEC_L2SQ_DISTANCE); + let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![]) + .unwrap() + .sort(vec![expr.sort(true, false)]) + .unwrap() + .limit(0, Some(5)) + .unwrap() + .build() + .unwrap(); + + let context = OptimizerContext::default(); + let _ = ScanHintRule.rewrite(plan, &context).unwrap(); + + assert!(dummy_provider.get_vector_search_hint().is_none()); + } + + #[test] + fn test_nullable_vector_with_is_not_null_filter() { + let dummy_provider = build_dummy_provider_with_nullable(10, true); + let table_source = Arc::new(DefaultTableSource::new(dummy_provider.clone())); + let expr = vec_distance_expr(VEC_L2SQ_DISTANCE); + let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![]) + .unwrap() + .filter(col("v").is_not_null()) + .unwrap() + .sort(vec![expr.sort(true, false)]) + .unwrap() + .limit(0, Some(5)) + .unwrap() + .build() + .unwrap(); + + let context = OptimizerContext::default(); + let _ = ScanHintRule.rewrite(plan, &context).unwrap(); + + let hint = dummy_provider.get_vector_search_hint().unwrap(); + assert_eq!(hint.column_id, 10); + assert_eq!(hint.k, 5); + } + + #[test] + fn test_sort_fetch_for_vector_hint() { + let dummy_provider = build_dummy_provider(10); + let table_source = Arc::new(DefaultTableSource::new(dummy_provider.clone())); + let expr = vec_distance_expr(VEC_L2SQ_DISTANCE); + let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![]) + .unwrap() + .sort_with_limit(vec![expr.sort(true, false)], Some(4)) + .unwrap() + .build() + .unwrap(); + + let context = OptimizerContext::default(); + let _ = ScanHintRule.rewrite(plan, &context).unwrap(); + + let hint = dummy_provider.get_vector_search_hint().unwrap(); + assert_eq!(hint.k, 4); + } + + #[test] + fn test_limit_scoped_to_sort_branch() { + let t1_provider = build_dummy_provider(10); + let t2_provider = build_dummy_provider(20); + let t1_source = Arc::new(DefaultTableSource::new(t1_provider.clone())); + let t2_source = Arc::new(DefaultTableSource::new(t2_provider.clone())); + let expr = vec_distance_expr(VEC_L2SQ_DISTANCE); + + let left = LogicalPlanBuilder::scan_with_filters("t1", t1_source, None, vec![]) + .unwrap() + .limit(0, Some(5)) + .unwrap() + .build() + .unwrap(); + + let right = LogicalPlanBuilder::scan_with_filters("t2", t2_source, None, vec![]) + .unwrap() + .sort(vec![expr.sort(true, false)]) + .unwrap() + .build() + .unwrap(); + + let join_keys: (Vec, Vec) = (vec![], vec![]); + let plan = LogicalPlanBuilder::from(left) + .join(right, JoinType::Inner, join_keys, None) + .unwrap() + .build() + .unwrap(); + + let context = OptimizerContext::default(); + let _ = ScanHintRule.rewrite(plan, &context).unwrap(); + + assert!(t1_provider.get_vector_search_hint().is_none()); + assert!(t2_provider.get_vector_search_hint().is_none()); + } + + fn vec_distance_expr_qualified( + function_name: &'static str, + table_name: &str, + column_name: &str, + ) -> Expr { + use datafusion_common::Column; + + let udf = create_udf(Arc::new(TestVectorFunction::new(function_name))); + let qualified_col = Expr::Column(Column::new(Some(table_name.to_string()), column_name)); + Expr::ScalarFunction(ScalarFunction::new_udf( + Arc::new(udf), + vec![ + qualified_col, + lit(ScalarValue::Utf8(Some("[1.0, 2.0]".to_string()))), + ], + )) + } + + #[test] + fn test_no_vector_hint_above_join() { + let t1_provider = build_dummy_provider(10); + let t2_provider = build_dummy_provider(20); + let t1_source = Arc::new(DefaultTableSource::new(t1_provider.clone())); + let t2_source = Arc::new(DefaultTableSource::new(t2_provider.clone())); + + let left = LogicalPlanBuilder::scan_with_filters("t1", t1_source, None, vec![]) + .unwrap() + .build() + .unwrap(); + + let right = LogicalPlanBuilder::scan_with_filters("t2", t2_source, None, vec![]) + .unwrap() + .build() + .unwrap(); + + let join_keys: (Vec, Vec) = (vec![], vec![]); + let join_plan = LogicalPlanBuilder::from(left) + .join(right, JoinType::Inner, join_keys, None) + .unwrap() + .build() + .unwrap(); + + let expr = vec_distance_expr_qualified(VEC_L2SQ_DISTANCE, "t1", "v"); + let plan = LogicalPlanBuilder::from(join_plan) + .sort(vec![expr.sort(true, false)]) + .unwrap() + .limit(0, Some(5)) + .unwrap() + .build() + .unwrap(); + + let context = OptimizerContext::default(); + let _ = ScanHintRule.rewrite(plan, &context).unwrap(); + + assert!(t1_provider.get_vector_search_hint().is_none()); + assert!(t2_provider.get_vector_search_hint().is_none()); + } + + #[test] + fn test_no_vector_hint_above_subquery() { + let provider = build_dummy_provider(10); + let table_source = Arc::new(DefaultTableSource::new(provider.clone())); + let scan_plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![]) + .unwrap() + .build() + .unwrap(); + + let subquery = LogicalPlan::Subquery(Subquery { + subquery: Arc::new(scan_plan), + outer_ref_columns: vec![], + spans: Default::default(), + }); + + let expr = vec_distance_expr(VEC_L2SQ_DISTANCE); + let plan = LogicalPlanBuilder::from(subquery) + .sort(vec![expr.sort(true, false)]) + .unwrap() + .limit(0, Some(5)) + .unwrap() + .build() + .unwrap(); + + let context = OptimizerContext::default(); + let _ = ScanHintRule.rewrite(plan, &context).unwrap(); + + assert!(provider.get_vector_search_hint().is_none()); + } + + #[test] + fn test_qualified_column_scopes_hint_to_correct_table() { + let t1_provider = build_dummy_provider(10); + let t2_provider = build_dummy_provider(20); + let t1_source = Arc::new(DefaultTableSource::new(t1_provider.clone())); + let t2_source = Arc::new(DefaultTableSource::new(t2_provider.clone())); + + let expr = vec_distance_expr_qualified(VEC_L2SQ_DISTANCE, "t2", "v"); + + let t1_plan = LogicalPlanBuilder::scan_with_filters("t1", t1_source, None, vec![]) + .unwrap() + .sort(vec![expr.clone().sort(true, false)]) + .unwrap() + .limit(0, Some(5)) + .unwrap() + .build() + .unwrap(); + + let t2_plan = LogicalPlanBuilder::scan_with_filters("t2", t2_source, None, vec![]) + .unwrap() + .sort(vec![expr.sort(true, false)]) + .unwrap() + .limit(0, Some(5)) + .unwrap() + .build() + .unwrap(); + + let context = OptimizerContext::default(); + + let _ = ScanHintRule.rewrite(t1_plan, &context).unwrap(); + assert!(t1_provider.get_vector_search_hint().is_none()); + + let _ = ScanHintRule.rewrite(t2_plan, &context).unwrap(); + let hint = t2_provider.get_vector_search_hint().unwrap(); + assert_eq!(hint.column_id, 20); + assert_eq!(hint.k, 5); + } +} diff --git a/src/table/src/table/adapter.rs b/src/table/src/table/adapter.rs index b341768bf3..e33bf53c84 100644 --- a/src/table/src/table/adapter.rs +++ b/src/table/src/table/adapter.rs @@ -26,7 +26,7 @@ use datafusion_expr::TableProviderFilterPushDown as DfTableProviderFilterPushDow use datafusion_expr::expr::Expr; use datafusion_physical_expr::PhysicalSortExpr; use datafusion_physical_expr::expressions::Column; -use store_api::storage::ScanRequest; +use store_api::storage::{ScanRequest, VectorSearchRequest}; use crate::table::{TableRef, TableType}; @@ -52,6 +52,14 @@ impl DfTableProviderAdapter { self.scan_req.lock().unwrap().output_ordering = Some(order_opts.to_vec()); } + pub fn with_vector_search_hint(&self, hint: VectorSearchRequest) { + self.scan_req.lock().unwrap().vector_search = Some(hint); + } + + pub fn get_vector_search_hint(&self) -> Option { + self.scan_req.lock().unwrap().vector_search.clone() + } + #[cfg(feature = "testing")] pub fn get_scan_req(&self) -> ScanRequest { self.scan_req.lock().unwrap().clone()