flexible vector column

This commit is contained in:
Lei Xu
2023-08-29 22:09:58 -07:00
parent 8391ffee84
commit 09585390f7
2 changed files with 123 additions and 5 deletions

View File

@@ -32,8 +32,42 @@ pub enum Error {
Store { message: String },
#[snafu(display("LanceDBError: {message}"))]
Lance { message: String },
#[snafu(display("Bad query: {message}"))]
InvalidQuery { message: String },
}
impl Error {
pub fn invalid_table_name(name: &str) -> Self {
Self::InvalidTableName {
name: name.to_string(),
}
}
pub fn table_not_found(name: &str) -> Self {
Self::TableNotFound {
name: name.to_string(),
}
}
pub fn table_already_exists(name: &str) -> Self {
Self::TableAlreadyExists {
name: name.to_string(),
}
}
pub fn invalid_query(message: &str) -> Self {
Self::InvalidQuery {
message: message.to_string(),
}
}
pub fn create_dir(path: &str, source: std::io::Error) -> Self {
Self::CreateDir {
path: path.to_string(),
source,
}
}
}
pub type Result<T> = std::result::Result<T, Error>;
impl From<lance::Error> for Error {

View File

@@ -14,12 +14,73 @@
use std::sync::Arc;
use arrow_array::Float32Array;
use lance::dataset::scanner::{DatasetRecordBatchStream, Scanner};
use lance::dataset::Dataset;
use arrow_array::{Array, Float32Array};
use arrow_schema::DataType;
use lance::dataset::{
scanner::{DatasetRecordBatchStream, Scanner},
Dataset,
};
use lance::datatypes::Schema;
use lance::index::vector::MetricType;
use crate::error::Result;
use crate::error::{Error, Result};
struct VectorQuery<T: Array> {
query: T,
column: String,
nprobs: usize,
refine_factor: Option<u32>,
metric_type: Option<MetricType>,
use_index: bool,
}
/// Best effort to find potential vector columns in a [Schema], which is a fixed size column with
/// float number, where the list size is equal to the vector dimension.
///
fn find_vector_columns(schema: &Schema, dim: i32) -> Vec<String> {
schema
.fields
.iter()
.filter(|f| match &f.data_type() {
DataType::FixedSizeList(field, list_size) => {
*list_size == dim && field.data_type().is_floating()
}
_ => false,
})
.map(|f| f.name.to_string())
.collect()
}
impl<T: Array> VectorQuery<T> {
fn try_new(dataset: &Dataset, query: T) -> Result<Self> {
let schema = dataset.schema();
let dim: i32 = query.len() as i32;
let vector_columns = find_vector_columns(schema, dim);
if vector_columns.is_empty() {
return Err(Error::InvalidQuery {
message: format!("Unable to find a vector column with dimension {}", dim),
});
};
if vector_columns.len() != 1 {
return Err(Error::invalid_query(
"Vector query can be applied to more than one vector columns, please specify the column to use"));
}
Ok(Self::with_column(query, &vector_columns[0]))
}
fn with_column(query: T, column: &str) -> Self {
VectorQuery {
query,
column: column.to_string(),
nprobs: 20,
refine_factor: None,
metric_type: None,
use_index: true,
}
}
}
/// A builder for nearest neighbor queries for LanceDB.
pub struct Query {
@@ -32,6 +93,7 @@ pub struct Query {
pub refine_factor: Option<u32>,
pub metric_type: Option<MetricType>,
pub use_index: bool,
vector_query: Option<VectorQuery<Float32Array>>,
}
impl Query {
@@ -48,6 +110,7 @@ impl Query {
pub(crate) fn new(dataset: Arc<Dataset>, vector: Float32Array) -> Self {
Query {
dataset,
vector_query: None,
query_vector: vector,
limit: 10,
nprobes: 20,
@@ -101,6 +164,25 @@ impl Query {
self
}
pub fn vector_search(mut self, query: Float32Array) -> Result<Query> {
let dim = query.len();
self.query_vector = query;
Ok(self)
}
/// Vector search on a given column.
pub fn vector_search_on(mut self, query: Float32Array, column: &str) -> Result<Query> {
if self.vector_query.is_some() {
return Err(Error::invalid_query("Vector search is already set"));
};
let dim = query.len();
self.query_vector = query;
Ok(self)
}
/// Set the number of probes to use.
///
/// # Arguments
@@ -162,9 +244,11 @@ impl Query {
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use arrow_array::{Float32Array, RecordBatch, RecordBatchIterator, RecordBatchReader};
use arrow_array::{RecordBatch, RecordBatchIterator, RecordBatchReader};
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
use lance::dataset::Dataset;
use lance::index::vector::MetricType;