chore: expose prefilter in lancedb rust (#674)

expose prefilter flag in vectordb rust code.
This commit is contained in:
Rob Meng
2023-12-01 00:44:14 -05:00
committed by Weston Pace
parent a94a033553
commit c1c3083b74

View File

@@ -32,6 +32,7 @@ pub struct Query {
pub refine_factor: Option<u32>,
pub metric_type: Option<MetricType>,
pub use_index: bool,
pub prefilter: bool,
}
impl Query {
@@ -56,6 +57,7 @@ impl Query {
use_index: true,
filter: None,
select: None,
prefilter: false,
}
}
@@ -74,6 +76,8 @@ impl Query {
)?;
scanner.nprobs(self.nprobes);
scanner.use_index(self.use_index);
scanner.prefilter(self.prefilter);
self.select.as_ref().map(|p| scanner.project(p.as_slice()));
self.filter.as_ref().map(|f| scanner.filter(f));
self.refine_factor.map(|rf| scanner.refine(rf));
@@ -158,6 +162,11 @@ impl Query {
self.select = columns;
self
}
pub fn prefilter(mut self, prefilter: bool) -> Query {
self.prefilter = prefilter;
self
}
}
#[cfg(test)]
@@ -167,7 +176,9 @@ mod tests {
use super::*;
use arrow_array::{Float32Array, RecordBatch, RecordBatchIterator, RecordBatchReader};
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
use futures::StreamExt;
use lance::dataset::Dataset;
use lance_testing::datagen::{BatchGenerator, IncrementingInt32, RandomVector};
use crate::query::Query;
@@ -200,13 +211,43 @@ mod tests {
#[tokio::test]
async fn test_execute() {
let batches = make_test_batches();
let ds = Dataset::write(batches, "memory://foo", None).await.unwrap();
let batches = make_non_empty_batches();
let ds = Arc::new(Dataset::write(batches, "memory://foo", None).await.unwrap());
let vector = Float32Array::from_iter_values([0.1; 128]);
let query = Query::new(Arc::new(ds), vector.clone());
let result = query.execute().await;
assert_eq!(result.is_ok(), true);
let vector = Float32Array::from_iter_values([0.1; 4]);
let query = Query::new(ds.clone(), vector.clone());
let result = query
.limit(10)
.filter(Some("id % 2 == 0".to_string()))
.execute()
.await;
let mut stream = result.expect("should have result");
// should only have one batch
while let Some(batch) = stream.next().await {
// post filter should have removed some rows
assert!(batch.expect("should be Ok").num_rows() < 10);
}
let query = Query::new(ds, vector.clone());
let result = query
.limit(10)
.filter(Some("id % 2 == 0".to_string()))
.prefilter(true)
.execute()
.await;
let mut stream = result.expect("should have result");
// should only have one batch
while let Some(batch) = stream.next().await {
// pre filter should return 10 rows
assert!(batch.expect("should be Ok").num_rows() == 10);
}
}
fn make_non_empty_batches() -> impl RecordBatchReader + Send + 'static {
let vec = Box::new(RandomVector::new().named("vector".to_string()));
let id = Box::new(IncrementingInt32::new().named("id".to_string()));
BatchGenerator::new().col(vec).col(id).batch(512)
}
fn make_test_batches() -> impl RecordBatchReader + Send + 'static {