feat: expose table trait (#2097)

Similar to
c269524b2f
this PR reworks and exposes an internal trait (this time
`TableInternal`) to be a public trait. These two PRs together should
make it possible for others to integrate LanceDB on top of other
catalogs.

This PR also adds a basic `TableProvider` implementation for tables,
although some work still needs to be done here (pushdown not yet
enabled).
This commit is contained in:
Weston Pace
2025-02-05 18:13:51 -08:00
committed by GitHub
parent ef3093bc23
commit 6bf742c759
14 changed files with 619 additions and 232 deletions

3
Cargo.lock generated
View File

@@ -4135,7 +4135,10 @@ dependencies = [
"candle-transformers",
"chrono",
"crunchy",
"datafusion-catalog",
"datafusion-common",
"datafusion-execution",
"datafusion-expr",
"datafusion-physical-plan",
"futures",
"half",

View File

@@ -42,7 +42,10 @@ arrow-arith = "53.2"
arrow-cast = "53.2"
async-trait = "0"
chrono = "0.4.35"
datafusion-common = "44.0"
datafusion-catalog = "44.0"
datafusion-common = { version = "44.0", default-features = false }
datafusion-execution = "44.0"
datafusion-expr = "44.0"
datafusion-physical-plan = "44.0"
env_logger = "0.10"
half = { "version" = "=2.4.1", default-features = false, features = [

View File

@@ -7,8 +7,7 @@ use arrow::pyarrow::FromPyArrow;
use lancedb::index::scalar::FullTextSearchQuery;
use lancedb::query::QueryExecutionOptions;
use lancedb::query::{
ExecutableQuery, HasQuery, Query as LanceDbQuery, QueryBase, Select,
VectorQuery as LanceDbVectorQuery,
ExecutableQuery, Query as LanceDbQuery, QueryBase, Select, VectorQuery as LanceDbVectorQuery,
};
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::{PyAnyMethods, PyDictMethods};
@@ -313,7 +312,8 @@ impl VectorQuery {
}
pub fn nearest_to_text(&mut self, query: Bound<'_, PyDict>) -> PyResult<HybridQuery> {
let fts_query = Query::new(self.inner.mut_query().clone()).nearest_to_text(query)?;
let base_query = self.inner.clone().into_plain();
let fts_query = Query::new(base_query).nearest_to_text(query)?;
Ok(HybridQuery {
inner_vec: self.clone(),
inner_fts: fts_query,
@@ -411,10 +411,14 @@ impl HybridQuery {
}
pub fn get_limit(&mut self) -> Option<u32> {
self.inner_fts.inner.limit.map(|i| i as u32)
self.inner_fts
.inner
.current_request()
.limit
.map(|i| i as u32)
}
pub fn get_with_row_id(&mut self) -> bool {
self.inner_fts.inner.with_row_id
self.inner_fts.inner.current_request().with_row_id
}
}

View File

@@ -19,7 +19,10 @@ arrow-ord = { workspace = true }
arrow-cast = { workspace = true }
arrow-ipc.workspace = true
chrono = { workspace = true }
datafusion-catalog.workspace = true
datafusion-common.workspace = true
datafusion-execution.workspace = true
datafusion-expr.workspace = true
datafusion-physical-plan.workspace = true
object_store = { workspace = true }
snafu = { workspace = true }
@@ -33,7 +36,7 @@ lance-table = { workspace = true }
lance-linalg = { workspace = true }
lance-testing = { workspace = true }
lance-encoding = { workspace = true }
moka = { workspace = true}
moka = { workspace = true }
pin-project = { workspace = true }
tokio = { version = "1.23", features = ["rt-multi-thread"] }
log.workspace = true
@@ -82,7 +85,7 @@ aws-sdk-s3 = { version = "1.38.0" }
aws-sdk-kms = { version = "1.37" }
aws-config = { version = "1.0" }
aws-smithy-runtime = { version = "1.3" }
http-body = "1" # Matching reqwest
http-body = "1" # Matching reqwest
[features]
@@ -98,7 +101,7 @@ sentence-transformers = [
"dep:candle-core",
"dep:candle-transformers",
"dep:candle-nn",
"dep:tokenizers"
"dep:tokenizers",
]
# TLS

View File

@@ -21,7 +21,7 @@ use arrow_array::RecordBatchReader;
use lance::dataset::ReadParams;
use crate::error::Result;
use crate::table::{TableDefinition, TableInternal, WriteOptions};
use crate::table::{BaseTable, TableDefinition, WriteOptions};
pub mod listing;
@@ -120,9 +120,9 @@ pub trait Database:
/// List the names of tables in the database
async fn table_names(&self, request: TableNamesRequest) -> Result<Vec<String>>;
/// Create a table in the database
async fn create_table(&self, request: CreateTableRequest) -> Result<Arc<dyn TableInternal>>;
async fn create_table(&self, request: CreateTableRequest) -> Result<Arc<dyn BaseTable>>;
/// Open a table in the database
async fn open_table(&self, request: OpenTableRequest) -> Result<Arc<dyn TableInternal>>;
async fn open_table(&self, request: OpenTableRequest) -> Result<Arc<dyn BaseTable>>;
/// Rename a table in the database
async fn rename_table(&self, old_name: &str, new_name: &str) -> Result<()>;
/// Drop a table in the database

View File

@@ -22,8 +22,8 @@ use crate::table::NativeTable;
use crate::utils::validate_table_name;
use super::{
CreateTableData, CreateTableMode, CreateTableRequest, Database, DatabaseOptions,
OpenTableRequest, TableInternal, TableNamesRequest,
BaseTable, CreateTableData, CreateTableMode, CreateTableRequest, Database, DatabaseOptions,
OpenTableRequest, TableNamesRequest,
};
/// File extension to indicate a lance table
@@ -356,10 +356,7 @@ impl Database for ListingDatabase {
Ok(f)
}
async fn create_table(
&self,
mut request: CreateTableRequest,
) -> Result<Arc<dyn TableInternal>> {
async fn create_table(&self, mut request: CreateTableRequest) -> Result<Arc<dyn BaseTable>> {
let table_uri = self.table_uri(&request.name)?;
// Inherit storage options from the connection
let storage_options = request
@@ -452,7 +449,7 @@ impl Database for ListingDatabase {
}
}
async fn open_table(&self, mut request: OpenTableRequest) -> Result<Arc<dyn TableInternal>> {
async fn open_table(&self, mut request: OpenTableRequest) -> Result<Arc<dyn BaseTable>> {
let table_uri = self.table_uri(&request.name)?;
// Inherit storage options from the connection

View File

@@ -8,7 +8,7 @@ use serde::Deserialize;
use serde_with::skip_serializing_none;
use vector::IvfFlatIndexBuilder;
use crate::{table::TableInternal, DistanceType, Error, Result};
use crate::{table::BaseTable, DistanceType, Error, Result};
use self::{
scalar::{BTreeIndexBuilder, BitmapIndexBuilder, LabelListIndexBuilder},
@@ -65,14 +65,14 @@ pub enum Index {
///
/// The methods on this builder are used to specify options common to all indices.
pub struct IndexBuilder {
parent: Arc<dyn TableInternal>,
parent: Arc<dyn BaseTable>,
pub(crate) index: Index,
pub(crate) columns: Vec<String>,
pub(crate) replace: bool,
}
impl IndexBuilder {
pub(crate) fn new(parent: Arc<dyn TableInternal>, columns: Vec<String>, index: Index) -> Self {
pub(crate) fn new(parent: Arc<dyn BaseTable>, columns: Vec<String>, index: Index) -> Self {
Self {
parent,
index,

View File

@@ -20,12 +20,12 @@ use lance_index::scalar::FullTextSearchQuery;
use lance_index::vector::DIST_COL;
use lance_io::stream::RecordBatchStreamAdapter;
use crate::arrow::SendableRecordBatchStream;
use crate::error::{Error, Result};
use crate::rerankers::rrf::RRFReranker;
use crate::rerankers::{check_reranker_result, NormalizeMethod, Reranker};
use crate::table::TableInternal;
use crate::table::BaseTable;
use crate::DistanceType;
use crate::{arrow::SendableRecordBatchStream, table::AnyQuery};
mod hybrid;
@@ -449,7 +449,7 @@ pub trait QueryBase {
}
pub trait HasQuery {
fn mut_query(&mut self) -> &mut Query;
fn mut_query(&mut self) -> &mut QueryRequest;
}
impl<T: HasQuery> QueryBase for T {
@@ -577,6 +577,65 @@ pub trait ExecutableQuery {
fn explain_plan(&self, verbose: bool) -> impl Future<Output = Result<String>> + Send;
}
/// A basic query into a table without any kind of search
///
/// This will result in a (potentially filtered) scan if executed
#[derive(Debug, Clone)]
pub struct QueryRequest {
/// limit the number of rows to return.
pub limit: Option<usize>,
/// Offset of the query.
pub offset: Option<usize>,
/// Apply filter to the returned rows.
pub filter: Option<String>,
/// Perform a full text search on the table.
pub full_text_search: Option<FullTextSearchQuery>,
/// Select column projection.
pub select: Select,
/// If set to true, the query is executed only on the indexed data,
/// and yields faster results.
///
/// By default, this is false.
pub fast_search: bool,
/// If set to true, the query will return the `_rowid` meta column.
///
/// By default, this is false.
pub with_row_id: bool,
/// If set to false, the filter will be applied after the vector search.
pub prefilter: bool,
/// Implementation of reranker that can be used to reorder or combine query
/// results, especially if using hybrid search
pub reranker: Option<Arc<dyn Reranker>>,
/// Configure how query results are normalized when doing hybrid search
pub norm: Option<NormalizeMethod>,
}
impl Default for QueryRequest {
fn default() -> Self {
Self {
limit: Some(DEFAULT_TOP_K),
offset: None,
filter: None,
full_text_search: None,
select: Select::All,
fast_search: false,
with_row_id: false,
prefilter: true,
reranker: None,
norm: None,
}
}
}
/// A builder for LanceDB queries.
///
/// See [`crate::Table::query`] for more details on queries
@@ -591,59 +650,15 @@ pub trait ExecutableQuery {
/// times.
#[derive(Debug, Clone)]
pub struct Query {
parent: Arc<dyn TableInternal>,
/// limit the number of rows to return.
pub limit: Option<usize>,
/// Offset of the query.
pub(crate) offset: Option<usize>,
/// Apply filter to the returned rows.
pub(crate) filter: Option<String>,
/// Perform a full text search on the table.
pub(crate) full_text_search: Option<FullTextSearchQuery>,
/// Select column projection.
pub(crate) select: Select,
/// If set to true, the query is executed only on the indexed data,
/// and yields faster results.
///
/// By default, this is false.
pub(crate) fast_search: bool,
/// If set to true, the query will return the `_rowid` meta column.
///
/// By default, this is false.
pub with_row_id: bool,
/// If set to false, the filter will be applied after the vector search.
pub(crate) prefilter: bool,
/// Implementation of reranker that can be used to reorder or combine query
/// results, especially if using hybrid search
pub(crate) reranker: Option<Arc<dyn Reranker>>,
/// Configure how query results are normalized when doing hybrid search
pub(crate) norm: Option<NormalizeMethod>,
parent: Arc<dyn BaseTable>,
request: QueryRequest,
}
impl Query {
pub(crate) fn new(parent: Arc<dyn TableInternal>) -> Self {
pub(crate) fn new(parent: Arc<dyn BaseTable>) -> Self {
Self {
parent,
limit: Some(DEFAULT_TOP_K),
offset: None,
filter: None,
full_text_search: None,
select: Select::All,
fast_search: false,
with_row_id: false,
prefilter: true,
reranker: None,
norm: None,
request: QueryRequest::default(),
}
}
@@ -691,38 +706,98 @@ impl Query {
pub fn nearest_to(self, vector: impl IntoQueryVector) -> Result<VectorQuery> {
let mut vector_query = self.into_vector();
let query_vector = vector.to_query_vector(&DataType::Float32, "default")?;
vector_query.query_vector.push(query_vector);
vector_query.request.query_vector.push(query_vector);
Ok(vector_query)
}
pub fn into_request(self) -> QueryRequest {
self.request
}
pub fn current_request(&self) -> &QueryRequest {
&self.request
}
}
impl HasQuery for Query {
fn mut_query(&mut self) -> &mut Query {
self
fn mut_query(&mut self) -> &mut QueryRequest {
&mut self.request
}
}
impl ExecutableQuery for Query {
async fn create_plan(&self, options: QueryExecutionOptions) -> Result<Arc<dyn ExecutionPlan>> {
self.parent
.clone()
.create_plan(&self.clone().into_vector(), options)
.await
let req = AnyQuery::Query(self.request.clone());
self.parent.clone().create_plan(&req, options).await
}
async fn execute_with_options(
&self,
options: QueryExecutionOptions,
) -> Result<SendableRecordBatchStream> {
let query = AnyQuery::Query(self.request.clone());
Ok(SendableRecordBatchStream::from(
self.parent.clone().plain_query(self, options).await?,
self.parent.clone().query(&query, options).await?,
))
}
async fn explain_plan(&self, verbose: bool) -> Result<String> {
self.parent
.explain_plan(&self.clone().into_vector(), verbose)
.await
let query = AnyQuery::Query(self.request.clone());
self.parent.explain_plan(&query, verbose).await
}
}
/// A request for a nearest-neighbors search into a table
#[derive(Debug, Clone)]
pub struct VectorQueryRequest {
/// The base query
pub base: QueryRequest,
/// The column to run the search on
///
/// If None, then the table will need to auto-detect which column to use
pub column: Option<String>,
/// The vector(s) to search for
pub query_vector: Vec<Arc<dyn Array>>,
/// The number of partitions to search
pub nprobes: usize,
/// The lower bound (inclusive) of the distance to search for.
pub lower_bound: Option<f32>,
/// The upper bound (exclusive) of the distance to search for.
pub upper_bound: Option<f32>,
/// The number of candidates to return during the refine step for HNSW,
/// defaults to 1.5 * limit.
pub ef: Option<usize>,
/// A multiplier to control how many additional rows are taken during the refine step
pub refine_factor: Option<u32>,
/// The distance type to use for the search
pub distance_type: Option<DistanceType>,
/// Default is true. Set to false to enforce a brute force search.
pub use_index: bool,
}
impl Default for VectorQueryRequest {
fn default() -> Self {
Self {
base: QueryRequest::default(),
column: None,
query_vector: Vec::new(),
nprobes: 20,
lower_bound: None,
upper_bound: None,
ef: None,
refine_factor: None,
distance_type: None,
use_index: true,
}
}
}
impl VectorQueryRequest {
pub fn from_plain_query(query: QueryRequest) -> Self {
Self {
base: query,
..Default::default()
}
}
}
@@ -737,39 +812,30 @@ impl ExecutableQuery for Query {
/// the query and retrieve results.
#[derive(Debug, Clone)]
pub struct VectorQuery {
pub(crate) base: Query,
// The column to run the query on. If not specified, we will attempt to guess
// the column based on the dataset's schema.
pub(crate) column: Option<String>,
// IVF PQ - ANN search.
pub(crate) query_vector: Vec<Arc<dyn Array>>,
pub(crate) nprobes: usize,
// The lower bound (inclusive) of the distance to search for.
pub(crate) lower_bound: Option<f32>,
// The upper bound (exclusive) of the distance to search for.
pub(crate) upper_bound: Option<f32>,
// The number of candidates to return during the refine step for HNSW,
// defaults to 1.5 * limit.
pub(crate) ef: Option<usize>,
pub(crate) refine_factor: Option<u32>,
pub(crate) distance_type: Option<DistanceType>,
/// Default is true. Set to false to enforce a brute force search.
pub(crate) use_index: bool,
parent: Arc<dyn BaseTable>,
request: VectorQueryRequest,
}
impl VectorQuery {
fn new(base: Query) -> Self {
Self {
base,
column: None,
query_vector: Vec::new(),
nprobes: 20,
lower_bound: None,
upper_bound: None,
ef: None,
refine_factor: None,
distance_type: None,
use_index: true,
parent: base.parent,
request: VectorQueryRequest::from_plain_query(base.request),
}
}
pub fn into_request(self) -> VectorQueryRequest {
self.request
}
pub fn current_request(&self) -> &VectorQueryRequest {
&self.request
}
pub fn into_plain(self) -> Query {
Query {
parent: self.parent,
request: self.request.base,
}
}
@@ -781,7 +847,7 @@ impl VectorQuery {
/// This parameter must be specified if the table has more than one column
/// whose data type is a fixed-size-list of floats.
pub fn column(mut self, column: &str) -> Self {
self.column = Some(column.to_string());
self.request.column = Some(column.to_string());
self
}
@@ -797,7 +863,7 @@ impl VectorQuery {
/// result.
pub fn add_query_vector(mut self, vector: impl IntoQueryVector) -> Result<Self> {
let query_vector = vector.to_query_vector(&DataType::Float32, "default")?;
self.query_vector.push(query_vector);
self.request.query_vector.push(query_vector);
Ok(self)
}
@@ -822,15 +888,15 @@ impl VectorQuery {
/// your actual data to find the smallest possible value that will still give
/// you the desired recall.
pub fn nprobes(mut self, nprobes: usize) -> Self {
self.nprobes = nprobes;
self.request.nprobes = nprobes;
self
}
/// Set the distance range for vector search,
/// only rows with distances in the range [lower_bound, upper_bound) will be returned
pub fn distance_range(mut self, lower_bound: Option<f32>, upper_bound: Option<f32>) -> Self {
self.lower_bound = lower_bound;
self.upper_bound = upper_bound;
self.request.lower_bound = lower_bound;
self.request.upper_bound = upper_bound;
self
}
@@ -842,7 +908,7 @@ impl VectorQuery {
/// Increasing this value will increase the recall of your query but will
/// also increase the latency of your query. The default value is 1.5*limit.
pub fn ef(mut self, ef: usize) -> Self {
self.ef = Some(ef);
self.request.ef = Some(ef);
self
}
@@ -874,7 +940,7 @@ impl VectorQuery {
/// and the quantized result vectors. This can be considerably different than the true
/// distance between the query vector and the actual uncompressed vector.
pub fn refine_factor(mut self, refine_factor: u32) -> Self {
self.refine_factor = Some(refine_factor);
self.request.refine_factor = Some(refine_factor);
self
}
@@ -891,7 +957,7 @@ impl VectorQuery {
///
/// By default [`DistanceType::L2`] is used.
pub fn distance_type(mut self, distance_type: DistanceType) -> Self {
self.distance_type = Some(distance_type);
self.request.distance_type = Some(distance_type);
self
}
@@ -903,16 +969,19 @@ impl VectorQuery {
/// the vector index can give you ground truth results which you can use to
/// calculate your recall to select an appropriate value for nprobes.
pub fn bypass_vector_index(mut self) -> Self {
self.use_index = false;
self.request.use_index = false;
self
}
pub async fn execute_hybrid(&self) -> Result<SendableRecordBatchStream> {
// clone query and specify we want to include row IDs, which can be needed for reranking
let fts_query = self.base.clone().with_row_id();
let mut fts_query = Query::new(self.parent.clone());
fts_query.request = self.request.base.clone();
fts_query = fts_query.with_row_id();
let mut vector_query = self.clone().with_row_id();
vector_query.base.full_text_search = None;
vector_query.request.base.full_text_search = None;
let (fts_results, vec_results) = try_join!(fts_query.execute(), vector_query.execute())?;
let (fts_results, vec_results) = try_join!(
@@ -928,7 +997,7 @@ impl VectorQuery {
let mut fts_results = concat_batches(&fts_schema, fts_results.iter())?;
let mut vec_results = concat_batches(&vec_schema, vec_results.iter())?;
if matches!(self.base.norm, Some(NormalizeMethod::Rank)) {
if matches!(self.request.base.norm, Some(NormalizeMethod::Rank)) {
vec_results = hybrid::rank(vec_results, DIST_COL, None)?;
fts_results = hybrid::rank(fts_results, SCORE_COL, None)?;
}
@@ -937,14 +1006,20 @@ impl VectorQuery {
fts_results = hybrid::normalize_scores(fts_results, SCORE_COL, None)?;
let reranker = self
.request
.base
.reranker
.clone()
.unwrap_or(Arc::new(RRFReranker::default()));
let fts_query = self.base.full_text_search.as_ref().ok_or(Error::Runtime {
message: "there should be an FTS search".to_string(),
})?;
let fts_query = self
.request
.base
.full_text_search
.as_ref()
.ok_or(Error::Runtime {
message: "there should be an FTS search".to_string(),
})?;
let mut results = reranker
.rerank_hybrid(&fts_query.query, vec_results, fts_results)
@@ -952,12 +1027,12 @@ impl VectorQuery {
check_reranker_result(&results)?;
let limit = self.base.limit.unwrap_or(DEFAULT_TOP_K);
let limit = self.request.base.limit.unwrap_or(DEFAULT_TOP_K);
if results.num_rows() > limit {
results = results.slice(0, limit);
}
if !self.base.with_row_id {
if !self.request.base.with_row_id {
results = results.drop_column(ROW_ID)?;
}
@@ -969,14 +1044,15 @@ impl VectorQuery {
impl ExecutableQuery for VectorQuery {
async fn create_plan(&self, options: QueryExecutionOptions) -> Result<Arc<dyn ExecutionPlan>> {
self.base.parent.clone().create_plan(self, options).await
let query = AnyQuery::VectorQuery(self.request.clone());
self.parent.clone().create_plan(&query, options).await
}
async fn execute_with_options(
&self,
options: QueryExecutionOptions,
) -> Result<SendableRecordBatchStream> {
if self.base.full_text_search.is_some() {
if self.request.base.full_text_search.is_some() {
let hybrid_result = async move { self.execute_hybrid().await }.boxed().await?;
return Ok(hybrid_result);
}
@@ -990,13 +1066,14 @@ impl ExecutableQuery for VectorQuery {
}
async fn explain_plan(&self, verbose: bool) -> Result<String> {
self.base.parent.explain_plan(self, verbose).await
let query = AnyQuery::VectorQuery(self.request.clone());
self.parent.explain_plan(&query, verbose).await
}
}
impl HasQuery for VectorQuery {
fn mut_query(&mut self) -> &mut Query {
&mut self.base
fn mut_query(&mut self) -> &mut QueryRequest {
&mut self.request.base
}
}
@@ -1036,7 +1113,13 @@ mod tests {
let vector = Float32Array::from_iter_values([0.1, 0.2]);
let query = table.query().nearest_to(&[0.1, 0.2]).unwrap();
assert_eq!(
*query.query_vector.first().unwrap().as_ref().as_primitive(),
*query
.request
.query_vector
.first()
.unwrap()
.as_ref()
.as_primitive(),
vector
);
@@ -1054,15 +1137,21 @@ mod tests {
.refine_factor(999);
assert_eq!(
*query.query_vector.first().unwrap().as_ref().as_primitive(),
*query
.request
.query_vector
.first()
.unwrap()
.as_ref()
.as_primitive(),
new_vector
);
assert_eq!(query.base.limit.unwrap(), 100);
assert_eq!(query.base.offset.unwrap(), 1);
assert_eq!(query.nprobes, 1000);
assert!(query.use_index);
assert_eq!(query.distance_type, Some(DistanceType::Cosine));
assert_eq!(query.refine_factor, Some(999));
assert_eq!(query.request.base.limit.unwrap(), 100);
assert_eq!(query.request.base.offset.unwrap(), 1);
assert_eq!(query.request.nprobes, 1000);
assert!(query.request.use_index);
assert_eq!(query.request.distance_type, Some(DistanceType::Cosine));
assert_eq!(query.request.refine_factor, Some(999));
}
#[tokio::test]

View File

@@ -14,6 +14,7 @@ pub(crate) mod util;
const ARROW_STREAM_CONTENT_TYPE: &str = "application/vnd.apache.arrow.stream";
#[cfg(test)]
const ARROW_FILE_CONTENT_TYPE: &str = "application/vnd.apache.arrow.file";
#[cfg(test)]
const JSON_CONTENT_TYPE: &str = "application/json";
pub use client::{ClientConfig, RetryConfig, TimeoutConfig};

View File

@@ -18,7 +18,7 @@ use crate::database::{
TableNamesRequest,
};
use crate::error::Result;
use crate::table::TableInternal;
use crate::table::BaseTable;
use super::client::{ClientConfig, HttpSend, RequestResultExt, RestfulLanceDbClient, Sender};
use super::table::RemoteTable;
@@ -126,7 +126,7 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
Ok(tables)
}
async fn create_table(&self, request: CreateTableRequest) -> Result<Arc<dyn TableInternal>> {
async fn create_table(&self, request: CreateTableRequest) -> Result<Arc<dyn BaseTable>> {
let data = match request.data {
CreateTableData::Data(data) => data,
CreateTableData::Empty(table_definition) => {
@@ -198,7 +198,7 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
)))
}
async fn open_table(&self, request: OpenTableRequest) -> Result<Arc<dyn TableInternal>> {
async fn open_table(&self, request: OpenTableRequest) -> Result<Arc<dyn BaseTable>> {
// We describe the table to confirm it exists before moving on.
if self.table_cache.get(&request.name).is_none() {
let req = self

View File

@@ -2,12 +2,13 @@
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::io::Cursor;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use crate::index::Index;
use crate::index::IndexStatistics;
use crate::query::Select;
use crate::table::AddDataMode;
use crate::query::{QueryRequest, Select, VectorQueryRequest};
use crate::table::{AddDataMode, AnyQuery, Filter};
use crate::utils::{supported_btree_data_type, supported_vector_data_type};
use crate::{DistanceType, Error, Table};
use arrow_array::RecordBatchReader;
@@ -16,14 +17,14 @@ use arrow_schema::{DataType, SchemaRef};
use async_trait::async_trait;
use datafusion_common::DataFusionError;
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
use datafusion_physical_plan::{ExecutionPlan, SendableRecordBatchStream};
use datafusion_physical_plan::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream};
use futures::TryStreamExt;
use http::header::CONTENT_TYPE;
use http::StatusCode;
use lance::arrow::json::{JsonDataType, JsonSchema};
use lance::dataset::scanner::DatasetRecordBatchStream;
use lance::dataset::{ColumnAlteration, NewColumnTransform, Version};
use lance_datafusion::exec::OneShotExec;
use lance_datafusion::exec::{execute_plan, OneShotExec};
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;
@@ -31,16 +32,16 @@ use crate::{
connection::NoData,
error::Result,
index::{IndexBuilder, IndexConfig},
query::{Query, QueryExecutionOptions, VectorQuery},
query::QueryExecutionOptions,
table::{
merge::MergeInsertBuilder, AddDataBuilder, NativeTable, OptimizeAction, OptimizeStats,
TableDefinition, TableInternal, UpdateBuilder,
merge::MergeInsertBuilder, AddDataBuilder, BaseTable, OptimizeAction, OptimizeStats,
TableDefinition, UpdateBuilder,
},
};
use super::client::RequestResultExt;
use super::client::{HttpSend, RestfulLanceDbClient, Sender};
use super::{ARROW_STREAM_CONTENT_TYPE, JSON_CONTENT_TYPE};
use super::ARROW_STREAM_CONTENT_TYPE;
#[derive(Debug)]
pub struct RemoteTable<S: HttpSend = Sender> {
@@ -147,7 +148,7 @@ impl<S: HttpSend> RemoteTable<S> {
Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
}
fn apply_query_params(body: &mut serde_json::Value, params: &Query) -> Result<()> {
fn apply_query_params(body: &mut serde_json::Value, params: &QueryRequest) -> Result<()> {
if let Some(offset) = params.offset {
body["offset"] = serde_json::Value::Number(serde_json::Number::from(offset));
}
@@ -205,7 +206,7 @@ impl<S: HttpSend> RemoteTable<S> {
fn apply_vector_query_params(
mut body: serde_json::Value,
query: &VectorQuery,
query: &VectorQueryRequest,
) -> Result<Vec<serde_json::Value>> {
Self::apply_query_params(&mut body, &query.base)?;
@@ -288,6 +289,45 @@ impl<S: HttpSend> RemoteTable<S> {
let read_guard = self.version.read().await;
*read_guard
}
async fn execute_query(
&self,
query: &AnyQuery,
_options: QueryExecutionOptions,
) -> Result<Vec<Pin<Box<dyn RecordBatchStream + Send>>>> {
let request = self.client.post(&format!("/v1/table/{}/query/", self.name));
let version = self.current_version().await;
let mut body = serde_json::json!({ "version": version });
match query {
AnyQuery::Query(query) => {
Self::apply_query_params(&mut body, query)?;
// Empty vector can be passed if no vector search is performed.
body["vector"] = serde_json::Value::Array(Vec::new());
let request = request.json(&body);
let (request_id, response) = self.client.send(request, true).await?;
let stream = self.read_arrow_stream(&request_id, response).await?;
Ok(vec![stream])
}
AnyQuery::VectorQuery(query) => {
let bodies = Self::apply_vector_query_params(body, query)?;
let mut futures = Vec::with_capacity(bodies.len());
for body in bodies {
let request = request.try_clone().unwrap().json(&body);
let future = async move {
let (request_id, response) = self.client.send(request, true).await?;
self.read_arrow_stream(&request_id, response).await
};
futures.push(future);
}
futures::future::try_join_all(futures).await
}
}
}
}
#[derive(Deserialize)]
@@ -325,13 +365,10 @@ mod test_utils {
}
#[async_trait]
impl<S: HttpSend> TableInternal for RemoteTable<S> {
impl<S: HttpSend> BaseTable for RemoteTable<S> {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_native(&self) -> Option<&NativeTable> {
None
}
fn name(&self) -> &str {
&self.name
}
@@ -398,7 +435,7 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
let schema = self.describe().await?.schema;
Ok(Arc::new(schema.try_into()?))
}
async fn count_rows(&self, filter: Option<String>) -> Result<usize> {
async fn count_rows(&self, filter: Option<Filter>) -> Result<usize> {
let mut request = self
.client
.post(&format!("/v1/table/{}/count_rows/", self.name));
@@ -406,6 +443,11 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
let version = self.current_version().await;
if let Some(filter) = filter {
let Filter::Sql(filter) = filter else {
return Err(Error::NotSupported {
message: "querying a remote table with a datafusion filter".to_string(),
});
};
request = request.json(&serde_json::json!({ "predicate": filter, "version": version }));
} else {
let body = serde_json::json!({ "version": version });
@@ -453,25 +495,11 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
async fn create_plan(
&self,
query: &VectorQuery,
_options: QueryExecutionOptions,
query: &AnyQuery,
options: QueryExecutionOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
let request = self.client.post(&format!("/v1/table/{}/query/", self.name));
let streams = self.execute_query(query, options).await?;
let version = self.current_version().await;
let body = serde_json::json!({ "version": version });
let bodies = Self::apply_vector_query_params(body, query)?;
let mut futures = Vec::with_capacity(bodies.len());
for body in bodies {
let request = request.try_clone().unwrap().json(&body);
let future = async move {
let (request_id, response) = self.client.send(request, true).await?;
self.read_arrow_stream(&request_id, response).await
};
futures.push(future);
}
let streams = futures::future::try_join_all(futures).await?;
if streams.len() == 1 {
let stream = streams.into_iter().next().unwrap();
Ok(Arc::new(OneShotExec::new(stream)))
@@ -484,29 +512,29 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
}
}
async fn plain_query(
async fn query(
&self,
query: &Query,
query: &AnyQuery,
_options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream> {
let request = self
.client
.post(&format!("/v1/table/{}/query/", self.name))
.header(CONTENT_TYPE, JSON_CONTENT_TYPE);
let streams = self.execute_query(query, _options).await?;
let version = self.current_version().await;
let mut body = serde_json::json!({ "version": version });
Self::apply_query_params(&mut body, query)?;
// Empty vector can be passed if no vector search is performed.
body["vector"] = serde_json::Value::Array(Vec::new());
if streams.len() == 1 {
Ok(DatasetRecordBatchStream::new(
streams.into_iter().next().unwrap(),
))
} else {
let stream_execs = streams
.into_iter()
.map(|stream| Arc::new(OneShotExec::new(stream)) as Arc<dyn ExecutionPlan>)
.collect();
let plan = Table::multi_vector_plan(stream_execs)?;
let request = request.json(&body);
let (request_id, response) = self.client.send(request, true).await?;
let stream = self.read_arrow_stream(&request_id, response).await?;
Ok(DatasetRecordBatchStream::new(stream))
Ok(DatasetRecordBatchStream::new(execute_plan(
plan,
Default::default(),
)?))
}
}
async fn update(&self, update: UpdateBuilder) -> Result<u64> {
self.check_mutable().await?;
@@ -891,6 +919,7 @@ mod tests {
use reqwest::Body;
use crate::index::vector::IvfFlatIndexBuilder;
use crate::remote::JSON_CONTENT_TYPE;
use crate::{
index::{vector::IvfPqIndexBuilder, Index, IndexStatistics, IndexType},
query::{ExecutableQuery, QueryBase},

View File

@@ -12,6 +12,7 @@ use arrow::datatypes::{Float32Type, UInt8Type};
use arrow_array::{RecordBatchIterator, RecordBatchReader};
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use async_trait::async_trait;
use datafusion_expr::Expr;
use datafusion_physical_plan::display::DisplayableExecutionPlan;
use datafusion_physical_plan::projection::ProjectionExec;
use datafusion_physical_plan::repartition::RepartitionExec;
@@ -21,12 +22,13 @@ use futures::{StreamExt, TryStreamExt};
use lance::dataset::builder::DatasetBuilder;
use lance::dataset::cleanup::RemovalStats;
use lance::dataset::optimize::{compact_files, CompactionMetrics, IndexRemapperOptions};
use lance::dataset::scanner::{DatasetRecordBatchStream, Scanner};
use lance::dataset::scanner::Scanner;
pub use lance::dataset::ColumnAlteration;
pub use lance::dataset::NewColumnTransform;
pub use lance::dataset::ReadParams;
pub use lance::dataset::Version;
use lance::dataset::{
Dataset, InsertBuilder, UpdateBuilder as LanceUpdateBuilder, Version, WhenMatched, WriteMode,
Dataset, InsertBuilder, UpdateBuilder as LanceUpdateBuilder, WhenMatched, WriteMode,
WriteParams,
};
use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource};
@@ -60,7 +62,8 @@ use crate::index::{
};
use crate::index::{IndexConfig, IndexStatisticsImpl};
use crate::query::{
IntoQueryVector, Query, QueryExecutionOptions, Select, VectorQuery, DEFAULT_TOP_K,
IntoQueryVector, Query, QueryExecutionOptions, QueryRequest, Select, VectorQuery,
VectorQueryRequest, DEFAULT_TOP_K,
};
use crate::utils::{
default_vector_column, supported_bitmap_data_type, supported_btree_data_type,
@@ -71,11 +74,13 @@ use crate::utils::{
use self::dataset::DatasetConsistencyWrapper;
use self::merge::MergeInsertBuilder;
pub mod datafusion;
pub(crate) mod dataset;
pub mod merge;
pub use chrono::Duration;
pub use lance::dataset::optimize::CompactionOptions;
pub use lance::dataset::scanner::DatasetRecordBatchStream;
pub use lance_index::optimize::OptimizeOptions;
/// Defines the type of column
@@ -273,7 +278,7 @@ pub enum AddDataMode {
/// A builder for configuring a [`crate::connection::Connection::create_table`] or [`Table::add`]
/// operation
pub struct AddDataBuilder<T: IntoArrow> {
parent: Arc<dyn TableInternal>,
parent: Arc<dyn BaseTable>,
pub(crate) data: T,
pub(crate) mode: AddDataMode,
pub(crate) write_options: WriteOptions,
@@ -318,13 +323,13 @@ impl<T: IntoArrow> AddDataBuilder<T> {
/// A builder for configuring an [`Table::update`] operation
#[derive(Debug, Clone)]
pub struct UpdateBuilder {
parent: Arc<dyn TableInternal>,
parent: Arc<dyn BaseTable>,
pub(crate) filter: Option<String>,
pub(crate) columns: Vec<(String, String)>,
}
impl UpdateBuilder {
fn new(parent: Arc<dyn TableInternal>) -> Self {
fn new(parent: Arc<dyn BaseTable>) -> Self {
Self {
parent,
filter: None,
@@ -381,64 +386,102 @@ impl UpdateBuilder {
}
}
/// Filters that can be used to limit the rows returned by a query
pub enum Filter {
/// A SQL filter string
Sql(String),
/// A Datafusion logical expression
Datafusion(Expr),
}
/// A query that can be used to search a LanceDB table
pub enum AnyQuery {
Query(QueryRequest),
VectorQuery(VectorQueryRequest),
}
/// A trait for anything "table-like". This is used for both native tables (which target
/// Lance datasets) and remote tables (which target LanceDB cloud)
///
/// This trait is still EXPERIMENTAL and subject to change in the future
#[async_trait]
pub trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Sync {
#[allow(dead_code)]
pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync {
/// Get a reference to std::any::Any
fn as_any(&self) -> &dyn std::any::Any;
/// Cast as [`NativeTable`], or return None it if is not a [`NativeTable`].
fn as_native(&self) -> Option<&NativeTable>;
/// Get the name of the table.
fn name(&self) -> &str;
/// Get the arrow [Schema] of the table.
async fn schema(&self) -> Result<SchemaRef>;
/// Count the number of rows in this table.
async fn count_rows(&self, filter: Option<String>) -> Result<usize>;
async fn count_rows(&self, filter: Option<Filter>) -> Result<usize>;
/// Create a physical plan for the query.
async fn create_plan(
&self,
query: &VectorQuery,
query: &AnyQuery,
options: QueryExecutionOptions,
) -> Result<Arc<dyn ExecutionPlan>>;
async fn plain_query(
/// Execute a query and return the results as a stream of RecordBatches.
async fn query(
&self,
query: &Query,
query: &AnyQuery,
options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream>;
async fn explain_plan(&self, query: &VectorQuery, verbose: bool) -> Result<String> {
/// Explain the plan for a query.
async fn explain_plan(&self, query: &AnyQuery, verbose: bool) -> Result<String> {
let plan = self.create_plan(query, Default::default()).await?;
let display = DisplayableExecutionPlan::new(plan.as_ref());
Ok(format!("{}", display.indent(verbose)))
}
/// Add new records to the table.
async fn add(
&self,
add: AddDataBuilder<NoData>,
data: Box<dyn arrow_array::RecordBatchReader + Send>,
) -> Result<()>;
/// Delete rows from the table.
async fn delete(&self, predicate: &str) -> Result<()>;
/// Update rows in the table.
async fn update(&self, update: UpdateBuilder) -> Result<u64>;
/// Create an index on the provided column(s).
async fn create_index(&self, index: IndexBuilder) -> Result<()>;
/// List the indices on the table.
async fn list_indices(&self) -> Result<Vec<IndexConfig>>;
/// Drop an index from the table.
async fn drop_index(&self, name: &str) -> Result<()>;
/// Get statistics about the index.
async fn index_stats(&self, index_name: &str) -> Result<Option<IndexStatistics>>;
/// Merge insert new records into the table.
async fn merge_insert(
&self,
params: MergeInsertBuilder,
new_data: Box<dyn RecordBatchReader + Send>,
) -> Result<()>;
/// Optimize the dataset.
async fn optimize(&self, action: OptimizeAction) -> Result<OptimizeStats>;
/// Add columns to the table.
async fn add_columns(
&self,
transforms: NewColumnTransform,
read_columns: Option<Vec<String>>,
) -> Result<()>;
/// Alter columns in the table.
async fn alter_columns(&self, alterations: &[ColumnAlteration]) -> Result<()>;
/// Drop columns from the table.
async fn drop_columns(&self, columns: &[&str]) -> Result<()>;
/// Get the version of the table.
async fn version(&self) -> Result<u64>;
/// Checkout a specific version of the table.
async fn checkout(&self, version: u64) -> Result<()>;
/// Checkout the latest version of the table.
async fn checkout_latest(&self) -> Result<()>;
/// Restore the table to the currently checked out version.
async fn restore(&self) -> Result<()>;
/// List the versions of the table.
async fn list_versions(&self) -> Result<Vec<Version>>;
/// Get the table definition.
async fn table_definition(&self) -> Result<TableDefinition>;
/// Get the table URI
fn dataset_uri(&self) -> &str;
}
@@ -447,7 +490,7 @@ pub trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Sync {
/// The type of the each row is defined in Apache Arrow [Schema].
#[derive(Clone)]
pub struct Table {
inner: Arc<dyn TableInternal>,
inner: Arc<dyn BaseTable>,
embedding_registry: Arc<dyn EmbeddingRegistry>,
}
@@ -483,15 +526,19 @@ impl std::fmt::Display for Table {
}
impl Table {
pub fn new(inner: Arc<dyn TableInternal>) -> Self {
pub fn new(inner: Arc<dyn BaseTable>) -> Self {
Self {
inner,
embedding_registry: Arc::new(MemoryRegistry::new()),
}
}
pub fn base_table(&self) -> &Arc<dyn BaseTable> {
&self.inner
}
pub(crate) fn new_with_embedding_registry(
inner: Arc<dyn TableInternal>,
inner: Arc<dyn BaseTable>,
embedding_registry: Arc<dyn EmbeddingRegistry>,
) -> Self {
Self {
@@ -524,7 +571,7 @@ impl Table {
///
/// * `filter` if present, only count rows matching the filter
pub async fn count_rows(&self, filter: Option<String>) -> Result<usize> {
self.inner.count_rows(filter).await
self.inner.count_rows(filter.map(Filter::Sql)).await
}
/// Insert new records into this Table
@@ -1063,6 +1110,17 @@ impl From<NativeTable> for Table {
}
}
pub trait NativeTableExt {
/// Cast as [`NativeTable`], or return None it if is not a [`NativeTable`].
fn as_native(&self) -> Option<&NativeTable>;
}
impl NativeTableExt for Arc<dyn BaseTable> {
fn as_native(&self) -> Option<&NativeTable> {
self.as_any().downcast_ref::<NativeTable>()
}
}
/// A table in a LanceDB database.
#[derive(Debug, Clone)]
pub struct NativeTable {
@@ -1676,7 +1734,7 @@ impl NativeTable {
async fn generic_query(
&self,
query: &VectorQuery,
query: &AnyQuery,
options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream> {
let plan = self.create_plan(query, options).await?;
@@ -1766,15 +1824,11 @@ impl NativeTable {
}
#[async_trait::async_trait]
impl TableInternal for NativeTable {
impl BaseTable for NativeTable {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_native(&self) -> Option<&NativeTable> {
Some(self)
}
fn name(&self) -> &str {
self.name.as_str()
}
@@ -1830,8 +1884,15 @@ impl TableInternal for NativeTable {
TableDefinition::try_from_rich_schema(schema)
}
async fn count_rows(&self, filter: Option<String>) -> Result<usize> {
Ok(self.dataset.get().await?.count_rows(filter).await?)
async fn count_rows(&self, filter: Option<Filter>) -> Result<usize> {
let dataset = self.dataset.get().await?;
match filter {
None => Ok(dataset.count_rows(None).await?),
Some(Filter::Sql(sql)) => Ok(dataset.count_rows(Some(sql)).await?),
Some(Filter::Datafusion(_)) => Err(Error::NotSupported {
message: "Datafusion filters are not yet supported".to_string(),
}),
}
}
async fn add(
@@ -1925,9 +1986,14 @@ impl TableInternal for NativeTable {
async fn create_plan(
&self,
query: &VectorQuery,
query: &AnyQuery,
options: QueryExecutionOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
let query = match query {
AnyQuery::VectorQuery(query) => query.clone(),
AnyQuery::Query(query) => VectorQueryRequest::from_plain_query(query.clone()),
};
let ds_ref = self.dataset.get().await?;
let mut column = query.column.clone();
let schema = ds_ref.schema();
@@ -1975,7 +2041,10 @@ impl TableInternal for NativeTable {
let mut sub_query = query.clone();
sub_query.query_vector = vec![query_vector];
let options_ref = options.clone();
async move { self.create_plan(&sub_query, options_ref).await }
async move {
self.create_plan(&AnyQuery::VectorQuery(sub_query), options_ref)
.await
}
})
.collect::<Vec<_>>();
let plans = futures::future::try_join_all(plan_futures).await?;
@@ -2073,13 +2142,12 @@ impl TableInternal for NativeTable {
Ok(scanner.create_plan().await?)
}
async fn plain_query(
async fn query(
&self,
query: &Query,
query: &AnyQuery,
options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream> {
self.generic_query(&query.clone().into_vector(), options)
.await
self.generic_query(query, options).await
}
async fn merge_insert(
@@ -2348,7 +2416,10 @@ mod tests {
assert_eq!(table.count_rows(None).await.unwrap(), 10);
assert_eq!(
table.count_rows(Some("i >= 5".to_string())).await.unwrap(),
table
.count_rows(Some(Filter::Sql("i >= 5".to_string())))
.await
.unwrap(),
5
);
}

View File

@@ -0,0 +1,187 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
//! This module contains adapters to allow LanceDB tables to be used as DataFusion table providers.
use std::{collections::HashMap, sync::Arc};
use arrow_schema::Schema as ArrowSchema;
use async_trait::async_trait;
use datafusion_catalog::{Session, TableProvider};
use datafusion_common::{DataFusionError, Result as DataFusionResult, Statistics};
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
use datafusion_expr::{Expr, TableProviderFilterPushDown, TableType};
use datafusion_physical_plan::{
stream::RecordBatchStreamAdapter, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties,
};
use futures::{TryFutureExt, TryStreamExt};
use super::{AnyQuery, BaseTable};
use crate::{
query::{QueryExecutionOptions, QueryRequest, Select},
Result,
};
/// Datafusion attempts to maintain batch metadata
///
/// This is needless and it triggers bugs in DF. This operator erases metadata from the batches.
#[derive(Debug)]
struct MetadataEraserExec {
input: Arc<dyn ExecutionPlan>,
schema: Arc<ArrowSchema>,
properties: PlanProperties,
}
impl MetadataEraserExec {
fn compute_properties_from_input(
input: &Arc<dyn ExecutionPlan>,
schema: &Arc<ArrowSchema>,
) -> PlanProperties {
let input_properties = input.properties();
let eq_properties = input_properties
.eq_properties
.clone()
.with_new_schema(schema.clone())
.unwrap();
input_properties.clone().with_eq_properties(eq_properties)
}
fn new(input: Arc<dyn ExecutionPlan>) -> Self {
let schema = Arc::new(
input
.schema()
.as_ref()
.clone()
.with_metadata(HashMap::new()),
);
Self {
properties: Self::compute_properties_from_input(&input, &schema),
input,
schema,
}
}
}
impl DisplayAs for MetadataEraserExec {
fn fmt_as(&self, _: DisplayFormatType, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "MetadataEraserExec")
}
}
impl ExecutionPlan for MetadataEraserExec {
fn name(&self) -> &str {
"MetadataEraserExec"
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn properties(&self) -> &PlanProperties {
&self.properties
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.input]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
assert_eq!(children.len(), 1);
let new_properties = Self::compute_properties_from_input(&children[0], &self.schema);
Ok(Arc::new(Self {
input: children[0].clone(),
schema: self.schema.clone(),
properties: new_properties,
}))
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> DataFusionResult<SendableRecordBatchStream> {
let stream = self.input.execute(partition, context)?;
let schema = self.schema.clone();
let stream = stream.map_ok(move |batch| batch.with_schema(schema.clone()).unwrap());
Ok(
Box::pin(RecordBatchStreamAdapter::new(self.schema.clone(), stream))
as SendableRecordBatchStream,
)
}
}
#[derive(Debug)]
pub struct BaseTableAdapter {
table: Arc<dyn BaseTable>,
schema: Arc<ArrowSchema>,
}
impl BaseTableAdapter {
pub async fn try_new(table: Arc<dyn BaseTable>) -> Result<Self> {
let schema = table.schema().await?;
Ok(Self { table, schema })
}
}
#[async_trait]
impl TableProvider for BaseTableAdapter {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn schema(&self) -> Arc<ArrowSchema> {
self.schema.clone()
}
fn table_type(&self) -> TableType {
TableType::Base
}
async fn scan(
&self,
_state: &dyn Session,
projection: Option<&Vec<usize>>,
filters: &[Expr],
limit: Option<usize>,
) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
let mut query = QueryRequest::default();
if let Some(projection) = projection {
let field_names = projection
.iter()
.map(|i| self.schema.field(*i).name().to_string())
.collect();
query.select = Select::Columns(field_names);
}
assert!(filters.is_empty());
if let Some(limit) = limit {
query.limit = Some(limit);
} else {
// Need to override the default of 10
query.limit = None;
}
let plan = self
.table
.create_plan(&AnyQuery::Query(query), QueryExecutionOptions::default())
.map_err(|err| DataFusionError::External(err.into()))
.await?;
Ok(Arc::new(MetadataEraserExec::new(plan)))
}
fn supports_filters_pushdown(
&self,
filters: &[&Expr],
) -> DataFusionResult<Vec<TableProviderFilterPushDown>> {
// TODO: Pushdown unsupported until we can support datafusion filters in BaseTable::create_plan
Ok(vec![
TableProviderFilterPushDown::Unsupported;
filters.len()
])
}
fn statistics(&self) -> Option<Statistics> {
// TODO
None
}
}

View File

@@ -7,14 +7,14 @@ use arrow_array::RecordBatchReader;
use crate::Result;
use super::TableInternal;
use super::BaseTable;
/// A builder used to create and run a merge insert operation
///
/// See [`super::Table::merge_insert`] for more context
#[derive(Debug, Clone)]
pub struct MergeInsertBuilder {
table: Arc<dyn TableInternal>,
table: Arc<dyn BaseTable>,
pub(crate) on: Vec<String>,
pub(crate) when_matched_update_all: bool,
pub(crate) when_matched_update_all_filt: Option<String>,
@@ -24,7 +24,7 @@ pub struct MergeInsertBuilder {
}
impl MergeInsertBuilder {
pub(super) fn new(table: Arc<dyn TableInternal>, on: Vec<String>) -> Self {
pub(super) fn new(table: Arc<dyn BaseTable>, on: Vec<String>) -> Self {
Self {
table,
on,