mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-09 13:22:58 +00:00
feat: record the server version for remote table (#2147)
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
This commit is contained in:
@@ -70,6 +70,7 @@ candle-core = { version = "0.6.0", optional = true }
|
||||
candle-transformers = { version = "0.6.0", optional = true }
|
||||
candle-nn = { version = "0.6.0", optional = true }
|
||||
tokenizers = { version = "0.19.1", optional = true }
|
||||
semver = { workspace = true }
|
||||
|
||||
# For a workaround, see workspace Cargo.toml
|
||||
crunchy.workspace = true
|
||||
@@ -87,6 +88,7 @@ aws-config = { version = "1.0" }
|
||||
aws-smithy-runtime = { version = "1.3" }
|
||||
datafusion.workspace = true
|
||||
http-body = "1" # Matching reqwest
|
||||
rstest = "0.23.0"
|
||||
|
||||
|
||||
[features]
|
||||
|
||||
@@ -19,12 +19,41 @@ use crate::database::{
|
||||
};
|
||||
use crate::error::Result;
|
||||
use crate::table::BaseTable;
|
||||
use crate::Error;
|
||||
|
||||
use super::client::{ClientConfig, HttpSend, RequestResultExt, RestfulLanceDbClient, Sender};
|
||||
use super::table::RemoteTable;
|
||||
use super::util::batches_to_ipc_bytes;
|
||||
use super::util::{batches_to_ipc_bytes, parse_server_version};
|
||||
use super::ARROW_STREAM_CONTENT_TYPE;
|
||||
|
||||
// the versions of the server that we support
|
||||
// for any new feature that we need to change the SDK behavior, we should bump the server version,
|
||||
// and add a feature flag as method of `ServerVersion` here.
|
||||
pub const DEFAULT_SERVER_VERSION: semver::Version = semver::Version::new(0, 1, 0);
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ServerVersion(pub semver::Version);
|
||||
|
||||
impl Default for ServerVersion {
|
||||
fn default() -> Self {
|
||||
Self(DEFAULT_SERVER_VERSION.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerVersion {
|
||||
pub fn parse(version: &str) -> Result<Self> {
|
||||
let version = Self(
|
||||
semver::Version::parse(version).map_err(|e| Error::InvalidInput {
|
||||
message: e.to_string(),
|
||||
})?,
|
||||
);
|
||||
Ok(version)
|
||||
}
|
||||
|
||||
pub fn support_multivector(&self) -> bool {
|
||||
self.0 >= semver::Version::new(0, 2, 0)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ListTablesResponse {
|
||||
tables: Vec<String>,
|
||||
@@ -33,7 +62,7 @@ struct ListTablesResponse {
|
||||
#[derive(Debug)]
|
||||
pub struct RemoteDatabase<S: HttpSend = Sender> {
|
||||
client: RestfulLanceDbClient<S>,
|
||||
table_cache: Cache<String, ()>,
|
||||
table_cache: Cache<String, Arc<RemoteTable<S>>>,
|
||||
}
|
||||
|
||||
impl RemoteDatabase {
|
||||
@@ -115,13 +144,19 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
|
||||
}
|
||||
let (request_id, rsp) = self.client.send(req, true).await?;
|
||||
let rsp = self.client.check_response(&request_id, rsp).await?;
|
||||
let version = parse_server_version(&request_id, &rsp)?;
|
||||
let tables = rsp
|
||||
.json::<ListTablesResponse>()
|
||||
.await
|
||||
.err_to_http(request_id)?
|
||||
.tables;
|
||||
for table in &tables {
|
||||
self.table_cache.insert(table.clone(), ()).await;
|
||||
let remote_table = Arc::new(RemoteTable::new(
|
||||
self.client.clone(),
|
||||
table.clone(),
|
||||
version.clone(),
|
||||
));
|
||||
self.table_cache.insert(table.clone(), remote_table).await;
|
||||
}
|
||||
Ok(tables)
|
||||
}
|
||||
@@ -187,34 +222,42 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
|
||||
return Err(crate::Error::InvalidInput { message: body });
|
||||
}
|
||||
}
|
||||
|
||||
self.client.check_response(&request_id, rsp).await?;
|
||||
|
||||
self.table_cache.insert(request.name.clone(), ()).await;
|
||||
|
||||
Ok(Arc::new(RemoteTable::new(
|
||||
let rsp = self.client.check_response(&request_id, rsp).await?;
|
||||
let version = parse_server_version(&request_id, &rsp)?;
|
||||
let table = Arc::new(RemoteTable::new(
|
||||
self.client.clone(),
|
||||
request.name,
|
||||
)))
|
||||
request.name.clone(),
|
||||
version,
|
||||
));
|
||||
self.table_cache
|
||||
.insert(request.name.clone(), table.clone())
|
||||
.await;
|
||||
|
||||
Ok(table)
|
||||
}
|
||||
|
||||
async fn open_table(&self, request: OpenTableRequest) -> Result<Arc<dyn BaseTable>> {
|
||||
// We describe the table to confirm it exists before moving on.
|
||||
if self.table_cache.get(&request.name).await.is_none() {
|
||||
if let Some(table) = self.table_cache.get(&request.name).await {
|
||||
Ok(table.clone())
|
||||
} else {
|
||||
let req = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/describe/", request.name));
|
||||
let (request_id, resp) = self.client.send(req, true).await?;
|
||||
if resp.status() == StatusCode::NOT_FOUND {
|
||||
let (request_id, rsp) = self.client.send(req, true).await?;
|
||||
if rsp.status() == StatusCode::NOT_FOUND {
|
||||
return Err(crate::Error::TableNotFound { name: request.name });
|
||||
}
|
||||
self.client.check_response(&request_id, resp).await?;
|
||||
let rsp = self.client.check_response(&request_id, rsp).await?;
|
||||
let version = parse_server_version(&request_id, &rsp)?;
|
||||
let table = Arc::new(RemoteTable::new(
|
||||
self.client.clone(),
|
||||
request.name.clone(),
|
||||
version,
|
||||
));
|
||||
self.table_cache.insert(request.name, table.clone()).await;
|
||||
Ok(table)
|
||||
}
|
||||
|
||||
Ok(Arc::new(RemoteTable::new(
|
||||
self.client.clone(),
|
||||
request.name,
|
||||
)))
|
||||
}
|
||||
|
||||
async fn rename_table(&self, current_name: &str, new_name: &str) -> Result<()> {
|
||||
@@ -224,8 +267,10 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
|
||||
let req = req.json(&serde_json::json!({ "new_table_name": new_name }));
|
||||
let (request_id, resp) = self.client.send(req, false).await?;
|
||||
self.client.check_response(&request_id, resp).await?;
|
||||
self.table_cache.remove(current_name).await;
|
||||
self.table_cache.insert(new_name.into(), ()).await;
|
||||
let table = self.table_cache.remove(current_name).await;
|
||||
if let Some(table) = table {
|
||||
self.table_cache.insert(new_name.into(), table).await;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ use crate::index::IndexStatistics;
|
||||
use crate::query::{QueryFilter, QueryRequest, Select, VectorQueryRequest};
|
||||
use crate::table::{AddDataMode, AnyQuery, Filter};
|
||||
use crate::utils::{supported_btree_data_type, supported_vector_data_type};
|
||||
use crate::{DistanceType, Error};
|
||||
use crate::{DistanceType, Error, Table};
|
||||
use arrow_array::RecordBatchReader;
|
||||
use arrow_ipc::reader::FileReader;
|
||||
use arrow_schema::{DataType, SchemaRef};
|
||||
@@ -24,7 +24,7 @@ use http::StatusCode;
|
||||
use lance::arrow::json::{JsonDataType, JsonSchema};
|
||||
use lance::dataset::scanner::DatasetRecordBatchStream;
|
||||
use lance::dataset::{ColumnAlteration, NewColumnTransform, Version};
|
||||
use lance_datafusion::exec::OneShotExec;
|
||||
use lance_datafusion::exec::{execute_plan, OneShotExec};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
@@ -41,6 +41,7 @@ use crate::{
|
||||
|
||||
use super::client::RequestResultExt;
|
||||
use super::client::{HttpSend, RestfulLanceDbClient, Sender};
|
||||
use super::db::ServerVersion;
|
||||
use super::ARROW_STREAM_CONTENT_TYPE;
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -48,15 +49,21 @@ pub struct RemoteTable<S: HttpSend = Sender> {
|
||||
#[allow(dead_code)]
|
||||
client: RestfulLanceDbClient<S>,
|
||||
name: String,
|
||||
server_version: ServerVersion,
|
||||
|
||||
version: RwLock<Option<u64>>,
|
||||
}
|
||||
|
||||
impl<S: HttpSend> RemoteTable<S> {
|
||||
pub fn new(client: RestfulLanceDbClient<S>, name: String) -> Self {
|
||||
pub fn new(
|
||||
client: RestfulLanceDbClient<S>,
|
||||
name: String,
|
||||
server_version: ServerVersion,
|
||||
) -> Self {
|
||||
Self {
|
||||
client,
|
||||
name,
|
||||
server_version,
|
||||
version: RwLock::new(None),
|
||||
}
|
||||
}
|
||||
@@ -212,10 +219,11 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
}
|
||||
|
||||
fn apply_vector_query_params(
|
||||
body: &mut serde_json::Value,
|
||||
&self,
|
||||
mut body: serde_json::Value,
|
||||
query: &VectorQueryRequest,
|
||||
) -> Result<()> {
|
||||
Self::apply_query_params(body, &query.base)?;
|
||||
) -> Result<Vec<serde_json::Value>> {
|
||||
Self::apply_query_params(&mut body, &query.base)?;
|
||||
|
||||
// Apply general parameters, before we dispatch based on number of query vectors.
|
||||
body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default());
|
||||
@@ -256,25 +264,40 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
}
|
||||
}
|
||||
|
||||
match query.query_vector.len() {
|
||||
let bodies = match query.query_vector.len() {
|
||||
0 => {
|
||||
// Server takes empty vector, not null or undefined.
|
||||
body["vector"] = serde_json::Value::Array(Vec::new());
|
||||
vec![body]
|
||||
}
|
||||
1 => {
|
||||
body["vector"] = vector_to_json(&query.query_vector[0])?;
|
||||
vec![body]
|
||||
}
|
||||
_ => {
|
||||
let vectors = query
|
||||
.query_vector
|
||||
.iter()
|
||||
.map(vector_to_json)
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
body["vector"] = serde_json::Value::Array(vectors);
|
||||
if self.server_version.support_multivector() {
|
||||
let vectors = query
|
||||
.query_vector
|
||||
.iter()
|
||||
.map(vector_to_json)
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
body["vector"] = serde_json::Value::Array(vectors);
|
||||
vec![body]
|
||||
} else {
|
||||
// Server does not support multiple vectors in a single query.
|
||||
// We need to send multiple requests.
|
||||
let mut bodies = Vec::with_capacity(query.query_vector.len());
|
||||
for vector in &query.query_vector {
|
||||
let mut body = body.clone();
|
||||
body["vector"] = vector_to_json(vector)?;
|
||||
bodies.push(body);
|
||||
}
|
||||
bodies
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(())
|
||||
Ok(bodies)
|
||||
}
|
||||
|
||||
async fn check_mutable(&self) -> Result<()> {
|
||||
@@ -299,27 +322,34 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
&self,
|
||||
query: &AnyQuery,
|
||||
_options: QueryExecutionOptions,
|
||||
) -> Result<Pin<Box<dyn RecordBatchStream + Send>>> {
|
||||
) -> Result<Vec<Pin<Box<dyn RecordBatchStream + Send>>>> {
|
||||
let request = self.client.post(&format!("/v1/table/{}/query/", self.name));
|
||||
|
||||
let version = self.current_version().await;
|
||||
let mut body = serde_json::json!({ "version": version });
|
||||
|
||||
match query {
|
||||
let requests = match query {
|
||||
AnyQuery::Query(query) => {
|
||||
Self::apply_query_params(&mut body, query)?;
|
||||
// Empty vector can be passed if no vector search is performed.
|
||||
body["vector"] = serde_json::Value::Array(Vec::new());
|
||||
vec![request.json(&body)]
|
||||
}
|
||||
AnyQuery::VectorQuery(query) => {
|
||||
Self::apply_vector_query_params(&mut body, query)?;
|
||||
let bodies = self.apply_vector_query_params(body, query)?;
|
||||
bodies
|
||||
.into_iter()
|
||||
.map(|body| request.try_clone().unwrap().json(&body))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let request = request.json(&body);
|
||||
let (request_id, response) = self.client.send(request, true).await?;
|
||||
let stream = self.read_arrow_stream(&request_id, response).await?;
|
||||
Ok(stream)
|
||||
let futures = requests.into_iter().map(|req| async move {
|
||||
let (request_id, response) = self.client.send(req, true).await?;
|
||||
self.read_arrow_stream(&request_id, response).await
|
||||
});
|
||||
let streams = futures::future::try_join_all(futures).await?;
|
||||
Ok(streams)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -342,7 +372,7 @@ mod test_utils {
|
||||
use crate::remote::client::test_utils::MockSender;
|
||||
|
||||
impl RemoteTable<MockSender> {
|
||||
pub fn new_mock<F, T>(name: String, handler: F) -> Self
|
||||
pub fn new_mock<F, T>(name: String, handler: F, version: Option<semver::Version>) -> Self
|
||||
where
|
||||
F: Fn(reqwest::Request) -> http::Response<T> + Send + Sync + 'static,
|
||||
T: Into<reqwest::Body>,
|
||||
@@ -351,6 +381,7 @@ mod test_utils {
|
||||
Self {
|
||||
client,
|
||||
name,
|
||||
server_version: version.map(ServerVersion).unwrap_or_default(),
|
||||
version: RwLock::new(None),
|
||||
}
|
||||
}
|
||||
@@ -491,8 +522,17 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
query: &AnyQuery,
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<Arc<dyn ExecutionPlan>> {
|
||||
let stream = self.execute_query(query, options).await?;
|
||||
Ok(Arc::new(OneShotExec::new(stream)))
|
||||
let streams = self.execute_query(query, options).await?;
|
||||
if streams.len() == 1 {
|
||||
let stream = streams.into_iter().next().unwrap();
|
||||
Ok(Arc::new(OneShotExec::new(stream)))
|
||||
} else {
|
||||
let stream_execs = streams
|
||||
.into_iter()
|
||||
.map(|stream| Arc::new(OneShotExec::new(stream)) as Arc<dyn ExecutionPlan>)
|
||||
.collect();
|
||||
Table::multi_vector_plan(stream_execs)
|
||||
}
|
||||
}
|
||||
|
||||
async fn query(
|
||||
@@ -500,8 +540,24 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
query: &AnyQuery,
|
||||
_options: QueryExecutionOptions,
|
||||
) -> Result<DatasetRecordBatchStream> {
|
||||
let stream = self.execute_query(query, _options).await?;
|
||||
Ok(DatasetRecordBatchStream::new(stream))
|
||||
let streams = self.execute_query(query, _options).await?;
|
||||
|
||||
if streams.len() == 1 {
|
||||
Ok(DatasetRecordBatchStream::new(
|
||||
streams.into_iter().next().unwrap(),
|
||||
))
|
||||
} else {
|
||||
let stream_execs = streams
|
||||
.into_iter()
|
||||
.map(|stream| Arc::new(OneShotExec::new(stream)) as Arc<dyn ExecutionPlan>)
|
||||
.collect();
|
||||
let plan = Table::multi_vector_plan(stream_execs)?;
|
||||
|
||||
Ok(DatasetRecordBatchStream::new(execute_plan(
|
||||
plan,
|
||||
Default::default(),
|
||||
)?))
|
||||
}
|
||||
}
|
||||
async fn update(&self, update: UpdateBuilder) -> Result<u64> {
|
||||
self.check_mutable().await?;
|
||||
@@ -884,8 +940,10 @@ mod tests {
|
||||
use futures::{future::BoxFuture, StreamExt, TryFutureExt};
|
||||
use lance_index::scalar::FullTextSearchQuery;
|
||||
use reqwest::Body;
|
||||
use rstest::rstest;
|
||||
|
||||
use crate::index::vector::IvfFlatIndexBuilder;
|
||||
use crate::remote::db::DEFAULT_SERVER_VERSION;
|
||||
use crate::remote::JSON_CONTENT_TYPE;
|
||||
use crate::{
|
||||
index::{vector::IvfPqIndexBuilder, Index, IndexStatistics, IndexType},
|
||||
@@ -1554,9 +1612,12 @@ mod tests {
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[case(DEFAULT_SERVER_VERSION.clone())]
|
||||
#[case(semver::Version::new(0, 2, 0))]
|
||||
#[tokio::test]
|
||||
async fn test_query_multiple_vectors() {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
async fn test_batch_queries(#[case] version: semver::Version) {
|
||||
let table = Table::new_with_handler_version("my_table", version.clone(), move |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/v1/table/my_table/query/");
|
||||
assert_eq!(
|
||||
@@ -1566,20 +1627,32 @@ mod tests {
|
||||
let body: serde_json::Value =
|
||||
serde_json::from_slice(request.body().unwrap().as_bytes().unwrap()).unwrap();
|
||||
let query_vectors = body["vector"].as_array().unwrap();
|
||||
assert_eq!(query_vectors.len(), 2);
|
||||
assert_eq!(query_vectors[0].as_array().unwrap().len(), 3);
|
||||
assert_eq!(query_vectors[1].as_array().unwrap().len(), 3);
|
||||
let data = RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![
|
||||
Field::new("a", DataType::Int32, false),
|
||||
Field::new("query_index", DataType::Int32, false),
|
||||
])),
|
||||
vec![
|
||||
Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6])),
|
||||
Arc::new(Int32Array::from(vec![0, 0, 0, 1, 1, 1])),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
let version = ServerVersion(version.clone());
|
||||
let data = if version.support_multivector() {
|
||||
assert_eq!(query_vectors.len(), 2);
|
||||
assert_eq!(query_vectors[0].as_array().unwrap().len(), 3);
|
||||
assert_eq!(query_vectors[1].as_array().unwrap().len(), 3);
|
||||
RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![
|
||||
Field::new("a", DataType::Int32, false),
|
||||
Field::new("query_index", DataType::Int32, false),
|
||||
])),
|
||||
vec![
|
||||
Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6])),
|
||||
Arc::new(Int32Array::from(vec![0, 0, 0, 1, 1, 1])),
|
||||
],
|
||||
)
|
||||
.unwrap()
|
||||
} else {
|
||||
// it's single flat vector, so here the length is dim
|
||||
assert_eq!(query_vectors.len(), 3);
|
||||
RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
|
||||
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
|
||||
)
|
||||
.unwrap()
|
||||
};
|
||||
|
||||
let response_body = write_ipc_file(&data);
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
|
||||
@@ -4,9 +4,12 @@
|
||||
use std::io::Cursor;
|
||||
|
||||
use arrow_array::RecordBatchReader;
|
||||
use reqwest::Response;
|
||||
|
||||
use crate::Result;
|
||||
|
||||
use super::db::ServerVersion;
|
||||
|
||||
pub fn batches_to_ipc_bytes(batches: impl RecordBatchReader) -> Result<Vec<u8>> {
|
||||
const WRITE_BUF_SIZE: usize = 4096;
|
||||
let buf = Vec::with_capacity(WRITE_BUF_SIZE);
|
||||
@@ -22,3 +25,24 @@ pub fn batches_to_ipc_bytes(batches: impl RecordBatchReader) -> Result<Vec<u8>>
|
||||
}
|
||||
Ok(buf.into_inner())
|
||||
}
|
||||
|
||||
pub fn parse_server_version(req_id: &str, rsp: &Response) -> Result<ServerVersion> {
|
||||
let version = rsp
|
||||
.headers()
|
||||
.get("phalanx-version")
|
||||
.map(|v| {
|
||||
let v = v.to_str().map_err(|e| crate::Error::Http {
|
||||
source: e.into(),
|
||||
request_id: req_id.to_string(),
|
||||
status_code: Some(rsp.status()),
|
||||
})?;
|
||||
ServerVersion::parse(v).map_err(|e| crate::Error::Http {
|
||||
source: e.into(),
|
||||
request_id: req_id.to_string(),
|
||||
status_code: Some(rsp.status()),
|
||||
})
|
||||
})
|
||||
.transpose()?
|
||||
.unwrap_or_default();
|
||||
Ok(version)
|
||||
}
|
||||
|
||||
@@ -509,6 +509,27 @@ mod test_utils {
|
||||
let inner = Arc::new(crate::remote::table::RemoteTable::new_mock(
|
||||
name.into(),
|
||||
handler,
|
||||
None,
|
||||
));
|
||||
Self {
|
||||
inner,
|
||||
// Registry is unused.
|
||||
embedding_registry: Arc::new(MemoryRegistry::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_with_handler_version<T>(
|
||||
name: impl Into<String>,
|
||||
version: semver::Version,
|
||||
handler: impl Fn(reqwest::Request) -> http::Response<T> + Clone + Send + Sync + 'static,
|
||||
) -> Self
|
||||
where
|
||||
T: Into<reqwest::Body>,
|
||||
{
|
||||
let inner = Arc::new(crate::remote::table::RemoteTable::new_mock(
|
||||
name.into(),
|
||||
handler,
|
||||
Some(version),
|
||||
));
|
||||
Self {
|
||||
inner,
|
||||
|
||||
Reference in New Issue
Block a user