feat(rust): remote client write data endpoint (#1645)

* Implements:
  * Add
  * Update
  * Delete
  * Merge-Insert

---------

Co-authored-by: Weston Pace <weston.pace@gmail.com>
This commit is contained in:
Will Jones
2024-09-18 15:02:56 -07:00
committed by GitHub
parent ffb28dd4fc
commit 521e665f57
3 changed files with 445 additions and 47 deletions

View File

@@ -45,8 +45,8 @@ serde_json = { version = "1" }
async-openai = { version = "0.20.0", optional = true }
serde_with = { version = "3.8.1" }
# For remote feature
reqwest = { version = "0.11.24", features = ["gzip", "json"], optional = true }
http = { version = "0.2", optional = true } # Matching what is in reqwest
reqwest = { version = "0.12.0", features = ["gzip", "json", "stream"], optional = true }
http = { version = "1", optional = true } # Matching what is in reqwest
polars-arrow = { version = ">=0.37,<0.40.0", optional = true }
polars = { version = ">=0.37,<0.40.0", optional = true }
hf-hub = { version = "0.3.2", optional = true }
@@ -66,6 +66,7 @@ aws-sdk-s3 = { version = "1.38.0" }
aws-sdk-kms = { version = "1.37" }
aws-config = { version = "1.0" }
aws-smithy-runtime = { version = "1.3" }
http-body = "1" # Matching reqwest
[features]
default = []

View File

@@ -1,16 +1,18 @@
use std::sync::Arc;
use std::sync::{Arc, Mutex};
use crate::table::dataset::DatasetReadGuard;
use crate::table::AddDataMode;
use crate::Error;
use arrow_array::RecordBatchReader;
use arrow_schema::SchemaRef;
use async_trait::async_trait;
use datafusion_physical_plan::ExecutionPlan;
use http::header::CONTENT_TYPE;
use http::StatusCode;
use lance::arrow::json::JsonSchema;
use lance::dataset::scanner::{DatasetRecordBatchStream, Scanner};
use lance::dataset::{ColumnAlteration, NewColumnTransform};
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use crate::{
connection::NoData,
@@ -24,6 +26,7 @@ use crate::{
};
use super::client::{HttpSend, RestfulLanceDbClient, Sender};
use super::ARROW_STREAM_CONTENT_TYPE;
#[derive(Debug)]
pub struct RemoteTable<S: HttpSend = Sender> {
@@ -41,13 +44,7 @@ impl<S: HttpSend> RemoteTable<S> {
let request = self.client.post(&format!("/table/{}/describe/", self.name));
let response = self.client.send(request).await?;
if response.status() == StatusCode::NOT_FOUND {
return Err(Error::TableNotFound {
name: self.name.clone(),
});
}
let response = self.client.check_response(response).await?;
let response = self.check_table_response(response).await?;
let body = response.text().await?;
@@ -55,6 +52,39 @@ impl<S: HttpSend> RemoteTable<S> {
message: format!("Failed to parse table description: {}", e),
})
}
fn reader_as_body(data: Box<dyn RecordBatchReader + Send>) -> Result<reqwest::Body> {
// TODO: Once Phalanx supports compression, we should use it here.
let mut writer = arrow_ipc::writer::StreamWriter::try_new(Vec::new(), &data.schema())?;
// Mutex is just here to make it sync. We shouldn't have any contention.
let mut data = Mutex::new(data);
let body_iter = std::iter::from_fn(move || match data.get_mut().unwrap().next() {
Some(Ok(batch)) => {
writer.write(&batch).ok()?;
let buffer = std::mem::take(writer.get_mut());
Some(Ok(buffer))
}
Some(Err(e)) => Some(Err(e)),
None => {
writer.finish().ok()?;
let buffer = std::mem::take(writer.get_mut());
Some(Ok(buffer))
}
});
let body_stream = futures::stream::iter(body_iter);
Ok(reqwest::Body::wrap_stream(body_stream))
}
async fn check_table_response(&self, response: reqwest::Response) -> Result<reqwest::Response> {
if response.status() == StatusCode::NOT_FOUND {
return Err(Error::TableNotFound {
name: self.name.clone(),
});
}
self.client.check_response(response).await
}
}
#[derive(Deserialize)]
@@ -133,13 +163,7 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
let response = self.client.send(request).await?;
if response.status() == StatusCode::NOT_FOUND {
return Err(Error::TableNotFound {
name: self.name.clone(),
});
}
let response = self.client.check_response(response).await?;
let response = self.check_table_response(response).await?;
let body = response.text().await?;
@@ -149,10 +173,28 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
}
async fn add(
&self,
_add: AddDataBuilder<NoData>,
_data: Box<dyn RecordBatchReader + Send>,
add: AddDataBuilder<NoData>,
data: Box<dyn RecordBatchReader + Send>,
) -> Result<()> {
todo!()
let body = Self::reader_as_body(data)?;
let mut request = self
.client
.post(&format!("/table/{}/insert/", self.name))
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
.body(body);
match add.mode {
AddDataMode::Append => {}
AddDataMode::Overwrite => {
request = request.query(&[("mode", "overwrite")]);
}
}
let response = self.client.send(request).await?;
self.check_table_response(response).await?;
Ok(())
}
async fn build_plan(
&self,
@@ -160,71 +202,170 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
_query: &VectorQuery,
_options: Option<QueryExecutionOptions>,
) -> Result<Scanner> {
todo!()
Err(Error::NotSupported {
message: "build_plan is not supported on LanceDB cloud.".into(),
})
}
async fn create_plan(
&self,
_query: &VectorQuery,
_options: QueryExecutionOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
unimplemented!()
Err(Error::NotSupported {
message: "create_plan is not supported on LanceDB cloud.".into(),
})
}
async fn explain_plan(&self, _query: &VectorQuery, _verbose: bool) -> Result<String> {
todo!()
Err(Error::NotSupported {
message: "explain_plan is not supported on LanceDB cloud.".into(),
})
}
async fn plain_query(
&self,
_query: &Query,
_options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream> {
todo!()
Err(Error::NotSupported {
message: "plain_query is not yet supported on LanceDB cloud.".into(),
})
}
async fn update(&self, _update: UpdateBuilder) -> Result<()> {
todo!()
async fn update(&self, update: UpdateBuilder) -> Result<()> {
let request = self.client.post(&format!("/table/{}/update/", self.name));
let mut updates = Vec::new();
for (column, expression) in update.columns {
updates.push(column);
updates.push(expression);
}
let request = request.json(&serde_json::json!({
"updates": updates,
"only_if": update.filter,
}));
let response = self.client.send(request).await?;
self.check_table_response(response).await?;
Ok(())
}
async fn delete(&self, _predicate: &str) -> Result<()> {
todo!()
async fn delete(&self, predicate: &str) -> Result<()> {
let body = serde_json::json!({ "predicate": predicate });
let request = self
.client
.post(&format!("/table/{}/delete/", self.name))
.json(&body);
let response = self.client.send(request).await?;
self.check_table_response(response).await?;
Ok(())
}
async fn create_index(&self, _index: IndexBuilder) -> Result<()> {
todo!()
Err(Error::NotSupported {
message: "create_index is not yet supported on LanceDB cloud.".into(),
})
}
async fn merge_insert(
&self,
_params: MergeInsertBuilder,
_new_data: Box<dyn RecordBatchReader + Send>,
params: MergeInsertBuilder,
new_data: Box<dyn RecordBatchReader + Send>,
) -> Result<()> {
todo!()
let query = MergeInsertRequest::try_from(params)?;
let body = Self::reader_as_body(new_data)?;
let request = self
.client
.post(&format!("/table/{}/merge_insert/", self.name))
.query(&query)
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
.body(body);
let response = self.client.send(request).await?;
self.check_table_response(response).await?;
Ok(())
}
async fn optimize(&self, _action: OptimizeAction) -> Result<OptimizeStats> {
todo!()
Err(Error::NotSupported {
message: "optimize is not supported on LanceDB cloud.".into(),
})
}
async fn add_columns(
&self,
_transforms: NewColumnTransform,
_read_columns: Option<Vec<String>>,
) -> Result<()> {
todo!()
Err(Error::NotSupported {
message: "add_columns is not yet supported.".into(),
})
}
async fn alter_columns(&self, _alterations: &[ColumnAlteration]) -> Result<()> {
todo!()
Err(Error::NotSupported {
message: "alter_columns is not yet supported.".into(),
})
}
async fn drop_columns(&self, _columns: &[&str]) -> Result<()> {
todo!()
Err(Error::NotSupported {
message: "drop_columns is not yet supported.".into(),
})
}
async fn list_indices(&self) -> Result<Vec<IndexConfig>> {
todo!()
Err(Error::NotSupported {
message: "list_indices is not yet supported.".into(),
})
}
async fn table_definition(&self) -> Result<TableDefinition> {
todo!()
Err(Error::NotSupported {
message: "table_definition is not supported on LanceDB cloud.".into(),
})
}
}
#[derive(Serialize)]
struct MergeInsertRequest {
on: String,
when_matched_update_all: bool,
when_matched_update_all_filt: Option<String>,
when_not_matched_insert_all: bool,
when_not_matched_by_source_delete: bool,
when_not_matched_by_source_delete_filt: Option<String>,
}
impl TryFrom<MergeInsertBuilder> for MergeInsertRequest {
type Error = Error;
fn try_from(value: MergeInsertBuilder) -> Result<Self> {
if value.on.is_empty() {
return Err(Error::InvalidInput {
message: "MergeInsertBuilder missing required 'on' field".into(),
});
} else if value.on.len() > 1 {
return Err(Error::NotSupported {
message: "MergeInsertBuilder only supports a single 'on' column".into(),
});
}
let on = value.on[0].clone();
Ok(Self {
on,
when_matched_update_all: value.when_matched_update_all,
when_matched_update_all_filt: value.when_matched_update_all_filt,
when_not_matched_insert_all: value.when_not_matched_insert_all,
when_not_matched_by_source_delete: value.when_not_matched_by_source_delete,
when_not_matched_by_source_delete_filt: value.when_not_matched_by_source_delete_filt,
})
}
}
#[cfg(test)]
mod tests {
use std::{collections::HashMap, pin::Pin};
use super::*;
use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator};
use arrow_schema::{DataType, Field, Schema};
use futures::{future::BoxFuture, TryFutureExt};
use futures::{future::BoxFuture, StreamExt, TryFutureExt};
use reqwest::Body;
use crate::{Error, Table};
@@ -237,12 +378,27 @@ mod tests {
.unwrap()
});
let batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap();
let example_data = || {
Box::new(RecordBatchIterator::new(
[Ok(batch.clone())],
batch.schema(),
))
};
// All endpoints should translate 404 to TableNotFound.
let results: Vec<BoxFuture<'_, Result<()>>> = vec![
Box::pin(table.version().map_ok(|_| ())),
Box::pin(table.schema().map_ok(|_| ())),
Box::pin(table.count_rows(None).map_ok(|_| ())),
// TODO: other endpoints.
Box::pin(table.update().column("a", "a + 1").execute()),
Box::pin(table.add(example_data()).execute().map_ok(|_| ())),
Box::pin(table.merge_insert(&["test"]).execute(example_data())),
Box::pin(table.delete("false")), // TODO: other endpoints.
];
for result in results {
@@ -324,4 +480,245 @@ mod tests {
let count = table.count_rows(Some("a > 10".into())).await.unwrap();
assert_eq!(count, 42);
}
async fn collect_body(body: Body) -> Vec<u8> {
use http_body::Body;
let mut body = body;
let mut data = Vec::new();
let mut body_pin = Pin::new(&mut body);
futures::stream::poll_fn(|cx| body_pin.as_mut().poll_frame(cx))
.for_each(|frame| {
data.extend_from_slice(frame.unwrap().data_ref().unwrap());
futures::future::ready(())
})
.await;
data
}
fn write_ipc_stream(data: &RecordBatch) -> Vec<u8> {
let mut body = Vec::new();
{
let mut writer = arrow_ipc::writer::StreamWriter::try_new(&mut body, &data.schema())
.expect("Failed to create writer");
writer.write(data).expect("Failed to write data");
writer.finish().expect("Failed to finish");
}
body
}
#[tokio::test]
async fn test_add_append() {
let data = RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap();
let (sender, receiver) = std::sync::mpsc::channel();
let table = Table::new_with_handler("my_table", move |mut request| {
assert_eq!(request.method(), "POST");
assert_eq!(request.url().path(), "/table/my_table/insert/");
// If mode is specified, it should be "append". Append is default
// so it's not required.
assert!(request
.url()
.query_pairs()
.filter(|(k, _)| k == "mode")
.all(|(_, v)| v == "append"));
assert_eq!(
request.headers().get("Content-Type").unwrap(),
ARROW_STREAM_CONTENT_TYPE
);
let mut body_out = reqwest::Body::from(Vec::new());
std::mem::swap(request.body_mut().as_mut().unwrap(), &mut body_out);
sender.send(body_out).unwrap();
http::Response::builder().status(200).body("").unwrap()
});
table
.add(RecordBatchIterator::new([Ok(data.clone())], data.schema()))
.execute()
.await
.unwrap();
let body = receiver.recv().unwrap();
let body = collect_body(body).await;
let expected_body = write_ipc_stream(&data);
assert_eq!(&body, &expected_body);
}
#[tokio::test]
async fn test_add_overwrite() {
let data = RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap();
let (sender, receiver) = std::sync::mpsc::channel();
let table = Table::new_with_handler("my_table", move |mut request| {
assert_eq!(request.method(), "POST");
assert_eq!(request.url().path(), "/table/my_table/insert/");
assert_eq!(
request
.url()
.query_pairs()
.find(|(k, _)| k == "mode")
.map(|kv| kv.1)
.as_deref(),
Some("overwrite"),
"Expected mode=overwrite"
);
assert_eq!(
request.headers().get("Content-Type").unwrap(),
ARROW_STREAM_CONTENT_TYPE
);
let mut body_out = reqwest::Body::from(Vec::new());
std::mem::swap(request.body_mut().as_mut().unwrap(), &mut body_out);
sender.send(body_out).unwrap();
http::Response::builder().status(200).body("").unwrap()
});
table
.add(RecordBatchIterator::new([Ok(data.clone())], data.schema()))
.mode(AddDataMode::Overwrite)
.execute()
.await
.unwrap();
let body = receiver.recv().unwrap();
let body = collect_body(body).await;
let expected_body = write_ipc_stream(&data);
assert_eq!(&body, &expected_body);
}
#[tokio::test]
async fn test_update() {
let table = Table::new_with_handler("my_table", |request| {
assert_eq!(request.method(), "POST");
assert_eq!(request.url().path(), "/table/my_table/update/");
if let Some(body) = request.body().unwrap().as_bytes() {
let body = std::str::from_utf8(body).unwrap();
let value: serde_json::Value = serde_json::from_str(body).unwrap();
let updates = value.get("updates").unwrap().as_array().unwrap();
assert!(updates.len() == 2);
let col_name = updates[0].as_str().unwrap();
let expression = updates[1].as_str().unwrap();
assert_eq!(col_name, "a");
assert_eq!(expression, "a + 1");
let only_if = value.get("only_if").unwrap().as_str().unwrap();
assert_eq!(only_if, "b > 10");
}
http::Response::builder().status(200).body("").unwrap()
});
table
.update()
.column("a", "a + 1")
.only_if("b > 10")
.execute()
.await
.unwrap();
}
#[tokio::test]
async fn test_merge_insert() {
let batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap();
let data = Box::new(RecordBatchIterator::new(
[Ok(batch.clone())],
batch.schema(),
));
// Default parameters
let table = Table::new_with_handler("my_table", |request| {
assert_eq!(request.method(), "POST");
assert_eq!(request.url().path(), "/table/my_table/merge_insert/");
let params = request.url().query_pairs().collect::<HashMap<_, _>>();
assert_eq!(params["on"], "some_col");
assert_eq!(params["when_matched_update_all"], "false");
assert_eq!(params["when_not_matched_insert_all"], "false");
assert_eq!(params["when_not_matched_by_source_delete"], "false");
assert!(!params.contains_key("when_matched_update_all_filt"));
assert!(!params.contains_key("when_not_matched_by_source_delete_filt"));
http::Response::builder().status(200).body("").unwrap()
});
table
.merge_insert(&["some_col"])
.execute(data)
.await
.unwrap();
// All parameters specified
let (sender, receiver) = std::sync::mpsc::channel();
let table = Table::new_with_handler("my_table", move |mut request| {
assert_eq!(request.method(), "POST");
assert_eq!(request.url().path(), "/table/my_table/merge_insert/");
assert_eq!(
request.headers().get("Content-Type").unwrap(),
ARROW_STREAM_CONTENT_TYPE
);
let params = request.url().query_pairs().collect::<HashMap<_, _>>();
assert_eq!(params["on"], "some_col");
assert_eq!(params["when_matched_update_all"], "true");
assert_eq!(params["when_not_matched_insert_all"], "false");
assert_eq!(params["when_not_matched_by_source_delete"], "true");
assert_eq!(params["when_matched_update_all_filt"], "a = 1");
assert_eq!(params["when_not_matched_by_source_delete_filt"], "b = 2");
let mut body_out = reqwest::Body::from(Vec::new());
std::mem::swap(request.body_mut().as_mut().unwrap(), &mut body_out);
sender.send(body_out).unwrap();
http::Response::builder().status(200).body("").unwrap()
});
let mut builder = table.merge_insert(&["some_col"]);
builder
.when_matched_update_all(Some("a = 1".into()))
.when_not_matched_by_source_delete(Some("b = 2".into()));
let data = Box::new(RecordBatchIterator::new(
[Ok(batch.clone())],
batch.schema(),
));
builder.execute(data).await.unwrap();
let body = receiver.recv().unwrap();
let body = collect_body(body).await;
let expected_body = write_ipc_stream(&batch);
assert_eq!(&body, &expected_body);
}
#[tokio::test]
async fn test_delete() {
let table = Table::new_with_handler("my_table", |request| {
assert_eq!(request.method(), "POST");
assert_eq!(request.url().path(), "/table/my_table/delete/");
let body = request.body().unwrap().as_bytes().unwrap();
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
let predicate = body.get("predicate").unwrap().as_str().unwrap();
assert_eq!(predicate, "id in (1, 2, 3)");
http::Response::builder().status(200).body("").unwrap()
});
table.delete("id in (1, 2, 3)").await.unwrap();
}
}

View File

@@ -26,12 +26,12 @@ use super::TableInternal;
#[derive(Debug, Clone)]
pub struct MergeInsertBuilder {
table: Arc<dyn TableInternal>,
pub(super) on: Vec<String>,
pub(super) when_matched_update_all: bool,
pub(super) when_matched_update_all_filt: Option<String>,
pub(super) when_not_matched_insert_all: bool,
pub(super) when_not_matched_by_source_delete: bool,
pub(super) when_not_matched_by_source_delete_filt: Option<String>,
pub(crate) on: Vec<String>,
pub(crate) when_matched_update_all: bool,
pub(crate) when_matched_update_all_filt: Option<String>,
pub(crate) when_not_matched_insert_all: bool,
pub(crate) when_not_matched_by_source_delete: bool,
pub(crate) when_not_matched_by_source_delete_filt: Option<String>,
}
impl MergeInsertBuilder {