mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-04 02:42:57 +00:00
feat: allow custom column name in query (#709)
This commit is contained in:
@@ -359,7 +359,9 @@ mod test {
|
||||
assert_eq!(t.count_rows().await.unwrap(), 100);
|
||||
|
||||
let q = t
|
||||
.search(Some(PrimitiveArray::from_iter_values(vec![0.1, 0.1, 0.1, 0.1])))
|
||||
.search(Some(PrimitiveArray::from_iter_values(vec![
|
||||
0.1, 0.1, 0.1, 0.1,
|
||||
])))
|
||||
.limit(10)
|
||||
.execute()
|
||||
.await
|
||||
|
||||
@@ -25,6 +25,7 @@ use crate::error::Result;
|
||||
pub struct Query {
|
||||
pub dataset: Arc<Dataset>,
|
||||
pub query_vector: Option<Float32Array>,
|
||||
pub column: String,
|
||||
pub limit: Option<usize>,
|
||||
pub filter: Option<String>,
|
||||
pub select: Option<Vec<String>>,
|
||||
@@ -50,6 +51,7 @@ impl Query {
|
||||
Query {
|
||||
dataset,
|
||||
query_vector: vector,
|
||||
column: crate::table::VECTOR_COLUMN_NAME.to_string(),
|
||||
limit: None,
|
||||
nprobes: 20,
|
||||
refine_factor: None,
|
||||
@@ -71,7 +73,7 @@ impl Query {
|
||||
|
||||
if let Some(query) = self.query_vector.as_ref() {
|
||||
// If there is a vector query, default to limit=10 if unspecified
|
||||
scanner.nearest(crate::table::VECTOR_COLUMN_NAME, query, self.limit.unwrap_or(10))?;
|
||||
scanner.nearest(&self.column, query, self.limit.unwrap_or(10))?;
|
||||
} else {
|
||||
// If there is no vector query, it's ok to not have a limit
|
||||
scanner.limit(self.limit.map(|limit| limit as i64), None)?;
|
||||
@@ -87,6 +89,16 @@ impl Query {
|
||||
Ok(scanner.try_into_stream().await?)
|
||||
}
|
||||
|
||||
/// Set the column to query
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `column` - The column name
|
||||
pub fn column(mut self, column: &str) -> Query {
|
||||
self.column = column.into();
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the maximum number of results to return.
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -176,7 +188,10 @@ mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::*;
|
||||
use arrow_array::{Float32Array, RecordBatch, RecordBatchIterator, RecordBatchReader, cast::AsArray, Int32Array};
|
||||
use arrow_array::{
|
||||
cast::AsArray, Float32Array, Int32Array, RecordBatch, RecordBatchIterator,
|
||||
RecordBatchReader,
|
||||
};
|
||||
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
|
||||
use futures::StreamExt;
|
||||
use lance::dataset::Dataset;
|
||||
@@ -260,7 +275,7 @@ mod tests {
|
||||
let mut stream = result.expect("should have result");
|
||||
// should only have one batch
|
||||
while let Some(batch) = stream.next().await {
|
||||
let b = batch.expect("should be Ok");
|
||||
let b = batch.expect("should be Ok");
|
||||
// cast arr into Int32Array
|
||||
let arr: &Int32Array = b["id"].as_primitive();
|
||||
assert!(arr.iter().all(|x| x.unwrap() % 2 == 0));
|
||||
|
||||
Reference in New Issue
Block a user