feat(rust): add trait for incoming data (#1128)

This will make it easier for 3rd party integrations. They simply need to
implement `IntoArrow` for their types in order for those types to be
used in ingestion.
This commit is contained in:
Weston Pace
2024-03-19 07:15:49 -07:00
parent 85a9ef472f
commit abde77eafb
6 changed files with 135 additions and 54 deletions

View File

@@ -101,3 +101,21 @@ impl<S: Stream<Item = Result<arrow_array::RecordBatch>>> RecordBatchStream
self.schema.clone()
}
}
/// A trait for converting incoming data to Arrow
///
/// Integrations should implement this trait to allow data to be
/// imported directly from the integration. For example, implementing
/// this trait for `Vec<Vec<...>>` would allow the `Vec` to be directly
/// used in methods like [`crate::connection::Connection::create_table`]
/// or [`crate::table::Table::add`]
pub trait IntoArrow {
/// Convert the data into an Arrow array
fn into_arrow(self) -> Result<Box<dyn arrow_array::RecordBatchReader + Send>>;
}
impl<T: arrow_array::RecordBatchReader + Send + 'static> IntoArrow for Box<T> {
fn into_arrow(self) -> Result<Box<dyn arrow_array::RecordBatchReader + Send>> {
Ok(self)
}
}

View File

@@ -27,6 +27,7 @@ use object_store::{
};
use snafu::prelude::*;
use crate::arrow::IntoArrow;
use crate::error::{CreateDirSnafu, Error, InvalidTableNameSnafu, Result};
use crate::io::object_store::MirroringObjectStoreWrapper;
use crate::table::{NativeTable, WriteOptions};
@@ -116,23 +117,27 @@ impl TableNamesBuilder {
}
}
pub struct NoData {}
impl IntoArrow for NoData {
fn into_arrow(self) -> Result<Box<dyn arrow_array::RecordBatchReader + Send>> {
unreachable!("NoData should never be converted to Arrow")
}
}
/// A builder for configuring a [`Connection::create_table`] operation
pub struct CreateTableBuilder<const HAS_DATA: bool> {
pub struct CreateTableBuilder<const HAS_DATA: bool, T: IntoArrow> {
parent: Arc<dyn ConnectionInternal>,
pub(crate) name: String,
pub(crate) data: Option<Box<dyn RecordBatchReader + Send>>,
pub(crate) data: Option<T>,
pub(crate) schema: Option<SchemaRef>,
pub(crate) mode: CreateTableMode,
pub(crate) write_options: WriteOptions,
}
// Builder methods that only apply when we have initial data
impl CreateTableBuilder<true> {
fn new(
parent: Arc<dyn ConnectionInternal>,
name: String,
data: Box<dyn RecordBatchReader + Send>,
) -> Self {
impl<T: IntoArrow> CreateTableBuilder<true, T> {
fn new(parent: Arc<dyn ConnectionInternal>, name: String, data: T) -> Self {
Self {
parent,
name,
@@ -151,12 +156,32 @@ impl CreateTableBuilder<true> {
/// Execute the create table operation
pub async fn execute(self) -> Result<Table> {
self.parent.clone().do_create_table(self).await
let parent = self.parent.clone();
let (data, builder) = self.extract_data()?;
parent.do_create_table(builder, data).await
}
fn extract_data(
mut self,
) -> Result<(
Box<dyn RecordBatchReader + Send>,
CreateTableBuilder<false, NoData>,
)> {
let data = self.data.take().unwrap().into_arrow()?;
let builder = CreateTableBuilder::<false, NoData> {
parent: self.parent,
name: self.name,
data: None,
schema: self.schema,
mode: self.mode,
write_options: self.write_options,
};
Ok((data, builder))
}
}
// Builder methods that only apply when we do not have initial data
impl CreateTableBuilder<false> {
impl CreateTableBuilder<false, NoData> {
fn new(parent: Arc<dyn ConnectionInternal>, name: String, schema: SchemaRef) -> Self {
Self {
parent,
@@ -174,7 +199,7 @@ impl CreateTableBuilder<false> {
}
}
impl<const HAS_DATA: bool> CreateTableBuilder<HAS_DATA> {
impl<const HAS_DATA: bool, T: IntoArrow> CreateTableBuilder<HAS_DATA, T> {
/// Set the mode for creating the table
///
/// This controls what happens if a table with the given name already exists
@@ -237,17 +262,24 @@ pub(crate) trait ConnectionInternal:
Send + Sync + std::fmt::Debug + std::fmt::Display + 'static
{
async fn table_names(&self, options: TableNamesBuilder) -> Result<Vec<String>>;
async fn do_create_table(&self, options: CreateTableBuilder<true>) -> Result<Table>;
async fn do_create_table(
&self,
options: CreateTableBuilder<false, NoData>,
data: Box<dyn RecordBatchReader + Send>,
) -> Result<Table>;
async fn do_open_table(&self, options: OpenTableBuilder) -> Result<Table>;
async fn drop_table(&self, name: &str) -> Result<()>;
async fn drop_db(&self) -> Result<()>;
async fn do_create_empty_table(&self, options: CreateTableBuilder<false>) -> Result<Table> {
let batches = RecordBatchIterator::new(vec![], options.schema.unwrap());
let opts = CreateTableBuilder::<true>::new(options.parent, options.name, Box::new(batches))
.mode(options.mode)
.write_options(options.write_options);
self.do_create_table(opts).await
async fn do_create_empty_table(
&self,
options: CreateTableBuilder<false, NoData>,
) -> Result<Table> {
let batches = Box::new(RecordBatchIterator::new(
vec![],
options.schema.as_ref().unwrap().clone(),
));
self.do_create_table(options, batches).await
}
}
@@ -285,12 +317,12 @@ impl Connection {
///
/// * `name` - The name of the table
/// * `initial_data` - The initial data to write to the table
pub fn create_table(
pub fn create_table<T: IntoArrow>(
&self,
name: impl Into<String>,
initial_data: Box<dyn RecordBatchReader + Send>,
) -> CreateTableBuilder<true> {
CreateTableBuilder::<true>::new(self.internal.clone(), name.into(), initial_data)
initial_data: T,
) -> CreateTableBuilder<true, T> {
CreateTableBuilder::<true, T>::new(self.internal.clone(), name.into(), initial_data)
}
/// Create an empty table with a given schema
@@ -303,8 +335,8 @@ impl Connection {
&self,
name: impl Into<String>,
schema: SchemaRef,
) -> CreateTableBuilder<false> {
CreateTableBuilder::<false>::new(self.internal.clone(), name.into(), schema)
) -> CreateTableBuilder<false, NoData> {
CreateTableBuilder::<false, NoData>::new(self.internal.clone(), name.into(), schema)
}
/// Open an existing table in the database
@@ -694,7 +726,11 @@ impl ConnectionInternal for Database {
Ok(f)
}
async fn do_create_table(&self, options: CreateTableBuilder<true>) -> Result<Table> {
async fn do_create_table(
&self,
options: CreateTableBuilder<false, NoData>,
data: Box<dyn RecordBatchReader + Send>,
) -> Result<Table> {
let table_uri = self.table_uri(&options.name)?;
let mut write_params = options.write_options.lance_write_params.unwrap_or_default();
@@ -705,7 +741,7 @@ impl ConnectionInternal for Database {
match NativeTable::create(
&table_uri,
&options.name,
options.data.unwrap(),
data,
self.store_wrapper.clone(),
Some(write_params),
self.read_consistency_interval,

View File

@@ -77,7 +77,7 @@ impl Select {
/// may be installed. These models may accept something other than f32. For example,
/// sentence transformers typically expect the query to be a string. This means that
/// any kind of conversion library should expect to convert more than just f32.
pub trait ToQueryVector {
pub trait IntoQueryVector {
/// Convert the user's query vector input to a query vector
///
/// This trait exists to allow users to provide many different types as
@@ -112,7 +112,7 @@ pub trait ToQueryVector {
}
// TODO: perhaps support some casts like f32->f64 and maybe even f64->f32?
impl ToQueryVector for Arc<dyn Array> {
impl IntoQueryVector for Arc<dyn Array> {
fn to_query_vector(
self,
data_type: &DataType,
@@ -147,7 +147,7 @@ impl ToQueryVector for Arc<dyn Array> {
}
}
impl ToQueryVector for &dyn Array {
impl IntoQueryVector for &dyn Array {
fn to_query_vector(
self,
data_type: &DataType,
@@ -167,7 +167,7 @@ impl ToQueryVector for &dyn Array {
}
}
impl ToQueryVector for &[f16] {
impl IntoQueryVector for &[f16] {
fn to_query_vector(
self,
data_type: &DataType,
@@ -197,7 +197,7 @@ impl ToQueryVector for &[f16] {
}
}
impl ToQueryVector for &[f32] {
impl IntoQueryVector for &[f32] {
fn to_query_vector(
self,
data_type: &DataType,
@@ -227,7 +227,7 @@ impl ToQueryVector for &[f32] {
}
}
impl ToQueryVector for &[f64] {
impl IntoQueryVector for &[f64] {
fn to_query_vector(
self,
data_type: &DataType,
@@ -257,7 +257,7 @@ impl ToQueryVector for &[f64] {
}
}
impl<const N: usize> ToQueryVector for &[f16; N] {
impl<const N: usize> IntoQueryVector for &[f16; N] {
fn to_query_vector(
self,
data_type: &DataType,
@@ -268,7 +268,7 @@ impl<const N: usize> ToQueryVector for &[f16; N] {
}
}
impl<const N: usize> ToQueryVector for &[f32; N] {
impl<const N: usize> IntoQueryVector for &[f32; N] {
fn to_query_vector(
self,
data_type: &DataType,
@@ -279,7 +279,7 @@ impl<const N: usize> ToQueryVector for &[f32; N] {
}
}
impl<const N: usize> ToQueryVector for &[f64; N] {
impl<const N: usize> IntoQueryVector for &[f64; N] {
fn to_query_vector(
self,
data_type: &DataType,
@@ -290,7 +290,7 @@ impl<const N: usize> ToQueryVector for &[f64; N] {
}
}
impl ToQueryVector for Vec<f16> {
impl IntoQueryVector for Vec<f16> {
fn to_query_vector(
self,
data_type: &DataType,
@@ -301,7 +301,7 @@ impl ToQueryVector for Vec<f16> {
}
}
impl ToQueryVector for Vec<f32> {
impl IntoQueryVector for Vec<f32> {
fn to_query_vector(
self,
data_type: &DataType,
@@ -312,7 +312,7 @@ impl ToQueryVector for Vec<f32> {
}
}
impl ToQueryVector for Vec<f64> {
impl IntoQueryVector for Vec<f64> {
fn to_query_vector(
self,
data_type: &DataType,
@@ -530,7 +530,7 @@ impl Query {
/// # Arguments
///
/// * `vector` - The vector that will be used for search.
pub fn nearest_to(self, vector: impl ToQueryVector) -> Result<VectorQuery> {
pub fn nearest_to(self, vector: impl IntoQueryVector) -> Result<VectorQuery> {
let mut vector_query = self.into_vector();
let query_vector = vector.to_query_vector(&DataType::Float32, "default")?;
vector_query.query_vector = Some(query_vector);

View File

@@ -14,13 +14,14 @@
use std::sync::Arc;
use arrow_array::RecordBatchReader;
use async_trait::async_trait;
use reqwest::header::CONTENT_TYPE;
use serde::Deserialize;
use tokio::task::spawn_blocking;
use crate::connection::{
ConnectionInternal, CreateTableBuilder, OpenTableBuilder, TableNamesBuilder,
ConnectionInternal, CreateTableBuilder, NoData, OpenTableBuilder, TableNamesBuilder,
};
use crate::error::Result;
use crate::Table;
@@ -74,8 +75,11 @@ impl ConnectionInternal for RemoteDatabase {
Ok(rsp.json::<ListTablesResponse>().await?.tables)
}
async fn do_create_table(&self, options: CreateTableBuilder<true>) -> Result<Table> {
let data = options.data.unwrap();
async fn do_create_table(
&self,
options: CreateTableBuilder<false, NoData>,
data: Box<dyn RecordBatchReader + Send>,
) -> Result<Table> {
// TODO: https://github.com/lancedb/lancedb/issues/1026
// We should accept data from an async source. In the meantime, spawn this as blocking
// to make sure we don't block the tokio runtime if the source is slow.

View File

@@ -4,6 +4,7 @@ use async_trait::async_trait;
use lance::dataset::{scanner::DatasetRecordBatchStream, ColumnAlteration, NewColumnTransform};
use crate::{
connection::NoData,
error::Result,
index::{IndexBuilder, IndexConfig},
query::{Query, QueryExecutionOptions, VectorQuery},
@@ -63,7 +64,11 @@ impl TableInternal for RemoteTable {
async fn count_rows(&self, _filter: Option<String>) -> Result<usize> {
todo!()
}
async fn add(&self, _add: AddDataBuilder) -> Result<()> {
async fn add(
&self,
_add: AddDataBuilder<NoData>,
_data: Box<dyn RecordBatchReader + Send>,
) -> Result<()> {
todo!()
}
async fn plain_query(

View File

@@ -42,6 +42,8 @@ use lance_index::{optimize::OptimizeOptions, DatasetIndexExt};
use log::info;
use snafu::whatever;
use crate::arrow::IntoArrow;
use crate::connection::NoData;
use crate::error::{Error, Result};
use crate::index::vector::{IvfPqIndexBuilder, VectorIndex, VectorIndexStatistics};
use crate::index::IndexConfig;
@@ -50,7 +52,7 @@ use crate::index::{
Index, IndexBuilder,
};
use crate::query::{
Query, QueryExecutionOptions, Select, ToQueryVector, VectorQuery, DEFAULT_TOP_K,
IntoQueryVector, Query, QueryExecutionOptions, Select, VectorQuery, DEFAULT_TOP_K,
};
use crate::utils::{default_vector_column, PatchReadParam, PatchWriteParam};
@@ -124,14 +126,14 @@ pub enum AddDataMode {
/// A builder for configuring a [`crate::connection::Connection::create_table`] or [`Table::add`]
/// operation
pub struct AddDataBuilder {
pub struct AddDataBuilder<T: IntoArrow> {
parent: Arc<dyn TableInternal>,
pub(crate) data: Box<dyn RecordBatchReader + Send>,
pub(crate) data: T,
pub(crate) mode: AddDataMode,
pub(crate) write_options: WriteOptions,
}
impl std::fmt::Debug for AddDataBuilder {
impl<T: IntoArrow> std::fmt::Debug for AddDataBuilder<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AddDataBuilder")
.field("parent", &self.parent)
@@ -141,7 +143,7 @@ impl std::fmt::Debug for AddDataBuilder {
}
}
impl AddDataBuilder {
impl<T: IntoArrow> AddDataBuilder<T> {
pub fn mode(mut self, mode: AddDataMode) -> Self {
self.mode = mode;
self
@@ -153,7 +155,15 @@ impl AddDataBuilder {
}
pub async fn execute(self) -> Result<()> {
self.parent.clone().add(self).await
let parent = self.parent.clone();
let data = self.data.into_arrow()?;
let without_data = AddDataBuilder::<NoData> {
data: NoData {},
mode: self.mode,
parent: self.parent,
write_options: self.write_options,
};
parent.add(without_data, data).await
}
}
@@ -233,7 +243,6 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn
async fn schema(&self) -> Result<SchemaRef>;
/// Count the number of rows in this table.
async fn count_rows(&self, filter: Option<String>) -> Result<usize>;
async fn add(&self, add: AddDataBuilder) -> Result<()>;
async fn plain_query(
&self,
query: &Query,
@@ -244,6 +253,11 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn
query: &VectorQuery,
options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream>;
async fn add(
&self,
add: AddDataBuilder<NoData>,
data: Box<dyn arrow_array::RecordBatchReader + Send>,
) -> Result<()>;
async fn delete(&self, predicate: &str) -> Result<()>;
async fn update(&self, update: UpdateBuilder) -> Result<()>;
async fn create_index(&self, index: IndexBuilder) -> Result<()>;
@@ -319,7 +333,7 @@ impl Table {
///
/// * `batches` data to be added to the Table
/// * `options` options to control how data is added
pub fn add(&self, batches: Box<dyn RecordBatchReader + Send>) -> AddDataBuilder {
pub fn add<T: IntoArrow>(&self, batches: T) -> AddDataBuilder<T> {
AddDataBuilder {
parent: self.inner.clone(),
data: batches,
@@ -637,7 +651,7 @@ impl Table {
/// This is a convenience method for preparing a vector query and
/// is the same thing as calling `nearest_to` on the builder returned
/// by `query`. See [`Query::nearest_to`] for more details.
pub fn vector_search(&self, query: impl ToQueryVector) -> Result<VectorQuery> {
pub fn vector_search(&self, query: impl IntoQueryVector) -> Result<VectorQuery> {
self.query().nearest_to(query)
}
@@ -1288,7 +1302,11 @@ impl TableInternal for NativeTable {
}
}
async fn add(&self, add: AddDataBuilder) -> Result<()> {
async fn add(
&self,
add: AddDataBuilder<NoData>,
data: Box<dyn RecordBatchReader + Send>,
) -> Result<()> {
let lance_params = add.write_options.lance_write_params.unwrap_or(WriteParams {
mode: match add.mode {
AddDataMode::Append => WriteMode::Append,
@@ -1305,7 +1323,7 @@ impl TableInternal for NativeTable {
self.dataset.ensure_mutable().await?;
let dataset = Dataset::write(add.data, &self.uri, Some(lance_params)).await?;
let dataset = Dataset::write(data, &self.uri, Some(lance_params)).await?;
self.dataset.set_latest(dataset).await;
Ok(())
}