mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-07 12:22:59 +00:00
refactor: rust vectordb API stabilization of the Connection trait (#993)
This is the start of a more comprehensive refactor and stabilization of the Rust API. The `Connection` trait is cleaned up to not require `lance` and to match the `Connection` trait in other APIs. In addition, the concrete implementation `Database` is hidden. BREAKING CHANGE: The struct `crate::connection::Database` is now gone. Several examples opened a connection using `Database::connect` or `Database::connect_with_params`. Users should now use `vectordb::connect`. BREAKING CHANGE: The `connect`, `create_table`, and `open_table` methods now all return a builder object. This means that a call like `conn.open_table(..., opt1, opt2)` will now become `conn.open_table(...).opt1(opt1).opt2(opt2).execute()` In addition, the structure of options has changed slightly. However, no options capability has been removed. --------- Co-authored-by: Will Jones <willjones127@gmail.com>
This commit is contained in:
@@ -12,18 +12,16 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use napi::bindgen_prelude::*;
|
||||
use napi_derive::*;
|
||||
|
||||
use crate::table::Table;
|
||||
use vectordb::connection::{Connection as LanceDBConnection, Database};
|
||||
use vectordb::connection::Connection as LanceDBConnection;
|
||||
use vectordb::ipc::ipc_file_to_batches;
|
||||
|
||||
#[napi]
|
||||
pub struct Connection {
|
||||
conn: Arc<dyn LanceDBConnection>,
|
||||
conn: LanceDBConnection,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
@@ -32,9 +30,9 @@ impl Connection {
|
||||
#[napi(factory)]
|
||||
pub async fn new(uri: String) -> napi::Result<Self> {
|
||||
Ok(Self {
|
||||
conn: Arc::new(Database::connect(&uri).await.map_err(|e| {
|
||||
conn: vectordb::connect(&uri).execute().await.map_err(|e| {
|
||||
napi::Error::from_reason(format!("Failed to connect to database: {}", e))
|
||||
})?),
|
||||
})?,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -59,7 +57,8 @@ impl Connection {
|
||||
.map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?;
|
||||
let tbl = self
|
||||
.conn
|
||||
.create_table(&name, Box::new(batches), None)
|
||||
.create_table(&name, Box::new(batches))
|
||||
.execute()
|
||||
.await
|
||||
.map_err(|e| napi::Error::from_reason(format!("{}", e)))?;
|
||||
Ok(Table::new(tbl))
|
||||
@@ -70,6 +69,7 @@ impl Connection {
|
||||
let tbl = self
|
||||
.conn
|
||||
.open_table(&name)
|
||||
.execute()
|
||||
.await
|
||||
.map_err(|e| napi::Error::from_reason(format!("{}", e)))?;
|
||||
Ok(Table::new(tbl))
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
use arrow_ipc::writer::FileWriter;
|
||||
use napi::bindgen_prelude::*;
|
||||
use napi_derive::napi;
|
||||
use vectordb::table::AddDataOptions;
|
||||
use vectordb::{ipc::ipc_file_to_batches, table::TableRef};
|
||||
|
||||
use crate::index::IndexBuilder;
|
||||
@@ -48,12 +49,15 @@ impl Table {
|
||||
pub async fn add(&self, buf: Buffer) -> napi::Result<()> {
|
||||
let batches = ipc_file_to_batches(buf.to_vec())
|
||||
.map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?;
|
||||
self.table.add(Box::new(batches), None).await.map_err(|e| {
|
||||
napi::Error::from_reason(format!(
|
||||
"Failed to add batches to table {}: {}",
|
||||
self.table, e
|
||||
))
|
||||
})
|
||||
self.table
|
||||
.add(Box::new(batches), AddDataOptions::default())
|
||||
.await
|
||||
.map_err(|e| {
|
||||
napi::Error::from_reason(format!(
|
||||
"Failed to add batches to table {}: {}",
|
||||
self.table, e
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
#[napi]
|
||||
|
||||
@@ -22,9 +22,9 @@ use object_store::CredentialProvider;
|
||||
use once_cell::sync::OnceCell;
|
||||
use tokio::runtime::Runtime;
|
||||
|
||||
use vectordb::connection::Database;
|
||||
use vectordb::connect;
|
||||
use vectordb::connection::Connection;
|
||||
use vectordb::table::ReadParams;
|
||||
use vectordb::{ConnectOptions, Connection};
|
||||
|
||||
use crate::error::ResultExt;
|
||||
use crate::query::JsQuery;
|
||||
@@ -39,7 +39,7 @@ mod query;
|
||||
mod table;
|
||||
|
||||
struct JsDatabase {
|
||||
database: Arc<dyn Connection + 'static>,
|
||||
database: Connection,
|
||||
}
|
||||
|
||||
impl Finalize for JsDatabase {}
|
||||
@@ -89,23 +89,23 @@ fn database_new(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let channel = cx.channel();
|
||||
let (deferred, promise) = cx.promise();
|
||||
|
||||
let mut conn_options = ConnectOptions::new(&path);
|
||||
let mut conn_builder = connect(&path);
|
||||
if let Some(region) = region {
|
||||
conn_options = conn_options.region(®ion);
|
||||
conn_builder = conn_builder.region(®ion);
|
||||
}
|
||||
if let Some(aws_creds) = aws_creds {
|
||||
conn_options = conn_options.aws_creds(AwsCredential {
|
||||
conn_builder = conn_builder.aws_creds(AwsCredential {
|
||||
key_id: aws_creds.key_id,
|
||||
secret_key: aws_creds.secret_key,
|
||||
token: aws_creds.token,
|
||||
});
|
||||
}
|
||||
rt.spawn(async move {
|
||||
let database = Database::connect_with_options(&conn_options).await;
|
||||
let database = conn_builder.execute().await;
|
||||
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
let db = JsDatabase {
|
||||
database: Arc::new(database.or_throw(&mut cx)?),
|
||||
database: database.or_throw(&mut cx)?,
|
||||
};
|
||||
Ok(cx.boxed(db))
|
||||
});
|
||||
@@ -217,7 +217,11 @@ fn database_open_table(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
|
||||
let (deferred, promise) = cx.promise();
|
||||
rt.spawn(async move {
|
||||
let table_rst = database.open_table_with_params(&table_name, params).await;
|
||||
let table_rst = database
|
||||
.open_table(&table_name)
|
||||
.lance_read_params(params)
|
||||
.execute()
|
||||
.await;
|
||||
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
let js_table = JsTable::from(table_rst.or_throw(&mut cx)?);
|
||||
|
||||
@@ -18,7 +18,7 @@ use arrow_array::{RecordBatch, RecordBatchIterator};
|
||||
use lance::dataset::optimize::CompactionOptions;
|
||||
use lance::dataset::{WriteMode, WriteParams};
|
||||
use lance::io::ObjectStoreParams;
|
||||
use vectordb::table::OptimizeAction;
|
||||
use vectordb::table::{AddDataOptions, OptimizeAction, WriteOptions};
|
||||
|
||||
use crate::arrow::{arrow_buffer_to_record_batch, record_batch_to_buffer};
|
||||
use neon::prelude::*;
|
||||
@@ -80,7 +80,11 @@ impl JsTable {
|
||||
rt.spawn(async move {
|
||||
let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
|
||||
let table_rst = database
|
||||
.create_table(&table_name, Box::new(batch_reader), Some(params))
|
||||
.create_table(&table_name, Box::new(batch_reader))
|
||||
.write_options(WriteOptions {
|
||||
lance_write_params: Some(params),
|
||||
})
|
||||
.execute()
|
||||
.await;
|
||||
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
@@ -121,7 +125,13 @@ impl JsTable {
|
||||
|
||||
rt.spawn(async move {
|
||||
let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
|
||||
let add_result = table.add(Box::new(batch_reader), Some(params)).await;
|
||||
let opts = AddDataOptions {
|
||||
write_options: WriteOptions {
|
||||
lance_write_params: Some(params),
|
||||
},
|
||||
..Default::default()
|
||||
};
|
||||
let add_result = table.add(Box::new(batch_reader), opts).await;
|
||||
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
add_result.or_throw(&mut cx)?;
|
||||
|
||||
@@ -19,7 +19,8 @@ use arrow_array::{FixedSizeListArray, Int32Array, RecordBatch, RecordBatchIterat
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use futures::TryStreamExt;
|
||||
|
||||
use vectordb::Connection;
|
||||
use vectordb::connection::Connection;
|
||||
use vectordb::table::AddDataOptions;
|
||||
use vectordb::{connect, Result, Table, TableRef};
|
||||
|
||||
#[tokio::main]
|
||||
@@ -29,18 +30,18 @@ async fn main() -> Result<()> {
|
||||
}
|
||||
// --8<-- [start:connect]
|
||||
let uri = "data/sample-lancedb";
|
||||
let db = connect(uri).await?;
|
||||
let db = connect(uri).execute().await?;
|
||||
// --8<-- [end:connect]
|
||||
|
||||
// --8<-- [start:list_names]
|
||||
println!("{:?}", db.table_names().await?);
|
||||
// --8<-- [end:list_names]
|
||||
let tbl = create_table(db.clone()).await?;
|
||||
let tbl = create_table(&db).await?;
|
||||
create_index(tbl.as_ref()).await?;
|
||||
let batches = search(tbl.as_ref()).await?;
|
||||
println!("{:?}", batches);
|
||||
|
||||
create_empty_table(db.clone()).await.unwrap();
|
||||
create_empty_table(&db).await.unwrap();
|
||||
|
||||
// --8<-- [start:delete]
|
||||
tbl.delete("id > 24").await.unwrap();
|
||||
@@ -55,17 +56,14 @@ async fn main() -> Result<()> {
|
||||
#[allow(dead_code)]
|
||||
async fn open_with_existing_tbl() -> Result<()> {
|
||||
let uri = "data/sample-lancedb";
|
||||
let db = connect(uri).await?;
|
||||
let db = connect(uri).execute().await?;
|
||||
// --8<-- [start:open_with_existing_file]
|
||||
let _ = db
|
||||
.open_table_with_params("my_table", Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let _ = db.open_table("my_table").execute().await.unwrap();
|
||||
// --8<-- [end:open_with_existing_file]
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_table(db: Arc<dyn Connection>) -> Result<TableRef> {
|
||||
async fn create_table(db: &Connection) -> Result<TableRef> {
|
||||
// --8<-- [start:create_table]
|
||||
const TOTAL: usize = 1000;
|
||||
const DIM: usize = 128;
|
||||
@@ -102,7 +100,8 @@ async fn create_table(db: Arc<dyn Connection>) -> Result<TableRef> {
|
||||
schema.clone(),
|
||||
);
|
||||
let tbl = db
|
||||
.create_table("my_table", Box::new(batches), None)
|
||||
.create_table("my_table", Box::new(batches))
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
// --8<-- [end:create_table]
|
||||
@@ -126,21 +125,21 @@ async fn create_table(db: Arc<dyn Connection>) -> Result<TableRef> {
|
||||
schema.clone(),
|
||||
);
|
||||
// --8<-- [start:add]
|
||||
tbl.add(Box::new(new_batches), None).await.unwrap();
|
||||
tbl.add(Box::new(new_batches), AddDataOptions::default())
|
||||
.await
|
||||
.unwrap();
|
||||
// --8<-- [end:add]
|
||||
|
||||
Ok(tbl)
|
||||
}
|
||||
|
||||
async fn create_empty_table(db: Arc<dyn Connection>) -> Result<TableRef> {
|
||||
async fn create_empty_table(db: &Connection) -> Result<TableRef> {
|
||||
// --8<-- [start:create_empty_table]
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
Field::new("id", DataType::Int32, false),
|
||||
Field::new("item", DataType::Utf8, true),
|
||||
]));
|
||||
let batches = RecordBatchIterator::new(vec![], schema.clone());
|
||||
db.create_table("empty_table", Box::new(batches), None)
|
||||
.await
|
||||
db.create_empty_table("empty_table", schema).execute().await
|
||||
// --8<-- [end:create_empty_table]
|
||||
}
|
||||
|
||||
|
||||
@@ -13,14 +13,14 @@
|
||||
// limitations under the License.
|
||||
|
||||
//! LanceDB Database
|
||||
//!
|
||||
|
||||
use std::fs::create_dir_all;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::RecordBatchReader;
|
||||
use lance::dataset::WriteParams;
|
||||
use arrow_array::{RecordBatchIterator, RecordBatchReader};
|
||||
use arrow_schema::SchemaRef;
|
||||
use lance::dataset::{ReadParams, WriteMode};
|
||||
use lance::io::{ObjectStore, ObjectStoreParams, WrappingObjectStore};
|
||||
use object_store::{
|
||||
aws::AwsCredential, local::LocalFileSystem, CredentialProvider, StaticCredentialProvider,
|
||||
@@ -29,73 +29,283 @@ use snafu::prelude::*;
|
||||
|
||||
use crate::error::{CreateDirSnafu, Error, InvalidTableNameSnafu, Result};
|
||||
use crate::io::object_store::MirroringObjectStoreWrapper;
|
||||
use crate::table::{NativeTable, ReadParams, TableRef};
|
||||
use crate::table::{NativeTable, TableRef, WriteOptions};
|
||||
|
||||
pub const LANCE_FILE_EXTENSION: &str = "lance";
|
||||
|
||||
/// A connection to LanceDB
|
||||
#[async_trait::async_trait]
|
||||
pub trait Connection: Send + Sync {
|
||||
/// Get the names of all tables in the database.
|
||||
async fn table_names(&self) -> Result<Vec<String>>;
|
||||
pub type TableBuilderCallback = Box<dyn FnOnce(OpenTableBuilder) -> OpenTableBuilder + Send>;
|
||||
|
||||
/// Create a new table in the database.
|
||||
/// Describes what happens when creating a table and a table with
|
||||
/// the same name already exists
|
||||
pub enum CreateTableMode {
|
||||
/// If the table already exists, an error is returned
|
||||
Create,
|
||||
/// If the table already exists, it is opened. Any provided data is
|
||||
/// ignored. The function will be passed an OpenTableBuilder to customize
|
||||
/// how the table is opened
|
||||
ExistOk(TableBuilderCallback),
|
||||
/// If the table already exists, it is overwritten
|
||||
Overwrite,
|
||||
}
|
||||
|
||||
impl CreateTableMode {
|
||||
pub fn exist_ok(
|
||||
callback: impl FnOnce(OpenTableBuilder) -> OpenTableBuilder + Send + 'static,
|
||||
) -> Self {
|
||||
Self::ExistOk(Box::new(callback))
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for CreateTableMode {
|
||||
fn default() -> Self {
|
||||
Self::Create
|
||||
}
|
||||
}
|
||||
|
||||
/// Describes what happens when a vector either contains NaN or
|
||||
/// does not have enough values
|
||||
#[derive(Clone, Debug, Default)]
|
||||
enum BadVectorHandling {
|
||||
/// An error is returned
|
||||
#[default]
|
||||
Error,
|
||||
#[allow(dead_code)] // https://github.com/lancedb/lancedb/issues/992
|
||||
/// The offending row is droppped
|
||||
Drop,
|
||||
#[allow(dead_code)] // https://github.com/lancedb/lancedb/issues/992
|
||||
/// The invalid/missing items are replaced by fill_value
|
||||
Fill(f32),
|
||||
}
|
||||
|
||||
/// A builder for configuring a [`Connection::create_table`] operation
|
||||
pub struct CreateTableBuilder<const HAS_DATA: bool> {
|
||||
parent: Arc<dyn ConnectionInternal>,
|
||||
name: String,
|
||||
data: Option<Box<dyn RecordBatchReader + Send>>,
|
||||
schema: Option<SchemaRef>,
|
||||
mode: CreateTableMode,
|
||||
write_options: WriteOptions,
|
||||
}
|
||||
|
||||
// Builder methods that only apply when we have initial data
|
||||
impl CreateTableBuilder<true> {
|
||||
fn new(
|
||||
parent: Arc<dyn ConnectionInternal>,
|
||||
name: String,
|
||||
data: Box<dyn RecordBatchReader + Send>,
|
||||
) -> Self {
|
||||
Self {
|
||||
parent,
|
||||
name,
|
||||
data: Some(data),
|
||||
schema: None,
|
||||
mode: CreateTableMode::default(),
|
||||
write_options: WriteOptions::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply the given write options when writing the initial data
|
||||
pub fn write_options(mut self, write_options: WriteOptions) -> Self {
|
||||
self.write_options = write_options;
|
||||
self
|
||||
}
|
||||
|
||||
/// Execute the create table operation
|
||||
pub async fn execute(self) -> Result<TableRef> {
|
||||
self.parent.clone().do_create_table(self).await
|
||||
}
|
||||
}
|
||||
|
||||
// Builder methods that only apply when we do not have initial data
|
||||
impl CreateTableBuilder<false> {
|
||||
fn new(parent: Arc<dyn ConnectionInternal>, name: String, schema: SchemaRef) -> Self {
|
||||
Self {
|
||||
parent,
|
||||
name,
|
||||
data: None,
|
||||
schema: Some(schema),
|
||||
mode: CreateTableMode::default(),
|
||||
write_options: WriteOptions::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute the create table operation
|
||||
pub async fn execute(self) -> Result<TableRef> {
|
||||
self.parent.clone().do_create_empty_table(self).await
|
||||
}
|
||||
}
|
||||
|
||||
impl<const HAS_DATA: bool> CreateTableBuilder<HAS_DATA> {
|
||||
/// Set the mode for creating the table
|
||||
///
|
||||
/// This controls what happens if a table with the given name already exists
|
||||
pub fn mode(mut self, mode: CreateTableMode) -> Self {
|
||||
self.mode = mode;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct OpenTableBuilder {
|
||||
parent: Arc<dyn ConnectionInternal>,
|
||||
name: String,
|
||||
index_cache_size: u32,
|
||||
lance_read_params: Option<ReadParams>,
|
||||
}
|
||||
|
||||
impl OpenTableBuilder {
|
||||
fn new(parent: Arc<dyn ConnectionInternal>, name: String) -> Self {
|
||||
Self {
|
||||
parent,
|
||||
name,
|
||||
index_cache_size: 256,
|
||||
lance_read_params: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the size of the index cache, specified as a number of entries
|
||||
///
|
||||
/// The default value is 256
|
||||
///
|
||||
/// The exact meaning of an "entry" will depend on the type of index:
|
||||
/// * IVF - there is one entry for each IVF partition
|
||||
/// * BTREE - there is one entry for the entire index
|
||||
///
|
||||
/// This cache applies to the entire opened table, across all indices.
|
||||
/// Setting this value higher will increase performance on larger datasets
|
||||
/// at the expense of more RAM
|
||||
pub fn index_cache_size(mut self, index_cache_size: u32) -> Self {
|
||||
self.index_cache_size = index_cache_size;
|
||||
self
|
||||
}
|
||||
|
||||
/// Advanced parameters that can be used to customize table reads
|
||||
///
|
||||
/// If set, these will take precedence over any overlapping `OpenTableOptions` options
|
||||
pub fn lance_read_params(mut self, params: ReadParams) -> Self {
|
||||
self.lance_read_params = Some(params);
|
||||
self
|
||||
}
|
||||
|
||||
/// Open the table
|
||||
pub async fn execute(self) -> Result<TableRef> {
|
||||
self.parent.clone().do_open_table(self).await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
trait ConnectionInternal: Send + Sync + std::fmt::Debug + 'static {
|
||||
async fn table_names(&self) -> Result<Vec<String>>;
|
||||
async fn do_create_table(&self, options: CreateTableBuilder<true>) -> Result<TableRef>;
|
||||
async fn do_open_table(&self, options: OpenTableBuilder) -> Result<TableRef>;
|
||||
async fn drop_table(&self, name: &str) -> Result<()>;
|
||||
async fn drop_db(&self) -> Result<()>;
|
||||
|
||||
async fn do_create_empty_table(&self, options: CreateTableBuilder<false>) -> Result<TableRef> {
|
||||
let batches = RecordBatchIterator::new(vec![], options.schema.unwrap());
|
||||
let opts = CreateTableBuilder::<true>::new(options.parent, options.name, Box::new(batches))
|
||||
.mode(options.mode)
|
||||
.write_options(options.write_options);
|
||||
self.do_create_table(opts).await
|
||||
}
|
||||
}
|
||||
|
||||
/// A connection to LanceDB
|
||||
#[derive(Clone)]
|
||||
pub struct Connection {
|
||||
uri: String,
|
||||
internal: Arc<dyn ConnectionInternal>,
|
||||
}
|
||||
|
||||
impl Connection {
|
||||
/// Get the URI of the connection
|
||||
pub fn uri(&self) -> &str {
|
||||
self.uri.as_str()
|
||||
}
|
||||
|
||||
/// Get the names of all tables in the database.
|
||||
pub async fn table_names(&self) -> Result<Vec<String>> {
|
||||
self.internal.table_names().await
|
||||
}
|
||||
|
||||
/// Create a new table from data
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// * `name` - The name of the table.
|
||||
/// * `batches` - The initial data to write to the table.
|
||||
/// * `params` - Optional [`WriteParams`] to create the table.
|
||||
///
|
||||
/// # Returns
|
||||
/// Created [`TableRef`], or [`Err(Error::TableAlreadyExists)`] if the table already exists.
|
||||
async fn create_table(
|
||||
/// * `name` - The name of the table
|
||||
/// * `initial_data` - The initial data to write to the table
|
||||
pub fn create_table(
|
||||
&self,
|
||||
name: &str,
|
||||
batches: Box<dyn RecordBatchReader + Send>,
|
||||
params: Option<WriteParams>,
|
||||
) -> Result<TableRef>;
|
||||
|
||||
async fn open_table(&self, name: &str) -> Result<TableRef> {
|
||||
self.open_table_with_params(name, ReadParams::default())
|
||||
.await
|
||||
name: impl Into<String>,
|
||||
initial_data: Box<dyn RecordBatchReader + Send>,
|
||||
) -> CreateTableBuilder<true> {
|
||||
CreateTableBuilder::<true>::new(self.internal.clone(), name.into(), initial_data)
|
||||
}
|
||||
|
||||
async fn open_table_with_params(&self, name: &str, params: ReadParams) -> Result<TableRef>;
|
||||
/// Create an empty table with a given schema
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// * `name` - The name of the table
|
||||
/// * `schema` - The schema of the table
|
||||
pub fn create_empty_table(
|
||||
&self,
|
||||
name: impl Into<String>,
|
||||
schema: SchemaRef,
|
||||
) -> CreateTableBuilder<false> {
|
||||
CreateTableBuilder::<false>::new(self.internal.clone(), name.into(), schema)
|
||||
}
|
||||
|
||||
/// Open an existing table in the database
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `name` - The name of the table
|
||||
///
|
||||
/// # Returns
|
||||
/// Created [`TableRef`], or [`Error::TableNotFound`] if the table does not exist.
|
||||
pub fn open_table(&self, name: impl Into<String>) -> OpenTableBuilder {
|
||||
OpenTableBuilder::new(self.internal.clone(), name.into())
|
||||
}
|
||||
|
||||
/// Drop a table in the database.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `name` - The name of the table.
|
||||
async fn drop_table(&self, name: &str) -> Result<()>;
|
||||
/// * `name` - The name of the table to drop
|
||||
pub async fn drop_table(&self, name: impl AsRef<str>) -> Result<()> {
|
||||
self.internal.drop_table(name.as_ref()).await
|
||||
}
|
||||
|
||||
/// Drop the database
|
||||
///
|
||||
/// This is the same as dropping all of the tables
|
||||
pub async fn drop_db(&self) -> Result<()> {
|
||||
self.internal.drop_db().await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ConnectOptions {
|
||||
pub struct ConnectBuilder {
|
||||
/// Database URI
|
||||
///
|
||||
/// # Accpeted URI formats
|
||||
/// ### Accpeted URI formats
|
||||
///
|
||||
/// - `/path/to/database` - local database on file system.
|
||||
/// - `s3://bucket/path/to/database` or `gs://bucket/path/to/database` - database on cloud object store
|
||||
/// - `db://dbname` - Lance Cloud
|
||||
pub uri: String,
|
||||
/// - `db://dbname` - LanceDB Cloud
|
||||
uri: String,
|
||||
|
||||
/// Lance Cloud API key
|
||||
pub api_key: Option<String>,
|
||||
/// Lance Cloud region
|
||||
pub region: Option<String>,
|
||||
/// Lance Cloud host override
|
||||
pub host_override: Option<String>,
|
||||
/// LanceDB Cloud API key, required if using Lance Cloud
|
||||
api_key: Option<String>,
|
||||
/// LanceDB Cloud region, required if using Lance Cloud
|
||||
region: Option<String>,
|
||||
/// LanceDB Cloud host override, only required if using an on-premises Lance Cloud instance
|
||||
host_override: Option<String>,
|
||||
|
||||
/// User provided AWS credentials
|
||||
pub aws_creds: Option<AwsCredential>,
|
||||
|
||||
/// The maximum number of indices to cache in memory. Defaults to 256.
|
||||
pub index_cache_size: u32,
|
||||
aws_creds: Option<AwsCredential>,
|
||||
}
|
||||
|
||||
impl ConnectOptions {
|
||||
impl ConnectBuilder {
|
||||
/// Create a new [`ConnectOptions`] with the given database URI.
|
||||
pub fn new(uri: &str) -> Self {
|
||||
Self {
|
||||
@@ -104,7 +314,6 @@ impl ConnectOptions {
|
||||
region: None,
|
||||
host_override: None,
|
||||
aws_creds: None,
|
||||
index_cache_size: 256,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -124,15 +333,18 @@ impl ConnectOptions {
|
||||
}
|
||||
|
||||
/// [`AwsCredential`] to use when connecting to S3.
|
||||
///
|
||||
pub fn aws_creds(mut self, aws_creds: AwsCredential) -> Self {
|
||||
self.aws_creds = Some(aws_creds);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn index_cache_size(mut self, index_cache_size: u32) -> Self {
|
||||
self.index_cache_size = index_cache_size;
|
||||
self
|
||||
/// Establishes a connection to the database
|
||||
pub async fn execute(self) -> Result<Connection> {
|
||||
let internal = Arc::new(Database::connect_with_options(&self).await?);
|
||||
Ok(Connection {
|
||||
internal,
|
||||
uri: self.uri,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -140,29 +352,14 @@ impl ConnectOptions {
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `uri` - URI where the database is located, can be a local file or a supported remote cloud storage
|
||||
///
|
||||
/// ## Accepted URI formats
|
||||
///
|
||||
/// - `/path/to/database` - local database on file system.
|
||||
/// - `s3://bucket/path/to/database` or `gs://bucket/path/to/database` - database on cloud object store
|
||||
/// - `db://dbname` - Lance Cloud
|
||||
///
|
||||
pub async fn connect(uri: &str) -> Result<Arc<dyn Connection>> {
|
||||
let options = ConnectOptions::new(uri);
|
||||
connect_with_options(&options).await
|
||||
/// * `uri` - URI where the database is located, can be a local directory, supported remote cloud storage,
|
||||
/// or a LanceDB Cloud database. See [ConnectOptions::uri] for a list of accepted formats
|
||||
pub fn connect(uri: &str) -> ConnectBuilder {
|
||||
ConnectBuilder::new(uri)
|
||||
}
|
||||
|
||||
/// Connect with [`ConnectOptions`].
|
||||
///
|
||||
/// # Arguments
|
||||
/// - `options` - [`ConnectOptions`] to connect to the database.
|
||||
pub async fn connect_with_options(options: &ConnectOptions) -> Result<Arc<dyn Connection>> {
|
||||
let db = Database::connect(&options.uri).await?;
|
||||
Ok(Arc::new(db))
|
||||
}
|
||||
|
||||
pub struct Database {
|
||||
#[derive(Debug)]
|
||||
struct Database {
|
||||
object_store: ObjectStore,
|
||||
query_string: Option<String>,
|
||||
|
||||
@@ -179,21 +376,7 @@ const MIRRORED_STORE: &str = "mirroredStore";
|
||||
|
||||
/// A connection to LanceDB
|
||||
impl Database {
|
||||
/// Connects to LanceDB
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `uri` - URI where the database is located, can be a local file or a supported remote cloud storage
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * A [Database] object.
|
||||
pub async fn connect(uri: &str) -> Result<Self> {
|
||||
let options = ConnectOptions::new(uri);
|
||||
Self::connect_with_options(&options).await
|
||||
}
|
||||
|
||||
pub async fn connect_with_options(options: &ConnectOptions) -> Result<Self> {
|
||||
async fn connect_with_options(options: &ConnectBuilder) -> Result<Self> {
|
||||
let uri = &options.uri;
|
||||
let parse_res = url::Url::parse(uri);
|
||||
|
||||
@@ -333,7 +516,7 @@ impl Database {
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Connection for Database {
|
||||
impl ConnectionInternal for Database {
|
||||
async fn table_names(&self) -> Result<Vec<String>> {
|
||||
let mut f = self
|
||||
.object_store
|
||||
@@ -354,40 +537,47 @@ impl Connection for Database {
|
||||
Ok(f)
|
||||
}
|
||||
|
||||
async fn create_table(
|
||||
&self,
|
||||
name: &str,
|
||||
batches: Box<dyn RecordBatchReader + Send>,
|
||||
params: Option<WriteParams>,
|
||||
) -> Result<TableRef> {
|
||||
let table_uri = self.table_uri(name)?;
|
||||
async fn do_create_table(&self, options: CreateTableBuilder<true>) -> Result<TableRef> {
|
||||
let table_uri = self.table_uri(&options.name)?;
|
||||
|
||||
Ok(Arc::new(
|
||||
NativeTable::create(
|
||||
&table_uri,
|
||||
name,
|
||||
batches,
|
||||
self.store_wrapper.clone(),
|
||||
params,
|
||||
)
|
||||
.await?,
|
||||
))
|
||||
let mut write_params = options.write_options.lance_write_params.unwrap_or_default();
|
||||
if matches!(&options.mode, CreateTableMode::Overwrite) {
|
||||
write_params.mode = WriteMode::Overwrite;
|
||||
}
|
||||
|
||||
match NativeTable::create(
|
||||
&table_uri,
|
||||
&options.name,
|
||||
options.data.unwrap(),
|
||||
self.store_wrapper.clone(),
|
||||
Some(write_params),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(table) => Ok(Arc::new(table)),
|
||||
Err(Error::TableAlreadyExists { name }) => match options.mode {
|
||||
CreateTableMode::Create => Err(Error::TableAlreadyExists { name }),
|
||||
CreateTableMode::ExistOk(callback) => {
|
||||
let builder = OpenTableBuilder::new(options.parent, options.name);
|
||||
let builder = (callback)(builder);
|
||||
builder.execute().await
|
||||
}
|
||||
CreateTableMode::Overwrite => unreachable!(),
|
||||
},
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
|
||||
/// Open a table in the database.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `name` - The name of the table.
|
||||
/// * `params` - The parameters to open the table.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * A [TableRef] object.
|
||||
async fn open_table_with_params(&self, name: &str, params: ReadParams) -> Result<TableRef> {
|
||||
let table_uri = self.table_uri(name)?;
|
||||
async fn do_open_table(&self, options: OpenTableBuilder) -> Result<TableRef> {
|
||||
let table_uri = self.table_uri(&options.name)?;
|
||||
Ok(Arc::new(
|
||||
NativeTable::open_with_params(&table_uri, name, self.store_wrapper.clone(), params)
|
||||
.await?,
|
||||
NativeTable::open_with_params(
|
||||
&table_uri,
|
||||
&options.name,
|
||||
self.store_wrapper.clone(),
|
||||
options.lance_read_params,
|
||||
)
|
||||
.await?,
|
||||
))
|
||||
}
|
||||
|
||||
@@ -397,12 +587,17 @@ impl Connection for Database {
|
||||
self.object_store.remove_dir_all(full_path).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn drop_db(&self) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::fs::create_dir_all;
|
||||
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use tempfile::tempdir;
|
||||
|
||||
use super::*;
|
||||
@@ -411,7 +606,7 @@ mod tests {
|
||||
async fn test_connect() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let uri = tmp_dir.path().to_str().unwrap();
|
||||
let db = Database::connect(uri).await.unwrap();
|
||||
let db = connect(uri).execute().await.unwrap();
|
||||
|
||||
assert_eq!(db.uri, uri);
|
||||
}
|
||||
@@ -429,7 +624,8 @@ mod tests {
|
||||
let relative_root = std::path::PathBuf::from(relative_ancestors.join("/"));
|
||||
let relative_uri = relative_root.join(&uri);
|
||||
|
||||
let db = Database::connect(relative_uri.to_str().unwrap())
|
||||
let db = connect(relative_uri.to_str().unwrap())
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -444,7 +640,7 @@ mod tests {
|
||||
create_dir_all(tmp_dir.path().join("invalidlance")).unwrap();
|
||||
|
||||
let uri = tmp_dir.path().to_str().unwrap();
|
||||
let db = Database::connect(uri).await.unwrap();
|
||||
let db = connect(uri).execute().await.unwrap();
|
||||
let tables = db.table_names().await.unwrap();
|
||||
assert_eq!(tables.len(), 2);
|
||||
assert!(tables[0].eq(&String::from("table1")));
|
||||
@@ -462,10 +658,44 @@ mod tests {
|
||||
create_dir_all(tmp_dir.path().join("table1.lance")).unwrap();
|
||||
|
||||
let uri = tmp_dir.path().to_str().unwrap();
|
||||
let db = Database::connect(uri).await.unwrap();
|
||||
let db = connect(uri).execute().await.unwrap();
|
||||
db.drop_table("table1").await.unwrap();
|
||||
|
||||
let tables = db.table_names().await.unwrap();
|
||||
assert_eq!(tables.len(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_table_already_exists() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let uri = tmp_dir.path().to_str().unwrap();
|
||||
let db = connect(uri).execute().await.unwrap();
|
||||
let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int32, false)]));
|
||||
db.create_empty_table("test", schema.clone())
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
// TODO: None of the open table options are "inspectable" right now but once one is we
|
||||
// should assert we are passing these options in correctly
|
||||
db.create_empty_table("test", schema)
|
||||
.mode(CreateTableMode::exist_ok(|builder| {
|
||||
builder.index_cache_size(16)
|
||||
}))
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
let other_schema = Arc::new(Schema::new(vec![Field::new("y", DataType::Int32, false)]));
|
||||
assert!(db
|
||||
.create_empty_table("test", other_schema.clone())
|
||||
.execute()
|
||||
.await
|
||||
.is_err());
|
||||
let overwritten = db
|
||||
.create_empty_table("test", other_schema.clone())
|
||||
.mode(CreateTableMode::Overwrite)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(other_schema, overwritten.schema());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -174,7 +174,6 @@ fn coerce_schema_batch(
|
||||
}
|
||||
|
||||
/// Coerce the reader (input data) to match the given [Schema].
|
||||
///
|
||||
pub fn coerce_schema(
|
||||
reader: impl RecordBatchReader + Send + 'static,
|
||||
schema: Arc<Schema>,
|
||||
|
||||
@@ -342,7 +342,7 @@ mod test {
|
||||
use object_store::local::LocalFileSystem;
|
||||
use tempfile;
|
||||
|
||||
use crate::connection::{Connection, Database};
|
||||
use crate::{connect, table::WriteOptions};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_e2e() {
|
||||
@@ -354,7 +354,7 @@ mod test {
|
||||
secondary: Arc::new(secondary_store),
|
||||
});
|
||||
|
||||
let db = Database::connect(dir1.to_str().unwrap()).await.unwrap();
|
||||
let db = connect(dir1.to_str().unwrap()).execute().await.unwrap();
|
||||
|
||||
let mut param = WriteParams::default();
|
||||
let store_params = ObjectStoreParams {
|
||||
@@ -368,7 +368,11 @@ mod test {
|
||||
datagen = datagen.col(Box::new(RandomVector::default().named("vector".into())));
|
||||
|
||||
let res = db
|
||||
.create_table("test", Box::new(datagen.batch(100)), Some(param.clone()))
|
||||
.create_table("test", Box::new(datagen.batch(100)))
|
||||
.write_options(WriteOptions {
|
||||
lance_write_params: Some(param),
|
||||
})
|
||||
.execute()
|
||||
.await;
|
||||
|
||||
// leave this here for easy debugging
|
||||
|
||||
@@ -43,10 +43,9 @@
|
||||
//! #### Connect to a database.
|
||||
//!
|
||||
//! ```rust
|
||||
//! use vectordb::connect;
|
||||
//! # use arrow_schema::{Field, Schema};
|
||||
//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
//! let db = connect("data/sample-lancedb").await.unwrap();
|
||||
//! let db = vectordb::connect("data/sample-lancedb").execute().await.unwrap();
|
||||
//! # });
|
||||
//! ```
|
||||
//!
|
||||
@@ -56,14 +55,20 @@
|
||||
//! - `s3://bucket/path/to/database` or `gs://bucket/path/to/database` - database on cloud object store
|
||||
//! - `db://dbname` - Lance Cloud
|
||||
//!
|
||||
//! You can also use [`ConnectOptions`] to configure the connectoin to the database.
|
||||
//! You can also use [`ConnectOptions`] to configure the connection to the database.
|
||||
//!
|
||||
//! ```rust
|
||||
//! use vectordb::{connect_with_options, ConnectOptions};
|
||||
//! use object_store::aws::AwsCredential;
|
||||
//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
//! let options = ConnectOptions::new("data/sample-lancedb")
|
||||
//! .index_cache_size(1024);
|
||||
//! let db = connect_with_options(&options).await.unwrap();
|
||||
//! let db = vectordb::connect("data/sample-lancedb")
|
||||
//! .aws_creds(AwsCredential {
|
||||
//! key_id: "some_key".to_string(),
|
||||
//! secret_key: "some_secret".to_string(),
|
||||
//! token: None,
|
||||
//! })
|
||||
//! .execute()
|
||||
//! .await
|
||||
//! .unwrap();
|
||||
//! # });
|
||||
//! ```
|
||||
//!
|
||||
@@ -79,31 +84,44 @@
|
||||
//!
|
||||
//! ```rust
|
||||
//! # use std::sync::Arc;
|
||||
//! use arrow_schema::{DataType, Schema, Field};
|
||||
//! use arrow_array::{RecordBatch, RecordBatchIterator};
|
||||
//! use arrow_schema::{DataType, Field, Schema};
|
||||
//! # use arrow_array::{FixedSizeListArray, Float32Array, Int32Array, types::Float32Type};
|
||||
//! # use vectordb::connection::{Database, Connection};
|
||||
//! # use vectordb::connect;
|
||||
//!
|
||||
//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
//! # let tmpdir = tempfile::tempdir().unwrap();
|
||||
//! # let db = connect(tmpdir.path().to_str().unwrap()).await.unwrap();
|
||||
//! # let db = vectordb::connect(tmpdir.path().to_str().unwrap()).execute().await.unwrap();
|
||||
//! let schema = Arc::new(Schema::new(vec![
|
||||
//! Field::new("id", DataType::Int32, false),
|
||||
//! Field::new("vector", DataType::FixedSizeList(
|
||||
//! Arc::new(Field::new("item", DataType::Float32, true)), 128), true),
|
||||
//! Field::new("id", DataType::Int32, false),
|
||||
//! Field::new(
|
||||
//! "vector",
|
||||
//! DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 128),
|
||||
//! true,
|
||||
//! ),
|
||||
//! ]));
|
||||
//! // Create a RecordBatch stream.
|
||||
//! let batches = RecordBatchIterator::new(vec![
|
||||
//! RecordBatch::try_new(schema.clone(),
|
||||
//! let batches = RecordBatchIterator::new(
|
||||
//! vec![RecordBatch::try_new(
|
||||
//! schema.clone(),
|
||||
//! vec![
|
||||
//! Arc::new(Int32Array::from_iter_values(0..1000)),
|
||||
//! Arc::new(FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
|
||||
//! (0..1000).map(|_| Some(vec![Some(1.0); 128])), 128)),
|
||||
//! ]).unwrap()
|
||||
//! ].into_iter().map(Ok),
|
||||
//! schema.clone());
|
||||
//! db.create_table("my_table", Box::new(batches), None).await.unwrap();
|
||||
//! Arc::new(Int32Array::from_iter_values(0..256)),
|
||||
//! Arc::new(
|
||||
//! FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
|
||||
//! (0..256).map(|_| Some(vec![Some(1.0); 128])),
|
||||
//! 128,
|
||||
//! ),
|
||||
//! ),
|
||||
//! ],
|
||||
//! )
|
||||
//! .unwrap()]
|
||||
//! .into_iter()
|
||||
//! .map(Ok),
|
||||
//! schema.clone(),
|
||||
//! );
|
||||
//! db.create_table("my_table", Box::new(batches))
|
||||
//! .execute()
|
||||
//! .await
|
||||
//! .unwrap();
|
||||
//! # });
|
||||
//! ```
|
||||
//!
|
||||
@@ -111,14 +129,13 @@
|
||||
//!
|
||||
//! ```no_run
|
||||
//! # use std::sync::Arc;
|
||||
//! # use vectordb::connect;
|
||||
//! # use arrow_array::{FixedSizeListArray, types::Float32Type, RecordBatch,
|
||||
//! # RecordBatchIterator, Int32Array};
|
||||
//! # use arrow_schema::{Schema, Field, DataType};
|
||||
//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
//! # let tmpdir = tempfile::tempdir().unwrap();
|
||||
//! # let db = connect(tmpdir.path().to_str().unwrap()).await.unwrap();
|
||||
//! # let tbl = db.open_table("idx_test").await.unwrap();
|
||||
//! # let db = vectordb::connect(tmpdir.path().to_str().unwrap()).execute().await.unwrap();
|
||||
//! # let tbl = db.open_table("idx_test").execute().await.unwrap();
|
||||
//! tbl.create_index(&["vector"])
|
||||
//! .ivf_pq()
|
||||
//! .num_partitions(256)
|
||||
@@ -136,10 +153,9 @@
|
||||
//! # use arrow_schema::{DataType, Schema, Field};
|
||||
//! # use arrow_array::{RecordBatch, RecordBatchIterator};
|
||||
//! # use arrow_array::{FixedSizeListArray, Float32Array, Int32Array, types::Float32Type};
|
||||
//! # use vectordb::connection::{Database, Connection};
|
||||
//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
//! # let tmpdir = tempfile::tempdir().unwrap();
|
||||
//! # let db = Database::connect(tmpdir.path().to_str().unwrap()).await.unwrap();
|
||||
//! # let db = vectordb::connect(tmpdir.path().to_str().unwrap()).execute().await.unwrap();
|
||||
//! # let schema = Arc::new(Schema::new(vec![
|
||||
//! # Field::new("id", DataType::Int32, false),
|
||||
//! # Field::new("vector", DataType::FixedSizeList(
|
||||
@@ -154,8 +170,8 @@
|
||||
//! # ]).unwrap()
|
||||
//! # ].into_iter().map(Ok),
|
||||
//! # schema.clone());
|
||||
//! # db.create_table("my_table", Box::new(batches), None).await.unwrap();
|
||||
//! # let table = db.open_table("my_table").await.unwrap();
|
||||
//! # db.create_table("my_table", Box::new(batches)).execute().await.unwrap();
|
||||
//! # let table = db.open_table("my_table").execute().await.unwrap();
|
||||
//! let results = table
|
||||
//! .search(&[1.0; 128])
|
||||
//! .execute_stream()
|
||||
@@ -165,8 +181,6 @@
|
||||
//! .await
|
||||
//! .unwrap();
|
||||
//! # });
|
||||
//!
|
||||
//!
|
||||
//! ```
|
||||
|
||||
pub mod connection;
|
||||
@@ -179,10 +193,8 @@ pub mod query;
|
||||
pub mod table;
|
||||
pub mod utils;
|
||||
|
||||
pub use connection::{Connection, Database};
|
||||
pub use error::{Error, Result};
|
||||
pub use table::{Table, TableRef};
|
||||
|
||||
/// Connect to a database
|
||||
pub use connection::{connect, connect_with_options, ConnectOptions};
|
||||
pub use lance::dataset::WriteMode;
|
||||
pub use connection::connect;
|
||||
|
||||
@@ -60,7 +60,6 @@ impl Query {
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `dataset` - Lance dataset.
|
||||
///
|
||||
pub(crate) fn new(dataset: Arc<Dataset>) -> Self {
|
||||
Self {
|
||||
dataset,
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
use std::path::Path;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use arrow_array::RecordBatchReader;
|
||||
use arrow_array::{RecordBatchIterator, RecordBatchReader};
|
||||
use arrow_schema::{Schema, SchemaRef};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Duration;
|
||||
@@ -27,7 +27,7 @@ use lance::dataset::optimize::{
|
||||
compact_files, CompactionMetrics, CompactionOptions, IndexRemapperOptions,
|
||||
};
|
||||
pub use lance::dataset::ReadParams;
|
||||
use lance::dataset::{Dataset, UpdateBuilder, WhenMatched, WriteParams};
|
||||
use lance::dataset::{Dataset, UpdateBuilder, WhenMatched, WriteMode, WriteParams};
|
||||
use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource};
|
||||
use lance::io::WrappingObjectStore;
|
||||
use lance_index::{optimize::OptimizeOptions, DatasetIndexExt};
|
||||
@@ -38,7 +38,6 @@ use crate::index::vector::{VectorIndex, VectorIndexStatistics};
|
||||
use crate::index::IndexBuilder;
|
||||
use crate::query::Query;
|
||||
use crate::utils::{PatchReadParam, PatchWriteParam};
|
||||
use crate::WriteMode;
|
||||
|
||||
use self::merge::{MergeInsert, MergeInsertBuilder};
|
||||
|
||||
@@ -85,6 +84,35 @@ pub struct OptimizeStats {
|
||||
pub prune: Option<RemovalStats>,
|
||||
}
|
||||
|
||||
/// Options to use when writing data
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct WriteOptions {
|
||||
// Coming soon: https://github.com/lancedb/lancedb/issues/992
|
||||
// /// What behavior to take if the data contains invalid vectors
|
||||
// pub on_bad_vectors: BadVectorHandling,
|
||||
/// Advanced parameters that can be used to customize table creation
|
||||
///
|
||||
/// If set, these will take precedence over any overlapping `OpenTableOptions` options
|
||||
pub lance_write_params: Option<WriteParams>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub enum AddDataMode {
|
||||
/// Rows will be appended to the table (the default)
|
||||
#[default]
|
||||
Append,
|
||||
/// The existing table will be overwritten with the new data
|
||||
Overwrite,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct AddDataOptions {
|
||||
/// Whether to add new rows (the default) or replace the existing data
|
||||
pub mode: AddDataMode,
|
||||
/// Options to use when writing the data
|
||||
pub write_options: WriteOptions,
|
||||
}
|
||||
|
||||
/// A Table is a collection of strong typed Rows.
|
||||
///
|
||||
/// The type of the each row is defined in Apache Arrow [Schema].
|
||||
@@ -112,12 +140,12 @@ pub trait Table: std::fmt::Display + Send + Sync {
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `batches` RecordBatch to be saved in the Table
|
||||
/// * `params` Append / Overwrite existing records. Default: Append
|
||||
/// * `batches` data to be added to the Table
|
||||
/// * `options` options to control how data is added
|
||||
async fn add(
|
||||
&self,
|
||||
batches: Box<dyn RecordBatchReader + Send>,
|
||||
params: Option<WriteParams>,
|
||||
options: AddDataOptions,
|
||||
) -> Result<()>;
|
||||
|
||||
/// Delete the rows from table that match the predicate.
|
||||
@@ -129,28 +157,43 @@ pub trait Table: std::fmt::Display + Send + Sync {
|
||||
///
|
||||
/// ```no_run
|
||||
/// # use std::sync::Arc;
|
||||
/// # use vectordb::connection::{Database, Connection};
|
||||
/// # use arrow_array::{FixedSizeListArray, types::Float32Type, RecordBatch,
|
||||
/// # RecordBatchIterator, Int32Array};
|
||||
/// # use arrow_schema::{Schema, Field, DataType};
|
||||
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
/// let tmpdir = tempfile::tempdir().unwrap();
|
||||
/// let db = Database::connect(tmpdir.path().to_str().unwrap()).await.unwrap();
|
||||
/// let db = vectordb::connect(tmpdir.path().to_str().unwrap())
|
||||
/// .execute()
|
||||
/// .await
|
||||
/// .unwrap();
|
||||
/// # let schema = Arc::new(Schema::new(vec![
|
||||
/// # Field::new("id", DataType::Int32, false),
|
||||
/// # Field::new("vector", DataType::FixedSizeList(
|
||||
/// # Arc::new(Field::new("item", DataType::Float32, true)), 128), true),
|
||||
/// # ]));
|
||||
/// let batches = RecordBatchIterator::new(vec![
|
||||
/// RecordBatch::try_new(schema.clone(),
|
||||
/// vec![
|
||||
/// Arc::new(Int32Array::from_iter_values(0..10)),
|
||||
/// Arc::new(FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
|
||||
/// (0..10).map(|_| Some(vec![Some(1.0); 128])), 128)),
|
||||
/// ]).unwrap()
|
||||
/// ].into_iter().map(Ok),
|
||||
/// schema.clone());
|
||||
/// let tbl = db.create_table("delete_test", Box::new(batches), None).await.unwrap();
|
||||
/// let batches = RecordBatchIterator::new(
|
||||
/// vec![RecordBatch::try_new(
|
||||
/// schema.clone(),
|
||||
/// vec![
|
||||
/// Arc::new(Int32Array::from_iter_values(0..10)),
|
||||
/// Arc::new(
|
||||
/// FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
|
||||
/// (0..10).map(|_| Some(vec![Some(1.0); 128])),
|
||||
/// 128,
|
||||
/// ),
|
||||
/// ),
|
||||
/// ],
|
||||
/// )
|
||||
/// .unwrap()]
|
||||
/// .into_iter()
|
||||
/// .map(Ok),
|
||||
/// schema.clone(),
|
||||
/// );
|
||||
/// let tbl = db
|
||||
/// .create_table("delete_test", Box::new(batches))
|
||||
/// .execute()
|
||||
/// .await
|
||||
/// .unwrap();
|
||||
/// tbl.delete("id > 5").await.unwrap();
|
||||
/// # });
|
||||
/// ```
|
||||
@@ -162,14 +205,16 @@ pub trait Table: std::fmt::Display + Send + Sync {
|
||||
///
|
||||
/// ```no_run
|
||||
/// # use std::sync::Arc;
|
||||
/// # use vectordb::connection::{Database, Connection};
|
||||
/// # use arrow_array::{FixedSizeListArray, types::Float32Type, RecordBatch,
|
||||
/// # RecordBatchIterator, Int32Array};
|
||||
/// # use arrow_schema::{Schema, Field, DataType};
|
||||
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
/// let tmpdir = tempfile::tempdir().unwrap();
|
||||
/// let db = Database::connect(tmpdir.path().to_str().unwrap()).await.unwrap();
|
||||
/// # let tbl = db.open_table("idx_test").await.unwrap();
|
||||
/// let db = vectordb::connect(tmpdir.path().to_str().unwrap())
|
||||
/// .execute()
|
||||
/// .await
|
||||
/// .unwrap();
|
||||
/// # let tbl = db.open_table("idx_test").execute().await.unwrap();
|
||||
/// tbl.create_index(&["vector"])
|
||||
/// .ivf_pq()
|
||||
/// .num_partitions(256)
|
||||
@@ -214,32 +259,44 @@ pub trait Table: std::fmt::Display + Send + Sync {
|
||||
///
|
||||
/// ```no_run
|
||||
/// # use std::sync::Arc;
|
||||
/// # use vectordb::connection::{Database, Connection};
|
||||
/// # use arrow_array::{FixedSizeListArray, types::Float32Type, RecordBatch,
|
||||
/// # RecordBatchIterator, Int32Array};
|
||||
/// # use arrow_schema::{Schema, Field, DataType};
|
||||
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
/// let tmpdir = tempfile::tempdir().unwrap();
|
||||
/// let db = Database::connect(tmpdir.path().to_str().unwrap()).await.unwrap();
|
||||
/// # let tbl = db.open_table("idx_test").await.unwrap();
|
||||
/// let db = vectordb::connect(tmpdir.path().to_str().unwrap())
|
||||
/// .execute()
|
||||
/// .await
|
||||
/// .unwrap();
|
||||
/// # let tbl = db.open_table("idx_test").execute().await.unwrap();
|
||||
/// # let schema = Arc::new(Schema::new(vec![
|
||||
/// # Field::new("id", DataType::Int32, false),
|
||||
/// # Field::new("vector", DataType::FixedSizeList(
|
||||
/// # Arc::new(Field::new("item", DataType::Float32, true)), 128), true),
|
||||
/// # ]));
|
||||
/// let new_data = RecordBatchIterator::new(vec![
|
||||
/// RecordBatch::try_new(schema.clone(),
|
||||
/// vec![
|
||||
/// Arc::new(Int32Array::from_iter_values(0..10)),
|
||||
/// Arc::new(FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
|
||||
/// (0..10).map(|_| Some(vec![Some(1.0); 128])), 128)),
|
||||
/// ]).unwrap()
|
||||
/// ].into_iter().map(Ok),
|
||||
/// schema.clone());
|
||||
/// let new_data = RecordBatchIterator::new(
|
||||
/// vec![RecordBatch::try_new(
|
||||
/// schema.clone(),
|
||||
/// vec![
|
||||
/// Arc::new(Int32Array::from_iter_values(0..10)),
|
||||
/// Arc::new(
|
||||
/// FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
|
||||
/// (0..10).map(|_| Some(vec![Some(1.0); 128])),
|
||||
/// 128,
|
||||
/// ),
|
||||
/// ),
|
||||
/// ],
|
||||
/// )
|
||||
/// .unwrap()]
|
||||
/// .into_iter()
|
||||
/// .map(Ok),
|
||||
/// schema.clone(),
|
||||
/// );
|
||||
/// // Perform an upsert operation
|
||||
/// let mut merge_insert = tbl.merge_insert(&["id"]);
|
||||
/// merge_insert.when_matched_update_all(None)
|
||||
/// .when_not_matched_insert_all();
|
||||
/// merge_insert
|
||||
/// .when_matched_update_all(None)
|
||||
/// .when_not_matched_insert_all();
|
||||
/// merge_insert.execute(Box::new(new_data)).await.unwrap();
|
||||
/// # });
|
||||
/// ```
|
||||
@@ -266,7 +323,9 @@ pub trait Table: std::fmt::Display + Send + Sync {
|
||||
/// # use futures::TryStreamExt;
|
||||
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
/// # let tbl = vectordb::table::NativeTable::open("/tmp/tbl").await.unwrap();
|
||||
/// let stream = tbl.query().nearest_to(&[1.0, 2.0, 3.0])
|
||||
/// let stream = tbl
|
||||
/// .query()
|
||||
/// .nearest_to(&[1.0, 2.0, 3.0])
|
||||
/// .refine_factor(5)
|
||||
/// .nprobes(10)
|
||||
/// .execute_stream()
|
||||
@@ -299,11 +358,7 @@ pub trait Table: std::fmt::Display + Send + Sync {
|
||||
/// # use futures::TryStreamExt;
|
||||
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
/// # let tbl = vectordb::table::NativeTable::open("/tmp/tbl").await.unwrap();
|
||||
/// let stream = tbl
|
||||
/// .query()
|
||||
/// .execute_stream()
|
||||
/// .await
|
||||
/// .unwrap();
|
||||
/// let stream = tbl.query().execute_stream().await.unwrap();
|
||||
/// let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
|
||||
/// # });
|
||||
/// ```
|
||||
@@ -351,7 +406,7 @@ impl NativeTable {
|
||||
/// * A [NativeTable] object.
|
||||
pub async fn open(uri: &str) -> Result<Self> {
|
||||
let name = Self::get_table_name(uri)?;
|
||||
Self::open_with_params(uri, &name, None, ReadParams::default()).await
|
||||
Self::open_with_params(uri, &name, None, None).await
|
||||
}
|
||||
|
||||
/// Opens an existing Table
|
||||
@@ -369,8 +424,9 @@ impl NativeTable {
|
||||
uri: &str,
|
||||
name: &str,
|
||||
write_store_wrapper: Option<Arc<dyn WrappingObjectStore>>,
|
||||
params: ReadParams,
|
||||
params: Option<ReadParams>,
|
||||
) -> Result<Self> {
|
||||
let params = params.unwrap_or_default();
|
||||
// patch the params if we have a write store wrapper
|
||||
let params = match write_store_wrapper.clone() {
|
||||
Some(wrapper) => params.patch_with_store_wrapper(wrapper)?,
|
||||
@@ -403,7 +459,6 @@ impl NativeTable {
|
||||
}
|
||||
|
||||
/// Checkout a specific version of this [NativeTable]
|
||||
///
|
||||
pub async fn checkout(uri: &str, version: u64) -> Result<Self> {
|
||||
let name = Self::get_table_name(uri)?;
|
||||
Self::checkout_with_params(uri, &name, version, None, ReadParams::default()).await
|
||||
@@ -489,13 +544,14 @@ impl NativeTable {
|
||||
write_store_wrapper: Option<Arc<dyn WrappingObjectStore>>,
|
||||
params: Option<WriteParams>,
|
||||
) -> Result<Self> {
|
||||
let params = params.unwrap_or_default();
|
||||
// patch the params if we have a write store wrapper
|
||||
let params = match write_store_wrapper.clone() {
|
||||
Some(wrapper) => params.patch_with_store_wrapper(wrapper)?,
|
||||
None => params,
|
||||
};
|
||||
|
||||
let dataset = Dataset::write(batches, uri, params)
|
||||
let dataset = Dataset::write(batches, uri, Some(params))
|
||||
.await
|
||||
.map_err(|e| match e {
|
||||
lance::Error::DatasetAlreadyExists { .. } => Error::TableAlreadyExists {
|
||||
@@ -513,6 +569,17 @@ impl NativeTable {
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn create_empty(
|
||||
uri: &str,
|
||||
name: &str,
|
||||
schema: SchemaRef,
|
||||
write_store_wrapper: Option<Arc<dyn WrappingObjectStore>>,
|
||||
params: Option<WriteParams>,
|
||||
) -> Result<Self> {
|
||||
let batches = RecordBatchIterator::new(vec![], schema);
|
||||
Self::create(uri, name, batches, write_store_wrapper, params).await
|
||||
}
|
||||
|
||||
/// Version of this Table
|
||||
pub fn version(&self) -> u64 {
|
||||
self.dataset.lock().expect("lock poison").version().version
|
||||
@@ -740,20 +807,26 @@ impl Table for NativeTable {
|
||||
async fn add(
|
||||
&self,
|
||||
batches: Box<dyn RecordBatchReader + Send>,
|
||||
params: Option<WriteParams>,
|
||||
params: AddDataOptions,
|
||||
) -> Result<()> {
|
||||
let params = Some(params.unwrap_or(WriteParams {
|
||||
mode: WriteMode::Append,
|
||||
..WriteParams::default()
|
||||
}));
|
||||
let lance_params = params
|
||||
.write_options
|
||||
.lance_write_params
|
||||
.unwrap_or(WriteParams {
|
||||
mode: match params.mode {
|
||||
AddDataMode::Append => WriteMode::Append,
|
||||
AddDataMode::Overwrite => WriteMode::Overwrite,
|
||||
},
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
// patch the params if we have a write store wrapper
|
||||
let params = match self.store_wrapper.clone() {
|
||||
Some(wrapper) => params.patch_with_store_wrapper(wrapper)?,
|
||||
None => params,
|
||||
let lance_params = match self.store_wrapper.clone() {
|
||||
Some(wrapper) => lance_params.patch_with_store_wrapper(wrapper)?,
|
||||
None => lance_params,
|
||||
};
|
||||
|
||||
self.reset_dataset(Dataset::write(batches, &self.uri, params).await?);
|
||||
self.reset_dataset(Dataset::write(batches, &self.uri, Some(lance_params)).await?);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -881,25 +954,6 @@ mod tests {
|
||||
assert_eq!(c.to_str().unwrap(), "s3://bucket/path/to/file/subfile");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_already_exists() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let uri = tmp_dir.path().to_str().unwrap();
|
||||
|
||||
let batches = make_test_batches();
|
||||
let _ = batches.schema().clone();
|
||||
NativeTable::create(uri, "test", batches, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let batches = make_test_batches();
|
||||
let result = NativeTable::create(uri, "test", batches, None, None).await;
|
||||
assert!(matches!(
|
||||
result.unwrap_err(),
|
||||
Error::TableAlreadyExists { .. }
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_count_rows() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
@@ -940,7 +994,10 @@ mod tests {
|
||||
schema.clone(),
|
||||
);
|
||||
|
||||
table.add(Box::new(new_batches), None).await.unwrap();
|
||||
table
|
||||
.add(Box::new(new_batches), AddDataOptions::default())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 20);
|
||||
assert_eq!(table.name, "test");
|
||||
}
|
||||
@@ -1003,23 +1060,47 @@ mod tests {
|
||||
.unwrap();
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 10);
|
||||
|
||||
let new_batches = RecordBatchIterator::new(
|
||||
vec![RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![Arc::new(Int32Array::from_iter_values(100..110))],
|
||||
)
|
||||
.unwrap()]
|
||||
.into_iter()
|
||||
.map(Ok),
|
||||
let batches = vec![RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
);
|
||||
vec![Arc::new(Int32Array::from_iter_values(100..110))],
|
||||
)
|
||||
.unwrap()]
|
||||
.into_iter()
|
||||
.map(Ok);
|
||||
|
||||
let new_batches = RecordBatchIterator::new(batches.clone(), schema.clone());
|
||||
|
||||
// Can overwrite using AddDataOptions::mode
|
||||
table
|
||||
.add(
|
||||
Box::new(new_batches),
|
||||
AddDataOptions {
|
||||
mode: AddDataMode::Overwrite,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 10);
|
||||
assert_eq!(table.name, "test");
|
||||
|
||||
// Can overwrite using underlying WriteParams (which
|
||||
// take precedence over AddDataOptions::mode)
|
||||
|
||||
let param: WriteParams = WriteParams {
|
||||
mode: WriteMode::Overwrite,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
table.add(Box::new(new_batches), Some(param)).await.unwrap();
|
||||
let opts = AddDataOptions {
|
||||
write_options: WriteOptions {
|
||||
lance_write_params: Some(param),
|
||||
},
|
||||
mode: AddDataMode::Append,
|
||||
};
|
||||
|
||||
let new_batches = RecordBatchIterator::new(batches.clone(), schema.clone());
|
||||
table.add(Box::new(new_batches), opts).await.unwrap();
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 10);
|
||||
assert_eq!(table.name, "test");
|
||||
}
|
||||
@@ -1329,7 +1410,7 @@ mod tests {
|
||||
..Default::default()
|
||||
};
|
||||
assert!(!wrapper.called());
|
||||
let _ = NativeTable::open_with_params(uri, "test", None, param)
|
||||
let _ = NativeTable::open_with_params(uri, "test", None, Some(param))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(wrapper.called());
|
||||
|
||||
@@ -32,20 +32,17 @@ impl PatchStoreParam for Option<ObjectStoreParams> {
|
||||
}
|
||||
|
||||
pub trait PatchWriteParam {
|
||||
fn patch_with_store_wrapper(
|
||||
self,
|
||||
wrapper: Arc<dyn WrappingObjectStore>,
|
||||
) -> Result<Option<WriteParams>>;
|
||||
fn patch_with_store_wrapper(self, wrapper: Arc<dyn WrappingObjectStore>)
|
||||
-> Result<WriteParams>;
|
||||
}
|
||||
|
||||
impl PatchWriteParam for Option<WriteParams> {
|
||||
impl PatchWriteParam for WriteParams {
|
||||
fn patch_with_store_wrapper(
|
||||
self,
|
||||
mut self,
|
||||
wrapper: Arc<dyn WrappingObjectStore>,
|
||||
) -> Result<Option<WriteParams>> {
|
||||
let mut params = self.unwrap_or_default();
|
||||
params.store_params = params.store_params.patch_with_store_wrapper(wrapper)?;
|
||||
Ok(Some(params))
|
||||
) -> Result<WriteParams> {
|
||||
self.store_params = self.store_params.patch_with_store_wrapper(wrapper)?;
|
||||
Ok(self)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user