diff --git a/Cargo.lock b/Cargo.lock index 149628f9..f96d3494 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4135,7 +4135,10 @@ dependencies = [ "candle-transformers", "chrono", "crunchy", + "datafusion-catalog", "datafusion-common", + "datafusion-execution", + "datafusion-expr", "datafusion-physical-plan", "futures", "half", diff --git a/Cargo.toml b/Cargo.toml index a4c44354..62a2d486 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,7 +42,10 @@ arrow-arith = "53.2" arrow-cast = "53.2" async-trait = "0" chrono = "0.4.35" -datafusion-common = "44.0" +datafusion-catalog = "44.0" +datafusion-common = { version = "44.0", default-features = false } +datafusion-execution = "44.0" +datafusion-expr = "44.0" datafusion-physical-plan = "44.0" env_logger = "0.10" half = { "version" = "=2.4.1", default-features = false, features = [ diff --git a/python/src/query.rs b/python/src/query.rs index 723d7c81..f3d1511c 100644 --- a/python/src/query.rs +++ b/python/src/query.rs @@ -7,8 +7,7 @@ use arrow::pyarrow::FromPyArrow; use lancedb::index::scalar::FullTextSearchQuery; use lancedb::query::QueryExecutionOptions; use lancedb::query::{ - ExecutableQuery, HasQuery, Query as LanceDbQuery, QueryBase, Select, - VectorQuery as LanceDbVectorQuery, + ExecutableQuery, Query as LanceDbQuery, QueryBase, Select, VectorQuery as LanceDbVectorQuery, }; use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::{PyAnyMethods, PyDictMethods}; @@ -313,7 +312,8 @@ impl VectorQuery { } pub fn nearest_to_text(&mut self, query: Bound<'_, PyDict>) -> PyResult { - let fts_query = Query::new(self.inner.mut_query().clone()).nearest_to_text(query)?; + let base_query = self.inner.clone().into_plain(); + let fts_query = Query::new(base_query).nearest_to_text(query)?; Ok(HybridQuery { inner_vec: self.clone(), inner_fts: fts_query, @@ -411,10 +411,14 @@ impl HybridQuery { } pub fn get_limit(&mut self) -> Option { - self.inner_fts.inner.limit.map(|i| i as u32) + self.inner_fts + .inner + .current_request() + .limit + .map(|i| i as u32) } pub fn get_with_row_id(&mut self) -> bool { - self.inner_fts.inner.with_row_id + self.inner_fts.inner.current_request().with_row_id } } diff --git a/rust/lancedb/Cargo.toml b/rust/lancedb/Cargo.toml index fd896ae0..28dafa0b 100644 --- a/rust/lancedb/Cargo.toml +++ b/rust/lancedb/Cargo.toml @@ -19,7 +19,10 @@ arrow-ord = { workspace = true } arrow-cast = { workspace = true } arrow-ipc.workspace = true chrono = { workspace = true } +datafusion-catalog.workspace = true datafusion-common.workspace = true +datafusion-execution.workspace = true +datafusion-expr.workspace = true datafusion-physical-plan.workspace = true object_store = { workspace = true } snafu = { workspace = true } @@ -33,7 +36,7 @@ lance-table = { workspace = true } lance-linalg = { workspace = true } lance-testing = { workspace = true } lance-encoding = { workspace = true } -moka = { workspace = true} +moka = { workspace = true } pin-project = { workspace = true } tokio = { version = "1.23", features = ["rt-multi-thread"] } log.workspace = true @@ -82,7 +85,7 @@ aws-sdk-s3 = { version = "1.38.0" } aws-sdk-kms = { version = "1.37" } aws-config = { version = "1.0" } aws-smithy-runtime = { version = "1.3" } -http-body = "1" # Matching reqwest +http-body = "1" # Matching reqwest [features] @@ -98,7 +101,7 @@ sentence-transformers = [ "dep:candle-core", "dep:candle-transformers", "dep:candle-nn", - "dep:tokenizers" + "dep:tokenizers", ] # TLS diff --git a/rust/lancedb/src/database.rs b/rust/lancedb/src/database.rs index 2516e4b5..1f5e7cb8 100644 --- a/rust/lancedb/src/database.rs +++ b/rust/lancedb/src/database.rs @@ -21,7 +21,7 @@ use arrow_array::RecordBatchReader; use lance::dataset::ReadParams; use crate::error::Result; -use crate::table::{TableDefinition, TableInternal, WriteOptions}; +use crate::table::{BaseTable, TableDefinition, WriteOptions}; pub mod listing; @@ -120,9 +120,9 @@ pub trait Database: /// List the names of tables in the database async fn table_names(&self, request: TableNamesRequest) -> Result>; /// Create a table in the database - async fn create_table(&self, request: CreateTableRequest) -> Result>; + async fn create_table(&self, request: CreateTableRequest) -> Result>; /// Open a table in the database - async fn open_table(&self, request: OpenTableRequest) -> Result>; + async fn open_table(&self, request: OpenTableRequest) -> Result>; /// Rename a table in the database async fn rename_table(&self, old_name: &str, new_name: &str) -> Result<()>; /// Drop a table in the database diff --git a/rust/lancedb/src/database/listing.rs b/rust/lancedb/src/database/listing.rs index 38ce87ff..fe361252 100644 --- a/rust/lancedb/src/database/listing.rs +++ b/rust/lancedb/src/database/listing.rs @@ -22,8 +22,8 @@ use crate::table::NativeTable; use crate::utils::validate_table_name; use super::{ - CreateTableData, CreateTableMode, CreateTableRequest, Database, DatabaseOptions, - OpenTableRequest, TableInternal, TableNamesRequest, + BaseTable, CreateTableData, CreateTableMode, CreateTableRequest, Database, DatabaseOptions, + OpenTableRequest, TableNamesRequest, }; /// File extension to indicate a lance table @@ -356,10 +356,7 @@ impl Database for ListingDatabase { Ok(f) } - async fn create_table( - &self, - mut request: CreateTableRequest, - ) -> Result> { + async fn create_table(&self, mut request: CreateTableRequest) -> Result> { let table_uri = self.table_uri(&request.name)?; // Inherit storage options from the connection let storage_options = request @@ -452,7 +449,7 @@ impl Database for ListingDatabase { } } - async fn open_table(&self, mut request: OpenTableRequest) -> Result> { + async fn open_table(&self, mut request: OpenTableRequest) -> Result> { let table_uri = self.table_uri(&request.name)?; // Inherit storage options from the connection diff --git a/rust/lancedb/src/index.rs b/rust/lancedb/src/index.rs index cc11f617..efcef35f 100644 --- a/rust/lancedb/src/index.rs +++ b/rust/lancedb/src/index.rs @@ -8,7 +8,7 @@ use serde::Deserialize; use serde_with::skip_serializing_none; use vector::IvfFlatIndexBuilder; -use crate::{table::TableInternal, DistanceType, Error, Result}; +use crate::{table::BaseTable, DistanceType, Error, Result}; use self::{ scalar::{BTreeIndexBuilder, BitmapIndexBuilder, LabelListIndexBuilder}, @@ -65,14 +65,14 @@ pub enum Index { /// /// The methods on this builder are used to specify options common to all indices. pub struct IndexBuilder { - parent: Arc, + parent: Arc, pub(crate) index: Index, pub(crate) columns: Vec, pub(crate) replace: bool, } impl IndexBuilder { - pub(crate) fn new(parent: Arc, columns: Vec, index: Index) -> Self { + pub(crate) fn new(parent: Arc, columns: Vec, index: Index) -> Self { Self { parent, index, diff --git a/rust/lancedb/src/query.rs b/rust/lancedb/src/query.rs index 1e09d6a0..1a69f8e0 100644 --- a/rust/lancedb/src/query.rs +++ b/rust/lancedb/src/query.rs @@ -20,12 +20,12 @@ use lance_index::scalar::FullTextSearchQuery; use lance_index::vector::DIST_COL; use lance_io::stream::RecordBatchStreamAdapter; -use crate::arrow::SendableRecordBatchStream; use crate::error::{Error, Result}; use crate::rerankers::rrf::RRFReranker; use crate::rerankers::{check_reranker_result, NormalizeMethod, Reranker}; -use crate::table::TableInternal; +use crate::table::BaseTable; use crate::DistanceType; +use crate::{arrow::SendableRecordBatchStream, table::AnyQuery}; mod hybrid; @@ -449,7 +449,7 @@ pub trait QueryBase { } pub trait HasQuery { - fn mut_query(&mut self) -> &mut Query; + fn mut_query(&mut self) -> &mut QueryRequest; } impl QueryBase for T { @@ -577,6 +577,65 @@ pub trait ExecutableQuery { fn explain_plan(&self, verbose: bool) -> impl Future> + Send; } +/// A basic query into a table without any kind of search +/// +/// This will result in a (potentially filtered) scan if executed +#[derive(Debug, Clone)] +pub struct QueryRequest { + /// limit the number of rows to return. + pub limit: Option, + + /// Offset of the query. + pub offset: Option, + + /// Apply filter to the returned rows. + pub filter: Option, + + /// Perform a full text search on the table. + pub full_text_search: Option, + + /// Select column projection. + pub select: Select, + + /// If set to true, the query is executed only on the indexed data, + /// and yields faster results. + /// + /// By default, this is false. + pub fast_search: bool, + + /// If set to true, the query will return the `_rowid` meta column. + /// + /// By default, this is false. + pub with_row_id: bool, + + /// If set to false, the filter will be applied after the vector search. + pub prefilter: bool, + + /// Implementation of reranker that can be used to reorder or combine query + /// results, especially if using hybrid search + pub reranker: Option>, + + /// Configure how query results are normalized when doing hybrid search + pub norm: Option, +} + +impl Default for QueryRequest { + fn default() -> Self { + Self { + limit: Some(DEFAULT_TOP_K), + offset: None, + filter: None, + full_text_search: None, + select: Select::All, + fast_search: false, + with_row_id: false, + prefilter: true, + reranker: None, + norm: None, + } + } +} + /// A builder for LanceDB queries. /// /// See [`crate::Table::query`] for more details on queries @@ -591,59 +650,15 @@ pub trait ExecutableQuery { /// times. #[derive(Debug, Clone)] pub struct Query { - parent: Arc, - - /// limit the number of rows to return. - pub limit: Option, - - /// Offset of the query. - pub(crate) offset: Option, - - /// Apply filter to the returned rows. - pub(crate) filter: Option, - - /// Perform a full text search on the table. - pub(crate) full_text_search: Option, - - /// Select column projection. - pub(crate) select: Select, - - /// If set to true, the query is executed only on the indexed data, - /// and yields faster results. - /// - /// By default, this is false. - pub(crate) fast_search: bool, - - /// If set to true, the query will return the `_rowid` meta column. - /// - /// By default, this is false. - pub with_row_id: bool, - - /// If set to false, the filter will be applied after the vector search. - pub(crate) prefilter: bool, - - /// Implementation of reranker that can be used to reorder or combine query - /// results, especially if using hybrid search - pub(crate) reranker: Option>, - - /// Configure how query results are normalized when doing hybrid search - pub(crate) norm: Option, + parent: Arc, + request: QueryRequest, } impl Query { - pub(crate) fn new(parent: Arc) -> Self { + pub(crate) fn new(parent: Arc) -> Self { Self { parent, - limit: Some(DEFAULT_TOP_K), - offset: None, - filter: None, - full_text_search: None, - select: Select::All, - fast_search: false, - with_row_id: false, - prefilter: true, - reranker: None, - norm: None, + request: QueryRequest::default(), } } @@ -691,38 +706,98 @@ impl Query { pub fn nearest_to(self, vector: impl IntoQueryVector) -> Result { let mut vector_query = self.into_vector(); let query_vector = vector.to_query_vector(&DataType::Float32, "default")?; - vector_query.query_vector.push(query_vector); + vector_query.request.query_vector.push(query_vector); Ok(vector_query) } + + pub fn into_request(self) -> QueryRequest { + self.request + } + + pub fn current_request(&self) -> &QueryRequest { + &self.request + } } impl HasQuery for Query { - fn mut_query(&mut self) -> &mut Query { - self + fn mut_query(&mut self) -> &mut QueryRequest { + &mut self.request } } impl ExecutableQuery for Query { async fn create_plan(&self, options: QueryExecutionOptions) -> Result> { - self.parent - .clone() - .create_plan(&self.clone().into_vector(), options) - .await + let req = AnyQuery::Query(self.request.clone()); + self.parent.clone().create_plan(&req, options).await } async fn execute_with_options( &self, options: QueryExecutionOptions, ) -> Result { + let query = AnyQuery::Query(self.request.clone()); Ok(SendableRecordBatchStream::from( - self.parent.clone().plain_query(self, options).await?, + self.parent.clone().query(&query, options).await?, )) } async fn explain_plan(&self, verbose: bool) -> Result { - self.parent - .explain_plan(&self.clone().into_vector(), verbose) - .await + let query = AnyQuery::Query(self.request.clone()); + self.parent.explain_plan(&query, verbose).await + } +} + +/// A request for a nearest-neighbors search into a table +#[derive(Debug, Clone)] +pub struct VectorQueryRequest { + /// The base query + pub base: QueryRequest, + /// The column to run the search on + /// + /// If None, then the table will need to auto-detect which column to use + pub column: Option, + /// The vector(s) to search for + pub query_vector: Vec>, + /// The number of partitions to search + pub nprobes: usize, + /// The lower bound (inclusive) of the distance to search for. + pub lower_bound: Option, + /// The upper bound (exclusive) of the distance to search for. + pub upper_bound: Option, + /// The number of candidates to return during the refine step for HNSW, + /// defaults to 1.5 * limit. + pub ef: Option, + /// A multiplier to control how many additional rows are taken during the refine step + pub refine_factor: Option, + /// The distance type to use for the search + pub distance_type: Option, + /// Default is true. Set to false to enforce a brute force search. + pub use_index: bool, +} + +impl Default for VectorQueryRequest { + fn default() -> Self { + Self { + base: QueryRequest::default(), + column: None, + query_vector: Vec::new(), + nprobes: 20, + lower_bound: None, + upper_bound: None, + ef: None, + refine_factor: None, + distance_type: None, + use_index: true, + } + } +} + +impl VectorQueryRequest { + pub fn from_plain_query(query: QueryRequest) -> Self { + Self { + base: query, + ..Default::default() + } } } @@ -737,39 +812,30 @@ impl ExecutableQuery for Query { /// the query and retrieve results. #[derive(Debug, Clone)] pub struct VectorQuery { - pub(crate) base: Query, - // The column to run the query on. If not specified, we will attempt to guess - // the column based on the dataset's schema. - pub(crate) column: Option, - // IVF PQ - ANN search. - pub(crate) query_vector: Vec>, - pub(crate) nprobes: usize, - // The lower bound (inclusive) of the distance to search for. - pub(crate) lower_bound: Option, - // The upper bound (exclusive) of the distance to search for. - pub(crate) upper_bound: Option, - // The number of candidates to return during the refine step for HNSW, - // defaults to 1.5 * limit. - pub(crate) ef: Option, - pub(crate) refine_factor: Option, - pub(crate) distance_type: Option, - /// Default is true. Set to false to enforce a brute force search. - pub(crate) use_index: bool, + parent: Arc, + request: VectorQueryRequest, } impl VectorQuery { fn new(base: Query) -> Self { Self { - base, - column: None, - query_vector: Vec::new(), - nprobes: 20, - lower_bound: None, - upper_bound: None, - ef: None, - refine_factor: None, - distance_type: None, - use_index: true, + parent: base.parent, + request: VectorQueryRequest::from_plain_query(base.request), + } + } + + pub fn into_request(self) -> VectorQueryRequest { + self.request + } + + pub fn current_request(&self) -> &VectorQueryRequest { + &self.request + } + + pub fn into_plain(self) -> Query { + Query { + parent: self.parent, + request: self.request.base, } } @@ -781,7 +847,7 @@ impl VectorQuery { /// This parameter must be specified if the table has more than one column /// whose data type is a fixed-size-list of floats. pub fn column(mut self, column: &str) -> Self { - self.column = Some(column.to_string()); + self.request.column = Some(column.to_string()); self } @@ -797,7 +863,7 @@ impl VectorQuery { /// result. pub fn add_query_vector(mut self, vector: impl IntoQueryVector) -> Result { let query_vector = vector.to_query_vector(&DataType::Float32, "default")?; - self.query_vector.push(query_vector); + self.request.query_vector.push(query_vector); Ok(self) } @@ -822,15 +888,15 @@ impl VectorQuery { /// your actual data to find the smallest possible value that will still give /// you the desired recall. pub fn nprobes(mut self, nprobes: usize) -> Self { - self.nprobes = nprobes; + self.request.nprobes = nprobes; self } /// Set the distance range for vector search, /// only rows with distances in the range [lower_bound, upper_bound) will be returned pub fn distance_range(mut self, lower_bound: Option, upper_bound: Option) -> Self { - self.lower_bound = lower_bound; - self.upper_bound = upper_bound; + self.request.lower_bound = lower_bound; + self.request.upper_bound = upper_bound; self } @@ -842,7 +908,7 @@ impl VectorQuery { /// Increasing this value will increase the recall of your query but will /// also increase the latency of your query. The default value is 1.5*limit. pub fn ef(mut self, ef: usize) -> Self { - self.ef = Some(ef); + self.request.ef = Some(ef); self } @@ -874,7 +940,7 @@ impl VectorQuery { /// and the quantized result vectors. This can be considerably different than the true /// distance between the query vector and the actual uncompressed vector. pub fn refine_factor(mut self, refine_factor: u32) -> Self { - self.refine_factor = Some(refine_factor); + self.request.refine_factor = Some(refine_factor); self } @@ -891,7 +957,7 @@ impl VectorQuery { /// /// By default [`DistanceType::L2`] is used. pub fn distance_type(mut self, distance_type: DistanceType) -> Self { - self.distance_type = Some(distance_type); + self.request.distance_type = Some(distance_type); self } @@ -903,16 +969,19 @@ impl VectorQuery { /// the vector index can give you ground truth results which you can use to /// calculate your recall to select an appropriate value for nprobes. pub fn bypass_vector_index(mut self) -> Self { - self.use_index = false; + self.request.use_index = false; self } pub async fn execute_hybrid(&self) -> Result { // clone query and specify we want to include row IDs, which can be needed for reranking - let fts_query = self.base.clone().with_row_id(); + let mut fts_query = Query::new(self.parent.clone()); + fts_query.request = self.request.base.clone(); + fts_query = fts_query.with_row_id(); + let mut vector_query = self.clone().with_row_id(); - vector_query.base.full_text_search = None; + vector_query.request.base.full_text_search = None; let (fts_results, vec_results) = try_join!(fts_query.execute(), vector_query.execute())?; let (fts_results, vec_results) = try_join!( @@ -928,7 +997,7 @@ impl VectorQuery { let mut fts_results = concat_batches(&fts_schema, fts_results.iter())?; let mut vec_results = concat_batches(&vec_schema, vec_results.iter())?; - if matches!(self.base.norm, Some(NormalizeMethod::Rank)) { + if matches!(self.request.base.norm, Some(NormalizeMethod::Rank)) { vec_results = hybrid::rank(vec_results, DIST_COL, None)?; fts_results = hybrid::rank(fts_results, SCORE_COL, None)?; } @@ -937,14 +1006,20 @@ impl VectorQuery { fts_results = hybrid::normalize_scores(fts_results, SCORE_COL, None)?; let reranker = self + .request .base .reranker .clone() .unwrap_or(Arc::new(RRFReranker::default())); - let fts_query = self.base.full_text_search.as_ref().ok_or(Error::Runtime { - message: "there should be an FTS search".to_string(), - })?; + let fts_query = self + .request + .base + .full_text_search + .as_ref() + .ok_or(Error::Runtime { + message: "there should be an FTS search".to_string(), + })?; let mut results = reranker .rerank_hybrid(&fts_query.query, vec_results, fts_results) @@ -952,12 +1027,12 @@ impl VectorQuery { check_reranker_result(&results)?; - let limit = self.base.limit.unwrap_or(DEFAULT_TOP_K); + let limit = self.request.base.limit.unwrap_or(DEFAULT_TOP_K); if results.num_rows() > limit { results = results.slice(0, limit); } - if !self.base.with_row_id { + if !self.request.base.with_row_id { results = results.drop_column(ROW_ID)?; } @@ -969,14 +1044,15 @@ impl VectorQuery { impl ExecutableQuery for VectorQuery { async fn create_plan(&self, options: QueryExecutionOptions) -> Result> { - self.base.parent.clone().create_plan(self, options).await + let query = AnyQuery::VectorQuery(self.request.clone()); + self.parent.clone().create_plan(&query, options).await } async fn execute_with_options( &self, options: QueryExecutionOptions, ) -> Result { - if self.base.full_text_search.is_some() { + if self.request.base.full_text_search.is_some() { let hybrid_result = async move { self.execute_hybrid().await }.boxed().await?; return Ok(hybrid_result); } @@ -990,13 +1066,14 @@ impl ExecutableQuery for VectorQuery { } async fn explain_plan(&self, verbose: bool) -> Result { - self.base.parent.explain_plan(self, verbose).await + let query = AnyQuery::VectorQuery(self.request.clone()); + self.parent.explain_plan(&query, verbose).await } } impl HasQuery for VectorQuery { - fn mut_query(&mut self) -> &mut Query { - &mut self.base + fn mut_query(&mut self) -> &mut QueryRequest { + &mut self.request.base } } @@ -1036,7 +1113,13 @@ mod tests { let vector = Float32Array::from_iter_values([0.1, 0.2]); let query = table.query().nearest_to(&[0.1, 0.2]).unwrap(); assert_eq!( - *query.query_vector.first().unwrap().as_ref().as_primitive(), + *query + .request + .query_vector + .first() + .unwrap() + .as_ref() + .as_primitive(), vector ); @@ -1054,15 +1137,21 @@ mod tests { .refine_factor(999); assert_eq!( - *query.query_vector.first().unwrap().as_ref().as_primitive(), + *query + .request + .query_vector + .first() + .unwrap() + .as_ref() + .as_primitive(), new_vector ); - assert_eq!(query.base.limit.unwrap(), 100); - assert_eq!(query.base.offset.unwrap(), 1); - assert_eq!(query.nprobes, 1000); - assert!(query.use_index); - assert_eq!(query.distance_type, Some(DistanceType::Cosine)); - assert_eq!(query.refine_factor, Some(999)); + assert_eq!(query.request.base.limit.unwrap(), 100); + assert_eq!(query.request.base.offset.unwrap(), 1); + assert_eq!(query.request.nprobes, 1000); + assert!(query.request.use_index); + assert_eq!(query.request.distance_type, Some(DistanceType::Cosine)); + assert_eq!(query.request.refine_factor, Some(999)); } #[tokio::test] diff --git a/rust/lancedb/src/remote.rs b/rust/lancedb/src/remote.rs index 45c007ee..22da6de5 100644 --- a/rust/lancedb/src/remote.rs +++ b/rust/lancedb/src/remote.rs @@ -14,6 +14,7 @@ pub(crate) mod util; const ARROW_STREAM_CONTENT_TYPE: &str = "application/vnd.apache.arrow.stream"; #[cfg(test)] const ARROW_FILE_CONTENT_TYPE: &str = "application/vnd.apache.arrow.file"; +#[cfg(test)] const JSON_CONTENT_TYPE: &str = "application/json"; pub use client::{ClientConfig, RetryConfig, TimeoutConfig}; diff --git a/rust/lancedb/src/remote/db.rs b/rust/lancedb/src/remote/db.rs index 31a88652..04b2436c 100644 --- a/rust/lancedb/src/remote/db.rs +++ b/rust/lancedb/src/remote/db.rs @@ -18,7 +18,7 @@ use crate::database::{ TableNamesRequest, }; use crate::error::Result; -use crate::table::TableInternal; +use crate::table::BaseTable; use super::client::{ClientConfig, HttpSend, RequestResultExt, RestfulLanceDbClient, Sender}; use super::table::RemoteTable; @@ -126,7 +126,7 @@ impl Database for RemoteDatabase { Ok(tables) } - async fn create_table(&self, request: CreateTableRequest) -> Result> { + async fn create_table(&self, request: CreateTableRequest) -> Result> { let data = match request.data { CreateTableData::Data(data) => data, CreateTableData::Empty(table_definition) => { @@ -198,7 +198,7 @@ impl Database for RemoteDatabase { ))) } - async fn open_table(&self, request: OpenTableRequest) -> Result> { + async fn open_table(&self, request: OpenTableRequest) -> Result> { // We describe the table to confirm it exists before moving on. if self.table_cache.get(&request.name).is_none() { let req = self diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index 9a5428c9..3b3c0904 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -2,12 +2,13 @@ // SPDX-FileCopyrightText: Copyright The LanceDB Authors use std::io::Cursor; +use std::pin::Pin; use std::sync::{Arc, Mutex}; use crate::index::Index; use crate::index::IndexStatistics; -use crate::query::Select; -use crate::table::AddDataMode; +use crate::query::{QueryRequest, Select, VectorQueryRequest}; +use crate::table::{AddDataMode, AnyQuery, Filter}; use crate::utils::{supported_btree_data_type, supported_vector_data_type}; use crate::{DistanceType, Error, Table}; use arrow_array::RecordBatchReader; @@ -16,14 +17,14 @@ use arrow_schema::{DataType, SchemaRef}; use async_trait::async_trait; use datafusion_common::DataFusionError; use datafusion_physical_plan::stream::RecordBatchStreamAdapter; -use datafusion_physical_plan::{ExecutionPlan, SendableRecordBatchStream}; +use datafusion_physical_plan::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream}; use futures::TryStreamExt; use http::header::CONTENT_TYPE; use http::StatusCode; use lance::arrow::json::{JsonDataType, JsonSchema}; use lance::dataset::scanner::DatasetRecordBatchStream; use lance::dataset::{ColumnAlteration, NewColumnTransform, Version}; -use lance_datafusion::exec::OneShotExec; +use lance_datafusion::exec::{execute_plan, OneShotExec}; use serde::{Deserialize, Serialize}; use tokio::sync::RwLock; @@ -31,16 +32,16 @@ use crate::{ connection::NoData, error::Result, index::{IndexBuilder, IndexConfig}, - query::{Query, QueryExecutionOptions, VectorQuery}, + query::QueryExecutionOptions, table::{ - merge::MergeInsertBuilder, AddDataBuilder, NativeTable, OptimizeAction, OptimizeStats, - TableDefinition, TableInternal, UpdateBuilder, + merge::MergeInsertBuilder, AddDataBuilder, BaseTable, OptimizeAction, OptimizeStats, + TableDefinition, UpdateBuilder, }, }; use super::client::RequestResultExt; use super::client::{HttpSend, RestfulLanceDbClient, Sender}; -use super::{ARROW_STREAM_CONTENT_TYPE, JSON_CONTENT_TYPE}; +use super::ARROW_STREAM_CONTENT_TYPE; #[derive(Debug)] pub struct RemoteTable { @@ -147,7 +148,7 @@ impl RemoteTable { Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream))) } - fn apply_query_params(body: &mut serde_json::Value, params: &Query) -> Result<()> { + fn apply_query_params(body: &mut serde_json::Value, params: &QueryRequest) -> Result<()> { if let Some(offset) = params.offset { body["offset"] = serde_json::Value::Number(serde_json::Number::from(offset)); } @@ -205,7 +206,7 @@ impl RemoteTable { fn apply_vector_query_params( mut body: serde_json::Value, - query: &VectorQuery, + query: &VectorQueryRequest, ) -> Result> { Self::apply_query_params(&mut body, &query.base)?; @@ -288,6 +289,45 @@ impl RemoteTable { let read_guard = self.version.read().await; *read_guard } + + async fn execute_query( + &self, + query: &AnyQuery, + _options: QueryExecutionOptions, + ) -> Result>>> { + let request = self.client.post(&format!("/v1/table/{}/query/", self.name)); + + let version = self.current_version().await; + let mut body = serde_json::json!({ "version": version }); + + match query { + AnyQuery::Query(query) => { + Self::apply_query_params(&mut body, query)?; + // Empty vector can be passed if no vector search is performed. + body["vector"] = serde_json::Value::Array(Vec::new()); + + let request = request.json(&body); + + let (request_id, response) = self.client.send(request, true).await?; + + let stream = self.read_arrow_stream(&request_id, response).await?; + Ok(vec![stream]) + } + AnyQuery::VectorQuery(query) => { + let bodies = Self::apply_vector_query_params(body, query)?; + let mut futures = Vec::with_capacity(bodies.len()); + for body in bodies { + let request = request.try_clone().unwrap().json(&body); + let future = async move { + let (request_id, response) = self.client.send(request, true).await?; + self.read_arrow_stream(&request_id, response).await + }; + futures.push(future); + } + futures::future::try_join_all(futures).await + } + } + } } #[derive(Deserialize)] @@ -325,13 +365,10 @@ mod test_utils { } #[async_trait] -impl TableInternal for RemoteTable { +impl BaseTable for RemoteTable { fn as_any(&self) -> &dyn std::any::Any { self } - fn as_native(&self) -> Option<&NativeTable> { - None - } fn name(&self) -> &str { &self.name } @@ -398,7 +435,7 @@ impl TableInternal for RemoteTable { let schema = self.describe().await?.schema; Ok(Arc::new(schema.try_into()?)) } - async fn count_rows(&self, filter: Option) -> Result { + async fn count_rows(&self, filter: Option) -> Result { let mut request = self .client .post(&format!("/v1/table/{}/count_rows/", self.name)); @@ -406,6 +443,11 @@ impl TableInternal for RemoteTable { let version = self.current_version().await; if let Some(filter) = filter { + let Filter::Sql(filter) = filter else { + return Err(Error::NotSupported { + message: "querying a remote table with a datafusion filter".to_string(), + }); + }; request = request.json(&serde_json::json!({ "predicate": filter, "version": version })); } else { let body = serde_json::json!({ "version": version }); @@ -453,25 +495,11 @@ impl TableInternal for RemoteTable { async fn create_plan( &self, - query: &VectorQuery, - _options: QueryExecutionOptions, + query: &AnyQuery, + options: QueryExecutionOptions, ) -> Result> { - let request = self.client.post(&format!("/v1/table/{}/query/", self.name)); + let streams = self.execute_query(query, options).await?; - let version = self.current_version().await; - let body = serde_json::json!({ "version": version }); - let bodies = Self::apply_vector_query_params(body, query)?; - - let mut futures = Vec::with_capacity(bodies.len()); - for body in bodies { - let request = request.try_clone().unwrap().json(&body); - let future = async move { - let (request_id, response) = self.client.send(request, true).await?; - self.read_arrow_stream(&request_id, response).await - }; - futures.push(future); - } - let streams = futures::future::try_join_all(futures).await?; if streams.len() == 1 { let stream = streams.into_iter().next().unwrap(); Ok(Arc::new(OneShotExec::new(stream))) @@ -484,29 +512,29 @@ impl TableInternal for RemoteTable { } } - async fn plain_query( + async fn query( &self, - query: &Query, + query: &AnyQuery, _options: QueryExecutionOptions, ) -> Result { - let request = self - .client - .post(&format!("/v1/table/{}/query/", self.name)) - .header(CONTENT_TYPE, JSON_CONTENT_TYPE); + let streams = self.execute_query(query, _options).await?; - let version = self.current_version().await; - let mut body = serde_json::json!({ "version": version }); - Self::apply_query_params(&mut body, query)?; - // Empty vector can be passed if no vector search is performed. - body["vector"] = serde_json::Value::Array(Vec::new()); + if streams.len() == 1 { + Ok(DatasetRecordBatchStream::new( + streams.into_iter().next().unwrap(), + )) + } else { + let stream_execs = streams + .into_iter() + .map(|stream| Arc::new(OneShotExec::new(stream)) as Arc) + .collect(); + let plan = Table::multi_vector_plan(stream_execs)?; - let request = request.json(&body); - - let (request_id, response) = self.client.send(request, true).await?; - - let stream = self.read_arrow_stream(&request_id, response).await?; - - Ok(DatasetRecordBatchStream::new(stream)) + Ok(DatasetRecordBatchStream::new(execute_plan( + plan, + Default::default(), + )?)) + } } async fn update(&self, update: UpdateBuilder) -> Result { self.check_mutable().await?; @@ -891,6 +919,7 @@ mod tests { use reqwest::Body; use crate::index::vector::IvfFlatIndexBuilder; + use crate::remote::JSON_CONTENT_TYPE; use crate::{ index::{vector::IvfPqIndexBuilder, Index, IndexStatistics, IndexType}, query::{ExecutableQuery, QueryBase}, diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index b8315638..7104088a 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -12,6 +12,7 @@ use arrow::datatypes::{Float32Type, UInt8Type}; use arrow_array::{RecordBatchIterator, RecordBatchReader}; use arrow_schema::{DataType, Field, Schema, SchemaRef}; use async_trait::async_trait; +use datafusion_expr::Expr; use datafusion_physical_plan::display::DisplayableExecutionPlan; use datafusion_physical_plan::projection::ProjectionExec; use datafusion_physical_plan::repartition::RepartitionExec; @@ -21,12 +22,13 @@ use futures::{StreamExt, TryStreamExt}; use lance::dataset::builder::DatasetBuilder; use lance::dataset::cleanup::RemovalStats; use lance::dataset::optimize::{compact_files, CompactionMetrics, IndexRemapperOptions}; -use lance::dataset::scanner::{DatasetRecordBatchStream, Scanner}; +use lance::dataset::scanner::Scanner; pub use lance::dataset::ColumnAlteration; pub use lance::dataset::NewColumnTransform; pub use lance::dataset::ReadParams; +pub use lance::dataset::Version; use lance::dataset::{ - Dataset, InsertBuilder, UpdateBuilder as LanceUpdateBuilder, Version, WhenMatched, WriteMode, + Dataset, InsertBuilder, UpdateBuilder as LanceUpdateBuilder, WhenMatched, WriteMode, WriteParams, }; use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource}; @@ -60,7 +62,8 @@ use crate::index::{ }; use crate::index::{IndexConfig, IndexStatisticsImpl}; use crate::query::{ - IntoQueryVector, Query, QueryExecutionOptions, Select, VectorQuery, DEFAULT_TOP_K, + IntoQueryVector, Query, QueryExecutionOptions, QueryRequest, Select, VectorQuery, + VectorQueryRequest, DEFAULT_TOP_K, }; use crate::utils::{ default_vector_column, supported_bitmap_data_type, supported_btree_data_type, @@ -71,11 +74,13 @@ use crate::utils::{ use self::dataset::DatasetConsistencyWrapper; use self::merge::MergeInsertBuilder; +pub mod datafusion; pub(crate) mod dataset; pub mod merge; pub use chrono::Duration; pub use lance::dataset::optimize::CompactionOptions; +pub use lance::dataset::scanner::DatasetRecordBatchStream; pub use lance_index::optimize::OptimizeOptions; /// Defines the type of column @@ -273,7 +278,7 @@ pub enum AddDataMode { /// A builder for configuring a [`crate::connection::Connection::create_table`] or [`Table::add`] /// operation pub struct AddDataBuilder { - parent: Arc, + parent: Arc, pub(crate) data: T, pub(crate) mode: AddDataMode, pub(crate) write_options: WriteOptions, @@ -318,13 +323,13 @@ impl AddDataBuilder { /// A builder for configuring an [`Table::update`] operation #[derive(Debug, Clone)] pub struct UpdateBuilder { - parent: Arc, + parent: Arc, pub(crate) filter: Option, pub(crate) columns: Vec<(String, String)>, } impl UpdateBuilder { - fn new(parent: Arc) -> Self { + fn new(parent: Arc) -> Self { Self { parent, filter: None, @@ -381,64 +386,102 @@ impl UpdateBuilder { } } +/// Filters that can be used to limit the rows returned by a query +pub enum Filter { + /// A SQL filter string + Sql(String), + /// A Datafusion logical expression + Datafusion(Expr), +} + +/// A query that can be used to search a LanceDB table +pub enum AnyQuery { + Query(QueryRequest), + VectorQuery(VectorQueryRequest), +} + +/// A trait for anything "table-like". This is used for both native tables (which target +/// Lance datasets) and remote tables (which target LanceDB cloud) +/// +/// This trait is still EXPERIMENTAL and subject to change in the future #[async_trait] -pub trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Sync { - #[allow(dead_code)] +pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync { + /// Get a reference to std::any::Any fn as_any(&self) -> &dyn std::any::Any; - /// Cast as [`NativeTable`], or return None it if is not a [`NativeTable`]. - fn as_native(&self) -> Option<&NativeTable>; /// Get the name of the table. fn name(&self) -> &str; /// Get the arrow [Schema] of the table. async fn schema(&self) -> Result; /// Count the number of rows in this table. - async fn count_rows(&self, filter: Option) -> Result; + async fn count_rows(&self, filter: Option) -> Result; + /// Create a physical plan for the query. async fn create_plan( &self, - query: &VectorQuery, + query: &AnyQuery, options: QueryExecutionOptions, ) -> Result>; - async fn plain_query( + /// Execute a query and return the results as a stream of RecordBatches. + async fn query( &self, - query: &Query, + query: &AnyQuery, options: QueryExecutionOptions, ) -> Result; - async fn explain_plan(&self, query: &VectorQuery, verbose: bool) -> Result { + /// Explain the plan for a query. + async fn explain_plan(&self, query: &AnyQuery, verbose: bool) -> Result { let plan = self.create_plan(query, Default::default()).await?; let display = DisplayableExecutionPlan::new(plan.as_ref()); Ok(format!("{}", display.indent(verbose))) } + /// Add new records to the table. async fn add( &self, add: AddDataBuilder, data: Box, ) -> Result<()>; + /// Delete rows from the table. async fn delete(&self, predicate: &str) -> Result<()>; + /// Update rows in the table. async fn update(&self, update: UpdateBuilder) -> Result; + /// Create an index on the provided column(s). async fn create_index(&self, index: IndexBuilder) -> Result<()>; + /// List the indices on the table. async fn list_indices(&self) -> Result>; + /// Drop an index from the table. async fn drop_index(&self, name: &str) -> Result<()>; + /// Get statistics about the index. async fn index_stats(&self, index_name: &str) -> Result>; + /// Merge insert new records into the table. async fn merge_insert( &self, params: MergeInsertBuilder, new_data: Box, ) -> Result<()>; + /// Optimize the dataset. async fn optimize(&self, action: OptimizeAction) -> Result; + /// Add columns to the table. async fn add_columns( &self, transforms: NewColumnTransform, read_columns: Option>, ) -> Result<()>; + /// Alter columns in the table. async fn alter_columns(&self, alterations: &[ColumnAlteration]) -> Result<()>; + /// Drop columns from the table. async fn drop_columns(&self, columns: &[&str]) -> Result<()>; + /// Get the version of the table. async fn version(&self) -> Result; + /// Checkout a specific version of the table. async fn checkout(&self, version: u64) -> Result<()>; + /// Checkout the latest version of the table. async fn checkout_latest(&self) -> Result<()>; + /// Restore the table to the currently checked out version. async fn restore(&self) -> Result<()>; + /// List the versions of the table. async fn list_versions(&self) -> Result>; + /// Get the table definition. async fn table_definition(&self) -> Result; + /// Get the table URI fn dataset_uri(&self) -> &str; } @@ -447,7 +490,7 @@ pub trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Sync { /// The type of the each row is defined in Apache Arrow [Schema]. #[derive(Clone)] pub struct Table { - inner: Arc, + inner: Arc, embedding_registry: Arc, } @@ -483,15 +526,19 @@ impl std::fmt::Display for Table { } impl Table { - pub fn new(inner: Arc) -> Self { + pub fn new(inner: Arc) -> Self { Self { inner, embedding_registry: Arc::new(MemoryRegistry::new()), } } + pub fn base_table(&self) -> &Arc { + &self.inner + } + pub(crate) fn new_with_embedding_registry( - inner: Arc, + inner: Arc, embedding_registry: Arc, ) -> Self { Self { @@ -524,7 +571,7 @@ impl Table { /// /// * `filter` if present, only count rows matching the filter pub async fn count_rows(&self, filter: Option) -> Result { - self.inner.count_rows(filter).await + self.inner.count_rows(filter.map(Filter::Sql)).await } /// Insert new records into this Table @@ -1063,6 +1110,17 @@ impl From for Table { } } +pub trait NativeTableExt { + /// Cast as [`NativeTable`], or return None it if is not a [`NativeTable`]. + fn as_native(&self) -> Option<&NativeTable>; +} + +impl NativeTableExt for Arc { + fn as_native(&self) -> Option<&NativeTable> { + self.as_any().downcast_ref::() + } +} + /// A table in a LanceDB database. #[derive(Debug, Clone)] pub struct NativeTable { @@ -1676,7 +1734,7 @@ impl NativeTable { async fn generic_query( &self, - query: &VectorQuery, + query: &AnyQuery, options: QueryExecutionOptions, ) -> Result { let plan = self.create_plan(query, options).await?; @@ -1766,15 +1824,11 @@ impl NativeTable { } #[async_trait::async_trait] -impl TableInternal for NativeTable { +impl BaseTable for NativeTable { fn as_any(&self) -> &dyn std::any::Any { self } - fn as_native(&self) -> Option<&NativeTable> { - Some(self) - } - fn name(&self) -> &str { self.name.as_str() } @@ -1830,8 +1884,15 @@ impl TableInternal for NativeTable { TableDefinition::try_from_rich_schema(schema) } - async fn count_rows(&self, filter: Option) -> Result { - Ok(self.dataset.get().await?.count_rows(filter).await?) + async fn count_rows(&self, filter: Option) -> Result { + let dataset = self.dataset.get().await?; + match filter { + None => Ok(dataset.count_rows(None).await?), + Some(Filter::Sql(sql)) => Ok(dataset.count_rows(Some(sql)).await?), + Some(Filter::Datafusion(_)) => Err(Error::NotSupported { + message: "Datafusion filters are not yet supported".to_string(), + }), + } } async fn add( @@ -1925,9 +1986,14 @@ impl TableInternal for NativeTable { async fn create_plan( &self, - query: &VectorQuery, + query: &AnyQuery, options: QueryExecutionOptions, ) -> Result> { + let query = match query { + AnyQuery::VectorQuery(query) => query.clone(), + AnyQuery::Query(query) => VectorQueryRequest::from_plain_query(query.clone()), + }; + let ds_ref = self.dataset.get().await?; let mut column = query.column.clone(); let schema = ds_ref.schema(); @@ -1975,7 +2041,10 @@ impl TableInternal for NativeTable { let mut sub_query = query.clone(); sub_query.query_vector = vec![query_vector]; let options_ref = options.clone(); - async move { self.create_plan(&sub_query, options_ref).await } + async move { + self.create_plan(&AnyQuery::VectorQuery(sub_query), options_ref) + .await + } }) .collect::>(); let plans = futures::future::try_join_all(plan_futures).await?; @@ -2073,13 +2142,12 @@ impl TableInternal for NativeTable { Ok(scanner.create_plan().await?) } - async fn plain_query( + async fn query( &self, - query: &Query, + query: &AnyQuery, options: QueryExecutionOptions, ) -> Result { - self.generic_query(&query.clone().into_vector(), options) - .await + self.generic_query(query, options).await } async fn merge_insert( @@ -2348,7 +2416,10 @@ mod tests { assert_eq!(table.count_rows(None).await.unwrap(), 10); assert_eq!( - table.count_rows(Some("i >= 5".to_string())).await.unwrap(), + table + .count_rows(Some(Filter::Sql("i >= 5".to_string()))) + .await + .unwrap(), 5 ); } diff --git a/rust/lancedb/src/table/datafusion.rs b/rust/lancedb/src/table/datafusion.rs new file mode 100644 index 00000000..4564786b --- /dev/null +++ b/rust/lancedb/src/table/datafusion.rs @@ -0,0 +1,187 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +//! This module contains adapters to allow LanceDB tables to be used as DataFusion table providers. +use std::{collections::HashMap, sync::Arc}; + +use arrow_schema::Schema as ArrowSchema; +use async_trait::async_trait; +use datafusion_catalog::{Session, TableProvider}; +use datafusion_common::{DataFusionError, Result as DataFusionResult, Statistics}; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_expr::{Expr, TableProviderFilterPushDown, TableType}; +use datafusion_physical_plan::{ + stream::RecordBatchStreamAdapter, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, +}; +use futures::{TryFutureExt, TryStreamExt}; + +use super::{AnyQuery, BaseTable}; +use crate::{ + query::{QueryExecutionOptions, QueryRequest, Select}, + Result, +}; + +/// Datafusion attempts to maintain batch metadata +/// +/// This is needless and it triggers bugs in DF. This operator erases metadata from the batches. +#[derive(Debug)] +struct MetadataEraserExec { + input: Arc, + schema: Arc, + properties: PlanProperties, +} + +impl MetadataEraserExec { + fn compute_properties_from_input( + input: &Arc, + schema: &Arc, + ) -> PlanProperties { + let input_properties = input.properties(); + let eq_properties = input_properties + .eq_properties + .clone() + .with_new_schema(schema.clone()) + .unwrap(); + input_properties.clone().with_eq_properties(eq_properties) + } + + fn new(input: Arc) -> Self { + let schema = Arc::new( + input + .schema() + .as_ref() + .clone() + .with_metadata(HashMap::new()), + ); + Self { + properties: Self::compute_properties_from_input(&input, &schema), + input, + schema, + } + } +} + +impl DisplayAs for MetadataEraserExec { + fn fmt_as(&self, _: DisplayFormatType, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "MetadataEraserExec") + } +} + +impl ExecutionPlan for MetadataEraserExec { + fn name(&self) -> &str { + "MetadataEraserExec" + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.properties + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> DataFusionResult> { + assert_eq!(children.len(), 1); + let new_properties = Self::compute_properties_from_input(&children[0], &self.schema); + Ok(Arc::new(Self { + input: children[0].clone(), + schema: self.schema.clone(), + properties: new_properties, + })) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> DataFusionResult { + let stream = self.input.execute(partition, context)?; + let schema = self.schema.clone(); + let stream = stream.map_ok(move |batch| batch.with_schema(schema.clone()).unwrap()); + Ok( + Box::pin(RecordBatchStreamAdapter::new(self.schema.clone(), stream)) + as SendableRecordBatchStream, + ) + } +} + +#[derive(Debug)] +pub struct BaseTableAdapter { + table: Arc, + schema: Arc, +} + +impl BaseTableAdapter { + pub async fn try_new(table: Arc) -> Result { + let schema = table.schema().await?; + Ok(Self { table, schema }) + } +} + +#[async_trait] +impl TableProvider for BaseTableAdapter { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> Arc { + self.schema.clone() + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + _state: &dyn Session, + projection: Option<&Vec>, + filters: &[Expr], + limit: Option, + ) -> DataFusionResult> { + let mut query = QueryRequest::default(); + if let Some(projection) = projection { + let field_names = projection + .iter() + .map(|i| self.schema.field(*i).name().to_string()) + .collect(); + query.select = Select::Columns(field_names); + } + assert!(filters.is_empty()); + if let Some(limit) = limit { + query.limit = Some(limit); + } else { + // Need to override the default of 10 + query.limit = None; + } + let plan = self + .table + .create_plan(&AnyQuery::Query(query), QueryExecutionOptions::default()) + .map_err(|err| DataFusionError::External(err.into())) + .await?; + Ok(Arc::new(MetadataEraserExec::new(plan))) + } + + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> DataFusionResult> { + // TODO: Pushdown unsupported until we can support datafusion filters in BaseTable::create_plan + Ok(vec![ + TableProviderFilterPushDown::Unsupported; + filters.len() + ]) + } + + fn statistics(&self) -> Option { + // TODO + None + } +} diff --git a/rust/lancedb/src/table/merge.rs b/rust/lancedb/src/table/merge.rs index 32d12bca..ea2999a2 100644 --- a/rust/lancedb/src/table/merge.rs +++ b/rust/lancedb/src/table/merge.rs @@ -7,14 +7,14 @@ use arrow_array::RecordBatchReader; use crate::Result; -use super::TableInternal; +use super::BaseTable; /// A builder used to create and run a merge insert operation /// /// See [`super::Table::merge_insert`] for more context #[derive(Debug, Clone)] pub struct MergeInsertBuilder { - table: Arc, + table: Arc, pub(crate) on: Vec, pub(crate) when_matched_update_all: bool, pub(crate) when_matched_update_all_filt: Option, @@ -24,7 +24,7 @@ pub struct MergeInsertBuilder { } impl MergeInsertBuilder { - pub(super) fn new(table: Arc, on: Vec) -> Self { + pub(super) fn new(table: Arc, on: Vec) -> Self { Self { table, on,