diff --git a/Cargo.lock b/Cargo.lock index 689714197..9cb765593 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5001,6 +5001,7 @@ dependencies = [ "datafusion-common", "datafusion-execution", "datafusion-expr", + "datafusion-physical-expr", "datafusion-physical-plan", "futures", "half", diff --git a/Cargo.toml b/Cargo.toml index 33cf8693b..89420ed00 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,6 +46,7 @@ datafusion-common = { version = "50.1", default-features = false } datafusion-execution = "50.1" datafusion-expr = "50.1" datafusion-physical-plan = "50.1" +datafusion-physical-expr = "50.1" env_logger = "0.11" half = { "version" = "2.6.0", default-features = false, features = [ "num-traits", diff --git a/rust/lancedb/Cargo.toml b/rust/lancedb/Cargo.toml index 7476ef619..178d216f7 100644 --- a/rust/lancedb/Cargo.toml +++ b/rust/lancedb/Cargo.toml @@ -25,6 +25,7 @@ datafusion-catalog.workspace = true datafusion-common.workspace = true datafusion-execution.workspace = true datafusion-expr.workspace = true +datafusion-physical-expr.workspace = true datafusion-physical-plan.workspace = true datafusion.workspace = true object_store = { workspace = true } diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index c3339d3c5..dbcd32464 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -1,6 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The LanceDB Authors +pub mod insert; + use crate::index::Index; use crate::index::IndexStatistics; use crate::query::{QueryFilter, QueryRequest, Select, VectorQueryRequest}; @@ -1508,6 +1510,21 @@ impl BaseTable for RemoteTable { })?; Ok(stats) } + + async fn create_insert_exec( + &self, + input: Arc, + write_params: lance::dataset::WriteParams, + ) -> Result> { + let overwrite = matches!(write_params.mode, lance::dataset::WriteMode::Overwrite); + Ok(Arc::new(insert::RemoteInsertExec::new( + self.name.clone(), + self.identifier.clone(), + self.client.clone(), + input, + overwrite, + ))) + } } #[derive(Serialize)] diff --git a/rust/lancedb/src/remote/table/insert.rs b/rust/lancedb/src/remote/table/insert.rs new file mode 100644 index 000000000..04caaaa4a --- /dev/null +++ b/rust/lancedb/src/remote/table/insert.rs @@ -0,0 +1,438 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +//! DataFusion ExecutionPlan for inserting data into remote LanceDB tables. + +use std::any::Any; +use std::sync::{Arc, Mutex}; + +use arrow_array::{ArrayRef, RecordBatch, UInt64Array}; +use arrow_ipc::CompressionType; +use arrow_schema::ArrowError; +use datafusion_common::{DataFusionError, Result as DataFusionResult}; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_physical_expr::EquivalenceProperties; +use datafusion_physical_plan::stream::RecordBatchStreamAdapter; +use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties}; +use futures::StreamExt; +use http::header::CONTENT_TYPE; + +use crate::remote::client::{HttpSend, RestfulLanceDbClient, Sender}; +use crate::remote::table::RemoteTable; +use crate::remote::ARROW_STREAM_CONTENT_TYPE; +use crate::table::datafusion::insert::COUNT_SCHEMA; +use crate::table::AddResult; +use crate::Error; + +/// ExecutionPlan for inserting data into a remote LanceDB table. +/// +/// This plan: +/// 1. Requires single partition (no parallel remote inserts yet) +/// 2. Streams data as Arrow IPC to `/v1/table/{id}/insert/` endpoint +/// 3. Stores AddResult for retrieval after execution +#[derive(Debug)] +pub struct RemoteInsertExec { + table_name: String, + identifier: String, + client: RestfulLanceDbClient, + input: Arc, + overwrite: bool, + properties: PlanProperties, + add_result: Arc>>, +} + +impl RemoteInsertExec { + /// Create a new RemoteInsertExec. + pub fn new( + table_name: String, + identifier: String, + client: RestfulLanceDbClient, + input: Arc, + overwrite: bool, + ) -> Self { + let schema = COUNT_SCHEMA.clone(); + let properties = PlanProperties::new( + EquivalenceProperties::new(schema), + datafusion_physical_plan::Partitioning::UnknownPartitioning(1), + datafusion_physical_plan::execution_plan::EmissionType::Final, + datafusion_physical_plan::execution_plan::Boundedness::Bounded, + ); + + Self { + table_name, + identifier, + client, + input, + overwrite, + properties, + add_result: Arc::new(Mutex::new(None)), + } + } + + /// Get the add result after execution. + // TODO: this will be used when we wire this up to Table::add(). + #[allow(dead_code)] + pub fn add_result(&self) -> Option { + self.add_result.lock().unwrap().clone() + } + + fn stream_as_body(data: SendableRecordBatchStream) -> DataFusionResult { + let options = arrow_ipc::writer::IpcWriteOptions::default() + .try_with_compression(Some(CompressionType::LZ4_FRAME))?; + let writer = arrow_ipc::writer::StreamWriter::try_new_with_options( + Vec::new(), + &data.schema(), + options, + )?; + + let stream = futures::stream::try_unfold((data, writer), move |(mut data, mut writer)| { + async move { + match data.next().await { + Some(Ok(batch)) => { + writer.write(&batch)?; + let buffer = std::mem::take(writer.get_mut()); + Ok(Some((buffer, (data, writer)))) + } + Some(Err(e)) => Err(e), + None => { + if let Err(ArrowError::IpcError(_msg)) = writer.finish() { + // Will error if already closed. + return Ok(None); + }; + let buffer = std::mem::take(writer.get_mut()); + Ok(Some((buffer, (data, writer)))) + } + } + } + }); + + Ok(reqwest::Body::wrap_stream(stream)) + } +} + +impl DisplayAs for RemoteInsertExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!( + f, + "RemoteInsertExec: table={}, overwrite={}", + self.table_name, self.overwrite + ) + } + DisplayFormatType::TreeRender => { + write!(f, "RemoteInsertExec") + } + } + } +} + +impl ExecutionPlan for RemoteInsertExec { + fn name(&self) -> &str { + Self::static_name() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.properties + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn maintains_input_order(&self) -> Vec { + vec![false] + } + + fn required_input_distribution(&self) -> Vec { + // Until we have a separate commit endpoint, we need to do all inserts in a single partition + vec![datafusion_physical_plan::Distribution::SinglePartition] + } + + fn benefits_from_input_partitioning(&self) -> Vec { + vec![false] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> DataFusionResult> { + if children.len() != 1 { + return Err(DataFusionError::Internal( + "RemoteInsertExec requires exactly one child".to_string(), + )); + } + Ok(Arc::new(Self::new( + self.table_name.clone(), + self.identifier.clone(), + self.client.clone(), + children[0].clone(), + self.overwrite, + ))) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> DataFusionResult { + if partition != 0 { + return Err(DataFusionError::Internal( + "RemoteInsertExec only supports single partition execution".to_string(), + )); + } + + let input_stream = self.input.execute(0, context)?; + let client = self.client.clone(); + let identifier = self.identifier.clone(); + let overwrite = self.overwrite; + let add_result = self.add_result.clone(); + let table_name = self.table_name.clone(); + + let stream = futures::stream::once(async move { + let mut request = client + .post(&format!("/v1/table/{}/insert/", identifier)) + .header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE); + + if overwrite { + request = request.query(&[("mode", "overwrite")]); + } + + let body = Self::stream_as_body(input_stream)?; + let request = request.body(body); + + let (request_id, response) = client + .send(request) + .await + .map_err(|e| DataFusionError::External(Box::new(e)))?; + + let response = + RemoteTable::::handle_table_not_found(&table_name, response, &request_id) + .await + .map_err(|e| DataFusionError::External(Box::new(e)))?; + + let response = client + .check_response(&request_id, response) + .await + .map_err(|e| DataFusionError::External(Box::new(e)))?; + + let body_text = response.text().await.map_err(|e| { + DataFusionError::External(Box::new(Error::Http { + source: Box::new(e), + request_id: request_id.clone(), + status_code: None, + })) + })?; + + let parsed_result = if body_text.trim().is_empty() { + // Backward compatible with old servers + AddResult { version: 0 } + } else { + serde_json::from_str(&body_text).map_err(|e| { + DataFusionError::External(Box::new(Error::Http { + source: format!("Failed to parse add response: {}", e).into(), + request_id: request_id.clone(), + status_code: None, + })) + })? + }; + + { + let mut res_lock = add_result.lock().map_err(|_| { + DataFusionError::Execution("Failed to acquire lock for add_result".to_string()) + })?; + *res_lock = Some(parsed_result); + } + + // Return a single batch with count 0 (actual count is tracked in add_result) + let count_array: ArrayRef = Arc::new(UInt64Array::from(vec![0u64])); + let batch = RecordBatch::try_new(COUNT_SCHEMA.clone(), vec![count_array])?; + Ok::<_, DataFusionError>(batch) + }); + + Ok(Box::pin(RecordBatchStreamAdapter::new( + COUNT_SCHEMA.clone(), + stream, + ))) + } +} + +#[cfg(test)] +mod tests { + use arrow_array::record_batch; + use arrow_schema::{DataType, Field, Schema as ArrowSchema}; + use datafusion::prelude::SessionContext; + use datafusion_catalog::MemTable; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; + + use crate::remote::ARROW_STREAM_CONTENT_TYPE; + use crate::table::datafusion::BaseTableAdapter; + use crate::Table; + + fn schema_json() -> &'static str { + r#"{"fields": [{"name": "id", "type": {"type": "int32"}, "nullable": true}]}"# + } + + #[tokio::test] + async fn test_remote_insert_exec_execute_empty() { + let request_count = Arc::new(AtomicUsize::new(0)); + let request_count_clone = request_count.clone(); + + let table = Table::new_with_handler("my_table", move |request| { + let path = request.url().path(); + + if path == "/v1/table/my_table/describe/" { + // Return schema for BaseTableAdapter::try_new + return http::Response::builder() + .status(200) + .body(format!(r#"{{"version": 1, "schema": {}}}"#, schema_json())) + .unwrap(); + } + + if path == "/v1/table/my_table/insert/" { + assert_eq!(request.method(), "POST"); + assert_eq!( + request.headers().get("Content-Type").unwrap(), + ARROW_STREAM_CONTENT_TYPE + ); + request_count_clone.fetch_add(1, Ordering::SeqCst); + + return http::Response::builder() + .status(200) + .body(r#"{"version": 2}"#.to_string()) + .unwrap(); + } + + panic!("Unexpected request path: {}", path); + }); + + let schema = Arc::new(ArrowSchema::new(vec![Field::new( + "id", + DataType::Int32, + true, + )])); + + // Create empty MemTable (no batches) + let source_table = MemTable::try_new(schema, vec![vec![]]).unwrap(); + + let ctx = SessionContext::new(); + + // Register the remote table as insert target + let provider = BaseTableAdapter::try_new(table.base_table().clone()) + .await + .unwrap(); + ctx.register_table("my_table", Arc::new(provider)).unwrap(); + + // Register empty source + ctx.register_table("empty_source", Arc::new(source_table)) + .unwrap(); + + // Execute the INSERT + ctx.sql("INSERT INTO my_table SELECT * FROM empty_source") + .await + .unwrap() + .collect() + .await + .unwrap(); + + // Verify: should have made exactly one HTTP request even with empty input + assert_eq!(request_count.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn test_remote_insert_exec_multi_partition() { + let request_count = Arc::new(AtomicUsize::new(0)); + let request_count_clone = request_count.clone(); + + let table = Table::new_with_handler("my_table", move |request| { + let path = request.url().path(); + + if path == "/v1/table/my_table/describe/" { + // Return schema for BaseTableAdapter::try_new + return http::Response::builder() + .status(200) + .body(format!(r#"{{"version": 1, "schema": {}}}"#, schema_json())) + .unwrap(); + } + + if path == "/v1/table/my_table/insert/" { + assert_eq!(request.method(), "POST"); + assert_eq!( + request.headers().get("Content-Type").unwrap(), + ARROW_STREAM_CONTENT_TYPE + ); + request_count_clone.fetch_add(1, Ordering::SeqCst); + + return http::Response::builder() + .status(200) + .body(r#"{"version": 2}"#.to_string()) + .unwrap(); + } + + panic!("Unexpected request path: {}", path); + }); + + let schema = Arc::new(ArrowSchema::new(vec![Field::new( + "id", + DataType::Int32, + true, + )])); + + // Create MemTable with multiple partitions and multiple batches + let source_table = MemTable::try_new( + schema, + vec![ + // Partition 0 + vec![ + record_batch!(("id", Int32, [1, 2])).unwrap(), + record_batch!(("id", Int32, [3, 4])).unwrap(), + ], + // Partition 1 + vec![record_batch!(("id", Int32, [5, 6, 7])).unwrap()], + // Partition 2 + vec![record_batch!(("id", Int32, [8])).unwrap()], + ], + ) + .unwrap(); + + let ctx = SessionContext::new(); + + // Register the remote table as insert target + let provider = BaseTableAdapter::try_new(table.base_table().clone()) + .await + .unwrap(); + ctx.register_table("my_table", Arc::new(provider)).unwrap(); + + // Register multi-partition source + ctx.register_table("multi_partition_source", Arc::new(source_table)) + .unwrap(); + + // Get the physical plan and verify it includes a repartition to 1 + let df = ctx + .sql("INSERT INTO my_table SELECT * FROM multi_partition_source") + .await + .unwrap(); + let plan = df.clone().create_physical_plan().await.unwrap(); + let plan_str = datafusion::physical_plan::displayable(plan.as_ref()) + .indent(true) + .to_string(); + + // The plan should include a CoalescePartitionsExec to merge partitions + assert!( + plan_str.contains("CoalescePartitionsExec"), + "Expected CoalescePartitionsExec in plan:\n{}", + plan_str + ); + + // Execute the INSERT + df.collect().await.unwrap(); + + // Verify: should have made exactly one HTTP request despite multiple input partitions + assert_eq!(request_count.load(Ordering::SeqCst), 1); + } +} diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index 456c242ee..9ebafbed4 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -537,6 +537,19 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync { ) -> Result<()>; /// Get statistics on the table async fn stats(&self) -> Result; + /// Create an ExecutionPlan for inserting data into the table. + /// + /// This is used by the DataFusion TableProvider implementation to support + /// INSERT INTO statements. + async fn create_insert_exec( + &self, + _input: Arc, + _write_params: WriteParams, + ) -> Result> { + Err(Error::NotSupported { + message: "create_insert_exec not implemented".to_string(), + }) + } } /// A Table is a collection of strong typed Rows. @@ -3247,6 +3260,21 @@ impl BaseTable for NativeTable { }; Ok(stats) } + + async fn create_insert_exec( + &self, + input: Arc, + write_params: WriteParams, + ) -> Result> { + let ds = self.dataset.get().await?; + let dataset = Arc::new((*ds).clone()); + Ok(Arc::new(datafusion::insert::InsertExec::new( + self.dataset.clone(), + dataset, + input, + write_params, + ))) + } } #[skip_serializing_none] diff --git a/rust/lancedb/src/table/datafusion.rs b/rust/lancedb/src/table/datafusion.rs index c7fa5fe8d..760871631 100644 --- a/rust/lancedb/src/table/datafusion.rs +++ b/rust/lancedb/src/table/datafusion.rs @@ -3,6 +3,7 @@ //! This module contains adapters to allow LanceDB tables to be used as DataFusion table providers. +pub mod insert; pub mod udtf; use std::{collections::HashMap, sync::Arc}; @@ -13,11 +14,12 @@ use async_trait::async_trait; use datafusion_catalog::{Session, TableProvider}; use datafusion_common::{DataFusionError, Result as DataFusionResult, Statistics}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; -use datafusion_expr::{Expr, TableProviderFilterPushDown, TableType}; +use datafusion_expr::{dml::InsertOp, Expr, TableProviderFilterPushDown, TableType}; use datafusion_physical_plan::{ stream::RecordBatchStreamAdapter, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, }; use futures::{TryFutureExt, TryStreamExt}; +use lance::dataset::{WriteMode, WriteParams}; use super::{AnyQuery, BaseTable}; use crate::{ @@ -250,6 +252,33 @@ impl TableProvider for BaseTableAdapter { // TODO None } + + async fn insert_into( + &self, + _state: &dyn Session, + input: Arc, + insert_op: InsertOp, + ) -> DataFusionResult> { + let mode = match insert_op { + InsertOp::Append => WriteMode::Append, + InsertOp::Overwrite => WriteMode::Overwrite, + InsertOp::Replace => { + return Err(DataFusionError::NotImplemented( + "Replace mode is not supported for LanceDB tables".to_string(), + )) + } + }; + + let write_params = WriteParams { + mode, + ..Default::default() + }; + + self.table + .create_insert_exec(input, write_params) + .await + .map_err(|e| DataFusionError::External(e.into())) + } } #[cfg(test)] diff --git a/rust/lancedb/src/table/datafusion/insert.rs b/rust/lancedb/src/table/datafusion/insert.rs new file mode 100644 index 000000000..e3ee371ed --- /dev/null +++ b/rust/lancedb/src/table/datafusion/insert.rs @@ -0,0 +1,446 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +//! DataFusion ExecutionPlan for inserting data into LanceDB tables. + +use std::any::Any; +use std::sync::{Arc, LazyLock, Mutex}; + +use arrow_array::{RecordBatch, UInt64Array}; +use arrow_schema::{DataType, Field, Schema as ArrowSchema, SchemaRef}; +use datafusion_common::{DataFusionError, Result as DataFusionResult}; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; +use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; +use datafusion_physical_plan::stream::RecordBatchStreamAdapter; +use datafusion_physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, +}; +use lance::dataset::transaction::{Operation, Transaction}; +use lance::dataset::{CommitBuilder, InsertBuilder, WriteParams}; +use lance::Dataset; +use lance_table::format::Fragment; + +use crate::table::dataset::DatasetConsistencyWrapper; + +pub(crate) static COUNT_SCHEMA: LazyLock = LazyLock::new(|| { + Arc::new(ArrowSchema::new(vec![Field::new( + "count", + DataType::UInt64, + false, + )])) +}); + +fn operation_fragments(operation: &Operation) -> &[Fragment] { + match operation { + Operation::Append { fragments } => fragments, + Operation::Overwrite { fragments, .. } => fragments, + _ => &[], + } +} + +fn count_rows_from_operation(operation: &Operation) -> u64 { + operation_fragments(operation) + .iter() + .map(|f| f.num_rows().unwrap_or(0) as u64) + .sum() +} + +fn operation_fragments_mut(operation: &mut Operation) -> &mut Vec { + match operation { + Operation::Append { fragments } => fragments, + Operation::Overwrite { fragments, .. } => fragments, + _ => panic!("Unsupported operation type for getting mutable fragments"), + } +} + +fn merge_transactions(mut transactions: Vec) -> Option { + let mut first = transactions.pop()?; + + for txn in transactions { + let first_fragments = operation_fragments_mut(&mut first.operation); + let txn_fragments = operation_fragments(&txn.operation); + first_fragments.extend_from_slice(txn_fragments); + } + + Some(first) +} + +/// ExecutionPlan for inserting data into a native LanceDB table. +/// +/// This plan executes inserts by: +/// 1. Each partition writes data independently using InsertBuilder::execute_uncommitted_stream +/// 2. The last partition to complete commits all transactions atomically +/// 3. Returns the count of inserted rows per partition +#[derive(Debug)] +pub struct InsertExec { + ds_wrapper: DatasetConsistencyWrapper, + dataset: Arc, + input: Arc, + write_params: WriteParams, + properties: PlanProperties, + partial_transactions: Arc>>, +} + +impl InsertExec { + pub fn new( + ds_wrapper: DatasetConsistencyWrapper, + dataset: Arc, + input: Arc, + write_params: WriteParams, + ) -> Self { + let schema = COUNT_SCHEMA.clone(); + let num_partitions = input.output_partitioning().partition_count(); + let properties = PlanProperties::new( + EquivalenceProperties::new(schema), + Partitioning::UnknownPartitioning(num_partitions), + EmissionType::Final, + Boundedness::Bounded, + ); + + Self { + ds_wrapper, + dataset, + input, + write_params, + properties, + partial_transactions: Arc::new(Mutex::new(Vec::with_capacity(num_partitions))), + } + } +} + +impl DisplayAs for InsertExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "InsertExec: mode={:?}", self.write_params.mode) + } + DisplayFormatType::TreeRender => { + write!(f, "InsertExec") + } + } + } +} + +impl ExecutionPlan for InsertExec { + fn name(&self) -> &str { + Self::static_name() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.properties + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn maintains_input_order(&self) -> Vec { + vec![false] + } + + fn benefits_from_input_partitioning(&self) -> Vec { + vec![false] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> DataFusionResult> { + if children.len() != 1 { + return Err(DataFusionError::Internal( + "InsertExec requires exactly one child".to_string(), + )); + } + Ok(Arc::new(Self::new( + self.ds_wrapper.clone(), + self.dataset.clone(), + children[0].clone(), + self.write_params.clone(), + ))) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> DataFusionResult { + let input_stream = self.input.execute(partition, context)?; + let dataset = self.dataset.clone(); + let write_params = self.write_params.clone(); + let partial_transactions = self.partial_transactions.clone(); + let total_partitions = self.input.output_partitioning().partition_count(); + let ds_wrapper = self.ds_wrapper.clone(); + + let stream = futures::stream::once(async move { + let transaction = InsertBuilder::new(dataset.clone()) + .with_params(&write_params) + .execute_uncommitted_stream(input_stream) + .await?; + + let num_rows = count_rows_from_operation(&transaction.operation); + + let to_commit = { + // Don't hold the lock over an await point. + let mut txns = partial_transactions.lock().unwrap(); + txns.push(transaction); + if txns.len() == total_partitions { + Some(std::mem::take(&mut *txns)) + } else { + None + } + }; + + if let Some(transactions) = to_commit { + if let Some(merged_txn) = merge_transactions(transactions) { + let new_dataset = CommitBuilder::new(dataset.clone()) + .execute(merged_txn) + .await?; + ds_wrapper.set_latest(new_dataset).await; + } + } + + Ok(RecordBatch::try_new( + COUNT_SCHEMA.clone(), + vec![Arc::new(UInt64Array::from(vec![num_rows]))], + )?) + }); + + Ok(Box::pin(RecordBatchStreamAdapter::new( + COUNT_SCHEMA.clone(), + stream, + ))) + } +} + +#[cfg(test)] +mod tests { + use std::vec; + + use super::*; + use arrow_array::{record_batch, Int32Array, RecordBatchIterator}; + use datafusion::prelude::SessionContext; + use datafusion_catalog::MemTable; + use tempfile::tempdir; + + use crate::connect; + + #[tokio::test] + async fn test_insert_via_sql() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + + let db = connect(uri).execute().await.unwrap(); + + // Create initial table + let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap(); + let schema = batch.schema(); + let reader = RecordBatchIterator::new(vec![Ok(batch)], schema); + + let table = db + .create_table("test_insert", Box::new(reader)) + .execute() + .await + .unwrap(); + + // Verify initial count + assert_eq!(table.count_rows(None).await.unwrap(), 3); + + let ctx = SessionContext::new(); + let provider = + crate::table::datafusion::BaseTableAdapter::try_new(table.base_table().clone()) + .await + .unwrap(); + ctx.register_table("test_insert", Arc::new(provider)) + .unwrap(); + + ctx.sql("INSERT INTO test_insert VALUES (4), (5), (6)") + .await + .unwrap() + .collect() + .await + .unwrap(); + + // Verify final count + table.checkout_latest().await.unwrap(); + assert_eq!(table.count_rows(None).await.unwrap(), 6); + } + + #[tokio::test] + async fn test_insert_overwrite_via_sql() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + + let db = connect(uri).execute().await.unwrap(); + + // Create initial table with 3 rows + let batch = record_batch!(("id", Int32, [1, 2, 3])).unwrap(); + let schema = batch.schema(); + let reader = RecordBatchIterator::new(vec![Ok(batch)], schema); + + let table = db + .create_table("test_overwrite", Box::new(reader)) + .execute() + .await + .unwrap(); + + assert_eq!(table.count_rows(None).await.unwrap(), 3); + + let ctx = SessionContext::new(); + let provider = + crate::table::datafusion::BaseTableAdapter::try_new(table.base_table().clone()) + .await + .unwrap(); + ctx.register_table("test_overwrite", Arc::new(provider)) + .unwrap(); + + ctx.sql("INSERT OVERWRITE INTO test_overwrite VALUES (10), (20)") + .await + .unwrap() + .collect() + .await + .unwrap(); + + // Verify: should have 2 rows (overwritten, not appended) + table.checkout_latest().await.unwrap(); + assert_eq!(table.count_rows(None).await.unwrap(), 2); + } + + #[tokio::test] + async fn test_insert_empty_batch() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + + let db = connect(uri).execute().await.unwrap(); + + // Create initial table + let schema = Arc::new(ArrowSchema::new(vec![Field::new( + "id", + DataType::Int32, + false, + )])); + let batches = vec![RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap()]; + let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema.clone()); + + let table = db + .create_table("test_empty", Box::new(reader)) + .execute() + .await + .unwrap(); + + assert_eq!(table.count_rows(None).await.unwrap(), 3); + + let ctx = SessionContext::new(); + let provider = + crate::table::datafusion::BaseTableAdapter::try_new(table.base_table().clone()) + .await + .unwrap(); + ctx.register_table("test_empty", Arc::new(provider)) + .unwrap(); + + let source_schema = Arc::new(ArrowSchema::new(vec![Field::new( + "id", + DataType::Int32, + false, + )])); + // Empty batches + let source_reader = RecordBatchIterator::new( + std::iter::empty::>(), + source_schema, + ); + let source_table = db + .create_table("empty_source", Box::new(source_reader)) + .execute() + .await + .unwrap(); + let source_provider = + crate::table::datafusion::BaseTableAdapter::try_new(source_table.base_table().clone()) + .await + .unwrap(); + ctx.register_table("empty_source", Arc::new(source_provider)) + .unwrap(); + + // Execute INSERT with empty source + ctx.sql("INSERT INTO test_empty SELECT * FROM empty_source") + .await + .unwrap() + .collect() + .await + .unwrap(); + + // Verify: should still have 3 rows (nothing inserted) + table.checkout_latest().await.unwrap(); + assert_eq!(table.count_rows(None).await.unwrap(), 3); + } + + #[tokio::test] + async fn test_insert_multiple_batches() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + + let db = connect(uri).execute().await.unwrap(); + + // Create initial table + let schema = Arc::new(ArrowSchema::new(vec![Field::new( + "id", + DataType::Int32, + true, + )])); + let batches = + vec![ + RecordBatch::try_new(schema.clone(), vec![Arc::new(Int32Array::from(vec![1]))]) + .unwrap(), + ]; + let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema.clone()); + + let table = db + .create_table("test_multi_batch", Box::new(reader)) + .execute() + .await + .unwrap(); + + let ctx = SessionContext::new(); + let provider = + crate::table::datafusion::BaseTableAdapter::try_new(table.base_table().clone()) + .await + .unwrap(); + ctx.register_table("test_multi_batch", Arc::new(provider)) + .unwrap(); + + // Memtable with multiple batches and multiple partitions + let source_table = MemTable::try_new( + schema.clone(), + vec![ + // Partition 0 + vec![ + record_batch!(("id", Int32, [2, 3])).unwrap(), + record_batch!(("id", Int32, [4, 5])).unwrap(), + ], + // Partition 1 + vec![record_batch!(("id", Int32, [6, 7, 8])).unwrap()], + ], + ) + .unwrap(); + ctx.register_table("multi_batch_source", Arc::new(source_table)) + .unwrap(); + + ctx.sql("INSERT INTO test_multi_batch SELECT * FROM multi_batch_source") + .await + .unwrap() + .collect() + .await + .unwrap(); + + // Verify: should have 1 + 2 + 2 + 3 = 8 rows + table.checkout_latest().await.unwrap(); + assert_eq!(table.count_rows(None).await.unwrap(), 8); + } +}