mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-14 12:00:40 +00:00
feat: impl vector index query (#7564)
* feat: impl vector index query Signed-off-by: Dennis Zhuang <killme2008@gmail.com> * feat: remove VectorSearchRule and merge it into scan hint rule Signed-off-by: Dennis Zhuang <killme2008@gmail.com> * refactor: vector search hint Signed-off-by: Dennis Zhuang <killme2008@gmail.com> * test: join and subquery Signed-off-by: Dennis Zhuang <killme2008@gmail.com> * fix: clippy when feature disabled Signed-off-by: Dennis Zhuang <killme2008@gmail.com> * fix: push hint only when column is non-nullable or an explicit IS NOT NULL filter exists Signed-off-by: Dennis Zhuang <killme2008@gmail.com> * fix: transformed = true Co-authored-by: Yingwen <realevenyag@gmail.com> Signed-off-by: Dennis Zhuang <killme2008@gmail.com> * chore: remove adpater vector hint Signed-off-by: Dennis Zhuang <killme2008@gmail.com> * chore: revert transformed Signed-off-by: Dennis Zhuang <killme2008@gmail.com> --------- Signed-off-by: Dennis Zhuang <killme2008@gmail.com> Co-authored-by: Yingwen <realevenyag@gmail.com>
This commit is contained in:
2
.github/workflows/develop.yml
vendored
2
.github/workflows/develop.yml
vendored
@@ -73,6 +73,8 @@ jobs:
|
||||
save-if: ${{ github.ref == 'refs/heads/main' }}
|
||||
- name: Run cargo check
|
||||
run: cargo check --locked --workspace --all-targets
|
||||
- name: Run cargo check (all features)
|
||||
run: cargo check --locked --workspace --all-targets --all-features
|
||||
|
||||
toml:
|
||||
if: ${{ github.repository == 'GreptimeTeam/greptimedb' }}
|
||||
|
||||
@@ -18,7 +18,7 @@ default = [
|
||||
]
|
||||
enterprise = ["common-meta/enterprise", "frontend/enterprise", "meta-srv/enterprise"]
|
||||
tokio-console = ["common-telemetry/tokio-console"]
|
||||
vector_index = ["mito2/vector_index"]
|
||||
vector_index = ["mito2/vector_index", "query/vector_index"]
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
mod convert;
|
||||
mod distance;
|
||||
pub mod distance;
|
||||
mod elem_avg;
|
||||
mod elem_product;
|
||||
mod elem_sum;
|
||||
@@ -33,6 +33,7 @@ use std::borrow::Cow;
|
||||
|
||||
use datafusion_common::{DataFusionError, Result, ScalarValue, utils};
|
||||
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
|
||||
use datatypes::arrow::array::new_empty_array;
|
||||
|
||||
use crate::function_registry::FunctionRegistry;
|
||||
use crate::scalars::vector::impl_conv::as_veclit;
|
||||
@@ -128,6 +129,11 @@ where
|
||||
}
|
||||
|
||||
let len = ensure_same_length(&[arg0, arg1])?;
|
||||
if len == 0 {
|
||||
return Ok(ColumnarValue::Array(new_empty_array(
|
||||
args.return_field.data_type(),
|
||||
)));
|
||||
}
|
||||
let mut results = Vec::with_capacity(len);
|
||||
for i in 0..len {
|
||||
let v0 = try_get_scalar_value!(arg0, i);
|
||||
@@ -155,6 +161,11 @@ where
|
||||
}
|
||||
|
||||
let len = ensure_same_length(&[arg0, arg1])?;
|
||||
if len == 0 {
|
||||
return Ok(ColumnarValue::Array(new_empty_array(
|
||||
args.return_field.data_type(),
|
||||
)));
|
||||
}
|
||||
let mut results = Vec::with_capacity(len);
|
||||
|
||||
match (arg0, arg1) {
|
||||
@@ -210,6 +221,11 @@ where
|
||||
};
|
||||
|
||||
let len = arg0.len();
|
||||
if len == 0 {
|
||||
return Ok(ColumnarValue::Array(new_empty_array(
|
||||
args.return_field.data_type(),
|
||||
)));
|
||||
}
|
||||
let mut results = Vec::with_capacity(len);
|
||||
for i in 0..len {
|
||||
let v = ScalarValue::try_from_array(arg0, i)?;
|
||||
|
||||
@@ -16,6 +16,10 @@ mod cos;
|
||||
mod dot;
|
||||
mod l2sq;
|
||||
|
||||
pub const VEC_COS_DISTANCE: &str = "vec_cos_distance";
|
||||
pub const VEC_L2SQ_DISTANCE: &str = "vec_l2sq_distance";
|
||||
pub const VEC_DOT_PRODUCT: &str = "vec_dot_product";
|
||||
|
||||
use std::borrow::Cow;
|
||||
use std::fmt::Display;
|
||||
|
||||
@@ -109,9 +113,9 @@ macro_rules! define_distance_function {
|
||||
};
|
||||
}
|
||||
|
||||
define_distance_function!(CosDistanceFunction, "vec_cos_distance", cos::cos);
|
||||
define_distance_function!(L2SqDistanceFunction, "vec_l2sq_distance", l2sq::l2sq);
|
||||
define_distance_function!(DotProductFunction, "vec_dot_product", dot::dot);
|
||||
define_distance_function!(CosDistanceFunction, VEC_COS_DISTANCE, cos::cos);
|
||||
define_distance_function!(L2SqDistanceFunction, VEC_L2SQ_DISTANCE, l2sq::l2sq);
|
||||
define_distance_function!(DotProductFunction, VEC_DOT_PRODUCT, dot::dot);
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
@@ -165,6 +165,8 @@ pub(crate) struct ScanMetricsSet {
|
||||
rows_bloom_filtered: usize,
|
||||
/// Number of rows filtered by vector index.
|
||||
rows_vector_filtered: usize,
|
||||
/// Number of rows selected by vector index.
|
||||
rows_vector_selected: usize,
|
||||
/// Number of rows filtered by precise filter.
|
||||
rows_precise_filtered: usize,
|
||||
/// Number of index result cache hits for fulltext index.
|
||||
@@ -291,6 +293,7 @@ impl fmt::Debug for ScanMetricsSet {
|
||||
rows_inverted_filtered,
|
||||
rows_bloom_filtered,
|
||||
rows_vector_filtered,
|
||||
rows_vector_selected,
|
||||
rows_precise_filtered,
|
||||
fulltext_index_cache_hit,
|
||||
fulltext_index_cache_miss,
|
||||
@@ -384,6 +387,9 @@ impl fmt::Debug for ScanMetricsSet {
|
||||
if *rows_vector_filtered > 0 {
|
||||
write!(f, ", \"rows_vector_filtered\":{rows_vector_filtered}")?;
|
||||
}
|
||||
if *rows_vector_selected > 0 {
|
||||
write!(f, ", \"rows_vector_selected\":{rows_vector_selected}")?;
|
||||
}
|
||||
if *rows_precise_filtered > 0 {
|
||||
write!(f, ", \"rows_precise_filtered\":{rows_precise_filtered}")?;
|
||||
}
|
||||
@@ -600,6 +606,7 @@ impl ScanMetricsSet {
|
||||
rows_inverted_filtered,
|
||||
rows_bloom_filtered,
|
||||
rows_vector_filtered,
|
||||
rows_vector_selected,
|
||||
rows_precise_filtered,
|
||||
fulltext_index_cache_hit,
|
||||
fulltext_index_cache_miss,
|
||||
@@ -636,6 +643,7 @@ impl ScanMetricsSet {
|
||||
self.rows_inverted_filtered += *rows_inverted_filtered;
|
||||
self.rows_bloom_filtered += *rows_bloom_filtered;
|
||||
self.rows_vector_filtered += *rows_vector_filtered;
|
||||
self.rows_vector_selected += *rows_vector_selected;
|
||||
self.rows_precise_filtered += *rows_precise_filtered;
|
||||
|
||||
self.fulltext_index_cache_hit += *fulltext_index_cache_hit;
|
||||
|
||||
@@ -979,6 +979,7 @@ impl ParquetReaderBuilder {
|
||||
return;
|
||||
}
|
||||
};
|
||||
metrics.rows_vector_selected += selection.row_count();
|
||||
apply_selection_and_update_metrics(output, &selection, metrics, INDEX_TYPE_VECTOR);
|
||||
}
|
||||
|
||||
@@ -1229,6 +1230,8 @@ pub(crate) struct ReaderFilterMetrics {
|
||||
pub(crate) rows_bloom_filtered: usize,
|
||||
/// Number of rows filtered by vector index.
|
||||
pub(crate) rows_vector_filtered: usize,
|
||||
/// Number of rows selected by vector index.
|
||||
pub(crate) rows_vector_selected: usize,
|
||||
/// Number of rows filtered by precise filter.
|
||||
pub(crate) rows_precise_filtered: usize,
|
||||
|
||||
@@ -1268,6 +1271,7 @@ impl ReaderFilterMetrics {
|
||||
self.rows_inverted_filtered += other.rows_inverted_filtered;
|
||||
self.rows_bloom_filtered += other.rows_bloom_filtered;
|
||||
self.rows_vector_filtered += other.rows_vector_filtered;
|
||||
self.rows_vector_selected += other.rows_vector_selected;
|
||||
self.rows_precise_filtered += other.rows_precise_filtered;
|
||||
|
||||
self.fulltext_index_cache_hit += other.fulltext_index_cache_hit;
|
||||
|
||||
@@ -9,6 +9,7 @@ workspace = true
|
||||
|
||||
[features]
|
||||
enterprise = []
|
||||
vector_index = []
|
||||
|
||||
[dependencies]
|
||||
ahash.workspace = true
|
||||
|
||||
@@ -162,6 +162,188 @@ impl Stream for EmptyStream {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "vector_index")]
|
||||
mod vector_search_tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_function::function::Function;
|
||||
use common_function::scalars::udf::create_udf;
|
||||
use datafusion_expr::expr::ScalarFunction;
|
||||
use datafusion_expr::{Expr, LogicalPlanBuilder, Signature, Volatility, col, lit};
|
||||
use datatypes::schema::{ColumnSchema, SchemaBuilder};
|
||||
use store_api::storage::ConcreteDataType;
|
||||
use table::metadata::{FilterPushDownType, TableInfoBuilder, TableMeta, TableType};
|
||||
use table::table::adapter::DfTableProviderAdapter;
|
||||
use table::{Table, TableRef};
|
||||
|
||||
use super::*;
|
||||
use crate::dist_plan::MergeScanLogicalPlan;
|
||||
|
||||
struct TestVectorFunction {
|
||||
name: &'static str,
|
||||
signature: Signature,
|
||||
}
|
||||
|
||||
impl TestVectorFunction {
|
||||
fn new(name: &'static str) -> Self {
|
||||
Self {
|
||||
name,
|
||||
signature: Signature::any(2, Volatility::Immutable),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TestVectorFunction {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.name)
|
||||
}
|
||||
}
|
||||
|
||||
impl Function for TestVectorFunction {
|
||||
fn name(&self) -> &str {
|
||||
self.name
|
||||
}
|
||||
|
||||
fn return_type(
|
||||
&self,
|
||||
_input_types: &[datatypes::arrow::datatypes::DataType],
|
||||
) -> datafusion_common::Result<datatypes::arrow::datatypes::DataType> {
|
||||
Ok(datatypes::arrow::datatypes::DataType::Float32)
|
||||
}
|
||||
|
||||
fn signature(&self) -> &Signature {
|
||||
&self.signature
|
||||
}
|
||||
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
_args: datafusion_expr::ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<datafusion_expr::ColumnarValue> {
|
||||
Err(datafusion_common::DataFusionError::Execution(
|
||||
"test udf should not be invoked".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
fn build_vector_table(table_id: TableId) -> TableRef {
|
||||
let schema = {
|
||||
let columns = vec![
|
||||
ColumnSchema::new("k0", ConcreteDataType::string_datatype(), true),
|
||||
ColumnSchema::new(
|
||||
"ts",
|
||||
ConcreteDataType::timestamp_millisecond_datatype(),
|
||||
false,
|
||||
)
|
||||
.with_time_index(true),
|
||||
ColumnSchema::new("v", ConcreteDataType::vector_datatype(2), false),
|
||||
];
|
||||
Arc::new(
|
||||
SchemaBuilder::try_from_columns(columns)
|
||||
.unwrap()
|
||||
.build()
|
||||
.unwrap(),
|
||||
)
|
||||
};
|
||||
|
||||
let table_meta = TableMeta {
|
||||
schema: schema.clone(),
|
||||
primary_key_indices: vec![0],
|
||||
value_indices: vec![2],
|
||||
engine: "test_engine".to_string(),
|
||||
next_column_id: 3,
|
||||
options: Default::default(),
|
||||
created_on: Default::default(),
|
||||
updated_on: Default::default(),
|
||||
partition_key_indices: vec![0],
|
||||
column_ids: vec![0, 1, 2],
|
||||
};
|
||||
|
||||
let table_info = TableInfoBuilder::default()
|
||||
.table_id(table_id)
|
||||
.name("t".to_string())
|
||||
.catalog_name(DEFAULT_CATALOG_NAME)
|
||||
.schema_name(DEFAULT_SCHEMA_NAME)
|
||||
.table_version(0)
|
||||
.table_type(TableType::Base)
|
||||
.meta(table_meta)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let data_source = Arc::new(TestDataSource::new(schema));
|
||||
Arc::new(Table::new(
|
||||
Arc::new(table_info),
|
||||
FilterPushDownType::Unsupported,
|
||||
data_source,
|
||||
))
|
||||
}
|
||||
|
||||
fn vector_distance_expr() -> Expr {
|
||||
let udf = create_udf(Arc::new(TestVectorFunction::new("vec_l2sq_distance")));
|
||||
Expr::ScalarFunction(ScalarFunction::new_udf(
|
||||
Arc::new(udf),
|
||||
vec![
|
||||
col("v"),
|
||||
lit(ScalarValue::Utf8(Some("[1.0, 2.0]".to_string()))),
|
||||
],
|
||||
))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vector_search_rewrite_keeps_sort_in_child_plan() {
|
||||
init_default_ut_logging();
|
||||
let table = build_vector_table(0);
|
||||
let table_source = Arc::new(DefaultTableSource::new(Arc::new(
|
||||
DfTableProviderAdapter::new(table),
|
||||
)));
|
||||
|
||||
let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
|
||||
.unwrap()
|
||||
.sort(vec![vector_distance_expr().sort(true, false)])
|
||||
.unwrap()
|
||||
.limit(0, Some(5))
|
||||
.unwrap()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let config = ConfigOptions::default();
|
||||
let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
|
||||
|
||||
let plan_str = result.to_string();
|
||||
assert!(plan_str.contains("MergeSort: vec_l2sq_distance"));
|
||||
assert!(plan_str.contains("Sort: vec_l2sq_distance"));
|
||||
assert!(plan_str.contains(MergeScanLogicalPlan::name()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vector_search_rewrite_with_filter_keeps_sort_in_child_plan() {
|
||||
init_default_ut_logging();
|
||||
let table = build_vector_table(0);
|
||||
let table_source = Arc::new(DefaultTableSource::new(Arc::new(
|
||||
DfTableProviderAdapter::new(table),
|
||||
)));
|
||||
|
||||
let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
|
||||
.unwrap()
|
||||
.filter(col("k0").eq(lit("hello")))
|
||||
.unwrap()
|
||||
.sort(vec![vector_distance_expr().sort(true, false)])
|
||||
.unwrap()
|
||||
.limit(0, Some(5))
|
||||
.unwrap()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let config = ConfigOptions::default();
|
||||
let result = DistPlannerAnalyzer {}.analyze(plan, &config).unwrap();
|
||||
|
||||
let plan_str = result.to_string();
|
||||
assert!(plan_str.contains("MergeSort: vec_l2sq_distance"));
|
||||
assert!(plan_str.contains("Sort: vec_l2sq_distance"));
|
||||
assert!(plan_str.contains("Filter: t.k0 = Utf8(\"hello\")"));
|
||||
assert!(plan_str.contains(MergeScanLogicalPlan::name()));
|
||||
}
|
||||
}
|
||||
|
||||
fn try_encode_decode_substrait(plan: &LogicalPlan, state: SessionState) {
|
||||
let sub_plan_bytes = substrait::DFLogicalSubstraitConvertor
|
||||
.encode(plan, crate::query_engine::DefaultSerializer)
|
||||
|
||||
@@ -16,9 +16,17 @@ use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_function::aggrs::aggr_wrapper::{StateMergeHelper, is_all_aggr_exprs_steppable};
|
||||
#[cfg(feature = "vector_index")]
|
||||
use common_function::scalars::vector::distance::{
|
||||
VEC_COS_DISTANCE, VEC_DOT_PRODUCT, VEC_L2SQ_DISTANCE,
|
||||
};
|
||||
use common_telemetry::debug;
|
||||
use datafusion::error::Result as DfResult;
|
||||
#[cfg(feature = "vector_index")]
|
||||
use datafusion_common::DataFusionError;
|
||||
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
|
||||
#[cfg(feature = "vector_index")]
|
||||
use datafusion_expr::Sort;
|
||||
use datafusion_expr::{Expr, LogicalPlan, UserDefinedLogicalNode};
|
||||
use promql::extension_plan::{
|
||||
EmptyMetric, InstantManipulate, RangeManipulate, SeriesDivide, SeriesNormalize,
|
||||
@@ -29,6 +37,37 @@ use crate::dist_plan::MergeScanLogicalPlan;
|
||||
use crate::dist_plan::analyzer::AliasMapping;
|
||||
use crate::dist_plan::merge_sort::{MergeSortLogicalPlan, merge_sort_transformer};
|
||||
|
||||
#[cfg(feature = "vector_index")]
|
||||
fn is_vector_sort(sort: &Sort) -> bool {
|
||||
if sort.expr.len() != 1 {
|
||||
return false;
|
||||
}
|
||||
let sort_expr = &sort.expr[0].expr;
|
||||
let Expr::ScalarFunction(func) = sort_expr else {
|
||||
return false;
|
||||
};
|
||||
matches!(
|
||||
func.name().to_lowercase().as_str(),
|
||||
VEC_L2SQ_DISTANCE | VEC_COS_DISTANCE | VEC_DOT_PRODUCT
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(feature = "vector_index")]
|
||||
fn vector_sort_transformer(plan: &LogicalPlan) -> DfResult<TransformerAction> {
|
||||
let LogicalPlan::Sort(sort) = plan else {
|
||||
return Err(DataFusionError::Internal(format!(
|
||||
"vector_sort_transformer expects Sort, got {plan}"
|
||||
)));
|
||||
};
|
||||
Ok(TransformerAction {
|
||||
extra_parent_plans: vec![
|
||||
MergeSortLogicalPlan::new(sort.input.clone(), sort.expr.clone(), sort.fetch)
|
||||
.into_logical_plan(),
|
||||
],
|
||||
new_child_plan: Some(LogicalPlan::Sort(sort.clone())),
|
||||
})
|
||||
}
|
||||
|
||||
pub struct StepTransformAction {
|
||||
extra_parent_plans: Vec<LogicalPlan>,
|
||||
new_child_plan: Option<LogicalPlan>,
|
||||
@@ -150,14 +189,19 @@ impl Categorizer {
|
||||
// commutativity is needed under this situation.
|
||||
Commutativity::ConditionalCommutative(None)
|
||||
}
|
||||
LogicalPlan::Sort(_) => {
|
||||
LogicalPlan::Sort(_sort) => {
|
||||
if partition_cols.is_empty() {
|
||||
return Ok(Commutativity::Commutative);
|
||||
}
|
||||
|
||||
// sort plan needs to consider column priority
|
||||
// Change Sort to MergeSort which assumes the input streams are already sorted hence can be more efficient
|
||||
// We should ensure the number of partition is not smaller than the number of region at present. Otherwise this would result in incorrect output.
|
||||
// Change Sort to MergeSort which assumes the input streams are already sorted hence can be more efficient.
|
||||
#[cfg(feature = "vector_index")]
|
||||
if is_vector_sort(_sort) {
|
||||
return Ok(Commutativity::TransformedCommutative {
|
||||
transformer: Some(Arc::new(vector_sort_transformer)),
|
||||
});
|
||||
}
|
||||
Commutativity::ConditionalCommutative(Some(Arc::new(merge_sort_transformer)))
|
||||
}
|
||||
LogicalPlan::Join(_) => Commutativity::NonCommutative,
|
||||
|
||||
@@ -35,7 +35,9 @@ use session::context::{QueryContext, QueryContextRef};
|
||||
use snafu::ResultExt;
|
||||
use store_api::metadata::RegionMetadataRef;
|
||||
use store_api::region_engine::RegionEngineRef;
|
||||
use store_api::storage::{RegionId, ScanRequest, TimeSeriesDistribution, TimeSeriesRowSelector};
|
||||
use store_api::storage::{
|
||||
RegionId, ScanRequest, TimeSeriesDistribution, TimeSeriesRowSelector, VectorSearchRequest,
|
||||
};
|
||||
use table::TableRef;
|
||||
use table::metadata::{TableId, TableInfoRef};
|
||||
use table::table::scan::RegionScanExec;
|
||||
@@ -256,6 +258,14 @@ impl DummyTableProvider {
|
||||
self.scan_request.lock().unwrap().series_row_selector = Some(selector);
|
||||
}
|
||||
|
||||
pub fn with_vector_search_hint(&self, hint: VectorSearchRequest) {
|
||||
self.scan_request.lock().unwrap().vector_search = Some(hint);
|
||||
}
|
||||
|
||||
pub fn get_vector_search_hint(&self) -> Option<VectorSearchRequest> {
|
||||
self.scan_request.lock().unwrap().vector_search.clone()
|
||||
}
|
||||
|
||||
pub fn with_sequence(&self, sequence: u64) {
|
||||
self.scan_request.lock().unwrap().memtable_max_sequence = Some(sequence);
|
||||
}
|
||||
|
||||
@@ -28,6 +28,10 @@ use store_api::metric_engine_consts::DATA_SCHEMA_TSID_COLUMN_NAME;
|
||||
use store_api::storage::{TimeSeriesDistribution, TimeSeriesRowSelector};
|
||||
|
||||
use crate::dummy_catalog::DummyTableProvider;
|
||||
#[cfg(feature = "vector_index")]
|
||||
mod vector_search;
|
||||
#[cfg(feature = "vector_index")]
|
||||
use vector_search::VectorSearchState;
|
||||
|
||||
/// This rule will traverse the plan to collect necessary hints for leaf
|
||||
/// table scan node and set them in [`ScanRequest`]. Hints include:
|
||||
@@ -59,13 +63,16 @@ impl ScanHintRule {
|
||||
let _ = plan.visit(&mut visitor)?;
|
||||
|
||||
if visitor.need_rewrite() {
|
||||
plan.transform_down(&|plan| Self::set_hints(plan, &visitor))
|
||||
plan.transform_down(&mut |plan| Self::set_hints(plan, &mut visitor))
|
||||
} else {
|
||||
Ok(Transformed::no(plan))
|
||||
}
|
||||
}
|
||||
|
||||
fn set_hints(plan: LogicalPlan, visitor: &ScanHintVisitor) -> Result<Transformed<LogicalPlan>> {
|
||||
fn set_hints(
|
||||
plan: LogicalPlan,
|
||||
visitor: &mut ScanHintVisitor,
|
||||
) -> Result<Transformed<LogicalPlan>> {
|
||||
match &plan {
|
||||
LogicalPlan::TableScan(table_scan) => {
|
||||
let mut transformed = false;
|
||||
@@ -94,6 +101,13 @@ impl ScanHintRule {
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(feature = "vector_index")]
|
||||
if let Some(vector_request) = visitor
|
||||
.vector_search
|
||||
.take_vector_request_from_dummy(adapter, &table_scan.table_name)
|
||||
{
|
||||
adapter.with_vector_search_hint(vector_request);
|
||||
}
|
||||
transformed = true;
|
||||
}
|
||||
}
|
||||
@@ -221,15 +235,29 @@ struct ScanHintVisitor {
|
||||
/// This field stores saved `group_by` columns when all aggregate functions are `last_value`
|
||||
/// and the `order_by` column which should be time index.
|
||||
ts_row_selector: Option<(HashSet<Column>, Column)>,
|
||||
#[cfg(feature = "vector_index")]
|
||||
vector_search: VectorSearchState,
|
||||
}
|
||||
|
||||
impl TreeNodeVisitor<'_> for ScanHintVisitor {
|
||||
type Node = LogicalPlan;
|
||||
|
||||
fn f_down(&mut self, node: &Self::Node) -> Result<TreeNodeRecursion> {
|
||||
#[cfg(feature = "vector_index")]
|
||||
if let LogicalPlan::Limit(limit) = node {
|
||||
// Track LIMIT so vector hint k can be derived within the same input chain.
|
||||
self.vector_search.on_limit_enter(limit);
|
||||
}
|
||||
|
||||
// Get order requirement from sort plan
|
||||
if let LogicalPlan::Sort(sort) = node {
|
||||
self.order_expr = Some(sort.expr.clone());
|
||||
|
||||
#[cfg(feature = "vector_index")]
|
||||
{
|
||||
// Capture vector ORDER BY and TopK hints from sort nodes.
|
||||
self.vector_search.on_sort_enter(sort);
|
||||
}
|
||||
}
|
||||
|
||||
// Get time series row selector from aggr plan
|
||||
@@ -294,12 +322,17 @@ impl TreeNodeVisitor<'_> for ScanHintVisitor {
|
||||
}
|
||||
}
|
||||
|
||||
if self.ts_row_selector.is_some()
|
||||
&& (matches!(node, LogicalPlan::Subquery(_)) || node.inputs().len() > 1)
|
||||
{
|
||||
// Avoid carrying vector hints across branching inputs (join/subquery) to prevent
|
||||
// pruning results before global ordering is applied.
|
||||
let is_branching = matches!(node, LogicalPlan::Subquery(_)) || node.inputs().len() > 1;
|
||||
if is_branching && self.ts_row_selector.is_some() {
|
||||
// clean previous time series selector hint when encounter subqueries or join
|
||||
self.ts_row_selector = None;
|
||||
}
|
||||
#[cfg(feature = "vector_index")]
|
||||
if is_branching {
|
||||
self.vector_search.on_branching_enter();
|
||||
}
|
||||
|
||||
if let LogicalPlan::Filter(filter) = node
|
||||
&& let Some(group_by_exprs) = &self.ts_row_selector
|
||||
@@ -312,13 +345,56 @@ impl TreeNodeVisitor<'_> for ScanHintVisitor {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "vector_index")]
|
||||
if let LogicalPlan::Filter(filter) = node {
|
||||
self.vector_search.on_filter_enter(&filter.predicate);
|
||||
}
|
||||
|
||||
#[cfg(feature = "vector_index")]
|
||||
if let LogicalPlan::TableScan(table_scan) = node {
|
||||
// Record vector hints at leaf scans after scope checks.
|
||||
self.vector_search.on_table_scan(table_scan);
|
||||
}
|
||||
|
||||
Ok(TreeNodeRecursion::Continue)
|
||||
}
|
||||
|
||||
fn f_up(&mut self, _node: &Self::Node) -> Result<TreeNodeRecursion> {
|
||||
#[cfg(feature = "vector_index")]
|
||||
match _node {
|
||||
LogicalPlan::Limit(_) => {
|
||||
self.vector_search.on_limit_exit();
|
||||
}
|
||||
LogicalPlan::Sort(_) => {
|
||||
self.vector_search.on_sort_exit();
|
||||
}
|
||||
LogicalPlan::Filter(_) => {
|
||||
self.vector_search.on_filter_exit();
|
||||
}
|
||||
LogicalPlan::Subquery(_) => {
|
||||
self.vector_search.on_branching_exit();
|
||||
}
|
||||
_ if _node.inputs().len() > 1 => {
|
||||
self.vector_search.on_branching_exit();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
Ok(TreeNodeRecursion::Continue)
|
||||
}
|
||||
}
|
||||
|
||||
impl ScanHintVisitor {
|
||||
fn need_rewrite(&self) -> bool {
|
||||
self.order_expr.is_some() || self.ts_row_selector.is_some()
|
||||
let base = self.order_expr.is_some() || self.ts_row_selector.is_some();
|
||||
#[cfg(feature = "vector_index")]
|
||||
{
|
||||
base || self.vector_search.need_rewrite()
|
||||
}
|
||||
#[cfg(not(feature = "vector_index"))]
|
||||
{
|
||||
base
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
837
src/query/src/optimizer/scan_hint/vector_search.rs
Normal file
837
src/query/src/optimizer/scan_hint/vector_search.rs
Normal file
@@ -0,0 +1,837 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::{HashMap, HashSet, VecDeque};
|
||||
|
||||
use common_function::scalars::vector::distance::{
|
||||
VEC_COS_DISTANCE, VEC_DOT_PRODUCT, VEC_L2SQ_DISTANCE,
|
||||
};
|
||||
use common_telemetry::debug;
|
||||
use datafusion_common::ScalarValue;
|
||||
use datafusion_expr::logical_plan::FetchType;
|
||||
use datafusion_expr::utils::split_conjunction;
|
||||
use datafusion_expr::{Expr, SortExpr};
|
||||
use datafusion_sql::TableReference;
|
||||
use datatypes::types::parse_string_to_vector_type_value;
|
||||
use store_api::storage::{VectorDistanceMetric, VectorSearchRequest};
|
||||
|
||||
use crate::dummy_catalog::DummyTableProvider;
|
||||
|
||||
/// Tracks vector search hints while traversing the logical plan.
|
||||
///
|
||||
/// Vector search requests are emitted only when:
|
||||
/// - ORDER BY uses a supported vector distance function and its direction matches the metric.
|
||||
/// - A LIMIT (or Sort.fetch) is present to derive k.
|
||||
/// - The hint stays within a single input chain (not across join/subquery branches).
|
||||
/// - The target column is non-nullable, or an explicit IS NOT NULL filter exists.
|
||||
#[derive(Default)]
|
||||
pub(crate) struct VectorSearchState {
|
||||
current_distance: Option<VectorDistanceInfo>,
|
||||
current_limit: Option<VectorLimitInfo>,
|
||||
distance_stack: Vec<Option<VectorDistanceInfo>>,
|
||||
limit_stack: Vec<Option<VectorLimitInfo>>,
|
||||
non_null_columns: HashSet<ColumnKey>,
|
||||
non_null_stack: Vec<HashSet<ColumnKey>>,
|
||||
vector_hints: HashMap<TableReference, VecDeque<VectorHintEntry>>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct VectorDistanceInfo {
|
||||
table_reference: Option<TableReference>,
|
||||
column_name: String,
|
||||
query_vector: Vec<f32>,
|
||||
metric: VectorDistanceMetric,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct VectorLimitInfo {
|
||||
fetch: usize,
|
||||
skip: usize,
|
||||
}
|
||||
|
||||
impl VectorLimitInfo {
|
||||
fn k(&self) -> Option<usize> {
|
||||
self.fetch.checked_add(self.skip)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct VectorHintEntry {
|
||||
distance: VectorDistanceInfo,
|
||||
limit: VectorLimitInfo,
|
||||
non_null_constraint: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Hash, PartialEq, Eq)]
|
||||
struct ColumnKey {
|
||||
table_reference: Option<TableReference>,
|
||||
column_name: String,
|
||||
}
|
||||
|
||||
impl VectorSearchState {
|
||||
pub(crate) fn need_rewrite(&self) -> bool {
|
||||
!self.vector_hints.is_empty()
|
||||
}
|
||||
|
||||
pub(crate) fn on_branching_enter(&mut self) {
|
||||
// Clear per-branch state so hints are only derived within a single input chain.
|
||||
self.distance_stack.push(self.current_distance.take());
|
||||
self.limit_stack.push(self.current_limit.take());
|
||||
self.non_null_stack
|
||||
.push(std::mem::take(&mut self.non_null_columns));
|
||||
}
|
||||
|
||||
pub(crate) fn on_branching_exit(&mut self) {
|
||||
// Restore the prior chain state after leaving the branch.
|
||||
if let Some(previous) = self.limit_stack.pop() {
|
||||
self.current_limit = previous;
|
||||
}
|
||||
if let Some(previous) = self.distance_stack.pop() {
|
||||
self.current_distance = previous;
|
||||
}
|
||||
if let Some(previous) = self.non_null_stack.pop() {
|
||||
self.non_null_columns = previous;
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn on_limit_enter(&mut self, limit: &datafusion_expr::logical_plan::Limit) {
|
||||
self.limit_stack.push(self.current_limit.take());
|
||||
self.current_limit = Self::extract_limit_info(limit);
|
||||
}
|
||||
|
||||
pub(crate) fn on_limit_exit(&mut self) {
|
||||
if let Some(previous) = self.limit_stack.pop() {
|
||||
self.current_limit = previous;
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn on_sort_enter(&mut self, sort: &datafusion_expr::logical_plan::Sort) {
|
||||
// Distance is scoped to the nearest sort, while limit may be inherited from parents.
|
||||
self.distance_stack.push(self.current_distance.take());
|
||||
self.limit_stack.push(self.current_limit.clone());
|
||||
let distance = Self::extract_distance_from_sort(sort);
|
||||
self.current_distance = distance.clone();
|
||||
// Sort.fetch is a TopK limit, so we can infer k when no LIMIT is present.
|
||||
if self.current_limit.is_none() {
|
||||
self.current_limit = distance
|
||||
.as_ref()
|
||||
.and_then(|_| Self::extract_limit_from_sort(sort));
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn on_sort_exit(&mut self) {
|
||||
if let Some(previous) = self.limit_stack.pop() {
|
||||
self.current_limit = previous;
|
||||
}
|
||||
if let Some(previous) = self.distance_stack.pop() {
|
||||
self.current_distance = previous;
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn on_table_scan(&mut self, table_scan: &datafusion_expr::logical_plan::TableScan) {
|
||||
self.record_vector_hint(table_scan);
|
||||
}
|
||||
|
||||
pub(crate) fn on_filter_enter(&mut self, predicate: &Expr) {
|
||||
self.non_null_stack.push(self.non_null_columns.clone());
|
||||
for expr in split_conjunction(predicate) {
|
||||
if let Expr::IsNotNull(inner) = expr
|
||||
&& let Expr::Column(col) = inner.as_ref()
|
||||
{
|
||||
self.non_null_columns.insert(ColumnKey {
|
||||
table_reference: col.relation.clone(),
|
||||
column_name: col.name.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
// TODO: detect non-null constraints from more complex predicates (casts/functions).
|
||||
}
|
||||
|
||||
pub(crate) fn on_filter_exit(&mut self) {
|
||||
if let Some(previous) = self.non_null_stack.pop() {
|
||||
self.non_null_columns = previous;
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn take_vector_request_from_dummy(
|
||||
&mut self,
|
||||
provider: &DummyTableProvider,
|
||||
table_name: &TableReference,
|
||||
) -> Option<VectorSearchRequest> {
|
||||
let hint = self.take_vector_hint(table_name)?;
|
||||
self.build_vector_request_from_dummy(provider, table_name, &hint)
|
||||
}
|
||||
|
||||
fn build_vector_request_from_dummy(
|
||||
&self,
|
||||
provider: &DummyTableProvider,
|
||||
table_name: &TableReference,
|
||||
hint: &VectorHintEntry,
|
||||
) -> Option<VectorSearchRequest> {
|
||||
let info = &hint.distance;
|
||||
let k = hint.limit.k()?;
|
||||
|
||||
if let Some(ref hint_table) = info.table_reference
|
||||
&& table_name != hint_table
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let metadata = provider.region_metadata();
|
||||
let column = metadata.column_by_name(&info.column_name)?;
|
||||
if column.column_schema.is_nullable() && !hint.non_null_constraint {
|
||||
debug!(
|
||||
"Skip vector hint: column '{}' is nullable without IS NOT NULL filter",
|
||||
info.column_name
|
||||
);
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(VectorSearchRequest {
|
||||
column_id: column.column_id,
|
||||
query_vector: info.query_vector.clone(),
|
||||
k,
|
||||
metric: info.metric,
|
||||
})
|
||||
}
|
||||
|
||||
fn extract_distance_info(expr: &Expr) -> Option<VectorDistanceInfo> {
|
||||
let Expr::ScalarFunction(func) = expr else {
|
||||
return None;
|
||||
};
|
||||
|
||||
let func_name = func.name().to_lowercase();
|
||||
let metric = match func_name.as_str() {
|
||||
VEC_L2SQ_DISTANCE => VectorDistanceMetric::L2sq,
|
||||
VEC_COS_DISTANCE => VectorDistanceMetric::Cosine,
|
||||
VEC_DOT_PRODUCT => VectorDistanceMetric::InnerProduct,
|
||||
_ => return None,
|
||||
};
|
||||
|
||||
if func.args.len() != 2 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let (table_reference, column_name) = match &func.args[0] {
|
||||
Expr::Column(col) => (col.relation.clone(), col.name.clone()),
|
||||
_ => return None,
|
||||
};
|
||||
|
||||
let query_vector = Self::extract_query_vector(&func.args[1])?;
|
||||
|
||||
Some(VectorDistanceInfo {
|
||||
table_reference,
|
||||
column_name,
|
||||
query_vector,
|
||||
metric,
|
||||
})
|
||||
}
|
||||
|
||||
fn extract_distance_from_sort(
|
||||
sort: &datafusion_expr::logical_plan::Sort,
|
||||
) -> Option<VectorDistanceInfo> {
|
||||
if sort.expr.len() != 1 {
|
||||
debug!(
|
||||
"Skip vector hint: Sort has {} expressions, expected 1",
|
||||
sort.expr.len()
|
||||
);
|
||||
return None;
|
||||
}
|
||||
let sort_expr: &SortExpr = &sort.expr[0];
|
||||
let info = Self::extract_distance_info(&sort_expr.expr)?;
|
||||
let expected_asc = info.metric != VectorDistanceMetric::InnerProduct;
|
||||
if sort_expr.asc == expected_asc {
|
||||
Some(info)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_limit_info(limit: &datafusion_expr::logical_plan::Limit) -> Option<VectorLimitInfo> {
|
||||
let fetch = match limit.get_fetch_type().ok()? {
|
||||
FetchType::Literal(fetch) => fetch?,
|
||||
FetchType::UnsupportedExpr => return None,
|
||||
};
|
||||
let skip = match limit.get_skip_type().ok()? {
|
||||
datafusion_expr::logical_plan::SkipType::Literal(skip) => skip,
|
||||
datafusion_expr::logical_plan::SkipType::UnsupportedExpr => return None,
|
||||
};
|
||||
Some(VectorLimitInfo { fetch, skip })
|
||||
}
|
||||
|
||||
fn extract_limit_from_sort(
|
||||
sort: &datafusion_expr::logical_plan::Sort,
|
||||
) -> Option<VectorLimitInfo> {
|
||||
let fetch = sort.fetch?;
|
||||
Some(VectorLimitInfo { fetch, skip: 0 })
|
||||
}
|
||||
|
||||
fn record_vector_hint(&mut self, table_scan: &datafusion_expr::logical_plan::TableScan) {
|
||||
let Some(limit) = self.current_limit.as_ref() else {
|
||||
return;
|
||||
};
|
||||
let Some(distance) = self.current_distance.as_ref() else {
|
||||
return;
|
||||
};
|
||||
// Only emit hints when distance+limit are present and the table matches the sort target.
|
||||
if let Some(ref hint_table) = distance.table_reference
|
||||
&& hint_table != &table_scan.table_name
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
let non_null_constraint =
|
||||
self.is_column_non_null(&distance.table_reference, &distance.column_name);
|
||||
self.vector_hints
|
||||
.entry(table_scan.table_name.clone())
|
||||
.or_default()
|
||||
.push_back(VectorHintEntry {
|
||||
distance: distance.clone(),
|
||||
limit: limit.clone(),
|
||||
non_null_constraint,
|
||||
});
|
||||
}
|
||||
|
||||
fn take_vector_hint(&mut self, table_name: &TableReference) -> Option<VectorHintEntry> {
|
||||
let hints = self.vector_hints.get_mut(table_name)?;
|
||||
let hint = hints.pop_front();
|
||||
if hints.is_empty() {
|
||||
self.vector_hints.remove(table_name);
|
||||
}
|
||||
hint
|
||||
}
|
||||
|
||||
fn is_column_non_null(
|
||||
&self,
|
||||
table_reference: &Option<TableReference>,
|
||||
column_name: &str,
|
||||
) -> bool {
|
||||
self.non_null_columns.contains(&ColumnKey {
|
||||
table_reference: table_reference.clone(),
|
||||
column_name: column_name.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
fn extract_query_vector(expr: &Expr) -> Option<Vec<f32>> {
|
||||
match expr {
|
||||
Expr::Literal(scalar, _) => match scalar {
|
||||
ScalarValue::Utf8(Some(s)) | ScalarValue::LargeUtf8(Some(s)) => {
|
||||
Self::parse_json_vector(s)
|
||||
}
|
||||
ScalarValue::Binary(Some(bytes)) | ScalarValue::LargeBinary(Some(bytes)) => {
|
||||
Self::parse_binary_vector(bytes)
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_json_vector(s: &str) -> Option<Vec<f32>> {
|
||||
let trimmed = s.trim();
|
||||
if !trimmed.starts_with('[') || !trimmed.ends_with(']') {
|
||||
return None;
|
||||
}
|
||||
parse_string_to_vector_type_value(trimmed, None)
|
||||
.ok()
|
||||
.and_then(|bytes| if bytes.is_empty() { None } else { Some(bytes) })
|
||||
.and_then(|bytes| Self::parse_binary_vector(&bytes))
|
||||
}
|
||||
|
||||
fn parse_binary_vector(bytes: &[u8]) -> Option<Vec<f32>> {
|
||||
if bytes.is_empty() || !bytes.len().is_multiple_of(4) {
|
||||
return None;
|
||||
}
|
||||
Some(
|
||||
bytes
|
||||
.chunks(4)
|
||||
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use api::v1::SemanticType;
|
||||
use common_function::function::Function;
|
||||
use common_function::scalars::udf::create_udf;
|
||||
use common_function::scalars::vector::distance::{VEC_DOT_PRODUCT, VEC_L2SQ_DISTANCE};
|
||||
use datafusion::datasource::DefaultTableSource;
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion_common::{Column, DataFusionError, Result, ScalarValue};
|
||||
use datafusion_expr::expr::ScalarFunction;
|
||||
use datafusion_expr::logical_plan::JoinType;
|
||||
use datafusion_expr::{
|
||||
Expr, LogicalPlan, LogicalPlanBuilder, Signature, Subquery, Volatility, col, lit,
|
||||
};
|
||||
use datafusion_optimizer::{OptimizerContext, OptimizerRule};
|
||||
use datatypes::schema::ColumnSchema;
|
||||
use store_api::metadata::{ColumnMetadata, RegionMetadataBuilder};
|
||||
use store_api::storage::{ConcreteDataType, VectorDistanceMetric};
|
||||
|
||||
use super::VectorSearchState;
|
||||
use crate::dummy_catalog::DummyTableProvider;
|
||||
use crate::optimizer::scan_hint::ScanHintRule;
|
||||
use crate::optimizer::test_util::MetaRegionEngine;
|
||||
|
||||
struct TestVectorFunction {
|
||||
name: &'static str,
|
||||
signature: Signature,
|
||||
}
|
||||
|
||||
impl TestVectorFunction {
|
||||
fn new(name: &'static str) -> Self {
|
||||
Self {
|
||||
name,
|
||||
signature: Signature::any(2, Volatility::Immutable),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TestVectorFunction {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.name)
|
||||
}
|
||||
}
|
||||
|
||||
impl Function for TestVectorFunction {
|
||||
fn name(&self) -> &str {
|
||||
self.name
|
||||
}
|
||||
|
||||
fn return_type(
|
||||
&self,
|
||||
_input_types: &[datatypes::arrow::datatypes::DataType],
|
||||
) -> Result<datatypes::arrow::datatypes::DataType> {
|
||||
Ok(datatypes::arrow::datatypes::DataType::Float32)
|
||||
}
|
||||
|
||||
fn signature(&self) -> &Signature {
|
||||
&self.signature
|
||||
}
|
||||
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
_args: datafusion_expr::ScalarFunctionArgs,
|
||||
) -> Result<ColumnarValue> {
|
||||
Err(DataFusionError::Execution(
|
||||
"test udf should not be invoked".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_distance_expr(function_name: &'static str) -> Expr {
|
||||
let udf = create_udf(Arc::new(TestVectorFunction::new(function_name)));
|
||||
Expr::ScalarFunction(ScalarFunction::new_udf(
|
||||
Arc::new(udf),
|
||||
vec![
|
||||
col("v"),
|
||||
lit(ScalarValue::Utf8(Some("[1.0, 2.0]".to_string()))),
|
||||
],
|
||||
))
|
||||
}
|
||||
|
||||
fn build_dummy_provider(column_id: u32) -> Arc<DummyTableProvider> {
|
||||
build_dummy_provider_with_nullable(column_id, false)
|
||||
}
|
||||
|
||||
fn build_dummy_provider_with_nullable(
|
||||
column_id: u32,
|
||||
nullable_vector: bool,
|
||||
) -> Arc<DummyTableProvider> {
|
||||
let mut builder = RegionMetadataBuilder::new(0.into());
|
||||
builder
|
||||
.push_column_metadata(ColumnMetadata {
|
||||
column_schema: ColumnSchema::new("k0", ConcreteDataType::string_datatype(), true),
|
||||
semantic_type: SemanticType::Tag,
|
||||
column_id: 1,
|
||||
})
|
||||
.push_column_metadata(ColumnMetadata {
|
||||
column_schema: ColumnSchema::new(
|
||||
"ts",
|
||||
ConcreteDataType::timestamp_millisecond_datatype(),
|
||||
false,
|
||||
),
|
||||
semantic_type: SemanticType::Timestamp,
|
||||
column_id: 2,
|
||||
})
|
||||
.push_column_metadata(ColumnMetadata {
|
||||
column_schema: ColumnSchema::new(
|
||||
"v",
|
||||
ConcreteDataType::vector_datatype(2),
|
||||
nullable_vector,
|
||||
),
|
||||
semantic_type: SemanticType::Field,
|
||||
column_id,
|
||||
})
|
||||
.primary_key(vec![1]);
|
||||
let metadata = Arc::new(builder.build().unwrap());
|
||||
let engine = Arc::new(MetaRegionEngine::with_metadata(metadata.clone()));
|
||||
Arc::new(DummyTableProvider::new(0.into(), engine, metadata))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_json_vector() {
|
||||
assert_eq!(
|
||||
VectorSearchState::parse_json_vector("[1.0, 2.0, 3.0]"),
|
||||
Some(vec![1.0, 2.0, 3.0])
|
||||
);
|
||||
assert_eq!(
|
||||
VectorSearchState::parse_json_vector("[1.5, -2.3, 0.0]"),
|
||||
Some(vec![1.5, -2.3, 0.0])
|
||||
);
|
||||
assert_eq!(VectorSearchState::parse_json_vector("invalid"), None);
|
||||
assert_eq!(VectorSearchState::parse_json_vector("[]"), None);
|
||||
assert_eq!(VectorSearchState::parse_json_vector(""), None);
|
||||
assert_eq!(VectorSearchState::parse_json_vector("["), None);
|
||||
assert_eq!(VectorSearchState::parse_json_vector("[1.0, abc]"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_binary_vector() {
|
||||
let v1: f32 = 1.0;
|
||||
let v2: f32 = 2.0;
|
||||
let mut bytes = Vec::new();
|
||||
bytes.extend_from_slice(&v1.to_le_bytes());
|
||||
bytes.extend_from_slice(&v2.to_le_bytes());
|
||||
|
||||
let result = VectorSearchState::parse_binary_vector(&bytes);
|
||||
assert_eq!(result, Some(vec![1.0, 2.0]));
|
||||
|
||||
let result = VectorSearchState::parse_binary_vector(&[1, 2, 3]);
|
||||
assert_eq!(result, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dummy_provider_vector_hint() {
|
||||
let dummy_provider = build_dummy_provider(10);
|
||||
let table_source = Arc::new(DefaultTableSource::new(dummy_provider.clone()));
|
||||
let expr = vec_distance_expr(VEC_L2SQ_DISTANCE);
|
||||
let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
|
||||
.unwrap()
|
||||
.sort(vec![expr.sort(true, false)])
|
||||
.unwrap()
|
||||
.limit(0, Some(5))
|
||||
.unwrap()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let context = OptimizerContext::default();
|
||||
let _ = ScanHintRule.rewrite(plan, &context).unwrap();
|
||||
|
||||
let hint = dummy_provider.get_vector_search_hint().unwrap();
|
||||
assert_eq!(hint.column_id, 10);
|
||||
assert_eq!(hint.k, 5);
|
||||
assert_eq!(hint.metric, VectorDistanceMetric::L2sq);
|
||||
assert_eq!(hint.query_vector, vec![1.0, 2.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_limit_offset_for_vector_hint() {
|
||||
let dummy_provider = build_dummy_provider(10);
|
||||
let table_source = Arc::new(DefaultTableSource::new(dummy_provider.clone()));
|
||||
let expr = vec_distance_expr(VEC_L2SQ_DISTANCE);
|
||||
let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
|
||||
.unwrap()
|
||||
.sort(vec![expr.sort(true, false)])
|
||||
.unwrap()
|
||||
.limit(5, Some(10))
|
||||
.unwrap()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let context = OptimizerContext::default();
|
||||
let _ = ScanHintRule.rewrite(plan, &context).unwrap();
|
||||
|
||||
let hint = dummy_provider.get_vector_search_hint().unwrap();
|
||||
assert_eq!(hint.k, 15);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_inner_product_sort_direction() {
|
||||
let dummy_provider = build_dummy_provider(10);
|
||||
let table_source = Arc::new(DefaultTableSource::new(dummy_provider.clone()));
|
||||
let expr = vec_distance_expr(VEC_DOT_PRODUCT);
|
||||
let plan = LogicalPlanBuilder::scan_with_filters("t", table_source.clone(), None, vec![])
|
||||
.unwrap()
|
||||
.sort(vec![expr.clone().sort(true, false)])
|
||||
.unwrap()
|
||||
.limit(0, Some(3))
|
||||
.unwrap()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let context = OptimizerContext::default();
|
||||
let _ = ScanHintRule.rewrite(plan, &context).unwrap();
|
||||
assert!(dummy_provider.get_vector_search_hint().is_none());
|
||||
|
||||
let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
|
||||
.unwrap()
|
||||
.sort(vec![expr.sort(false, false)])
|
||||
.unwrap()
|
||||
.limit(0, Some(3))
|
||||
.unwrap()
|
||||
.build()
|
||||
.unwrap();
|
||||
let _ = ScanHintRule.rewrite(plan, &context).unwrap();
|
||||
let hint = dummy_provider.get_vector_search_hint().unwrap();
|
||||
assert_eq!(hint.metric, VectorDistanceMetric::InnerProduct);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_limit_clause() {
|
||||
let dummy_provider = build_dummy_provider(10);
|
||||
let table_source = Arc::new(DefaultTableSource::new(dummy_provider.clone()));
|
||||
let expr = vec_distance_expr(VEC_L2SQ_DISTANCE);
|
||||
let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
|
||||
.unwrap()
|
||||
.sort(vec![expr.sort(true, false)])
|
||||
.unwrap()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let context = OptimizerContext::default();
|
||||
let _ = ScanHintRule.rewrite(plan, &context).unwrap();
|
||||
|
||||
assert!(dummy_provider.get_vector_search_hint().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nullable_vector_requires_is_not_null_filter() {
|
||||
let dummy_provider = build_dummy_provider_with_nullable(10, true);
|
||||
let table_source = Arc::new(DefaultTableSource::new(dummy_provider.clone()));
|
||||
let expr = vec_distance_expr(VEC_L2SQ_DISTANCE);
|
||||
let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
|
||||
.unwrap()
|
||||
.sort(vec![expr.sort(true, false)])
|
||||
.unwrap()
|
||||
.limit(0, Some(5))
|
||||
.unwrap()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let context = OptimizerContext::default();
|
||||
let _ = ScanHintRule.rewrite(plan, &context).unwrap();
|
||||
|
||||
assert!(dummy_provider.get_vector_search_hint().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nullable_vector_with_is_not_null_filter() {
|
||||
let dummy_provider = build_dummy_provider_with_nullable(10, true);
|
||||
let table_source = Arc::new(DefaultTableSource::new(dummy_provider.clone()));
|
||||
let expr = vec_distance_expr(VEC_L2SQ_DISTANCE);
|
||||
let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
|
||||
.unwrap()
|
||||
.filter(col("v").is_not_null())
|
||||
.unwrap()
|
||||
.sort(vec![expr.sort(true, false)])
|
||||
.unwrap()
|
||||
.limit(0, Some(5))
|
||||
.unwrap()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let context = OptimizerContext::default();
|
||||
let _ = ScanHintRule.rewrite(plan, &context).unwrap();
|
||||
|
||||
let hint = dummy_provider.get_vector_search_hint().unwrap();
|
||||
assert_eq!(hint.column_id, 10);
|
||||
assert_eq!(hint.k, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sort_fetch_for_vector_hint() {
|
||||
let dummy_provider = build_dummy_provider(10);
|
||||
let table_source = Arc::new(DefaultTableSource::new(dummy_provider.clone()));
|
||||
let expr = vec_distance_expr(VEC_L2SQ_DISTANCE);
|
||||
let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
|
||||
.unwrap()
|
||||
.sort_with_limit(vec![expr.sort(true, false)], Some(4))
|
||||
.unwrap()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let context = OptimizerContext::default();
|
||||
let _ = ScanHintRule.rewrite(plan, &context).unwrap();
|
||||
|
||||
let hint = dummy_provider.get_vector_search_hint().unwrap();
|
||||
assert_eq!(hint.k, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_limit_scoped_to_sort_branch() {
|
||||
let t1_provider = build_dummy_provider(10);
|
||||
let t2_provider = build_dummy_provider(20);
|
||||
let t1_source = Arc::new(DefaultTableSource::new(t1_provider.clone()));
|
||||
let t2_source = Arc::new(DefaultTableSource::new(t2_provider.clone()));
|
||||
let expr = vec_distance_expr(VEC_L2SQ_DISTANCE);
|
||||
|
||||
let left = LogicalPlanBuilder::scan_with_filters("t1", t1_source, None, vec![])
|
||||
.unwrap()
|
||||
.limit(0, Some(5))
|
||||
.unwrap()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let right = LogicalPlanBuilder::scan_with_filters("t2", t2_source, None, vec![])
|
||||
.unwrap()
|
||||
.sort(vec![expr.sort(true, false)])
|
||||
.unwrap()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let join_keys: (Vec<Column>, Vec<Column>) = (vec![], vec![]);
|
||||
let plan = LogicalPlanBuilder::from(left)
|
||||
.join(right, JoinType::Inner, join_keys, None)
|
||||
.unwrap()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let context = OptimizerContext::default();
|
||||
let _ = ScanHintRule.rewrite(plan, &context).unwrap();
|
||||
|
||||
assert!(t1_provider.get_vector_search_hint().is_none());
|
||||
assert!(t2_provider.get_vector_search_hint().is_none());
|
||||
}
|
||||
|
||||
fn vec_distance_expr_qualified(
|
||||
function_name: &'static str,
|
||||
table_name: &str,
|
||||
column_name: &str,
|
||||
) -> Expr {
|
||||
use datafusion_common::Column;
|
||||
|
||||
let udf = create_udf(Arc::new(TestVectorFunction::new(function_name)));
|
||||
let qualified_col = Expr::Column(Column::new(Some(table_name.to_string()), column_name));
|
||||
Expr::ScalarFunction(ScalarFunction::new_udf(
|
||||
Arc::new(udf),
|
||||
vec![
|
||||
qualified_col,
|
||||
lit(ScalarValue::Utf8(Some("[1.0, 2.0]".to_string()))),
|
||||
],
|
||||
))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_vector_hint_above_join() {
|
||||
let t1_provider = build_dummy_provider(10);
|
||||
let t2_provider = build_dummy_provider(20);
|
||||
let t1_source = Arc::new(DefaultTableSource::new(t1_provider.clone()));
|
||||
let t2_source = Arc::new(DefaultTableSource::new(t2_provider.clone()));
|
||||
|
||||
let left = LogicalPlanBuilder::scan_with_filters("t1", t1_source, None, vec![])
|
||||
.unwrap()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let right = LogicalPlanBuilder::scan_with_filters("t2", t2_source, None, vec![])
|
||||
.unwrap()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let join_keys: (Vec<Column>, Vec<Column>) = (vec![], vec![]);
|
||||
let join_plan = LogicalPlanBuilder::from(left)
|
||||
.join(right, JoinType::Inner, join_keys, None)
|
||||
.unwrap()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let expr = vec_distance_expr_qualified(VEC_L2SQ_DISTANCE, "t1", "v");
|
||||
let plan = LogicalPlanBuilder::from(join_plan)
|
||||
.sort(vec![expr.sort(true, false)])
|
||||
.unwrap()
|
||||
.limit(0, Some(5))
|
||||
.unwrap()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let context = OptimizerContext::default();
|
||||
let _ = ScanHintRule.rewrite(plan, &context).unwrap();
|
||||
|
||||
assert!(t1_provider.get_vector_search_hint().is_none());
|
||||
assert!(t2_provider.get_vector_search_hint().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_vector_hint_above_subquery() {
|
||||
let provider = build_dummy_provider(10);
|
||||
let table_source = Arc::new(DefaultTableSource::new(provider.clone()));
|
||||
let scan_plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
|
||||
.unwrap()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let subquery = LogicalPlan::Subquery(Subquery {
|
||||
subquery: Arc::new(scan_plan),
|
||||
outer_ref_columns: vec![],
|
||||
spans: Default::default(),
|
||||
});
|
||||
|
||||
let expr = vec_distance_expr(VEC_L2SQ_DISTANCE);
|
||||
let plan = LogicalPlanBuilder::from(subquery)
|
||||
.sort(vec![expr.sort(true, false)])
|
||||
.unwrap()
|
||||
.limit(0, Some(5))
|
||||
.unwrap()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let context = OptimizerContext::default();
|
||||
let _ = ScanHintRule.rewrite(plan, &context).unwrap();
|
||||
|
||||
assert!(provider.get_vector_search_hint().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_qualified_column_scopes_hint_to_correct_table() {
|
||||
let t1_provider = build_dummy_provider(10);
|
||||
let t2_provider = build_dummy_provider(20);
|
||||
let t1_source = Arc::new(DefaultTableSource::new(t1_provider.clone()));
|
||||
let t2_source = Arc::new(DefaultTableSource::new(t2_provider.clone()));
|
||||
|
||||
let expr = vec_distance_expr_qualified(VEC_L2SQ_DISTANCE, "t2", "v");
|
||||
|
||||
let t1_plan = LogicalPlanBuilder::scan_with_filters("t1", t1_source, None, vec![])
|
||||
.unwrap()
|
||||
.sort(vec![expr.clone().sort(true, false)])
|
||||
.unwrap()
|
||||
.limit(0, Some(5))
|
||||
.unwrap()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let t2_plan = LogicalPlanBuilder::scan_with_filters("t2", t2_source, None, vec![])
|
||||
.unwrap()
|
||||
.sort(vec![expr.sort(true, false)])
|
||||
.unwrap()
|
||||
.limit(0, Some(5))
|
||||
.unwrap()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let context = OptimizerContext::default();
|
||||
|
||||
let _ = ScanHintRule.rewrite(t1_plan, &context).unwrap();
|
||||
assert!(t1_provider.get_vector_search_hint().is_none());
|
||||
|
||||
let _ = ScanHintRule.rewrite(t2_plan, &context).unwrap();
|
||||
let hint = t2_provider.get_vector_search_hint().unwrap();
|
||||
assert_eq!(hint.column_id, 20);
|
||||
assert_eq!(hint.k, 5);
|
||||
}
|
||||
}
|
||||
@@ -26,7 +26,7 @@ use datafusion_expr::TableProviderFilterPushDown as DfTableProviderFilterPushDow
|
||||
use datafusion_expr::expr::Expr;
|
||||
use datafusion_physical_expr::PhysicalSortExpr;
|
||||
use datafusion_physical_expr::expressions::Column;
|
||||
use store_api::storage::ScanRequest;
|
||||
use store_api::storage::{ScanRequest, VectorSearchRequest};
|
||||
|
||||
use crate::table::{TableRef, TableType};
|
||||
|
||||
@@ -52,6 +52,14 @@ impl DfTableProviderAdapter {
|
||||
self.scan_req.lock().unwrap().output_ordering = Some(order_opts.to_vec());
|
||||
}
|
||||
|
||||
pub fn with_vector_search_hint(&self, hint: VectorSearchRequest) {
|
||||
self.scan_req.lock().unwrap().vector_search = Some(hint);
|
||||
}
|
||||
|
||||
pub fn get_vector_search_hint(&self) -> Option<VectorSearchRequest> {
|
||||
self.scan_req.lock().unwrap().vector_search.clone()
|
||||
}
|
||||
|
||||
#[cfg(feature = "testing")]
|
||||
pub fn get_scan_req(&self) -> ScanRequest {
|
||||
self.scan_req.lock().unwrap().clone()
|
||||
|
||||
Reference in New Issue
Block a user