From f6c9d31f9888fa202bf8fd5932bfb75f5876a10b Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Tue, 30 Jun 2026 08:28:41 -0700 Subject: [PATCH] feat: add polars dataframe integration (#3584) This PR is part cleanup, part feature, part example. It removes `IntoArrow` and `IntoArrowStream`. There was only one redundant call site between the two. Once we moved everything to `Scannable` these traits no longer serve any purpose. It adds a `Scannable` impl for a polars DataFrame. We used to have this at one point for `IntoArrow` so this is more like a regression fix than anything. It adds an example (and unit test) which ensures we can ingest from a Polars DataFrame and export to one. LazyFrame support would be a follow-up (though a pretty straightforward one) but we've never had proper LazyFrame support before. --- nodejs/src/merge.rs | 10 ++-- rust/lancedb/Cargo.toml | 4 ++ rust/lancedb/examples/polars.rs | 47 +++++++++++++++ rust/lancedb/src/arrow.rs | 49 +--------------- rust/lancedb/src/data/scannable.rs | 93 ++++++++++++++++++++++++++++++ 5 files changed, 150 insertions(+), 53 deletions(-) create mode 100644 rust/lancedb/examples/polars.rs diff --git a/nodejs/src/merge.rs b/nodejs/src/merge.rs index 5ba9846bc..1f9609160 100644 --- a/nodejs/src/merge.rs +++ b/nodejs/src/merge.rs @@ -3,7 +3,7 @@ use std::time::Duration; -use lancedb::{arrow::IntoArrow, ipc::ipc_file_to_batches, table::merge::MergeInsertBuilder}; +use lancedb::{ipc::ipc_file_to_batches, table::merge::MergeInsertBuilder}; use napi::bindgen_prelude::*; use napi_derive::napi; @@ -66,11 +66,9 @@ impl NativeMergeInsertBuilder { #[napi(catch_unwind)] pub async fn execute(&self, buf: Buffer) -> napi::Result { - let data = ipc_file_to_batches(buf.to_vec()) - .and_then(IntoArrow::into_arrow) - .map_err(|e| { - napi::Error::from_reason(format!("Failed to read IPC file: {}", convert_error(&e))) - })?; + let data = ipc_file_to_batches(buf.to_vec()).map_err(|e| { + napi::Error::from_reason(format!("Failed to read IPC file: {}", convert_error(&e))) + })?; let this = self.clone(); diff --git a/rust/lancedb/Cargo.toml b/rust/lancedb/Cargo.toml index 5a5834e24..e9d148421 100644 --- a/rust/lancedb/Cargo.toml +++ b/rust/lancedb/Cargo.toml @@ -166,6 +166,10 @@ required-features = ["bedrock"] [[example]] name = "simple" +[[example]] +name = "polars" +required-features = ["polars"] + [[example]] name = "full_text_search" diff --git a/rust/lancedb/examples/polars.rs b/rust/lancedb/examples/polars.rs new file mode 100644 index 000000000..9744b9eba --- /dev/null +++ b/rust/lancedb/examples/polars.rs @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +//! This example demonstrates ingesting a Polars DataFrame into LanceDB and +//! reading it back out as a Polars DataFrame. + +use lancedb::arrow::IntoPolars; +use lancedb::query::ExecutableQuery; +use lancedb::{Result, connect}; +use polars::prelude::{DataFrame, NamedFrom, Series}; + +fn make_dataframe() -> DataFrame { + let ids = Series::new("id", &[1i32, 2, 3, 4, 5]); + let names = Series::new("name", &["Alice", "Bob", "Carol", "Dave", "Eve"]); + let scores = Series::new("score", &[9.5f64, 8.1, 7.3, 9.0, 6.5]); + DataFrame::new(vec![ids, names, scores]).unwrap() +} + +#[tokio::main] +async fn main() -> Result<()> { + let tmp = tempfile::tempdir().unwrap(); + let db = connect(tmp.path().to_str().unwrap()).execute().await?; + + // Ingest a Polars DataFrame directly — DataFrame now implements Scannable. + let df = make_dataframe(); + println!("Input DataFrame:\n{df}"); + + let table = db.create_table("people", df).execute().await?; + + // Append more rows. + let more = DataFrame::new(vec![ + Series::new("id", &[6i32, 7]), + Series::new("name", &["Frank", "Grace"]), + Series::new("score", &[7.8f64, 8.9]), + ]) + .unwrap(); + table.add(more).execute().await?; + + // Read back as a Polars DataFrame. + let result_df = table.query().execute().await?.into_polars().await?; + + println!( + "\nRound-tripped DataFrame ({} rows):\n{result_df}", + result_df.height() + ); + Ok(()) +} diff --git a/rust/lancedb/src/arrow.rs b/rust/lancedb/src/arrow.rs index fe5dd998d..d4e7c2c8c 100644 --- a/rust/lancedb/src/arrow.rs +++ b/rust/lancedb/src/arrow.rs @@ -112,54 +112,14 @@ impl>> RecordBatchStream /// A trait for converting incoming data to Arrow /// -/// Integrations should implement this trait to allow data to be -/// imported directly from the integration. For example, implementing -/// this trait for `Vec>` would allow the `Vec` to be directly -/// used in methods like [`crate::connection::Connection::create_table`] -/// or [`crate::table::Table::add`] -pub trait IntoArrow { - /// Convert the data into an iterator of Arrow batches - fn into_arrow(self) -> Result>; -} - pub type BoxedRecordBatchReader = Box; -impl IntoArrow for T { - fn into_arrow(self) -> Result> { - Ok(Box::new(self)) - } -} - -/// A trait for converting incoming data to Arrow asynchronously -/// -/// Serves the same purpose as [`IntoArrow`], but for asynchronous data. -/// -/// Note: Arrow has no async equivalent to RecordBatchReader and so -pub trait IntoArrowStream { - /// Convert the data into a stream of Arrow batches - fn into_arrow(self) -> Result; -} - impl>> SimpleRecordBatchStream { pub fn new(stream: S, schema: Arc) -> Self { Self { schema, stream } } } -impl IntoArrowStream for SendableRecordBatchStream { - fn into_arrow(self) -> Result { - Ok(self) - } -} - -impl IntoArrowStream for datafusion_physical_plan::SendableRecordBatchStream { - fn into_arrow(self) -> Result { - let schema = self.schema(); - let stream = self.map_err(|df_err| df_err.into()); - Ok(Box::pin(SimpleRecordBatchStream::new(stream, schema))) - } -} - pub trait LanceDbDatagenExt { fn into_ldb_stream( self, @@ -264,9 +224,7 @@ impl IntoPolars for SendableRecordBatchStream { #[cfg(all(test, feature = "polars"))] mod tests { use super::SendableRecordBatchStream; - use crate::arrow::{ - IntoArrow, IntoPolars, PolarsDataFrameRecordBatchReader, SimpleRecordBatchStream, - }; + use crate::arrow::{IntoPolars, PolarsDataFrameRecordBatchReader, SimpleRecordBatchStream}; use polars::prelude::{DataFrame, NamedFrom, Series}; fn get_record_batch_reader_from_polars() -> Box { @@ -280,10 +238,7 @@ mod tests { float_series = Series::new("float", &[2.0]); let df2 = DataFrame::new(vec![string_series, int_series, float_series]).unwrap(); - PolarsDataFrameRecordBatchReader::new(df1.vstack(&df2).unwrap()) - .unwrap() - .into_arrow() - .unwrap() + Box::new(PolarsDataFrameRecordBatchReader::new(df1.vstack(&df2).unwrap()).unwrap()) } #[test] diff --git a/rust/lancedb/src/data/scannable.rs b/rust/lancedb/src/data/scannable.rs index a200ac465..0f810941a 100644 --- a/rust/lancedb/src/data/scannable.rs +++ b/rust/lancedb/src/data/scannable.rs @@ -185,6 +185,43 @@ impl Scannable for SendableRecordBatchStream { } } +#[cfg(feature = "polars")] +impl Scannable for polars::frame::DataFrame { + fn schema(&self) -> SchemaRef { + crate::polars_arrow_convertors::convert_polars_df_schema_to_arrow_rb_schema( + self.schema().clone(), + ) + .expect("failed to convert Polars DataFrame schema to Arrow schema") + } + + fn scan_as_stream(&mut self) -> SendableRecordBatchStream { + let schema = Scannable::schema(self); + let batches: crate::Result> = + match crate::arrow::PolarsDataFrameRecordBatchReader::new(self.clone()) { + Err(e) => Err(e), + Ok(reader) => reader.map(|b| b.map_err(Into::into)).collect(), + }; + match batches { + Err(e) => Box::pin(SimpleRecordBatchStream { + schema, + stream: once(async move { Err(e) }), + }), + Ok(batches) => { + let stream = futures::stream::iter(batches.into_iter().map(Ok)); + Box::pin(SimpleRecordBatchStream { schema, stream }) + } + } + } + + fn num_rows(&self) -> Option { + Some(self.height()) + } + + fn rescannable(&self) -> bool { + true + } +} + #[async_trait] impl StreamingWriteSource for Box { fn arrow_schema(&self) -> SchemaRef { @@ -1089,4 +1126,60 @@ mod tests { ); } } + + #[cfg(feature = "polars")] + mod polars_tests { + use super::*; + use crate::arrow::IntoPolars; + use crate::query::ExecutableQuery; + use polars::prelude::{DataFrame, NamedFrom, Series}; + + fn make_df() -> DataFrame { + DataFrame::new(vec![ + Series::new("id", &[1i32, 2, 3]), + Series::new("val", &[1.1f64, 2.2, 3.3]), + ]) + .unwrap() + } + + #[tokio::test] + async fn test_dataframe_scannable_round_trip() { + let tmp = tempfile::tempdir().unwrap(); + let db = crate::connect(tmp.path().to_str().unwrap()) + .execute() + .await + .unwrap(); + + let df = make_df(); + let table = db.create_table("t", df.clone()).execute().await.unwrap(); + + // Append the same rows again. + table.add(df.clone()).execute().await.unwrap(); + + let result = table + .query() + .execute() + .await + .unwrap() + .into_polars() + .await + .unwrap(); + + assert_eq!(result.height(), df.height() * 2); + assert_eq!(result.schema(), df.schema()); + } + + #[tokio::test] + async fn test_dataframe_scannable_rescannable() { + let mut df = make_df(); + assert!(df.rescannable()); + + let batches1: Vec = df.scan_as_stream().try_collect().await.unwrap(); + assert_eq!(batches1.iter().map(|b| b.num_rows()).sum::(), 3); + + // Can be scanned again. + let batches2: Vec = df.scan_as_stream().try_collect().await.unwrap(); + assert_eq!(batches2.iter().map(|b| b.num_rows()).sum::(), 3); + } + } }