feat: allow custom column name in query (#709)

This commit is contained in:
Rob Meng
2023-12-14 23:29:26 -05:00
committed by GitHub
parent bd0034a157
commit 7c09b9b9a9
2 changed files with 21 additions and 4 deletions

View File

@@ -359,7 +359,9 @@ mod test {
assert_eq!(t.count_rows().await.unwrap(), 100);
let q = t
.search(Some(PrimitiveArray::from_iter_values(vec![0.1, 0.1, 0.1, 0.1])))
.search(Some(PrimitiveArray::from_iter_values(vec![
0.1, 0.1, 0.1, 0.1,
])))
.limit(10)
.execute()
.await

View File

@@ -25,6 +25,7 @@ use crate::error::Result;
pub struct Query {
pub dataset: Arc<Dataset>,
pub query_vector: Option<Float32Array>,
pub column: String,
pub limit: Option<usize>,
pub filter: Option<String>,
pub select: Option<Vec<String>>,
@@ -50,6 +51,7 @@ impl Query {
Query {
dataset,
query_vector: vector,
column: crate::table::VECTOR_COLUMN_NAME.to_string(),
limit: None,
nprobes: 20,
refine_factor: None,
@@ -71,7 +73,7 @@ impl Query {
if let Some(query) = self.query_vector.as_ref() {
// If there is a vector query, default to limit=10 if unspecified
scanner.nearest(crate::table::VECTOR_COLUMN_NAME, query, self.limit.unwrap_or(10))?;
scanner.nearest(&self.column, query, self.limit.unwrap_or(10))?;
} else {
// If there is no vector query, it's ok to not have a limit
scanner.limit(self.limit.map(|limit| limit as i64), None)?;
@@ -87,6 +89,16 @@ impl Query {
Ok(scanner.try_into_stream().await?)
}
/// Set the column to query
///
/// # Arguments
///
/// * `column` - The column name
pub fn column(mut self, column: &str) -> Query {
self.column = column.into();
self
}
/// Set the maximum number of results to return.
///
/// # Arguments
@@ -176,7 +188,10 @@ mod tests {
use std::sync::Arc;
use super::*;
use arrow_array::{Float32Array, RecordBatch, RecordBatchIterator, RecordBatchReader, cast::AsArray, Int32Array};
use arrow_array::{
cast::AsArray, Float32Array, Int32Array, RecordBatch, RecordBatchIterator,
RecordBatchReader,
};
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
use futures::StreamExt;
use lance::dataset::Dataset;
@@ -260,7 +275,7 @@ mod tests {
let mut stream = result.expect("should have result");
// should only have one batch
while let Some(batch) = stream.next().await {
let b = batch.expect("should be Ok");
let b = batch.expect("should be Ok");
// cast arr into Int32Array
let arr: &Int32Array = b["id"].as_primitive();
assert!(arr.iter().all(|x| x.unwrap() % 2 == 0));