chore(rust): provide a Connection trait to match python and nodejs SDK (#846)

In NodeJS and Python, LanceDB establishes a connection to a db. In Rust
core, it is called Database.
We should be consistent with the naming.
This commit is contained in:
Lei Xu
2024-01-22 17:35:02 -08:00
committed by GitHub
parent 66eaa2a00e
commit 4c303ba293
5 changed files with 78 additions and 55 deletions

View File

@@ -22,8 +22,9 @@ use object_store::CredentialProvider;
use once_cell::sync::OnceCell;
use tokio::runtime::Runtime;
use vectordb::database::Database;
use vectordb::connection::Database;
use vectordb::table::ReadParams;
use vectordb::Connection;
use crate::error::ResultExt;
use crate::query::JsQuery;
@@ -38,7 +39,7 @@ mod query;
mod table;
struct JsDatabase {
database: Arc<Database>,
database: Arc<dyn Connection + 'static>,
}
impl Finalize for JsDatabase {}

View File

@@ -77,7 +77,7 @@ impl JsTable {
rt.spawn(async move {
let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
let table_rst = database
.create_table(&table_name, batch_reader, Some(params))
.create_table(&table_name, Box::new(batch_reader), Some(params))
.await;
deferred.settle_with(&channel, move |mut cx| {

View File

@@ -31,6 +31,40 @@ use crate::table::{ReadParams, Table};
pub const LANCE_FILE_EXTENSION: &str = "lance";
/// A connection to LanceDB
#[async_trait::async_trait]
pub trait Connection: Send + Sync {
/// Get the names of all tables in the database.
async fn table_names(&self) -> Result<Vec<String>>;
/// Create a new table in the database.
///
/// # Parameters
///
/// * `name` - The name of the table.
/// * `batches` - The initial data to write to the table.
/// * `params` - Optional [`WriteParams`] to create the table.
///
/// # Returns
/// Created [`Table`], or [`Err(Error::TableAlreadyExists)`] if the table already exists.
async fn create_table(
&self,
name: &str,
batches: Box<dyn RecordBatchReader + Send>,
params: Option<WriteParams>,
) -> Result<Table>;
async fn open_table(&self, name: &str) -> Result<Table>;
async fn open_table_with_params(&self, name: &str, params: ReadParams) -> Result<Table>;
/// Drop a table in the database.
///
/// # Arguments
/// * `name` - The name of the table.
async fn drop_table(&self, name: &str) -> Result<()>;
}
pub struct Database {
object_store: ObjectStore,
query_string: Option<String>,
@@ -52,7 +86,7 @@ impl Database {
///
/// # Arguments
///
/// * `path` - URI where the database is located, can be a local file or a supported remote cloud storage
/// * `uri` - URI where the database is located, can be a local file or a supported remote cloud storage
///
/// # Returns
///
@@ -158,12 +192,30 @@ impl Database {
Ok(())
}
/// Get the names of all tables in the database.
///
/// # Returns
///
/// * A [`Vec<String>`] with all table names.
pub async fn table_names(&self) -> Result<Vec<String>> {
/// Get the URI of a table in the database.
fn table_uri(&self, name: &str) -> Result<String> {
let path = Path::new(&self.uri);
let table_uri = path.join(format!("{}.{}", name, LANCE_FILE_EXTENSION));
let mut uri = table_uri
.as_path()
.to_str()
.context(InvalidTableNameSnafu { name })?
.to_string();
// If there are query string set on the connection, propagate to lance
if let Some(query) = self.query_string.as_ref() {
uri.push('?');
uri.push_str(query.as_str());
}
Ok(uri)
}
}
#[async_trait::async_trait]
impl Connection for Database {
async fn table_names(&self) -> Result<Vec<String>> {
let mut f = self
.object_store
.read_dir(self.base_path.clone())
@@ -183,16 +235,10 @@ impl Database {
Ok(f)
}
/// Create a new table in the database.
///
/// # Arguments
/// * `name` - The name of the table.
/// * `batches` - The initial data to write to the table.
/// * `params` - Optional [`WriteParams`] to create the table.
pub async fn create_table(
async fn create_table(
&self,
name: &str,
batches: impl RecordBatchReader + Send + 'static,
batches: Box<dyn RecordBatchReader + Send>,
params: Option<WriteParams>,
) -> Result<Table> {
let table_uri = self.table_uri(name)?;
@@ -215,7 +261,7 @@ impl Database {
/// # Returns
///
/// * A [Table] object.
pub async fn open_table(&self, name: &str) -> Result<Table> {
async fn open_table(&self, name: &str) -> Result<Table> {
self.open_table_with_params(name, ReadParams::default())
.await
}
@@ -229,41 +275,17 @@ impl Database {
/// # Returns
///
/// * A [Table] object.
pub async fn open_table_with_params(&self, name: &str, params: ReadParams) -> Result<Table> {
async fn open_table_with_params(&self, name: &str, params: ReadParams) -> Result<Table> {
let table_uri = self.table_uri(name)?;
Table::open_with_params(&table_uri, name, self.store_wrapper.clone(), params).await
}
/// Drop a table in the database.
///
/// # Arguments
/// * `name` - The name of the table.
pub async fn drop_table(&self, name: &str) -> Result<()> {
async fn drop_table(&self, name: &str) -> Result<()> {
let dir_name = format!("{}.{}", name, LANCE_EXTENSION);
let full_path = self.base_path.child(dir_name.clone());
self.object_store.remove_dir_all(full_path).await?;
Ok(())
}
/// Get the URI of a table in the database.
fn table_uri(&self, name: &str) -> Result<String> {
let path = Path::new(&self.uri);
let table_uri = path.join(format!("{}.{}", name, LANCE_FILE_EXTENSION));
let mut uri = table_uri
.as_path()
.to_str()
.context(InvalidTableNameSnafu { name })?
.to_string();
// If there are query string set on the connection, propagate to lance
if let Some(query) = self.query_string.as_ref() {
uri.push('?');
uri.push_str(query.as_str());
}
Ok(uri)
}
}
#[cfg(test)]
@@ -272,7 +294,7 @@ mod tests {
use tempfile::tempdir;
use crate::database::Database;
use super::*;
#[tokio::test]
async fn test_connect() {

View File

@@ -335,7 +335,7 @@ impl WrappingObjectStore for MirroringObjectStoreWrapper {
#[cfg(all(test, not(windows)))]
mod test {
use super::*;
use crate::Database;
use crate::connection::{Connection, Database};
use arrow_array::PrimitiveArray;
use futures::TryStreamExt;
use lance::{dataset::WriteParams, io::object_store::ObjectStoreParams};
@@ -365,7 +365,7 @@ mod test {
datagen = datagen.col(Box::new(RandomVector::default().named("vector".into())));
let res = db
.create_table("test", datagen.batch(100), Some(param.clone()))
.create_table("test", Box::new(datagen.batch(100)), Some(param.clone()))
.await;
// leave this here for easy debugging

View File

@@ -46,7 +46,7 @@
//! #### Connect to a database.
//!
//! ```rust
//! use vectordb::{Database, Table, WriteMode};
//! use vectordb::{connection::{Database, Connection}, Table, WriteMode};
//! use arrow_schema::{Field, Schema};
//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
//! let db = Database::connect("data/sample-lancedb").await.unwrap();
@@ -66,7 +66,7 @@
//! use arrow_schema::{DataType, Schema, Field};
//! use arrow_array::{RecordBatch, RecordBatchIterator};
//! # use arrow_array::{FixedSizeListArray, Float32Array, Int32Array, types::Float32Type};
//! # use vectordb::Database;
//! # use vectordb::connection::{Database, Connection};
//!
//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
//! # let tmpdir = tempfile::tempdir().unwrap();
@@ -86,7 +86,7 @@
//! ]).unwrap()
//! ].into_iter().map(Ok),
//! schema.clone());
//! db.create_table("my_table", batches, None).await.unwrap();
//! db.create_table("my_table", Box::new(batches), None).await.unwrap();
//! # });
//! ```
//!
@@ -98,7 +98,7 @@
//! # use arrow_schema::{DataType, Schema, Field};
//! # use arrow_array::{RecordBatch, RecordBatchIterator};
//! # use arrow_array::{FixedSizeListArray, Float32Array, Int32Array, types::Float32Type};
//! # use vectordb::Database;
//! # use vectordb::connection::{Database, Connection};
//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
//! # let tmpdir = tempfile::tempdir().unwrap();
//! # let db = Database::connect(tmpdir.path().to_str().unwrap()).await.unwrap();
@@ -116,7 +116,7 @@
//! # ]).unwrap()
//! # ].into_iter().map(Ok),
//! # schema.clone());
//! # db.create_table("my_table", batches, None).await.unwrap();
//! # db.create_table("my_table", Box::new(batches), None).await.unwrap();
//! let table = db.open_table("my_table").await.unwrap();
//! let results = table
//! .search(Some(vec![1.0; 128]))
@@ -131,8 +131,8 @@
//!
//! ```
pub mod connection;
pub mod data;
pub mod database;
pub mod error;
pub mod index;
pub mod io;
@@ -140,7 +140,7 @@ pub mod query;
pub mod table;
pub mod utils;
pub use database::Database;
pub use connection::Connection;
pub use table::Table;
pub use lance::dataset::WriteMode;