diff --git a/rust/vectordb/src/error.rs b/rust/vectordb/src/error.rs index 4a8a9820..c7043eb9 100644 --- a/rust/vectordb/src/error.rs +++ b/rust/vectordb/src/error.rs @@ -32,8 +32,42 @@ pub enum Error { Store { message: String }, #[snafu(display("LanceDBError: {message}"))] Lance { message: String }, + #[snafu(display("Bad query: {message}"))] + InvalidQuery { message: String }, } +impl Error { + pub fn invalid_table_name(name: &str) -> Self { + Self::InvalidTableName { + name: name.to_string(), + } + } + + pub fn table_not_found(name: &str) -> Self { + Self::TableNotFound { + name: name.to_string(), + } + } + + pub fn table_already_exists(name: &str) -> Self { + Self::TableAlreadyExists { + name: name.to_string(), + } + } + + pub fn invalid_query(message: &str) -> Self { + Self::InvalidQuery { + message: message.to_string(), + } + } + + pub fn create_dir(path: &str, source: std::io::Error) -> Self { + Self::CreateDir { + path: path.to_string(), + source, + } + } +} pub type Result = std::result::Result; impl From for Error { diff --git a/rust/vectordb/src/query.rs b/rust/vectordb/src/query.rs index fb7bbb38..ddf27c47 100644 --- a/rust/vectordb/src/query.rs +++ b/rust/vectordb/src/query.rs @@ -14,12 +14,73 @@ use std::sync::Arc; -use arrow_array::Float32Array; -use lance::dataset::scanner::{DatasetRecordBatchStream, Scanner}; -use lance::dataset::Dataset; +use arrow_array::{Array, Float32Array}; +use arrow_schema::DataType; +use lance::dataset::{ + scanner::{DatasetRecordBatchStream, Scanner}, + Dataset, +}; +use lance::datatypes::Schema; use lance::index::vector::MetricType; -use crate::error::Result; +use crate::error::{Error, Result}; + +struct VectorQuery { + query: T, + column: String, + nprobs: usize, + refine_factor: Option, + metric_type: Option, + use_index: bool, +} + +/// Best effort to find potential vector columns in a [Schema], which is a fixed size column with +/// float number, where the list size is equal to the vector dimension. +/// +fn find_vector_columns(schema: &Schema, dim: i32) -> Vec { + schema + .fields + .iter() + .filter(|f| match &f.data_type() { + DataType::FixedSizeList(field, list_size) => { + *list_size == dim && field.data_type().is_floating() + } + _ => false, + }) + .map(|f| f.name.to_string()) + .collect() +} + +impl VectorQuery { + fn try_new(dataset: &Dataset, query: T) -> Result { + let schema = dataset.schema(); + let dim: i32 = query.len() as i32; + let vector_columns = find_vector_columns(schema, dim); + + if vector_columns.is_empty() { + return Err(Error::InvalidQuery { + message: format!("Unable to find a vector column with dimension {}", dim), + }); + }; + if vector_columns.len() != 1 { + return Err(Error::invalid_query( + "Vector query can be applied to more than one vector columns, please specify the column to use")); + } + + Ok(Self::with_column(query, &vector_columns[0])) + } + + fn with_column(query: T, column: &str) -> Self { + VectorQuery { + query, + column: column.to_string(), + nprobs: 20, + refine_factor: None, + metric_type: None, + use_index: true, + } + } +} /// A builder for nearest neighbor queries for LanceDB. pub struct Query { @@ -32,6 +93,7 @@ pub struct Query { pub refine_factor: Option, pub metric_type: Option, pub use_index: bool, + vector_query: Option>, } impl Query { @@ -48,6 +110,7 @@ impl Query { pub(crate) fn new(dataset: Arc, vector: Float32Array) -> Self { Query { dataset, + vector_query: None, query_vector: vector, limit: 10, nprobes: 20, @@ -101,6 +164,25 @@ impl Query { self } + pub fn vector_search(mut self, query: Float32Array) -> Result { + let dim = query.len(); + + self.query_vector = query; + Ok(self) + } + + /// Vector search on a given column. + pub fn vector_search_on(mut self, query: Float32Array, column: &str) -> Result { + if self.vector_query.is_some() { + return Err(Error::invalid_query("Vector search is already set")); + }; + + let dim = query.len(); + + self.query_vector = query; + Ok(self) + } + /// Set the number of probes to use. /// /// # Arguments @@ -162,9 +244,11 @@ impl Query { #[cfg(test)] mod tests { + use super::*; + use std::sync::Arc; - use arrow_array::{Float32Array, RecordBatch, RecordBatchIterator, RecordBatchReader}; + use arrow_array::{RecordBatch, RecordBatchIterator, RecordBatchReader}; use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema}; use lance::dataset::Dataset; use lance::index::vector::MetricType;