mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-10 13:52:58 +00:00
feat(rust): remote client write data endpoint (#1645)
* Implements: * Add * Update * Delete * Merge-Insert --------- Co-authored-by: Weston Pace <weston.pace@gmail.com>
This commit is contained in:
@@ -45,8 +45,8 @@ serde_json = { version = "1" }
|
||||
async-openai = { version = "0.20.0", optional = true }
|
||||
serde_with = { version = "3.8.1" }
|
||||
# For remote feature
|
||||
reqwest = { version = "0.11.24", features = ["gzip", "json"], optional = true }
|
||||
http = { version = "0.2", optional = true } # Matching what is in reqwest
|
||||
reqwest = { version = "0.12.0", features = ["gzip", "json", "stream"], optional = true }
|
||||
http = { version = "1", optional = true } # Matching what is in reqwest
|
||||
polars-arrow = { version = ">=0.37,<0.40.0", optional = true }
|
||||
polars = { version = ">=0.37,<0.40.0", optional = true }
|
||||
hf-hub = { version = "0.3.2", optional = true }
|
||||
@@ -66,6 +66,7 @@ aws-sdk-s3 = { version = "1.38.0" }
|
||||
aws-sdk-kms = { version = "1.37" }
|
||||
aws-config = { version = "1.0" }
|
||||
aws-smithy-runtime = { version = "1.3" }
|
||||
http-body = "1" # Matching reqwest
|
||||
|
||||
[features]
|
||||
default = []
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use crate::table::dataset::DatasetReadGuard;
|
||||
use crate::table::AddDataMode;
|
||||
use crate::Error;
|
||||
use arrow_array::RecordBatchReader;
|
||||
use arrow_schema::SchemaRef;
|
||||
use async_trait::async_trait;
|
||||
use datafusion_physical_plan::ExecutionPlan;
|
||||
use http::header::CONTENT_TYPE;
|
||||
use http::StatusCode;
|
||||
use lance::arrow::json::JsonSchema;
|
||||
use lance::dataset::scanner::{DatasetRecordBatchStream, Scanner};
|
||||
use lance::dataset::{ColumnAlteration, NewColumnTransform};
|
||||
use serde::Deserialize;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
connection::NoData,
|
||||
@@ -24,6 +26,7 @@ use crate::{
|
||||
};
|
||||
|
||||
use super::client::{HttpSend, RestfulLanceDbClient, Sender};
|
||||
use super::ARROW_STREAM_CONTENT_TYPE;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RemoteTable<S: HttpSend = Sender> {
|
||||
@@ -41,13 +44,7 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
let request = self.client.post(&format!("/table/{}/describe/", self.name));
|
||||
let response = self.client.send(request).await?;
|
||||
|
||||
if response.status() == StatusCode::NOT_FOUND {
|
||||
return Err(Error::TableNotFound {
|
||||
name: self.name.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
let response = self.client.check_response(response).await?;
|
||||
let response = self.check_table_response(response).await?;
|
||||
|
||||
let body = response.text().await?;
|
||||
|
||||
@@ -55,6 +52,39 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
message: format!("Failed to parse table description: {}", e),
|
||||
})
|
||||
}
|
||||
|
||||
fn reader_as_body(data: Box<dyn RecordBatchReader + Send>) -> Result<reqwest::Body> {
|
||||
// TODO: Once Phalanx supports compression, we should use it here.
|
||||
let mut writer = arrow_ipc::writer::StreamWriter::try_new(Vec::new(), &data.schema())?;
|
||||
|
||||
// Mutex is just here to make it sync. We shouldn't have any contention.
|
||||
let mut data = Mutex::new(data);
|
||||
let body_iter = std::iter::from_fn(move || match data.get_mut().unwrap().next() {
|
||||
Some(Ok(batch)) => {
|
||||
writer.write(&batch).ok()?;
|
||||
let buffer = std::mem::take(writer.get_mut());
|
||||
Some(Ok(buffer))
|
||||
}
|
||||
Some(Err(e)) => Some(Err(e)),
|
||||
None => {
|
||||
writer.finish().ok()?;
|
||||
let buffer = std::mem::take(writer.get_mut());
|
||||
Some(Ok(buffer))
|
||||
}
|
||||
});
|
||||
let body_stream = futures::stream::iter(body_iter);
|
||||
Ok(reqwest::Body::wrap_stream(body_stream))
|
||||
}
|
||||
|
||||
async fn check_table_response(&self, response: reqwest::Response) -> Result<reqwest::Response> {
|
||||
if response.status() == StatusCode::NOT_FOUND {
|
||||
return Err(Error::TableNotFound {
|
||||
name: self.name.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
self.client.check_response(response).await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -133,13 +163,7 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
|
||||
let response = self.client.send(request).await?;
|
||||
|
||||
if response.status() == StatusCode::NOT_FOUND {
|
||||
return Err(Error::TableNotFound {
|
||||
name: self.name.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
let response = self.client.check_response(response).await?;
|
||||
let response = self.check_table_response(response).await?;
|
||||
|
||||
let body = response.text().await?;
|
||||
|
||||
@@ -149,10 +173,28 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
}
|
||||
async fn add(
|
||||
&self,
|
||||
_add: AddDataBuilder<NoData>,
|
||||
_data: Box<dyn RecordBatchReader + Send>,
|
||||
add: AddDataBuilder<NoData>,
|
||||
data: Box<dyn RecordBatchReader + Send>,
|
||||
) -> Result<()> {
|
||||
todo!()
|
||||
let body = Self::reader_as_body(data)?;
|
||||
let mut request = self
|
||||
.client
|
||||
.post(&format!("/table/{}/insert/", self.name))
|
||||
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
|
||||
.body(body);
|
||||
|
||||
match add.mode {
|
||||
AddDataMode::Append => {}
|
||||
AddDataMode::Overwrite => {
|
||||
request = request.query(&[("mode", "overwrite")]);
|
||||
}
|
||||
}
|
||||
|
||||
let response = self.client.send(request).await?;
|
||||
|
||||
self.check_table_response(response).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
async fn build_plan(
|
||||
&self,
|
||||
@@ -160,71 +202,170 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
_query: &VectorQuery,
|
||||
_options: Option<QueryExecutionOptions>,
|
||||
) -> Result<Scanner> {
|
||||
todo!()
|
||||
Err(Error::NotSupported {
|
||||
message: "build_plan is not supported on LanceDB cloud.".into(),
|
||||
})
|
||||
}
|
||||
async fn create_plan(
|
||||
&self,
|
||||
_query: &VectorQuery,
|
||||
_options: QueryExecutionOptions,
|
||||
) -> Result<Arc<dyn ExecutionPlan>> {
|
||||
unimplemented!()
|
||||
Err(Error::NotSupported {
|
||||
message: "create_plan is not supported on LanceDB cloud.".into(),
|
||||
})
|
||||
}
|
||||
async fn explain_plan(&self, _query: &VectorQuery, _verbose: bool) -> Result<String> {
|
||||
todo!()
|
||||
Err(Error::NotSupported {
|
||||
message: "explain_plan is not supported on LanceDB cloud.".into(),
|
||||
})
|
||||
}
|
||||
async fn plain_query(
|
||||
&self,
|
||||
_query: &Query,
|
||||
_options: QueryExecutionOptions,
|
||||
) -> Result<DatasetRecordBatchStream> {
|
||||
todo!()
|
||||
Err(Error::NotSupported {
|
||||
message: "plain_query is not yet supported on LanceDB cloud.".into(),
|
||||
})
|
||||
}
|
||||
async fn update(&self, _update: UpdateBuilder) -> Result<()> {
|
||||
todo!()
|
||||
async fn update(&self, update: UpdateBuilder) -> Result<()> {
|
||||
let request = self.client.post(&format!("/table/{}/update/", self.name));
|
||||
|
||||
let mut updates = Vec::new();
|
||||
for (column, expression) in update.columns {
|
||||
updates.push(column);
|
||||
updates.push(expression);
|
||||
}
|
||||
|
||||
let request = request.json(&serde_json::json!({
|
||||
"updates": updates,
|
||||
"only_if": update.filter,
|
||||
}));
|
||||
|
||||
let response = self.client.send(request).await?;
|
||||
|
||||
self.check_table_response(response).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
async fn delete(&self, _predicate: &str) -> Result<()> {
|
||||
todo!()
|
||||
async fn delete(&self, predicate: &str) -> Result<()> {
|
||||
let body = serde_json::json!({ "predicate": predicate });
|
||||
let request = self
|
||||
.client
|
||||
.post(&format!("/table/{}/delete/", self.name))
|
||||
.json(&body);
|
||||
let response = self.client.send(request).await?;
|
||||
self.check_table_response(response).await?;
|
||||
Ok(())
|
||||
}
|
||||
async fn create_index(&self, _index: IndexBuilder) -> Result<()> {
|
||||
todo!()
|
||||
Err(Error::NotSupported {
|
||||
message: "create_index is not yet supported on LanceDB cloud.".into(),
|
||||
})
|
||||
}
|
||||
async fn merge_insert(
|
||||
&self,
|
||||
_params: MergeInsertBuilder,
|
||||
_new_data: Box<dyn RecordBatchReader + Send>,
|
||||
params: MergeInsertBuilder,
|
||||
new_data: Box<dyn RecordBatchReader + Send>,
|
||||
) -> Result<()> {
|
||||
todo!()
|
||||
let query = MergeInsertRequest::try_from(params)?;
|
||||
let body = Self::reader_as_body(new_data)?;
|
||||
let request = self
|
||||
.client
|
||||
.post(&format!("/table/{}/merge_insert/", self.name))
|
||||
.query(&query)
|
||||
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
|
||||
.body(body);
|
||||
|
||||
let response = self.client.send(request).await?;
|
||||
|
||||
self.check_table_response(response).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
async fn optimize(&self, _action: OptimizeAction) -> Result<OptimizeStats> {
|
||||
todo!()
|
||||
Err(Error::NotSupported {
|
||||
message: "optimize is not supported on LanceDB cloud.".into(),
|
||||
})
|
||||
}
|
||||
async fn add_columns(
|
||||
&self,
|
||||
_transforms: NewColumnTransform,
|
||||
_read_columns: Option<Vec<String>>,
|
||||
) -> Result<()> {
|
||||
todo!()
|
||||
Err(Error::NotSupported {
|
||||
message: "add_columns is not yet supported.".into(),
|
||||
})
|
||||
}
|
||||
async fn alter_columns(&self, _alterations: &[ColumnAlteration]) -> Result<()> {
|
||||
todo!()
|
||||
Err(Error::NotSupported {
|
||||
message: "alter_columns is not yet supported.".into(),
|
||||
})
|
||||
}
|
||||
async fn drop_columns(&self, _columns: &[&str]) -> Result<()> {
|
||||
todo!()
|
||||
Err(Error::NotSupported {
|
||||
message: "drop_columns is not yet supported.".into(),
|
||||
})
|
||||
}
|
||||
async fn list_indices(&self) -> Result<Vec<IndexConfig>> {
|
||||
todo!()
|
||||
Err(Error::NotSupported {
|
||||
message: "list_indices is not yet supported.".into(),
|
||||
})
|
||||
}
|
||||
async fn table_definition(&self) -> Result<TableDefinition> {
|
||||
todo!()
|
||||
Err(Error::NotSupported {
|
||||
message: "table_definition is not supported on LanceDB cloud.".into(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct MergeInsertRequest {
|
||||
on: String,
|
||||
when_matched_update_all: bool,
|
||||
when_matched_update_all_filt: Option<String>,
|
||||
when_not_matched_insert_all: bool,
|
||||
when_not_matched_by_source_delete: bool,
|
||||
when_not_matched_by_source_delete_filt: Option<String>,
|
||||
}
|
||||
|
||||
impl TryFrom<MergeInsertBuilder> for MergeInsertRequest {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(value: MergeInsertBuilder) -> Result<Self> {
|
||||
if value.on.is_empty() {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "MergeInsertBuilder missing required 'on' field".into(),
|
||||
});
|
||||
} else if value.on.len() > 1 {
|
||||
return Err(Error::NotSupported {
|
||||
message: "MergeInsertBuilder only supports a single 'on' column".into(),
|
||||
});
|
||||
}
|
||||
let on = value.on[0].clone();
|
||||
|
||||
Ok(Self {
|
||||
on,
|
||||
when_matched_update_all: value.when_matched_update_all,
|
||||
when_matched_update_all_filt: value.when_matched_update_all_filt,
|
||||
when_not_matched_insert_all: value.when_not_matched_insert_all,
|
||||
when_not_matched_by_source_delete: value.when_not_matched_by_source_delete,
|
||||
when_not_matched_by_source_delete_filt: value.when_not_matched_by_source_delete_filt,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{collections::HashMap, pin::Pin};
|
||||
|
||||
use super::*;
|
||||
|
||||
use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use futures::{future::BoxFuture, TryFutureExt};
|
||||
use futures::{future::BoxFuture, StreamExt, TryFutureExt};
|
||||
use reqwest::Body;
|
||||
|
||||
use crate::{Error, Table};
|
||||
|
||||
@@ -237,12 +378,27 @@ mod tests {
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
let batch = RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
|
||||
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
|
||||
)
|
||||
.unwrap();
|
||||
let example_data = || {
|
||||
Box::new(RecordBatchIterator::new(
|
||||
[Ok(batch.clone())],
|
||||
batch.schema(),
|
||||
))
|
||||
};
|
||||
|
||||
// All endpoints should translate 404 to TableNotFound.
|
||||
let results: Vec<BoxFuture<'_, Result<()>>> = vec![
|
||||
Box::pin(table.version().map_ok(|_| ())),
|
||||
Box::pin(table.schema().map_ok(|_| ())),
|
||||
Box::pin(table.count_rows(None).map_ok(|_| ())),
|
||||
// TODO: other endpoints.
|
||||
Box::pin(table.update().column("a", "a + 1").execute()),
|
||||
Box::pin(table.add(example_data()).execute().map_ok(|_| ())),
|
||||
Box::pin(table.merge_insert(&["test"]).execute(example_data())),
|
||||
Box::pin(table.delete("false")), // TODO: other endpoints.
|
||||
];
|
||||
|
||||
for result in results {
|
||||
@@ -324,4 +480,245 @@ mod tests {
|
||||
let count = table.count_rows(Some("a > 10".into())).await.unwrap();
|
||||
assert_eq!(count, 42);
|
||||
}
|
||||
|
||||
async fn collect_body(body: Body) -> Vec<u8> {
|
||||
use http_body::Body;
|
||||
let mut body = body;
|
||||
let mut data = Vec::new();
|
||||
let mut body_pin = Pin::new(&mut body);
|
||||
futures::stream::poll_fn(|cx| body_pin.as_mut().poll_frame(cx))
|
||||
.for_each(|frame| {
|
||||
data.extend_from_slice(frame.unwrap().data_ref().unwrap());
|
||||
futures::future::ready(())
|
||||
})
|
||||
.await;
|
||||
data
|
||||
}
|
||||
|
||||
fn write_ipc_stream(data: &RecordBatch) -> Vec<u8> {
|
||||
let mut body = Vec::new();
|
||||
{
|
||||
let mut writer = arrow_ipc::writer::StreamWriter::try_new(&mut body, &data.schema())
|
||||
.expect("Failed to create writer");
|
||||
writer.write(data).expect("Failed to write data");
|
||||
writer.finish().expect("Failed to finish");
|
||||
}
|
||||
body
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_add_append() {
|
||||
let data = RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
|
||||
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let (sender, receiver) = std::sync::mpsc::channel();
|
||||
let table = Table::new_with_handler("my_table", move |mut request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/table/my_table/insert/");
|
||||
// If mode is specified, it should be "append". Append is default
|
||||
// so it's not required.
|
||||
assert!(request
|
||||
.url()
|
||||
.query_pairs()
|
||||
.filter(|(k, _)| k == "mode")
|
||||
.all(|(_, v)| v == "append"));
|
||||
|
||||
assert_eq!(
|
||||
request.headers().get("Content-Type").unwrap(),
|
||||
ARROW_STREAM_CONTENT_TYPE
|
||||
);
|
||||
|
||||
let mut body_out = reqwest::Body::from(Vec::new());
|
||||
std::mem::swap(request.body_mut().as_mut().unwrap(), &mut body_out);
|
||||
sender.send(body_out).unwrap();
|
||||
|
||||
http::Response::builder().status(200).body("").unwrap()
|
||||
});
|
||||
|
||||
table
|
||||
.add(RecordBatchIterator::new([Ok(data.clone())], data.schema()))
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let body = receiver.recv().unwrap();
|
||||
let body = collect_body(body).await;
|
||||
let expected_body = write_ipc_stream(&data);
|
||||
assert_eq!(&body, &expected_body);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_add_overwrite() {
|
||||
let data = RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
|
||||
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let (sender, receiver) = std::sync::mpsc::channel();
|
||||
let table = Table::new_with_handler("my_table", move |mut request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/table/my_table/insert/");
|
||||
assert_eq!(
|
||||
request
|
||||
.url()
|
||||
.query_pairs()
|
||||
.find(|(k, _)| k == "mode")
|
||||
.map(|kv| kv.1)
|
||||
.as_deref(),
|
||||
Some("overwrite"),
|
||||
"Expected mode=overwrite"
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
request.headers().get("Content-Type").unwrap(),
|
||||
ARROW_STREAM_CONTENT_TYPE
|
||||
);
|
||||
|
||||
let mut body_out = reqwest::Body::from(Vec::new());
|
||||
std::mem::swap(request.body_mut().as_mut().unwrap(), &mut body_out);
|
||||
sender.send(body_out).unwrap();
|
||||
|
||||
http::Response::builder().status(200).body("").unwrap()
|
||||
});
|
||||
|
||||
table
|
||||
.add(RecordBatchIterator::new([Ok(data.clone())], data.schema()))
|
||||
.mode(AddDataMode::Overwrite)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let body = receiver.recv().unwrap();
|
||||
let body = collect_body(body).await;
|
||||
let expected_body = write_ipc_stream(&data);
|
||||
assert_eq!(&body, &expected_body);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update() {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/table/my_table/update/");
|
||||
|
||||
if let Some(body) = request.body().unwrap().as_bytes() {
|
||||
let body = std::str::from_utf8(body).unwrap();
|
||||
let value: serde_json::Value = serde_json::from_str(body).unwrap();
|
||||
let updates = value.get("updates").unwrap().as_array().unwrap();
|
||||
assert!(updates.len() == 2);
|
||||
|
||||
let col_name = updates[0].as_str().unwrap();
|
||||
let expression = updates[1].as_str().unwrap();
|
||||
assert_eq!(col_name, "a");
|
||||
assert_eq!(expression, "a + 1");
|
||||
|
||||
let only_if = value.get("only_if").unwrap().as_str().unwrap();
|
||||
assert_eq!(only_if, "b > 10");
|
||||
}
|
||||
|
||||
http::Response::builder().status(200).body("").unwrap()
|
||||
});
|
||||
|
||||
table
|
||||
.update()
|
||||
.column("a", "a + 1")
|
||||
.only_if("b > 10")
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_merge_insert() {
|
||||
let batch = RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
|
||||
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
|
||||
)
|
||||
.unwrap();
|
||||
let data = Box::new(RecordBatchIterator::new(
|
||||
[Ok(batch.clone())],
|
||||
batch.schema(),
|
||||
));
|
||||
|
||||
// Default parameters
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/table/my_table/merge_insert/");
|
||||
|
||||
let params = request.url().query_pairs().collect::<HashMap<_, _>>();
|
||||
assert_eq!(params["on"], "some_col");
|
||||
assert_eq!(params["when_matched_update_all"], "false");
|
||||
assert_eq!(params["when_not_matched_insert_all"], "false");
|
||||
assert_eq!(params["when_not_matched_by_source_delete"], "false");
|
||||
assert!(!params.contains_key("when_matched_update_all_filt"));
|
||||
assert!(!params.contains_key("when_not_matched_by_source_delete_filt"));
|
||||
|
||||
http::Response::builder().status(200).body("").unwrap()
|
||||
});
|
||||
|
||||
table
|
||||
.merge_insert(&["some_col"])
|
||||
.execute(data)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// All parameters specified
|
||||
let (sender, receiver) = std::sync::mpsc::channel();
|
||||
let table = Table::new_with_handler("my_table", move |mut request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/table/my_table/merge_insert/");
|
||||
assert_eq!(
|
||||
request.headers().get("Content-Type").unwrap(),
|
||||
ARROW_STREAM_CONTENT_TYPE
|
||||
);
|
||||
|
||||
let params = request.url().query_pairs().collect::<HashMap<_, _>>();
|
||||
assert_eq!(params["on"], "some_col");
|
||||
assert_eq!(params["when_matched_update_all"], "true");
|
||||
assert_eq!(params["when_not_matched_insert_all"], "false");
|
||||
assert_eq!(params["when_not_matched_by_source_delete"], "true");
|
||||
assert_eq!(params["when_matched_update_all_filt"], "a = 1");
|
||||
assert_eq!(params["when_not_matched_by_source_delete_filt"], "b = 2");
|
||||
|
||||
let mut body_out = reqwest::Body::from(Vec::new());
|
||||
std::mem::swap(request.body_mut().as_mut().unwrap(), &mut body_out);
|
||||
sender.send(body_out).unwrap();
|
||||
|
||||
http::Response::builder().status(200).body("").unwrap()
|
||||
});
|
||||
let mut builder = table.merge_insert(&["some_col"]);
|
||||
builder
|
||||
.when_matched_update_all(Some("a = 1".into()))
|
||||
.when_not_matched_by_source_delete(Some("b = 2".into()));
|
||||
let data = Box::new(RecordBatchIterator::new(
|
||||
[Ok(batch.clone())],
|
||||
batch.schema(),
|
||||
));
|
||||
builder.execute(data).await.unwrap();
|
||||
|
||||
let body = receiver.recv().unwrap();
|
||||
let body = collect_body(body).await;
|
||||
let expected_body = write_ipc_stream(&batch);
|
||||
assert_eq!(&body, &expected_body);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete() {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/table/my_table/delete/");
|
||||
|
||||
let body = request.body().unwrap().as_bytes().unwrap();
|
||||
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
|
||||
let predicate = body.get("predicate").unwrap().as_str().unwrap();
|
||||
assert_eq!(predicate, "id in (1, 2, 3)");
|
||||
|
||||
http::Response::builder().status(200).body("").unwrap()
|
||||
});
|
||||
|
||||
table.delete("id in (1, 2, 3)").await.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,12 +26,12 @@ use super::TableInternal;
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MergeInsertBuilder {
|
||||
table: Arc<dyn TableInternal>,
|
||||
pub(super) on: Vec<String>,
|
||||
pub(super) when_matched_update_all: bool,
|
||||
pub(super) when_matched_update_all_filt: Option<String>,
|
||||
pub(super) when_not_matched_insert_all: bool,
|
||||
pub(super) when_not_matched_by_source_delete: bool,
|
||||
pub(super) when_not_matched_by_source_delete_filt: Option<String>,
|
||||
pub(crate) on: Vec<String>,
|
||||
pub(crate) when_matched_update_all: bool,
|
||||
pub(crate) when_matched_update_all_filt: Option<String>,
|
||||
pub(crate) when_not_matched_insert_all: bool,
|
||||
pub(crate) when_not_matched_by_source_delete: bool,
|
||||
pub(crate) when_not_matched_by_source_delete_filt: Option<String>,
|
||||
}
|
||||
|
||||
impl MergeInsertBuilder {
|
||||
|
||||
Reference in New Issue
Block a user