From e61ba7f4e21ed801912e9ea75f07d8c40b29adc2 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Fri, 4 Oct 2024 08:43:07 -0700 Subject: [PATCH] fix(rust): remote SDK bugs (#1723) A few bugs uncovered by integration tests: * We didn't prepend `/v1` to the Table endpoint URLs * `/create_index` takes `metric_type` not `distance_type`. (This is also an error in the OpenAPI docs.) * `/create_index` expects the `metric_type` parameter to always be lowercase. * We were writing an IPC file message when we were supposed to send an IPC stream message. --- rust/lancedb/src/remote/table.rs | 67 ++++++++++++++++++-------------- rust/lancedb/src/remote/util.rs | 2 +- 2 files changed, 38 insertions(+), 31 deletions(-) diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index 16caad70..4c6182ce 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -50,7 +50,9 @@ impl RemoteTable { } async fn describe(&self) -> Result { - let request = self.client.post(&format!("/table/{}/describe/", self.name)); + let request = self + .client + .post(&format!("/v1/table/{}/describe/", self.name)); let response = self.client.send(request, true).await?; let response = self.check_table_response(response).await?; @@ -249,7 +251,7 @@ impl TableInternal for RemoteTable { async fn count_rows(&self, filter: Option) -> Result { let mut request = self .client - .post(&format!("/table/{}/count_rows/", self.name)); + .post(&format!("/v1/table/{}/count_rows/", self.name)); if let Some(filter) = filter { request = request.json(&serde_json::json!({ "filter": filter })); @@ -275,7 +277,7 @@ impl TableInternal for RemoteTable { let body = Self::reader_as_body(data)?; let mut request = self .client - .post(&format!("/table/{}/insert/", self.name)) + .post(&format!("/v1/table/{}/insert/", self.name)) .header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE) .body(body); @@ -298,7 +300,7 @@ impl TableInternal for RemoteTable { query: &VectorQuery, _options: QueryExecutionOptions, ) -> Result> { - let request = self.client.post(&format!("/table/{}/query/", self.name)); + let request = self.client.post(&format!("/v1/table/{}/query/", self.name)); let mut body = serde_json::Value::Object(Default::default()); Self::apply_query_params(&mut body, &query.base)?; @@ -351,7 +353,7 @@ impl TableInternal for RemoteTable { ) -> Result { let request = self .client - .post(&format!("/table/{}/query/", self.name)) + .post(&format!("/v1/table/{}/query/", self.name)) .header(CONTENT_TYPE, JSON_CONTENT_TYPE); let mut body = serde_json::Value::Object(Default::default()); @@ -366,7 +368,9 @@ impl TableInternal for RemoteTable { Ok(DatasetRecordBatchStream::new(stream)) } async fn update(&self, update: UpdateBuilder) -> Result { - let request = self.client.post(&format!("/table/{}/update/", self.name)); + let request = self + .client + .post(&format!("/v1/table/{}/update/", self.name)); let mut updates = Vec::new(); for (column, expression) in update.columns { @@ -396,7 +400,7 @@ impl TableInternal for RemoteTable { let body = serde_json::json!({ "predicate": predicate }); let request = self .client - .post(&format!("/table/{}/delete/", self.name)) + .post(&format!("/v1/table/{}/delete/", self.name)) .json(&body); let response = self.client.send(request, false).await?; self.check_table_response(response).await?; @@ -406,7 +410,7 @@ impl TableInternal for RemoteTable { async fn create_index(&self, mut index: IndexBuilder) -> Result<()> { let request = self .client - .post(&format!("/table/{}/create_index/", self.name)); + .post(&format!("/v1/table/{}/create_index/", self.name)); let column = match index.columns.len() { 0 => { @@ -463,7 +467,9 @@ impl TableInternal for RemoteTable { }; body["index_type"] = serde_json::Value::String(index_type.into()); if let Some(distance_type) = distance_type { - body["distance_type"] = serde_json::Value::String(distance_type.to_string()); + // Phalanx expects this to be lowercase right now. + body["metric_type"] = + serde_json::Value::String(distance_type.to_string().to_lowercase()); } let request = request.json(&body); @@ -484,7 +490,7 @@ impl TableInternal for RemoteTable { let body = Self::reader_as_body(new_data)?; let request = self .client - .post(&format!("/table/{}/merge_insert/", self.name)) + .post(&format!("/v1/table/{}/merge_insert/", self.name)) .query(&query) .header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE) .body(body); @@ -525,9 +531,10 @@ impl TableInternal for RemoteTable { }) } async fn index_stats(&self, index_name: &str) -> Result> { - let request = self - .client - .post(&format!("/table/{}/index/{}/stats/", self.name, index_name)); + let request = self.client.post(&format!( + "/v1/table/{}/index/{}/stats/", + self.name, index_name + )); let response = self.client.send(request, true).await?; if response.status() == StatusCode::NOT_FOUND { @@ -651,7 +658,7 @@ mod tests { async fn test_version() { let table = Table::new_with_handler("my_table", |request| { assert_eq!(request.method(), "POST"); - assert_eq!(request.url().path(), "/table/my_table/describe/"); + assert_eq!(request.url().path(), "/v1/table/my_table/describe/"); http::Response::builder() .status(200) @@ -667,7 +674,7 @@ mod tests { async fn test_schema() { let table = Table::new_with_handler("my_table", |request| { assert_eq!(request.method(), "POST"); - assert_eq!(request.url().path(), "/table/my_table/describe/"); + assert_eq!(request.url().path(), "/v1/table/my_table/describe/"); http::Response::builder() .status(200) @@ -696,7 +703,7 @@ mod tests { async fn test_count_rows() { let table = Table::new_with_handler("my_table", |request| { assert_eq!(request.method(), "POST"); - assert_eq!(request.url().path(), "/table/my_table/count_rows/"); + assert_eq!(request.url().path(), "/v1/table/my_table/count_rows/"); assert_eq!( request.headers().get("Content-Type").unwrap(), JSON_CONTENT_TYPE @@ -711,7 +718,7 @@ mod tests { let table = Table::new_with_handler("my_table", |request| { assert_eq!(request.method(), "POST"); - assert_eq!(request.url().path(), "/table/my_table/count_rows/"); + assert_eq!(request.url().path(), "/v1/table/my_table/count_rows/"); assert_eq!( request.headers().get("Content-Type").unwrap(), JSON_CONTENT_TYPE @@ -764,7 +771,7 @@ mod tests { let (sender, receiver) = std::sync::mpsc::channel(); let table = Table::new_with_handler("my_table", move |mut request| { assert_eq!(request.method(), "POST"); - assert_eq!(request.url().path(), "/table/my_table/insert/"); + assert_eq!(request.url().path(), "/v1/table/my_table/insert/"); // If mode is specified, it should be "append". Append is default // so it's not required. assert!(request @@ -808,7 +815,7 @@ mod tests { let (sender, receiver) = std::sync::mpsc::channel(); let table = Table::new_with_handler("my_table", move |mut request| { assert_eq!(request.method(), "POST"); - assert_eq!(request.url().path(), "/table/my_table/insert/"); + assert_eq!(request.url().path(), "/v1/table/my_table/insert/"); assert_eq!( request .url() @@ -849,7 +856,7 @@ mod tests { async fn test_update() { let table = Table::new_with_handler("my_table", |request| { assert_eq!(request.method(), "POST"); - assert_eq!(request.url().path(), "/table/my_table/update/"); + assert_eq!(request.url().path(), "/v1/table/my_table/update/"); assert_eq!( request.headers().get("Content-Type").unwrap(), JSON_CONTENT_TYPE @@ -897,7 +904,7 @@ mod tests { // Default parameters let table = Table::new_with_handler("my_table", |request| { assert_eq!(request.method(), "POST"); - assert_eq!(request.url().path(), "/table/my_table/merge_insert/"); + assert_eq!(request.url().path(), "/v1/table/my_table/merge_insert/"); let params = request.url().query_pairs().collect::>(); assert_eq!(params["on"], "some_col"); @@ -920,7 +927,7 @@ mod tests { let (sender, receiver) = std::sync::mpsc::channel(); let table = Table::new_with_handler("my_table", move |mut request| { assert_eq!(request.method(), "POST"); - assert_eq!(request.url().path(), "/table/my_table/merge_insert/"); + assert_eq!(request.url().path(), "/v1/table/my_table/merge_insert/"); assert_eq!( request.headers().get("Content-Type").unwrap(), ARROW_STREAM_CONTENT_TYPE @@ -960,7 +967,7 @@ mod tests { async fn test_delete() { let table = Table::new_with_handler("my_table", |request| { assert_eq!(request.method(), "POST"); - assert_eq!(request.url().path(), "/table/my_table/delete/"); + assert_eq!(request.url().path(), "/v1/table/my_table/delete/"); assert_eq!( request.headers().get("Content-Type").unwrap(), JSON_CONTENT_TYPE @@ -988,7 +995,7 @@ mod tests { let table = Table::new_with_handler("my_table", move |request| { assert_eq!(request.method(), "POST"); - assert_eq!(request.url().path(), "/table/my_table/query/"); + assert_eq!(request.url().path(), "/v1/table/my_table/query/"); assert_eq!( request.headers().get("Content-Type").unwrap(), JSON_CONTENT_TYPE @@ -1029,7 +1036,7 @@ mod tests { async fn test_query_vector_all_params() { let table = Table::new_with_handler("my_table", |request| { assert_eq!(request.method(), "POST"); - assert_eq!(request.url().path(), "/table/my_table/query/"); + assert_eq!(request.url().path(), "/v1/table/my_table/query/"); assert_eq!( request.headers().get("Content-Type").unwrap(), JSON_CONTENT_TYPE @@ -1085,7 +1092,7 @@ mod tests { async fn test_query_fts() { let table = Table::new_with_handler("my_table", |request| { assert_eq!(request.method(), "POST"); - assert_eq!(request.url().path(), "/table/my_table/query/"); + assert_eq!(request.url().path(), "/v1/table/my_table/query/"); assert_eq!( request.headers().get("Content-Type").unwrap(), JSON_CONTENT_TYPE @@ -1151,7 +1158,7 @@ mod tests { for (index_type, distance_type, index) in cases { let table = Table::new_with_handler("my_table", move |request| { assert_eq!(request.method(), "POST"); - assert_eq!(request.url().path(), "/table/my_table/create_index/"); + assert_eq!(request.url().path(), "/v1/table/my_table/create_index/"); assert_eq!( request.headers().get("Content-Type").unwrap(), JSON_CONTENT_TYPE @@ -1163,7 +1170,7 @@ mod tests { "index_type": index_type, }); if let Some(distance_type) = distance_type { - expected_body["distance_type"] = distance_type.into(); + expected_body["metric_type"] = distance_type.to_lowercase().into(); } assert_eq!(body, expected_body); @@ -1180,7 +1187,7 @@ mod tests { assert_eq!(request.method(), "POST"); assert_eq!( request.url().path(), - "/table/my_table/index/my_index/stats/" + "/v1/table/my_table/index/my_index/stats/" ); let response_body = serde_json::json!({ @@ -1210,7 +1217,7 @@ mod tests { assert_eq!(request.method(), "POST"); assert_eq!( request.url().path(), - "/table/my_table/index/my_index/stats/" + "/v1/table/my_table/index/my_index/stats/" ); http::Response::builder().status(404).body("").unwrap() diff --git a/rust/lancedb/src/remote/util.rs b/rust/lancedb/src/remote/util.rs index b594ed6e..d13be70f 100644 --- a/rust/lancedb/src/remote/util.rs +++ b/rust/lancedb/src/remote/util.rs @@ -9,7 +9,7 @@ pub fn batches_to_ipc_bytes(batches: impl RecordBatchReader) -> Result> let buf = Vec::with_capacity(WRITE_BUF_SIZE); let mut buf = Cursor::new(buf); { - let mut writer = arrow_ipc::writer::FileWriter::try_new(&mut buf, &batches.schema())?; + let mut writer = arrow_ipc::writer::StreamWriter::try_new(&mut buf, &batches.schema())?; for batch in batches { let batch = batch?;