Compare commits

...

4 Commits

Author SHA1 Message Date
Lance Release
4ca0b15354 Bump version: 0.16.0-beta.0 → 0.16.0-beta.1 2024-11-14 04:41:56 +00:00
Rob Meng
d8c217b47d chore: bump lance to 0.19.2 (#1829) 2024-11-13 23:23:02 -05:00
Rob Meng
b724b1a01f feat: support remote empty query (#1828)
Support sending empty query types to remote lancedb. also include offset
and limit, where were previously omitted.
2024-11-13 23:04:52 -05:00
Will Jones
abd75e0ead feat: search multiple query vectors as one query (#1811)
Allows users to pass multiple query vector as part of a single query
plan. This just runs the queries in parallel without any further
optimization. It's mostly a convenience.

Previously, I think this was only handled by the sync Python remote API.
This makes it common across all SDKs.

Closes https://github.com/lancedb/lancedb/issues/1803

```python
>>> import lancedb
>>> import asyncio
>>> 
>>> async def main():
...     db = await lancedb.connect_async("./demo")
...     table = await db.create_table("demo", [{"id": 1, "vector": [1, 2, 3]}, {"id": 2, "vector": [4, 5, 6]}], mode="overwrite")
...     return await table.query().nearest_to([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [4.0, 5.0, 6.0]]).limit(1).to_pandas()
... 
>>> asyncio.run(main())
   query_index  id           vector  _distance
0            2   2  [4.0, 5.0, 6.0]        0.0
1            1   2  [4.0, 5.0, 6.0]        0.0
2            0   1  [1.0, 2.0, 3.0]        0.0
```
2024-11-13 16:05:16 -08:00
15 changed files with 407 additions and 83 deletions

View File

@@ -23,13 +23,13 @@ rust-version = "1.80.0" # TODO: lower this once we upgrade Lance again.
[workspace.dependencies]
lance = { "version" = "=0.19.2", "features" = [
"dynamodb",
], git = "https://github.com/lancedb/lance.git", tag = "v0.19.2-beta.3" }
lance-index = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2-beta.3" }
lance-linalg = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2-beta.3" }
lance-table = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2-beta.3" }
lance-testing = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2-beta.3" }
lance-datafusion = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2-beta.3" }
lance-encoding = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2-beta.3" }
], git = "https://github.com/lancedb/lance.git", tag = "v0.19.2" }
lance-index = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2" }
lance-linalg = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2" }
lance-table = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2" }
lance-testing = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2" }
lance-datafusion = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2" }
lance-encoding = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2" }
# Note that this one does not include pyarrow
arrow = { version = "52.2", optional = false }
arrow-array = "52.2"

View File

@@ -998,4 +998,18 @@ describe("column name options", () => {
const results = await table.query().where("`camelCase` = 1").toArray();
expect(results[0].camelCase).toBe(1);
});
test("can make multiple vector queries in one go", async () => {
const results = await table
.query()
.nearestTo([0.1, 0.2])
.addQueryVector([0.1, 0.2])
.limit(1)
.toArray();
console.log(results);
expect(results.length).toBe(2);
results.sort((a, b) => a.query_index - b.query_index);
expect(results[0].query_index).toBe(0);
expect(results[1].query_index).toBe(1);
});
});

View File

@@ -492,6 +492,42 @@ export class VectorQuery extends QueryBase<NativeVectorQuery> {
super.doCall((inner) => inner.bypassVectorIndex());
return this;
}
/*
* Add a query vector to the search
*
* This method can be called multiple times to add multiple query vectors
* to the search. If multiple query vectors are added, then they will be searched
* in parallel, and the results will be concatenated. A column called `query_index`
* will be added to indicate the index of the query vector that produced the result.
*
* Performance wise, this is equivalent to running multiple queries concurrently.
*/
addQueryVector(vector: IntoVector): VectorQuery {
if (vector instanceof Promise) {
const res = (async () => {
try {
const v = await vector;
const arr = Float32Array.from(v);
//
// biome-ignore lint/suspicious/noExplicitAny: we need to get the `inner`, but js has no package scoping
const value: any = this.addQueryVector(arr);
const inner = value.inner as
| NativeVectorQuery
| Promise<NativeVectorQuery>;
return inner;
} catch (e) {
return Promise.reject(e);
}
})();
return new VectorQuery(res);
} else {
super.doCall((inner) => {
inner.addQueryVector(Float32Array.from(vector));
});
return this;
}
}
}
/** A builder for LanceDB queries. */

View File

@@ -135,6 +135,16 @@ impl VectorQuery {
self.inner = self.inner.clone().column(&column);
}
#[napi]
pub fn add_query_vector(&mut self, vector: Float32Array) -> Result<()> {
self.inner = self
.inner
.clone()
.add_query_vector(vector.as_ref())
.default_error()?;
Ok(())
}
#[napi]
pub fn distance_type(&mut self, distance_type: String) -> napi::Result<()> {
let distance_type = parse_distance_type(distance_type)?;

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.16.0-beta.0"
current_version = "0.16.0-beta.1"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb-python"
version = "0.16.0-beta.0"
version = "0.16.0-beta.1"
edition.workspace = true
description = "Python bindings for LanceDB"
license.workspace = true

View File

@@ -4,7 +4,7 @@ name = "lancedb"
dependencies = [
"deprecation",
"nest-asyncio~=1.0",
"pylance==0.19.2-beta.3",
"pylance==0.19.2",
"tqdm>=4.27.0",
"pydantic>=1.10",
"packaging",

View File

@@ -943,12 +943,16 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
class LanceEmptyQueryBuilder(LanceQueryBuilder):
def to_arrow(self) -> pa.Table:
ds = self._table.to_lance()
return ds.to_table(
query = Query(
columns=self._columns,
filter=self._where,
limit=self._limit,
k=self._limit or 10,
with_row_id=self._with_row_id,
vector=[],
# not actually respected in remote query
offset=self._offset or 0,
)
return self._table._execute_query(query).read_all()
def rerank(self, reranker: Reranker) -> LanceEmptyQueryBuilder:
"""Rerank the results using the specified reranker.
@@ -1491,7 +1495,7 @@ class AsyncQuery(AsyncQueryBase):
return pa.array(vec)
def nearest_to(
self, query_vector: Optional[Union[VEC, Tuple]] = None
self, query_vector: Optional[Union[VEC, Tuple, List[VEC]]] = None
) -> AsyncVectorQuery:
"""
Find the nearest vectors to the given query vector.
@@ -1529,10 +1533,30 @@ class AsyncQuery(AsyncQueryBase):
Vector searches always have a [limit][]. If `limit` has not been called then
a default `limit` of 10 will be used.
Typically, a single vector is passed in as the query. However, you can also
pass in multiple vectors. This can be useful if you want to find the nearest
vectors to multiple query vectors. This is not expected to be faster than
making multiple queries concurrently; it is just a convenience method.
If multiple vectors are passed in then an additional column `query_index`
will be added to the results. This column will contain the index of the
query vector that the result is nearest to.
"""
return AsyncVectorQuery(
self._inner.nearest_to(AsyncQuery._query_vec_to_array(query_vector))
)
if (
isinstance(query_vector, list)
and len(query_vector) > 0
and not isinstance(query_vector[0], (float, int))
):
# multiple have been passed
query_vectors = [AsyncQuery._query_vec_to_array(v) for v in query_vector]
new_self = self._inner.nearest_to(query_vectors[0])
for v in query_vectors[1:]:
new_self.add_query_vector(v)
return AsyncVectorQuery(new_self)
else:
return AsyncVectorQuery(
self._inner.nearest_to(AsyncQuery._query_vec_to_array(query_vector))
)
def nearest_to_text(
self, query: str, columns: Union[str, List[str]] = []

View File

@@ -327,10 +327,6 @@ class RemoteTable(Table):
- and also the "_distance" column which is the distance between the query
vector and the returned vector.
"""
# empty query builder is not supported in saas, raise error
if query is None and query_type != "hybrid":
raise ValueError("Empty query is not supported")
return LanceQueryBuilder.create(
self,
query,

View File

@@ -197,6 +197,23 @@ def test_query_sync_minimal():
assert data == expected
def test_query_sync_empty_query():
def handler(body):
assert body == {
"k": 10,
"filter": "true",
"vector": [],
"columns": ["id"],
}
return pa.table({"id": [1, 2, 3]})
with query_test_table(handler) as table:
data = table.search(None).where("true").select(["id"]).limit(10).to_list()
expected = [{"id": 1}, {"id": 2}, {"id": 3}]
assert data == expected
def test_query_sync_maximal():
def handler(body):
assert body == {
@@ -229,6 +246,17 @@ def test_query_sync_maximal():
)
def test_query_sync_multiple_vectors():
def handler(_body):
return pa.table({"id": [1]})
with query_test_table(handler) as table:
results = table.search([[1, 2, 3], [4, 5, 6]]).limit(1).to_list()
assert len(results) == 2
results.sort(key=lambda x: x["query_index"])
assert results == [{"id": 1, "query_index": 0}, {"id": 1, "query_index": 1}]
def test_query_sync_fts():
def handler(body):
assert body == {

View File

@@ -892,10 +892,15 @@ def test_empty_query(db):
table = LanceTable.create(db, "my_table2", data=[{"id": i} for i in range(100)])
df = table.search().select(["id"]).to_pandas()
assert len(df) == 10
# None is the same as default
df = table.search().select(["id"]).limit(None).to_pandas()
assert len(df) == 100
assert len(df) == 10
# invalid limist is the same as None, wihch is the same as default
df = table.search().select(["id"]).limit(-1).to_pandas()
assert len(df) == 100
assert len(df) == 10
# valid limit should work
df = table.search().select(["id"]).limit(42).to_pandas()
assert len(df) == 42
def test_search_with_schema_inf_single_vector(db):

View File

@@ -142,6 +142,13 @@ impl VectorQuery {
self.inner = self.inner.clone().only_if(predicate);
}
pub fn add_query_vector(&mut self, vector: Bound<'_, PyAny>) -> PyResult<()> {
let data: ArrayData = ArrayData::from_pyarrow_bound(&vector)?;
let array = make_array(data);
self.inner = self.inner.clone().add_query_vector(array).infer_error()?;
Ok(())
}
pub fn select(&mut self, columns: Vec<(String, String)>) {
self.inner = self.inner.clone().select(Select::dynamic(&columns));
}

View File

@@ -475,6 +475,7 @@ impl<T: HasQuery> QueryBase for T {
/// Options for controlling the execution of a query
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct QueryExecutionOptions {
/// The maximum number of rows that will be contained in a single
/// `RecordBatch` delivered by the query.
@@ -650,7 +651,7 @@ impl Query {
pub fn nearest_to(self, vector: impl IntoQueryVector) -> Result<VectorQuery> {
let mut vector_query = self.into_vector();
let query_vector = vector.to_query_vector(&DataType::Float32, "default")?;
vector_query.query_vector = Some(query_vector);
vector_query.query_vector.push(query_vector);
Ok(vector_query)
}
}
@@ -701,7 +702,7 @@ pub struct VectorQuery {
// the column based on the dataset's schema.
pub(crate) column: Option<String>,
// IVF PQ - ANN search.
pub(crate) query_vector: Option<Arc<dyn Array>>,
pub(crate) query_vector: Vec<Arc<dyn Array>>,
pub(crate) nprobes: usize,
pub(crate) refine_factor: Option<u32>,
pub(crate) distance_type: Option<DistanceType>,
@@ -714,7 +715,7 @@ impl VectorQuery {
Self {
base,
column: None,
query_vector: None,
query_vector: Vec::new(),
nprobes: 20,
refine_factor: None,
distance_type: None,
@@ -734,6 +735,22 @@ impl VectorQuery {
self
}
/// Add another query vector to the search.
///
/// Multiple searches will be dispatched as part of the query.
/// This is a convenience method for adding multiple query vectors
/// to the search. It is not expected to be faster than issuing
/// multiple queries concurrently.
///
/// The output data will contain an additional columns `query_index` which
/// will contain the index of the query vector that was used to generate the
/// result.
pub fn add_query_vector(mut self, vector: impl IntoQueryVector) -> Result<Self> {
let query_vector = vector.to_query_vector(&DataType::Float32, "default")?;
self.query_vector.push(query_vector);
Ok(self)
}
/// Set the number of partitions to search (probe)
///
/// This argument is only used when the vector column has an IVF PQ index.
@@ -854,6 +871,7 @@ mod tests {
use std::sync::Arc;
use super::*;
use arrow::{compute::concat_batches, datatypes::Int32Type};
use arrow_array::{
cast::AsArray, Float32Array, Int32Array, RecordBatch, RecordBatchIterator,
RecordBatchReader,
@@ -883,7 +901,10 @@ mod tests {
let vector = Float32Array::from_iter_values([0.1, 0.2]);
let query = table.query().nearest_to(&[0.1, 0.2]).unwrap();
assert_eq!(*query.query_vector.unwrap().as_ref().as_primitive(), vector);
assert_eq!(
*query.query_vector.first().unwrap().as_ref().as_primitive(),
vector
);
let new_vector = Float32Array::from_iter_values([9.8, 8.7]);
@@ -899,7 +920,7 @@ mod tests {
.refine_factor(999);
assert_eq!(
*query.query_vector.unwrap().as_ref().as_primitive(),
*query.query_vector.first().unwrap().as_ref().as_primitive(),
new_vector
);
assert_eq!(query.base.limit.unwrap(), 100);
@@ -1197,4 +1218,34 @@ mod tests {
assert!(batch.column_by_name("_rowid").is_some());
}
}
#[tokio::test]
async fn test_multiple_query_vectors() {
let tmp_dir = tempdir().unwrap();
let table = make_test_table(&tmp_dir).await;
let query = table
.query()
.nearest_to(&[0.1, 0.2, 0.3, 0.4])
.unwrap()
.add_query_vector(&[0.5, 0.6, 0.7, 0.8])
.unwrap()
.limit(1);
let plan = query.explain_plan(true).await.unwrap();
assert!(plan.contains("UnionExec"));
let results = query
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let results = concat_batches(&results[0].schema(), &results).unwrap();
assert_eq!(results.num_rows(), 2); // One result for each query vector.
let query_index = results["query_index"].as_primitive::<Int32Type>();
// We don't guarantee order.
assert!(query_index.values().contains(&0));
assert!(query_index.values().contains(&1));
}
}

View File

@@ -6,7 +6,7 @@ use crate::index::IndexStatistics;
use crate::query::Select;
use crate::table::AddDataMode;
use crate::utils::{supported_btree_data_type, supported_vector_data_type};
use crate::Error;
use crate::{Error, Table};
use arrow_array::RecordBatchReader;
use arrow_ipc::reader::FileReader;
use arrow_schema::{DataType, SchemaRef};
@@ -185,6 +185,71 @@ impl<S: HttpSend> RemoteTable<S> {
Ok(())
}
fn apply_vector_query_params(
mut body: serde_json::Value,
query: &VectorQuery,
) -> Result<Vec<serde_json::Value>> {
Self::apply_query_params(&mut body, &query.base)?;
// Apply general parameters, before we dispatch based on number of query vectors.
body["prefilter"] = query.base.prefilter.into();
body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default());
body["nprobes"] = query.nprobes.into();
body["refine_factor"] = query.refine_factor.into();
if let Some(vector_column) = query.column.as_ref() {
body["vector_column"] = serde_json::Value::String(vector_column.clone());
}
if !query.use_index {
body["bypass_vector_index"] = serde_json::Value::Bool(true);
}
fn vector_to_json(vector: &arrow_array::ArrayRef) -> Result<serde_json::Value> {
match vector.data_type() {
DataType::Float32 => {
let array = vector
.as_any()
.downcast_ref::<arrow_array::Float32Array>()
.unwrap();
Ok(serde_json::Value::Array(
array
.values()
.iter()
.map(|v| {
serde_json::Value::Number(
serde_json::Number::from_f64(*v as f64).unwrap(),
)
})
.collect(),
))
}
_ => Err(Error::InvalidInput {
message: "VectorQuery vector must be of type Float32".into(),
}),
}
}
match query.query_vector.len() {
0 => {
// Server takes empty vector, not null or undefined.
body["vector"] = serde_json::Value::Array(Vec::new());
Ok(vec![body])
}
1 => {
body["vector"] = vector_to_json(&query.query_vector[0])?;
Ok(vec![body])
}
_ => {
let mut bodies = Vec::with_capacity(query.query_vector.len());
for vector in &query.query_vector {
let mut body = body.clone();
body["vector"] = vector_to_json(vector)?;
bodies.push(body);
}
Ok(bodies)
}
}
}
}
#[derive(Deserialize)]
@@ -306,51 +371,29 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
) -> Result<Arc<dyn ExecutionPlan>> {
let request = self.client.post(&format!("/v1/table/{}/query/", self.name));
let mut body = serde_json::Value::Object(Default::default());
Self::apply_query_params(&mut body, &query.base)?;
let body = serde_json::Value::Object(Default::default());
let bodies = Self::apply_vector_query_params(body, query)?;
body["prefilter"] = query.base.prefilter.into();
body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default());
body["nprobes"] = query.nprobes.into();
body["refine_factor"] = query.refine_factor.into();
let vector: Vec<f32> = if let Some(vector) = query.query_vector.as_ref() {
match vector.data_type() {
DataType::Float32 => vector
.as_any()
.downcast_ref::<arrow_array::Float32Array>()
.unwrap()
.values()
.iter()
.cloned()
.collect(),
_ => {
return Err(Error::InvalidInput {
message: "VectorQuery vector must be of type Float32".into(),
})
}
}
let mut futures = Vec::with_capacity(bodies.len());
for body in bodies {
let request = request.try_clone().unwrap().json(&body);
let future = async move {
let (request_id, response) = self.client.send(request, true).await?;
self.read_arrow_stream(&request_id, response).await
};
futures.push(future);
}
let streams = futures::future::try_join_all(futures).await?;
if streams.len() == 1 {
let stream = streams.into_iter().next().unwrap();
Ok(Arc::new(OneShotExec::new(stream)))
} else {
// Server takes empty vector, not null or undefined.
Vec::new()
};
body["vector"] = serde_json::json!(vector);
if let Some(vector_column) = query.column.as_ref() {
body["vector_column"] = serde_json::Value::String(vector_column.clone());
let stream_execs = streams
.into_iter()
.map(|stream| Arc::new(OneShotExec::new(stream)) as Arc<dyn ExecutionPlan>)
.collect();
Table::multi_vector_plan(stream_execs)
}
if !query.use_index {
body["bypass_vector_index"] = serde_json::Value::Bool(true);
}
let request = request.json(&body);
let (request_id, response) = self.client.send(request, true).await?;
let stream = self.read_arrow_stream(&request_id, response).await?;
Ok(Arc::new(OneShotExec::new(stream)))
}
async fn plain_query(
@@ -655,6 +698,7 @@ mod tests {
use super::*;
use arrow::{array::AsArray, compute::concat_batches, datatypes::Int32Type};
use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator};
use arrow_schema::{DataType, Field, Schema};
use futures::{future::BoxFuture, StreamExt, TryFutureExt};
@@ -1207,6 +1251,52 @@ mod tests {
.unwrap();
}
#[tokio::test]
async fn test_query_multiple_vectors() {
let table = Table::new_with_handler("my_table", |request| {
assert_eq!(request.method(), "POST");
assert_eq!(request.url().path(), "/v1/table/my_table/query/");
assert_eq!(
request.headers().get("Content-Type").unwrap(),
JSON_CONTENT_TYPE
);
let data = RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap();
let response_body = write_ipc_file(&data);
http::Response::builder()
.status(200)
.header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE)
.body(response_body)
.unwrap()
});
let query = table
.query()
.nearest_to(vec![0.1, 0.2, 0.3])
.unwrap()
.add_query_vector(vec![0.4, 0.5, 0.6])
.unwrap();
let plan = query.explain_plan(true).await.unwrap();
assert!(plan.contains("UnionExec"), "Plan: {}", plan);
let results = query
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let results = concat_batches(&results[0].schema(), &results).unwrap();
let query_index = results["query_index"].as_primitive::<Int32Type>();
// We don't guarantee order.
assert!(query_index.values().contains(&0));
assert!(query_index.values().contains(&1));
}
#[tokio::test]
async fn test_create_index() {
let cases = [

View File

@@ -24,6 +24,9 @@ use arrow_array::{RecordBatchIterator, RecordBatchReader};
use arrow_schema::{Field, Schema, SchemaRef};
use async_trait::async_trait;
use datafusion_physical_plan::display::DisplayableExecutionPlan;
use datafusion_physical_plan::projection::ProjectionExec;
use datafusion_physical_plan::repartition::RepartitionExec;
use datafusion_physical_plan::union::UnionExec;
use datafusion_physical_plan::ExecutionPlan;
use futures::{StreamExt, TryStreamExt};
use lance::dataset::builder::DatasetBuilder;
@@ -972,6 +975,57 @@ impl Table {
) -> Result<Option<IndexStatistics>> {
self.inner.index_stats(index_name.as_ref()).await
}
// Take many execution plans and map them into a single plan that adds
// a query_index column and unions them.
pub(crate) fn multi_vector_plan(
plans: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
if plans.is_empty() {
return Err(Error::InvalidInput {
message: "No plans provided".to_string(),
});
}
// Projection to keeping all existing columns
let first_plan = plans[0].clone();
let project_all_columns = first_plan
.schema()
.fields()
.iter()
.enumerate()
.map(|(i, field)| {
let expr =
datafusion_physical_plan::expressions::Column::new(field.name().as_str(), i);
let expr = Arc::new(expr) as Arc<dyn datafusion_physical_plan::PhysicalExpr>;
(expr, field.name().clone())
})
.collect::<Vec<_>>();
let projected_plans = plans
.into_iter()
.enumerate()
.map(|(plan_i, plan)| {
let query_index = datafusion_common::ScalarValue::Int32(Some(plan_i as i32));
let query_index_expr =
datafusion_physical_plan::expressions::Literal::new(query_index);
let query_index_expr =
Arc::new(query_index_expr) as Arc<dyn datafusion_physical_plan::PhysicalExpr>;
let mut projections = vec![(query_index_expr, "query_index".to_string())];
projections.extend_from_slice(&project_all_columns);
let projection = ProjectionExec::try_new(projections, plan).unwrap();
Arc::new(projection) as Arc<dyn datafusion_physical_plan::ExecutionPlan>
})
.collect::<Vec<_>>();
let unioned = Arc::new(UnionExec::new(projected_plans));
// We require 1 partition in the final output
let repartitioned = RepartitionExec::try_new(
unioned,
datafusion_physical_plan::Partitioning::RoundRobinBatch(1),
)
.unwrap();
Ok(Arc::new(repartitioned))
}
}
impl From<NativeTable> for Table {
@@ -1784,9 +1838,25 @@ impl TableInternal for NativeTable {
) -> Result<Arc<dyn ExecutionPlan>> {
let ds_ref = self.dataset.get().await?;
if query.query_vector.len() > 1 {
// If there are multiple query vectors, create a plan for each of them and union them.
let query_vecs = query.query_vector.clone();
let plan_futures = query_vecs
.into_iter()
.map(|query_vector| {
let mut sub_query = query.clone();
sub_query.query_vector = vec![query_vector];
let options_ref = options.clone();
async move { self.create_plan(&sub_query, options_ref).await }
})
.collect::<Vec<_>>();
let plans = futures::future::try_join_all(plan_futures).await?;
return Table::multi_vector_plan(plans);
}
let mut scanner: Scanner = ds_ref.scan();
if let Some(query_vector) = query.query_vector.as_ref() {
if let Some(query_vector) = query.query_vector.first() {
// If there is a vector query, default to limit=10 if unspecified
let column = if let Some(col) = query.column.as_ref() {
col.clone()
@@ -1828,18 +1898,11 @@ impl TableInternal for NativeTable {
query_vector,
query.base.limit.unwrap_or(DEFAULT_TOP_K),
)?;
scanner.limit(
query.base.limit.map(|limit| limit as i64),
query.base.offset.map(|offset| offset as i64),
)?;
} else {
// If there is no vector query, it's ok to not have a limit
scanner.limit(
query.base.limit.map(|limit| limit as i64),
query.base.offset.map(|offset| offset as i64),
)?;
}
scanner.limit(
query.base.limit.map(|limit| limit as i64),
query.base.offset.map(|offset| offset as i64),
)?;
scanner.nprobs(query.nprobes);
scanner.use_index(query.use_index);
scanner.prefilter(query.base.prefilter);