mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-04 19:02:58 +00:00
feat(rust): create index API improvement (#853)
* Extract a minimal Table interface in Rust SDK * Make create_index composable in Rust. * Fix compiling issues from ffi
This commit is contained in:
@@ -24,11 +24,14 @@ arrow-ord = "49.0"
|
||||
arrow-schema = "49.0"
|
||||
arrow-arith = "49.0"
|
||||
arrow-cast = "49.0"
|
||||
async-trait = "0"
|
||||
chrono = "0.4.23"
|
||||
half = { "version" = "=2.3.1", default-features = false, features = [
|
||||
"num-traits",
|
||||
] }
|
||||
futures = "0"
|
||||
log = "0.4"
|
||||
object_store = "0.9.0"
|
||||
snafu = "0.7.4"
|
||||
url = "2"
|
||||
num-traits = "0.2"
|
||||
|
||||
@@ -16,16 +16,16 @@ use crate::query::Query;
|
||||
use arrow_ipc::writer::FileWriter;
|
||||
use napi::bindgen_prelude::*;
|
||||
use napi_derive::napi;
|
||||
use vectordb::{ipc::ipc_file_to_batches, table::Table as LanceDBTable};
|
||||
use vectordb::{ipc::ipc_file_to_batches, table::TableRef};
|
||||
|
||||
#[napi]
|
||||
pub struct Table {
|
||||
pub(crate) table: LanceDBTable,
|
||||
pub(crate) table: TableRef,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
impl Table {
|
||||
pub(crate) fn new(table: LanceDBTable) -> Self {
|
||||
pub(crate) fn new(table: TableRef) -> Self {
|
||||
Self { table }
|
||||
}
|
||||
|
||||
@@ -46,7 +46,7 @@ impl Table {
|
||||
pub async unsafe fn add(&mut self, buf: Buffer) -> napi::Result<()> {
|
||||
let batches = ipc_file_to_batches(buf.to_vec())
|
||||
.map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?;
|
||||
self.table.add(batches, None).await.map_err(|e| {
|
||||
self.table.add(Box::new(batches), None).await.map_err(|e| {
|
||||
napi::Error::from_reason(format!(
|
||||
"Failed to add batches to table {}: {}",
|
||||
self.table, e
|
||||
|
||||
@@ -29,10 +29,14 @@ pub(crate) fn table_create_scalar_index(mut cx: FunctionContext) -> JsResult<JsP
|
||||
|
||||
let (deferred, promise) = cx.promise();
|
||||
let channel = cx.channel();
|
||||
let mut table = js_table.table.clone();
|
||||
let table = js_table.table.clone();
|
||||
|
||||
rt.spawn(async move {
|
||||
let idx_result = table.create_scalar_index(&column, replace).await;
|
||||
let idx_result = table
|
||||
.as_native()
|
||||
.unwrap()
|
||||
.create_scalar_index(&column, replace)
|
||||
.await;
|
||||
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
idx_result.or_throw(&mut cx)?;
|
||||
|
||||
@@ -12,13 +12,11 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use lance_index::vector::{ivf::IvfBuildParams, pq::PQBuildParams};
|
||||
use lance_linalg::distance::MetricType;
|
||||
use neon::context::FunctionContext;
|
||||
use neon::prelude::*;
|
||||
use std::convert::TryFrom;
|
||||
|
||||
use vectordb::index::vector::{IvfPQIndexBuilder, VectorIndexBuilder};
|
||||
use vectordb::index::IndexBuilder;
|
||||
|
||||
use crate::error::Error::InvalidIndexType;
|
||||
use crate::error::ResultExt;
|
||||
@@ -29,17 +27,24 @@ use crate::table::JsTable;
|
||||
pub(crate) fn table_create_vector_index(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||
let index_params = cx.argument::<JsObject>(0)?;
|
||||
let index_params_builder = get_index_params_builder(&mut cx, index_params).or_throw(&mut cx)?;
|
||||
|
||||
let rt = runtime(&mut cx)?;
|
||||
|
||||
let (deferred, promise) = cx.promise();
|
||||
let channel = cx.channel();
|
||||
let mut table = js_table.table.clone();
|
||||
let table = js_table.table.clone();
|
||||
|
||||
let column_name = index_params
|
||||
.get_opt::<JsString, _, _>(&mut cx, "column")?
|
||||
.map(|s| s.value(&mut cx))
|
||||
.unwrap_or("vector".to_string()); // Backward compatibility
|
||||
|
||||
let tbl = table.clone();
|
||||
let mut index_builder = tbl.create_index(&[&column_name]);
|
||||
get_index_params_builder(&mut cx, index_params, &mut index_builder).or_throw(&mut cx)?;
|
||||
|
||||
rt.spawn(async move {
|
||||
let idx_result = table.create_index(&index_params_builder).await;
|
||||
|
||||
let idx_result = index_builder.build().await;
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
idx_result.or_throw(&mut cx)?;
|
||||
Ok(cx.boxed(JsTable::from(table)))
|
||||
@@ -51,66 +56,39 @@ pub(crate) fn table_create_vector_index(mut cx: FunctionContext) -> JsResult<JsP
|
||||
fn get_index_params_builder(
|
||||
cx: &mut FunctionContext,
|
||||
obj: Handle<JsObject>,
|
||||
) -> crate::error::Result<impl VectorIndexBuilder> {
|
||||
let idx_type = obj.get::<JsString, _, _>(cx, "type")?.value(cx);
|
||||
|
||||
match idx_type.as_str() {
|
||||
"ivf_pq" => {
|
||||
let mut index_builder: IvfPQIndexBuilder = IvfPQIndexBuilder::new();
|
||||
let mut pq_params = PQBuildParams::default();
|
||||
|
||||
obj.get_opt::<JsString, _, _>(cx, "column")?
|
||||
.map(|s| index_builder.column(s.value(cx)));
|
||||
|
||||
obj.get_opt::<JsString, _, _>(cx, "index_name")?
|
||||
.map(|s| index_builder.index_name(s.value(cx)));
|
||||
|
||||
if let Some(metric_type) = obj.get_opt::<JsString, _, _>(cx, "metric_type")? {
|
||||
let metric_type = MetricType::try_from(metric_type.value(cx).as_str()).unwrap();
|
||||
index_builder.metric_type(metric_type);
|
||||
}
|
||||
|
||||
let num_partitions = obj.get_opt_usize(cx, "num_partitions")?;
|
||||
let max_iters = obj.get_opt_usize(cx, "max_iters")?;
|
||||
|
||||
num_partitions.map(|np| {
|
||||
let max_iters = max_iters.unwrap_or(50);
|
||||
let ivf_params = IvfBuildParams {
|
||||
num_partitions: np,
|
||||
max_iters,
|
||||
..Default::default()
|
||||
};
|
||||
index_builder.ivf_params(ivf_params)
|
||||
});
|
||||
|
||||
if let Some(use_opq) = obj.get_opt::<JsBoolean, _, _>(cx, "use_opq")? {
|
||||
pq_params.use_opq = use_opq.value(cx);
|
||||
}
|
||||
|
||||
if let Some(num_sub_vectors) = obj.get_opt_usize(cx, "num_sub_vectors")? {
|
||||
pq_params.num_sub_vectors = num_sub_vectors;
|
||||
}
|
||||
|
||||
if let Some(num_bits) = obj.get_opt_usize(cx, "num_bits")? {
|
||||
pq_params.num_bits = num_bits;
|
||||
}
|
||||
|
||||
if let Some(max_iters) = obj.get_opt_usize(cx, "max_iters")? {
|
||||
pq_params.max_iters = max_iters;
|
||||
}
|
||||
|
||||
if let Some(max_opq_iters) = obj.get_opt_usize(cx, "max_opq_iters")? {
|
||||
pq_params.max_opq_iters = max_opq_iters;
|
||||
}
|
||||
|
||||
if let Some(replace) = obj.get_opt::<JsBoolean, _, _>(cx, "replace")? {
|
||||
index_builder.replace(replace.value(cx));
|
||||
}
|
||||
|
||||
Ok(index_builder)
|
||||
builder: &mut IndexBuilder,
|
||||
) -> crate::error::Result<()> {
|
||||
match obj.get::<JsString, _, _>(cx, "type")?.value(cx).as_str() {
|
||||
"ivf_pq" => builder.ivf_pq(),
|
||||
_ => {
|
||||
return Err(InvalidIndexType {
|
||||
index_type: "".into(),
|
||||
})
|
||||
}
|
||||
index_type => Err(InvalidIndexType {
|
||||
index_type: index_type.into(),
|
||||
}),
|
||||
};
|
||||
|
||||
obj.get_opt::<JsString, _, _>(cx, "index_name")?
|
||||
.map(|s| builder.name(s.value(cx).as_str()));
|
||||
|
||||
if let Some(metric_type) = obj.get_opt::<JsString, _, _>(cx, "metric_type")? {
|
||||
let metric_type = MetricType::try_from(metric_type.value(cx).as_str())?;
|
||||
builder.metric_type(metric_type);
|
||||
}
|
||||
|
||||
if let Some(np) = obj.get_opt_usize(cx, "num_partitions")? {
|
||||
builder.num_partitions(np as u64);
|
||||
}
|
||||
if let Some(ns) = obj.get_opt_u32(cx, "num_sub_vectors")? {
|
||||
builder.num_sub_vectors(ns);
|
||||
}
|
||||
if let Some(max_iters) = obj.get_opt_u32(cx, "max_iters")? {
|
||||
builder.max_iterations(max_iters);
|
||||
}
|
||||
if let Some(num_bits) = obj.get_opt_u32(cx, "num_bits")? {
|
||||
builder.num_bits(num_bits);
|
||||
}
|
||||
if let Some(replace) = obj.get_opt::<JsBoolean, _, _>(cx, "replace")? {
|
||||
builder.replace(replace.value(cx));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -70,14 +70,15 @@ impl JsQuery {
|
||||
let query = query_vector.map(|q| convert::js_array_to_vec(q.deref(), &mut cx));
|
||||
|
||||
rt.spawn(async move {
|
||||
let mut builder = table
|
||||
.search(query)
|
||||
.refine_factor(refine_factor)
|
||||
.nprobes(nprobes)
|
||||
.filter(filter)
|
||||
.metric_type(metric_type)
|
||||
.select(select)
|
||||
.prefilter(prefilter);
|
||||
let mut builder = table.query();
|
||||
if let Some(query) = query {
|
||||
builder = builder
|
||||
.query_vector(&query)
|
||||
.refine_factor(refine_factor)
|
||||
.nprobes(nprobes)
|
||||
.metric_type(metric_type);
|
||||
};
|
||||
builder = builder.filter(filter).select(select).prefilter(prefilter);
|
||||
if let Some(limit) = limit {
|
||||
builder = builder.limit(limit as usize);
|
||||
};
|
||||
|
||||
@@ -20,19 +20,19 @@ use lance::io::ObjectStoreParams;
|
||||
use crate::arrow::{arrow_buffer_to_record_batch, record_batch_to_buffer};
|
||||
use neon::prelude::*;
|
||||
use neon::types::buffer::TypedArray;
|
||||
use vectordb::Table;
|
||||
use vectordb::TableRef;
|
||||
|
||||
use crate::error::ResultExt;
|
||||
use crate::{convert, get_aws_creds, get_aws_region, runtime, JsDatabase};
|
||||
|
||||
pub(crate) struct JsTable {
|
||||
pub table: Table,
|
||||
pub table: TableRef,
|
||||
}
|
||||
|
||||
impl Finalize for JsTable {}
|
||||
|
||||
impl From<Table> for JsTable {
|
||||
fn from(table: Table) -> Self {
|
||||
impl From<TableRef> for JsTable {
|
||||
fn from(table: TableRef) -> Self {
|
||||
JsTable { table }
|
||||
}
|
||||
}
|
||||
@@ -96,7 +96,7 @@ impl JsTable {
|
||||
arrow_buffer_to_record_batch(buffer.as_slice(&cx)).or_throw(&mut cx)?;
|
||||
let rt = runtime(&mut cx)?;
|
||||
let channel = cx.channel();
|
||||
let mut table = js_table.table.clone();
|
||||
let table = js_table.table.clone();
|
||||
|
||||
let (deferred, promise) = cx.promise();
|
||||
let write_mode = match write_mode.as_str() {
|
||||
@@ -118,7 +118,7 @@ impl JsTable {
|
||||
|
||||
rt.spawn(async move {
|
||||
let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
|
||||
let add_result = table.add(batch_reader, Some(params)).await;
|
||||
let add_result = table.add(Box::new(batch_reader), Some(params)).await;
|
||||
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
add_result.or_throw(&mut cx)?;
|
||||
@@ -152,7 +152,7 @@ impl JsTable {
|
||||
let (deferred, promise) = cx.promise();
|
||||
let predicate = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||
let channel = cx.channel();
|
||||
let mut table = js_table.table.clone();
|
||||
let table = js_table.table.clone();
|
||||
|
||||
rt.spawn(async move {
|
||||
let delete_result = table.delete(&predicate).await;
|
||||
@@ -167,7 +167,7 @@ impl JsTable {
|
||||
|
||||
pub(crate) fn js_update(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||
let mut table = js_table.table.clone();
|
||||
let table = js_table.table.clone();
|
||||
|
||||
let rt = runtime(&mut cx)?;
|
||||
let (deferred, promise) = cx.promise();
|
||||
@@ -218,7 +218,11 @@ impl JsTable {
|
||||
|
||||
let predicate = predicate.as_deref();
|
||||
|
||||
let update_result = table.update(predicate, updates_arg).await;
|
||||
let update_result = table
|
||||
.as_native()
|
||||
.unwrap()
|
||||
.update(predicate, updates_arg)
|
||||
.await;
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
update_result.or_throw(&mut cx)?;
|
||||
Ok(cx.boxed(JsTable::from(table)))
|
||||
@@ -249,6 +253,8 @@ impl JsTable {
|
||||
|
||||
rt.spawn(async move {
|
||||
let stats = table
|
||||
.as_native()
|
||||
.unwrap()
|
||||
.cleanup_old_versions(older_than, Some(delete_unverified))
|
||||
.await;
|
||||
|
||||
@@ -278,7 +284,7 @@ impl JsTable {
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||
let rt = runtime(&mut cx)?;
|
||||
let (deferred, promise) = cx.promise();
|
||||
let mut table = js_table.table.clone();
|
||||
let table = js_table.table.clone();
|
||||
let channel = cx.channel();
|
||||
|
||||
let js_options = cx.argument::<JsObject>(0)?;
|
||||
@@ -310,7 +316,11 @@ impl JsTable {
|
||||
}
|
||||
|
||||
rt.spawn(async move {
|
||||
let stats = table.compact_files(options, None).await;
|
||||
let stats = table
|
||||
.as_native()
|
||||
.unwrap()
|
||||
.compact_files(options, None)
|
||||
.await;
|
||||
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
let stats = stats.or_throw(&mut cx)?;
|
||||
@@ -349,7 +359,7 @@ impl JsTable {
|
||||
let table = js_table.table.clone();
|
||||
|
||||
rt.spawn(async move {
|
||||
let indices = table.load_indices().await;
|
||||
let indices = table.as_native().unwrap().load_indices().await;
|
||||
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
let indices = indices.or_throw(&mut cx)?;
|
||||
@@ -389,8 +399,8 @@ impl JsTable {
|
||||
|
||||
rt.spawn(async move {
|
||||
let load_stats = futures::try_join!(
|
||||
table.count_indexed_rows(&index_uuid),
|
||||
table.count_unindexed_rows(&index_uuid)
|
||||
table.as_native().unwrap().count_indexed_rows(&index_uuid),
|
||||
table.as_native().unwrap().count_unindexed_rows(&index_uuid)
|
||||
);
|
||||
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
|
||||
@@ -26,11 +26,11 @@ lance-index = { workspace = true }
|
||||
lance-linalg = { workspace = true }
|
||||
lance-testing = { workspace = true }
|
||||
tokio = { version = "1.23", features = ["rt-multi-thread"] }
|
||||
log = { workspace = true }
|
||||
log.workspace = true
|
||||
async-trait = "0"
|
||||
bytes = "1"
|
||||
futures = "0"
|
||||
num-traits = "0"
|
||||
futures.workspace = true
|
||||
num-traits.workspace = true
|
||||
url = { workspace = true }
|
||||
serde = { version = "^1" }
|
||||
serde_json = { version = "1" }
|
||||
|
||||
@@ -27,7 +27,7 @@ use snafu::prelude::*;
|
||||
|
||||
use crate::error::{CreateDirSnafu, Error, InvalidTableNameSnafu, Result};
|
||||
use crate::io::object_store::MirroringObjectStoreWrapper;
|
||||
use crate::table::{ReadParams, Table};
|
||||
use crate::table::{NativeTable, ReadParams, TableRef};
|
||||
|
||||
pub const LANCE_FILE_EXTENSION: &str = "lance";
|
||||
|
||||
@@ -46,17 +46,20 @@ pub trait Connection: Send + Sync {
|
||||
/// * `params` - Optional [`WriteParams`] to create the table.
|
||||
///
|
||||
/// # Returns
|
||||
/// Created [`Table`], or [`Err(Error::TableAlreadyExists)`] if the table already exists.
|
||||
/// Created [`TableRef`], or [`Err(Error::TableAlreadyExists)`] if the table already exists.
|
||||
async fn create_table(
|
||||
&self,
|
||||
name: &str,
|
||||
batches: Box<dyn RecordBatchReader + Send>,
|
||||
params: Option<WriteParams>,
|
||||
) -> Result<Table>;
|
||||
) -> Result<TableRef>;
|
||||
|
||||
async fn open_table(&self, name: &str) -> Result<Table>;
|
||||
async fn open_table(&self, name: &str) -> Result<TableRef> {
|
||||
self.open_table_with_params(name, ReadParams::default())
|
||||
.await
|
||||
}
|
||||
|
||||
async fn open_table_with_params(&self, name: &str, params: ReadParams) -> Result<Table>;
|
||||
async fn open_table_with_params(&self, name: &str, params: ReadParams) -> Result<TableRef>;
|
||||
|
||||
/// Drop a table in the database.
|
||||
///
|
||||
@@ -240,30 +243,19 @@ impl Connection for Database {
|
||||
name: &str,
|
||||
batches: Box<dyn RecordBatchReader + Send>,
|
||||
params: Option<WriteParams>,
|
||||
) -> Result<Table> {
|
||||
) -> Result<TableRef> {
|
||||
let table_uri = self.table_uri(name)?;
|
||||
|
||||
Table::create(
|
||||
&table_uri,
|
||||
name,
|
||||
batches,
|
||||
self.store_wrapper.clone(),
|
||||
params,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Open a table in the database.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `name` - The name of the table.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * A [Table] object.
|
||||
async fn open_table(&self, name: &str) -> Result<Table> {
|
||||
self.open_table_with_params(name, ReadParams::default())
|
||||
.await
|
||||
Ok(Arc::new(
|
||||
NativeTable::create(
|
||||
&table_uri,
|
||||
name,
|
||||
batches,
|
||||
self.store_wrapper.clone(),
|
||||
params,
|
||||
)
|
||||
.await?,
|
||||
))
|
||||
}
|
||||
|
||||
/// Open a table in the database.
|
||||
@@ -274,10 +266,13 @@ impl Connection for Database {
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * A [Table] object.
|
||||
async fn open_table_with_params(&self, name: &str, params: ReadParams) -> Result<Table> {
|
||||
/// * A [TableRef] object.
|
||||
async fn open_table_with_params(&self, name: &str, params: ReadParams) -> Result<TableRef> {
|
||||
let table_uri = self.table_uri(name)?;
|
||||
Table::open_with_params(&table_uri, name, self.store_wrapper.clone(), params).await
|
||||
Ok(Arc::new(
|
||||
NativeTable::open_with_params(&table_uri, name, self.store_wrapper.clone(), params)
|
||||
.await?,
|
||||
))
|
||||
}
|
||||
|
||||
async fn drop_table(&self, name: &str) -> Result<()> {
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::PoisonError;
|
||||
|
||||
use arrow_schema::ArrowError;
|
||||
use snafu::Snafu;
|
||||
|
||||
@@ -35,6 +37,8 @@ pub enum Error {
|
||||
Lance { message: String },
|
||||
#[snafu(display("LanceDB Schema Error: {message}"))]
|
||||
Schema { message: String },
|
||||
#[snafu(display("Runtime error: {message}"))]
|
||||
Runtime { message: String },
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
@@ -70,3 +74,11 @@ impl From<object_store::path::Error> for Error {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<PoisonError<T>> for Error {
|
||||
fn from(e: PoisonError<T>) -> Self {
|
||||
Self::Runtime {
|
||||
message: e.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// Copyright 2023 Lance Developers.
|
||||
// Copyright 2024 Lance Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@@ -12,4 +12,263 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{cmp::max, sync::Arc};
|
||||
|
||||
use lance_index::{DatasetIndexExt, IndexType};
|
||||
pub use lance_linalg::distance::MetricType;
|
||||
|
||||
pub mod vector;
|
||||
|
||||
use crate::{Error, Result, Table};
|
||||
|
||||
/// Index Parameters.
|
||||
pub enum IndexParams {
|
||||
Scalar {
|
||||
replace: bool,
|
||||
},
|
||||
IvfPq {
|
||||
replace: bool,
|
||||
metric_type: MetricType,
|
||||
num_partitions: u64,
|
||||
num_sub_vectors: u32,
|
||||
num_bits: u32,
|
||||
sample_rate: u32,
|
||||
max_iterations: u32,
|
||||
},
|
||||
}
|
||||
|
||||
/// Builder for Index Parameters.
|
||||
|
||||
pub struct IndexBuilder {
|
||||
table: Arc<dyn Table>,
|
||||
columns: Vec<String>,
|
||||
// General parameters
|
||||
/// Index name.
|
||||
name: Option<String>,
|
||||
/// Replace the existing index.
|
||||
replace: bool,
|
||||
|
||||
index_type: IndexType,
|
||||
|
||||
// Scalar index parameters
|
||||
// Nothing to set here.
|
||||
|
||||
// IVF_PQ parameters
|
||||
metric_type: MetricType,
|
||||
num_partitions: Option<u64>,
|
||||
// PQ related
|
||||
num_sub_vectors: Option<u32>,
|
||||
num_bits: u32,
|
||||
|
||||
/// The rate to find samples to train kmeans.
|
||||
sample_rate: u32,
|
||||
/// Max iteration to train kmeans.
|
||||
max_iterations: u32,
|
||||
}
|
||||
|
||||
impl IndexBuilder {
|
||||
pub(crate) fn new(table: Arc<dyn Table>, columns: &[&str]) -> Self {
|
||||
IndexBuilder {
|
||||
table,
|
||||
columns: columns.iter().map(|c| c.to_string()).collect(),
|
||||
name: None,
|
||||
replace: true,
|
||||
index_type: IndexType::Scalar,
|
||||
metric_type: MetricType::L2,
|
||||
num_partitions: None,
|
||||
num_sub_vectors: None,
|
||||
num_bits: 8,
|
||||
sample_rate: 256,
|
||||
max_iterations: 50,
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a Scalar Index.
|
||||
///
|
||||
/// Accepted parameters:
|
||||
/// - `replace`: Replace the existing index.
|
||||
/// - `name`: Index name. Default: `None`
|
||||
pub fn scalar(&mut self) -> &mut Self {
|
||||
self.index_type = IndexType::Scalar;
|
||||
self
|
||||
}
|
||||
|
||||
/// Build an IVF PQ index.
|
||||
///
|
||||
/// Accepted parameters:
|
||||
/// - `replace`: Replace the existing index.
|
||||
/// - `name`: Index name. Default: `None`
|
||||
/// - `metric_type`: [MetricType] to use to build Vector Index.
|
||||
/// - `num_partitions`: Number of IVF partitions.
|
||||
/// - `num_sub_vectors`: Number of sub-vectors of PQ.
|
||||
/// - `num_bits`: Number of bits used for PQ centroids.
|
||||
/// - `sample_rate`: The rate to find samples to train kmeans.
|
||||
/// - `max_iterations`: Max iteration to train kmeans.
|
||||
pub fn ivf_pq(&mut self) -> &mut Self {
|
||||
self.index_type = IndexType::Vector;
|
||||
self
|
||||
}
|
||||
|
||||
/// Whether to replace the existing index, default is `true`.
|
||||
pub fn replace(&mut self, v: bool) -> &mut Self {
|
||||
self.replace = v;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the index name.
|
||||
pub fn name(&mut self, name: &str) -> &mut Self {
|
||||
self.name = Some(name.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
/// [MetricType] to use to build Vector Index.
|
||||
///
|
||||
/// Default value is [MetricType::L2].
|
||||
pub fn metric_type(&mut self, metric_type: MetricType) -> &mut Self {
|
||||
self.metric_type = metric_type;
|
||||
self
|
||||
}
|
||||
|
||||
/// Number of IVF partitions.
|
||||
pub fn num_partitions(&mut self, num_partitions: u64) -> &mut Self {
|
||||
self.num_partitions = Some(num_partitions);
|
||||
self
|
||||
}
|
||||
|
||||
/// Number of sub-vectors of PQ.
|
||||
pub fn num_sub_vectors(&mut self, num_sub_vectors: u32) -> &mut Self {
|
||||
self.num_sub_vectors = Some(num_sub_vectors);
|
||||
self
|
||||
}
|
||||
|
||||
/// Number of bits used for PQ centroids.
|
||||
pub fn num_bits(&mut self, num_bits: u32) -> &mut Self {
|
||||
self.num_bits = num_bits;
|
||||
self
|
||||
}
|
||||
|
||||
/// The rate to find samples to train kmeans.
|
||||
pub fn sample_rate(&mut self, sample_rate: u32) -> &mut Self {
|
||||
self.sample_rate = sample_rate;
|
||||
self
|
||||
}
|
||||
|
||||
/// Max iteration to train kmeans.
|
||||
pub fn max_iterations(&mut self, max_iterations: u32) -> &mut Self {
|
||||
self.max_iterations = max_iterations;
|
||||
self
|
||||
}
|
||||
|
||||
/// Build the parameters.
|
||||
pub async fn build(&self) -> Result<()> {
|
||||
if self.columns.len() != 1 {
|
||||
return Err(Error::Schema {
|
||||
message: "Only one column is supported for index".to_string(),
|
||||
});
|
||||
}
|
||||
let column = &self.columns[0];
|
||||
let schema = self.table.schema();
|
||||
let field = schema.field_with_name(column)?;
|
||||
|
||||
let params = match self.index_type {
|
||||
IndexType::Scalar => IndexParams::Scalar {
|
||||
replace: self.replace,
|
||||
},
|
||||
IndexType::Vector => {
|
||||
let num_partitions = if let Some(n) = self.num_partitions {
|
||||
n
|
||||
} else {
|
||||
suggested_num_partitions(self.table.count_rows().await?)
|
||||
};
|
||||
let num_sub_vectors: u32 = if let Some(n) = self.num_sub_vectors {
|
||||
n
|
||||
} else {
|
||||
match field.data_type() {
|
||||
arrow_schema::DataType::FixedSizeList(_, n) => {
|
||||
Ok::<u32, Error>(suggested_num_sub_vectors(*n as u32))
|
||||
}
|
||||
_ => Err(Error::Schema {
|
||||
message: format!(
|
||||
"Column '{}' is not a FixedSizeList",
|
||||
&self.columns[0]
|
||||
),
|
||||
}),
|
||||
}?
|
||||
};
|
||||
IndexParams::IvfPq {
|
||||
replace: self.replace,
|
||||
metric_type: self.metric_type,
|
||||
num_partitions,
|
||||
num_sub_vectors,
|
||||
num_bits: self.num_bits,
|
||||
sample_rate: self.sample_rate,
|
||||
max_iterations: self.max_iterations,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let tbl = self
|
||||
.table
|
||||
.as_native()
|
||||
.expect("Only native table is supported here");
|
||||
let mut dataset = tbl.clone_inner_dataset();
|
||||
match params {
|
||||
IndexParams::Scalar { replace } => {
|
||||
self.table
|
||||
.as_native()
|
||||
.unwrap()
|
||||
.create_scalar_index(column, replace)
|
||||
.await?
|
||||
}
|
||||
IndexParams::IvfPq {
|
||||
replace,
|
||||
metric_type,
|
||||
num_partitions,
|
||||
num_sub_vectors,
|
||||
num_bits,
|
||||
max_iterations,
|
||||
..
|
||||
} => {
|
||||
let lance_idx_params = lance::index::vector::VectorIndexParams::ivf_pq(
|
||||
num_partitions as usize,
|
||||
num_bits as u8,
|
||||
num_sub_vectors as usize,
|
||||
false,
|
||||
metric_type,
|
||||
max_iterations as usize,
|
||||
);
|
||||
dataset
|
||||
.create_index(
|
||||
&[column],
|
||||
IndexType::Vector,
|
||||
None,
|
||||
&lance_idx_params,
|
||||
replace,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
tbl.reset_dataset(dataset);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn suggested_num_partitions(rows: usize) -> u64 {
|
||||
let num_partitions = (rows as f64).sqrt() as u64;
|
||||
max(1, num_partitions)
|
||||
}
|
||||
|
||||
fn suggested_num_sub_vectors(dim: u32) -> u32 {
|
||||
if dim % 16 == 0 {
|
||||
// Should be more aggressive than this default.
|
||||
dim / 16
|
||||
} else if dim % 8 == 0 {
|
||||
dim / 8
|
||||
} else {
|
||||
log::warn!(
|
||||
"The dimension of the vector is not divisible by 8 or 16, \
|
||||
which may cause performance degradation in PQ"
|
||||
);
|
||||
1
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,104 +14,7 @@
|
||||
|
||||
use serde::Deserialize;
|
||||
|
||||
use lance::index::vector::pq::PQBuildParams;
|
||||
use lance::index::vector::VectorIndexParams;
|
||||
use lance::table::format::{Index, Manifest};
|
||||
use lance_index::vector::ivf::IvfBuildParams;
|
||||
use lance_linalg::distance::MetricType;
|
||||
|
||||
pub trait VectorIndexBuilder {
|
||||
fn get_column(&self) -> Option<String>;
|
||||
fn get_index_name(&self) -> Option<String>;
|
||||
fn build(&self) -> VectorIndexParams;
|
||||
|
||||
fn get_replace(&self) -> bool;
|
||||
}
|
||||
|
||||
pub struct IvfPQIndexBuilder {
|
||||
column: Option<String>,
|
||||
index_name: Option<String>,
|
||||
metric_type: Option<MetricType>,
|
||||
ivf_params: Option<IvfBuildParams>,
|
||||
pq_params: Option<PQBuildParams>,
|
||||
replace: bool,
|
||||
}
|
||||
|
||||
impl IvfPQIndexBuilder {
|
||||
pub fn new() -> IvfPQIndexBuilder {
|
||||
Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for IvfPQIndexBuilder {
|
||||
fn default() -> Self {
|
||||
IvfPQIndexBuilder {
|
||||
column: None,
|
||||
index_name: None,
|
||||
metric_type: None,
|
||||
ivf_params: None,
|
||||
pq_params: None,
|
||||
replace: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl IvfPQIndexBuilder {
|
||||
pub fn column(&mut self, column: String) -> &mut IvfPQIndexBuilder {
|
||||
self.column = Some(column);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn index_name(&mut self, index_name: String) -> &mut IvfPQIndexBuilder {
|
||||
self.index_name = Some(index_name);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn metric_type(&mut self, metric_type: MetricType) -> &mut IvfPQIndexBuilder {
|
||||
self.metric_type = Some(metric_type);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn ivf_params(&mut self, ivf_params: IvfBuildParams) -> &mut IvfPQIndexBuilder {
|
||||
self.ivf_params = Some(ivf_params);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn pq_params(&mut self, pq_params: PQBuildParams) -> &mut IvfPQIndexBuilder {
|
||||
self.pq_params = Some(pq_params);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn replace(&mut self, replace: bool) -> &mut IvfPQIndexBuilder {
|
||||
self.replace = replace;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl VectorIndexBuilder for IvfPQIndexBuilder {
|
||||
fn get_column(&self) -> Option<String> {
|
||||
self.column.clone()
|
||||
}
|
||||
|
||||
fn get_index_name(&self) -> Option<String> {
|
||||
self.index_name.clone()
|
||||
}
|
||||
|
||||
fn build(&self) -> VectorIndexParams {
|
||||
let ivf_params = self.ivf_params.clone().unwrap_or_default();
|
||||
let pq_params = self.pq_params.clone().unwrap_or_default();
|
||||
|
||||
VectorIndexParams::with_ivf_pq_params(
|
||||
self.metric_type.unwrap_or(MetricType::L2),
|
||||
ivf_params,
|
||||
pq_params,
|
||||
)
|
||||
}
|
||||
|
||||
fn get_replace(&self) -> bool {
|
||||
self.replace
|
||||
}
|
||||
}
|
||||
|
||||
pub struct VectorIndex {
|
||||
pub columns: Vec<String>,
|
||||
@@ -139,79 +42,3 @@ pub struct VectorIndexStatistics {
|
||||
pub num_indexed_rows: usize,
|
||||
pub num_unindexed_rows: usize,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use lance::index::vector::StageParams;
|
||||
use lance_index::vector::ivf::IvfBuildParams;
|
||||
use lance_index::vector::pq::PQBuildParams;
|
||||
|
||||
use crate::index::vector::{IvfPQIndexBuilder, VectorIndexBuilder};
|
||||
|
||||
#[test]
|
||||
fn test_builder_no_params() {
|
||||
let index_builder = IvfPQIndexBuilder::new();
|
||||
assert!(index_builder.get_column().is_none());
|
||||
assert!(index_builder.get_index_name().is_none());
|
||||
|
||||
let index_params = index_builder.build();
|
||||
assert_eq!(index_params.stages.len(), 2);
|
||||
if let StageParams::Ivf(ivf_params) = index_params.stages.get(0).unwrap() {
|
||||
let default = IvfBuildParams::default();
|
||||
assert_eq!(ivf_params.num_partitions, default.num_partitions);
|
||||
assert_eq!(ivf_params.max_iters, default.max_iters);
|
||||
} else {
|
||||
panic!("Expected first stage to be ivf")
|
||||
}
|
||||
|
||||
if let StageParams::PQ(pq_params) = index_params.stages.get(1).unwrap() {
|
||||
assert_eq!(pq_params.use_opq, false);
|
||||
} else {
|
||||
panic!("Expected second stage to be pq")
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_builder_all_params() {
|
||||
let mut index_builder = IvfPQIndexBuilder::new();
|
||||
|
||||
index_builder
|
||||
.column("c".to_owned())
|
||||
.metric_type(MetricType::Cosine)
|
||||
.index_name("index".to_owned());
|
||||
|
||||
assert_eq!(index_builder.column.clone().unwrap(), "c");
|
||||
assert_eq!(index_builder.metric_type.unwrap(), MetricType::Cosine);
|
||||
assert_eq!(index_builder.index_name.clone().unwrap(), "index");
|
||||
|
||||
let ivf_params = IvfBuildParams::new(500);
|
||||
let mut pq_params = PQBuildParams::default();
|
||||
pq_params.use_opq = true;
|
||||
pq_params.max_iters = 1;
|
||||
pq_params.num_bits = 8;
|
||||
pq_params.num_sub_vectors = 50;
|
||||
pq_params.max_opq_iters = 2;
|
||||
index_builder.ivf_params(ivf_params);
|
||||
index_builder.pq_params(pq_params);
|
||||
|
||||
let index_params = index_builder.build();
|
||||
assert_eq!(index_params.stages.len(), 2);
|
||||
if let StageParams::Ivf(ivf_params) = index_params.stages.get(0).unwrap() {
|
||||
assert_eq!(ivf_params.num_partitions, 500);
|
||||
} else {
|
||||
assert!(false, "Expected first stage to be ivf")
|
||||
}
|
||||
|
||||
if let StageParams::PQ(pq_params) = index_params.stages.get(1).unwrap() {
|
||||
assert_eq!(pq_params.use_opq, true);
|
||||
assert_eq!(pq_params.max_iters, 1);
|
||||
assert_eq!(pq_params.num_bits, 8);
|
||||
assert_eq!(pq_params.num_sub_vectors, 50);
|
||||
assert_eq!(pq_params.max_opq_iters, 2);
|
||||
} else {
|
||||
assert!(false, "Expected second stage to be pq")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -335,14 +335,15 @@ impl WrappingObjectStore for MirroringObjectStoreWrapper {
|
||||
#[cfg(all(test, not(windows)))]
|
||||
mod test {
|
||||
use super::*;
|
||||
use crate::connection::{Connection, Database};
|
||||
use arrow_array::PrimitiveArray;
|
||||
|
||||
use futures::TryStreamExt;
|
||||
use lance::{dataset::WriteParams, io::ObjectStoreParams};
|
||||
use lance_testing::datagen::{BatchGenerator, IncrementingInt32, RandomVector};
|
||||
use object_store::local::LocalFileSystem;
|
||||
use tempfile;
|
||||
|
||||
use crate::connection::{Connection, Database};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_e2e() {
|
||||
let dir1 = tempfile::tempdir().unwrap().into_path();
|
||||
@@ -374,9 +375,7 @@ mod test {
|
||||
assert_eq!(t.count_rows().await.unwrap(), 100);
|
||||
|
||||
let q = t
|
||||
.search(Some(PrimitiveArray::from_iter_values(vec![
|
||||
0.1, 0.1, 0.1, 0.1,
|
||||
])))
|
||||
.search(&[0.1, 0.1, 0.1, 0.1])
|
||||
.limit(10)
|
||||
.execute()
|
||||
.await
|
||||
|
||||
@@ -46,7 +46,7 @@
|
||||
//! #### Connect to a database.
|
||||
//!
|
||||
//! ```rust
|
||||
//! use vectordb::{connection::{Database, Connection}, Table, WriteMode};
|
||||
//! use vectordb::{connection::{Database, Connection}, WriteMode};
|
||||
//! use arrow_schema::{Field, Schema};
|
||||
//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
//! let db = Database::connect("data/sample-lancedb").await.unwrap();
|
||||
@@ -119,7 +119,7 @@
|
||||
//! # db.create_table("my_table", Box::new(batches), None).await.unwrap();
|
||||
//! let table = db.open_table("my_table").await.unwrap();
|
||||
//! let results = table
|
||||
//! .search(Some(vec![1.0; 128]))
|
||||
//! .search(&[1.0; 128])
|
||||
//! .execute()
|
||||
//! .await
|
||||
//! .unwrap()
|
||||
@@ -143,6 +143,6 @@ pub mod utils;
|
||||
|
||||
pub use connection::{Connection, Database};
|
||||
pub use error::{Error, Result};
|
||||
pub use table::Table;
|
||||
pub use table::{Table, TableRef};
|
||||
|
||||
pub use lance::dataset::WriteMode;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// Copyright 2023 Lance Developers.
|
||||
// Copyright 2024 Lance Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@@ -48,10 +48,10 @@ impl Query {
|
||||
/// # Returns
|
||||
///
|
||||
/// * A [Query] object.
|
||||
pub(crate) fn new(dataset: Arc<Dataset>, vector: Option<Float32Array>) -> Self {
|
||||
pub(crate) fn new(dataset: Arc<Dataset>) -> Self {
|
||||
Query {
|
||||
dataset,
|
||||
query_vector: vector,
|
||||
query_vector: None,
|
||||
column: crate::table::VECTOR_COLUMN_NAME.to_string(),
|
||||
limit: None,
|
||||
nprobes: 20,
|
||||
@@ -206,7 +206,7 @@ mod tests {
|
||||
let ds = Dataset::write(batches, "memory://foo", None).await.unwrap();
|
||||
|
||||
let vector = Some(Float32Array::from_iter_values([0.1, 0.2]));
|
||||
let query = Query::new(Arc::new(ds), vector.clone());
|
||||
let query = Query::new(Arc::new(ds)).query_vector(&[0.1, 0.2]);
|
||||
assert_eq!(query.query_vector, vector);
|
||||
|
||||
let new_vector = Float32Array::from_iter_values([9.8, 8.7]);
|
||||
@@ -232,9 +232,7 @@ mod tests {
|
||||
let batches = make_non_empty_batches();
|
||||
let ds = Arc::new(Dataset::write(batches, "memory://foo", None).await.unwrap());
|
||||
|
||||
let vector = Some(Float32Array::from_iter_values([0.1; 4]));
|
||||
|
||||
let query = Query::new(ds.clone(), vector.clone());
|
||||
let query = Query::new(ds.clone()).query_vector(&[0.1; 4]);
|
||||
let result = query
|
||||
.limit(10)
|
||||
.filter(Some("id % 2 == 0".to_string()))
|
||||
@@ -247,7 +245,7 @@ mod tests {
|
||||
assert!(batch.expect("should be Ok").num_rows() < 10);
|
||||
}
|
||||
|
||||
let query = Query::new(ds, vector.clone());
|
||||
let query = Query::new(ds).query_vector(&[0.1; 4]);
|
||||
let result = query
|
||||
.limit(10)
|
||||
.filter(Some("id % 2 == 0".to_string()))
|
||||
@@ -268,7 +266,7 @@ mod tests {
|
||||
let batches = make_non_empty_batches();
|
||||
let ds = Arc::new(Dataset::write(batches, "memory://foo", None).await.unwrap());
|
||||
|
||||
let query = Query::new(ds.clone(), None);
|
||||
let query = Query::new(ds.clone());
|
||||
let result = query
|
||||
.filter(Some("id % 2 == 0".to_string()))
|
||||
.execute()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// Copyright 2023 LanceDB Developers.
|
||||
// Copyright 2024 LanceDB Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@@ -12,72 +12,170 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! LanceDB Table APIs
|
||||
|
||||
use std::path::Path;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use arrow_array::RecordBatchReader;
|
||||
use arrow_schema::{Schema, SchemaRef};
|
||||
use chrono::Duration;
|
||||
use lance::dataset::builder::DatasetBuilder;
|
||||
use lance::index::scalar::ScalarIndexParams;
|
||||
use lance_index::optimize::OptimizeOptions;
|
||||
use lance_index::IndexType;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::{Float32Array, RecordBatchReader};
|
||||
use arrow_schema::SchemaRef;
|
||||
use lance::dataset::cleanup::RemovalStats;
|
||||
use lance::dataset::optimize::{
|
||||
compact_files, CompactionMetrics, CompactionOptions, IndexRemapperOptions,
|
||||
};
|
||||
pub use lance::dataset::ReadParams;
|
||||
use lance::dataset::{Dataset, UpdateBuilder, WriteParams};
|
||||
use lance::index::scalar::ScalarIndexParams;
|
||||
use lance::io::WrappingObjectStore;
|
||||
use lance_index::DatasetIndexExt;
|
||||
use std::path::Path;
|
||||
use lance_index::{optimize::OptimizeOptions, DatasetIndexExt, IndexType};
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
use crate::index::vector::{VectorIndex, VectorIndexBuilder, VectorIndexStatistics};
|
||||
use crate::index::vector::{VectorIndex, VectorIndexStatistics};
|
||||
use crate::index::IndexBuilder;
|
||||
use crate::query::Query;
|
||||
use crate::utils::{PatchReadParam, PatchWriteParam};
|
||||
use crate::WriteMode;
|
||||
|
||||
pub use lance::dataset::ReadParams;
|
||||
|
||||
pub const VECTOR_COLUMN_NAME: &str = "vector";
|
||||
|
||||
/// A Table is a collection of strong typed Rows.
|
||||
///
|
||||
/// The type of the each row is defined in Apache Arrow [Schema].
|
||||
#[async_trait::async_trait]
|
||||
pub trait Table: std::fmt::Display + Send + Sync {
|
||||
fn as_any(&self) -> &dyn std::any::Any;
|
||||
|
||||
/// Cast as [`NativeTable`], or return None it if is not a [`NativeTable`].
|
||||
fn as_native(&self) -> Option<&NativeTable>;
|
||||
|
||||
/// Get the name of the table.
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// Get the arrow [Schema] of the table.
|
||||
fn schema(&self) -> SchemaRef;
|
||||
|
||||
/// Count the number of rows in this dataset.
|
||||
async fn count_rows(&self) -> Result<usize>;
|
||||
|
||||
/// Insert new records into this Table
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `batches` RecordBatch to be saved in the Table
|
||||
/// * `params` Append / Overwrite existing records. Default: Append
|
||||
async fn add(
|
||||
&self,
|
||||
batches: Box<dyn RecordBatchReader + Send>,
|
||||
params: Option<WriteParams>,
|
||||
) -> Result<()>;
|
||||
|
||||
/// Delete the rows from table that match the predicate.
|
||||
///
|
||||
/// # Arguments
|
||||
/// - `predicate` - The SQL predicate string to filter the rows to be deleted.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// # use std::sync::Arc;
|
||||
/// # use vectordb::connection::{Database, Connection};
|
||||
/// # use arrow_array::{FixedSizeListArray, types::Float32Type, RecordBatch,
|
||||
/// # RecordBatchIterator, Int32Array};
|
||||
/// # use arrow_schema::{Schema, Field, DataType};
|
||||
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
/// let tmpdir = tempfile::tempdir().unwrap();
|
||||
/// let db = Database::connect(tmpdir.path().to_str().unwrap()).await.unwrap();
|
||||
/// # let schema = Arc::new(Schema::new(vec![
|
||||
/// # Field::new("id", DataType::Int32, false),
|
||||
/// # Field::new("vector", DataType::FixedSizeList(
|
||||
/// # Arc::new(Field::new("item", DataType::Float32, true)), 128), true),
|
||||
/// # ]));
|
||||
/// let batches = RecordBatchIterator::new(vec![
|
||||
/// RecordBatch::try_new(schema.clone(),
|
||||
/// vec![
|
||||
/// Arc::new(Int32Array::from_iter_values(0..10)),
|
||||
/// Arc::new(FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
|
||||
/// (0..10).map(|_| Some(vec![Some(1.0); 128])), 128)),
|
||||
/// ]).unwrap()
|
||||
/// ].into_iter().map(Ok),
|
||||
/// schema.clone());
|
||||
/// let tbl = db.create_table("delete_test", Box::new(batches), None).await.unwrap();
|
||||
/// tbl.delete("id > 5").await.unwrap();
|
||||
/// # });
|
||||
/// ```
|
||||
async fn delete(&self, predicate: &str) -> Result<()>;
|
||||
|
||||
/// Create an index on the column name.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```no_run
|
||||
/// # use std::sync::Arc;
|
||||
/// # use vectordb::connection::{Database, Connection};
|
||||
/// # use arrow_array::{FixedSizeListArray, types::Float32Type, RecordBatch,
|
||||
/// # RecordBatchIterator, Int32Array};
|
||||
/// # use arrow_schema::{Schema, Field, DataType};
|
||||
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
/// let tmpdir = tempfile::tempdir().unwrap();
|
||||
/// let db = Database::connect(tmpdir.path().to_str().unwrap()).await.unwrap();
|
||||
/// # let tbl = db.open_table("delete_test").await.unwrap();
|
||||
/// tbl.create_index(&["vector"])
|
||||
/// .ivf_pq()
|
||||
/// .num_partitions(256)
|
||||
/// .build()
|
||||
/// .await
|
||||
/// .unwrap();
|
||||
/// # });
|
||||
/// ```
|
||||
fn create_index(&self, column: &[&str]) -> IndexBuilder;
|
||||
|
||||
/// Search the table with a given query vector.
|
||||
fn search(&self, query: &[f32]) -> Query {
|
||||
self.query().query_vector(query)
|
||||
}
|
||||
|
||||
/// Create a Query builder.
|
||||
fn query(&self) -> Query;
|
||||
}
|
||||
|
||||
/// Reference to a Table pointer.
|
||||
pub type TableRef = Arc<dyn Table>;
|
||||
|
||||
/// A table in a LanceDB database.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Table {
|
||||
pub struct NativeTable {
|
||||
name: String,
|
||||
uri: String,
|
||||
dataset: Arc<Dataset>,
|
||||
dataset: Arc<Mutex<Dataset>>,
|
||||
|
||||
// the object store wrapper to use on write path
|
||||
store_wrapper: Option<Arc<dyn WrappingObjectStore>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Table {
|
||||
impl std::fmt::Display for NativeTable {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "Table({})", self.name)
|
||||
}
|
||||
}
|
||||
|
||||
impl Table {
|
||||
impl NativeTable {
|
||||
/// Opens an existing Table
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `uri` - The uri to a [Table]
|
||||
/// * `uri` - The uri to a [NativeTable]
|
||||
/// * `name` - The table name
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * A [Table] object.
|
||||
/// * A [NativeTable] object.
|
||||
pub async fn open(uri: &str) -> Result<Self> {
|
||||
let name = Self::get_table_name(uri)?;
|
||||
Self::open_with_params(uri, &name, None, ReadParams::default()).await
|
||||
}
|
||||
|
||||
/// Open an Table with a given name.
|
||||
pub async fn open_with_name(uri: &str, name: &str) -> Result<Self> {
|
||||
Self::open_with_params(uri, name, None, ReadParams::default()).await
|
||||
}
|
||||
|
||||
/// Opens an existing Table
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -88,7 +186,7 @@ impl Table {
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * A [Table] object.
|
||||
/// * A [NativeTable] object.
|
||||
pub async fn open_with_params(
|
||||
uri: &str,
|
||||
name: &str,
|
||||
@@ -113,25 +211,26 @@ impl Table {
|
||||
message: e.to_string(),
|
||||
},
|
||||
})?;
|
||||
Ok(Table {
|
||||
Ok(NativeTable {
|
||||
name: name.to_string(),
|
||||
uri: uri.to_string(),
|
||||
dataset: Arc::new(dataset),
|
||||
dataset: Arc::new(Mutex::new(dataset)),
|
||||
store_wrapper: write_store_wrapper,
|
||||
})
|
||||
}
|
||||
|
||||
/// Checkout a specific version of this [`Table`]
|
||||
/// Make a new clone of the internal lance dataset.
|
||||
pub(crate) fn clone_inner_dataset(&self) -> Dataset {
|
||||
self.dataset.lock().expect("Lock poison").clone()
|
||||
}
|
||||
|
||||
/// Checkout a specific version of this [NativeTable]
|
||||
///
|
||||
pub async fn checkout(uri: &str, version: u64) -> Result<Self> {
|
||||
let name = Self::get_table_name(uri)?;
|
||||
Self::checkout_with_params(uri, &name, version, None, ReadParams::default()).await
|
||||
}
|
||||
|
||||
pub async fn checkout_with_name(uri: &str, name: &str, version: u64) -> Result<Self> {
|
||||
Self::checkout_with_params(uri, name, version, None, ReadParams::default()).await
|
||||
}
|
||||
|
||||
pub async fn checkout_with_params(
|
||||
uri: &str,
|
||||
name: &str,
|
||||
@@ -154,26 +253,27 @@ impl Table {
|
||||
message: e.to_string(),
|
||||
},
|
||||
})?;
|
||||
Ok(Table {
|
||||
Ok(NativeTable {
|
||||
name: name.to_string(),
|
||||
uri: uri.to_string(),
|
||||
dataset: Arc::new(dataset),
|
||||
dataset: Arc::new(Mutex::new(dataset)),
|
||||
store_wrapper: write_store_wrapper,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn checkout_latest(&self) -> Result<Self> {
|
||||
let latest_version_id = self.dataset.latest_version_id().await?;
|
||||
let dataset = if latest_version_id == self.dataset.version().version {
|
||||
self.dataset.clone()
|
||||
let dataset = self.clone_inner_dataset();
|
||||
let latest_version_id = dataset.latest_version_id().await?;
|
||||
let dataset = if latest_version_id == dataset.version().version {
|
||||
dataset
|
||||
} else {
|
||||
Arc::new(self.dataset.checkout_version(latest_version_id).await?)
|
||||
dataset.checkout_version(latest_version_id).await?
|
||||
};
|
||||
|
||||
Ok(Table {
|
||||
Ok(Self {
|
||||
name: self.name.clone(),
|
||||
uri: self.uri.clone(),
|
||||
dataset,
|
||||
dataset: Arc::new(Mutex::new(dataset)),
|
||||
store_wrapper: self.store_wrapper.clone(),
|
||||
})
|
||||
}
|
||||
@@ -203,7 +303,7 @@ impl Table {
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * A [Table] object.
|
||||
/// * A [TableImpl] object.
|
||||
pub(crate) async fn create(
|
||||
uri: &str,
|
||||
name: &str,
|
||||
@@ -227,46 +327,22 @@ impl Table {
|
||||
message: e.to_string(),
|
||||
},
|
||||
})?;
|
||||
Ok(Table {
|
||||
Ok(NativeTable {
|
||||
name: name.to_string(),
|
||||
uri: uri.to_string(),
|
||||
dataset: Arc::new(dataset),
|
||||
dataset: Arc::new(Mutex::new(dataset)),
|
||||
store_wrapper: write_store_wrapper,
|
||||
})
|
||||
}
|
||||
|
||||
/// Schema of this Table.
|
||||
pub fn schema(&self) -> SchemaRef {
|
||||
Arc::new(self.dataset.schema().into())
|
||||
}
|
||||
|
||||
/// Version of this Table
|
||||
pub fn version(&self) -> u64 {
|
||||
self.dataset.version().version
|
||||
}
|
||||
|
||||
/// Create index on the table.
|
||||
pub async fn create_index(&mut self, index_builder: &impl VectorIndexBuilder) -> Result<()> {
|
||||
let mut dataset = self.dataset.as_ref().clone();
|
||||
dataset
|
||||
.create_index(
|
||||
&[index_builder
|
||||
.get_column()
|
||||
.unwrap_or(VECTOR_COLUMN_NAME.to_string())
|
||||
.as_str()],
|
||||
IndexType::Vector,
|
||||
index_builder.get_index_name(),
|
||||
&index_builder.build(),
|
||||
index_builder.get_replace(),
|
||||
)
|
||||
.await?;
|
||||
self.dataset = Arc::new(dataset);
|
||||
Ok(())
|
||||
self.dataset.lock().expect("lock poison").version().version
|
||||
}
|
||||
|
||||
/// Create a scalar index on the table
|
||||
pub async fn create_scalar_index(&mut self, column: &str, replace: bool) -> Result<()> {
|
||||
let mut dataset = self.dataset.as_ref().clone();
|
||||
pub async fn create_scalar_index(&self, column: &str, replace: bool) -> Result<()> {
|
||||
let mut dataset = self.clone_inner_dataset();
|
||||
let params = ScalarIndexParams::default();
|
||||
dataset
|
||||
.create_index(&[column], IndexType::Scalar, None, ¶ms, replace)
|
||||
@@ -275,65 +351,21 @@ impl Table {
|
||||
}
|
||||
|
||||
pub async fn optimize_indices(&mut self, options: &OptimizeOptions) -> Result<()> {
|
||||
let mut dataset = self.dataset.as_ref().clone();
|
||||
let mut dataset = self.clone_inner_dataset();
|
||||
dataset.optimize_indices(options).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Insert records into this Table
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `batches` RecordBatch to be saved in the Table
|
||||
/// * `write_mode` Append / Overwrite existing records. Default: Append
|
||||
/// # Returns
|
||||
///
|
||||
/// * The number of rows added
|
||||
pub async fn add(
|
||||
&mut self,
|
||||
batches: impl RecordBatchReader + Send + 'static,
|
||||
params: Option<WriteParams>,
|
||||
) -> Result<()> {
|
||||
let params = Some(params.unwrap_or(WriteParams {
|
||||
mode: WriteMode::Append,
|
||||
..WriteParams::default()
|
||||
}));
|
||||
|
||||
// patch the params if we have a write store wrapper
|
||||
let params = match self.store_wrapper.clone() {
|
||||
Some(wrapper) => params.patch_with_store_wrapper(wrapper)?,
|
||||
None => params,
|
||||
};
|
||||
|
||||
self.dataset = Arc::new(Dataset::write(batches, &self.uri, params).await?);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn query(&self) -> Query {
|
||||
Query::new(self.dataset.clone(), None)
|
||||
}
|
||||
|
||||
/// Creates a new Query object that can be executed.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `query_vector` The vector used for this query.
|
||||
///
|
||||
/// # Returns
|
||||
/// * A [Query] object.
|
||||
pub fn search<T: Into<Float32Array>>(&self, query_vector: Option<T>) -> Query {
|
||||
Query::new(self.dataset.clone(), query_vector.map(|q| q.into()))
|
||||
Query::new(self.clone_inner_dataset().into())
|
||||
}
|
||||
|
||||
pub fn filter(&self, expr: String) -> Query {
|
||||
Query::new(self.dataset.clone(), None).filter(Some(expr))
|
||||
Query::new(self.clone_inner_dataset().into()).filter(Some(expr))
|
||||
}
|
||||
|
||||
/// Returns the number of rows in this Table
|
||||
pub async fn count_rows(&self) -> Result<usize> {
|
||||
Ok(self.dataset.count_rows().await?)
|
||||
}
|
||||
|
||||
/// Merge new data into this table.
|
||||
pub async fn merge(
|
||||
@@ -342,26 +374,14 @@ impl Table {
|
||||
left_on: &str,
|
||||
right_on: &str,
|
||||
) -> Result<()> {
|
||||
let mut dataset = self.dataset.as_ref().clone();
|
||||
let mut dataset = self.clone_inner_dataset();
|
||||
dataset.merge(batches, left_on, right_on).await?;
|
||||
self.dataset = Arc::new(dataset);
|
||||
self.dataset = Arc::new(Mutex::new(dataset));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Delete rows from the table
|
||||
pub async fn delete(&mut self, predicate: &str) -> Result<()> {
|
||||
let mut dataset = self.dataset.as_ref().clone();
|
||||
dataset.delete(predicate).await?;
|
||||
self.dataset = Arc::new(dataset);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn update(
|
||||
&mut self,
|
||||
predicate: Option<&str>,
|
||||
updates: Vec<(&str, &str)>,
|
||||
) -> Result<()> {
|
||||
let mut builder = UpdateBuilder::new(self.dataset.clone());
|
||||
pub async fn update(&self, predicate: Option<&str>, updates: Vec<(&str, &str)>) -> Result<()> {
|
||||
let mut builder = UpdateBuilder::new(self.clone_inner_dataset().into());
|
||||
if let Some(predicate) = predicate {
|
||||
builder = builder.update_where(predicate)?;
|
||||
}
|
||||
@@ -371,9 +391,8 @@ impl Table {
|
||||
}
|
||||
|
||||
let operation = builder.build()?;
|
||||
let new_ds = operation.execute().await?;
|
||||
self.dataset = new_ds;
|
||||
|
||||
let ds = operation.execute().await?;
|
||||
self.reset_dataset(ds.as_ref().clone());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -393,8 +412,8 @@ impl Table {
|
||||
older_than: Duration,
|
||||
delete_unverified: Option<bool>,
|
||||
) -> Result<RemovalStats> {
|
||||
Ok(self
|
||||
.dataset
|
||||
let dataset = self.clone_inner_dataset();
|
||||
Ok(dataset
|
||||
.cleanup_old_versions(older_than, delete_unverified)
|
||||
.await?)
|
||||
}
|
||||
@@ -406,26 +425,28 @@ impl Table {
|
||||
///
|
||||
/// This calls into [lance::dataset::optimize::compact_files].
|
||||
pub async fn compact_files(
|
||||
&mut self,
|
||||
&self,
|
||||
options: CompactionOptions,
|
||||
remap_options: Option<Arc<dyn IndexRemapperOptions>>,
|
||||
) -> Result<CompactionMetrics> {
|
||||
let mut dataset = self.dataset.as_ref().clone();
|
||||
let mut dataset = self.clone_inner_dataset();
|
||||
let metrics = compact_files(&mut dataset, options, remap_options).await?;
|
||||
self.dataset = Arc::new(dataset);
|
||||
self.reset_dataset(dataset);
|
||||
Ok(metrics)
|
||||
}
|
||||
|
||||
pub fn count_fragments(&self) -> usize {
|
||||
self.dataset.count_fragments()
|
||||
self.dataset.lock().expect("lock poison").count_fragments()
|
||||
}
|
||||
|
||||
pub async fn count_deleted_rows(&self) -> Result<usize> {
|
||||
Ok(self.dataset.count_deleted_rows().await?)
|
||||
let dataset = self.clone_inner_dataset();
|
||||
Ok(dataset.count_deleted_rows().await?)
|
||||
}
|
||||
|
||||
pub async fn num_small_files(&self, max_rows_per_group: usize) -> usize {
|
||||
self.dataset.num_small_files(max_rows_per_group).await
|
||||
let dataset = self.clone_inner_dataset();
|
||||
dataset.num_small_files(max_rows_per_group).await
|
||||
}
|
||||
|
||||
pub async fn count_indexed_rows(&self, index_uuid: &str) -> Result<Option<usize>> {
|
||||
@@ -443,8 +464,8 @@ impl Table {
|
||||
}
|
||||
|
||||
pub async fn load_indices(&self) -> Result<Vec<VectorIndex>> {
|
||||
let (indices, mf) =
|
||||
futures::try_join!(self.dataset.load_indices(), self.dataset.latest_manifest())?;
|
||||
let dataset = self.clone_inner_dataset();
|
||||
let (indices, mf) = futures::try_join!(dataset.load_indices(), dataset.latest_manifest())?;
|
||||
Ok(indices
|
||||
.iter()
|
||||
.map(|i| VectorIndex::new_from_format(&mf, i))
|
||||
@@ -460,10 +481,8 @@ impl Table {
|
||||
if index.is_none() {
|
||||
return Ok(None);
|
||||
}
|
||||
let index_stats = self
|
||||
.dataset
|
||||
.index_statistics(&index.unwrap().index_name)
|
||||
.await?;
|
||||
let dataset = self.clone_inner_dataset();
|
||||
let index_stats = dataset.index_statistics(&index.unwrap().index_name).await?;
|
||||
let index_stats: VectorIndexStatistics =
|
||||
serde_json::from_str(&index_stats).map_err(|e| Error::Lance {
|
||||
message: format!(
|
||||
@@ -474,6 +493,71 @@ impl Table {
|
||||
|
||||
Ok(Some(index_stats))
|
||||
}
|
||||
|
||||
pub(crate) fn reset_dataset(&self, dataset: Dataset) {
|
||||
*self.dataset.lock().expect("lock poison") = dataset;
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Table for NativeTable {
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
|
||||
fn as_native(&self) -> Option<&NativeTable> {
|
||||
Some(self)
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
self.name.as_str()
|
||||
}
|
||||
|
||||
fn schema(&self) -> SchemaRef {
|
||||
let lance_schema = { self.dataset.lock().expect("lock poison").schema().clone() };
|
||||
Arc::new(Schema::from(&lance_schema))
|
||||
}
|
||||
|
||||
async fn count_rows(&self) -> Result<usize> {
|
||||
let dataset = { self.dataset.lock().expect("lock poison").clone() };
|
||||
Ok(dataset.count_rows().await?)
|
||||
}
|
||||
|
||||
async fn add(
|
||||
&self,
|
||||
batches: Box<dyn RecordBatchReader + Send>,
|
||||
params: Option<WriteParams>,
|
||||
) -> Result<()> {
|
||||
let params = Some(params.unwrap_or(WriteParams {
|
||||
mode: WriteMode::Append,
|
||||
..WriteParams::default()
|
||||
}));
|
||||
|
||||
// patch the params if we have a write store wrapper
|
||||
let params = match self.store_wrapper.clone() {
|
||||
Some(wrapper) => params.patch_with_store_wrapper(wrapper)?,
|
||||
None => params,
|
||||
};
|
||||
|
||||
self.reset_dataset(Dataset::write(batches, &self.uri, params).await?);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn create_index(&self, columns: &[&str]) -> IndexBuilder {
|
||||
IndexBuilder::new(Arc::new(self.clone()), columns)
|
||||
}
|
||||
|
||||
fn query(&self) -> Query {
|
||||
Query::new(Arc::new(self.dataset.lock().expect("lock poison").clone()))
|
||||
}
|
||||
|
||||
/// Delete rows from the table
|
||||
async fn delete(&self, predicate: &str) -> Result<()> {
|
||||
let mut dataset = self.clone_inner_dataset();
|
||||
dataset.delete(predicate).await?;
|
||||
self.reset_dataset(dataset);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -491,14 +575,11 @@ mod tests {
|
||||
use arrow_schema::{DataType, Field, Schema, TimeUnit};
|
||||
use futures::TryStreamExt;
|
||||
use lance::dataset::{Dataset, WriteMode};
|
||||
use lance::index::vector::pq::PQBuildParams;
|
||||
use lance::io::{ObjectStoreParams, WrappingObjectStore};
|
||||
use lance_index::vector::ivf::IvfBuildParams;
|
||||
use rand::Rng;
|
||||
use tempfile::tempdir;
|
||||
|
||||
use super::*;
|
||||
use crate::index::vector::IvfPQIndexBuilder;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_open() {
|
||||
@@ -510,7 +591,9 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let table = Table::open(dataset_path.to_str().unwrap()).await.unwrap();
|
||||
let table = NativeTable::open(dataset_path.to_str().unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(table.name, "test")
|
||||
}
|
||||
@@ -519,7 +602,7 @@ mod tests {
|
||||
async fn test_open_not_found() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let uri = tmp_dir.path().to_str().unwrap();
|
||||
let table = Table::open(uri).await;
|
||||
let table = NativeTable::open(uri).await;
|
||||
assert!(matches!(table.unwrap_err(), Error::TableNotFound { .. }));
|
||||
}
|
||||
|
||||
@@ -539,12 +622,12 @@ mod tests {
|
||||
|
||||
let batches = make_test_batches();
|
||||
let _ = batches.schema().clone();
|
||||
Table::create(&uri, "test", batches, None, None)
|
||||
NativeTable::create(&uri, "test", batches, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let batches = make_test_batches();
|
||||
let result = Table::create(&uri, "test", batches, None, None).await;
|
||||
let result = NativeTable::create(&uri, "test", batches, None, None).await;
|
||||
assert!(matches!(
|
||||
result.unwrap_err(),
|
||||
Error::TableAlreadyExists { .. }
|
||||
@@ -558,7 +641,7 @@ mod tests {
|
||||
|
||||
let batches = make_test_batches();
|
||||
let schema = batches.schema().clone();
|
||||
let mut table = Table::create(&uri, "test", batches, None, None)
|
||||
let table = NativeTable::create(&uri, "test", batches, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(table.count_rows().await.unwrap(), 10);
|
||||
@@ -574,7 +657,7 @@ mod tests {
|
||||
schema.clone(),
|
||||
);
|
||||
|
||||
table.add(new_batches, None).await.unwrap();
|
||||
table.add(Box::new(new_batches), None).await.unwrap();
|
||||
assert_eq!(table.count_rows().await.unwrap(), 20);
|
||||
assert_eq!(table.name, "test");
|
||||
}
|
||||
@@ -586,7 +669,7 @@ mod tests {
|
||||
|
||||
let batches = make_test_batches();
|
||||
let schema = batches.schema().clone();
|
||||
let mut table = Table::create(uri, "test", batches, None, None)
|
||||
let table = NativeTable::create(uri, "test", batches, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(table.count_rows().await.unwrap(), 10);
|
||||
@@ -607,7 +690,7 @@ mod tests {
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
table.add(new_batches, Some(param)).await.unwrap();
|
||||
table.add(Box::new(new_batches), Some(param)).await.unwrap();
|
||||
assert_eq!(table.count_rows().await.unwrap(), 10);
|
||||
assert_eq!(table.name, "test");
|
||||
}
|
||||
@@ -640,7 +723,7 @@ mod tests {
|
||||
);
|
||||
|
||||
Dataset::write(record_batch_iter, uri, None).await.unwrap();
|
||||
let mut table = Table::open(uri).await.unwrap();
|
||||
let table = NativeTable::open(uri).await.unwrap();
|
||||
|
||||
table
|
||||
.update(Some("id > 5"), vec![("name", "'foo'")])
|
||||
@@ -772,7 +855,7 @@ mod tests {
|
||||
);
|
||||
|
||||
Dataset::write(record_batch_iter, uri, None).await.unwrap();
|
||||
let mut table = Table::open(uri).await.unwrap();
|
||||
let table = NativeTable::open(uri).await.unwrap();
|
||||
|
||||
// check it can do update for each type
|
||||
let updates: Vec<(&str, &str)> = vec![
|
||||
@@ -889,11 +972,10 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let table = Table::open(uri).await.unwrap();
|
||||
let table = NativeTable::open(uri).await.unwrap();
|
||||
|
||||
let vector = Float32Array::from_iter_values([0.1, 0.2]);
|
||||
let query = table.search(Some(vector.clone()));
|
||||
assert_eq!(vector, query.query_vector.unwrap());
|
||||
let query = table.search(&[0.1, 0.2]);
|
||||
assert_eq!(&[0.1, 0.2], query.query_vector.unwrap().values());
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
@@ -937,7 +1019,7 @@ mod tests {
|
||||
..Default::default()
|
||||
};
|
||||
assert!(!wrapper.called());
|
||||
let _ = Table::open_with_params(uri, "test", None, param)
|
||||
let _ = NativeTable::open_with_params(uri, "test", None, param)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(wrapper.called());
|
||||
@@ -991,23 +1073,23 @@ mod tests {
|
||||
schema,
|
||||
);
|
||||
|
||||
let mut table = Table::create(uri, "test", batches, None, None)
|
||||
let table = NativeTable::create(uri, "test", batches, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let mut i = IvfPQIndexBuilder::new();
|
||||
|
||||
assert_eq!(table.count_indexed_rows("my_index").await.unwrap(), None);
|
||||
assert_eq!(table.count_unindexed_rows("my_index").await.unwrap(), None);
|
||||
|
||||
let index_builder = i
|
||||
.column("embeddings".to_string())
|
||||
.index_name("my_index".to_string())
|
||||
.ivf_params(IvfBuildParams::new(256))
|
||||
.pq_params(PQBuildParams::default());
|
||||
table
|
||||
.create_index(&["embeddings"])
|
||||
.ivf_pq()
|
||||
.name("my_index")
|
||||
.num_partitions(256)
|
||||
.build()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
table.create_index(index_builder).await.unwrap();
|
||||
|
||||
assert_eq!(table.dataset.load_indices().await.unwrap().len(), 1);
|
||||
assert_eq!(table.load_indices().await.unwrap().len(), 1);
|
||||
assert_eq!(table.count_rows().await.unwrap(), 512);
|
||||
assert_eq!(table.name, "test");
|
||||
|
||||
|
||||
Reference in New Issue
Block a user