mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-07 12:22:59 +00:00
fix(rust): fix delete, update, query in remote SDK (#1782)
Fixes several minor issues with Rust remote SDK: * Delete uses `predicate` not `filter` as parameter * Update does not return the row value in remote SDK * Update takes tuples * Content type returned by query node is wrong, so we shouldn't validate it. https://github.com/lancedb/sophon/issues/2742 * Data returned by query endpoint is actually an Arrow IPC file, not IPC stream.
This commit is contained in:
@@ -23,6 +23,8 @@ pub(crate) mod table;
|
||||
pub(crate) mod util;
|
||||
|
||||
const ARROW_STREAM_CONTENT_TYPE: &str = "application/vnd.apache.arrow.stream";
|
||||
#[cfg(test)]
|
||||
const ARROW_FILE_CONTENT_TYPE: &str = "application/vnd.apache.arrow.file";
|
||||
const JSON_CONTENT_TYPE: &str = "application/json";
|
||||
|
||||
pub use client::{ClientConfig, RetryConfig, TimeoutConfig};
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use std::io::Cursor;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use crate::index::Index;
|
||||
@@ -7,10 +8,9 @@ use crate::table::AddDataMode;
|
||||
use crate::utils::{supported_btree_data_type, supported_vector_data_type};
|
||||
use crate::Error;
|
||||
use arrow_array::RecordBatchReader;
|
||||
use arrow_ipc::reader::StreamReader;
|
||||
use arrow_ipc::reader::FileReader;
|
||||
use arrow_schema::{DataType, SchemaRef};
|
||||
use async_trait::async_trait;
|
||||
use bytes::Buf;
|
||||
use datafusion_common::DataFusionError;
|
||||
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
|
||||
use datafusion_physical_plan::{ExecutionPlan, SendableRecordBatchStream};
|
||||
@@ -115,39 +115,14 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
async fn read_arrow_stream(
|
||||
&self,
|
||||
request_id: &str,
|
||||
body: reqwest::Response,
|
||||
response: reqwest::Response,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
// Assert that the content type is correct
|
||||
let content_type = body
|
||||
.headers()
|
||||
.get(CONTENT_TYPE)
|
||||
.ok_or_else(|| Error::Http {
|
||||
source: "Missing content type".into(),
|
||||
request_id: request_id.to_string(),
|
||||
status_code: None,
|
||||
})?
|
||||
.to_str()
|
||||
.map_err(|e| Error::Http {
|
||||
source: format!("Failed to parse content type: {}", e).into(),
|
||||
request_id: request_id.to_string(),
|
||||
status_code: None,
|
||||
})?;
|
||||
if content_type != ARROW_STREAM_CONTENT_TYPE {
|
||||
return Err(Error::Http {
|
||||
source: format!(
|
||||
"Expected content type {}, got {}",
|
||||
ARROW_STREAM_CONTENT_TYPE, content_type
|
||||
)
|
||||
.into(),
|
||||
request_id: request_id.to_string(),
|
||||
status_code: None,
|
||||
});
|
||||
}
|
||||
let response = self.check_table_response(request_id, response).await?;
|
||||
|
||||
// There isn't a way to actually stream this data yet. I have an upstream issue:
|
||||
// https://github.com/apache/arrow-rs/issues/6420
|
||||
let body = body.bytes().await.err_to_http(request_id.into())?;
|
||||
let reader = StreamReader::try_new(body.reader(), None)?;
|
||||
let body = response.bytes().await.err_to_http(request_id.into())?;
|
||||
let reader = FileReader::try_new(Cursor::new(body), None)?;
|
||||
let schema = reader.schema();
|
||||
let stream = futures::stream::iter(reader).map_err(DataFusionError::from);
|
||||
Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
|
||||
@@ -277,7 +252,7 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
.post(&format!("/v1/table/{}/count_rows/", self.name));
|
||||
|
||||
if let Some(filter) = filter {
|
||||
request = request.json(&serde_json::json!({ "filter": filter }));
|
||||
request = request.json(&serde_json::json!({ "predicate": filter }));
|
||||
} else {
|
||||
request = request.json(&serde_json::json!({}));
|
||||
}
|
||||
@@ -399,8 +374,7 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
|
||||
let mut updates = Vec::new();
|
||||
for (column, expression) in update.columns {
|
||||
updates.push(column);
|
||||
updates.push(expression);
|
||||
updates.push(vec![column, expression]);
|
||||
}
|
||||
|
||||
let request = request.json(&serde_json::json!({
|
||||
@@ -410,19 +384,9 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
|
||||
let (request_id, response) = self.client.send(request, false).await?;
|
||||
|
||||
let response = self.check_table_response(&request_id, response).await?;
|
||||
self.check_table_response(&request_id, response).await?;
|
||||
|
||||
let body = response.text().await.err_to_http(request_id.clone())?;
|
||||
|
||||
serde_json::from_str(&body).map_err(|e| Error::Http {
|
||||
source: format!(
|
||||
"Failed to parse updated rows result from response {}: {}",
|
||||
body, e
|
||||
)
|
||||
.into(),
|
||||
request_id,
|
||||
status_code: None,
|
||||
})
|
||||
Ok(0) // TODO: support returning number of modified rows once supported in SaaS.
|
||||
}
|
||||
async fn delete(&self, predicate: &str) -> Result<()> {
|
||||
let body = serde_json::json!({ "predicate": predicate });
|
||||
@@ -691,6 +655,7 @@ mod tests {
|
||||
use crate::{
|
||||
index::{vector::IvfPqIndexBuilder, Index, IndexStatistics, IndexType},
|
||||
query::{ExecutableQuery, QueryBase},
|
||||
remote::ARROW_FILE_CONTENT_TYPE,
|
||||
DistanceType, Error, Table,
|
||||
};
|
||||
|
||||
@@ -804,7 +769,7 @@ mod tests {
|
||||
);
|
||||
assert_eq!(
|
||||
request.body().unwrap().as_bytes().unwrap(),
|
||||
br#"{"filter":"a > 10"}"#
|
||||
br#"{"predicate":"a > 10"}"#
|
||||
);
|
||||
|
||||
http::Response::builder().status(200).body("42").unwrap()
|
||||
@@ -839,6 +804,17 @@ mod tests {
|
||||
body
|
||||
}
|
||||
|
||||
fn write_ipc_file(data: &RecordBatch) -> Vec<u8> {
|
||||
let mut body = Vec::new();
|
||||
{
|
||||
let mut writer = arrow_ipc::writer::FileWriter::try_new(&mut body, &data.schema())
|
||||
.expect("Failed to create writer");
|
||||
writer.write(data).expect("Failed to write data");
|
||||
writer.finish().expect("Failed to finish");
|
||||
}
|
||||
body
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_add_append() {
|
||||
let data = RecordBatch::try_new(
|
||||
@@ -947,21 +923,27 @@ mod tests {
|
||||
let updates = value.get("updates").unwrap().as_array().unwrap();
|
||||
assert!(updates.len() == 2);
|
||||
|
||||
let col_name = updates[0].as_str().unwrap();
|
||||
let expression = updates[1].as_str().unwrap();
|
||||
let col_name = updates[0][0].as_str().unwrap();
|
||||
let expression = updates[0][1].as_str().unwrap();
|
||||
assert_eq!(col_name, "a");
|
||||
assert_eq!(expression, "a + 1");
|
||||
|
||||
let col_name = updates[1][0].as_str().unwrap();
|
||||
let expression = updates[1][1].as_str().unwrap();
|
||||
assert_eq!(col_name, "b");
|
||||
assert_eq!(expression, "b - 1");
|
||||
|
||||
let only_if = value.get("only_if").unwrap().as_str().unwrap();
|
||||
assert_eq!(only_if, "b > 10");
|
||||
}
|
||||
|
||||
http::Response::builder().status(200).body("1").unwrap()
|
||||
http::Response::builder().status(200).body("{}").unwrap()
|
||||
});
|
||||
|
||||
table
|
||||
.update()
|
||||
.column("a", "a + 1")
|
||||
.column("b", "b - 1")
|
||||
.only_if("b > 10")
|
||||
.execute()
|
||||
.await
|
||||
@@ -1092,10 +1074,10 @@ mod tests {
|
||||
expected_body["vector"] = vec![0.1f32, 0.2, 0.3].into();
|
||||
assert_eq!(body, expected_body);
|
||||
|
||||
let response_body = write_ipc_stream(&expected_data_ref);
|
||||
let response_body = write_ipc_file(&expected_data_ref);
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
|
||||
.header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE)
|
||||
.body(response_body)
|
||||
.unwrap()
|
||||
});
|
||||
@@ -1142,10 +1124,10 @@ mod tests {
|
||||
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
|
||||
)
|
||||
.unwrap();
|
||||
let response_body = write_ipc_stream(&data);
|
||||
let response_body = write_ipc_file(&data);
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
|
||||
.header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE)
|
||||
.body(response_body)
|
||||
.unwrap()
|
||||
});
|
||||
@@ -1193,10 +1175,10 @@ mod tests {
|
||||
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
|
||||
)
|
||||
.unwrap();
|
||||
let response_body = write_ipc_stream(&data);
|
||||
let response_body = write_ipc_file(&data);
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
|
||||
.header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE)
|
||||
.body(response_body)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user