diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index 33d01858..e9c06a82 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -998,4 +998,18 @@ describe("column name options", () => { const results = await table.query().where("`camelCase` = 1").toArray(); expect(results[0].camelCase).toBe(1); }); + + test("can make multiple vector queries in one go", async () => { + const results = await table + .query() + .nearestTo([0.1, 0.2]) + .addQueryVector([0.1, 0.2]) + .limit(1) + .toArray(); + console.log(results); + expect(results.length).toBe(2); + results.sort((a, b) => a.query_index - b.query_index); + expect(results[0].query_index).toBe(0); + expect(results[1].query_index).toBe(1); + }); }); diff --git a/nodejs/lancedb/query.ts b/nodejs/lancedb/query.ts index 9a1ab8bf..d29babb3 100644 --- a/nodejs/lancedb/query.ts +++ b/nodejs/lancedb/query.ts @@ -492,6 +492,42 @@ export class VectorQuery extends QueryBase { super.doCall((inner) => inner.bypassVectorIndex()); return this; } + + /* + * Add a query vector to the search + * + * This method can be called multiple times to add multiple query vectors + * to the search. If multiple query vectors are added, then they will be searched + * in parallel, and the results will be concatenated. A column called `query_index` + * will be added to indicate the index of the query vector that produced the result. + * + * Performance wise, this is equivalent to running multiple queries concurrently. + */ + addQueryVector(vector: IntoVector): VectorQuery { + if (vector instanceof Promise) { + const res = (async () => { + try { + const v = await vector; + const arr = Float32Array.from(v); + // + // biome-ignore lint/suspicious/noExplicitAny: we need to get the `inner`, but js has no package scoping + const value: any = this.addQueryVector(arr); + const inner = value.inner as + | NativeVectorQuery + | Promise; + return inner; + } catch (e) { + return Promise.reject(e); + } + })(); + return new VectorQuery(res); + } else { + super.doCall((inner) => { + inner.addQueryVector(Float32Array.from(vector)); + }); + return this; + } + } } /** A builder for LanceDB queries. */ diff --git a/nodejs/src/query.rs b/nodejs/src/query.rs index 448ca134..57eb24c4 100644 --- a/nodejs/src/query.rs +++ b/nodejs/src/query.rs @@ -135,6 +135,16 @@ impl VectorQuery { self.inner = self.inner.clone().column(&column); } + #[napi] + pub fn add_query_vector(&mut self, vector: Float32Array) -> Result<()> { + self.inner = self + .inner + .clone() + .add_query_vector(vector.as_ref()) + .default_error()?; + Ok(()) + } + #[napi] pub fn distance_type(&mut self, distance_type: String) -> napi::Result<()> { let distance_type = parse_distance_type(distance_type)?; diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 09eaa414..b9fb1ec4 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -1491,7 +1491,7 @@ class AsyncQuery(AsyncQueryBase): return pa.array(vec) def nearest_to( - self, query_vector: Optional[Union[VEC, Tuple]] = None + self, query_vector: Optional[Union[VEC, Tuple, List[VEC]]] = None ) -> AsyncVectorQuery: """ Find the nearest vectors to the given query vector. @@ -1529,10 +1529,30 @@ class AsyncQuery(AsyncQueryBase): Vector searches always have a [limit][]. If `limit` has not been called then a default `limit` of 10 will be used. + + Typically, a single vector is passed in as the query. However, you can also + pass in multiple vectors. This can be useful if you want to find the nearest + vectors to multiple query vectors. This is not expected to be faster than + making multiple queries concurrently; it is just a convenience method. + If multiple vectors are passed in then an additional column `query_index` + will be added to the results. This column will contain the index of the + query vector that the result is nearest to. """ - return AsyncVectorQuery( - self._inner.nearest_to(AsyncQuery._query_vec_to_array(query_vector)) - ) + if ( + isinstance(query_vector, list) + and len(query_vector) > 0 + and not isinstance(query_vector[0], (float, int)) + ): + # multiple have been passed + query_vectors = [AsyncQuery._query_vec_to_array(v) for v in query_vector] + new_self = self._inner.nearest_to(query_vectors[0]) + for v in query_vectors[1:]: + new_self.add_query_vector(v) + return AsyncVectorQuery(new_self) + else: + return AsyncVectorQuery( + self._inner.nearest_to(AsyncQuery._query_vec_to_array(query_vector)) + ) def nearest_to_text( self, query: str, columns: Union[str, List[str]] = [] diff --git a/python/python/tests/test_remote_db.py b/python/python/tests/test_remote_db.py index bc3a2783..62d39f0c 100644 --- a/python/python/tests/test_remote_db.py +++ b/python/python/tests/test_remote_db.py @@ -229,6 +229,17 @@ def test_query_sync_maximal(): ) +def test_query_sync_multiple_vectors(): + def handler(_body): + return pa.table({"id": [1]}) + + with query_test_table(handler) as table: + results = table.search([[1, 2, 3], [4, 5, 6]]).limit(1).to_list() + assert len(results) == 2 + results.sort(key=lambda x: x["query_index"]) + assert results == [{"id": 1, "query_index": 0}, {"id": 1, "query_index": 1}] + + def test_query_sync_fts(): def handler(body): assert body == { diff --git a/python/src/query.rs b/python/src/query.rs index 0f93f9ce..8e4042df 100644 --- a/python/src/query.rs +++ b/python/src/query.rs @@ -142,6 +142,13 @@ impl VectorQuery { self.inner = self.inner.clone().only_if(predicate); } + pub fn add_query_vector(&mut self, vector: Bound<'_, PyAny>) -> PyResult<()> { + let data: ArrayData = ArrayData::from_pyarrow_bound(&vector)?; + let array = make_array(data); + self.inner = self.inner.clone().add_query_vector(array).infer_error()?; + Ok(()) + } + pub fn select(&mut self, columns: Vec<(String, String)>) { self.inner = self.inner.clone().select(Select::dynamic(&columns)); } diff --git a/rust/lancedb/src/query.rs b/rust/lancedb/src/query.rs index 135f46a1..db9bf311 100644 --- a/rust/lancedb/src/query.rs +++ b/rust/lancedb/src/query.rs @@ -475,6 +475,7 @@ impl QueryBase for T { /// Options for controlling the execution of a query #[non_exhaustive] +#[derive(Debug, Clone)] pub struct QueryExecutionOptions { /// The maximum number of rows that will be contained in a single /// `RecordBatch` delivered by the query. @@ -650,7 +651,7 @@ impl Query { pub fn nearest_to(self, vector: impl IntoQueryVector) -> Result { let mut vector_query = self.into_vector(); let query_vector = vector.to_query_vector(&DataType::Float32, "default")?; - vector_query.query_vector = Some(query_vector); + vector_query.query_vector.push(query_vector); Ok(vector_query) } } @@ -701,7 +702,7 @@ pub struct VectorQuery { // the column based on the dataset's schema. pub(crate) column: Option, // IVF PQ - ANN search. - pub(crate) query_vector: Option>, + pub(crate) query_vector: Vec>, pub(crate) nprobes: usize, pub(crate) refine_factor: Option, pub(crate) distance_type: Option, @@ -714,7 +715,7 @@ impl VectorQuery { Self { base, column: None, - query_vector: None, + query_vector: Vec::new(), nprobes: 20, refine_factor: None, distance_type: None, @@ -734,6 +735,22 @@ impl VectorQuery { self } + /// Add another query vector to the search. + /// + /// Multiple searches will be dispatched as part of the query. + /// This is a convenience method for adding multiple query vectors + /// to the search. It is not expected to be faster than issuing + /// multiple queries concurrently. + /// + /// The output data will contain an additional columns `query_index` which + /// will contain the index of the query vector that was used to generate the + /// result. + pub fn add_query_vector(mut self, vector: impl IntoQueryVector) -> Result { + let query_vector = vector.to_query_vector(&DataType::Float32, "default")?; + self.query_vector.push(query_vector); + Ok(self) + } + /// Set the number of partitions to search (probe) /// /// This argument is only used when the vector column has an IVF PQ index. @@ -854,6 +871,7 @@ mod tests { use std::sync::Arc; use super::*; + use arrow::{compute::concat_batches, datatypes::Int32Type}; use arrow_array::{ cast::AsArray, Float32Array, Int32Array, RecordBatch, RecordBatchIterator, RecordBatchReader, @@ -883,7 +901,10 @@ mod tests { let vector = Float32Array::from_iter_values([0.1, 0.2]); let query = table.query().nearest_to(&[0.1, 0.2]).unwrap(); - assert_eq!(*query.query_vector.unwrap().as_ref().as_primitive(), vector); + assert_eq!( + *query.query_vector.first().unwrap().as_ref().as_primitive(), + vector + ); let new_vector = Float32Array::from_iter_values([9.8, 8.7]); @@ -899,7 +920,7 @@ mod tests { .refine_factor(999); assert_eq!( - *query.query_vector.unwrap().as_ref().as_primitive(), + *query.query_vector.first().unwrap().as_ref().as_primitive(), new_vector ); assert_eq!(query.base.limit.unwrap(), 100); @@ -1197,4 +1218,34 @@ mod tests { assert!(batch.column_by_name("_rowid").is_some()); } } + + #[tokio::test] + async fn test_multiple_query_vectors() { + let tmp_dir = tempdir().unwrap(); + let table = make_test_table(&tmp_dir).await; + let query = table + .query() + .nearest_to(&[0.1, 0.2, 0.3, 0.4]) + .unwrap() + .add_query_vector(&[0.5, 0.6, 0.7, 0.8]) + .unwrap() + .limit(1); + + let plan = query.explain_plan(true).await.unwrap(); + assert!(plan.contains("UnionExec")); + + let results = query + .execute() + .await + .unwrap() + .try_collect::>() + .await + .unwrap(); + let results = concat_batches(&results[0].schema(), &results).unwrap(); + assert_eq!(results.num_rows(), 2); // One result for each query vector. + let query_index = results["query_index"].as_primitive::(); + // We don't guarantee order. + assert!(query_index.values().contains(&0)); + assert!(query_index.values().contains(&1)); + } } diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index a8754cc3..55cabf95 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -6,7 +6,7 @@ use crate::index::IndexStatistics; use crate::query::Select; use crate::table::AddDataMode; use crate::utils::{supported_btree_data_type, supported_vector_data_type}; -use crate::Error; +use crate::{Error, Table}; use arrow_array::RecordBatchReader; use arrow_ipc::reader::FileReader; use arrow_schema::{DataType, SchemaRef}; @@ -185,6 +185,71 @@ impl RemoteTable { Ok(()) } + + fn apply_vector_query_params( + mut body: serde_json::Value, + query: &VectorQuery, + ) -> Result> { + Self::apply_query_params(&mut body, &query.base)?; + + // Apply general parameters, before we dispatch based on number of query vectors. + body["prefilter"] = query.base.prefilter.into(); + body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default()); + body["nprobes"] = query.nprobes.into(); + body["refine_factor"] = query.refine_factor.into(); + if let Some(vector_column) = query.column.as_ref() { + body["vector_column"] = serde_json::Value::String(vector_column.clone()); + } + if !query.use_index { + body["bypass_vector_index"] = serde_json::Value::Bool(true); + } + + fn vector_to_json(vector: &arrow_array::ArrayRef) -> Result { + match vector.data_type() { + DataType::Float32 => { + let array = vector + .as_any() + .downcast_ref::() + .unwrap(); + Ok(serde_json::Value::Array( + array + .values() + .iter() + .map(|v| { + serde_json::Value::Number( + serde_json::Number::from_f64(*v as f64).unwrap(), + ) + }) + .collect(), + )) + } + _ => Err(Error::InvalidInput { + message: "VectorQuery vector must be of type Float32".into(), + }), + } + } + + match query.query_vector.len() { + 0 => { + // Server takes empty vector, not null or undefined. + body["vector"] = serde_json::Value::Array(Vec::new()); + Ok(vec![body]) + } + 1 => { + body["vector"] = vector_to_json(&query.query_vector[0])?; + Ok(vec![body]) + } + _ => { + let mut bodies = Vec::with_capacity(query.query_vector.len()); + for vector in &query.query_vector { + let mut body = body.clone(); + body["vector"] = vector_to_json(vector)?; + bodies.push(body); + } + Ok(bodies) + } + } + } } #[derive(Deserialize)] @@ -306,51 +371,29 @@ impl TableInternal for RemoteTable { ) -> Result> { let request = self.client.post(&format!("/v1/table/{}/query/", self.name)); - let mut body = serde_json::Value::Object(Default::default()); - Self::apply_query_params(&mut body, &query.base)?; + let body = serde_json::Value::Object(Default::default()); + let bodies = Self::apply_vector_query_params(body, query)?; - body["prefilter"] = query.base.prefilter.into(); - body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default()); - body["nprobes"] = query.nprobes.into(); - body["refine_factor"] = query.refine_factor.into(); - - let vector: Vec = if let Some(vector) = query.query_vector.as_ref() { - match vector.data_type() { - DataType::Float32 => vector - .as_any() - .downcast_ref::() - .unwrap() - .values() - .iter() - .cloned() - .collect(), - _ => { - return Err(Error::InvalidInput { - message: "VectorQuery vector must be of type Float32".into(), - }) - } - } + let mut futures = Vec::with_capacity(bodies.len()); + for body in bodies { + let request = request.try_clone().unwrap().json(&body); + let future = async move { + let (request_id, response) = self.client.send(request, true).await?; + self.read_arrow_stream(&request_id, response).await + }; + futures.push(future); + } + let streams = futures::future::try_join_all(futures).await?; + if streams.len() == 1 { + let stream = streams.into_iter().next().unwrap(); + Ok(Arc::new(OneShotExec::new(stream))) } else { - // Server takes empty vector, not null or undefined. - Vec::new() - }; - body["vector"] = serde_json::json!(vector); - - if let Some(vector_column) = query.column.as_ref() { - body["vector_column"] = serde_json::Value::String(vector_column.clone()); + let stream_execs = streams + .into_iter() + .map(|stream| Arc::new(OneShotExec::new(stream)) as Arc) + .collect(); + Table::multi_vector_plan(stream_execs) } - - if !query.use_index { - body["bypass_vector_index"] = serde_json::Value::Bool(true); - } - - let request = request.json(&body); - - let (request_id, response) = self.client.send(request, true).await?; - - let stream = self.read_arrow_stream(&request_id, response).await?; - - Ok(Arc::new(OneShotExec::new(stream))) } async fn plain_query( @@ -655,6 +698,7 @@ mod tests { use super::*; + use arrow::{array::AsArray, compute::concat_batches, datatypes::Int32Type}; use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator}; use arrow_schema::{DataType, Field, Schema}; use futures::{future::BoxFuture, StreamExt, TryFutureExt}; @@ -1207,6 +1251,52 @@ mod tests { .unwrap(); } + #[tokio::test] + async fn test_query_multiple_vectors() { + let table = Table::new_with_handler("my_table", |request| { + assert_eq!(request.method(), "POST"); + assert_eq!(request.url().path(), "/v1/table/my_table/query/"); + assert_eq!( + request.headers().get("Content-Type").unwrap(), + JSON_CONTENT_TYPE + ); + let data = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + let response_body = write_ipc_file(&data); + http::Response::builder() + .status(200) + .header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE) + .body(response_body) + .unwrap() + }); + + let query = table + .query() + .nearest_to(vec![0.1, 0.2, 0.3]) + .unwrap() + .add_query_vector(vec![0.4, 0.5, 0.6]) + .unwrap(); + let plan = query.explain_plan(true).await.unwrap(); + assert!(plan.contains("UnionExec"), "Plan: {}", plan); + + let results = query + .execute() + .await + .unwrap() + .try_collect::>() + .await + .unwrap(); + let results = concat_batches(&results[0].schema(), &results).unwrap(); + + let query_index = results["query_index"].as_primitive::(); + // We don't guarantee order. + assert!(query_index.values().contains(&0)); + assert!(query_index.values().contains(&1)); + } + #[tokio::test] async fn test_create_index() { let cases = [ diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index ee5e5bba..11415e52 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -24,6 +24,9 @@ use arrow_array::{RecordBatchIterator, RecordBatchReader}; use arrow_schema::{Field, Schema, SchemaRef}; use async_trait::async_trait; use datafusion_physical_plan::display::DisplayableExecutionPlan; +use datafusion_physical_plan::projection::ProjectionExec; +use datafusion_physical_plan::repartition::RepartitionExec; +use datafusion_physical_plan::union::UnionExec; use datafusion_physical_plan::ExecutionPlan; use futures::{StreamExt, TryStreamExt}; use lance::dataset::builder::DatasetBuilder; @@ -972,6 +975,57 @@ impl Table { ) -> Result> { self.inner.index_stats(index_name.as_ref()).await } + + // Take many execution plans and map them into a single plan that adds + // a query_index column and unions them. + pub(crate) fn multi_vector_plan( + plans: Vec>, + ) -> Result> { + if plans.is_empty() { + return Err(Error::InvalidInput { + message: "No plans provided".to_string(), + }); + } + // Projection to keeping all existing columns + let first_plan = plans[0].clone(); + let project_all_columns = first_plan + .schema() + .fields() + .iter() + .enumerate() + .map(|(i, field)| { + let expr = + datafusion_physical_plan::expressions::Column::new(field.name().as_str(), i); + let expr = Arc::new(expr) as Arc; + (expr, field.name().clone()) + }) + .collect::>(); + + let projected_plans = plans + .into_iter() + .enumerate() + .map(|(plan_i, plan)| { + let query_index = datafusion_common::ScalarValue::Int32(Some(plan_i as i32)); + let query_index_expr = + datafusion_physical_plan::expressions::Literal::new(query_index); + let query_index_expr = + Arc::new(query_index_expr) as Arc; + let mut projections = vec![(query_index_expr, "query_index".to_string())]; + projections.extend_from_slice(&project_all_columns); + let projection = ProjectionExec::try_new(projections, plan).unwrap(); + Arc::new(projection) as Arc + }) + .collect::>(); + + let unioned = Arc::new(UnionExec::new(projected_plans)); + // We require 1 partition in the final output + let repartitioned = RepartitionExec::try_new( + unioned, + datafusion_physical_plan::Partitioning::RoundRobinBatch(1), + ) + .unwrap(); + Ok(Arc::new(repartitioned)) + } } impl From for Table { @@ -1784,9 +1838,25 @@ impl TableInternal for NativeTable { ) -> Result> { let ds_ref = self.dataset.get().await?; + if query.query_vector.len() > 1 { + // If there are multiple query vectors, create a plan for each of them and union them. + let query_vecs = query.query_vector.clone(); + let plan_futures = query_vecs + .into_iter() + .map(|query_vector| { + let mut sub_query = query.clone(); + sub_query.query_vector = vec![query_vector]; + let options_ref = options.clone(); + async move { self.create_plan(&sub_query, options_ref).await } + }) + .collect::>(); + let plans = futures::future::try_join_all(plan_futures).await?; + return Table::multi_vector_plan(plans); + } + let mut scanner: Scanner = ds_ref.scan(); - if let Some(query_vector) = query.query_vector.as_ref() { + if let Some(query_vector) = query.query_vector.first() { // If there is a vector query, default to limit=10 if unspecified let column = if let Some(col) = query.column.as_ref() { col.clone() @@ -1828,18 +1898,11 @@ impl TableInternal for NativeTable { query_vector, query.base.limit.unwrap_or(DEFAULT_TOP_K), )?; - scanner.limit( - query.base.limit.map(|limit| limit as i64), - query.base.offset.map(|offset| offset as i64), - )?; - } else { - // If there is no vector query, it's ok to not have a limit - scanner.limit( - query.base.limit.map(|limit| limit as i64), - query.base.offset.map(|offset| offset as i64), - )?; } - + scanner.limit( + query.base.limit.map(|limit| limit as i64), + query.base.offset.map(|offset| offset as i64), + )?; scanner.nprobs(query.nprobes); scanner.use_index(query.use_index); scanner.prefilter(query.base.prefilter);