From f3fc339ef650517674a22995d81f9ae21e7bae5a Mon Sep 17 00:00:00 2001 From: Will Jones Date: Thu, 31 Oct 2024 15:22:09 -0700 Subject: [PATCH] fix(rust): fix delete, update, query in remote SDK (#1782) Fixes several minor issues with Rust remote SDK: * Delete uses `predicate` not `filter` as parameter * Update does not return the row value in remote SDK * Update takes tuples * Content type returned by query node is wrong, so we shouldn't validate it. https://github.com/lancedb/sophon/issues/2742 * Data returned by query endpoint is actually an Arrow IPC file, not IPC stream. --- rust/lancedb/src/remote.rs | 2 + rust/lancedb/src/remote/table.rs | 94 +++++++++++++------------------- 2 files changed, 40 insertions(+), 56 deletions(-) diff --git a/rust/lancedb/src/remote.rs b/rust/lancedb/src/remote.rs index 08b52f3f..7f94ea7d 100644 --- a/rust/lancedb/src/remote.rs +++ b/rust/lancedb/src/remote.rs @@ -23,6 +23,8 @@ pub(crate) mod table; pub(crate) mod util; const ARROW_STREAM_CONTENT_TYPE: &str = "application/vnd.apache.arrow.stream"; +#[cfg(test)] +const ARROW_FILE_CONTENT_TYPE: &str = "application/vnd.apache.arrow.file"; const JSON_CONTENT_TYPE: &str = "application/json"; pub use client::{ClientConfig, RetryConfig, TimeoutConfig}; diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index 81fb7a90..f9900b2c 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -1,3 +1,4 @@ +use std::io::Cursor; use std::sync::{Arc, Mutex}; use crate::index::Index; @@ -7,10 +8,9 @@ use crate::table::AddDataMode; use crate::utils::{supported_btree_data_type, supported_vector_data_type}; use crate::Error; use arrow_array::RecordBatchReader; -use arrow_ipc::reader::StreamReader; +use arrow_ipc::reader::FileReader; use arrow_schema::{DataType, SchemaRef}; use async_trait::async_trait; -use bytes::Buf; use datafusion_common::DataFusionError; use datafusion_physical_plan::stream::RecordBatchStreamAdapter; use datafusion_physical_plan::{ExecutionPlan, SendableRecordBatchStream}; @@ -115,39 +115,14 @@ impl RemoteTable { async fn read_arrow_stream( &self, request_id: &str, - body: reqwest::Response, + response: reqwest::Response, ) -> Result { - // Assert that the content type is correct - let content_type = body - .headers() - .get(CONTENT_TYPE) - .ok_or_else(|| Error::Http { - source: "Missing content type".into(), - request_id: request_id.to_string(), - status_code: None, - })? - .to_str() - .map_err(|e| Error::Http { - source: format!("Failed to parse content type: {}", e).into(), - request_id: request_id.to_string(), - status_code: None, - })?; - if content_type != ARROW_STREAM_CONTENT_TYPE { - return Err(Error::Http { - source: format!( - "Expected content type {}, got {}", - ARROW_STREAM_CONTENT_TYPE, content_type - ) - .into(), - request_id: request_id.to_string(), - status_code: None, - }); - } + let response = self.check_table_response(request_id, response).await?; // There isn't a way to actually stream this data yet. I have an upstream issue: // https://github.com/apache/arrow-rs/issues/6420 - let body = body.bytes().await.err_to_http(request_id.into())?; - let reader = StreamReader::try_new(body.reader(), None)?; + let body = response.bytes().await.err_to_http(request_id.into())?; + let reader = FileReader::try_new(Cursor::new(body), None)?; let schema = reader.schema(); let stream = futures::stream::iter(reader).map_err(DataFusionError::from); Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream))) @@ -277,7 +252,7 @@ impl TableInternal for RemoteTable { .post(&format!("/v1/table/{}/count_rows/", self.name)); if let Some(filter) = filter { - request = request.json(&serde_json::json!({ "filter": filter })); + request = request.json(&serde_json::json!({ "predicate": filter })); } else { request = request.json(&serde_json::json!({})); } @@ -399,8 +374,7 @@ impl TableInternal for RemoteTable { let mut updates = Vec::new(); for (column, expression) in update.columns { - updates.push(column); - updates.push(expression); + updates.push(vec![column, expression]); } let request = request.json(&serde_json::json!({ @@ -410,19 +384,9 @@ impl TableInternal for RemoteTable { let (request_id, response) = self.client.send(request, false).await?; - let response = self.check_table_response(&request_id, response).await?; + self.check_table_response(&request_id, response).await?; - let body = response.text().await.err_to_http(request_id.clone())?; - - serde_json::from_str(&body).map_err(|e| Error::Http { - source: format!( - "Failed to parse updated rows result from response {}: {}", - body, e - ) - .into(), - request_id, - status_code: None, - }) + Ok(0) // TODO: support returning number of modified rows once supported in SaaS. } async fn delete(&self, predicate: &str) -> Result<()> { let body = serde_json::json!({ "predicate": predicate }); @@ -691,6 +655,7 @@ mod tests { use crate::{ index::{vector::IvfPqIndexBuilder, Index, IndexStatistics, IndexType}, query::{ExecutableQuery, QueryBase}, + remote::ARROW_FILE_CONTENT_TYPE, DistanceType, Error, Table, }; @@ -804,7 +769,7 @@ mod tests { ); assert_eq!( request.body().unwrap().as_bytes().unwrap(), - br#"{"filter":"a > 10"}"# + br#"{"predicate":"a > 10"}"# ); http::Response::builder().status(200).body("42").unwrap() @@ -839,6 +804,17 @@ mod tests { body } + fn write_ipc_file(data: &RecordBatch) -> Vec { + let mut body = Vec::new(); + { + let mut writer = arrow_ipc::writer::FileWriter::try_new(&mut body, &data.schema()) + .expect("Failed to create writer"); + writer.write(data).expect("Failed to write data"); + writer.finish().expect("Failed to finish"); + } + body + } + #[tokio::test] async fn test_add_append() { let data = RecordBatch::try_new( @@ -947,21 +923,27 @@ mod tests { let updates = value.get("updates").unwrap().as_array().unwrap(); assert!(updates.len() == 2); - let col_name = updates[0].as_str().unwrap(); - let expression = updates[1].as_str().unwrap(); + let col_name = updates[0][0].as_str().unwrap(); + let expression = updates[0][1].as_str().unwrap(); assert_eq!(col_name, "a"); assert_eq!(expression, "a + 1"); + let col_name = updates[1][0].as_str().unwrap(); + let expression = updates[1][1].as_str().unwrap(); + assert_eq!(col_name, "b"); + assert_eq!(expression, "b - 1"); + let only_if = value.get("only_if").unwrap().as_str().unwrap(); assert_eq!(only_if, "b > 10"); } - http::Response::builder().status(200).body("1").unwrap() + http::Response::builder().status(200).body("{}").unwrap() }); table .update() .column("a", "a + 1") + .column("b", "b - 1") .only_if("b > 10") .execute() .await @@ -1092,10 +1074,10 @@ mod tests { expected_body["vector"] = vec![0.1f32, 0.2, 0.3].into(); assert_eq!(body, expected_body); - let response_body = write_ipc_stream(&expected_data_ref); + let response_body = write_ipc_file(&expected_data_ref); http::Response::builder() .status(200) - .header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE) + .header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE) .body(response_body) .unwrap() }); @@ -1142,10 +1124,10 @@ mod tests { vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], ) .unwrap(); - let response_body = write_ipc_stream(&data); + let response_body = write_ipc_file(&data); http::Response::builder() .status(200) - .header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE) + .header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE) .body(response_body) .unwrap() }); @@ -1193,10 +1175,10 @@ mod tests { vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], ) .unwrap(); - let response_body = write_ipc_stream(&data); + let response_body = write_ipc_file(&data); http::Response::builder() .status(200) - .header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE) + .header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE) .body(response_body) .unwrap() });