diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index facf15343..f488c6377 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -3,11 +3,18 @@ pub mod insert; +use super::client::RequestResultExt; +use super::client::{HttpSend, RestfulLanceDbClient, Sender}; +use super::db::ServerVersion; +use super::util::stream_as_body; +use super::ARROW_STREAM_CONTENT_TYPE; use crate::data::scannable::Scannable; +use crate::index::waiter::wait_for_index; use crate::index::Index; use crate::index::IndexStatistics; use crate::query::{QueryFilter, QueryRequest, Select, VectorQueryRequest}; use crate::remote::util::stream_as_ipc; +use crate::table::query::create_multi_vector_plan; use crate::table::AddColumnsResult; use crate::table::AddResult; use crate::table::AlterColumnsResult; @@ -18,7 +25,16 @@ use crate::table::Tags; use crate::table::UpdateResult; use crate::table::{AddDataMode, AnyQuery, Filter, TableStatistics}; use crate::utils::{supported_btree_data_type, supported_vector_data_type}; -use crate::{DistanceType, Error, Table}; +use crate::{ + error::Result, + index::{IndexBuilder, IndexConfig}, + query::QueryExecutionOptions, + table::{ + merge::MergeInsertBuilder, AddDataBuilder, BaseTable, OptimizeAction, OptimizeStats, + TableDefinition, UpdateBuilder, + }, +}; +use crate::{DistanceType, Error}; use arrow_array::{RecordBatch, RecordBatchIterator, RecordBatchReader}; use arrow_ipc::reader::FileReader; use arrow_schema::{DataType, SchemaRef}; @@ -45,22 +61,6 @@ use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant}; use tokio::sync::RwLock; -use super::client::RequestResultExt; -use super::client::{HttpSend, RestfulLanceDbClient, Sender}; -use super::db::ServerVersion; -use super::util::stream_as_body; -use super::ARROW_STREAM_CONTENT_TYPE; -use crate::index::waiter::wait_for_index; -use crate::{ - error::Result, - index::{IndexBuilder, IndexConfig}, - query::QueryExecutionOptions, - table::{ - merge::MergeInsertBuilder, AddDataBuilder, BaseTable, OptimizeAction, OptimizeStats, - TableDefinition, UpdateBuilder, - }, -}; - const REQUEST_TIMEOUT_HEADER: HeaderName = HeaderName::from_static("x-request-timeout-ms"); const METRIC_TYPE_KEY: &str = "metric_type"; const INDEX_TYPE_KEY: &str = "index_type"; @@ -1309,7 +1309,7 @@ impl BaseTable for RemoteTable { .into_iter() .map(|stream| Arc::new(OneShotExec::new(stream)) as Arc) .collect(); - Table::multi_vector_plan(stream_execs) + create_multi_vector_plan(stream_execs) } } @@ -1329,7 +1329,7 @@ impl BaseTable for RemoteTable { .into_iter() .map(|stream| Arc::new(OneShotExec::new(stream)) as Arc) .collect(); - let plan = Table::multi_vector_plan(stream_execs)?; + let plan = create_multi_vector_plan(stream_execs)?; Ok(DatasetRecordBatchStream::new(execute_plan( plan, diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index b7761f5f9..3ca3f521f 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -3,20 +3,14 @@ //! LanceDB Table APIs -use arrow::array::{AsArray, FixedSizeListBuilder, Float32Builder}; -use arrow::datatypes::{Float32Type, UInt8Type}; use arrow_array::{RecordBatch, RecordBatchReader}; use arrow_schema::{DataType, Field, Schema, SchemaRef}; use async_trait::async_trait; use datafusion_expr::Expr; use datafusion_physical_plan::display::DisplayableExecutionPlan; -use datafusion_physical_plan::projection::ProjectionExec; -use datafusion_physical_plan::repartition::RepartitionExec; -use datafusion_physical_plan::union::UnionExec; use datafusion_physical_plan::ExecutionPlan; use futures::{FutureExt, StreamExt, TryFutureExt}; use lance::dataset::builder::DatasetBuilder; -use lance::dataset::scanner::Scanner; pub use lance::dataset::ColumnAlteration; pub use lance::dataset::NewColumnTransform; pub use lance::dataset::ReadParams; @@ -26,7 +20,6 @@ use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatch use lance::index::vector::utils::infer_vector_dim; use lance::index::vector::VectorIndexParams; use lance::io::{ObjectStoreParams, WrappingObjectStore}; -use lance_datafusion::exec::{analyze_plan as lance_analyze_plan, execute_plan}; use lance_datafusion::utils::StreamingWriteSource; use lance_index::scalar::{BuiltinIndexType, ScalarIndexParams}; use lance_index::vector::bq::RQBuildParams; @@ -37,10 +30,8 @@ use lance_index::vector::sq::builder::SQBuildParams; use lance_index::DatasetIndexExt; use lance_index::IndexType; use lance_io::object_store::{LanceNamespaceStorageOptionsProvider, StorageOptionsAccessor}; -use lance_namespace::models::{ - QueryTableRequest as NsQueryTableRequest, QueryTableRequestColumns, - QueryTableRequestFullTextQuery, QueryTableRequestVector, StringFtsQuery, -}; +pub use query::AnyQuery; + use lance_namespace::LanceNamespace; use lance_table::format::Manifest; use lance_table::io::commit::ManifestNamingScheme; @@ -58,14 +49,10 @@ use crate::index::vector::VectorIndex; use crate::index::IndexStatistics; use crate::index::{vector::suggested_num_sub_vectors, Index, IndexBuilder}; use crate::index::{IndexConfig, IndexStatisticsImpl}; -use crate::query::{ - IntoQueryVector, Query, QueryExecutionOptions, QueryFilter, QueryRequest, Select, TakeQuery, - VectorQuery, VectorQueryRequest, DEFAULT_TOP_K, -}; +use crate::query::{IntoQueryVector, Query, QueryExecutionOptions, TakeQuery, VectorQuery}; use crate::utils::{ - default_vector_column, supported_bitmap_data_type, supported_btree_data_type, - supported_fts_data_type, supported_label_list_data_type, supported_vector_data_type, - PatchReadParam, PatchWriteParam, TimeoutStream, + supported_bitmap_data_type, supported_btree_data_type, supported_fts_data_type, + supported_label_list_data_type, supported_vector_data_type, PatchReadParam, PatchWriteParam, }; use self::dataset::DatasetConsistencyWrapper; @@ -77,12 +64,11 @@ pub(crate) mod dataset; pub mod delete; pub mod merge; pub mod optimize; +pub mod query; pub mod schema_evolution; pub mod update; - -pub use add_data::{AddDataBuilder, AddDataMode, AddResult}; - use crate::index::waiter::wait_for_index; +pub use add_data::{AddDataBuilder, AddDataMode, AddResult}; pub use chrono::Duration; pub use delete::DeleteResult; use futures::future::{join_all, Either}; @@ -206,13 +192,6 @@ pub enum Filter { Datafusion(Expr), } -/// A query that can be used to search a LanceDB table -#[derive(Debug, Clone)] -pub enum AnyQuery { - Query(QueryRequest), - VectorQuery(VectorQueryRequest), -} - #[async_trait] pub trait Tags: Send + Sync { /// List the tags of the table. @@ -1191,59 +1170,6 @@ impl Table { self.inner.tags().await } - // Take many execution plans and map them into a single plan that adds - // a query_index column and unions them. - pub(crate) fn multi_vector_plan( - plans: Vec>, - ) -> Result> { - if plans.is_empty() { - return Err(Error::InvalidInput { - message: "No plans provided".to_string(), - }); - } - // Projection to keeping all existing columns - let first_plan = plans[0].clone(); - let project_all_columns = first_plan - .schema() - .fields() - .iter() - .enumerate() - .map(|(i, field)| { - let expr = - datafusion_physical_plan::expressions::Column::new(field.name().as_str(), i); - let expr = Arc::new(expr) as Arc; - (expr, field.name().clone()) - }) - .collect::>(); - - let projected_plans = plans - .into_iter() - .enumerate() - .map(|(plan_i, plan)| { - let query_index = datafusion_common::ScalarValue::Int32(Some(plan_i as i32)); - let query_index_expr = - datafusion_physical_plan::expressions::Literal::new(query_index); - let query_index_expr = - Arc::new(query_index_expr) as Arc; - let mut projections = vec![(query_index_expr, "query_index".to_string())]; - projections.extend_from_slice(&project_all_columns); - let projection = ProjectionExec::try_new(projections, plan).unwrap(); - Arc::new(projection) as Arc - }) - .collect::>(); - - let unioned = UnionExec::try_new(projected_plans).map_err(|err| Error::Runtime { - message: err.to_string(), - })?; - // We require 1 partition in the final output - let repartitioned = RepartitionExec::try_new( - unioned, - datafusion_physical_plan::Partitioning::RoundRobinBatch(1), - ) - .unwrap(); - Ok(Arc::new(repartitioned)) - } - /// Retrieve statistics on the table pub async fn stats(&self) -> Result { self.inner.stats().await @@ -1308,7 +1234,8 @@ pub struct NativeTable { read_consistency_interval: Option, // Optional namespace client for server-side query execution. // When set, queries will be executed on the namespace server instead of locally. - namespace_client: Option>, + // pub (crate) namespace_client so query.rs can access the fields + pub(crate) namespace_client: Option>, } impl std::fmt::Debug for NativeTable { @@ -2037,292 +1964,6 @@ impl NativeTable { } } - async fn generic_query( - &self, - query: &AnyQuery, - options: QueryExecutionOptions, - ) -> Result { - let plan = self.create_plan(query, options.clone()).await?; - let inner = execute_plan(plan, Default::default())?; - let inner = if let Some(timeout) = options.timeout { - TimeoutStream::new_boxed(inner, timeout) - } else { - inner - }; - Ok(DatasetRecordBatchStream::new(inner)) - } - - /// Execute a query on the namespace server instead of locally. - async fn namespace_query( - &self, - namespace_client: Arc, - query: &AnyQuery, - _options: QueryExecutionOptions, - ) -> Result { - // Build table_id from namespace + table name - let mut table_id = self.namespace.clone(); - table_id.push(self.name.clone()); - - // Convert AnyQuery to namespace QueryTableRequest - let mut ns_request = self.convert_to_namespace_query(query)?; - // Set the table ID on the request - ns_request.id = Some(table_id); - - // Call the namespace query_table API - let response_bytes = namespace_client - .query_table(ns_request) - .await - .map_err(|e| Error::Runtime { - message: format!("Failed to execute server-side query: {}", e), - })?; - - // Parse the Arrow IPC response into a RecordBatchStream - self.parse_arrow_ipc_response(response_bytes).await - } - - /// Convert a QueryFilter to a SQL string for the namespace API. - fn filter_to_sql(&self, filter: &QueryFilter) -> Result { - match filter { - QueryFilter::Sql(sql) => Ok(sql.clone()), - QueryFilter::Substrait(_) => Err(Error::NotSupported { - message: "Substrait filters are not supported for server-side queries".to_string(), - }), - QueryFilter::Datafusion(_) => Err(Error::NotSupported { - message: "Datafusion expression filters are not supported for server-side queries. Use SQL filter instead.".to_string(), - }), - } - } - - /// Convert an AnyQuery to the namespace QueryTableRequest format. - fn convert_to_namespace_query(&self, query: &AnyQuery) -> Result { - match query { - AnyQuery::VectorQuery(vq) => { - // Extract the query vector(s) - let vector = self.extract_query_vector(&vq.query_vector)?; - - // Convert filter to SQL string - let filter = match &vq.base.filter { - Some(f) => Some(self.filter_to_sql(f)?), - None => None, - }; - - // Convert select to columns list - let columns = match &vq.base.select { - Select::All => None, - Select::Columns(cols) => Some(Box::new(QueryTableRequestColumns { - column_names: Some(cols.clone()), - column_aliases: None, - })), - Select::Dynamic(_) => { - return Err(Error::NotSupported { - message: - "Dynamic column selection is not supported for server-side queries" - .to_string(), - }); - } - }; - - // Check for unsupported features - if vq.base.reranker.is_some() { - return Err(Error::NotSupported { - message: "Reranker is not supported for server-side queries".to_string(), - }); - } - - // Convert FTS query if present - let full_text_query = vq.base.full_text_search.as_ref().map(|fts| { - let columns = fts.columns(); - let columns_vec = if columns.is_empty() { - None - } else { - Some(columns.into_iter().collect()) - }; - Box::new(QueryTableRequestFullTextQuery { - string_query: Some(Box::new(StringFtsQuery { - query: fts.query.to_string(), - columns: columns_vec, - })), - structured_query: None, - }) - }); - - Ok(NsQueryTableRequest { - id: None, // Will be set in namespace_query - k: vq.base.limit.unwrap_or(10) as i32, - vector: Box::new(vector), - vector_column: vq.column.clone(), - filter, - columns, - offset: vq.base.offset.map(|o| o as i32), - distance_type: vq.distance_type.map(|dt| dt.to_string()), - nprobes: Some(vq.minimum_nprobes as i32), - ef: vq.ef.map(|e| e as i32), - refine_factor: vq.refine_factor.map(|r| r as i32), - lower_bound: vq.lower_bound, - upper_bound: vq.upper_bound, - prefilter: Some(vq.base.prefilter), - fast_search: Some(vq.base.fast_search), - with_row_id: Some(vq.base.with_row_id), - bypass_vector_index: Some(!vq.use_index), - full_text_query, - ..Default::default() - }) - } - AnyQuery::Query(q) => { - // For non-vector queries, pass an empty vector (similar to remote table implementation) - if q.reranker.is_some() { - return Err(Error::NotSupported { - message: "Reranker is not supported for server-side query execution" - .to_string(), - }); - } - - let filter = q - .filter - .as_ref() - .map(|f| self.filter_to_sql(f)) - .transpose()?; - - let columns = match &q.select { - Select::All => None, - Select::Columns(cols) => Some(Box::new(QueryTableRequestColumns { - column_names: Some(cols.clone()), - column_aliases: None, - })), - Select::Dynamic(_) => { - return Err(Error::NotSupported { - message: "Dynamic columns are not supported for server-side query" - .to_string(), - }); - } - }; - - // Handle full text search if present - let full_text_query = q.full_text_search.as_ref().map(|fts| { - let columns_vec = if fts.columns().is_empty() { - None - } else { - Some(fts.columns().iter().cloned().collect()) - }; - Box::new(QueryTableRequestFullTextQuery { - string_query: Some(Box::new(StringFtsQuery { - query: fts.query.to_string(), - columns: columns_vec, - })), - structured_query: None, - }) - }); - - // Empty vector for non-vector queries - let vector = Box::new(QueryTableRequestVector { - single_vector: Some(vec![]), - multi_vector: None, - }); - - Ok(NsQueryTableRequest { - id: None, // Will be set by caller - vector, - k: q.limit.unwrap_or(10) as i32, - filter, - columns, - prefilter: Some(q.prefilter), - offset: q.offset.map(|o| o as i32), - vector_column: None, // No vector column for plain queries - with_row_id: Some(q.with_row_id), - bypass_vector_index: Some(true), // No vector index for plain queries - full_text_query, - ..Default::default() - }) - } - } - } - - /// Extract query vector(s) from Arrow arrays into the namespace format. - fn extract_query_vector( - &self, - query_vectors: &[Arc], - ) -> Result { - if query_vectors.is_empty() { - return Err(Error::InvalidInput { - message: "Query vector is required for vector search".to_string(), - }); - } - - // Handle single vector case - if query_vectors.len() == 1 { - let arr = &query_vectors[0]; - let single_vector = self.array_to_f32_vec(arr)?; - Ok(QueryTableRequestVector { - single_vector: Some(single_vector), - multi_vector: None, - }) - } else { - // Handle multi-vector case - let multi_vector: Result>> = query_vectors - .iter() - .map(|arr| self.array_to_f32_vec(arr)) - .collect(); - Ok(QueryTableRequestVector { - single_vector: None, - multi_vector: Some(multi_vector?), - }) - } - } - - /// Convert an Arrow array to a Vec. - fn array_to_f32_vec(&self, arr: &Arc) -> Result> { - // Handle FixedSizeList (common for vectors) - if let Some(fsl) = arr - .as_any() - .downcast_ref::() - { - let values = fsl.values(); - if let Some(f32_arr) = values.as_any().downcast_ref::() { - return Ok(f32_arr.values().to_vec()); - } - } - - // Handle direct Float32Array - if let Some(f32_arr) = arr.as_any().downcast_ref::() { - return Ok(f32_arr.values().to_vec()); - } - - Err(Error::InvalidInput { - message: "Query vector must be Float32 type".to_string(), - }) - } - - /// Parse Arrow IPC response from the namespace server. - async fn parse_arrow_ipc_response( - &self, - bytes: bytes::Bytes, - ) -> Result { - use arrow_ipc::reader::StreamReader; - use std::io::Cursor; - - let cursor = Cursor::new(bytes); - let reader = StreamReader::try_new(cursor, None).map_err(|e| Error::Runtime { - message: format!("Failed to parse Arrow IPC response: {}", e), - })?; - - // Collect all record batches - let schema = reader.schema(); - let batches: Vec<_> = reader - .into_iter() - .collect::, _>>() - .map_err(|e| Error::Runtime { - message: format!("Failed to read Arrow IPC batches: {}", e), - })?; - - // Create a stream from the batches - let stream = futures::stream::iter(batches.into_iter().map(Ok)); - let record_batch_stream = Box::pin( - datafusion_physical_plan::stream::RecordBatchStreamAdapter::new(schema, stream), - ); - - Ok(DatasetRecordBatchStream::new(record_batch_stream)) - } - /// Check whether the table uses V2 manifest paths. /// /// See [Self::migrate_manifest_paths_v2] and [ManifestNamingScheme] for @@ -2564,167 +2205,7 @@ impl BaseTable for NativeTable { query: &AnyQuery, options: QueryExecutionOptions, ) -> Result> { - let query = match query { - AnyQuery::VectorQuery(query) => query.clone(), - AnyQuery::Query(query) => VectorQueryRequest::from_plain_query(query.clone()), - }; - - let ds_ref = self.dataset.get().await?; - let schema = ds_ref.schema(); - let mut column = query.column.clone(); - - let mut query_vector = query.query_vector.first().cloned(); - if query.query_vector.len() > 1 { - if column.is_none() { - // Infer a vector column with the same dimension of the query vector. - let arrow_schema = Schema::from(ds_ref.schema()); - column = Some(default_vector_column( - &arrow_schema, - Some(query.query_vector[0].len() as i32), - )?); - } - let vector_field = schema.field(column.as_ref().unwrap()).unwrap(); - if let DataType::List(_) = vector_field.data_type() { - // it's multivector, then the vectors should be treated as single query - // concatenate the vectors into a FixedSizeList> - // it's also possible to concatenate the vectors into a List>, - // but FixedSizeList is more efficient and easier to construct - let vectors = query - .query_vector - .iter() - .map(|arr| arr.as_ref()) - .collect::>(); - let dim = vectors[0].len(); - let mut fsl_builder = FixedSizeListBuilder::with_capacity( - Float32Builder::with_capacity(dim), - dim as i32, - vectors.len(), - ); - for vec in vectors { - fsl_builder - .values() - .append_slice(vec.as_primitive::().values()); - fsl_builder.append(true); - } - query_vector = Some(Arc::new(fsl_builder.finish())); - } else { - // If there are multiple query vectors, create a plan for each of them and union them. - let query_vecs = query.query_vector.clone(); - let plan_futures = query_vecs - .into_iter() - .map(|query_vector| { - let mut sub_query = query.clone(); - sub_query.query_vector = vec![query_vector]; - let options_ref = options.clone(); - async move { - self.create_plan(&AnyQuery::VectorQuery(sub_query), options_ref) - .await - } - }) - .collect::>(); - let plans = futures::future::try_join_all(plan_futures).await?; - return Table::multi_vector_plan(plans); - } - } - - let mut scanner: Scanner = ds_ref.scan(); - - if let Some(query_vector) = query_vector { - // If there is a vector query, default to limit=10 if unspecified - let column = if let Some(col) = column { - col - } else { - // Infer a vector column with the same dimension of the query vector. - let arrow_schema = Schema::from(ds_ref.schema()); - default_vector_column(&arrow_schema, Some(query_vector.len() as i32))? - }; - - let (_, element_type) = lance::index::vector::utils::get_vector_type(schema, &column)?; - let is_binary = matches!(element_type, DataType::UInt8); - let top_k = query.base.limit.unwrap_or(DEFAULT_TOP_K) + query.base.offset.unwrap_or(0); - if is_binary { - let query_vector = arrow::compute::cast(&query_vector, &DataType::UInt8)?; - let query_vector = query_vector.as_primitive::(); - scanner.nearest(&column, query_vector, top_k)?; - } else { - scanner.nearest(&column, query_vector.as_ref(), top_k)?; - } - scanner.minimum_nprobes(query.minimum_nprobes); - if let Some(maximum_nprobes) = query.maximum_nprobes { - scanner.maximum_nprobes(maximum_nprobes); - } - } - scanner.limit( - query.base.limit.map(|limit| limit as i64), - query.base.offset.map(|offset| offset as i64), - )?; - if let Some(ef) = query.ef { - scanner.ef(ef); - } - scanner.distance_range(query.lower_bound, query.upper_bound); - scanner.use_index(query.use_index); - scanner.prefilter(query.base.prefilter); - match query.base.select { - Select::Columns(ref columns) => { - scanner.project(columns.as_slice())?; - } - Select::Dynamic(ref select_with_transform) => { - scanner.project_with_transform(select_with_transform.as_slice())?; - } - Select::All => {} - } - - if query.base.with_row_id { - scanner.with_row_id(); - } - - scanner.batch_size(options.max_batch_length as usize); - - if query.base.fast_search { - scanner.fast_search(); - } - - match &query.base.select { - Select::Columns(select) => { - scanner.project(select.as_slice())?; - } - Select::Dynamic(select_with_transform) => { - scanner.project_with_transform(select_with_transform.as_slice())?; - } - Select::All => { /* Do nothing */ } - } - - if let Some(filter) = &query.base.filter { - match filter { - QueryFilter::Sql(sql) => { - scanner.filter(sql)?; - } - QueryFilter::Substrait(substrait) => { - scanner.filter_substrait(substrait)?; - } - QueryFilter::Datafusion(expr) => { - scanner.filter_expr(expr.clone()); - } - } - } - - if let Some(fts) = &query.base.full_text_search { - scanner.full_text_search(fts.clone())?; - } - - if let Some(refine_factor) = query.refine_factor { - scanner.refine(refine_factor); - } - - if let Some(distance_type) = query.distance_type { - scanner.distance_metric(distance_type.into()); - } - - if query.base.disable_scoring_autoprojection { - scanner.disable_scoring_autoprojection(); - } - - Ok(scanner.create_plan().await?) + query::create_plan(self, query, options).await } async fn query( @@ -2732,13 +2213,7 @@ impl BaseTable for NativeTable { query: &AnyQuery, options: QueryExecutionOptions, ) -> Result { - // If namespace client is configured, use server-side query execution - if let Some(ref namespace_client) = self.namespace_client { - return self - .namespace_query(namespace_client.clone(), query, options) - .await; - } - self.generic_query(query, options).await + query::execute_query(self, query, options).await } async fn analyze_plan( @@ -2746,8 +2221,7 @@ impl BaseTable for NativeTable { query: &AnyQuery, options: QueryExecutionOptions, ) -> Result { - let plan = self.create_plan(query, options).await?; - Ok(lance_analyze_plan(plan, Default::default()).await?) + query::analyze_query_plan(self, query, options).await } async fn merge_insert( @@ -3104,8 +2578,8 @@ mod tests { use arrow_array::{ builder::{ListBuilder, StringBuilder}, - Array, BooleanArray, FixedSizeListArray, Float32Array, Int32Array, LargeStringArray, - RecordBatch, RecordBatchIterator, RecordBatchReader, StringArray, + Array, BooleanArray, FixedSizeListArray, Int32Array, LargeStringArray, RecordBatch, + RecordBatchIterator, RecordBatchReader, StringArray, }; use arrow_array::{BinaryArray, LargeBinaryArray}; use arrow_data::ArrayDataBuilder; @@ -3121,9 +2595,9 @@ mod tests { use crate::connection::ConnectBuilder; use crate::index::scalar::{BTreeIndexBuilder, BitmapIndexBuilder}; use crate::index::vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder}; + use crate::query::Select; use crate::query::{ExecutableQuery, QueryBase}; use crate::test_utils::connection::new_test_connection; - #[tokio::test] async fn test_open() { let tmp_dir = tempdir().unwrap(); @@ -4389,105 +3863,4 @@ mod tests { assert_eq!(result.len(), 1); assert_eq!(result[0].index_type, crate::index::IndexType::Bitmap); } - - #[tokio::test] - async fn test_convert_to_namespace_query_vector() { - let tmp_dir = tempdir().unwrap(); - let dataset_path = tmp_dir.path().join("test_ns_query.lance"); - - let batch = make_test_batches(); - let reader = RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema()); - Dataset::write(reader, dataset_path.to_str().unwrap(), None) - .await - .unwrap(); - - let table = NativeTable::open(dataset_path.to_str().unwrap()) - .await - .unwrap(); - - // Create a vector query - let query_vector = Arc::new(Float32Array::from(vec![1.0, 2.0, 3.0, 4.0])); - let vq = VectorQueryRequest { - base: QueryRequest { - limit: Some(10), - offset: Some(5), - filter: Some(QueryFilter::Sql("id > 0".to_string())), - select: Select::Columns(vec!["id".to_string()]), - ..Default::default() - }, - column: Some("vector".to_string()), - query_vector: vec![query_vector as Arc], - minimum_nprobes: 20, - distance_type: Some(crate::DistanceType::L2), - ..Default::default() - }; - - let any_query = AnyQuery::VectorQuery(vq); - let ns_request = table.convert_to_namespace_query(&any_query).unwrap(); - - assert_eq!(ns_request.k, 10); - assert_eq!(ns_request.offset, Some(5)); - assert_eq!(ns_request.filter, Some("id > 0".to_string())); - assert_eq!( - ns_request - .columns - .as_ref() - .and_then(|c| c.column_names.as_ref()), - Some(&vec!["id".to_string()]) - ); - assert_eq!(ns_request.vector_column, Some("vector".to_string())); - assert_eq!(ns_request.distance_type, Some("l2".to_string())); - assert!(ns_request.vector.single_vector.is_some()); - assert_eq!( - ns_request.vector.single_vector.as_ref().unwrap(), - &vec![1.0, 2.0, 3.0, 4.0] - ); - } - - #[tokio::test] - async fn test_convert_to_namespace_query_plain_query() { - let tmp_dir = tempdir().unwrap(); - let dataset_path = tmp_dir.path().join("test_ns_plain.lance"); - - let batch = make_test_batches(); - let reader = RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema()); - Dataset::write(reader, dataset_path.to_str().unwrap(), None) - .await - .unwrap(); - - let table = NativeTable::open(dataset_path.to_str().unwrap()) - .await - .unwrap(); - - // Create a plain (non-vector) query with filter and select - let q = QueryRequest { - limit: Some(20), - offset: Some(5), - filter: Some(QueryFilter::Sql("id > 5".to_string())), - select: Select::Columns(vec!["id".to_string()]), - with_row_id: true, - ..Default::default() - }; - - let any_query = AnyQuery::Query(q); - let ns_request = table.convert_to_namespace_query(&any_query).unwrap(); - - // Plain queries should pass an empty vector - assert_eq!(ns_request.k, 20); - assert_eq!(ns_request.offset, Some(5)); - assert_eq!(ns_request.filter, Some("id > 5".to_string())); - assert_eq!( - ns_request - .columns - .as_ref() - .and_then(|c| c.column_names.as_ref()), - Some(&vec!["id".to_string()]) - ); - assert_eq!(ns_request.with_row_id, Some(true)); - assert_eq!(ns_request.bypass_vector_index, Some(true)); - assert!(ns_request.vector_column.is_none()); // No vector column for plain queries - - // Should have an empty vector - assert!(ns_request.vector.single_vector.as_ref().unwrap().is_empty()); - } } diff --git a/rust/lancedb/src/table/query.rs b/rust/lancedb/src/table/query.rs new file mode 100644 index 000000000..c63d2c79e --- /dev/null +++ b/rust/lancedb/src/table/query.rs @@ -0,0 +1,739 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +use std::sync::Arc; + +use super::NativeTable; +use crate::error::{Error, Result}; +use crate::query::{ + QueryExecutionOptions, QueryFilter, QueryRequest, Select, VectorQueryRequest, DEFAULT_TOP_K, +}; +use crate::utils::{default_vector_column, TimeoutStream}; +use arrow::array::{AsArray, FixedSizeListBuilder, Float32Builder}; +use arrow::datatypes::{Float32Type, UInt8Type}; +use arrow_array::Array; +use arrow_schema::{DataType, Schema}; +use datafusion_physical_plan::projection::ProjectionExec; +use datafusion_physical_plan::repartition::RepartitionExec; +use datafusion_physical_plan::union::UnionExec; +use datafusion_physical_plan::ExecutionPlan; +use futures::future::try_join_all; +use lance::dataset::scanner::DatasetRecordBatchStream; +use lance::dataset::scanner::Scanner; +use lance_datafusion::exec::{analyze_plan as lance_analyze_plan, execute_plan}; +use lance_namespace::models::{ + QueryTableRequest as NsQueryTableRequest, QueryTableRequestColumns, + QueryTableRequestFullTextQuery, QueryTableRequestVector, StringFtsQuery, +}; +use lance_namespace::LanceNamespace; + +#[derive(Debug, Clone)] +pub enum AnyQuery { + Query(QueryRequest), + VectorQuery(VectorQueryRequest), +} + +//Decide between namespace or local +pub async fn execute_query( + table: &NativeTable, + query: &AnyQuery, + options: QueryExecutionOptions, +) -> Result { + // If namespace client is configured, use server-side query execution + if let Some(ref namespace_client) = table.namespace_client { + return execute_namespace_query(table, namespace_client.clone(), query, options).await; + } + execute_generic_query(table, query, options).await +} + +pub async fn analyze_query_plan( + table: &NativeTable, + query: &AnyQuery, + options: QueryExecutionOptions, +) -> Result { + let plan = create_plan(table, query, options).await?; + Ok(lance_analyze_plan(plan, Default::default()).await?) +} + +/// Local Execution Path (DataFusion) +async fn execute_generic_query( + table: &NativeTable, + query: &AnyQuery, + options: QueryExecutionOptions, +) -> Result { + let plan = create_plan(table, query, options.clone()).await?; + let inner = execute_plan(plan, Default::default())?; + let inner = if let Some(timeout) = options.timeout { + TimeoutStream::new_boxed(inner, timeout) + } else { + inner + }; + Ok(DatasetRecordBatchStream::new(inner)) +} + +pub async fn create_plan( + table: &NativeTable, + query: &AnyQuery, + options: QueryExecutionOptions, +) -> Result> { + let query = match query { + AnyQuery::VectorQuery(query) => query.clone(), + AnyQuery::Query(query) => VectorQueryRequest::from_plain_query(query.clone()), + }; + + let ds_ref = table.dataset.get().await?; + let schema = ds_ref.schema(); + let mut column = query.column.clone(); + + let mut query_vector = query.query_vector.first().cloned(); + if query.query_vector.len() > 1 { + if column.is_none() { + // Infer a vector column with the same dimension of the query vector. + let arrow_schema = Schema::from(ds_ref.schema()); + column = Some(default_vector_column( + &arrow_schema, + Some(query.query_vector[0].len() as i32), + )?); + } + let vector_field = schema.field(column.as_ref().unwrap()).unwrap(); + if let DataType::List(_) = vector_field.data_type() { + // Multivector handling: concatenate into FixedSizeList> + let vectors = query + .query_vector + .iter() + .map(|arr| arr.as_ref()) + .collect::>(); + let dim = vectors[0].len(); + let mut fsl_builder = FixedSizeListBuilder::with_capacity( + Float32Builder::with_capacity(dim), + dim as i32, + vectors.len(), + ); + for vec in vectors { + fsl_builder + .values() + .append_slice(vec.as_primitive::().values()); + fsl_builder.append(true); + } + query_vector = Some(Arc::new(fsl_builder.finish())); + } else { + // Multiple query vectors: create a plan for each and union them + let query_vecs = query.query_vector.clone(); + let plan_futures = query_vecs + .into_iter() + .map(|query_vector| { + let mut sub_query = query.clone(); + sub_query.query_vector = vec![query_vector]; + let options_ref = options.clone(); + async move { + create_plan(table, &AnyQuery::VectorQuery(sub_query), options_ref).await + } + }) + .collect::>(); + let plans = try_join_all(plan_futures).await?; + return create_multi_vector_plan(plans); + } + } + + let mut scanner: Scanner = ds_ref.scan(); + + if let Some(query_vector) = query_vector { + let column = if let Some(col) = column { + col + } else { + let arrow_schema = Schema::from(ds_ref.schema()); + default_vector_column(&arrow_schema, Some(query_vector.len() as i32))? + }; + + let (_, element_type) = lance::index::vector::utils::get_vector_type(schema, &column)?; + let is_binary = matches!(element_type, DataType::UInt8); + let top_k = query.base.limit.unwrap_or(DEFAULT_TOP_K) + query.base.offset.unwrap_or(0); + + if is_binary { + let query_vector = arrow::compute::cast(&query_vector, &DataType::UInt8)?; + let query_vector = query_vector.as_primitive::(); + scanner.nearest(&column, query_vector, top_k)?; + } else { + scanner.nearest(&column, query_vector.as_ref(), top_k)?; + } + + scanner.minimum_nprobes(query.minimum_nprobes); + if let Some(maximum_nprobes) = query.maximum_nprobes { + scanner.maximum_nprobes(maximum_nprobes); + } + } + + scanner.limit( + query.base.limit.map(|limit| limit as i64), + query.base.offset.map(|offset| offset as i64), + )?; + + if let Some(ef) = query.ef { + scanner.ef(ef); + } + + scanner.distance_range(query.lower_bound, query.upper_bound); + scanner.use_index(query.use_index); + scanner.prefilter(query.base.prefilter); + + match query.base.select { + Select::Columns(ref columns) => { + scanner.project(columns.as_slice())?; + } + Select::Dynamic(ref select_with_transform) => { + scanner.project_with_transform(select_with_transform.as_slice())?; + } + Select::All => {} + } + + if query.base.with_row_id { + scanner.with_row_id(); + } + + scanner.batch_size(options.max_batch_length as usize); + + if query.base.fast_search { + scanner.fast_search(); + } + + if let Some(filter) = &query.base.filter { + match filter { + QueryFilter::Sql(sql) => { + scanner.filter(sql)?; + } + QueryFilter::Substrait(substrait) => { + scanner.filter_substrait(substrait)?; + } + QueryFilter::Datafusion(expr) => { + scanner.filter_expr(expr.clone()); + } + } + } + + if let Some(fts) = &query.base.full_text_search { + scanner.full_text_search(fts.clone())?; + } + + if let Some(refine_factor) = query.refine_factor { + scanner.refine(refine_factor); + } + + if let Some(distance_type) = query.distance_type { + scanner.distance_metric(distance_type.into()); + } + + if query.base.disable_scoring_autoprojection { + scanner.disable_scoring_autoprojection(); + } + + Ok(scanner.create_plan().await?) +} + +//Helper functions below + +// Take many execution plans and map them into a single plan that adds +// a query_index column and unions them. +pub(crate) fn create_multi_vector_plan( + plans: Vec>, +) -> Result> { + if plans.is_empty() { + return Err(Error::InvalidInput { + message: "No plans provided".to_string(), + }); + } + // Projection to keeping all existing columns + let first_plan = plans[0].clone(); + let project_all_columns = first_plan + .schema() + .fields() + .iter() + .enumerate() + .map(|(i, field)| { + let expr = datafusion_physical_plan::expressions::Column::new(field.name().as_str(), i); + let expr = Arc::new(expr) as Arc; + (expr, field.name().clone()) + }) + .collect::>(); + + let projected_plans = plans + .into_iter() + .enumerate() + .map(|(plan_i, plan)| { + let query_index = datafusion_common::ScalarValue::Int32(Some(plan_i as i32)); + let query_index_expr = datafusion_physical_plan::expressions::Literal::new(query_index); + let query_index_expr = + Arc::new(query_index_expr) as Arc; + let mut projections = vec![(query_index_expr, "query_index".to_string())]; + projections.extend_from_slice(&project_all_columns); + let projection = ProjectionExec::try_new(projections, plan).unwrap(); + Arc::new(projection) as Arc + }) + .collect::>(); + + let unioned = UnionExec::try_new(projected_plans).map_err(|err| Error::Runtime { + message: err.to_string(), + })?; + // We require 1 partition in the final output + let repartitioned = RepartitionExec::try_new( + unioned, + datafusion_physical_plan::Partitioning::RoundRobinBatch(1), + ) + .unwrap(); + Ok(Arc::new(repartitioned)) +} + +/// Execute a query on the namespace server instead of locally. +async fn execute_namespace_query( + table: &NativeTable, + namespace_client: Arc, + query: &AnyQuery, + _options: QueryExecutionOptions, +) -> Result { + // Build table_id from namespace + table name + let mut table_id = table.namespace.clone(); + table_id.push(table.name.clone()); + + // Convert AnyQuery to namespace QueryTableRequest + let mut ns_request = convert_to_namespace_query(query)?; + // Set the table ID on the request + ns_request.id = Some(table_id); + + // Call the namespace query_table API + let response_bytes = namespace_client + .query_table(ns_request) + .await + .map_err(|e| Error::Runtime { + message: format!("Failed to execute server-side query: {}", e), + })?; + + // Parse the Arrow IPC response into a RecordBatchStream + parse_arrow_ipc_response(response_bytes).await +} + +/// Convert an AnyQuery to the namespace QueryTableRequest format. +fn convert_to_namespace_query(query: &AnyQuery) -> Result { + match query { + AnyQuery::VectorQuery(vq) => { + // Extract the query vector(s) + let vector = extract_query_vector(&vq.query_vector)?; + + // Convert filter to SQL string + let filter = match &vq.base.filter { + Some(f) => Some(filter_to_sql(f)?), + None => None, + }; + + // Convert select to columns list + let columns = match &vq.base.select { + Select::All => None, + Select::Columns(cols) => Some(Box::new(QueryTableRequestColumns { + column_names: Some(cols.clone()), + column_aliases: None, + })), + Select::Dynamic(_) => { + return Err(Error::NotSupported { + message: + "Dynamic column selection is not supported for server-side queries" + .to_string(), + }); + } + }; + + // Check for unsupported features + if vq.base.reranker.is_some() { + return Err(Error::NotSupported { + message: "Reranker is not supported for server-side queries".to_string(), + }); + } + + // Convert FTS query if present + let full_text_query = vq.base.full_text_search.as_ref().map(|fts| { + let columns = fts.columns(); + let columns_vec = if columns.is_empty() { + None + } else { + Some(columns.into_iter().collect()) + }; + Box::new(QueryTableRequestFullTextQuery { + string_query: Some(Box::new(StringFtsQuery { + query: fts.query.to_string(), + columns: columns_vec, + })), + structured_query: None, + }) + }); + + Ok(NsQueryTableRequest { + id: None, // Will be set in namespace_query + k: vq.base.limit.unwrap_or(10) as i32, + vector: Box::new(vector), + vector_column: vq.column.clone(), + filter, + columns, + offset: vq.base.offset.map(|o| o as i32), + distance_type: vq.distance_type.map(|dt| dt.to_string()), + nprobes: Some(vq.minimum_nprobes as i32), + ef: vq.ef.map(|e| e as i32), + refine_factor: vq.refine_factor.map(|r| r as i32), + lower_bound: vq.lower_bound, + upper_bound: vq.upper_bound, + prefilter: Some(vq.base.prefilter), + fast_search: Some(vq.base.fast_search), + with_row_id: Some(vq.base.with_row_id), + bypass_vector_index: Some(!vq.use_index), + full_text_query, + ..Default::default() + }) + } + AnyQuery::Query(q) => { + // For non-vector queries, pass an empty vector (similar to remote table implementation) + if q.reranker.is_some() { + return Err(Error::NotSupported { + message: "Reranker is not supported for server-side query execution" + .to_string(), + }); + } + + let filter = q.filter.as_ref().map(filter_to_sql).transpose()?; + + let columns = match &q.select { + Select::All => None, + Select::Columns(cols) => Some(Box::new(QueryTableRequestColumns { + column_names: Some(cols.clone()), + column_aliases: None, + })), + Select::Dynamic(_) => { + return Err(Error::NotSupported { + message: "Dynamic columns are not supported for server-side query" + .to_string(), + }); + } + }; + + // Handle full text search if present + let full_text_query = q.full_text_search.as_ref().map(|fts| { + let columns_vec = if fts.columns().is_empty() { + None + } else { + Some(fts.columns().iter().cloned().collect()) + }; + Box::new(QueryTableRequestFullTextQuery { + string_query: Some(Box::new(StringFtsQuery { + query: fts.query.to_string(), + columns: columns_vec, + })), + structured_query: None, + }) + }); + + // Empty vector for non-vector queries + let vector = Box::new(QueryTableRequestVector { + single_vector: Some(vec![]), + multi_vector: None, + }); + + Ok(NsQueryTableRequest { + id: None, // Will be set by caller + vector, + k: q.limit.unwrap_or(10) as i32, + filter, + columns, + prefilter: Some(q.prefilter), + offset: q.offset.map(|o| o as i32), + vector_column: None, // No vector column for plain queries + with_row_id: Some(q.with_row_id), + bypass_vector_index: Some(true), // No vector index for plain queries + full_text_query, + ..Default::default() + }) + } + } +} + +fn filter_to_sql(filter: &QueryFilter) -> Result { + match filter { + QueryFilter::Sql(sql) => Ok(sql.clone()), + QueryFilter::Substrait(_) => Err(Error::NotSupported { + message: "Substrait filters are not supported for server-side queries".to_string(), + }), + QueryFilter::Datafusion(_) => Err(Error::NotSupported { + message: "Datafusion expression filters are not supported for server-side queries. Use SQL filter instead.".to_string(), + }), + } +} + +/// Extract query vector(s) from Arrow arrays into the namespace format. +fn extract_query_vector( + query_vectors: &[Arc], +) -> Result { + if query_vectors.is_empty() { + return Err(Error::InvalidInput { + message: "Query vector is required for vector search".to_string(), + }); + } + + // Handle single vector case + if query_vectors.len() == 1 { + let arr = &query_vectors[0]; + let single_vector = array_to_f32_vec(arr)?; + Ok(QueryTableRequestVector { + single_vector: Some(single_vector), + multi_vector: None, + }) + } else { + // Handle multi-vector case + let multi_vector: Result>> = + query_vectors.iter().map(array_to_f32_vec).collect(); + Ok(QueryTableRequestVector { + single_vector: None, + multi_vector: Some(multi_vector?), + }) + } +} + +/// Convert an Arrow array to a Vec. +fn array_to_f32_vec(arr: &Arc) -> Result> { + // Handle FixedSizeList (common for vectors) + if let Some(fsl) = arr + .as_any() + .downcast_ref::() + { + let values = fsl.values(); + if let Some(f32_arr) = values.as_any().downcast_ref::() { + return Ok(f32_arr.values().to_vec()); + } + } + + // Handle direct Float32Array + if let Some(f32_arr) = arr.as_any().downcast_ref::() { + return Ok(f32_arr.values().to_vec()); + } + + Err(Error::InvalidInput { + message: "Query vector must be Float32 type".to_string(), + }) +} + +/// Parse Arrow IPC response from the namespace server. +async fn parse_arrow_ipc_response(bytes: bytes::Bytes) -> Result { + use arrow_ipc::reader::StreamReader; + use std::io::Cursor; + + let cursor = Cursor::new(bytes); + let reader = StreamReader::try_new(cursor, None).map_err(|e| Error::Runtime { + message: format!("Failed to parse Arrow IPC response: {}", e), + })?; + + // Collect all record batches + let schema = reader.schema(); + let batches: Vec<_> = reader + .into_iter() + .collect::, _>>() + .map_err(|e| Error::Runtime { + message: format!("Failed to read Arrow IPC batches: {}", e), + })?; + + // Create a stream from the batches + let stream = futures::stream::iter(batches.into_iter().map(Ok)); + let record_batch_stream = + Box::pin(datafusion_physical_plan::stream::RecordBatchStreamAdapter::new(schema, stream)); + + Ok(DatasetRecordBatchStream::new(record_batch_stream)) +} + +#[cfg(test)] +#[allow(deprecated)] +mod tests { + use arrow_array::Float32Array; + use futures::TryStreamExt; + use std::sync::Arc; + + use super::*; + use crate::query::QueryExecutionOptions; + + #[test] + fn test_convert_to_namespace_query_vector() { + let query_vector = Arc::new(Float32Array::from(vec![1.0, 2.0, 3.0, 4.0])); + + let vq = VectorQueryRequest { + base: QueryRequest { + limit: Some(10), + offset: Some(5), + filter: Some(QueryFilter::Sql("id > 0".to_string())), + select: Select::Columns(vec!["id".to_string()]), + ..Default::default() + }, + column: Some("vector".to_string()), + // We cast here to satisfy the struct definition + query_vector: vec![query_vector as Arc], + minimum_nprobes: 20, + distance_type: Some(crate::DistanceType::L2), + ..Default::default() + }; + + let any_query = AnyQuery::VectorQuery(vq); + + let ns_request = convert_to_namespace_query(&any_query).unwrap(); + + assert_eq!(ns_request.k, 10); + assert_eq!(ns_request.offset, Some(5)); + assert_eq!(ns_request.filter, Some("id > 0".to_string())); + assert_eq!( + ns_request + .columns + .as_ref() + .and_then(|c| c.column_names.as_ref()), + Some(&vec!["id".to_string()]) + ); + assert_eq!(ns_request.vector_column, Some("vector".to_string())); + assert_eq!(ns_request.distance_type, Some("l2".to_string())); + + // Verify the vector data was extracted correctly + assert!(ns_request.vector.single_vector.is_some()); + assert_eq!( + ns_request.vector.single_vector.as_ref().unwrap(), + &vec![1.0, 2.0, 3.0, 4.0] + ); + } + + #[test] + fn test_convert_to_namespace_query_plain_query() { + let q = QueryRequest { + limit: Some(20), + offset: Some(5), + filter: Some(QueryFilter::Sql("id > 5".to_string())), + select: Select::Columns(vec!["id".to_string()]), + with_row_id: true, + ..Default::default() + }; + + let any_query = AnyQuery::Query(q); + + let ns_request = convert_to_namespace_query(&any_query).unwrap(); + + assert_eq!(ns_request.k, 20); + assert_eq!(ns_request.offset, Some(5)); + assert_eq!(ns_request.filter, Some("id > 5".to_string())); + assert_eq!( + ns_request + .columns + .as_ref() + .and_then(|c| c.column_names.as_ref()), + Some(&vec!["id".to_string()]) + ); + assert_eq!(ns_request.with_row_id, Some(true)); + assert_eq!(ns_request.bypass_vector_index, Some(true)); + assert!(ns_request.vector_column.is_none()); + + assert!(ns_request.vector.single_vector.as_ref().unwrap().is_empty()); + } + + #[tokio::test] + async fn test_execute_query_local_routing() { + use crate::connect; + use crate::table::query::execute_query; + use arrow_array::{Int32Array, RecordBatch}; + use arrow_schema::{DataType, Field, Schema}; + + let conn = connect("memory://").execute().await.unwrap(); + + let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]))], + ) + .unwrap(); + + let table = conn + .create_table("test_routing", vec![batch]) + .execute() + .await + .unwrap(); + + let native_table = table.as_native().unwrap(); + + // Setup a request + let req = QueryRequest { + filter: Some(QueryFilter::Sql("id > 3".to_string())), + ..Default::default() + }; + let query = AnyQuery::Query(req); + + // Action: Call execute_query directly + // This validates that execute_query correctly routes to the local DataFusion engine + // when table.namespace_client is None. + let stream = execute_query(native_table, &query, QueryExecutionOptions::default()) + .await + .unwrap(); + + // Verify results + let batches = stream.try_collect::>().await.unwrap(); + let count: usize = batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(count, 2); // 4 and 5 + } + + #[tokio::test] + async fn test_create_plan_multivector_structure() { + use arrow_array::{Float32Array, RecordBatch}; + use arrow_schema::{DataType, Field, Schema}; + use datafusion_physical_plan::display::DisplayableExecutionPlan; + + use crate::table::query::create_plan; + + use crate::connect; + + let conn = connect("memory://").execute().await.unwrap(); + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new( + "vector", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 2), + false, + ), + ])); + + let batch = RecordBatch::new_empty(schema.clone()); + let table = conn + .create_table("test_plan", vec![batch]) + .execute() + .await + .unwrap(); + let native_table = table.as_native().unwrap(); + + // This triggers the "create_multi_vector_plan" logic branch + let q1 = Arc::new(Float32Array::from(vec![1.0, 2.0])); + let q2 = Arc::new(Float32Array::from(vec![3.0, 4.0])); + + let req = VectorQueryRequest { + column: Some("vector".to_string()), + query_vector: vec![q1, q2], + ..Default::default() + }; + let query = AnyQuery::VectorQuery(req); + + // Create the Plan + let plan = create_plan(native_table, &query, QueryExecutionOptions::default()) + .await + .unwrap(); + + // formatting it allows us to see the hierarchy + let display = DisplayableExecutionPlan::new(plan.as_ref()) + .indent(true) + .to_string(); + + // We expect a RepartitionExec wrapping a UnionExec + assert!( + display.contains("RepartitionExec"), + "Plan should include Repartitioning" + ); + assert!( + display.contains("UnionExec"), + "Plan should include a Union of multiple searches" + ); + // We expect the projection to add the 'query_index' column (logic inside multi_vector_plan) + assert!( + display.contains("query_index"), + "Plan should add query_index column" + ); + } +}