fix: remote table doesn't apply the prefilter flag for FTS (#2145)

This commit is contained in:
BubbleCal
2025-02-24 21:37:43 +08:00
committed by GitHub
parent a99a450f2b
commit f391ed828a

View File

@@ -149,6 +149,7 @@ impl<S: HttpSend> RemoteTable<S> {
}
fn apply_query_params(body: &mut serde_json::Value, params: &QueryRequest) -> Result<()> {
body["prefilter"] = params.prefilter.into();
if let Some(offset) = params.offset {
body["offset"] = serde_json::Value::Number(serde_json::Number::from(offset));
}
@@ -211,7 +212,6 @@ impl<S: HttpSend> RemoteTable<S> {
Self::apply_query_params(body, &query.base)?;
// Apply general parameters, before we dispatch based on number of query vectors.
body["prefilter"] = query.base.prefilter.into();
body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default());
body["nprobes"] = query.nprobes.into();
body["lower_bound"] = query.lower_bound.into();
@@ -1340,6 +1340,55 @@ mod tests {
assert_eq!(data[0].as_ref().unwrap(), &expected_data);
}
#[tokio::test]
async fn test_query_fts_default_values() {
let expected_data = RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap();
let expected_data_ref = expected_data.clone();
let table = Table::new_with_handler("my_table", move |request| {
assert_eq!(request.method(), "POST");
assert_eq!(request.url().path(), "/v1/table/my_table/query/");
assert_eq!(
request.headers().get("Content-Type").unwrap(),
JSON_CONTENT_TYPE
);
let body = request.body().unwrap().as_bytes().unwrap();
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
let expected_body = serde_json::json!({
"full_text_query": {
"columns": [],
"query": "test",
},
"prefilter": true,
"version": null,
"k": 10,
"vector": [],
});
assert_eq!(body, expected_body);
let response_body = write_ipc_file(&expected_data_ref);
http::Response::builder()
.status(200)
.header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE)
.body(response_body)
.unwrap()
});
let data = table
.query()
.full_text_search(FullTextSearchQuery::new("test".to_owned()))
.execute()
.await;
let data = data.unwrap().collect::<Vec<_>>().await;
assert_eq!(data.len(), 1);
assert_eq!(data[0].as_ref().unwrap(), &expected_data);
}
#[tokio::test]
async fn test_query_vector_all_params() {
let table = Table::new_with_handler("my_table", |request| {
@@ -1422,6 +1471,7 @@ mod tests {
"k": 10,
"vector": [],
"with_row_id": true,
"prefilter": true,
"version": null
});
assert_eq!(body, expected_body);