mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-04 19:02:58 +00:00
feat: improve the rust table query API and documents (#860)
* Easy to type * Handle `String, &str, [String] and [&str]` well without manual conversion * Fix function name to be verb * Improve docstring of Rust. * Promote `query` and `search()` to public `Table` trait
This commit is contained in:
@@ -33,7 +33,7 @@ impl Query {
|
||||
|
||||
#[napi]
|
||||
pub fn vector(&mut self, vector: Float32Array) {
|
||||
let inn = self.inner.clone().query_vector(&vector);
|
||||
let inn = self.inner.clone().nearest_to(&vector);
|
||||
self.inner = inn;
|
||||
}
|
||||
|
||||
|
||||
@@ -40,17 +40,6 @@ impl JsQuery {
|
||||
}
|
||||
projection_vec
|
||||
});
|
||||
let filter = query_obj
|
||||
.get_opt::<JsString, _, _>(&mut cx, "_filter")?
|
||||
.map(|s| s.value(&mut cx));
|
||||
let refine_factor = query_obj
|
||||
.get_opt_u32(&mut cx, "_refineFactor")
|
||||
.or_throw(&mut cx)?;
|
||||
let nprobes = query_obj.get_usize(&mut cx, "_nprobes").or_throw(&mut cx)?;
|
||||
let metric_type = query_obj
|
||||
.get_opt::<JsString, _, _>(&mut cx, "_metricType")?
|
||||
.map(|s| s.value(&mut cx))
|
||||
.map(|s| MetricType::try_from(s.as_str()).unwrap());
|
||||
|
||||
let prefilter = query_obj
|
||||
.get::<JsBoolean, _, _>(&mut cx, "_prefilter")?
|
||||
@@ -65,25 +54,41 @@ impl JsQuery {
|
||||
|
||||
let (deferred, promise) = cx.promise();
|
||||
let channel = cx.channel();
|
||||
let query_vector = query_obj.get_opt::<JsArray, _, _>(&mut cx, "_queryVector")?;
|
||||
let table = js_table.table.clone();
|
||||
let query = query_vector.map(|q| convert::js_array_to_vec(q.deref(), &mut cx));
|
||||
|
||||
let query_vector = query_obj.get_opt::<JsArray, _, _>(&mut cx, "_queryVector")?;
|
||||
let mut builder = table.query();
|
||||
if let Some(query) = query_vector.map(|q| convert::js_array_to_vec(q.deref(), &mut cx)) {
|
||||
builder = builder.nearest_to(&query);
|
||||
if let Some(metric_type) = query_obj
|
||||
.get_opt::<JsString, _, _>(&mut cx, "_metricType")?
|
||||
.map(|s| s.value(&mut cx))
|
||||
.map(|s| MetricType::try_from(s.as_str()).unwrap())
|
||||
{
|
||||
builder = builder.metric_type(metric_type);
|
||||
}
|
||||
|
||||
let nprobes = query_obj.get_usize(&mut cx, "_nprobes").or_throw(&mut cx)?;
|
||||
builder = builder.nprobes(nprobes);
|
||||
};
|
||||
|
||||
if let Some(filter) = query_obj
|
||||
.get_opt::<JsString, _, _>(&mut cx, "_filter")?
|
||||
.map(|s| s.value(&mut cx))
|
||||
{
|
||||
builder = builder.filter(filter);
|
||||
}
|
||||
if let Some(select) = select {
|
||||
builder = builder.select(select.as_slice());
|
||||
}
|
||||
if let Some(limit) = limit {
|
||||
builder = builder.limit(limit as usize);
|
||||
};
|
||||
|
||||
builder = builder.prefilter(prefilter);
|
||||
|
||||
rt.spawn(async move {
|
||||
let mut builder = table.query();
|
||||
if let Some(query) = query {
|
||||
builder = builder
|
||||
.query_vector(&query)
|
||||
.refine_factor(refine_factor)
|
||||
.nprobes(nprobes)
|
||||
.metric_type(metric_type);
|
||||
};
|
||||
builder = builder.filter(filter).select(select).prefilter(prefilter);
|
||||
if let Some(limit) = limit {
|
||||
builder = builder.limit(limit as usize);
|
||||
};
|
||||
|
||||
let record_batch_stream = builder.execute();
|
||||
let record_batch_stream = builder.execute_stream();
|
||||
let results = record_batch_stream
|
||||
.and_then(|stream| {
|
||||
stream
|
||||
|
||||
@@ -14,13 +14,12 @@
|
||||
|
||||
use std::{cmp::max, sync::Arc};
|
||||
|
||||
use arrow_schema::Schema;
|
||||
use lance_index::{DatasetIndexExt, IndexType};
|
||||
pub use lance_linalg::distance::MetricType;
|
||||
|
||||
pub mod vector;
|
||||
|
||||
use crate::{Error, Result, Table};
|
||||
use crate::{utils::default_vector_column, Error, Result, Table};
|
||||
|
||||
/// Index Parameters.
|
||||
pub enum IndexParams {
|
||||
@@ -110,6 +109,7 @@ impl IndexBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
/// The columns to build index on.
|
||||
pub fn columns(&mut self, cols: &[&str]) -> &mut Self {
|
||||
self.columns = cols.iter().map(|s| s.to_string()).collect();
|
||||
self
|
||||
@@ -174,7 +174,7 @@ impl IndexBuilder {
|
||||
let columns = if self.columns.is_empty() {
|
||||
// By default we create vector index.
|
||||
index_type = &IndexType::Vector;
|
||||
vec![default_column_for_index(&schema)?]
|
||||
vec![default_vector_column(&schema, None)?]
|
||||
} else {
|
||||
self.columns.clone()
|
||||
};
|
||||
@@ -290,83 +290,3 @@ fn suggested_num_sub_vectors(dim: u32) -> u32 {
|
||||
1
|
||||
}
|
||||
}
|
||||
|
||||
/// Find one default column to create index.
|
||||
fn default_column_for_index(schema: &Schema) -> Result<String> {
|
||||
// Try to find one fixed size list array column.
|
||||
let candidates = schema
|
||||
.fields()
|
||||
.iter()
|
||||
.filter_map(|field| match field.data_type() {
|
||||
arrow_schema::DataType::FixedSizeList(f, _) if f.data_type().is_floating() => {
|
||||
Some(field.name())
|
||||
}
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
if candidates.is_empty() {
|
||||
Err(Error::Store {
|
||||
message: "No vector column found to create index".to_string(),
|
||||
})
|
||||
} else if candidates.len() != 1 {
|
||||
Err(Error::Store {
|
||||
message: format!(
|
||||
"More than one vector columns found, \
|
||||
please specify which column to create index: {:?}",
|
||||
candidates
|
||||
),
|
||||
})
|
||||
} else {
|
||||
Ok(candidates[0].to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use arrow_schema::{DataType, Field};
|
||||
|
||||
#[test]
|
||||
fn test_guess_default_column() {
|
||||
let schema_no_vector = Schema::new(vec![
|
||||
Field::new("id", DataType::Int16, true),
|
||||
Field::new("tag", DataType::Utf8, false),
|
||||
]);
|
||||
assert!(default_column_for_index(&schema_no_vector)
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("No vector column"));
|
||||
|
||||
let schema_with_vec_col = Schema::new(vec![
|
||||
Field::new("id", DataType::Int16, true),
|
||||
Field::new(
|
||||
"vec",
|
||||
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float64, false)), 10),
|
||||
false,
|
||||
),
|
||||
]);
|
||||
assert_eq!(
|
||||
default_column_for_index(&schema_with_vec_col).unwrap(),
|
||||
"vec"
|
||||
);
|
||||
|
||||
let multi_vec_col = Schema::new(vec![
|
||||
Field::new("id", DataType::Int16, true),
|
||||
Field::new(
|
||||
"vec",
|
||||
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float64, false)), 10),
|
||||
false,
|
||||
),
|
||||
Field::new(
|
||||
"vec2",
|
||||
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float64, false)), 50),
|
||||
false,
|
||||
),
|
||||
]);
|
||||
assert!(default_column_for_index(&multi_vec_col)
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("More than one"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -377,7 +377,7 @@ mod test {
|
||||
let q = t
|
||||
.search(&[0.1, 0.1, 0.1, 0.1])
|
||||
.limit(10)
|
||||
.execute()
|
||||
.execute_stream()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
||||
@@ -46,8 +46,8 @@
|
||||
//! #### Connect to a database.
|
||||
//!
|
||||
//! ```rust
|
||||
//! use vectordb::{connection::{Database, Connection}, WriteMode};
|
||||
//! use arrow_schema::{Field, Schema};
|
||||
//! use vectordb::connection::Database;
|
||||
//! # use arrow_schema::{Field, Schema};
|
||||
//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
//! let db = Database::connect("data/sample-lancedb").await.unwrap();
|
||||
//! # });
|
||||
@@ -55,7 +55,7 @@
|
||||
//!
|
||||
//! LanceDB uses [arrow-rs](https://github.com/apache/arrow-rs) to define schema, data types and array itself.
|
||||
//! It treats [`FixedSizeList<Float16/Float32>`](https://docs.rs/arrow/latest/arrow/array/struct.FixedSizeListArray.html)
|
||||
//! columns as vectors.
|
||||
//! columns as vector columns.
|
||||
//!
|
||||
//! #### Create a table
|
||||
//!
|
||||
@@ -90,6 +90,27 @@
|
||||
//! # });
|
||||
//! ```
|
||||
//!
|
||||
//! #### Create vector index (IVF_PQ)
|
||||
//!
|
||||
//! ```no_run
|
||||
//! # use std::sync::Arc;
|
||||
//! # use vectordb::connection::{Database, Connection};
|
||||
//! # use arrow_array::{FixedSizeListArray, types::Float32Type, RecordBatch,
|
||||
//! # RecordBatchIterator, Int32Array};
|
||||
//! # use arrow_schema::{Schema, Field, DataType};
|
||||
//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
//! # let tmpdir = tempfile::tempdir().unwrap();
|
||||
//! # let db = Database::connect(tmpdir.path().to_str().unwrap()).await.unwrap();
|
||||
//! # let tbl = db.open_table("idx_test").await.unwrap();
|
||||
//! tbl.create_index(&["vector"])
|
||||
//! .ivf_pq()
|
||||
//! .num_partitions(256)
|
||||
//! .build()
|
||||
//! .await
|
||||
//! .unwrap();
|
||||
//! # });
|
||||
//! ```
|
||||
//!
|
||||
//! #### Open table and run search
|
||||
//!
|
||||
//! ```rust
|
||||
@@ -120,7 +141,7 @@
|
||||
//! let table = db.open_table("my_table").await.unwrap();
|
||||
//! let results = table
|
||||
//! .search(&[1.0; 128])
|
||||
//! .execute()
|
||||
//! .execute_stream()
|
||||
//! .await
|
||||
//! .unwrap()
|
||||
//! .try_collect::<Vec<_>>()
|
||||
|
||||
@@ -15,26 +15,42 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::Float32Array;
|
||||
use arrow_schema::Schema;
|
||||
use lance::dataset::scanner::{DatasetRecordBatchStream, Scanner};
|
||||
use lance::dataset::Dataset;
|
||||
use lance_linalg::distance::MetricType;
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::utils::default_vector_column;
|
||||
|
||||
const DEFAULT_TOP_K: usize = 10;
|
||||
|
||||
/// A builder for nearest neighbor queries for LanceDB.
|
||||
#[derive(Clone)]
|
||||
pub struct Query {
|
||||
pub dataset: Arc<Dataset>,
|
||||
pub query_vector: Option<Float32Array>,
|
||||
pub column: String,
|
||||
pub limit: Option<usize>,
|
||||
pub filter: Option<String>,
|
||||
pub select: Option<Vec<String>>,
|
||||
pub nprobes: usize,
|
||||
pub refine_factor: Option<u32>,
|
||||
pub metric_type: Option<MetricType>,
|
||||
pub use_index: bool,
|
||||
pub prefilter: bool,
|
||||
dataset: Arc<Dataset>,
|
||||
|
||||
// The column to run the query on. If not specified, we will attempt to guess
|
||||
// the column based on the dataset's schema.
|
||||
column: Option<String>,
|
||||
|
||||
// IVF PQ - ANN search.
|
||||
query_vector: Option<Float32Array>,
|
||||
nprobes: usize,
|
||||
refine_factor: Option<u32>,
|
||||
metric_type: Option<MetricType>,
|
||||
|
||||
/// limit the number of rows to return.
|
||||
limit: Option<usize>,
|
||||
/// Apply filter to the returned rows.
|
||||
filter: Option<String>,
|
||||
/// Select column projection.
|
||||
select: Option<Vec<String>>,
|
||||
|
||||
/// Default is true. Set to false to enforce a brute force search.
|
||||
use_index: bool,
|
||||
/// Apply filter before ANN search/
|
||||
prefilter: bool,
|
||||
}
|
||||
|
||||
impl Query {
|
||||
@@ -42,17 +58,13 @@ impl Query {
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `dataset` - The table / dataset the query will be run against.
|
||||
/// * `vector` The vector used for this query.
|
||||
/// * `dataset` - Lance dataset.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * A [Query] object.
|
||||
pub(crate) fn new(dataset: Arc<Dataset>) -> Self {
|
||||
Query {
|
||||
dataset,
|
||||
query_vector: None,
|
||||
column: crate::table::VECTOR_COLUMN_NAME.to_string(),
|
||||
column: None,
|
||||
limit: None,
|
||||
nprobes: 20,
|
||||
refine_factor: None,
|
||||
@@ -64,17 +76,24 @@ impl Query {
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute the queries and return its results.
|
||||
/// Convert the query plan to a [`DatasetRecordBatchStream`]
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * A [DatasetRecordBatchStream] with the query's results.
|
||||
pub async fn execute(&self) -> Result<DatasetRecordBatchStream> {
|
||||
pub async fn execute_stream(&self) -> Result<DatasetRecordBatchStream> {
|
||||
let mut scanner: Scanner = self.dataset.scan();
|
||||
|
||||
if let Some(query) = self.query_vector.as_ref() {
|
||||
// If there is a vector query, default to limit=10 if unspecified
|
||||
scanner.nearest(&self.column, query, self.limit.unwrap_or(10))?;
|
||||
let column = if let Some(col) = self.column.as_ref() {
|
||||
col.clone()
|
||||
} else {
|
||||
// Infer a vector column with the same dimension of the query vector.
|
||||
let arrow_schema = Schema::from(self.dataset.schema());
|
||||
default_vector_column(&arrow_schema, Some(query.len() as i32))?
|
||||
};
|
||||
scanner.nearest(&column, query, self.limit.unwrap_or(DEFAULT_TOP_K))?;
|
||||
} else {
|
||||
// If there is no vector query, it's ok to not have a limit
|
||||
scanner.limit(self.limit.map(|limit| limit as i64), None)?;
|
||||
@@ -95,8 +114,8 @@ impl Query {
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `column` - The column name
|
||||
pub fn column(mut self, column: &str) -> Query {
|
||||
self.column = column.into();
|
||||
pub fn column(mut self, column: &str) -> Self {
|
||||
self.column = Some(column.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
@@ -105,17 +124,17 @@ impl Query {
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `limit` - The maximum number of results to return.
|
||||
pub fn limit(mut self, limit: usize) -> Query {
|
||||
pub fn limit(mut self, limit: usize) -> Self {
|
||||
self.limit = Some(limit);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the vector used for this query.
|
||||
/// Find the nearest vectors to the given query vector.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `vector` - The vector that will be used for search.
|
||||
pub fn query_vector(mut self, vector: &[f32]) -> Query {
|
||||
pub fn nearest_to(mut self, vector: &[f32]) -> Self {
|
||||
self.query_vector = Some(Float32Array::from(vector.to_vec()));
|
||||
self
|
||||
}
|
||||
@@ -125,7 +144,7 @@ impl Query {
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `nprobes` - The number of probes to use.
|
||||
pub fn nprobes(mut self, nprobes: usize) -> Query {
|
||||
pub fn nprobes(mut self, nprobes: usize) -> Self {
|
||||
self.nprobes = nprobes;
|
||||
self
|
||||
}
|
||||
@@ -135,8 +154,8 @@ impl Query {
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `refine_factor` - The refine factor to use.
|
||||
pub fn refine_factor(mut self, refine_factor: Option<u32>) -> Query {
|
||||
self.refine_factor = refine_factor;
|
||||
pub fn refine_factor(mut self, refine_factor: u32) -> Self {
|
||||
self.refine_factor = Some(refine_factor);
|
||||
self
|
||||
}
|
||||
|
||||
@@ -145,8 +164,8 @@ impl Query {
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `metric_type` - The distance metric to use. By default [MetricType::L2] is used.
|
||||
pub fn metric_type(mut self, metric_type: Option<MetricType>) -> Query {
|
||||
self.metric_type = metric_type;
|
||||
pub fn metric_type(mut self, metric_type: MetricType) -> Self {
|
||||
self.metric_type = Some(metric_type);
|
||||
self
|
||||
}
|
||||
|
||||
@@ -155,7 +174,7 @@ impl Query {
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `use_index` - Sets Whether to use an ANN index if available
|
||||
pub fn use_index(mut self, use_index: bool) -> Query {
|
||||
pub fn use_index(mut self, use_index: bool) -> Self {
|
||||
self.use_index = use_index;
|
||||
self
|
||||
}
|
||||
@@ -164,21 +183,21 @@ impl Query {
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `filter` - value A filter in the same format used by a sql WHERE clause.
|
||||
pub fn filter(mut self, filter: Option<String>) -> Query {
|
||||
self.filter = filter;
|
||||
/// * `filter` - SQL filter
|
||||
pub fn filter(mut self, filter: impl AsRef<str>) -> Self {
|
||||
self.filter = Some(filter.as_ref().to_string());
|
||||
self
|
||||
}
|
||||
|
||||
/// Return only the specified columns.
|
||||
///
|
||||
/// Only select the specified columns. If not specified, all columns will be returned.
|
||||
pub fn select(mut self, columns: Option<Vec<String>>) -> Query {
|
||||
self.select = columns;
|
||||
pub fn select(mut self, columns: &[impl AsRef<str>]) -> Self {
|
||||
self.select = Some(columns.iter().map(|c| c.as_ref().to_string()).collect());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn prefilter(mut self, prefilter: bool) -> Query {
|
||||
pub fn prefilter(mut self, prefilter: bool) -> Self {
|
||||
self.prefilter = prefilter;
|
||||
self
|
||||
}
|
||||
@@ -197,8 +216,10 @@ mod tests {
|
||||
use futures::StreamExt;
|
||||
use lance::dataset::Dataset;
|
||||
use lance_testing::datagen::{BatchGenerator, IncrementingInt32, RandomVector};
|
||||
use tempfile::tempdir;
|
||||
|
||||
use crate::query::Query;
|
||||
use crate::table::{NativeTable, Table};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_setters_getters() {
|
||||
@@ -206,18 +227,18 @@ mod tests {
|
||||
let ds = Dataset::write(batches, "memory://foo", None).await.unwrap();
|
||||
|
||||
let vector = Some(Float32Array::from_iter_values([0.1, 0.2]));
|
||||
let query = Query::new(Arc::new(ds)).query_vector(&[0.1, 0.2]);
|
||||
let query = Query::new(Arc::new(ds)).nearest_to(&[0.1, 0.2]);
|
||||
assert_eq!(query.query_vector, vector);
|
||||
|
||||
let new_vector = Float32Array::from_iter_values([9.8, 8.7]);
|
||||
|
||||
let query = query
|
||||
.query_vector(&[9.8, 8.7])
|
||||
.nearest_to(&[9.8, 8.7])
|
||||
.limit(100)
|
||||
.nprobes(1000)
|
||||
.use_index(true)
|
||||
.metric_type(Some(MetricType::Cosine))
|
||||
.refine_factor(Some(999));
|
||||
.metric_type(MetricType::Cosine)
|
||||
.refine_factor(999);
|
||||
|
||||
assert_eq!(query.query_vector.unwrap(), new_vector);
|
||||
assert_eq!(query.limit.unwrap(), 100);
|
||||
@@ -232,12 +253,8 @@ mod tests {
|
||||
let batches = make_non_empty_batches();
|
||||
let ds = Arc::new(Dataset::write(batches, "memory://foo", None).await.unwrap());
|
||||
|
||||
let query = Query::new(ds.clone()).query_vector(&[0.1; 4]);
|
||||
let result = query
|
||||
.limit(10)
|
||||
.filter(Some("id % 2 == 0".to_string()))
|
||||
.execute()
|
||||
.await;
|
||||
let query = Query::new(ds.clone()).nearest_to(&[0.1; 4]);
|
||||
let result = query.limit(10).filter("id % 2 == 0").execute_stream().await;
|
||||
let mut stream = result.expect("should have result");
|
||||
// should only have one batch
|
||||
while let Some(batch) = stream.next().await {
|
||||
@@ -245,12 +262,12 @@ mod tests {
|
||||
assert!(batch.expect("should be Ok").num_rows() < 10);
|
||||
}
|
||||
|
||||
let query = Query::new(ds).query_vector(&[0.1; 4]);
|
||||
let query = Query::new(ds).nearest_to(&[0.1; 4]);
|
||||
let result = query
|
||||
.limit(10)
|
||||
.filter(Some("id % 2 == 0".to_string()))
|
||||
.filter(String::from("id % 2 == 0")) // Work with String too
|
||||
.prefilter(true)
|
||||
.execute()
|
||||
.execute_stream()
|
||||
.await;
|
||||
let mut stream = result.expect("should have result");
|
||||
// should only have one batch
|
||||
@@ -267,10 +284,7 @@ mod tests {
|
||||
let ds = Arc::new(Dataset::write(batches, "memory://foo", None).await.unwrap());
|
||||
|
||||
let query = Query::new(ds.clone());
|
||||
let result = query
|
||||
.filter(Some("id % 2 == 0".to_string()))
|
||||
.execute()
|
||||
.await;
|
||||
let result = query.filter("id % 2 == 0").execute_stream().await;
|
||||
let mut stream = result.expect("should have result");
|
||||
// should only have one batch
|
||||
while let Some(batch) = stream.next().await {
|
||||
@@ -308,4 +322,21 @@ mod tests {
|
||||
schema,
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_search() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let dataset_path = tmp_dir.path().join("test.lance");
|
||||
let uri = dataset_path.to_str().unwrap();
|
||||
|
||||
let batches = make_test_batches();
|
||||
Dataset::write(batches, dataset_path.to_str().unwrap(), None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let table = NativeTable::open(uri).await.unwrap();
|
||||
|
||||
let query = table.search(&[0.1, 0.2]);
|
||||
assert_eq!(&[0.1, 0.2], query.query_vector.unwrap().values());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -120,7 +120,7 @@ pub trait Table: std::fmt::Display + Send + Sync {
|
||||
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
/// let tmpdir = tempfile::tempdir().unwrap();
|
||||
/// let db = Database::connect(tmpdir.path().to_str().unwrap()).await.unwrap();
|
||||
/// # let tbl = db.open_table("delete_test").await.unwrap();
|
||||
/// # let tbl = db.open_table("idx_test").await.unwrap();
|
||||
/// tbl.create_index(&["vector"])
|
||||
/// .ivf_pq()
|
||||
/// .num_partitions(256)
|
||||
@@ -132,11 +132,67 @@ pub trait Table: std::fmt::Display + Send + Sync {
|
||||
fn create_index(&self, column: &[&str]) -> IndexBuilder;
|
||||
|
||||
/// Search the table with a given query vector.
|
||||
///
|
||||
/// This is a convenience method for preparing an ANN query.
|
||||
fn search(&self, query: &[f32]) -> Query {
|
||||
self.query().query_vector(query)
|
||||
self.query().nearest_to(query)
|
||||
}
|
||||
|
||||
/// Create a Query builder.
|
||||
/// Create a generic [`Query`] Builder.
|
||||
///
|
||||
/// When appropriate, various indices and statistics based pruning will be used to
|
||||
/// accelerate the query.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ## Run a vector search (ANN) query.
|
||||
///
|
||||
/// ```no_run
|
||||
/// # use arrow_array::RecordBatch;
|
||||
/// # use futures::TryStreamExt;
|
||||
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
/// # let tbl = vectordb::table::NativeTable::open("/tmp/tbl").await.unwrap();
|
||||
/// let stream = tbl.query().nearest_to(&[1.0, 2.0, 3.0])
|
||||
/// .refine_factor(5)
|
||||
/// .nprobes(10)
|
||||
/// .execute_stream()
|
||||
/// .await
|
||||
/// .unwrap();
|
||||
/// let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
|
||||
/// # });
|
||||
/// ```
|
||||
///
|
||||
/// ## Run a SQL-style filter
|
||||
/// ```no_run
|
||||
/// # use arrow_array::RecordBatch;
|
||||
/// # use futures::TryStreamExt;
|
||||
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
/// # let tbl = vectordb::table::NativeTable::open("/tmp/tbl").await.unwrap();
|
||||
/// let stream = tbl
|
||||
/// .query()
|
||||
/// .filter("id > 5")
|
||||
/// .limit(1000)
|
||||
/// .execute_stream()
|
||||
/// .await
|
||||
/// .unwrap();
|
||||
/// let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
|
||||
/// # });
|
||||
/// ```
|
||||
///
|
||||
/// ## Run a full scan query.
|
||||
/// ```no_run
|
||||
/// # use arrow_array::RecordBatch;
|
||||
/// # use futures::TryStreamExt;
|
||||
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
/// # let tbl = vectordb::table::NativeTable::open("/tmp/tbl").await.unwrap();
|
||||
/// let stream = tbl
|
||||
/// .query()
|
||||
/// .execute_stream()
|
||||
/// .await
|
||||
/// .unwrap();
|
||||
/// let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
|
||||
/// # });
|
||||
/// ```
|
||||
fn query(&self) -> Query;
|
||||
}
|
||||
|
||||
@@ -362,7 +418,7 @@ impl NativeTable {
|
||||
}
|
||||
|
||||
pub fn filter(&self, expr: String) -> Query {
|
||||
Query::new(self.clone_inner_dataset().into()).filter(Some(expr))
|
||||
Query::new(self.clone_inner_dataset().into()).filter(expr)
|
||||
}
|
||||
|
||||
/// Returns the number of rows in this Table
|
||||
@@ -961,23 +1017,6 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_search() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let dataset_path = tmp_dir.path().join("test.lance");
|
||||
let uri = dataset_path.to_str().unwrap();
|
||||
|
||||
let batches = make_test_batches();
|
||||
Dataset::write(batches, dataset_path.to_str().unwrap(), None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let table = NativeTable::open(uri).await.unwrap();
|
||||
|
||||
let query = table.search(&[0.1, 0.2]);
|
||||
assert_eq!(&[0.1, 0.2], query.query_vector.unwrap().values());
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
struct NoOpCacheWrapper {
|
||||
called: AtomicBool,
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_schema::Schema;
|
||||
|
||||
use lance::dataset::{ReadParams, WriteParams};
|
||||
use lance::io::{ObjectStoreParams, WrappingObjectStore};
|
||||
|
||||
@@ -63,3 +65,86 @@ impl PatchReadParam for ReadParams {
|
||||
Ok(self)
|
||||
}
|
||||
}
|
||||
|
||||
/// Find one default column to create index.
|
||||
pub(crate) fn default_vector_column(schema: &Schema, dim: Option<i32>) -> Result<String> {
|
||||
// Try to find one fixed size list array column.
|
||||
let candidates = schema
|
||||
.fields()
|
||||
.iter()
|
||||
.filter_map(|field| match field.data_type() {
|
||||
arrow_schema::DataType::FixedSizeList(f, d)
|
||||
if f.data_type().is_floating()
|
||||
&& dim.map(|expect| *d == expect).unwrap_or(true) =>
|
||||
{
|
||||
Some(field.name())
|
||||
}
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
if candidates.is_empty() {
|
||||
Err(Error::Store {
|
||||
message: "No vector column found to create index".to_string(),
|
||||
})
|
||||
} else if candidates.len() != 1 {
|
||||
Err(Error::Store {
|
||||
message: format!(
|
||||
"More than one vector columns found, \
|
||||
please specify which column to create index: {:?}",
|
||||
candidates
|
||||
),
|
||||
})
|
||||
} else {
|
||||
Ok(candidates[0].to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use arrow_schema::{DataType, Field};
|
||||
|
||||
#[test]
|
||||
fn test_guess_default_column() {
|
||||
let schema_no_vector = Schema::new(vec![
|
||||
Field::new("id", DataType::Int16, true),
|
||||
Field::new("tag", DataType::Utf8, false),
|
||||
]);
|
||||
assert!(default_vector_column(&schema_no_vector, None)
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("No vector column"));
|
||||
|
||||
let schema_with_vec_col = Schema::new(vec![
|
||||
Field::new("id", DataType::Int16, true),
|
||||
Field::new(
|
||||
"vec",
|
||||
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float64, false)), 10),
|
||||
false,
|
||||
),
|
||||
]);
|
||||
assert_eq!(
|
||||
default_vector_column(&schema_with_vec_col, None).unwrap(),
|
||||
"vec"
|
||||
);
|
||||
|
||||
let multi_vec_col = Schema::new(vec![
|
||||
Field::new("id", DataType::Int16, true),
|
||||
Field::new(
|
||||
"vec",
|
||||
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float64, false)), 10),
|
||||
false,
|
||||
),
|
||||
Field::new(
|
||||
"vec2",
|
||||
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float64, false)), 50),
|
||||
false,
|
||||
),
|
||||
]);
|
||||
assert!(default_vector_column(&multi_vec_col, None)
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("More than one"));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user