fix(rust): fix delete, update, query in remote SDK (#1782)

Fixes several minor issues with Rust remote SDK:

* Delete uses `predicate` not `filter` as parameter
* Update does not return the row value in remote SDK
* Update takes tuples
* Content type returned by query node is wrong, so we shouldn't validate
it. https://github.com/lancedb/sophon/issues/2742
* Data returned by query endpoint is actually an Arrow IPC file, not IPC
stream.
This commit is contained in:
Will Jones
2024-10-31 15:22:09 -07:00
committed by GitHub
parent 113cd6995b
commit f3fc339ef6
2 changed files with 40 additions and 56 deletions

View File

@@ -23,6 +23,8 @@ pub(crate) mod table;
pub(crate) mod util;
const ARROW_STREAM_CONTENT_TYPE: &str = "application/vnd.apache.arrow.stream";
#[cfg(test)]
const ARROW_FILE_CONTENT_TYPE: &str = "application/vnd.apache.arrow.file";
const JSON_CONTENT_TYPE: &str = "application/json";
pub use client::{ClientConfig, RetryConfig, TimeoutConfig};

View File

@@ -1,3 +1,4 @@
use std::io::Cursor;
use std::sync::{Arc, Mutex};
use crate::index::Index;
@@ -7,10 +8,9 @@ use crate::table::AddDataMode;
use crate::utils::{supported_btree_data_type, supported_vector_data_type};
use crate::Error;
use arrow_array::RecordBatchReader;
use arrow_ipc::reader::StreamReader;
use arrow_ipc::reader::FileReader;
use arrow_schema::{DataType, SchemaRef};
use async_trait::async_trait;
use bytes::Buf;
use datafusion_common::DataFusionError;
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
use datafusion_physical_plan::{ExecutionPlan, SendableRecordBatchStream};
@@ -115,39 +115,14 @@ impl<S: HttpSend> RemoteTable<S> {
async fn read_arrow_stream(
&self,
request_id: &str,
body: reqwest::Response,
response: reqwest::Response,
) -> Result<SendableRecordBatchStream> {
// Assert that the content type is correct
let content_type = body
.headers()
.get(CONTENT_TYPE)
.ok_or_else(|| Error::Http {
source: "Missing content type".into(),
request_id: request_id.to_string(),
status_code: None,
})?
.to_str()
.map_err(|e| Error::Http {
source: format!("Failed to parse content type: {}", e).into(),
request_id: request_id.to_string(),
status_code: None,
})?;
if content_type != ARROW_STREAM_CONTENT_TYPE {
return Err(Error::Http {
source: format!(
"Expected content type {}, got {}",
ARROW_STREAM_CONTENT_TYPE, content_type
)
.into(),
request_id: request_id.to_string(),
status_code: None,
});
}
let response = self.check_table_response(request_id, response).await?;
// There isn't a way to actually stream this data yet. I have an upstream issue:
// https://github.com/apache/arrow-rs/issues/6420
let body = body.bytes().await.err_to_http(request_id.into())?;
let reader = StreamReader::try_new(body.reader(), None)?;
let body = response.bytes().await.err_to_http(request_id.into())?;
let reader = FileReader::try_new(Cursor::new(body), None)?;
let schema = reader.schema();
let stream = futures::stream::iter(reader).map_err(DataFusionError::from);
Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
@@ -277,7 +252,7 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
.post(&format!("/v1/table/{}/count_rows/", self.name));
if let Some(filter) = filter {
request = request.json(&serde_json::json!({ "filter": filter }));
request = request.json(&serde_json::json!({ "predicate": filter }));
} else {
request = request.json(&serde_json::json!({}));
}
@@ -399,8 +374,7 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
let mut updates = Vec::new();
for (column, expression) in update.columns {
updates.push(column);
updates.push(expression);
updates.push(vec![column, expression]);
}
let request = request.json(&serde_json::json!({
@@ -410,19 +384,9 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
let (request_id, response) = self.client.send(request, false).await?;
let response = self.check_table_response(&request_id, response).await?;
self.check_table_response(&request_id, response).await?;
let body = response.text().await.err_to_http(request_id.clone())?;
serde_json::from_str(&body).map_err(|e| Error::Http {
source: format!(
"Failed to parse updated rows result from response {}: {}",
body, e
)
.into(),
request_id,
status_code: None,
})
Ok(0) // TODO: support returning number of modified rows once supported in SaaS.
}
async fn delete(&self, predicate: &str) -> Result<()> {
let body = serde_json::json!({ "predicate": predicate });
@@ -691,6 +655,7 @@ mod tests {
use crate::{
index::{vector::IvfPqIndexBuilder, Index, IndexStatistics, IndexType},
query::{ExecutableQuery, QueryBase},
remote::ARROW_FILE_CONTENT_TYPE,
DistanceType, Error, Table,
};
@@ -804,7 +769,7 @@ mod tests {
);
assert_eq!(
request.body().unwrap().as_bytes().unwrap(),
br#"{"filter":"a > 10"}"#
br#"{"predicate":"a > 10"}"#
);
http::Response::builder().status(200).body("42").unwrap()
@@ -839,6 +804,17 @@ mod tests {
body
}
fn write_ipc_file(data: &RecordBatch) -> Vec<u8> {
let mut body = Vec::new();
{
let mut writer = arrow_ipc::writer::FileWriter::try_new(&mut body, &data.schema())
.expect("Failed to create writer");
writer.write(data).expect("Failed to write data");
writer.finish().expect("Failed to finish");
}
body
}
#[tokio::test]
async fn test_add_append() {
let data = RecordBatch::try_new(
@@ -947,21 +923,27 @@ mod tests {
let updates = value.get("updates").unwrap().as_array().unwrap();
assert!(updates.len() == 2);
let col_name = updates[0].as_str().unwrap();
let expression = updates[1].as_str().unwrap();
let col_name = updates[0][0].as_str().unwrap();
let expression = updates[0][1].as_str().unwrap();
assert_eq!(col_name, "a");
assert_eq!(expression, "a + 1");
let col_name = updates[1][0].as_str().unwrap();
let expression = updates[1][1].as_str().unwrap();
assert_eq!(col_name, "b");
assert_eq!(expression, "b - 1");
let only_if = value.get("only_if").unwrap().as_str().unwrap();
assert_eq!(only_if, "b > 10");
}
http::Response::builder().status(200).body("1").unwrap()
http::Response::builder().status(200).body("{}").unwrap()
});
table
.update()
.column("a", "a + 1")
.column("b", "b - 1")
.only_if("b > 10")
.execute()
.await
@@ -1092,10 +1074,10 @@ mod tests {
expected_body["vector"] = vec![0.1f32, 0.2, 0.3].into();
assert_eq!(body, expected_body);
let response_body = write_ipc_stream(&expected_data_ref);
let response_body = write_ipc_file(&expected_data_ref);
http::Response::builder()
.status(200)
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
.header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE)
.body(response_body)
.unwrap()
});
@@ -1142,10 +1124,10 @@ mod tests {
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap();
let response_body = write_ipc_stream(&data);
let response_body = write_ipc_file(&data);
http::Response::builder()
.status(200)
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
.header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE)
.body(response_body)
.unwrap()
});
@@ -1193,10 +1175,10 @@ mod tests {
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap();
let response_body = write_ipc_stream(&data);
let response_body = write_ipc_file(&data);
http::Response::builder()
.status(200)
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
.header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE)
.body(response_body)
.unwrap()
});