From d32360b99d6d4fc2f1ccf1cdfb2fa25836ccaf7a Mon Sep 17 00:00:00 2001 From: Bert Date: Tue, 26 Nov 2024 11:38:36 -0500 Subject: [PATCH] feat: support overwrite and exist_ok mode for remote create_table (#1883) Support passing modes "overwrite" and "exist_ok" when creating a remote table. --- rust/lancedb/src/connection.rs | 4 +- rust/lancedb/src/remote/db.rs | 109 ++++++++++++++++++++++++++++++++- 2 files changed, 109 insertions(+), 4 deletions(-) diff --git a/rust/lancedb/src/connection.rs b/rust/lancedb/src/connection.rs index 40329b66..8946beee 100644 --- a/rust/lancedb/src/connection.rs +++ b/rust/lancedb/src/connection.rs @@ -133,7 +133,7 @@ impl IntoArrow for NoData { /// A builder for configuring a [`Connection::create_table`] operation pub struct CreateTableBuilder { - parent: Arc, + pub(crate) parent: Arc, pub(crate) name: String, pub(crate) data: Option, pub(crate) mode: CreateTableMode, @@ -341,7 +341,7 @@ pub struct OpenTableBuilder { } impl OpenTableBuilder { - fn new(parent: Arc, name: String) -> Self { + pub(crate) fn new(parent: Arc, name: String) -> Self { Self { parent, name, diff --git a/rust/lancedb/src/remote/db.rs b/rust/lancedb/src/remote/db.rs index 05b5dfe2..fc10fbdb 100644 --- a/rust/lancedb/src/remote/db.rs +++ b/rust/lancedb/src/remote/db.rs @@ -23,7 +23,8 @@ use serde::Deserialize; use tokio::task::spawn_blocking; use crate::connection::{ - ConnectionInternal, CreateTableBuilder, NoData, OpenTableBuilder, TableNamesBuilder, + ConnectionInternal, CreateTableBuilder, CreateTableMode, NoData, OpenTableBuilder, + TableNamesBuilder, }; use crate::embeddings::EmbeddingRegistry; use crate::error::Result; @@ -95,6 +96,16 @@ impl std::fmt::Display for RemoteDatabase { } } +impl From<&CreateTableMode> for &'static str { + fn from(val: &CreateTableMode) -> Self { + match val { + CreateTableMode::Create => "create", + CreateTableMode::Overwrite => "overwrite", + CreateTableMode::ExistOk(_) => "exist_ok", + } + } +} + #[async_trait] impl ConnectionInternal for RemoteDatabase { async fn table_names(&self, options: TableNamesBuilder) -> Result> { @@ -133,14 +144,40 @@ impl ConnectionInternal for RemoteDatabase { let req = self .client .post(&format!("/v1/table/{}/create/", options.name)) + .query(&[("mode", Into::<&str>::into(&options.mode))]) .body(data_buffer) .header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE); + let (request_id, rsp) = self.client.send(req, false).await?; if rsp.status() == StatusCode::BAD_REQUEST { let body = rsp.text().await.err_to_http(request_id.clone())?; if body.contains("already exists") { - return Err(crate::Error::TableAlreadyExists { name: options.name }); + return match options.mode { + CreateTableMode::Create => { + Err(crate::Error::TableAlreadyExists { name: options.name }) + } + CreateTableMode::ExistOk(callback) => { + let builder = OpenTableBuilder::new(options.parent, options.name); + let builder = (callback)(builder); + builder.execute().await + } + + // This should not happen, as we explicitly set the mode to overwrite and the server + // shouldn't return an error if the table already exists. + // + // However if the server is an older version that doesn't support the mode parameter, + // then we'll get the 400 response. + CreateTableMode::Overwrite => Err(crate::Error::Http { + source: format!( + "unexpected response from server for create mode overwrite: {}", + body + ) + .into(), + request_id, + status_code: Some(StatusCode::BAD_REQUEST), + }), + }; } else { return Err(crate::Error::InvalidInput { message: body }); } @@ -214,6 +251,7 @@ mod tests { use arrow_schema::{DataType, Field, Schema}; use crate::{ + connection::CreateTableMode, remote::{ARROW_STREAM_CONTENT_TYPE, JSON_CONTENT_TYPE}, Connection, Error, }; @@ -382,6 +420,73 @@ mod tests { ); } + #[tokio::test] + async fn test_create_table_modes() { + let test_cases = [ + (None, "mode=create"), + (Some(CreateTableMode::Create), "mode=create"), + (Some(CreateTableMode::Overwrite), "mode=overwrite"), + ( + Some(CreateTableMode::ExistOk(Box::new(|b| b))), + "mode=exist_ok", + ), + ]; + + for (mode, expected_query_string) in test_cases { + let conn = Connection::new_with_handler(move |request| { + assert_eq!(request.method(), &reqwest::Method::POST); + assert_eq!(request.url().path(), "/v1/table/table1/create/"); + assert_eq!(request.url().query(), Some(expected_query_string)); + + http::Response::builder().status(200).body("").unwrap() + }); + + let data = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + let reader = RecordBatchIterator::new([Ok(data.clone())], data.schema()); + let mut builder = conn.create_table("table1", reader); + if let Some(mode) = mode { + builder = builder.mode(mode); + } + builder.execute().await.unwrap(); + } + + // check that the open table callback is called with exist_ok + let conn = Connection::new_with_handler(|request| match request.url().path() { + "/v1/table/table1/create/" => http::Response::builder() + .status(400) + .body("Table table1 already exists") + .unwrap(), + "/v1/table/table1/describe/" => http::Response::builder().status(200).body("").unwrap(), + _ => { + panic!("unexpected path: {:?}", request.url().path()); + } + }); + let data = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + + let called: Arc> = Arc::new(OnceLock::new()); + let reader = RecordBatchIterator::new([Ok(data.clone())], data.schema()); + let called_in_cb = called.clone(); + conn.create_table("table1", reader) + .mode(CreateTableMode::ExistOk(Box::new(move |b| { + called_in_cb.clone().set(true).unwrap(); + b + }))) + .execute() + .await + .unwrap(); + + let called = *called.get().unwrap_or(&false); + assert!(called); + } + #[tokio::test] async fn test_create_table_empty() { let conn = Connection::new_with_handler(|request| {