feat: implement Remote connection for LanceDB Rust (#1639)

* Adding a simple test facility, which allows you to mock a single
endpoint at a time with a closure.
* Implementing all the database-level endpoints

Table-level APIs will be done in a follow up PR.

---------

Co-authored-by: Weston Pace <weston.pace@gmail.com>
This commit is contained in:
Will Jones
2024-09-13 10:53:27 -07:00
committed by GitHub
parent bf7d2d6fb0
commit 1c123b58d8
5 changed files with 368 additions and 60 deletions

View File

@@ -46,6 +46,7 @@ async-openai = { version = "0.20.0", optional = true }
serde_with = { version = "3.8.1" }
# For remote feature
reqwest = { version = "0.11.24", features = ["gzip", "json"], optional = true }
http = { version = "0.2", optional = true } # Matching what is in reqwest
polars-arrow = { version = ">=0.37,<0.40.0", optional = true }
polars = { version = ">=0.37,<0.40.0", optional = true }
hf-hub = { version = "0.3.2", optional = true }
@@ -68,7 +69,7 @@ aws-smithy-runtime = { version = "1.3" }
[features]
default = []
remote = ["dep:reqwest"]
remote = ["dep:reqwest", "dep:http"]
fp16kernels = ["lance-linalg/fp16kernels"]
s3-test = []
openai = ["dep:async-openai", "dep:reqwest"]

View File

@@ -335,8 +335,8 @@ impl<const HAS_DATA: bool, T: IntoArrow> CreateTableBuilder<HAS_DATA, T> {
#[derive(Clone, Debug)]
pub struct OpenTableBuilder {
parent: Arc<dyn ConnectionInternal>,
name: String,
pub(crate) parent: Arc<dyn ConnectionInternal>,
pub(crate) name: String,
index_cache_size: u32,
lance_read_params: Option<ReadParams>,
}
@@ -1095,6 +1095,25 @@ impl ConnectionInternal for Database {
}
}
#[cfg(all(test, feature = "remote"))]
mod test_utils {
use super::*;
impl Connection {
pub fn new_with_handler<T>(
handler: impl Fn(reqwest::Request) -> http::Response<T> + Clone + Send + Sync + 'static,
) -> Self
where
T: Into<reqwest::Body>,
{
let internal = Arc::new(crate::remote::db::RemoteDatabase::new_mock(handler));
Self {
internal,
uri: "db://test".to_string(),
}
}
}
}
#[cfg(test)]
mod tests {
use arrow_schema::{DataType, Field, Schema};
@@ -1208,9 +1227,9 @@ mod tests {
assert_eq!(tables, vec!["table1".to_owned()]);
}
fn make_data() -> impl RecordBatchReader + Send + 'static {
fn make_data() -> Box<dyn RecordBatchReader + Send + 'static> {
let id = Box::new(IncrementingInt32::new().named("id".to_string()));
BatchGenerator::new().col(id).batches(10, 2000)
Box::new(BatchGenerator::new().col(id).batches(10, 2000))
}
#[tokio::test]

View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::time::Duration;
use std::{future::Future, time::Duration};
use reqwest::{
header::{HeaderMap, HeaderValue},
@@ -21,13 +21,66 @@ use reqwest::{
use crate::error::{Error, Result};
// We use the `HttpSend` trait to abstract over the `reqwest::Client` so that
// we can mock responses in tests. Based on the patterns from this blog post:
// https://write.as/balrogboogie/testing-reqwest-based-clients
#[derive(Clone, Debug)]
pub struct RestfulLanceDbClient {
pub struct RestfulLanceDbClient<S: HttpSend = Sender> {
client: reqwest::Client,
host: String,
sender: S,
}
impl RestfulLanceDbClient {
pub trait HttpSend: Clone + Send + Sync + std::fmt::Debug + 'static {
fn send(&self, req: RequestBuilder) -> impl Future<Output = Result<Response>> + Send;
}
// Default implementation of HttpSend which sends the request normally with reqwest
#[derive(Clone, Debug)]
pub struct Sender;
impl HttpSend for Sender {
async fn send(&self, request: reqwest::RequestBuilder) -> Result<reqwest::Response> {
Ok(request.send().await?)
}
}
impl RestfulLanceDbClient<Sender> {
pub fn try_new(
db_url: &str,
api_key: &str,
region: &str,
host_override: Option<String>,
) -> Result<Self> {
let parsed_url = url::Url::parse(db_url)?;
debug_assert_eq!(parsed_url.scheme(), "db");
if !parsed_url.has_host() {
return Err(Error::Http {
message: format!("Invalid database URL (missing host) '{}'", db_url),
});
}
let db_name = parsed_url.host_str().unwrap();
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.default_headers(Self::default_headers(
api_key,
region,
db_name,
host_override.is_some(),
)?)
.build()?;
let host = match host_override {
Some(host_override) => host_override,
None => format!("https://{}.{}.api.lancedb.com", db_name, region),
};
Ok(Self {
client,
host,
sender: Sender,
})
}
}
impl<S: HttpSend> RestfulLanceDbClient<S> {
pub fn host(&self) -> &str {
&self.host
}
@@ -66,36 +119,6 @@ impl RestfulLanceDbClient {
Ok(headers)
}
pub fn try_new(
db_url: &str,
api_key: &str,
region: &str,
host_override: Option<String>,
) -> Result<Self> {
let parsed_url = url::Url::parse(db_url)?;
debug_assert_eq!(parsed_url.scheme(), "db");
if !parsed_url.has_host() {
return Err(Error::Http {
message: format!("Invalid database URL (missing host) '{}'", db_url),
});
}
let db_name = parsed_url.host_str().unwrap();
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.default_headers(Self::default_headers(
api_key,
region,
db_name,
host_override.is_some(),
)?)
.build()?;
let host = match host_override {
Some(host_override) => host_override,
None => format!("https://{}.{}.api.lancedb.com", db_name, region),
};
Ok(Self { client, host })
}
pub fn get(&self, uri: &str) -> RequestBuilder {
let full_uri = format!("{}{}", self.host, uri);
self.client.get(full_uri)
@@ -106,6 +129,10 @@ impl RestfulLanceDbClient {
self.client.post(full_uri)
}
pub async fn send(&self, req: RequestBuilder) -> Result<Response> {
self.sender.send(req).await
}
async fn rsp_to_str(response: Response) -> String {
let status = response.status();
response.text().await.unwrap_or_else(|_| status.to_string())
@@ -126,3 +153,49 @@ impl RestfulLanceDbClient {
}
}
}
#[cfg(test)]
pub mod test_utils {
use std::sync::Arc;
use super::*;
#[derive(Clone)]
pub struct MockSender {
f: Arc<dyn Fn(reqwest::Request) -> reqwest::Response + Send + Sync + 'static>,
}
impl std::fmt::Debug for MockSender {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "MockSender")
}
}
impl HttpSend for MockSender {
async fn send(&self, request: reqwest::RequestBuilder) -> Result<reqwest::Response> {
let request = request.build().unwrap();
let response = (self.f)(request);
Ok(response)
}
}
pub fn client_with_handler<T>(
handler: impl Fn(reqwest::Request) -> http::response::Response<T> + Send + Sync + 'static,
) -> RestfulLanceDbClient<MockSender>
where
T: Into<reqwest::Body>,
{
let wrapper = move |req: reqwest::Request| {
let response = handler(req);
response.into()
};
RestfulLanceDbClient {
client: reqwest::Client::new(),
host: "http://localhost".to_string(),
sender: MockSender {
f: Arc::new(wrapper),
},
}
}
}

View File

@@ -16,6 +16,7 @@ use std::sync::Arc;
use arrow_array::RecordBatchReader;
use async_trait::async_trait;
use http::StatusCode;
use reqwest::header::CONTENT_TYPE;
use serde::Deserialize;
use tokio::task::spawn_blocking;
@@ -27,7 +28,7 @@ use crate::embeddings::EmbeddingRegistry;
use crate::error::Result;
use crate::Table;
use super::client::RestfulLanceDbClient;
use super::client::{HttpSend, RestfulLanceDbClient, Sender};
use super::table::RemoteTable;
use super::util::batches_to_ipc_bytes;
@@ -39,8 +40,8 @@ struct ListTablesResponse {
}
#[derive(Debug)]
pub struct RemoteDatabase {
client: RestfulLanceDbClient,
pub struct RemoteDatabase<S: HttpSend = Sender> {
client: RestfulLanceDbClient<S>,
}
impl RemoteDatabase {
@@ -55,14 +56,32 @@ impl RemoteDatabase {
}
}
impl std::fmt::Display for RemoteDatabase {
#[cfg(all(test, feature = "remote"))]
mod test_utils {
use super::*;
use crate::remote::client::test_utils::client_with_handler;
use crate::remote::client::test_utils::MockSender;
impl RemoteDatabase<MockSender> {
pub fn new_mock<F, T>(handler: F) -> Self
where
F: Fn(reqwest::Request) -> http::Response<T> + Send + Sync + 'static,
T: Into<reqwest::Body>,
{
let client = client_with_handler(handler);
Self { client }
}
}
}
impl<S: HttpSend> std::fmt::Display for RemoteDatabase<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "RemoteDatabase(host={})", self.client.host())
}
}
#[async_trait]
impl ConnectionInternal for RemoteDatabase {
impl<S: HttpSend> ConnectionInternal for RemoteDatabase<S> {
async fn table_names(&self, options: TableNamesBuilder) -> Result<Vec<String>> {
let mut req = self.client.get("/v1/table/");
if let Some(limit) = options.limit {
@@ -71,7 +90,7 @@ impl ConnectionInternal for RemoteDatabase {
if let Some(start_after) = options.start_after {
req = req.query(&[("page_token", start_after)]);
}
let rsp = req.send().await?;
let rsp = self.client.send(req).await?;
let rsp = self.client.check_response(rsp).await?;
Ok(rsp.json::<ListTablesResponse>().await?.tables)
}
@@ -88,15 +107,24 @@ impl ConnectionInternal for RemoteDatabase {
.await
.unwrap()?;
let rsp = self
let req = self
.client
.post(&format!("/v1/table/{}/create/", options.name))
.body(data_buffer)
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
// This is currently expected by LanceDb cloud but will be removed soon.
.header("x-request-id", "na")
.send()
.await?;
.header("x-request-id", "na");
let rsp = self.client.send(req).await?;
if rsp.status() == StatusCode::BAD_REQUEST {
let body = rsp.text().await?;
if body.contains("already exists") {
return Err(crate::Error::TableAlreadyExists { name: options.name });
} else {
return Err(crate::Error::InvalidInput { message: body });
}
}
self.client.check_response(rsp).await?;
Ok(Table::new(Arc::new(RemoteTable::new(
@@ -105,19 +133,206 @@ impl ConnectionInternal for RemoteDatabase {
))))
}
async fn do_open_table(&self, _options: OpenTableBuilder) -> Result<Table> {
todo!()
async fn do_open_table(&self, options: OpenTableBuilder) -> Result<Table> {
// We describe the table to confirm it exists before moving on.
// TODO: a TTL cache of table existence
let req = self
.client
.get(&format!("/v1/table/{}/describe/", options.name));
let resp = self.client.send(req).await?;
if resp.status() == StatusCode::NOT_FOUND {
return Err(crate::Error::TableNotFound { name: options.name });
}
self.client.check_response(resp).await?;
Ok(Table::new(Arc::new(RemoteTable::new(
self.client.clone(),
options.name,
))))
}
async fn drop_table(&self, _name: &str) -> Result<()> {
todo!()
async fn drop_table(&self, name: &str) -> Result<()> {
let req = self.client.post(&format!("/v1/table/{}/drop/", name));
let resp = self.client.send(req).await?;
self.client.check_response(resp).await?;
Ok(())
}
async fn drop_db(&self) -> Result<()> {
todo!()
Err(crate::Error::NotSupported {
message: "Dropping databases is not supported in the remote API".to_string(),
})
}
fn embedding_registry(&self) -> &dyn EmbeddingRegistry {
todo!()
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator};
use arrow_schema::{DataType, Field, Schema};
use crate::{remote::db::ARROW_STREAM_CONTENT_TYPE, Connection};
#[tokio::test]
async fn test_table_names() {
let conn = Connection::new_with_handler(|request| {
assert_eq!(request.method(), &reqwest::Method::GET);
assert_eq!(request.url().path(), "/v1/table/");
assert_eq!(request.url().query(), None);
http::Response::builder()
.status(200)
.body(r#"{"tables": ["table1", "table2"]}"#)
.unwrap()
});
let names = conn.table_names().execute().await.unwrap();
assert_eq!(names, vec!["table1", "table2"]);
}
#[tokio::test]
async fn test_table_names_pagination() {
let conn = Connection::new_with_handler(|request| {
assert_eq!(request.method(), &reqwest::Method::GET);
assert_eq!(request.url().path(), "/v1/table/");
assert!(request.url().query().unwrap().contains("limit=2"));
assert!(request.url().query().unwrap().contains("page_token=table2"));
http::Response::builder()
.status(200)
.body(r#"{"tables": ["table3", "table4"], "page_token": "token"}"#)
.unwrap()
});
let names = conn
.table_names()
.start_after("table2")
.limit(2)
.execute()
.await
.unwrap();
assert_eq!(names, vec!["table3", "table4"]);
}
#[tokio::test]
async fn test_open_table() {
let conn = Connection::new_with_handler(|request| {
assert_eq!(request.method(), &reqwest::Method::GET);
assert_eq!(request.url().path(), "/v1/table/table1/describe/");
assert_eq!(request.url().query(), None);
http::Response::builder()
.status(200)
.body(r#"{"table": "table1"}"#)
.unwrap()
});
let table = conn.open_table("table1").execute().await.unwrap();
assert_eq!(table.name(), "table1");
// Storage options should be ignored.
let table = conn
.open_table("table1")
.storage_option("key", "value")
.execute()
.await
.unwrap();
assert_eq!(table.name(), "table1");
}
#[tokio::test]
async fn test_open_table_not_found() {
let conn = Connection::new_with_handler(|_| {
http::Response::builder()
.status(404)
.body("table not found")
.unwrap()
});
let result = conn.open_table("table1").execute().await;
assert!(result.is_err());
assert!(matches!(result, Err(crate::Error::TableNotFound { .. })));
}
#[tokio::test]
async fn test_create_table() {
let conn = Connection::new_with_handler(|request| {
assert_eq!(request.method(), &reqwest::Method::POST);
assert_eq!(request.url().path(), "/v1/table/table1/create/");
assert_eq!(
request
.headers()
.get(reqwest::header::CONTENT_TYPE)
.unwrap(),
ARROW_STREAM_CONTENT_TYPE.as_bytes()
);
http::Response::builder().status(200).body("").unwrap()
});
let data = RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap();
let reader = RecordBatchIterator::new([Ok(data.clone())], data.schema());
let table = conn.create_table("table1", reader).execute().await.unwrap();
assert_eq!(table.name(), "table1");
}
#[tokio::test]
async fn test_create_table_already_exists() {
let conn = Connection::new_with_handler(|_| {
http::Response::builder()
.status(400)
.body("table table1 already exists")
.unwrap()
});
let data = RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap();
let reader = RecordBatchIterator::new([Ok(data.clone())], data.schema());
let result = conn.create_table("table1", reader).execute().await;
assert!(result.is_err());
assert!(
matches!(result, Err(crate::Error::TableAlreadyExists { name }) if name == "table1")
);
}
#[tokio::test]
async fn test_create_table_empty() {
let conn = Connection::new_with_handler(|request| {
assert_eq!(request.method(), &reqwest::Method::POST);
assert_eq!(request.url().path(), "/v1/table/table1/create/");
assert_eq!(
request
.headers()
.get(reqwest::header::CONTENT_TYPE)
.unwrap(),
ARROW_STREAM_CONTENT_TYPE.as_bytes()
);
http::Response::builder().status(200).body("").unwrap()
});
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
conn.create_empty_table("table1", schema)
.execute()
.await
.unwrap();
}
#[tokio::test]
async fn test_drop_table() {
let conn = Connection::new_with_handler(|request| {
assert_eq!(request.method(), &reqwest::Method::POST);
assert_eq!(request.url().path(), "/v1/table/table1/drop/");
assert_eq!(request.url().query(), None);
assert!(request.body().is_none());
http::Response::builder().status(200).body("").unwrap()
});
conn.drop_table("table1").await.unwrap();
// NOTE: the API will return 200 even if the table does not exist. So we shouldn't expect 404.
}
}

View File

@@ -19,29 +19,29 @@ use crate::{
},
};
use super::client::RestfulLanceDbClient;
use super::client::{HttpSend, RestfulLanceDbClient, Sender};
#[derive(Debug)]
pub struct RemoteTable {
pub struct RemoteTable<S: HttpSend = Sender> {
#[allow(dead_code)]
client: RestfulLanceDbClient,
client: RestfulLanceDbClient<S>,
name: String,
}
impl RemoteTable {
pub fn new(client: RestfulLanceDbClient, name: String) -> Self {
impl<S: HttpSend> RemoteTable<S> {
pub fn new(client: RestfulLanceDbClient<S>, name: String) -> Self {
Self { client, name }
}
}
impl std::fmt::Display for RemoteTable {
impl<S: HttpSend> std::fmt::Display for RemoteTable<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "RemoteTable({})", self.name)
}
}
#[async_trait]
impl TableInternal for RemoteTable {
impl<S: HttpSend> TableInternal for RemoteTable<S> {
fn as_any(&self) -> &dyn std::any::Any {
self
}