From 8877eb020db7c7c0f772cacea268de249d7bfcfd Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Thu, 27 Feb 2025 15:55:59 +0800 Subject: [PATCH] feat: record the server version for remote table (#2147) Signed-off-by: BubbleCal --- Cargo.lock | 38 ++++++ Cargo.toml | 1 + python/python/tests/test_remote_db.py | 24 ++-- rust/lancedb/Cargo.toml | 2 + rust/lancedb/src/remote/db.rs | 89 ++++++++++---- rust/lancedb/src/remote/table.rs | 161 +++++++++++++++++++------- rust/lancedb/src/remote/util.rs | 24 ++++ rust/lancedb/src/table.rs | 21 ++++ 8 files changed, 287 insertions(+), 73 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d6864598..48fc174e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3969,6 +3969,8 @@ dependencies = [ "random_word", "regex", "reqwest", + "rstest", + "semver 1.0.25", "serde", "serde_json", "serde_with", @@ -5938,6 +5940,12 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" +[[package]] +name = "relative-path" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2" + [[package]] name = "reqwest" version = "0.12.12" @@ -6043,6 +6051,36 @@ dependencies = [ "byteorder", ] +[[package]] +name = "rstest" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a2c585be59b6b5dd66a9d2084aa1d8bd52fbdb806eafdeffb52791147862035" +dependencies = [ + "futures", + "futures-timer", + "rstest_macros", + "rustc_version", +] + +[[package]] +name = "rstest_macros" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "825ea780781b15345a146be27eaefb05085e337e869bff01b4306a4fd4a9ad5a" +dependencies = [ + "cfg-if", + "glob", + "proc-macro-crate", + "proc-macro2", + "quote", + "regex", + "relative-path", + "rustc_version", + "syn 2.0.98", + "unicode-ident", +] + [[package]] name = "rust-stemmers" version = "1.2.0" diff --git a/Cargo.toml b/Cargo.toml index 7b31e3eb..0befcad6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -62,6 +62,7 @@ num-traits = "0.2" rand = "0.8" regex = "1.10" lazy_static = "1" +semver = "1.0.25" # Temporary pins to work around downstream issues # https://github.com/apache/arrow-rs/commit/2fddf85afcd20110ce783ed5b4cdeb82293da30b diff --git a/python/python/tests/test_remote_db.py b/python/python/tests/test_remote_db.py index 34590259..b46d6880 100644 --- a/python/python/tests/test_remote_db.py +++ b/python/python/tests/test_remote_db.py @@ -9,6 +9,7 @@ import json import threading from unittest.mock import MagicMock import uuid +from packaging.version import Version import lancedb from lancedb.conftest import MockTextEmbeddingFunction @@ -277,11 +278,12 @@ def test_table_create_indices(): @contextlib.contextmanager -def query_test_table(query_handler): +def query_test_table(query_handler, *, server_version=Version("0.1.0")): def handler(request): if request.path == "/v1/table/test/describe/": request.send_response(200) request.send_header("Content-Type", "application/json") + request.send_header("phalanx-version", str(server_version)) request.end_headers() request.wfile.write(b"{}") elif request.path == "/v1/table/test/query/": @@ -388,17 +390,25 @@ def test_query_sync_maximal(): ) -def test_query_sync_multiple_vectors(): +@pytest.mark.parametrize("server_version", [Version("0.1.0"), Version("0.2.0")]) +def test_query_sync_batch_queries(server_version): def handler(body): # TODO: we will add the ability to get the server version, # so that we can decide how to perform batch quires. vectors = body["vector"] - res = [] - for i, vector in enumerate(vectors): - res.append({"id": 1, "query_index": i}) - return pa.Table.from_pylist(res) + if server_version >= Version( + "0.2.0" + ): # we can handle batch queries in single request since 0.2.0 + assert len(vectors) == 2 + res = [] + for i, vector in enumerate(vectors): + res.append({"id": 1, "query_index": i}) + return pa.Table.from_pylist(res) + else: + assert len(vectors) == 3 # matching dim + return pa.table({"id": [1]}) - with query_test_table(handler) as table: + with query_test_table(handler, server_version=server_version) as table: results = table.search([[1, 2, 3], [4, 5, 6]]).limit(1).to_list() assert len(results) == 2 results.sort(key=lambda x: x["query_index"]) diff --git a/rust/lancedb/Cargo.toml b/rust/lancedb/Cargo.toml index da052ff2..f9a8fab6 100644 --- a/rust/lancedb/Cargo.toml +++ b/rust/lancedb/Cargo.toml @@ -70,6 +70,7 @@ candle-core = { version = "0.6.0", optional = true } candle-transformers = { version = "0.6.0", optional = true } candle-nn = { version = "0.6.0", optional = true } tokenizers = { version = "0.19.1", optional = true } +semver = { workspace = true } # For a workaround, see workspace Cargo.toml crunchy.workspace = true @@ -87,6 +88,7 @@ aws-config = { version = "1.0" } aws-smithy-runtime = { version = "1.3" } datafusion.workspace = true http-body = "1" # Matching reqwest +rstest = "0.23.0" [features] diff --git a/rust/lancedb/src/remote/db.rs b/rust/lancedb/src/remote/db.rs index efa886db..7d4d8da0 100644 --- a/rust/lancedb/src/remote/db.rs +++ b/rust/lancedb/src/remote/db.rs @@ -19,12 +19,41 @@ use crate::database::{ }; use crate::error::Result; use crate::table::BaseTable; +use crate::Error; use super::client::{ClientConfig, HttpSend, RequestResultExt, RestfulLanceDbClient, Sender}; use super::table::RemoteTable; -use super::util::batches_to_ipc_bytes; +use super::util::{batches_to_ipc_bytes, parse_server_version}; use super::ARROW_STREAM_CONTENT_TYPE; +// the versions of the server that we support +// for any new feature that we need to change the SDK behavior, we should bump the server version, +// and add a feature flag as method of `ServerVersion` here. +pub const DEFAULT_SERVER_VERSION: semver::Version = semver::Version::new(0, 1, 0); +#[derive(Debug, Clone)] +pub struct ServerVersion(pub semver::Version); + +impl Default for ServerVersion { + fn default() -> Self { + Self(DEFAULT_SERVER_VERSION.clone()) + } +} + +impl ServerVersion { + pub fn parse(version: &str) -> Result { + let version = Self( + semver::Version::parse(version).map_err(|e| Error::InvalidInput { + message: e.to_string(), + })?, + ); + Ok(version) + } + + pub fn support_multivector(&self) -> bool { + self.0 >= semver::Version::new(0, 2, 0) + } +} + #[derive(Deserialize)] struct ListTablesResponse { tables: Vec, @@ -33,7 +62,7 @@ struct ListTablesResponse { #[derive(Debug)] pub struct RemoteDatabase { client: RestfulLanceDbClient, - table_cache: Cache, + table_cache: Cache>>, } impl RemoteDatabase { @@ -115,13 +144,19 @@ impl Database for RemoteDatabase { } let (request_id, rsp) = self.client.send(req, true).await?; let rsp = self.client.check_response(&request_id, rsp).await?; + let version = parse_server_version(&request_id, &rsp)?; let tables = rsp .json::() .await .err_to_http(request_id)? .tables; for table in &tables { - self.table_cache.insert(table.clone(), ()).await; + let remote_table = Arc::new(RemoteTable::new( + self.client.clone(), + table.clone(), + version.clone(), + )); + self.table_cache.insert(table.clone(), remote_table).await; } Ok(tables) } @@ -187,34 +222,42 @@ impl Database for RemoteDatabase { return Err(crate::Error::InvalidInput { message: body }); } } - - self.client.check_response(&request_id, rsp).await?; - - self.table_cache.insert(request.name.clone(), ()).await; - - Ok(Arc::new(RemoteTable::new( + let rsp = self.client.check_response(&request_id, rsp).await?; + let version = parse_server_version(&request_id, &rsp)?; + let table = Arc::new(RemoteTable::new( self.client.clone(), - request.name, - ))) + request.name.clone(), + version, + )); + self.table_cache + .insert(request.name.clone(), table.clone()) + .await; + + Ok(table) } async fn open_table(&self, request: OpenTableRequest) -> Result> { // We describe the table to confirm it exists before moving on. - if self.table_cache.get(&request.name).await.is_none() { + if let Some(table) = self.table_cache.get(&request.name).await { + Ok(table.clone()) + } else { let req = self .client .post(&format!("/v1/table/{}/describe/", request.name)); - let (request_id, resp) = self.client.send(req, true).await?; - if resp.status() == StatusCode::NOT_FOUND { + let (request_id, rsp) = self.client.send(req, true).await?; + if rsp.status() == StatusCode::NOT_FOUND { return Err(crate::Error::TableNotFound { name: request.name }); } - self.client.check_response(&request_id, resp).await?; + let rsp = self.client.check_response(&request_id, rsp).await?; + let version = parse_server_version(&request_id, &rsp)?; + let table = Arc::new(RemoteTable::new( + self.client.clone(), + request.name.clone(), + version, + )); + self.table_cache.insert(request.name, table.clone()).await; + Ok(table) } - - Ok(Arc::new(RemoteTable::new( - self.client.clone(), - request.name, - ))) } async fn rename_table(&self, current_name: &str, new_name: &str) -> Result<()> { @@ -224,8 +267,10 @@ impl Database for RemoteDatabase { let req = req.json(&serde_json::json!({ "new_table_name": new_name })); let (request_id, resp) = self.client.send(req, false).await?; self.client.check_response(&request_id, resp).await?; - self.table_cache.remove(current_name).await; - self.table_cache.insert(new_name.into(), ()).await; + let table = self.table_cache.remove(current_name).await; + if let Some(table) = table { + self.table_cache.insert(new_name.into(), table).await; + } Ok(()) } diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index 88ee3f45..b11dd496 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -10,7 +10,7 @@ use crate::index::IndexStatistics; use crate::query::{QueryFilter, QueryRequest, Select, VectorQueryRequest}; use crate::table::{AddDataMode, AnyQuery, Filter}; use crate::utils::{supported_btree_data_type, supported_vector_data_type}; -use crate::{DistanceType, Error}; +use crate::{DistanceType, Error, Table}; use arrow_array::RecordBatchReader; use arrow_ipc::reader::FileReader; use arrow_schema::{DataType, SchemaRef}; @@ -24,7 +24,7 @@ use http::StatusCode; use lance::arrow::json::{JsonDataType, JsonSchema}; use lance::dataset::scanner::DatasetRecordBatchStream; use lance::dataset::{ColumnAlteration, NewColumnTransform, Version}; -use lance_datafusion::exec::OneShotExec; +use lance_datafusion::exec::{execute_plan, OneShotExec}; use serde::{Deserialize, Serialize}; use tokio::sync::RwLock; @@ -41,6 +41,7 @@ use crate::{ use super::client::RequestResultExt; use super::client::{HttpSend, RestfulLanceDbClient, Sender}; +use super::db::ServerVersion; use super::ARROW_STREAM_CONTENT_TYPE; #[derive(Debug)] @@ -48,15 +49,21 @@ pub struct RemoteTable { #[allow(dead_code)] client: RestfulLanceDbClient, name: String, + server_version: ServerVersion, version: RwLock>, } impl RemoteTable { - pub fn new(client: RestfulLanceDbClient, name: String) -> Self { + pub fn new( + client: RestfulLanceDbClient, + name: String, + server_version: ServerVersion, + ) -> Self { Self { client, name, + server_version, version: RwLock::new(None), } } @@ -212,10 +219,11 @@ impl RemoteTable { } fn apply_vector_query_params( - body: &mut serde_json::Value, + &self, + mut body: serde_json::Value, query: &VectorQueryRequest, - ) -> Result<()> { - Self::apply_query_params(body, &query.base)?; + ) -> Result> { + Self::apply_query_params(&mut body, &query.base)?; // Apply general parameters, before we dispatch based on number of query vectors. body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default()); @@ -256,25 +264,40 @@ impl RemoteTable { } } - match query.query_vector.len() { + let bodies = match query.query_vector.len() { 0 => { // Server takes empty vector, not null or undefined. body["vector"] = serde_json::Value::Array(Vec::new()); + vec![body] } 1 => { body["vector"] = vector_to_json(&query.query_vector[0])?; + vec![body] } _ => { - let vectors = query - .query_vector - .iter() - .map(vector_to_json) - .collect::>>()?; - body["vector"] = serde_json::Value::Array(vectors); + if self.server_version.support_multivector() { + let vectors = query + .query_vector + .iter() + .map(vector_to_json) + .collect::>>()?; + body["vector"] = serde_json::Value::Array(vectors); + vec![body] + } else { + // Server does not support multiple vectors in a single query. + // We need to send multiple requests. + let mut bodies = Vec::with_capacity(query.query_vector.len()); + for vector in &query.query_vector { + let mut body = body.clone(); + body["vector"] = vector_to_json(vector)?; + bodies.push(body); + } + bodies + } } - } + }; - Ok(()) + Ok(bodies) } async fn check_mutable(&self) -> Result<()> { @@ -299,27 +322,34 @@ impl RemoteTable { &self, query: &AnyQuery, _options: QueryExecutionOptions, - ) -> Result>> { + ) -> Result>>> { let request = self.client.post(&format!("/v1/table/{}/query/", self.name)); let version = self.current_version().await; let mut body = serde_json::json!({ "version": version }); - match query { + let requests = match query { AnyQuery::Query(query) => { Self::apply_query_params(&mut body, query)?; // Empty vector can be passed if no vector search is performed. body["vector"] = serde_json::Value::Array(Vec::new()); + vec![request.json(&body)] } AnyQuery::VectorQuery(query) => { - Self::apply_vector_query_params(&mut body, query)?; + let bodies = self.apply_vector_query_params(body, query)?; + bodies + .into_iter() + .map(|body| request.try_clone().unwrap().json(&body)) + .collect() } - } + }; - let request = request.json(&body); - let (request_id, response) = self.client.send(request, true).await?; - let stream = self.read_arrow_stream(&request_id, response).await?; - Ok(stream) + let futures = requests.into_iter().map(|req| async move { + let (request_id, response) = self.client.send(req, true).await?; + self.read_arrow_stream(&request_id, response).await + }); + let streams = futures::future::try_join_all(futures).await?; + Ok(streams) } } @@ -342,7 +372,7 @@ mod test_utils { use crate::remote::client::test_utils::MockSender; impl RemoteTable { - pub fn new_mock(name: String, handler: F) -> Self + pub fn new_mock(name: String, handler: F, version: Option) -> Self where F: Fn(reqwest::Request) -> http::Response + Send + Sync + 'static, T: Into, @@ -351,6 +381,7 @@ mod test_utils { Self { client, name, + server_version: version.map(ServerVersion).unwrap_or_default(), version: RwLock::new(None), } } @@ -491,8 +522,17 @@ impl BaseTable for RemoteTable { query: &AnyQuery, options: QueryExecutionOptions, ) -> Result> { - let stream = self.execute_query(query, options).await?; - Ok(Arc::new(OneShotExec::new(stream))) + let streams = self.execute_query(query, options).await?; + if streams.len() == 1 { + let stream = streams.into_iter().next().unwrap(); + Ok(Arc::new(OneShotExec::new(stream))) + } else { + let stream_execs = streams + .into_iter() + .map(|stream| Arc::new(OneShotExec::new(stream)) as Arc) + .collect(); + Table::multi_vector_plan(stream_execs) + } } async fn query( @@ -500,8 +540,24 @@ impl BaseTable for RemoteTable { query: &AnyQuery, _options: QueryExecutionOptions, ) -> Result { - let stream = self.execute_query(query, _options).await?; - Ok(DatasetRecordBatchStream::new(stream)) + let streams = self.execute_query(query, _options).await?; + + if streams.len() == 1 { + Ok(DatasetRecordBatchStream::new( + streams.into_iter().next().unwrap(), + )) + } else { + let stream_execs = streams + .into_iter() + .map(|stream| Arc::new(OneShotExec::new(stream)) as Arc) + .collect(); + let plan = Table::multi_vector_plan(stream_execs)?; + + Ok(DatasetRecordBatchStream::new(execute_plan( + plan, + Default::default(), + )?)) + } } async fn update(&self, update: UpdateBuilder) -> Result { self.check_mutable().await?; @@ -884,8 +940,10 @@ mod tests { use futures::{future::BoxFuture, StreamExt, TryFutureExt}; use lance_index::scalar::FullTextSearchQuery; use reqwest::Body; + use rstest::rstest; use crate::index::vector::IvfFlatIndexBuilder; + use crate::remote::db::DEFAULT_SERVER_VERSION; use crate::remote::JSON_CONTENT_TYPE; use crate::{ index::{vector::IvfPqIndexBuilder, Index, IndexStatistics, IndexType}, @@ -1554,9 +1612,12 @@ mod tests { .unwrap(); } + #[rstest] + #[case(DEFAULT_SERVER_VERSION.clone())] + #[case(semver::Version::new(0, 2, 0))] #[tokio::test] - async fn test_query_multiple_vectors() { - let table = Table::new_with_handler("my_table", |request| { + async fn test_batch_queries(#[case] version: semver::Version) { + let table = Table::new_with_handler_version("my_table", version.clone(), move |request| { assert_eq!(request.method(), "POST"); assert_eq!(request.url().path(), "/v1/table/my_table/query/"); assert_eq!( @@ -1566,20 +1627,32 @@ mod tests { let body: serde_json::Value = serde_json::from_slice(request.body().unwrap().as_bytes().unwrap()).unwrap(); let query_vectors = body["vector"].as_array().unwrap(); - assert_eq!(query_vectors.len(), 2); - assert_eq!(query_vectors[0].as_array().unwrap().len(), 3); - assert_eq!(query_vectors[1].as_array().unwrap().len(), 3); - let data = RecordBatch::try_new( - Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("query_index", DataType::Int32, false), - ])), - vec![ - Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6])), - Arc::new(Int32Array::from(vec![0, 0, 0, 1, 1, 1])), - ], - ) - .unwrap(); + let version = ServerVersion(version.clone()); + let data = if version.support_multivector() { + assert_eq!(query_vectors.len(), 2); + assert_eq!(query_vectors[0].as_array().unwrap().len(), 3); + assert_eq!(query_vectors[1].as_array().unwrap().len(), 3); + RecordBatch::try_new( + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("query_index", DataType::Int32, false), + ])), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6])), + Arc::new(Int32Array::from(vec![0, 0, 0, 1, 1, 1])), + ], + ) + .unwrap() + } else { + // it's single flat vector, so here the length is dim + assert_eq!(query_vectors.len(), 3); + RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap() + }; + let response_body = write_ipc_file(&data); http::Response::builder() .status(200) diff --git a/rust/lancedb/src/remote/util.rs b/rust/lancedb/src/remote/util.rs index b9fe824c..4b92ec0c 100644 --- a/rust/lancedb/src/remote/util.rs +++ b/rust/lancedb/src/remote/util.rs @@ -4,9 +4,12 @@ use std::io::Cursor; use arrow_array::RecordBatchReader; +use reqwest::Response; use crate::Result; +use super::db::ServerVersion; + pub fn batches_to_ipc_bytes(batches: impl RecordBatchReader) -> Result> { const WRITE_BUF_SIZE: usize = 4096; let buf = Vec::with_capacity(WRITE_BUF_SIZE); @@ -22,3 +25,24 @@ pub fn batches_to_ipc_bytes(batches: impl RecordBatchReader) -> Result> } Ok(buf.into_inner()) } + +pub fn parse_server_version(req_id: &str, rsp: &Response) -> Result { + let version = rsp + .headers() + .get("phalanx-version") + .map(|v| { + let v = v.to_str().map_err(|e| crate::Error::Http { + source: e.into(), + request_id: req_id.to_string(), + status_code: Some(rsp.status()), + })?; + ServerVersion::parse(v).map_err(|e| crate::Error::Http { + source: e.into(), + request_id: req_id.to_string(), + status_code: Some(rsp.status()), + }) + }) + .transpose()? + .unwrap_or_default(); + Ok(version) +} diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index 4ba9d6c0..0747df90 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -509,6 +509,27 @@ mod test_utils { let inner = Arc::new(crate::remote::table::RemoteTable::new_mock( name.into(), handler, + None, + )); + Self { + inner, + // Registry is unused. + embedding_registry: Arc::new(MemoryRegistry::new()), + } + } + + pub fn new_with_handler_version( + name: impl Into, + version: semver::Version, + handler: impl Fn(reqwest::Request) -> http::Response + Clone + Send + Sync + 'static, + ) -> Self + where + T: Into, + { + let inner = Arc::new(crate::remote::table::RemoteTable::new_mock( + name.into(), + handler, + Some(version), )); Self { inner,