mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-30 10:20:40 +00:00
feat(rust): implement TableProvider::insert_into() for LanceDB tables (#2939)
Implements `InsertExec` and `RemoteInsertExec` to support running inserts in DataFusion. ## Context In https://github.com/lancedb/lancedb/pull/2929, I've prototyped moving the insert pipeline into DataFusion. This will enable parallelism at two levels: 1. Running preprocessing, such as casting the input schema or computing embeddings 2. Writing out files This PR is just the first part of running the actual writes. In the end, the plans might look like: ``` InsertExec RepartitionExec num_partitions=<write_parallelism> ProjectionExec vector=compute_embedding() RepartitionExec num_partitions=<num_cpus> DataSourceExec ``` where `num_cpus` is used to take advantage of all cores, while `write_parallelism` might be less than `num_cpus` if there are too few rows to want to split writes across `num_cpus` files. Later PRs will move the preprocessing steps into DataFusion, and then hook this up to the `Table::add()` implementations. ## Relation to future SQL work We eventually plan on having the Remote SDK go through a FlightSQL endpoint. Then for most queries we will send just the SQL string to the server, and not run any sort of DataFusion plan on the client. However, I think writes will be a little special, especially bulk writes where we need to upload large streams of data and likely want parallelism. So we'll have different code paths for writes, and I think using DataFusion makes sense, especially as long as we are doing the pre-processing on the client side still.
This commit is contained in:
@@ -25,6 +25,7 @@ datafusion-catalog.workspace = true
|
||||
datafusion-common.workspace = true
|
||||
datafusion-execution.workspace = true
|
||||
datafusion-expr.workspace = true
|
||||
datafusion-physical-expr.workspace = true
|
||||
datafusion-physical-plan.workspace = true
|
||||
datafusion.workspace = true
|
||||
object_store = { workspace = true }
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
pub mod insert;
|
||||
|
||||
use crate::index::Index;
|
||||
use crate::index::IndexStatistics;
|
||||
use crate::query::{QueryFilter, QueryRequest, Select, VectorQueryRequest};
|
||||
@@ -1508,6 +1510,21 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
})?;
|
||||
Ok(stats)
|
||||
}
|
||||
|
||||
async fn create_insert_exec(
|
||||
&self,
|
||||
input: Arc<dyn ExecutionPlan>,
|
||||
write_params: lance::dataset::WriteParams,
|
||||
) -> Result<Arc<dyn ExecutionPlan>> {
|
||||
let overwrite = matches!(write_params.mode, lance::dataset::WriteMode::Overwrite);
|
||||
Ok(Arc::new(insert::RemoteInsertExec::new(
|
||||
self.name.clone(),
|
||||
self.identifier.clone(),
|
||||
self.client.clone(),
|
||||
input,
|
||||
overwrite,
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
|
||||
438
rust/lancedb/src/remote/table/insert.rs
Normal file
438
rust/lancedb/src/remote/table/insert.rs
Normal file
@@ -0,0 +1,438 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
//! DataFusion ExecutionPlan for inserting data into remote LanceDB tables.
|
||||
|
||||
use std::any::Any;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use arrow_array::{ArrayRef, RecordBatch, UInt64Array};
|
||||
use arrow_ipc::CompressionType;
|
||||
use arrow_schema::ArrowError;
|
||||
use datafusion_common::{DataFusionError, Result as DataFusionResult};
|
||||
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
|
||||
use datafusion_physical_expr::EquivalenceProperties;
|
||||
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
|
||||
use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
|
||||
use futures::StreamExt;
|
||||
use http::header::CONTENT_TYPE;
|
||||
|
||||
use crate::remote::client::{HttpSend, RestfulLanceDbClient, Sender};
|
||||
use crate::remote::table::RemoteTable;
|
||||
use crate::remote::ARROW_STREAM_CONTENT_TYPE;
|
||||
use crate::table::datafusion::insert::COUNT_SCHEMA;
|
||||
use crate::table::AddResult;
|
||||
use crate::Error;
|
||||
|
||||
/// ExecutionPlan for inserting data into a remote LanceDB table.
|
||||
///
|
||||
/// This plan:
|
||||
/// 1. Requires single partition (no parallel remote inserts yet)
|
||||
/// 2. Streams data as Arrow IPC to `/v1/table/{id}/insert/` endpoint
|
||||
/// 3. Stores AddResult for retrieval after execution
|
||||
#[derive(Debug)]
|
||||
pub struct RemoteInsertExec<S: HttpSend = Sender> {
|
||||
table_name: String,
|
||||
identifier: String,
|
||||
client: RestfulLanceDbClient<S>,
|
||||
input: Arc<dyn ExecutionPlan>,
|
||||
overwrite: bool,
|
||||
properties: PlanProperties,
|
||||
add_result: Arc<Mutex<Option<AddResult>>>,
|
||||
}
|
||||
|
||||
impl<S: HttpSend + 'static> RemoteInsertExec<S> {
|
||||
/// Create a new RemoteInsertExec.
|
||||
pub fn new(
|
||||
table_name: String,
|
||||
identifier: String,
|
||||
client: RestfulLanceDbClient<S>,
|
||||
input: Arc<dyn ExecutionPlan>,
|
||||
overwrite: bool,
|
||||
) -> Self {
|
||||
let schema = COUNT_SCHEMA.clone();
|
||||
let properties = PlanProperties::new(
|
||||
EquivalenceProperties::new(schema),
|
||||
datafusion_physical_plan::Partitioning::UnknownPartitioning(1),
|
||||
datafusion_physical_plan::execution_plan::EmissionType::Final,
|
||||
datafusion_physical_plan::execution_plan::Boundedness::Bounded,
|
||||
);
|
||||
|
||||
Self {
|
||||
table_name,
|
||||
identifier,
|
||||
client,
|
||||
input,
|
||||
overwrite,
|
||||
properties,
|
||||
add_result: Arc::new(Mutex::new(None)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the add result after execution.
|
||||
// TODO: this will be used when we wire this up to Table::add().
|
||||
#[allow(dead_code)]
|
||||
pub fn add_result(&self) -> Option<AddResult> {
|
||||
self.add_result.lock().unwrap().clone()
|
||||
}
|
||||
|
||||
fn stream_as_body(data: SendableRecordBatchStream) -> DataFusionResult<reqwest::Body> {
|
||||
let options = arrow_ipc::writer::IpcWriteOptions::default()
|
||||
.try_with_compression(Some(CompressionType::LZ4_FRAME))?;
|
||||
let writer = arrow_ipc::writer::StreamWriter::try_new_with_options(
|
||||
Vec::new(),
|
||||
&data.schema(),
|
||||
options,
|
||||
)?;
|
||||
|
||||
let stream = futures::stream::try_unfold((data, writer), move |(mut data, mut writer)| {
|
||||
async move {
|
||||
match data.next().await {
|
||||
Some(Ok(batch)) => {
|
||||
writer.write(&batch)?;
|
||||
let buffer = std::mem::take(writer.get_mut());
|
||||
Ok(Some((buffer, (data, writer))))
|
||||
}
|
||||
Some(Err(e)) => Err(e),
|
||||
None => {
|
||||
if let Err(ArrowError::IpcError(_msg)) = writer.finish() {
|
||||
// Will error if already closed.
|
||||
return Ok(None);
|
||||
};
|
||||
let buffer = std::mem::take(writer.get_mut());
|
||||
Ok(Some((buffer, (data, writer))))
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(reqwest::Body::wrap_stream(stream))
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: HttpSend + 'static> DisplayAs for RemoteInsertExec<S> {
|
||||
fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match t {
|
||||
DisplayFormatType::Default | DisplayFormatType::Verbose => {
|
||||
write!(
|
||||
f,
|
||||
"RemoteInsertExec: table={}, overwrite={}",
|
||||
self.table_name, self.overwrite
|
||||
)
|
||||
}
|
||||
DisplayFormatType::TreeRender => {
|
||||
write!(f, "RemoteInsertExec")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: HttpSend + 'static> ExecutionPlan for RemoteInsertExec<S> {
|
||||
fn name(&self) -> &str {
|
||||
Self::static_name()
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
|
||||
fn properties(&self) -> &PlanProperties {
|
||||
&self.properties
|
||||
}
|
||||
|
||||
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
|
||||
vec![&self.input]
|
||||
}
|
||||
|
||||
fn maintains_input_order(&self) -> Vec<bool> {
|
||||
vec![false]
|
||||
}
|
||||
|
||||
fn required_input_distribution(&self) -> Vec<datafusion_physical_plan::Distribution> {
|
||||
// Until we have a separate commit endpoint, we need to do all inserts in a single partition
|
||||
vec![datafusion_physical_plan::Distribution::SinglePartition]
|
||||
}
|
||||
|
||||
fn benefits_from_input_partitioning(&self) -> Vec<bool> {
|
||||
vec![false]
|
||||
}
|
||||
|
||||
fn with_new_children(
|
||||
self: Arc<Self>,
|
||||
children: Vec<Arc<dyn ExecutionPlan>>,
|
||||
) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
|
||||
if children.len() != 1 {
|
||||
return Err(DataFusionError::Internal(
|
||||
"RemoteInsertExec requires exactly one child".to_string(),
|
||||
));
|
||||
}
|
||||
Ok(Arc::new(Self::new(
|
||||
self.table_name.clone(),
|
||||
self.identifier.clone(),
|
||||
self.client.clone(),
|
||||
children[0].clone(),
|
||||
self.overwrite,
|
||||
)))
|
||||
}
|
||||
|
||||
fn execute(
|
||||
&self,
|
||||
partition: usize,
|
||||
context: Arc<TaskContext>,
|
||||
) -> DataFusionResult<SendableRecordBatchStream> {
|
||||
if partition != 0 {
|
||||
return Err(DataFusionError::Internal(
|
||||
"RemoteInsertExec only supports single partition execution".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let input_stream = self.input.execute(0, context)?;
|
||||
let client = self.client.clone();
|
||||
let identifier = self.identifier.clone();
|
||||
let overwrite = self.overwrite;
|
||||
let add_result = self.add_result.clone();
|
||||
let table_name = self.table_name.clone();
|
||||
|
||||
let stream = futures::stream::once(async move {
|
||||
let mut request = client
|
||||
.post(&format!("/v1/table/{}/insert/", identifier))
|
||||
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE);
|
||||
|
||||
if overwrite {
|
||||
request = request.query(&[("mode", "overwrite")]);
|
||||
}
|
||||
|
||||
let body = Self::stream_as_body(input_stream)?;
|
||||
let request = request.body(body);
|
||||
|
||||
let (request_id, response) = client
|
||||
.send(request)
|
||||
.await
|
||||
.map_err(|e| DataFusionError::External(Box::new(e)))?;
|
||||
|
||||
let response =
|
||||
RemoteTable::<Sender>::handle_table_not_found(&table_name, response, &request_id)
|
||||
.await
|
||||
.map_err(|e| DataFusionError::External(Box::new(e)))?;
|
||||
|
||||
let response = client
|
||||
.check_response(&request_id, response)
|
||||
.await
|
||||
.map_err(|e| DataFusionError::External(Box::new(e)))?;
|
||||
|
||||
let body_text = response.text().await.map_err(|e| {
|
||||
DataFusionError::External(Box::new(Error::Http {
|
||||
source: Box::new(e),
|
||||
request_id: request_id.clone(),
|
||||
status_code: None,
|
||||
}))
|
||||
})?;
|
||||
|
||||
let parsed_result = if body_text.trim().is_empty() {
|
||||
// Backward compatible with old servers
|
||||
AddResult { version: 0 }
|
||||
} else {
|
||||
serde_json::from_str(&body_text).map_err(|e| {
|
||||
DataFusionError::External(Box::new(Error::Http {
|
||||
source: format!("Failed to parse add response: {}", e).into(),
|
||||
request_id: request_id.clone(),
|
||||
status_code: None,
|
||||
}))
|
||||
})?
|
||||
};
|
||||
|
||||
{
|
||||
let mut res_lock = add_result.lock().map_err(|_| {
|
||||
DataFusionError::Execution("Failed to acquire lock for add_result".to_string())
|
||||
})?;
|
||||
*res_lock = Some(parsed_result);
|
||||
}
|
||||
|
||||
// Return a single batch with count 0 (actual count is tracked in add_result)
|
||||
let count_array: ArrayRef = Arc::new(UInt64Array::from(vec![0u64]));
|
||||
let batch = RecordBatch::try_new(COUNT_SCHEMA.clone(), vec![count_array])?;
|
||||
Ok::<_, DataFusionError>(batch)
|
||||
});
|
||||
|
||||
Ok(Box::pin(RecordBatchStreamAdapter::new(
|
||||
COUNT_SCHEMA.clone(),
|
||||
stream,
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use arrow_array::record_batch;
|
||||
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
|
||||
use datafusion::prelude::SessionContext;
|
||||
use datafusion_catalog::MemTable;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::remote::ARROW_STREAM_CONTENT_TYPE;
|
||||
use crate::table::datafusion::BaseTableAdapter;
|
||||
use crate::Table;
|
||||
|
||||
fn schema_json() -> &'static str {
|
||||
r#"{"fields": [{"name": "id", "type": {"type": "int32"}, "nullable": true}]}"#
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_remote_insert_exec_execute_empty() {
|
||||
let request_count = Arc::new(AtomicUsize::new(0));
|
||||
let request_count_clone = request_count.clone();
|
||||
|
||||
let table = Table::new_with_handler("my_table", move |request| {
|
||||
let path = request.url().path();
|
||||
|
||||
if path == "/v1/table/my_table/describe/" {
|
||||
// Return schema for BaseTableAdapter::try_new
|
||||
return http::Response::builder()
|
||||
.status(200)
|
||||
.body(format!(r#"{{"version": 1, "schema": {}}}"#, schema_json()))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
if path == "/v1/table/my_table/insert/" {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(
|
||||
request.headers().get("Content-Type").unwrap(),
|
||||
ARROW_STREAM_CONTENT_TYPE
|
||||
);
|
||||
request_count_clone.fetch_add(1, Ordering::SeqCst);
|
||||
|
||||
return http::Response::builder()
|
||||
.status(200)
|
||||
.body(r#"{"version": 2}"#.to_string())
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
panic!("Unexpected request path: {}", path);
|
||||
});
|
||||
|
||||
let schema = Arc::new(ArrowSchema::new(vec![Field::new(
|
||||
"id",
|
||||
DataType::Int32,
|
||||
true,
|
||||
)]));
|
||||
|
||||
// Create empty MemTable (no batches)
|
||||
let source_table = MemTable::try_new(schema, vec![vec![]]).unwrap();
|
||||
|
||||
let ctx = SessionContext::new();
|
||||
|
||||
// Register the remote table as insert target
|
||||
let provider = BaseTableAdapter::try_new(table.base_table().clone())
|
||||
.await
|
||||
.unwrap();
|
||||
ctx.register_table("my_table", Arc::new(provider)).unwrap();
|
||||
|
||||
// Register empty source
|
||||
ctx.register_table("empty_source", Arc::new(source_table))
|
||||
.unwrap();
|
||||
|
||||
// Execute the INSERT
|
||||
ctx.sql("INSERT INTO my_table SELECT * FROM empty_source")
|
||||
.await
|
||||
.unwrap()
|
||||
.collect()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify: should have made exactly one HTTP request even with empty input
|
||||
assert_eq!(request_count.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_remote_insert_exec_multi_partition() {
|
||||
let request_count = Arc::new(AtomicUsize::new(0));
|
||||
let request_count_clone = request_count.clone();
|
||||
|
||||
let table = Table::new_with_handler("my_table", move |request| {
|
||||
let path = request.url().path();
|
||||
|
||||
if path == "/v1/table/my_table/describe/" {
|
||||
// Return schema for BaseTableAdapter::try_new
|
||||
return http::Response::builder()
|
||||
.status(200)
|
||||
.body(format!(r#"{{"version": 1, "schema": {}}}"#, schema_json()))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
if path == "/v1/table/my_table/insert/" {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(
|
||||
request.headers().get("Content-Type").unwrap(),
|
||||
ARROW_STREAM_CONTENT_TYPE
|
||||
);
|
||||
request_count_clone.fetch_add(1, Ordering::SeqCst);
|
||||
|
||||
return http::Response::builder()
|
||||
.status(200)
|
||||
.body(r#"{"version": 2}"#.to_string())
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
panic!("Unexpected request path: {}", path);
|
||||
});
|
||||
|
||||
let schema = Arc::new(ArrowSchema::new(vec![Field::new(
|
||||
"id",
|
||||
DataType::Int32,
|
||||
true,
|
||||
)]));
|
||||
|
||||
// Create MemTable with multiple partitions and multiple batches
|
||||
let source_table = MemTable::try_new(
|
||||
schema,
|
||||
vec![
|
||||
// Partition 0
|
||||
vec![
|
||||
record_batch!(("id", Int32, [1, 2])).unwrap(),
|
||||
record_batch!(("id", Int32, [3, 4])).unwrap(),
|
||||
],
|
||||
// Partition 1
|
||||
vec![record_batch!(("id", Int32, [5, 6, 7])).unwrap()],
|
||||
// Partition 2
|
||||
vec![record_batch!(("id", Int32, [8])).unwrap()],
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let ctx = SessionContext::new();
|
||||
|
||||
// Register the remote table as insert target
|
||||
let provider = BaseTableAdapter::try_new(table.base_table().clone())
|
||||
.await
|
||||
.unwrap();
|
||||
ctx.register_table("my_table", Arc::new(provider)).unwrap();
|
||||
|
||||
// Register multi-partition source
|
||||
ctx.register_table("multi_partition_source", Arc::new(source_table))
|
||||
.unwrap();
|
||||
|
||||
// Get the physical plan and verify it includes a repartition to 1
|
||||
let df = ctx
|
||||
.sql("INSERT INTO my_table SELECT * FROM multi_partition_source")
|
||||
.await
|
||||
.unwrap();
|
||||
let plan = df.clone().create_physical_plan().await.unwrap();
|
||||
let plan_str = datafusion::physical_plan::displayable(plan.as_ref())
|
||||
.indent(true)
|
||||
.to_string();
|
||||
|
||||
// The plan should include a CoalescePartitionsExec to merge partitions
|
||||
assert!(
|
||||
plan_str.contains("CoalescePartitionsExec"),
|
||||
"Expected CoalescePartitionsExec in plan:\n{}",
|
||||
plan_str
|
||||
);
|
||||
|
||||
// Execute the INSERT
|
||||
df.collect().await.unwrap();
|
||||
|
||||
// Verify: should have made exactly one HTTP request despite multiple input partitions
|
||||
assert_eq!(request_count.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
}
|
||||
@@ -537,6 +537,19 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync {
|
||||
) -> Result<()>;
|
||||
/// Get statistics on the table
|
||||
async fn stats(&self) -> Result<TableStatistics>;
|
||||
/// Create an ExecutionPlan for inserting data into the table.
|
||||
///
|
||||
/// This is used by the DataFusion TableProvider implementation to support
|
||||
/// INSERT INTO statements.
|
||||
async fn create_insert_exec(
|
||||
&self,
|
||||
_input: Arc<dyn datafusion_physical_plan::ExecutionPlan>,
|
||||
_write_params: WriteParams,
|
||||
) -> Result<Arc<dyn datafusion_physical_plan::ExecutionPlan>> {
|
||||
Err(Error::NotSupported {
|
||||
message: "create_insert_exec not implemented".to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// A Table is a collection of strong typed Rows.
|
||||
@@ -3247,6 +3260,21 @@ impl BaseTable for NativeTable {
|
||||
};
|
||||
Ok(stats)
|
||||
}
|
||||
|
||||
async fn create_insert_exec(
|
||||
&self,
|
||||
input: Arc<dyn datafusion_physical_plan::ExecutionPlan>,
|
||||
write_params: WriteParams,
|
||||
) -> Result<Arc<dyn datafusion_physical_plan::ExecutionPlan>> {
|
||||
let ds = self.dataset.get().await?;
|
||||
let dataset = Arc::new((*ds).clone());
|
||||
Ok(Arc::new(datafusion::insert::InsertExec::new(
|
||||
self.dataset.clone(),
|
||||
dataset,
|
||||
input,
|
||||
write_params,
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
//! This module contains adapters to allow LanceDB tables to be used as DataFusion table providers.
|
||||
|
||||
pub mod insert;
|
||||
pub mod udtf;
|
||||
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
@@ -13,11 +14,12 @@ use async_trait::async_trait;
|
||||
use datafusion_catalog::{Session, TableProvider};
|
||||
use datafusion_common::{DataFusionError, Result as DataFusionResult, Statistics};
|
||||
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
|
||||
use datafusion_expr::{Expr, TableProviderFilterPushDown, TableType};
|
||||
use datafusion_expr::{dml::InsertOp, Expr, TableProviderFilterPushDown, TableType};
|
||||
use datafusion_physical_plan::{
|
||||
stream::RecordBatchStreamAdapter, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties,
|
||||
};
|
||||
use futures::{TryFutureExt, TryStreamExt};
|
||||
use lance::dataset::{WriteMode, WriteParams};
|
||||
|
||||
use super::{AnyQuery, BaseTable};
|
||||
use crate::{
|
||||
@@ -250,6 +252,33 @@ impl TableProvider for BaseTableAdapter {
|
||||
// TODO
|
||||
None
|
||||
}
|
||||
|
||||
async fn insert_into(
|
||||
&self,
|
||||
_state: &dyn Session,
|
||||
input: Arc<dyn ExecutionPlan>,
|
||||
insert_op: InsertOp,
|
||||
) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
|
||||
let mode = match insert_op {
|
||||
InsertOp::Append => WriteMode::Append,
|
||||
InsertOp::Overwrite => WriteMode::Overwrite,
|
||||
InsertOp::Replace => {
|
||||
return Err(DataFusionError::NotImplemented(
|
||||
"Replace mode is not supported for LanceDB tables".to_string(),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let write_params = WriteParams {
|
||||
mode,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
self.table
|
||||
.create_insert_exec(input, write_params)
|
||||
.await
|
||||
.map_err(|e| DataFusionError::External(e.into()))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
446
rust/lancedb/src/table/datafusion/insert.rs
Normal file
446
rust/lancedb/src/table/datafusion/insert.rs
Normal file
@@ -0,0 +1,446 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
//! DataFusion ExecutionPlan for inserting data into LanceDB tables.
|
||||
|
||||
use std::any::Any;
|
||||
use std::sync::{Arc, LazyLock, Mutex};
|
||||
|
||||
use arrow_array::{RecordBatch, UInt64Array};
|
||||
use arrow_schema::{DataType, Field, Schema as ArrowSchema, SchemaRef};
|
||||
use datafusion_common::{DataFusionError, Result as DataFusionResult};
|
||||
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
|
||||
use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
|
||||
use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType};
|
||||
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
|
||||
use datafusion_physical_plan::{
|
||||
DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties,
|
||||
};
|
||||
use lance::dataset::transaction::{Operation, Transaction};
|
||||
use lance::dataset::{CommitBuilder, InsertBuilder, WriteParams};
|
||||
use lance::Dataset;
|
||||
use lance_table::format::Fragment;
|
||||
|
||||
use crate::table::dataset::DatasetConsistencyWrapper;
|
||||
|
||||
pub(crate) static COUNT_SCHEMA: LazyLock<SchemaRef> = LazyLock::new(|| {
|
||||
Arc::new(ArrowSchema::new(vec![Field::new(
|
||||
"count",
|
||||
DataType::UInt64,
|
||||
false,
|
||||
)]))
|
||||
});
|
||||
|
||||
fn operation_fragments(operation: &Operation) -> &[Fragment] {
|
||||
match operation {
|
||||
Operation::Append { fragments } => fragments,
|
||||
Operation::Overwrite { fragments, .. } => fragments,
|
||||
_ => &[],
|
||||
}
|
||||
}
|
||||
|
||||
fn count_rows_from_operation(operation: &Operation) -> u64 {
|
||||
operation_fragments(operation)
|
||||
.iter()
|
||||
.map(|f| f.num_rows().unwrap_or(0) as u64)
|
||||
.sum()
|
||||
}
|
||||
|
||||
fn operation_fragments_mut(operation: &mut Operation) -> &mut Vec<Fragment> {
|
||||
match operation {
|
||||
Operation::Append { fragments } => fragments,
|
||||
Operation::Overwrite { fragments, .. } => fragments,
|
||||
_ => panic!("Unsupported operation type for getting mutable fragments"),
|
||||
}
|
||||
}
|
||||
|
||||
fn merge_transactions(mut transactions: Vec<Transaction>) -> Option<Transaction> {
|
||||
let mut first = transactions.pop()?;
|
||||
|
||||
for txn in transactions {
|
||||
let first_fragments = operation_fragments_mut(&mut first.operation);
|
||||
let txn_fragments = operation_fragments(&txn.operation);
|
||||
first_fragments.extend_from_slice(txn_fragments);
|
||||
}
|
||||
|
||||
Some(first)
|
||||
}
|
||||
|
||||
/// ExecutionPlan for inserting data into a native LanceDB table.
|
||||
///
|
||||
/// This plan executes inserts by:
|
||||
/// 1. Each partition writes data independently using InsertBuilder::execute_uncommitted_stream
|
||||
/// 2. The last partition to complete commits all transactions atomically
|
||||
/// 3. Returns the count of inserted rows per partition
|
||||
#[derive(Debug)]
|
||||
pub struct InsertExec {
|
||||
ds_wrapper: DatasetConsistencyWrapper,
|
||||
dataset: Arc<Dataset>,
|
||||
input: Arc<dyn ExecutionPlan>,
|
||||
write_params: WriteParams,
|
||||
properties: PlanProperties,
|
||||
partial_transactions: Arc<Mutex<Vec<Transaction>>>,
|
||||
}
|
||||
|
||||
impl InsertExec {
|
||||
pub fn new(
|
||||
ds_wrapper: DatasetConsistencyWrapper,
|
||||
dataset: Arc<Dataset>,
|
||||
input: Arc<dyn ExecutionPlan>,
|
||||
write_params: WriteParams,
|
||||
) -> Self {
|
||||
let schema = COUNT_SCHEMA.clone();
|
||||
let num_partitions = input.output_partitioning().partition_count();
|
||||
let properties = PlanProperties::new(
|
||||
EquivalenceProperties::new(schema),
|
||||
Partitioning::UnknownPartitioning(num_partitions),
|
||||
EmissionType::Final,
|
||||
Boundedness::Bounded,
|
||||
);
|
||||
|
||||
Self {
|
||||
ds_wrapper,
|
||||
dataset,
|
||||
input,
|
||||
write_params,
|
||||
properties,
|
||||
partial_transactions: Arc::new(Mutex::new(Vec::with_capacity(num_partitions))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DisplayAs for InsertExec {
|
||||
fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match t {
|
||||
DisplayFormatType::Default | DisplayFormatType::Verbose => {
|
||||
write!(f, "InsertExec: mode={:?}", self.write_params.mode)
|
||||
}
|
||||
DisplayFormatType::TreeRender => {
|
||||
write!(f, "InsertExec")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ExecutionPlan for InsertExec {
|
||||
fn name(&self) -> &str {
|
||||
Self::static_name()
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
|
||||
fn properties(&self) -> &PlanProperties {
|
||||
&self.properties
|
||||
}
|
||||
|
||||
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
|
||||
vec![&self.input]
|
||||
}
|
||||
|
||||
fn maintains_input_order(&self) -> Vec<bool> {
|
||||
vec![false]
|
||||
}
|
||||
|
||||
fn benefits_from_input_partitioning(&self) -> Vec<bool> {
|
||||
vec![false]
|
||||
}
|
||||
|
||||
fn with_new_children(
|
||||
self: Arc<Self>,
|
||||
children: Vec<Arc<dyn ExecutionPlan>>,
|
||||
) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
|
||||
if children.len() != 1 {
|
||||
return Err(DataFusionError::Internal(
|
||||
"InsertExec requires exactly one child".to_string(),
|
||||
));
|
||||
}
|
||||
Ok(Arc::new(Self::new(
|
||||
self.ds_wrapper.clone(),
|
||||
self.dataset.clone(),
|
||||
children[0].clone(),
|
||||
self.write_params.clone(),
|
||||
)))
|
||||
}
|
||||
|
||||
fn execute(
|
||||
&self,
|
||||
partition: usize,
|
||||
context: Arc<TaskContext>,
|
||||
) -> DataFusionResult<SendableRecordBatchStream> {
|
||||
let input_stream = self.input.execute(partition, context)?;
|
||||
let dataset = self.dataset.clone();
|
||||
let write_params = self.write_params.clone();
|
||||
let partial_transactions = self.partial_transactions.clone();
|
||||
let total_partitions = self.input.output_partitioning().partition_count();
|
||||
let ds_wrapper = self.ds_wrapper.clone();
|
||||
|
||||
let stream = futures::stream::once(async move {
|
||||
let transaction = InsertBuilder::new(dataset.clone())
|
||||
.with_params(&write_params)
|
||||
.execute_uncommitted_stream(input_stream)
|
||||
.await?;
|
||||
|
||||
let num_rows = count_rows_from_operation(&transaction.operation);
|
||||
|
||||
let to_commit = {
|
||||
// Don't hold the lock over an await point.
|
||||
let mut txns = partial_transactions.lock().unwrap();
|
||||
txns.push(transaction);
|
||||
if txns.len() == total_partitions {
|
||||
Some(std::mem::take(&mut *txns))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(transactions) = to_commit {
|
||||
if let Some(merged_txn) = merge_transactions(transactions) {
|
||||
let new_dataset = CommitBuilder::new(dataset.clone())
|
||||
.execute(merged_txn)
|
||||
.await?;
|
||||
ds_wrapper.set_latest(new_dataset).await;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(RecordBatch::try_new(
|
||||
COUNT_SCHEMA.clone(),
|
||||
vec![Arc::new(UInt64Array::from(vec![num_rows]))],
|
||||
)?)
|
||||
});
|
||||
|
||||
Ok(Box::pin(RecordBatchStreamAdapter::new(
|
||||
COUNT_SCHEMA.clone(),
|
||||
stream,
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::vec;
|
||||
|
||||
use super::*;
|
||||
use arrow_array::{record_batch, Int32Array, RecordBatchIterator};
|
||||
use datafusion::prelude::SessionContext;
|
||||
use datafusion_catalog::MemTable;
|
||||
use tempfile::tempdir;
|
||||
|
||||
use crate::connect;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_insert_via_sql() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let uri = tmp_dir.path().to_str().unwrap();
|
||||
|
||||
let db = connect(uri).execute().await.unwrap();
|
||||
|
||||
// Create initial table
|
||||
let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap();
|
||||
let schema = batch.schema();
|
||||
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
|
||||
|
||||
let table = db
|
||||
.create_table("test_insert", Box::new(reader))
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify initial count
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 3);
|
||||
|
||||
let ctx = SessionContext::new();
|
||||
let provider =
|
||||
crate::table::datafusion::BaseTableAdapter::try_new(table.base_table().clone())
|
||||
.await
|
||||
.unwrap();
|
||||
ctx.register_table("test_insert", Arc::new(provider))
|
||||
.unwrap();
|
||||
|
||||
ctx.sql("INSERT INTO test_insert VALUES (4), (5), (6)")
|
||||
.await
|
||||
.unwrap()
|
||||
.collect()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify final count
|
||||
table.checkout_latest().await.unwrap();
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 6);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_insert_overwrite_via_sql() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let uri = tmp_dir.path().to_str().unwrap();
|
||||
|
||||
let db = connect(uri).execute().await.unwrap();
|
||||
|
||||
// Create initial table with 3 rows
|
||||
let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap();
|
||||
let schema = batch.schema();
|
||||
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
|
||||
|
||||
let table = db
|
||||
.create_table("test_overwrite", Box::new(reader))
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 3);
|
||||
|
||||
let ctx = SessionContext::new();
|
||||
let provider =
|
||||
crate::table::datafusion::BaseTableAdapter::try_new(table.base_table().clone())
|
||||
.await
|
||||
.unwrap();
|
||||
ctx.register_table("test_overwrite", Arc::new(provider))
|
||||
.unwrap();
|
||||
|
||||
ctx.sql("INSERT OVERWRITE INTO test_overwrite VALUES (10), (20)")
|
||||
.await
|
||||
.unwrap()
|
||||
.collect()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify: should have 2 rows (overwritten, not appended)
|
||||
table.checkout_latest().await.unwrap();
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_insert_empty_batch() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let uri = tmp_dir.path().to_str().unwrap();
|
||||
|
||||
let db = connect(uri).execute().await.unwrap();
|
||||
|
||||
// Create initial table
|
||||
let schema = Arc::new(ArrowSchema::new(vec![Field::new(
|
||||
"id",
|
||||
DataType::Int32,
|
||||
false,
|
||||
)]));
|
||||
let batches = vec![RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
|
||||
)
|
||||
.unwrap()];
|
||||
let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema.clone());
|
||||
|
||||
let table = db
|
||||
.create_table("test_empty", Box::new(reader))
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 3);
|
||||
|
||||
let ctx = SessionContext::new();
|
||||
let provider =
|
||||
crate::table::datafusion::BaseTableAdapter::try_new(table.base_table().clone())
|
||||
.await
|
||||
.unwrap();
|
||||
ctx.register_table("test_empty", Arc::new(provider))
|
||||
.unwrap();
|
||||
|
||||
let source_schema = Arc::new(ArrowSchema::new(vec![Field::new(
|
||||
"id",
|
||||
DataType::Int32,
|
||||
false,
|
||||
)]));
|
||||
// Empty batches
|
||||
let source_reader = RecordBatchIterator::new(
|
||||
std::iter::empty::<Result<RecordBatch, arrow_schema::ArrowError>>(),
|
||||
source_schema,
|
||||
);
|
||||
let source_table = db
|
||||
.create_table("empty_source", Box::new(source_reader))
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
let source_provider =
|
||||
crate::table::datafusion::BaseTableAdapter::try_new(source_table.base_table().clone())
|
||||
.await
|
||||
.unwrap();
|
||||
ctx.register_table("empty_source", Arc::new(source_provider))
|
||||
.unwrap();
|
||||
|
||||
// Execute INSERT with empty source
|
||||
ctx.sql("INSERT INTO test_empty SELECT * FROM empty_source")
|
||||
.await
|
||||
.unwrap()
|
||||
.collect()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify: should still have 3 rows (nothing inserted)
|
||||
table.checkout_latest().await.unwrap();
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_insert_multiple_batches() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let uri = tmp_dir.path().to_str().unwrap();
|
||||
|
||||
let db = connect(uri).execute().await.unwrap();
|
||||
|
||||
// Create initial table
|
||||
let schema = Arc::new(ArrowSchema::new(vec![Field::new(
|
||||
"id",
|
||||
DataType::Int32,
|
||||
true,
|
||||
)]));
|
||||
let batches =
|
||||
vec![
|
||||
RecordBatch::try_new(schema.clone(), vec![Arc::new(Int32Array::from(vec![1]))])
|
||||
.unwrap(),
|
||||
];
|
||||
let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema.clone());
|
||||
|
||||
let table = db
|
||||
.create_table("test_multi_batch", Box::new(reader))
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let ctx = SessionContext::new();
|
||||
let provider =
|
||||
crate::table::datafusion::BaseTableAdapter::try_new(table.base_table().clone())
|
||||
.await
|
||||
.unwrap();
|
||||
ctx.register_table("test_multi_batch", Arc::new(provider))
|
||||
.unwrap();
|
||||
|
||||
// Memtable with multiple batches and multiple partitions
|
||||
let source_table = MemTable::try_new(
|
||||
schema.clone(),
|
||||
vec![
|
||||
// Partition 0
|
||||
vec![
|
||||
record_batch!(("id", Int32, [2, 3])).unwrap(),
|
||||
record_batch!(("id", Int32, [4, 5])).unwrap(),
|
||||
],
|
||||
// Partition 1
|
||||
vec![record_batch!(("id", Int32, [6, 7, 8])).unwrap()],
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
ctx.register_table("multi_batch_source", Arc::new(source_table))
|
||||
.unwrap();
|
||||
|
||||
ctx.sql("INSERT INTO test_multi_batch SELECT * FROM multi_batch_source")
|
||||
.await
|
||||
.unwrap()
|
||||
.collect()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify: should have 1 + 2 + 2 + 3 = 8 rows
|
||||
table.checkout_latest().await.unwrap();
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 8);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user