feat(rust): create index API improvement (#853)

* Extract a minimal Table interface in Rust SDK
* Make create_index composable in Rust.
* Fix compiling issues from ffi
This commit is contained in:
Lei Xu
2024-01-24 10:05:12 -08:00
committed by GitHub
parent 82cbcf6d07
commit 008e0b1a93
15 changed files with 665 additions and 497 deletions

View File

@@ -24,11 +24,14 @@ arrow-ord = "49.0"
arrow-schema = "49.0"
arrow-arith = "49.0"
arrow-cast = "49.0"
async-trait = "0"
chrono = "0.4.23"
half = { "version" = "=2.3.1", default-features = false, features = [
"num-traits",
] }
futures = "0"
log = "0.4"
object_store = "0.9.0"
snafu = "0.7.4"
url = "2"
num-traits = "0.2"

View File

@@ -16,16 +16,16 @@ use crate::query::Query;
use arrow_ipc::writer::FileWriter;
use napi::bindgen_prelude::*;
use napi_derive::napi;
use vectordb::{ipc::ipc_file_to_batches, table::Table as LanceDBTable};
use vectordb::{ipc::ipc_file_to_batches, table::TableRef};
#[napi]
pub struct Table {
pub(crate) table: LanceDBTable,
pub(crate) table: TableRef,
}
#[napi]
impl Table {
pub(crate) fn new(table: LanceDBTable) -> Self {
pub(crate) fn new(table: TableRef) -> Self {
Self { table }
}
@@ -46,7 +46,7 @@ impl Table {
pub async unsafe fn add(&mut self, buf: Buffer) -> napi::Result<()> {
let batches = ipc_file_to_batches(buf.to_vec())
.map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?;
self.table.add(batches, None).await.map_err(|e| {
self.table.add(Box::new(batches), None).await.map_err(|e| {
napi::Error::from_reason(format!(
"Failed to add batches to table {}: {}",
self.table, e

View File

@@ -29,10 +29,14 @@ pub(crate) fn table_create_scalar_index(mut cx: FunctionContext) -> JsResult<JsP
let (deferred, promise) = cx.promise();
let channel = cx.channel();
let mut table = js_table.table.clone();
let table = js_table.table.clone();
rt.spawn(async move {
let idx_result = table.create_scalar_index(&column, replace).await;
let idx_result = table
.as_native()
.unwrap()
.create_scalar_index(&column, replace)
.await;
deferred.settle_with(&channel, move |mut cx| {
idx_result.or_throw(&mut cx)?;

View File

@@ -12,13 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use lance_index::vector::{ivf::IvfBuildParams, pq::PQBuildParams};
use lance_linalg::distance::MetricType;
use neon::context::FunctionContext;
use neon::prelude::*;
use std::convert::TryFrom;
use vectordb::index::vector::{IvfPQIndexBuilder, VectorIndexBuilder};
use vectordb::index::IndexBuilder;
use crate::error::Error::InvalidIndexType;
use crate::error::ResultExt;
@@ -29,17 +27,24 @@ use crate::table::JsTable;
pub(crate) fn table_create_vector_index(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let index_params = cx.argument::<JsObject>(0)?;
let index_params_builder = get_index_params_builder(&mut cx, index_params).or_throw(&mut cx)?;
let rt = runtime(&mut cx)?;
let (deferred, promise) = cx.promise();
let channel = cx.channel();
let mut table = js_table.table.clone();
let table = js_table.table.clone();
let column_name = index_params
.get_opt::<JsString, _, _>(&mut cx, "column")?
.map(|s| s.value(&mut cx))
.unwrap_or("vector".to_string()); // Backward compatibility
let tbl = table.clone();
let mut index_builder = tbl.create_index(&[&column_name]);
get_index_params_builder(&mut cx, index_params, &mut index_builder).or_throw(&mut cx)?;
rt.spawn(async move {
let idx_result = table.create_index(&index_params_builder).await;
let idx_result = index_builder.build().await;
deferred.settle_with(&channel, move |mut cx| {
idx_result.or_throw(&mut cx)?;
Ok(cx.boxed(JsTable::from(table)))
@@ -51,66 +56,39 @@ pub(crate) fn table_create_vector_index(mut cx: FunctionContext) -> JsResult<JsP
fn get_index_params_builder(
cx: &mut FunctionContext,
obj: Handle<JsObject>,
) -> crate::error::Result<impl VectorIndexBuilder> {
let idx_type = obj.get::<JsString, _, _>(cx, "type")?.value(cx);
match idx_type.as_str() {
"ivf_pq" => {
let mut index_builder: IvfPQIndexBuilder = IvfPQIndexBuilder::new();
let mut pq_params = PQBuildParams::default();
obj.get_opt::<JsString, _, _>(cx, "column")?
.map(|s| index_builder.column(s.value(cx)));
obj.get_opt::<JsString, _, _>(cx, "index_name")?
.map(|s| index_builder.index_name(s.value(cx)));
if let Some(metric_type) = obj.get_opt::<JsString, _, _>(cx, "metric_type")? {
let metric_type = MetricType::try_from(metric_type.value(cx).as_str()).unwrap();
index_builder.metric_type(metric_type);
}
let num_partitions = obj.get_opt_usize(cx, "num_partitions")?;
let max_iters = obj.get_opt_usize(cx, "max_iters")?;
num_partitions.map(|np| {
let max_iters = max_iters.unwrap_or(50);
let ivf_params = IvfBuildParams {
num_partitions: np,
max_iters,
..Default::default()
};
index_builder.ivf_params(ivf_params)
});
if let Some(use_opq) = obj.get_opt::<JsBoolean, _, _>(cx, "use_opq")? {
pq_params.use_opq = use_opq.value(cx);
}
if let Some(num_sub_vectors) = obj.get_opt_usize(cx, "num_sub_vectors")? {
pq_params.num_sub_vectors = num_sub_vectors;
}
if let Some(num_bits) = obj.get_opt_usize(cx, "num_bits")? {
pq_params.num_bits = num_bits;
}
if let Some(max_iters) = obj.get_opt_usize(cx, "max_iters")? {
pq_params.max_iters = max_iters;
}
if let Some(max_opq_iters) = obj.get_opt_usize(cx, "max_opq_iters")? {
pq_params.max_opq_iters = max_opq_iters;
}
if let Some(replace) = obj.get_opt::<JsBoolean, _, _>(cx, "replace")? {
index_builder.replace(replace.value(cx));
}
Ok(index_builder)
builder: &mut IndexBuilder,
) -> crate::error::Result<()> {
match obj.get::<JsString, _, _>(cx, "type")?.value(cx).as_str() {
"ivf_pq" => builder.ivf_pq(),
_ => {
return Err(InvalidIndexType {
index_type: "".into(),
})
}
index_type => Err(InvalidIndexType {
index_type: index_type.into(),
}),
};
obj.get_opt::<JsString, _, _>(cx, "index_name")?
.map(|s| builder.name(s.value(cx).as_str()));
if let Some(metric_type) = obj.get_opt::<JsString, _, _>(cx, "metric_type")? {
let metric_type = MetricType::try_from(metric_type.value(cx).as_str())?;
builder.metric_type(metric_type);
}
if let Some(np) = obj.get_opt_usize(cx, "num_partitions")? {
builder.num_partitions(np as u64);
}
if let Some(ns) = obj.get_opt_u32(cx, "num_sub_vectors")? {
builder.num_sub_vectors(ns);
}
if let Some(max_iters) = obj.get_opt_u32(cx, "max_iters")? {
builder.max_iterations(max_iters);
}
if let Some(num_bits) = obj.get_opt_u32(cx, "num_bits")? {
builder.num_bits(num_bits);
}
if let Some(replace) = obj.get_opt::<JsBoolean, _, _>(cx, "replace")? {
builder.replace(replace.value(cx));
}
Ok(())
}

View File

@@ -70,14 +70,15 @@ impl JsQuery {
let query = query_vector.map(|q| convert::js_array_to_vec(q.deref(), &mut cx));
rt.spawn(async move {
let mut builder = table
.search(query)
.refine_factor(refine_factor)
.nprobes(nprobes)
.filter(filter)
.metric_type(metric_type)
.select(select)
.prefilter(prefilter);
let mut builder = table.query();
if let Some(query) = query {
builder = builder
.query_vector(&query)
.refine_factor(refine_factor)
.nprobes(nprobes)
.metric_type(metric_type);
};
builder = builder.filter(filter).select(select).prefilter(prefilter);
if let Some(limit) = limit {
builder = builder.limit(limit as usize);
};

View File

@@ -20,19 +20,19 @@ use lance::io::ObjectStoreParams;
use crate::arrow::{arrow_buffer_to_record_batch, record_batch_to_buffer};
use neon::prelude::*;
use neon::types::buffer::TypedArray;
use vectordb::Table;
use vectordb::TableRef;
use crate::error::ResultExt;
use crate::{convert, get_aws_creds, get_aws_region, runtime, JsDatabase};
pub(crate) struct JsTable {
pub table: Table,
pub table: TableRef,
}
impl Finalize for JsTable {}
impl From<Table> for JsTable {
fn from(table: Table) -> Self {
impl From<TableRef> for JsTable {
fn from(table: TableRef) -> Self {
JsTable { table }
}
}
@@ -96,7 +96,7 @@ impl JsTable {
arrow_buffer_to_record_batch(buffer.as_slice(&cx)).or_throw(&mut cx)?;
let rt = runtime(&mut cx)?;
let channel = cx.channel();
let mut table = js_table.table.clone();
let table = js_table.table.clone();
let (deferred, promise) = cx.promise();
let write_mode = match write_mode.as_str() {
@@ -118,7 +118,7 @@ impl JsTable {
rt.spawn(async move {
let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
let add_result = table.add(batch_reader, Some(params)).await;
let add_result = table.add(Box::new(batch_reader), Some(params)).await;
deferred.settle_with(&channel, move |mut cx| {
add_result.or_throw(&mut cx)?;
@@ -152,7 +152,7 @@ impl JsTable {
let (deferred, promise) = cx.promise();
let predicate = cx.argument::<JsString>(0)?.value(&mut cx);
let channel = cx.channel();
let mut table = js_table.table.clone();
let table = js_table.table.clone();
rt.spawn(async move {
let delete_result = table.delete(&predicate).await;
@@ -167,7 +167,7 @@ impl JsTable {
pub(crate) fn js_update(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let mut table = js_table.table.clone();
let table = js_table.table.clone();
let rt = runtime(&mut cx)?;
let (deferred, promise) = cx.promise();
@@ -218,7 +218,11 @@ impl JsTable {
let predicate = predicate.as_deref();
let update_result = table.update(predicate, updates_arg).await;
let update_result = table
.as_native()
.unwrap()
.update(predicate, updates_arg)
.await;
deferred.settle_with(&channel, move |mut cx| {
update_result.or_throw(&mut cx)?;
Ok(cx.boxed(JsTable::from(table)))
@@ -249,6 +253,8 @@ impl JsTable {
rt.spawn(async move {
let stats = table
.as_native()
.unwrap()
.cleanup_old_versions(older_than, Some(delete_unverified))
.await;
@@ -278,7 +284,7 @@ impl JsTable {
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let rt = runtime(&mut cx)?;
let (deferred, promise) = cx.promise();
let mut table = js_table.table.clone();
let table = js_table.table.clone();
let channel = cx.channel();
let js_options = cx.argument::<JsObject>(0)?;
@@ -310,7 +316,11 @@ impl JsTable {
}
rt.spawn(async move {
let stats = table.compact_files(options, None).await;
let stats = table
.as_native()
.unwrap()
.compact_files(options, None)
.await;
deferred.settle_with(&channel, move |mut cx| {
let stats = stats.or_throw(&mut cx)?;
@@ -349,7 +359,7 @@ impl JsTable {
let table = js_table.table.clone();
rt.spawn(async move {
let indices = table.load_indices().await;
let indices = table.as_native().unwrap().load_indices().await;
deferred.settle_with(&channel, move |mut cx| {
let indices = indices.or_throw(&mut cx)?;
@@ -389,8 +399,8 @@ impl JsTable {
rt.spawn(async move {
let load_stats = futures::try_join!(
table.count_indexed_rows(&index_uuid),
table.count_unindexed_rows(&index_uuid)
table.as_native().unwrap().count_indexed_rows(&index_uuid),
table.as_native().unwrap().count_unindexed_rows(&index_uuid)
);
deferred.settle_with(&channel, move |mut cx| {

View File

@@ -26,11 +26,11 @@ lance-index = { workspace = true }
lance-linalg = { workspace = true }
lance-testing = { workspace = true }
tokio = { version = "1.23", features = ["rt-multi-thread"] }
log = { workspace = true }
log.workspace = true
async-trait = "0"
bytes = "1"
futures = "0"
num-traits = "0"
futures.workspace = true
num-traits.workspace = true
url = { workspace = true }
serde = { version = "^1" }
serde_json = { version = "1" }

View File

@@ -27,7 +27,7 @@ use snafu::prelude::*;
use crate::error::{CreateDirSnafu, Error, InvalidTableNameSnafu, Result};
use crate::io::object_store::MirroringObjectStoreWrapper;
use crate::table::{ReadParams, Table};
use crate::table::{NativeTable, ReadParams, TableRef};
pub const LANCE_FILE_EXTENSION: &str = "lance";
@@ -46,17 +46,20 @@ pub trait Connection: Send + Sync {
/// * `params` - Optional [`WriteParams`] to create the table.
///
/// # Returns
/// Created [`Table`], or [`Err(Error::TableAlreadyExists)`] if the table already exists.
/// Created [`TableRef`], or [`Err(Error::TableAlreadyExists)`] if the table already exists.
async fn create_table(
&self,
name: &str,
batches: Box<dyn RecordBatchReader + Send>,
params: Option<WriteParams>,
) -> Result<Table>;
) -> Result<TableRef>;
async fn open_table(&self, name: &str) -> Result<Table>;
async fn open_table(&self, name: &str) -> Result<TableRef> {
self.open_table_with_params(name, ReadParams::default())
.await
}
async fn open_table_with_params(&self, name: &str, params: ReadParams) -> Result<Table>;
async fn open_table_with_params(&self, name: &str, params: ReadParams) -> Result<TableRef>;
/// Drop a table in the database.
///
@@ -240,30 +243,19 @@ impl Connection for Database {
name: &str,
batches: Box<dyn RecordBatchReader + Send>,
params: Option<WriteParams>,
) -> Result<Table> {
) -> Result<TableRef> {
let table_uri = self.table_uri(name)?;
Table::create(
&table_uri,
name,
batches,
self.store_wrapper.clone(),
params,
)
.await
}
/// Open a table in the database.
///
/// # Arguments
/// * `name` - The name of the table.
///
/// # Returns
///
/// * A [Table] object.
async fn open_table(&self, name: &str) -> Result<Table> {
self.open_table_with_params(name, ReadParams::default())
.await
Ok(Arc::new(
NativeTable::create(
&table_uri,
name,
batches,
self.store_wrapper.clone(),
params,
)
.await?,
))
}
/// Open a table in the database.
@@ -274,10 +266,13 @@ impl Connection for Database {
///
/// # Returns
///
/// * A [Table] object.
async fn open_table_with_params(&self, name: &str, params: ReadParams) -> Result<Table> {
/// * A [TableRef] object.
async fn open_table_with_params(&self, name: &str, params: ReadParams) -> Result<TableRef> {
let table_uri = self.table_uri(name)?;
Table::open_with_params(&table_uri, name, self.store_wrapper.clone(), params).await
Ok(Arc::new(
NativeTable::open_with_params(&table_uri, name, self.store_wrapper.clone(), params)
.await?,
))
}
async fn drop_table(&self, name: &str) -> Result<()> {

View File

@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::PoisonError;
use arrow_schema::ArrowError;
use snafu::Snafu;
@@ -35,6 +37,8 @@ pub enum Error {
Lance { message: String },
#[snafu(display("LanceDB Schema Error: {message}"))]
Schema { message: String },
#[snafu(display("Runtime error: {message}"))]
Runtime { message: String },
}
pub type Result<T> = std::result::Result<T, Error>;
@@ -70,3 +74,11 @@ impl From<object_store::path::Error> for Error {
}
}
}
impl<T> From<PoisonError<T>> for Error {
fn from(e: PoisonError<T>) -> Self {
Self::Runtime {
message: e.to_string(),
}
}
}

View File

@@ -1,4 +1,4 @@
// Copyright 2023 Lance Developers.
// Copyright 2024 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,4 +12,263 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{cmp::max, sync::Arc};
use lance_index::{DatasetIndexExt, IndexType};
pub use lance_linalg::distance::MetricType;
pub mod vector;
use crate::{Error, Result, Table};
/// Index Parameters.
pub enum IndexParams {
Scalar {
replace: bool,
},
IvfPq {
replace: bool,
metric_type: MetricType,
num_partitions: u64,
num_sub_vectors: u32,
num_bits: u32,
sample_rate: u32,
max_iterations: u32,
},
}
/// Builder for Index Parameters.
pub struct IndexBuilder {
table: Arc<dyn Table>,
columns: Vec<String>,
// General parameters
/// Index name.
name: Option<String>,
/// Replace the existing index.
replace: bool,
index_type: IndexType,
// Scalar index parameters
// Nothing to set here.
// IVF_PQ parameters
metric_type: MetricType,
num_partitions: Option<u64>,
// PQ related
num_sub_vectors: Option<u32>,
num_bits: u32,
/// The rate to find samples to train kmeans.
sample_rate: u32,
/// Max iteration to train kmeans.
max_iterations: u32,
}
impl IndexBuilder {
pub(crate) fn new(table: Arc<dyn Table>, columns: &[&str]) -> Self {
IndexBuilder {
table,
columns: columns.iter().map(|c| c.to_string()).collect(),
name: None,
replace: true,
index_type: IndexType::Scalar,
metric_type: MetricType::L2,
num_partitions: None,
num_sub_vectors: None,
num_bits: 8,
sample_rate: 256,
max_iterations: 50,
}
}
/// Build a Scalar Index.
///
/// Accepted parameters:
/// - `replace`: Replace the existing index.
/// - `name`: Index name. Default: `None`
pub fn scalar(&mut self) -> &mut Self {
self.index_type = IndexType::Scalar;
self
}
/// Build an IVF PQ index.
///
/// Accepted parameters:
/// - `replace`: Replace the existing index.
/// - `name`: Index name. Default: `None`
/// - `metric_type`: [MetricType] to use to build Vector Index.
/// - `num_partitions`: Number of IVF partitions.
/// - `num_sub_vectors`: Number of sub-vectors of PQ.
/// - `num_bits`: Number of bits used for PQ centroids.
/// - `sample_rate`: The rate to find samples to train kmeans.
/// - `max_iterations`: Max iteration to train kmeans.
pub fn ivf_pq(&mut self) -> &mut Self {
self.index_type = IndexType::Vector;
self
}
/// Whether to replace the existing index, default is `true`.
pub fn replace(&mut self, v: bool) -> &mut Self {
self.replace = v;
self
}
/// Set the index name.
pub fn name(&mut self, name: &str) -> &mut Self {
self.name = Some(name.to_string());
self
}
/// [MetricType] to use to build Vector Index.
///
/// Default value is [MetricType::L2].
pub fn metric_type(&mut self, metric_type: MetricType) -> &mut Self {
self.metric_type = metric_type;
self
}
/// Number of IVF partitions.
pub fn num_partitions(&mut self, num_partitions: u64) -> &mut Self {
self.num_partitions = Some(num_partitions);
self
}
/// Number of sub-vectors of PQ.
pub fn num_sub_vectors(&mut self, num_sub_vectors: u32) -> &mut Self {
self.num_sub_vectors = Some(num_sub_vectors);
self
}
/// Number of bits used for PQ centroids.
pub fn num_bits(&mut self, num_bits: u32) -> &mut Self {
self.num_bits = num_bits;
self
}
/// The rate to find samples to train kmeans.
pub fn sample_rate(&mut self, sample_rate: u32) -> &mut Self {
self.sample_rate = sample_rate;
self
}
/// Max iteration to train kmeans.
pub fn max_iterations(&mut self, max_iterations: u32) -> &mut Self {
self.max_iterations = max_iterations;
self
}
/// Build the parameters.
pub async fn build(&self) -> Result<()> {
if self.columns.len() != 1 {
return Err(Error::Schema {
message: "Only one column is supported for index".to_string(),
});
}
let column = &self.columns[0];
let schema = self.table.schema();
let field = schema.field_with_name(column)?;
let params = match self.index_type {
IndexType::Scalar => IndexParams::Scalar {
replace: self.replace,
},
IndexType::Vector => {
let num_partitions = if let Some(n) = self.num_partitions {
n
} else {
suggested_num_partitions(self.table.count_rows().await?)
};
let num_sub_vectors: u32 = if let Some(n) = self.num_sub_vectors {
n
} else {
match field.data_type() {
arrow_schema::DataType::FixedSizeList(_, n) => {
Ok::<u32, Error>(suggested_num_sub_vectors(*n as u32))
}
_ => Err(Error::Schema {
message: format!(
"Column '{}' is not a FixedSizeList",
&self.columns[0]
),
}),
}?
};
IndexParams::IvfPq {
replace: self.replace,
metric_type: self.metric_type,
num_partitions,
num_sub_vectors,
num_bits: self.num_bits,
sample_rate: self.sample_rate,
max_iterations: self.max_iterations,
}
}
};
let tbl = self
.table
.as_native()
.expect("Only native table is supported here");
let mut dataset = tbl.clone_inner_dataset();
match params {
IndexParams::Scalar { replace } => {
self.table
.as_native()
.unwrap()
.create_scalar_index(column, replace)
.await?
}
IndexParams::IvfPq {
replace,
metric_type,
num_partitions,
num_sub_vectors,
num_bits,
max_iterations,
..
} => {
let lance_idx_params = lance::index::vector::VectorIndexParams::ivf_pq(
num_partitions as usize,
num_bits as u8,
num_sub_vectors as usize,
false,
metric_type,
max_iterations as usize,
);
dataset
.create_index(
&[column],
IndexType::Vector,
None,
&lance_idx_params,
replace,
)
.await?;
}
}
tbl.reset_dataset(dataset);
Ok(())
}
}
fn suggested_num_partitions(rows: usize) -> u64 {
let num_partitions = (rows as f64).sqrt() as u64;
max(1, num_partitions)
}
fn suggested_num_sub_vectors(dim: u32) -> u32 {
if dim % 16 == 0 {
// Should be more aggressive than this default.
dim / 16
} else if dim % 8 == 0 {
dim / 8
} else {
log::warn!(
"The dimension of the vector is not divisible by 8 or 16, \
which may cause performance degradation in PQ"
);
1
}
}

View File

@@ -14,104 +14,7 @@
use serde::Deserialize;
use lance::index::vector::pq::PQBuildParams;
use lance::index::vector::VectorIndexParams;
use lance::table::format::{Index, Manifest};
use lance_index::vector::ivf::IvfBuildParams;
use lance_linalg::distance::MetricType;
pub trait VectorIndexBuilder {
fn get_column(&self) -> Option<String>;
fn get_index_name(&self) -> Option<String>;
fn build(&self) -> VectorIndexParams;
fn get_replace(&self) -> bool;
}
pub struct IvfPQIndexBuilder {
column: Option<String>,
index_name: Option<String>,
metric_type: Option<MetricType>,
ivf_params: Option<IvfBuildParams>,
pq_params: Option<PQBuildParams>,
replace: bool,
}
impl IvfPQIndexBuilder {
pub fn new() -> IvfPQIndexBuilder {
Default::default()
}
}
impl Default for IvfPQIndexBuilder {
fn default() -> Self {
IvfPQIndexBuilder {
column: None,
index_name: None,
metric_type: None,
ivf_params: None,
pq_params: None,
replace: true,
}
}
}
impl IvfPQIndexBuilder {
pub fn column(&mut self, column: String) -> &mut IvfPQIndexBuilder {
self.column = Some(column);
self
}
pub fn index_name(&mut self, index_name: String) -> &mut IvfPQIndexBuilder {
self.index_name = Some(index_name);
self
}
pub fn metric_type(&mut self, metric_type: MetricType) -> &mut IvfPQIndexBuilder {
self.metric_type = Some(metric_type);
self
}
pub fn ivf_params(&mut self, ivf_params: IvfBuildParams) -> &mut IvfPQIndexBuilder {
self.ivf_params = Some(ivf_params);
self
}
pub fn pq_params(&mut self, pq_params: PQBuildParams) -> &mut IvfPQIndexBuilder {
self.pq_params = Some(pq_params);
self
}
pub fn replace(&mut self, replace: bool) -> &mut IvfPQIndexBuilder {
self.replace = replace;
self
}
}
impl VectorIndexBuilder for IvfPQIndexBuilder {
fn get_column(&self) -> Option<String> {
self.column.clone()
}
fn get_index_name(&self) -> Option<String> {
self.index_name.clone()
}
fn build(&self) -> VectorIndexParams {
let ivf_params = self.ivf_params.clone().unwrap_or_default();
let pq_params = self.pq_params.clone().unwrap_or_default();
VectorIndexParams::with_ivf_pq_params(
self.metric_type.unwrap_or(MetricType::L2),
ivf_params,
pq_params,
)
}
fn get_replace(&self) -> bool {
self.replace
}
}
pub struct VectorIndex {
pub columns: Vec<String>,
@@ -139,79 +42,3 @@ pub struct VectorIndexStatistics {
pub num_indexed_rows: usize,
pub num_unindexed_rows: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use lance::index::vector::StageParams;
use lance_index::vector::ivf::IvfBuildParams;
use lance_index::vector::pq::PQBuildParams;
use crate::index::vector::{IvfPQIndexBuilder, VectorIndexBuilder};
#[test]
fn test_builder_no_params() {
let index_builder = IvfPQIndexBuilder::new();
assert!(index_builder.get_column().is_none());
assert!(index_builder.get_index_name().is_none());
let index_params = index_builder.build();
assert_eq!(index_params.stages.len(), 2);
if let StageParams::Ivf(ivf_params) = index_params.stages.get(0).unwrap() {
let default = IvfBuildParams::default();
assert_eq!(ivf_params.num_partitions, default.num_partitions);
assert_eq!(ivf_params.max_iters, default.max_iters);
} else {
panic!("Expected first stage to be ivf")
}
if let StageParams::PQ(pq_params) = index_params.stages.get(1).unwrap() {
assert_eq!(pq_params.use_opq, false);
} else {
panic!("Expected second stage to be pq")
}
}
#[test]
fn test_builder_all_params() {
let mut index_builder = IvfPQIndexBuilder::new();
index_builder
.column("c".to_owned())
.metric_type(MetricType::Cosine)
.index_name("index".to_owned());
assert_eq!(index_builder.column.clone().unwrap(), "c");
assert_eq!(index_builder.metric_type.unwrap(), MetricType::Cosine);
assert_eq!(index_builder.index_name.clone().unwrap(), "index");
let ivf_params = IvfBuildParams::new(500);
let mut pq_params = PQBuildParams::default();
pq_params.use_opq = true;
pq_params.max_iters = 1;
pq_params.num_bits = 8;
pq_params.num_sub_vectors = 50;
pq_params.max_opq_iters = 2;
index_builder.ivf_params(ivf_params);
index_builder.pq_params(pq_params);
let index_params = index_builder.build();
assert_eq!(index_params.stages.len(), 2);
if let StageParams::Ivf(ivf_params) = index_params.stages.get(0).unwrap() {
assert_eq!(ivf_params.num_partitions, 500);
} else {
assert!(false, "Expected first stage to be ivf")
}
if let StageParams::PQ(pq_params) = index_params.stages.get(1).unwrap() {
assert_eq!(pq_params.use_opq, true);
assert_eq!(pq_params.max_iters, 1);
assert_eq!(pq_params.num_bits, 8);
assert_eq!(pq_params.num_sub_vectors, 50);
assert_eq!(pq_params.max_opq_iters, 2);
} else {
assert!(false, "Expected second stage to be pq")
}
}
}

View File

@@ -335,14 +335,15 @@ impl WrappingObjectStore for MirroringObjectStoreWrapper {
#[cfg(all(test, not(windows)))]
mod test {
use super::*;
use crate::connection::{Connection, Database};
use arrow_array::PrimitiveArray;
use futures::TryStreamExt;
use lance::{dataset::WriteParams, io::ObjectStoreParams};
use lance_testing::datagen::{BatchGenerator, IncrementingInt32, RandomVector};
use object_store::local::LocalFileSystem;
use tempfile;
use crate::connection::{Connection, Database};
#[tokio::test]
async fn test_e2e() {
let dir1 = tempfile::tempdir().unwrap().into_path();
@@ -374,9 +375,7 @@ mod test {
assert_eq!(t.count_rows().await.unwrap(), 100);
let q = t
.search(Some(PrimitiveArray::from_iter_values(vec![
0.1, 0.1, 0.1, 0.1,
])))
.search(&[0.1, 0.1, 0.1, 0.1])
.limit(10)
.execute()
.await

View File

@@ -46,7 +46,7 @@
//! #### Connect to a database.
//!
//! ```rust
//! use vectordb::{connection::{Database, Connection}, Table, WriteMode};
//! use vectordb::{connection::{Database, Connection}, WriteMode};
//! use arrow_schema::{Field, Schema};
//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
//! let db = Database::connect("data/sample-lancedb").await.unwrap();
@@ -119,7 +119,7 @@
//! # db.create_table("my_table", Box::new(batches), None).await.unwrap();
//! let table = db.open_table("my_table").await.unwrap();
//! let results = table
//! .search(Some(vec![1.0; 128]))
//! .search(&[1.0; 128])
//! .execute()
//! .await
//! .unwrap()
@@ -143,6 +143,6 @@ pub mod utils;
pub use connection::{Connection, Database};
pub use error::{Error, Result};
pub use table::Table;
pub use table::{Table, TableRef};
pub use lance::dataset::WriteMode;

View File

@@ -1,4 +1,4 @@
// Copyright 2023 Lance Developers.
// Copyright 2024 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -48,10 +48,10 @@ impl Query {
/// # Returns
///
/// * A [Query] object.
pub(crate) fn new(dataset: Arc<Dataset>, vector: Option<Float32Array>) -> Self {
pub(crate) fn new(dataset: Arc<Dataset>) -> Self {
Query {
dataset,
query_vector: vector,
query_vector: None,
column: crate::table::VECTOR_COLUMN_NAME.to_string(),
limit: None,
nprobes: 20,
@@ -206,7 +206,7 @@ mod tests {
let ds = Dataset::write(batches, "memory://foo", None).await.unwrap();
let vector = Some(Float32Array::from_iter_values([0.1, 0.2]));
let query = Query::new(Arc::new(ds), vector.clone());
let query = Query::new(Arc::new(ds)).query_vector(&[0.1, 0.2]);
assert_eq!(query.query_vector, vector);
let new_vector = Float32Array::from_iter_values([9.8, 8.7]);
@@ -232,9 +232,7 @@ mod tests {
let batches = make_non_empty_batches();
let ds = Arc::new(Dataset::write(batches, "memory://foo", None).await.unwrap());
let vector = Some(Float32Array::from_iter_values([0.1; 4]));
let query = Query::new(ds.clone(), vector.clone());
let query = Query::new(ds.clone()).query_vector(&[0.1; 4]);
let result = query
.limit(10)
.filter(Some("id % 2 == 0".to_string()))
@@ -247,7 +245,7 @@ mod tests {
assert!(batch.expect("should be Ok").num_rows() < 10);
}
let query = Query::new(ds, vector.clone());
let query = Query::new(ds).query_vector(&[0.1; 4]);
let result = query
.limit(10)
.filter(Some("id % 2 == 0".to_string()))
@@ -268,7 +266,7 @@ mod tests {
let batches = make_non_empty_batches();
let ds = Arc::new(Dataset::write(batches, "memory://foo", None).await.unwrap());
let query = Query::new(ds.clone(), None);
let query = Query::new(ds.clone());
let result = query
.filter(Some("id % 2 == 0".to_string()))
.execute()

View File

@@ -1,4 +1,4 @@
// Copyright 2023 LanceDB Developers.
// Copyright 2024 LanceDB Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,72 +12,170 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//! LanceDB Table APIs
use std::path::Path;
use std::sync::{Arc, Mutex};
use arrow_array::RecordBatchReader;
use arrow_schema::{Schema, SchemaRef};
use chrono::Duration;
use lance::dataset::builder::DatasetBuilder;
use lance::index::scalar::ScalarIndexParams;
use lance_index::optimize::OptimizeOptions;
use lance_index::IndexType;
use std::sync::Arc;
use arrow_array::{Float32Array, RecordBatchReader};
use arrow_schema::SchemaRef;
use lance::dataset::cleanup::RemovalStats;
use lance::dataset::optimize::{
compact_files, CompactionMetrics, CompactionOptions, IndexRemapperOptions,
};
pub use lance::dataset::ReadParams;
use lance::dataset::{Dataset, UpdateBuilder, WriteParams};
use lance::index::scalar::ScalarIndexParams;
use lance::io::WrappingObjectStore;
use lance_index::DatasetIndexExt;
use std::path::Path;
use lance_index::{optimize::OptimizeOptions, DatasetIndexExt, IndexType};
use crate::error::{Error, Result};
use crate::index::vector::{VectorIndex, VectorIndexBuilder, VectorIndexStatistics};
use crate::index::vector::{VectorIndex, VectorIndexStatistics};
use crate::index::IndexBuilder;
use crate::query::Query;
use crate::utils::{PatchReadParam, PatchWriteParam};
use crate::WriteMode;
pub use lance::dataset::ReadParams;
pub const VECTOR_COLUMN_NAME: &str = "vector";
/// A Table is a collection of strong typed Rows.
///
/// The type of the each row is defined in Apache Arrow [Schema].
#[async_trait::async_trait]
pub trait Table: std::fmt::Display + Send + Sync {
fn as_any(&self) -> &dyn std::any::Any;
/// Cast as [`NativeTable`], or return None it if is not a [`NativeTable`].
fn as_native(&self) -> Option<&NativeTable>;
/// Get the name of the table.
fn name(&self) -> &str;
/// Get the arrow [Schema] of the table.
fn schema(&self) -> SchemaRef;
/// Count the number of rows in this dataset.
async fn count_rows(&self) -> Result<usize>;
/// Insert new records into this Table
///
/// # Arguments
///
/// * `batches` RecordBatch to be saved in the Table
/// * `params` Append / Overwrite existing records. Default: Append
async fn add(
&self,
batches: Box<dyn RecordBatchReader + Send>,
params: Option<WriteParams>,
) -> Result<()>;
/// Delete the rows from table that match the predicate.
///
/// # Arguments
/// - `predicate` - The SQL predicate string to filter the rows to be deleted.
///
/// # Example
///
/// ```no_run
/// # use std::sync::Arc;
/// # use vectordb::connection::{Database, Connection};
/// # use arrow_array::{FixedSizeListArray, types::Float32Type, RecordBatch,
/// # RecordBatchIterator, Int32Array};
/// # use arrow_schema::{Schema, Field, DataType};
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let tmpdir = tempfile::tempdir().unwrap();
/// let db = Database::connect(tmpdir.path().to_str().unwrap()).await.unwrap();
/// # let schema = Arc::new(Schema::new(vec![
/// # Field::new("id", DataType::Int32, false),
/// # Field::new("vector", DataType::FixedSizeList(
/// # Arc::new(Field::new("item", DataType::Float32, true)), 128), true),
/// # ]));
/// let batches = RecordBatchIterator::new(vec![
/// RecordBatch::try_new(schema.clone(),
/// vec![
/// Arc::new(Int32Array::from_iter_values(0..10)),
/// Arc::new(FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
/// (0..10).map(|_| Some(vec![Some(1.0); 128])), 128)),
/// ]).unwrap()
/// ].into_iter().map(Ok),
/// schema.clone());
/// let tbl = db.create_table("delete_test", Box::new(batches), None).await.unwrap();
/// tbl.delete("id > 5").await.unwrap();
/// # });
/// ```
async fn delete(&self, predicate: &str) -> Result<()>;
/// Create an index on the column name.
///
/// # Examples
///
/// ```no_run
/// # use std::sync::Arc;
/// # use vectordb::connection::{Database, Connection};
/// # use arrow_array::{FixedSizeListArray, types::Float32Type, RecordBatch,
/// # RecordBatchIterator, Int32Array};
/// # use arrow_schema::{Schema, Field, DataType};
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let tmpdir = tempfile::tempdir().unwrap();
/// let db = Database::connect(tmpdir.path().to_str().unwrap()).await.unwrap();
/// # let tbl = db.open_table("delete_test").await.unwrap();
/// tbl.create_index(&["vector"])
/// .ivf_pq()
/// .num_partitions(256)
/// .build()
/// .await
/// .unwrap();
/// # });
/// ```
fn create_index(&self, column: &[&str]) -> IndexBuilder;
/// Search the table with a given query vector.
fn search(&self, query: &[f32]) -> Query {
self.query().query_vector(query)
}
/// Create a Query builder.
fn query(&self) -> Query;
}
/// Reference to a Table pointer.
pub type TableRef = Arc<dyn Table>;
/// A table in a LanceDB database.
#[derive(Debug, Clone)]
pub struct Table {
pub struct NativeTable {
name: String,
uri: String,
dataset: Arc<Dataset>,
dataset: Arc<Mutex<Dataset>>,
// the object store wrapper to use on write path
store_wrapper: Option<Arc<dyn WrappingObjectStore>>,
}
impl std::fmt::Display for Table {
impl std::fmt::Display for NativeTable {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Table({})", self.name)
}
}
impl Table {
impl NativeTable {
/// Opens an existing Table
///
/// # Arguments
///
/// * `uri` - The uri to a [Table]
/// * `uri` - The uri to a [NativeTable]
/// * `name` - The table name
///
/// # Returns
///
/// * A [Table] object.
/// * A [NativeTable] object.
pub async fn open(uri: &str) -> Result<Self> {
let name = Self::get_table_name(uri)?;
Self::open_with_params(uri, &name, None, ReadParams::default()).await
}
/// Open an Table with a given name.
pub async fn open_with_name(uri: &str, name: &str) -> Result<Self> {
Self::open_with_params(uri, name, None, ReadParams::default()).await
}
/// Opens an existing Table
///
/// # Arguments
@@ -88,7 +186,7 @@ impl Table {
///
/// # Returns
///
/// * A [Table] object.
/// * A [NativeTable] object.
pub async fn open_with_params(
uri: &str,
name: &str,
@@ -113,25 +211,26 @@ impl Table {
message: e.to_string(),
},
})?;
Ok(Table {
Ok(NativeTable {
name: name.to_string(),
uri: uri.to_string(),
dataset: Arc::new(dataset),
dataset: Arc::new(Mutex::new(dataset)),
store_wrapper: write_store_wrapper,
})
}
/// Checkout a specific version of this [`Table`]
/// Make a new clone of the internal lance dataset.
pub(crate) fn clone_inner_dataset(&self) -> Dataset {
self.dataset.lock().expect("Lock poison").clone()
}
/// Checkout a specific version of this [NativeTable]
///
pub async fn checkout(uri: &str, version: u64) -> Result<Self> {
let name = Self::get_table_name(uri)?;
Self::checkout_with_params(uri, &name, version, None, ReadParams::default()).await
}
pub async fn checkout_with_name(uri: &str, name: &str, version: u64) -> Result<Self> {
Self::checkout_with_params(uri, name, version, None, ReadParams::default()).await
}
pub async fn checkout_with_params(
uri: &str,
name: &str,
@@ -154,26 +253,27 @@ impl Table {
message: e.to_string(),
},
})?;
Ok(Table {
Ok(NativeTable {
name: name.to_string(),
uri: uri.to_string(),
dataset: Arc::new(dataset),
dataset: Arc::new(Mutex::new(dataset)),
store_wrapper: write_store_wrapper,
})
}
pub async fn checkout_latest(&self) -> Result<Self> {
let latest_version_id = self.dataset.latest_version_id().await?;
let dataset = if latest_version_id == self.dataset.version().version {
self.dataset.clone()
let dataset = self.clone_inner_dataset();
let latest_version_id = dataset.latest_version_id().await?;
let dataset = if latest_version_id == dataset.version().version {
dataset
} else {
Arc::new(self.dataset.checkout_version(latest_version_id).await?)
dataset.checkout_version(latest_version_id).await?
};
Ok(Table {
Ok(Self {
name: self.name.clone(),
uri: self.uri.clone(),
dataset,
dataset: Arc::new(Mutex::new(dataset)),
store_wrapper: self.store_wrapper.clone(),
})
}
@@ -203,7 +303,7 @@ impl Table {
///
/// # Returns
///
/// * A [Table] object.
/// * A [TableImpl] object.
pub(crate) async fn create(
uri: &str,
name: &str,
@@ -227,46 +327,22 @@ impl Table {
message: e.to_string(),
},
})?;
Ok(Table {
Ok(NativeTable {
name: name.to_string(),
uri: uri.to_string(),
dataset: Arc::new(dataset),
dataset: Arc::new(Mutex::new(dataset)),
store_wrapper: write_store_wrapper,
})
}
/// Schema of this Table.
pub fn schema(&self) -> SchemaRef {
Arc::new(self.dataset.schema().into())
}
/// Version of this Table
pub fn version(&self) -> u64 {
self.dataset.version().version
}
/// Create index on the table.
pub async fn create_index(&mut self, index_builder: &impl VectorIndexBuilder) -> Result<()> {
let mut dataset = self.dataset.as_ref().clone();
dataset
.create_index(
&[index_builder
.get_column()
.unwrap_or(VECTOR_COLUMN_NAME.to_string())
.as_str()],
IndexType::Vector,
index_builder.get_index_name(),
&index_builder.build(),
index_builder.get_replace(),
)
.await?;
self.dataset = Arc::new(dataset);
Ok(())
self.dataset.lock().expect("lock poison").version().version
}
/// Create a scalar index on the table
pub async fn create_scalar_index(&mut self, column: &str, replace: bool) -> Result<()> {
let mut dataset = self.dataset.as_ref().clone();
pub async fn create_scalar_index(&self, column: &str, replace: bool) -> Result<()> {
let mut dataset = self.clone_inner_dataset();
let params = ScalarIndexParams::default();
dataset
.create_index(&[column], IndexType::Scalar, None, &params, replace)
@@ -275,65 +351,21 @@ impl Table {
}
pub async fn optimize_indices(&mut self, options: &OptimizeOptions) -> Result<()> {
let mut dataset = self.dataset.as_ref().clone();
let mut dataset = self.clone_inner_dataset();
dataset.optimize_indices(options).await?;
Ok(())
}
/// Insert records into this Table
///
/// # Arguments
///
/// * `batches` RecordBatch to be saved in the Table
/// * `write_mode` Append / Overwrite existing records. Default: Append
/// # Returns
///
/// * The number of rows added
pub async fn add(
&mut self,
batches: impl RecordBatchReader + Send + 'static,
params: Option<WriteParams>,
) -> Result<()> {
let params = Some(params.unwrap_or(WriteParams {
mode: WriteMode::Append,
..WriteParams::default()
}));
// patch the params if we have a write store wrapper
let params = match self.store_wrapper.clone() {
Some(wrapper) => params.patch_with_store_wrapper(wrapper)?,
None => params,
};
self.dataset = Arc::new(Dataset::write(batches, &self.uri, params).await?);
Ok(())
}
pub fn query(&self) -> Query {
Query::new(self.dataset.clone(), None)
}
/// Creates a new Query object that can be executed.
///
/// # Arguments
///
/// * `query_vector` The vector used for this query.
///
/// # Returns
/// * A [Query] object.
pub fn search<T: Into<Float32Array>>(&self, query_vector: Option<T>) -> Query {
Query::new(self.dataset.clone(), query_vector.map(|q| q.into()))
Query::new(self.clone_inner_dataset().into())
}
pub fn filter(&self, expr: String) -> Query {
Query::new(self.dataset.clone(), None).filter(Some(expr))
Query::new(self.clone_inner_dataset().into()).filter(Some(expr))
}
/// Returns the number of rows in this Table
pub async fn count_rows(&self) -> Result<usize> {
Ok(self.dataset.count_rows().await?)
}
/// Merge new data into this table.
pub async fn merge(
@@ -342,26 +374,14 @@ impl Table {
left_on: &str,
right_on: &str,
) -> Result<()> {
let mut dataset = self.dataset.as_ref().clone();
let mut dataset = self.clone_inner_dataset();
dataset.merge(batches, left_on, right_on).await?;
self.dataset = Arc::new(dataset);
self.dataset = Arc::new(Mutex::new(dataset));
Ok(())
}
/// Delete rows from the table
pub async fn delete(&mut self, predicate: &str) -> Result<()> {
let mut dataset = self.dataset.as_ref().clone();
dataset.delete(predicate).await?;
self.dataset = Arc::new(dataset);
Ok(())
}
pub async fn update(
&mut self,
predicate: Option<&str>,
updates: Vec<(&str, &str)>,
) -> Result<()> {
let mut builder = UpdateBuilder::new(self.dataset.clone());
pub async fn update(&self, predicate: Option<&str>, updates: Vec<(&str, &str)>) -> Result<()> {
let mut builder = UpdateBuilder::new(self.clone_inner_dataset().into());
if let Some(predicate) = predicate {
builder = builder.update_where(predicate)?;
}
@@ -371,9 +391,8 @@ impl Table {
}
let operation = builder.build()?;
let new_ds = operation.execute().await?;
self.dataset = new_ds;
let ds = operation.execute().await?;
self.reset_dataset(ds.as_ref().clone());
Ok(())
}
@@ -393,8 +412,8 @@ impl Table {
older_than: Duration,
delete_unverified: Option<bool>,
) -> Result<RemovalStats> {
Ok(self
.dataset
let dataset = self.clone_inner_dataset();
Ok(dataset
.cleanup_old_versions(older_than, delete_unverified)
.await?)
}
@@ -406,26 +425,28 @@ impl Table {
///
/// This calls into [lance::dataset::optimize::compact_files].
pub async fn compact_files(
&mut self,
&self,
options: CompactionOptions,
remap_options: Option<Arc<dyn IndexRemapperOptions>>,
) -> Result<CompactionMetrics> {
let mut dataset = self.dataset.as_ref().clone();
let mut dataset = self.clone_inner_dataset();
let metrics = compact_files(&mut dataset, options, remap_options).await?;
self.dataset = Arc::new(dataset);
self.reset_dataset(dataset);
Ok(metrics)
}
pub fn count_fragments(&self) -> usize {
self.dataset.count_fragments()
self.dataset.lock().expect("lock poison").count_fragments()
}
pub async fn count_deleted_rows(&self) -> Result<usize> {
Ok(self.dataset.count_deleted_rows().await?)
let dataset = self.clone_inner_dataset();
Ok(dataset.count_deleted_rows().await?)
}
pub async fn num_small_files(&self, max_rows_per_group: usize) -> usize {
self.dataset.num_small_files(max_rows_per_group).await
let dataset = self.clone_inner_dataset();
dataset.num_small_files(max_rows_per_group).await
}
pub async fn count_indexed_rows(&self, index_uuid: &str) -> Result<Option<usize>> {
@@ -443,8 +464,8 @@ impl Table {
}
pub async fn load_indices(&self) -> Result<Vec<VectorIndex>> {
let (indices, mf) =
futures::try_join!(self.dataset.load_indices(), self.dataset.latest_manifest())?;
let dataset = self.clone_inner_dataset();
let (indices, mf) = futures::try_join!(dataset.load_indices(), dataset.latest_manifest())?;
Ok(indices
.iter()
.map(|i| VectorIndex::new_from_format(&mf, i))
@@ -460,10 +481,8 @@ impl Table {
if index.is_none() {
return Ok(None);
}
let index_stats = self
.dataset
.index_statistics(&index.unwrap().index_name)
.await?;
let dataset = self.clone_inner_dataset();
let index_stats = dataset.index_statistics(&index.unwrap().index_name).await?;
let index_stats: VectorIndexStatistics =
serde_json::from_str(&index_stats).map_err(|e| Error::Lance {
message: format!(
@@ -474,6 +493,71 @@ impl Table {
Ok(Some(index_stats))
}
pub(crate) fn reset_dataset(&self, dataset: Dataset) {
*self.dataset.lock().expect("lock poison") = dataset;
}
}
#[async_trait::async_trait]
impl Table for NativeTable {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_native(&self) -> Option<&NativeTable> {
Some(self)
}
fn name(&self) -> &str {
self.name.as_str()
}
fn schema(&self) -> SchemaRef {
let lance_schema = { self.dataset.lock().expect("lock poison").schema().clone() };
Arc::new(Schema::from(&lance_schema))
}
async fn count_rows(&self) -> Result<usize> {
let dataset = { self.dataset.lock().expect("lock poison").clone() };
Ok(dataset.count_rows().await?)
}
async fn add(
&self,
batches: Box<dyn RecordBatchReader + Send>,
params: Option<WriteParams>,
) -> Result<()> {
let params = Some(params.unwrap_or(WriteParams {
mode: WriteMode::Append,
..WriteParams::default()
}));
// patch the params if we have a write store wrapper
let params = match self.store_wrapper.clone() {
Some(wrapper) => params.patch_with_store_wrapper(wrapper)?,
None => params,
};
self.reset_dataset(Dataset::write(batches, &self.uri, params).await?);
Ok(())
}
fn create_index(&self, columns: &[&str]) -> IndexBuilder {
IndexBuilder::new(Arc::new(self.clone()), columns)
}
fn query(&self) -> Query {
Query::new(Arc::new(self.dataset.lock().expect("lock poison").clone()))
}
/// Delete rows from the table
async fn delete(&self, predicate: &str) -> Result<()> {
let mut dataset = self.clone_inner_dataset();
dataset.delete(predicate).await?;
self.reset_dataset(dataset);
Ok(())
}
}
#[cfg(test)]
@@ -491,14 +575,11 @@ mod tests {
use arrow_schema::{DataType, Field, Schema, TimeUnit};
use futures::TryStreamExt;
use lance::dataset::{Dataset, WriteMode};
use lance::index::vector::pq::PQBuildParams;
use lance::io::{ObjectStoreParams, WrappingObjectStore};
use lance_index::vector::ivf::IvfBuildParams;
use rand::Rng;
use tempfile::tempdir;
use super::*;
use crate::index::vector::IvfPQIndexBuilder;
#[tokio::test]
async fn test_open() {
@@ -510,7 +591,9 @@ mod tests {
.await
.unwrap();
let table = Table::open(dataset_path.to_str().unwrap()).await.unwrap();
let table = NativeTable::open(dataset_path.to_str().unwrap())
.await
.unwrap();
assert_eq!(table.name, "test")
}
@@ -519,7 +602,7 @@ mod tests {
async fn test_open_not_found() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let table = Table::open(uri).await;
let table = NativeTable::open(uri).await;
assert!(matches!(table.unwrap_err(), Error::TableNotFound { .. }));
}
@@ -539,12 +622,12 @@ mod tests {
let batches = make_test_batches();
let _ = batches.schema().clone();
Table::create(&uri, "test", batches, None, None)
NativeTable::create(&uri, "test", batches, None, None)
.await
.unwrap();
let batches = make_test_batches();
let result = Table::create(&uri, "test", batches, None, None).await;
let result = NativeTable::create(&uri, "test", batches, None, None).await;
assert!(matches!(
result.unwrap_err(),
Error::TableAlreadyExists { .. }
@@ -558,7 +641,7 @@ mod tests {
let batches = make_test_batches();
let schema = batches.schema().clone();
let mut table = Table::create(&uri, "test", batches, None, None)
let table = NativeTable::create(&uri, "test", batches, None, None)
.await
.unwrap();
assert_eq!(table.count_rows().await.unwrap(), 10);
@@ -574,7 +657,7 @@ mod tests {
schema.clone(),
);
table.add(new_batches, None).await.unwrap();
table.add(Box::new(new_batches), None).await.unwrap();
assert_eq!(table.count_rows().await.unwrap(), 20);
assert_eq!(table.name, "test");
}
@@ -586,7 +669,7 @@ mod tests {
let batches = make_test_batches();
let schema = batches.schema().clone();
let mut table = Table::create(uri, "test", batches, None, None)
let table = NativeTable::create(uri, "test", batches, None, None)
.await
.unwrap();
assert_eq!(table.count_rows().await.unwrap(), 10);
@@ -607,7 +690,7 @@ mod tests {
..Default::default()
};
table.add(new_batches, Some(param)).await.unwrap();
table.add(Box::new(new_batches), Some(param)).await.unwrap();
assert_eq!(table.count_rows().await.unwrap(), 10);
assert_eq!(table.name, "test");
}
@@ -640,7 +723,7 @@ mod tests {
);
Dataset::write(record_batch_iter, uri, None).await.unwrap();
let mut table = Table::open(uri).await.unwrap();
let table = NativeTable::open(uri).await.unwrap();
table
.update(Some("id > 5"), vec![("name", "'foo'")])
@@ -772,7 +855,7 @@ mod tests {
);
Dataset::write(record_batch_iter, uri, None).await.unwrap();
let mut table = Table::open(uri).await.unwrap();
let table = NativeTable::open(uri).await.unwrap();
// check it can do update for each type
let updates: Vec<(&str, &str)> = vec![
@@ -889,11 +972,10 @@ mod tests {
.await
.unwrap();
let table = Table::open(uri).await.unwrap();
let table = NativeTable::open(uri).await.unwrap();
let vector = Float32Array::from_iter_values([0.1, 0.2]);
let query = table.search(Some(vector.clone()));
assert_eq!(vector, query.query_vector.unwrap());
let query = table.search(&[0.1, 0.2]);
assert_eq!(&[0.1, 0.2], query.query_vector.unwrap().values());
}
#[derive(Default, Debug)]
@@ -937,7 +1019,7 @@ mod tests {
..Default::default()
};
assert!(!wrapper.called());
let _ = Table::open_with_params(uri, "test", None, param)
let _ = NativeTable::open_with_params(uri, "test", None, param)
.await
.unwrap();
assert!(wrapper.called());
@@ -991,23 +1073,23 @@ mod tests {
schema,
);
let mut table = Table::create(uri, "test", batches, None, None)
let table = NativeTable::create(uri, "test", batches, None, None)
.await
.unwrap();
let mut i = IvfPQIndexBuilder::new();
assert_eq!(table.count_indexed_rows("my_index").await.unwrap(), None);
assert_eq!(table.count_unindexed_rows("my_index").await.unwrap(), None);
let index_builder = i
.column("embeddings".to_string())
.index_name("my_index".to_string())
.ivf_params(IvfBuildParams::new(256))
.pq_params(PQBuildParams::default());
table
.create_index(&["embeddings"])
.ivf_pq()
.name("my_index")
.num_partitions(256)
.build()
.await
.unwrap();
table.create_index(index_builder).await.unwrap();
assert_eq!(table.dataset.load_indices().await.unwrap().len(), 1);
assert_eq!(table.load_indices().await.unwrap().len(), 1);
assert_eq!(table.count_rows().await.unwrap(), 512);
assert_eq!(table.name, "test");