mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 05:19:58 +00:00
Compare commits
1 Commits
aidangomar
...
lei/better
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
09585390f7 |
@@ -32,8 +32,42 @@ pub enum Error {
|
||||
Store { message: String },
|
||||
#[snafu(display("LanceDBError: {message}"))]
|
||||
Lance { message: String },
|
||||
#[snafu(display("Bad query: {message}"))]
|
||||
InvalidQuery { message: String },
|
||||
}
|
||||
|
||||
impl Error {
|
||||
pub fn invalid_table_name(name: &str) -> Self {
|
||||
Self::InvalidTableName {
|
||||
name: name.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn table_not_found(name: &str) -> Self {
|
||||
Self::TableNotFound {
|
||||
name: name.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn table_already_exists(name: &str) -> Self {
|
||||
Self::TableAlreadyExists {
|
||||
name: name.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn invalid_query(message: &str) -> Self {
|
||||
Self::InvalidQuery {
|
||||
message: message.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_dir(path: &str, source: std::io::Error) -> Self {
|
||||
Self::CreateDir {
|
||||
path: path.to_string(),
|
||||
source,
|
||||
}
|
||||
}
|
||||
}
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
impl From<lance::Error> for Error {
|
||||
|
||||
@@ -14,12 +14,73 @@
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::Float32Array;
|
||||
use lance::dataset::scanner::{DatasetRecordBatchStream, Scanner};
|
||||
use lance::dataset::Dataset;
|
||||
use arrow_array::{Array, Float32Array};
|
||||
use arrow_schema::DataType;
|
||||
use lance::dataset::{
|
||||
scanner::{DatasetRecordBatchStream, Scanner},
|
||||
Dataset,
|
||||
};
|
||||
use lance::datatypes::Schema;
|
||||
use lance::index::vector::MetricType;
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::error::{Error, Result};
|
||||
|
||||
struct VectorQuery<T: Array> {
|
||||
query: T,
|
||||
column: String,
|
||||
nprobs: usize,
|
||||
refine_factor: Option<u32>,
|
||||
metric_type: Option<MetricType>,
|
||||
use_index: bool,
|
||||
}
|
||||
|
||||
/// Best effort to find potential vector columns in a [Schema], which is a fixed size column with
|
||||
/// float number, where the list size is equal to the vector dimension.
|
||||
///
|
||||
fn find_vector_columns(schema: &Schema, dim: i32) -> Vec<String> {
|
||||
schema
|
||||
.fields
|
||||
.iter()
|
||||
.filter(|f| match &f.data_type() {
|
||||
DataType::FixedSizeList(field, list_size) => {
|
||||
*list_size == dim && field.data_type().is_floating()
|
||||
}
|
||||
_ => false,
|
||||
})
|
||||
.map(|f| f.name.to_string())
|
||||
.collect()
|
||||
}
|
||||
|
||||
impl<T: Array> VectorQuery<T> {
|
||||
fn try_new(dataset: &Dataset, query: T) -> Result<Self> {
|
||||
let schema = dataset.schema();
|
||||
let dim: i32 = query.len() as i32;
|
||||
let vector_columns = find_vector_columns(schema, dim);
|
||||
|
||||
if vector_columns.is_empty() {
|
||||
return Err(Error::InvalidQuery {
|
||||
message: format!("Unable to find a vector column with dimension {}", dim),
|
||||
});
|
||||
};
|
||||
if vector_columns.len() != 1 {
|
||||
return Err(Error::invalid_query(
|
||||
"Vector query can be applied to more than one vector columns, please specify the column to use"));
|
||||
}
|
||||
|
||||
Ok(Self::with_column(query, &vector_columns[0]))
|
||||
}
|
||||
|
||||
fn with_column(query: T, column: &str) -> Self {
|
||||
VectorQuery {
|
||||
query,
|
||||
column: column.to_string(),
|
||||
nprobs: 20,
|
||||
refine_factor: None,
|
||||
metric_type: None,
|
||||
use_index: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A builder for nearest neighbor queries for LanceDB.
|
||||
pub struct Query {
|
||||
@@ -32,6 +93,7 @@ pub struct Query {
|
||||
pub refine_factor: Option<u32>,
|
||||
pub metric_type: Option<MetricType>,
|
||||
pub use_index: bool,
|
||||
vector_query: Option<VectorQuery<Float32Array>>,
|
||||
}
|
||||
|
||||
impl Query {
|
||||
@@ -48,6 +110,7 @@ impl Query {
|
||||
pub(crate) fn new(dataset: Arc<Dataset>, vector: Float32Array) -> Self {
|
||||
Query {
|
||||
dataset,
|
||||
vector_query: None,
|
||||
query_vector: vector,
|
||||
limit: 10,
|
||||
nprobes: 20,
|
||||
@@ -101,6 +164,25 @@ impl Query {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn vector_search(mut self, query: Float32Array) -> Result<Query> {
|
||||
let dim = query.len();
|
||||
|
||||
self.query_vector = query;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Vector search on a given column.
|
||||
pub fn vector_search_on(mut self, query: Float32Array, column: &str) -> Result<Query> {
|
||||
if self.vector_query.is_some() {
|
||||
return Err(Error::invalid_query("Vector search is already set"));
|
||||
};
|
||||
|
||||
let dim = query.len();
|
||||
|
||||
self.query_vector = query;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Set the number of probes to use.
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -162,9 +244,11 @@ impl Query {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::{Float32Array, RecordBatch, RecordBatchIterator, RecordBatchReader};
|
||||
use arrow_array::{RecordBatch, RecordBatchIterator, RecordBatchReader};
|
||||
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
|
||||
use lance::dataset::Dataset;
|
||||
use lance::index::vector::MetricType;
|
||||
|
||||
Reference in New Issue
Block a user