diff --git a/python/src/error.rs b/python/src/error.rs index c04eac7b..4688b523 100644 --- a/python/src/error.rs +++ b/python/src/error.rs @@ -35,21 +35,16 @@ impl PythonErrorExt for std::result::Result { match &self { Ok(_) => Ok(self.unwrap()), Err(err) => match err { - LanceError::InvalidInput { .. } => self.value_error(), - LanceError::InvalidTableName { .. } => self.value_error(), - LanceError::TableNotFound { .. } => self.value_error(), - LanceError::Schema { .. } => self.value_error(), + LanceError::InvalidInput { .. } + | LanceError::InvalidTableName { .. } + | LanceError::TableNotFound { .. } + | LanceError::Schema { .. } => self.value_error(), LanceError::CreateDir { .. } => self.os_error(), - LanceError::TableAlreadyExists { .. } => self.runtime_error(), LanceError::ObjectStore { .. } => Err(PyIOError::new_err(err.to_string())), - LanceError::Lance { .. } => self.runtime_error(), - LanceError::Runtime { .. } => self.runtime_error(), - LanceError::Http { .. } => self.runtime_error(), - LanceError::Arrow { .. } => self.runtime_error(), LanceError::NotSupported { .. } => { Err(PyNotImplementedError::new_err(err.to_string())) } - LanceError::Other { .. } => self.runtime_error(), + _ => self.runtime_error(), }, } } diff --git a/rust/lancedb/Cargo.toml b/rust/lancedb/Cargo.toml index 95a19379..df8330ac 100644 --- a/rust/lancedb/Cargo.toml +++ b/rust/lancedb/Cargo.toml @@ -40,6 +40,8 @@ serde = { version = "^1" } serde_json = { version = "1" } # For remote feature reqwest = { version = "0.11.24", features = ["gzip", "json"], optional = true } +polars-arrow = { version = ">=0.37", optional = true } +polars = { version = ">=0.37", optional = true} [dev-dependencies] tempfile = "3.5.0" @@ -56,3 +58,4 @@ default = [] remote = ["dep:reqwest"] fp16kernels = ["lance-linalg/fp16kernels"] s3-test = [] +polars = ["dep:polars-arrow", "dep:polars"] diff --git a/rust/lancedb/src/arrow.rs b/rust/lancedb/src/arrow.rs index da990975..f9440bed 100644 --- a/rust/lancedb/src/arrow.rs +++ b/rust/lancedb/src/arrow.rs @@ -14,10 +14,12 @@ use std::{pin::Pin, sync::Arc}; -pub use arrow_array; pub use arrow_schema; use futures::{Stream, StreamExt}; +#[cfg(feature = "polars")] +use {crate::polars_arrow_convertors, polars::frame::ArrowChunk, polars::prelude::DataFrame}; + use crate::error::Result; /// An iterator of batches that also has a schema @@ -119,3 +121,171 @@ impl IntoArrow for T { Ok(Box::new(self)) } } + +#[cfg(feature = "polars")] +/// An iterator of record batches formed from a Polars DataFrame. +pub struct PolarsDataFrameRecordBatchReader { + chunks: std::vec::IntoIter, + arrow_schema: Arc, +} + +#[cfg(feature = "polars")] +impl PolarsDataFrameRecordBatchReader { + /// Creates a new `PolarsDataFrameRecordBatchReader` from a given Polars DataFrame. + /// If the input dataframe does not have aligned chunks, this function undergoes + /// the costly operation of reallocating each series as a single contigous chunk. + pub fn new(mut df: DataFrame) -> Result { + df.align_chunks(); + let arrow_schema = + polars_arrow_convertors::convert_polars_df_schema_to_arrow_rb_schema(df.schema())?; + Ok(Self { + chunks: df + .iter_chunks(polars_arrow_convertors::POLARS_ARROW_FLAVOR) + .collect::>() + .into_iter(), + arrow_schema, + }) + } +} + +#[cfg(feature = "polars")] +impl Iterator for PolarsDataFrameRecordBatchReader { + type Item = std::result::Result; + + fn next(&mut self) -> Option { + self.chunks.next().map(|chunk| { + let columns: std::result::Result, arrow_schema::ArrowError> = + chunk + .into_arrays() + .into_iter() + .zip(self.arrow_schema.fields.iter()) + .map(|(polars_array, arrow_field)| { + polars_arrow_convertors::convert_polars_arrow_array_to_arrow_rs_array( + polars_array, + arrow_field.data_type().clone(), + ) + }) + .collect(); + arrow_array::RecordBatch::try_new(self.arrow_schema.clone(), columns?) + }) + } +} + +#[cfg(feature = "polars")] +impl arrow_array::RecordBatchReader for PolarsDataFrameRecordBatchReader { + fn schema(&self) -> Arc { + self.arrow_schema.clone() + } +} + +/// A trait for converting the result of a LanceDB query into a Polars DataFrame with aligned +/// chunks. The resulting Polars DataFrame will have aligned chunks, but the series's +/// chunks are not guaranteed to be contiguous. +#[cfg(feature = "polars")] +pub trait IntoPolars { + fn into_polars(self) -> impl std::future::Future> + Send; +} + +#[cfg(feature = "polars")] +impl IntoPolars for SendableRecordBatchStream { + async fn into_polars(mut self) -> Result { + let polars_schema = + polars_arrow_convertors::convert_arrow_rb_schema_to_polars_df_schema(&self.schema())?; + let mut acc_df: DataFrame = DataFrame::from(&polars_schema); + while let Some(record_batch) = self.next().await { + let new_df = polars_arrow_convertors::convert_arrow_rb_to_polars_df( + &record_batch?, + &polars_schema, + )?; + acc_df = acc_df.vstack(&new_df)?; + } + Ok(acc_df) + } +} + +#[cfg(all(test, feature = "polars"))] +mod tests { + use super::SendableRecordBatchStream; + use crate::arrow::{ + IntoArrow, IntoPolars, PolarsDataFrameRecordBatchReader, SimpleRecordBatchStream, + }; + use polars::prelude::{DataFrame, NamedFrom, Series}; + + fn get_record_batch_reader_from_polars() -> Box { + let mut string_series = Series::new("string", &["ab"]); + let mut int_series = Series::new("int", &[1]); + let mut float_series = Series::new("float", &[1.0]); + let df1 = DataFrame::new(vec![string_series, int_series, float_series]).unwrap(); + + string_series = Series::new("string", &["bc"]); + int_series = Series::new("int", &[2]); + float_series = Series::new("float", &[2.0]); + let df2 = DataFrame::new(vec![string_series, int_series, float_series]).unwrap(); + + PolarsDataFrameRecordBatchReader::new(df1.vstack(&df2).unwrap()) + .unwrap() + .into_arrow() + .unwrap() + } + + #[test] + fn from_polars_to_arrow() { + let record_batch_reader = get_record_batch_reader_from_polars(); + let schema = record_batch_reader.schema(); + + // Test schema conversion + assert_eq!( + schema + .fields + .iter() + .map(|field| (field.name().as_str(), field.data_type())) + .collect::>(), + vec![ + ("string", &arrow_schema::DataType::LargeUtf8), + ("int", &arrow_schema::DataType::Int32), + ("float", &arrow_schema::DataType::Float64) + ] + ); + let record_batches: Vec = + record_batch_reader.map(|result| result.unwrap()).collect(); + assert_eq!(record_batches.len(), 2); + assert_eq!(schema, record_batches[0].schema()); + assert_eq!(record_batches[0].schema(), record_batches[1].schema()); + + // Test number of rows + assert_eq!(record_batches[0].num_rows(), 1); + assert_eq!(record_batches[1].num_rows(), 1); + } + + #[tokio::test] + async fn from_arrow_to_polars() { + let record_batch_reader = get_record_batch_reader_from_polars(); + let schema = record_batch_reader.schema(); + let stream: SendableRecordBatchStream = Box::pin(SimpleRecordBatchStream { + schema: schema.clone(), + stream: futures::stream::iter( + record_batch_reader + .into_iter() + .map(|r| r.map_err(Into::into)), + ), + }); + let df = stream.into_polars().await.unwrap(); + + // Test number of chunks and rows + assert_eq!(df.n_chunks(), 2); + assert_eq!(df.height(), 2); + + // Test schema conversion + assert_eq!( + df.schema() + .into_iter() + .map(|(name, datatype)| (name.to_string(), datatype)) + .collect::>(), + vec![ + ("string".to_string(), polars::prelude::DataType::String), + ("int".to_owned(), polars::prelude::DataType::Int32), + ("float".to_owned(), polars::prelude::DataType::Float64) + ] + ); + } +} diff --git a/rust/lancedb/src/error.rs b/rust/lancedb/src/error.rs index a528a177..1f14ef57 100644 --- a/rust/lancedb/src/error.rs +++ b/rust/lancedb/src/error.rs @@ -112,3 +112,13 @@ impl From for Error { } } } + +#[cfg(feature = "polars")] +impl From for Error { + fn from(source: polars::prelude::PolarsError) -> Self { + Self::Other { + message: "Error in Polars DataFrame integration.".to_string(), + source: Some(Box::new(source)), + } + } +} diff --git a/rust/lancedb/src/lib.rs b/rust/lancedb/src/lib.rs index 4fae0fa6..5e0d8f17 100644 --- a/rust/lancedb/src/lib.rs +++ b/rust/lancedb/src/lib.rs @@ -198,6 +198,8 @@ pub mod error; pub mod index; pub mod io; pub mod ipc; +#[cfg(feature = "polars")] +mod polars_arrow_convertors; pub mod query; #[cfg(feature = "remote")] pub(crate) mod remote; diff --git a/rust/lancedb/src/polars_arrow_convertors.rs b/rust/lancedb/src/polars_arrow_convertors.rs new file mode 100644 index 00000000..79db4fa2 --- /dev/null +++ b/rust/lancedb/src/polars_arrow_convertors.rs @@ -0,0 +1,123 @@ +/// Polars and LanceDB both use Arrow for their in memory-representation, but use +/// different Rust Arrow implementations. LanceDB uses the arrow-rs crate and +/// Polars uses the polars-arrow crate. +/// +/// This crate defines zero-copy conversions (of the underlying buffers) +/// between polars-arrow and arrow-rs using the C FFI. +/// +/// The polars-arrow does implement conversions to and from arrow-rs, but +/// requires a feature flagged dependency on arrow-rs. The version of arrow-rs +/// depended on by polars-arrow and LanceDB may not be compatible, +/// which necessitates using the C FFI. +use crate::error::Result; +use polars::prelude::{DataFrame, Series}; +use std::{mem, sync::Arc}; + +/// When interpreting Polars dataframes as polars-arrow record batches, +/// one must decide whether to use Arrow string/binary view types +/// instead of the standard Arrow string/binary types. +/// For now, we will not use string view types because conversions +/// for string view types from polars-arrow to arrow-rs are not yet implemented. +/// See: https://lists.apache.org/thread/w88tpz76ox8h3rxkjl4so6rg3f1rv7wt for the +/// differences in the types. +pub const POLARS_ARROW_FLAVOR: bool = false; +const IS_ARRAY_NULLABLE: bool = true; + +/// Converts a Polars DataFrame schema to an Arrow RecordBatch schema. +pub fn convert_polars_df_schema_to_arrow_rb_schema( + polars_df_schema: polars::prelude::Schema, +) -> Result> { + let arrow_fields: Result> = polars_df_schema + .into_iter() + .map(|(name, df_dtype)| { + let polars_arrow_dtype = df_dtype.to_arrow(POLARS_ARROW_FLAVOR); + let polars_field = + polars_arrow::datatypes::Field::new(name, polars_arrow_dtype, IS_ARRAY_NULLABLE); + convert_polars_arrow_field_to_arrow_rs_field(polars_field) + }) + .collect(); + Ok(Arc::new(arrow_schema::Schema::new(arrow_fields?))) +} + +/// Converts an Arrow RecordBatch schema to a Polars DataFrame schema. +pub fn convert_arrow_rb_schema_to_polars_df_schema( + arrow_schema: &arrow_schema::Schema, +) -> Result { + let polars_df_fields: Result> = arrow_schema + .fields() + .iter() + .map(|arrow_rs_field| { + let polars_arrow_field = convert_arrow_rs_field_to_polars_arrow_field(arrow_rs_field)?; + Ok(polars::prelude::Field::new( + arrow_rs_field.name(), + polars::datatypes::DataType::from(polars_arrow_field.data_type()), + )) + }) + .collect(); + Ok(polars::prelude::Schema::from_iter(polars_df_fields?)) +} + +/// Converts an Arrow RecordBatch to a Polars DataFrame, using a provided Polars DataFrame schema. +pub fn convert_arrow_rb_to_polars_df( + arrow_rb: &arrow::record_batch::RecordBatch, + polars_schema: &polars::prelude::Schema, +) -> Result { + let mut columns: Vec = Vec::with_capacity(arrow_rb.num_columns()); + + for (i, column) in arrow_rb.columns().iter().enumerate() { + let polars_df_dtype = polars_schema.try_get_at_index(i)?.1; + let polars_arrow_dtype = polars_df_dtype.to_arrow(POLARS_ARROW_FLAVOR); + let polars_array = + convert_arrow_rs_array_to_polars_arrow_array(column, polars_arrow_dtype)?; + columns.push(Series::from_arrow( + polars_schema.try_get_at_index(i)?.0, + polars_array, + )?); + } + + Ok(DataFrame::from_iter(columns)) +} + +/// Converts a polars-arrow Arrow array to an arrow-rs Arrow array. +pub fn convert_polars_arrow_array_to_arrow_rs_array( + polars_array: Box, + arrow_datatype: arrow_schema::DataType, +) -> std::result::Result { + let polars_c_array = polars_arrow::ffi::export_array_to_c(polars_array); + let arrow_c_array = unsafe { mem::transmute(polars_c_array) }; + Ok(arrow_array::make_array(unsafe { + arrow::ffi::from_ffi_and_data_type(arrow_c_array, arrow_datatype) + }?)) +} + +/// Converts an arrow-rs Arrow array to a polars-arrow Arrow array. +fn convert_arrow_rs_array_to_polars_arrow_array( + arrow_rs_array: &Arc, + polars_arrow_dtype: polars::datatypes::ArrowDataType, +) -> Result> { + let arrow_c_array = arrow::ffi::FFI_ArrowArray::new(&arrow_rs_array.to_data()); + let polars_c_array = unsafe { mem::transmute(arrow_c_array) }; + Ok(unsafe { polars_arrow::ffi::import_array_from_c(polars_c_array, polars_arrow_dtype) }?) +} + +fn convert_polars_arrow_field_to_arrow_rs_field( + polars_arrow_field: polars_arrow::datatypes::Field, +) -> Result { + let polars_c_schema = polars_arrow::ffi::export_field_to_c(&polars_arrow_field); + let arrow_c_schema: arrow::ffi::FFI_ArrowSchema = unsafe { mem::transmute(polars_c_schema) }; + let arrow_rs_dtype = arrow_schema::DataType::try_from(&arrow_c_schema)?; + Ok(arrow_schema::Field::new( + polars_arrow_field.name, + arrow_rs_dtype, + IS_ARRAY_NULLABLE, + )) +} + +fn convert_arrow_rs_field_to_polars_arrow_field( + arrow_rs_field: &arrow_schema::Field, +) -> Result { + let arrow_rs_dtype = arrow_rs_field.data_type(); + let arrow_c_schema = arrow::ffi::FFI_ArrowSchema::try_from(arrow_rs_dtype)?; + let polars_c_schema: polars_arrow::ffi::ArrowSchema = unsafe { mem::transmute(arrow_c_schema) }; + Ok(unsafe { polars_arrow::ffi::import_field_from_c(&polars_c_schema) }?) +}