diff --git a/rust/vectordb/src/io/object_store.rs b/rust/vectordb/src/io/object_store.rs index 01c85d87..21b4d60a 100644 --- a/rust/vectordb/src/io/object_store.rs +++ b/rust/vectordb/src/io/object_store.rs @@ -359,7 +359,9 @@ mod test { assert_eq!(t.count_rows().await.unwrap(), 100); let q = t - .search(Some(PrimitiveArray::from_iter_values(vec![0.1, 0.1, 0.1, 0.1]))) + .search(Some(PrimitiveArray::from_iter_values(vec![ + 0.1, 0.1, 0.1, 0.1, + ]))) .limit(10) .execute() .await diff --git a/rust/vectordb/src/query.rs b/rust/vectordb/src/query.rs index f28a96c2..652b53dc 100644 --- a/rust/vectordb/src/query.rs +++ b/rust/vectordb/src/query.rs @@ -25,6 +25,7 @@ use crate::error::Result; pub struct Query { pub dataset: Arc, pub query_vector: Option, + pub column: String, pub limit: Option, pub filter: Option, pub select: Option>, @@ -50,6 +51,7 @@ impl Query { Query { dataset, query_vector: vector, + column: crate::table::VECTOR_COLUMN_NAME.to_string(), limit: None, nprobes: 20, refine_factor: None, @@ -71,7 +73,7 @@ impl Query { if let Some(query) = self.query_vector.as_ref() { // If there is a vector query, default to limit=10 if unspecified - scanner.nearest(crate::table::VECTOR_COLUMN_NAME, query, self.limit.unwrap_or(10))?; + scanner.nearest(&self.column, query, self.limit.unwrap_or(10))?; } else { // If there is no vector query, it's ok to not have a limit scanner.limit(self.limit.map(|limit| limit as i64), None)?; @@ -87,6 +89,16 @@ impl Query { Ok(scanner.try_into_stream().await?) } + /// Set the column to query + /// + /// # Arguments + /// + /// * `column` - The column name + pub fn column(mut self, column: &str) -> Query { + self.column = column.into(); + self + } + /// Set the maximum number of results to return. /// /// # Arguments @@ -176,7 +188,10 @@ mod tests { use std::sync::Arc; use super::*; - use arrow_array::{Float32Array, RecordBatch, RecordBatchIterator, RecordBatchReader, cast::AsArray, Int32Array}; + use arrow_array::{ + cast::AsArray, Float32Array, Int32Array, RecordBatch, RecordBatchIterator, + RecordBatchReader, + }; use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema}; use futures::StreamExt; use lance::dataset::Dataset; @@ -260,7 +275,7 @@ mod tests { let mut stream = result.expect("should have result"); // should only have one batch while let Some(batch) = stream.next().await { - let b = batch.expect("should be Ok"); + let b = batch.expect("should be Ok"); // cast arr into Int32Array let arr: &Int32Array = b["id"].as_primitive(); assert!(arr.iter().all(|x| x.unwrap() % 2 == 0));