mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-25 06:19:57 +00:00
Compare commits
4 Commits
lancedb-cl
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4ca0b15354 | ||
|
|
d8c217b47d | ||
|
|
b724b1a01f | ||
|
|
abd75e0ead |
14
Cargo.toml
14
Cargo.toml
@@ -23,13 +23,13 @@ rust-version = "1.80.0" # TODO: lower this once we upgrade Lance again.
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=0.19.2", "features" = [
|
||||
"dynamodb",
|
||||
], git = "https://github.com/lancedb/lance.git", tag = "v0.19.2-beta.3" }
|
||||
lance-index = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2-beta.3" }
|
||||
lance-linalg = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2-beta.3" }
|
||||
lance-table = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2-beta.3" }
|
||||
lance-testing = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2-beta.3" }
|
||||
lance-datafusion = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2-beta.3" }
|
||||
lance-encoding = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2-beta.3" }
|
||||
], git = "https://github.com/lancedb/lance.git", tag = "v0.19.2" }
|
||||
lance-index = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2" }
|
||||
lance-linalg = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2" }
|
||||
lance-table = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2" }
|
||||
lance-testing = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2" }
|
||||
lance-datafusion = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2" }
|
||||
lance-encoding = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2" }
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "52.2", optional = false }
|
||||
arrow-array = "52.2"
|
||||
|
||||
@@ -998,4 +998,18 @@ describe("column name options", () => {
|
||||
const results = await table.query().where("`camelCase` = 1").toArray();
|
||||
expect(results[0].camelCase).toBe(1);
|
||||
});
|
||||
|
||||
test("can make multiple vector queries in one go", async () => {
|
||||
const results = await table
|
||||
.query()
|
||||
.nearestTo([0.1, 0.2])
|
||||
.addQueryVector([0.1, 0.2])
|
||||
.limit(1)
|
||||
.toArray();
|
||||
console.log(results);
|
||||
expect(results.length).toBe(2);
|
||||
results.sort((a, b) => a.query_index - b.query_index);
|
||||
expect(results[0].query_index).toBe(0);
|
||||
expect(results[1].query_index).toBe(1);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -492,6 +492,42 @@ export class VectorQuery extends QueryBase<NativeVectorQuery> {
|
||||
super.doCall((inner) => inner.bypassVectorIndex());
|
||||
return this;
|
||||
}
|
||||
|
||||
/*
|
||||
* Add a query vector to the search
|
||||
*
|
||||
* This method can be called multiple times to add multiple query vectors
|
||||
* to the search. If multiple query vectors are added, then they will be searched
|
||||
* in parallel, and the results will be concatenated. A column called `query_index`
|
||||
* will be added to indicate the index of the query vector that produced the result.
|
||||
*
|
||||
* Performance wise, this is equivalent to running multiple queries concurrently.
|
||||
*/
|
||||
addQueryVector(vector: IntoVector): VectorQuery {
|
||||
if (vector instanceof Promise) {
|
||||
const res = (async () => {
|
||||
try {
|
||||
const v = await vector;
|
||||
const arr = Float32Array.from(v);
|
||||
//
|
||||
// biome-ignore lint/suspicious/noExplicitAny: we need to get the `inner`, but js has no package scoping
|
||||
const value: any = this.addQueryVector(arr);
|
||||
const inner = value.inner as
|
||||
| NativeVectorQuery
|
||||
| Promise<NativeVectorQuery>;
|
||||
return inner;
|
||||
} catch (e) {
|
||||
return Promise.reject(e);
|
||||
}
|
||||
})();
|
||||
return new VectorQuery(res);
|
||||
} else {
|
||||
super.doCall((inner) => {
|
||||
inner.addQueryVector(Float32Array.from(vector));
|
||||
});
|
||||
return this;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** A builder for LanceDB queries. */
|
||||
|
||||
@@ -135,6 +135,16 @@ impl VectorQuery {
|
||||
self.inner = self.inner.clone().column(&column);
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn add_query_vector(&mut self, vector: Float32Array) -> Result<()> {
|
||||
self.inner = self
|
||||
.inner
|
||||
.clone()
|
||||
.add_query_vector(vector.as_ref())
|
||||
.default_error()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn distance_type(&mut self, distance_type: String) -> napi::Result<()> {
|
||||
let distance_type = parse_distance_type(distance_type)?;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.16.0-beta.0"
|
||||
current_version = "0.16.0-beta.1"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb-python"
|
||||
version = "0.16.0-beta.0"
|
||||
version = "0.16.0-beta.1"
|
||||
edition.workspace = true
|
||||
description = "Python bindings for LanceDB"
|
||||
license.workspace = true
|
||||
|
||||
@@ -4,7 +4,7 @@ name = "lancedb"
|
||||
dependencies = [
|
||||
"deprecation",
|
||||
"nest-asyncio~=1.0",
|
||||
"pylance==0.19.2-beta.3",
|
||||
"pylance==0.19.2",
|
||||
"tqdm>=4.27.0",
|
||||
"pydantic>=1.10",
|
||||
"packaging",
|
||||
|
||||
@@ -943,12 +943,16 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
|
||||
class LanceEmptyQueryBuilder(LanceQueryBuilder):
|
||||
def to_arrow(self) -> pa.Table:
|
||||
ds = self._table.to_lance()
|
||||
return ds.to_table(
|
||||
query = Query(
|
||||
columns=self._columns,
|
||||
filter=self._where,
|
||||
limit=self._limit,
|
||||
k=self._limit or 10,
|
||||
with_row_id=self._with_row_id,
|
||||
vector=[],
|
||||
# not actually respected in remote query
|
||||
offset=self._offset or 0,
|
||||
)
|
||||
return self._table._execute_query(query).read_all()
|
||||
|
||||
def rerank(self, reranker: Reranker) -> LanceEmptyQueryBuilder:
|
||||
"""Rerank the results using the specified reranker.
|
||||
@@ -1491,7 +1495,7 @@ class AsyncQuery(AsyncQueryBase):
|
||||
return pa.array(vec)
|
||||
|
||||
def nearest_to(
|
||||
self, query_vector: Optional[Union[VEC, Tuple]] = None
|
||||
self, query_vector: Optional[Union[VEC, Tuple, List[VEC]]] = None
|
||||
) -> AsyncVectorQuery:
|
||||
"""
|
||||
Find the nearest vectors to the given query vector.
|
||||
@@ -1529,10 +1533,30 @@ class AsyncQuery(AsyncQueryBase):
|
||||
|
||||
Vector searches always have a [limit][]. If `limit` has not been called then
|
||||
a default `limit` of 10 will be used.
|
||||
|
||||
Typically, a single vector is passed in as the query. However, you can also
|
||||
pass in multiple vectors. This can be useful if you want to find the nearest
|
||||
vectors to multiple query vectors. This is not expected to be faster than
|
||||
making multiple queries concurrently; it is just a convenience method.
|
||||
If multiple vectors are passed in then an additional column `query_index`
|
||||
will be added to the results. This column will contain the index of the
|
||||
query vector that the result is nearest to.
|
||||
"""
|
||||
return AsyncVectorQuery(
|
||||
self._inner.nearest_to(AsyncQuery._query_vec_to_array(query_vector))
|
||||
)
|
||||
if (
|
||||
isinstance(query_vector, list)
|
||||
and len(query_vector) > 0
|
||||
and not isinstance(query_vector[0], (float, int))
|
||||
):
|
||||
# multiple have been passed
|
||||
query_vectors = [AsyncQuery._query_vec_to_array(v) for v in query_vector]
|
||||
new_self = self._inner.nearest_to(query_vectors[0])
|
||||
for v in query_vectors[1:]:
|
||||
new_self.add_query_vector(v)
|
||||
return AsyncVectorQuery(new_self)
|
||||
else:
|
||||
return AsyncVectorQuery(
|
||||
self._inner.nearest_to(AsyncQuery._query_vec_to_array(query_vector))
|
||||
)
|
||||
|
||||
def nearest_to_text(
|
||||
self, query: str, columns: Union[str, List[str]] = []
|
||||
|
||||
@@ -327,10 +327,6 @@ class RemoteTable(Table):
|
||||
- and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vector.
|
||||
"""
|
||||
# empty query builder is not supported in saas, raise error
|
||||
if query is None and query_type != "hybrid":
|
||||
raise ValueError("Empty query is not supported")
|
||||
|
||||
return LanceQueryBuilder.create(
|
||||
self,
|
||||
query,
|
||||
|
||||
@@ -197,6 +197,23 @@ def test_query_sync_minimal():
|
||||
assert data == expected
|
||||
|
||||
|
||||
def test_query_sync_empty_query():
|
||||
def handler(body):
|
||||
assert body == {
|
||||
"k": 10,
|
||||
"filter": "true",
|
||||
"vector": [],
|
||||
"columns": ["id"],
|
||||
}
|
||||
|
||||
return pa.table({"id": [1, 2, 3]})
|
||||
|
||||
with query_test_table(handler) as table:
|
||||
data = table.search(None).where("true").select(["id"]).limit(10).to_list()
|
||||
expected = [{"id": 1}, {"id": 2}, {"id": 3}]
|
||||
assert data == expected
|
||||
|
||||
|
||||
def test_query_sync_maximal():
|
||||
def handler(body):
|
||||
assert body == {
|
||||
@@ -229,6 +246,17 @@ def test_query_sync_maximal():
|
||||
)
|
||||
|
||||
|
||||
def test_query_sync_multiple_vectors():
|
||||
def handler(_body):
|
||||
return pa.table({"id": [1]})
|
||||
|
||||
with query_test_table(handler) as table:
|
||||
results = table.search([[1, 2, 3], [4, 5, 6]]).limit(1).to_list()
|
||||
assert len(results) == 2
|
||||
results.sort(key=lambda x: x["query_index"])
|
||||
assert results == [{"id": 1, "query_index": 0}, {"id": 1, "query_index": 1}]
|
||||
|
||||
|
||||
def test_query_sync_fts():
|
||||
def handler(body):
|
||||
assert body == {
|
||||
|
||||
@@ -892,10 +892,15 @@ def test_empty_query(db):
|
||||
table = LanceTable.create(db, "my_table2", data=[{"id": i} for i in range(100)])
|
||||
df = table.search().select(["id"]).to_pandas()
|
||||
assert len(df) == 10
|
||||
# None is the same as default
|
||||
df = table.search().select(["id"]).limit(None).to_pandas()
|
||||
assert len(df) == 100
|
||||
assert len(df) == 10
|
||||
# invalid limist is the same as None, wihch is the same as default
|
||||
df = table.search().select(["id"]).limit(-1).to_pandas()
|
||||
assert len(df) == 100
|
||||
assert len(df) == 10
|
||||
# valid limit should work
|
||||
df = table.search().select(["id"]).limit(42).to_pandas()
|
||||
assert len(df) == 42
|
||||
|
||||
|
||||
def test_search_with_schema_inf_single_vector(db):
|
||||
|
||||
@@ -142,6 +142,13 @@ impl VectorQuery {
|
||||
self.inner = self.inner.clone().only_if(predicate);
|
||||
}
|
||||
|
||||
pub fn add_query_vector(&mut self, vector: Bound<'_, PyAny>) -> PyResult<()> {
|
||||
let data: ArrayData = ArrayData::from_pyarrow_bound(&vector)?;
|
||||
let array = make_array(data);
|
||||
self.inner = self.inner.clone().add_query_vector(array).infer_error()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn select(&mut self, columns: Vec<(String, String)>) {
|
||||
self.inner = self.inner.clone().select(Select::dynamic(&columns));
|
||||
}
|
||||
|
||||
@@ -475,6 +475,7 @@ impl<T: HasQuery> QueryBase for T {
|
||||
|
||||
/// Options for controlling the execution of a query
|
||||
#[non_exhaustive]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct QueryExecutionOptions {
|
||||
/// The maximum number of rows that will be contained in a single
|
||||
/// `RecordBatch` delivered by the query.
|
||||
@@ -650,7 +651,7 @@ impl Query {
|
||||
pub fn nearest_to(self, vector: impl IntoQueryVector) -> Result<VectorQuery> {
|
||||
let mut vector_query = self.into_vector();
|
||||
let query_vector = vector.to_query_vector(&DataType::Float32, "default")?;
|
||||
vector_query.query_vector = Some(query_vector);
|
||||
vector_query.query_vector.push(query_vector);
|
||||
Ok(vector_query)
|
||||
}
|
||||
}
|
||||
@@ -701,7 +702,7 @@ pub struct VectorQuery {
|
||||
// the column based on the dataset's schema.
|
||||
pub(crate) column: Option<String>,
|
||||
// IVF PQ - ANN search.
|
||||
pub(crate) query_vector: Option<Arc<dyn Array>>,
|
||||
pub(crate) query_vector: Vec<Arc<dyn Array>>,
|
||||
pub(crate) nprobes: usize,
|
||||
pub(crate) refine_factor: Option<u32>,
|
||||
pub(crate) distance_type: Option<DistanceType>,
|
||||
@@ -714,7 +715,7 @@ impl VectorQuery {
|
||||
Self {
|
||||
base,
|
||||
column: None,
|
||||
query_vector: None,
|
||||
query_vector: Vec::new(),
|
||||
nprobes: 20,
|
||||
refine_factor: None,
|
||||
distance_type: None,
|
||||
@@ -734,6 +735,22 @@ impl VectorQuery {
|
||||
self
|
||||
}
|
||||
|
||||
/// Add another query vector to the search.
|
||||
///
|
||||
/// Multiple searches will be dispatched as part of the query.
|
||||
/// This is a convenience method for adding multiple query vectors
|
||||
/// to the search. It is not expected to be faster than issuing
|
||||
/// multiple queries concurrently.
|
||||
///
|
||||
/// The output data will contain an additional columns `query_index` which
|
||||
/// will contain the index of the query vector that was used to generate the
|
||||
/// result.
|
||||
pub fn add_query_vector(mut self, vector: impl IntoQueryVector) -> Result<Self> {
|
||||
let query_vector = vector.to_query_vector(&DataType::Float32, "default")?;
|
||||
self.query_vector.push(query_vector);
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Set the number of partitions to search (probe)
|
||||
///
|
||||
/// This argument is only used when the vector column has an IVF PQ index.
|
||||
@@ -854,6 +871,7 @@ mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::*;
|
||||
use arrow::{compute::concat_batches, datatypes::Int32Type};
|
||||
use arrow_array::{
|
||||
cast::AsArray, Float32Array, Int32Array, RecordBatch, RecordBatchIterator,
|
||||
RecordBatchReader,
|
||||
@@ -883,7 +901,10 @@ mod tests {
|
||||
|
||||
let vector = Float32Array::from_iter_values([0.1, 0.2]);
|
||||
let query = table.query().nearest_to(&[0.1, 0.2]).unwrap();
|
||||
assert_eq!(*query.query_vector.unwrap().as_ref().as_primitive(), vector);
|
||||
assert_eq!(
|
||||
*query.query_vector.first().unwrap().as_ref().as_primitive(),
|
||||
vector
|
||||
);
|
||||
|
||||
let new_vector = Float32Array::from_iter_values([9.8, 8.7]);
|
||||
|
||||
@@ -899,7 +920,7 @@ mod tests {
|
||||
.refine_factor(999);
|
||||
|
||||
assert_eq!(
|
||||
*query.query_vector.unwrap().as_ref().as_primitive(),
|
||||
*query.query_vector.first().unwrap().as_ref().as_primitive(),
|
||||
new_vector
|
||||
);
|
||||
assert_eq!(query.base.limit.unwrap(), 100);
|
||||
@@ -1197,4 +1218,34 @@ mod tests {
|
||||
assert!(batch.column_by_name("_rowid").is_some());
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multiple_query_vectors() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let table = make_test_table(&tmp_dir).await;
|
||||
let query = table
|
||||
.query()
|
||||
.nearest_to(&[0.1, 0.2, 0.3, 0.4])
|
||||
.unwrap()
|
||||
.add_query_vector(&[0.5, 0.6, 0.7, 0.8])
|
||||
.unwrap()
|
||||
.limit(1);
|
||||
|
||||
let plan = query.explain_plan(true).await.unwrap();
|
||||
assert!(plan.contains("UnionExec"));
|
||||
|
||||
let results = query
|
||||
.execute()
|
||||
.await
|
||||
.unwrap()
|
||||
.try_collect::<Vec<_>>()
|
||||
.await
|
||||
.unwrap();
|
||||
let results = concat_batches(&results[0].schema(), &results).unwrap();
|
||||
assert_eq!(results.num_rows(), 2); // One result for each query vector.
|
||||
let query_index = results["query_index"].as_primitive::<Int32Type>();
|
||||
// We don't guarantee order.
|
||||
assert!(query_index.values().contains(&0));
|
||||
assert!(query_index.values().contains(&1));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ use crate::index::IndexStatistics;
|
||||
use crate::query::Select;
|
||||
use crate::table::AddDataMode;
|
||||
use crate::utils::{supported_btree_data_type, supported_vector_data_type};
|
||||
use crate::Error;
|
||||
use crate::{Error, Table};
|
||||
use arrow_array::RecordBatchReader;
|
||||
use arrow_ipc::reader::FileReader;
|
||||
use arrow_schema::{DataType, SchemaRef};
|
||||
@@ -185,6 +185,71 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn apply_vector_query_params(
|
||||
mut body: serde_json::Value,
|
||||
query: &VectorQuery,
|
||||
) -> Result<Vec<serde_json::Value>> {
|
||||
Self::apply_query_params(&mut body, &query.base)?;
|
||||
|
||||
// Apply general parameters, before we dispatch based on number of query vectors.
|
||||
body["prefilter"] = query.base.prefilter.into();
|
||||
body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default());
|
||||
body["nprobes"] = query.nprobes.into();
|
||||
body["refine_factor"] = query.refine_factor.into();
|
||||
if let Some(vector_column) = query.column.as_ref() {
|
||||
body["vector_column"] = serde_json::Value::String(vector_column.clone());
|
||||
}
|
||||
if !query.use_index {
|
||||
body["bypass_vector_index"] = serde_json::Value::Bool(true);
|
||||
}
|
||||
|
||||
fn vector_to_json(vector: &arrow_array::ArrayRef) -> Result<serde_json::Value> {
|
||||
match vector.data_type() {
|
||||
DataType::Float32 => {
|
||||
let array = vector
|
||||
.as_any()
|
||||
.downcast_ref::<arrow_array::Float32Array>()
|
||||
.unwrap();
|
||||
Ok(serde_json::Value::Array(
|
||||
array
|
||||
.values()
|
||||
.iter()
|
||||
.map(|v| {
|
||||
serde_json::Value::Number(
|
||||
serde_json::Number::from_f64(*v as f64).unwrap(),
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
))
|
||||
}
|
||||
_ => Err(Error::InvalidInput {
|
||||
message: "VectorQuery vector must be of type Float32".into(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
match query.query_vector.len() {
|
||||
0 => {
|
||||
// Server takes empty vector, not null or undefined.
|
||||
body["vector"] = serde_json::Value::Array(Vec::new());
|
||||
Ok(vec![body])
|
||||
}
|
||||
1 => {
|
||||
body["vector"] = vector_to_json(&query.query_vector[0])?;
|
||||
Ok(vec![body])
|
||||
}
|
||||
_ => {
|
||||
let mut bodies = Vec::with_capacity(query.query_vector.len());
|
||||
for vector in &query.query_vector {
|
||||
let mut body = body.clone();
|
||||
body["vector"] = vector_to_json(vector)?;
|
||||
bodies.push(body);
|
||||
}
|
||||
Ok(bodies)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -306,51 +371,29 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
) -> Result<Arc<dyn ExecutionPlan>> {
|
||||
let request = self.client.post(&format!("/v1/table/{}/query/", self.name));
|
||||
|
||||
let mut body = serde_json::Value::Object(Default::default());
|
||||
Self::apply_query_params(&mut body, &query.base)?;
|
||||
let body = serde_json::Value::Object(Default::default());
|
||||
let bodies = Self::apply_vector_query_params(body, query)?;
|
||||
|
||||
body["prefilter"] = query.base.prefilter.into();
|
||||
body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default());
|
||||
body["nprobes"] = query.nprobes.into();
|
||||
body["refine_factor"] = query.refine_factor.into();
|
||||
|
||||
let vector: Vec<f32> = if let Some(vector) = query.query_vector.as_ref() {
|
||||
match vector.data_type() {
|
||||
DataType::Float32 => vector
|
||||
.as_any()
|
||||
.downcast_ref::<arrow_array::Float32Array>()
|
||||
.unwrap()
|
||||
.values()
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect(),
|
||||
_ => {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "VectorQuery vector must be of type Float32".into(),
|
||||
})
|
||||
}
|
||||
}
|
||||
let mut futures = Vec::with_capacity(bodies.len());
|
||||
for body in bodies {
|
||||
let request = request.try_clone().unwrap().json(&body);
|
||||
let future = async move {
|
||||
let (request_id, response) = self.client.send(request, true).await?;
|
||||
self.read_arrow_stream(&request_id, response).await
|
||||
};
|
||||
futures.push(future);
|
||||
}
|
||||
let streams = futures::future::try_join_all(futures).await?;
|
||||
if streams.len() == 1 {
|
||||
let stream = streams.into_iter().next().unwrap();
|
||||
Ok(Arc::new(OneShotExec::new(stream)))
|
||||
} else {
|
||||
// Server takes empty vector, not null or undefined.
|
||||
Vec::new()
|
||||
};
|
||||
body["vector"] = serde_json::json!(vector);
|
||||
|
||||
if let Some(vector_column) = query.column.as_ref() {
|
||||
body["vector_column"] = serde_json::Value::String(vector_column.clone());
|
||||
let stream_execs = streams
|
||||
.into_iter()
|
||||
.map(|stream| Arc::new(OneShotExec::new(stream)) as Arc<dyn ExecutionPlan>)
|
||||
.collect();
|
||||
Table::multi_vector_plan(stream_execs)
|
||||
}
|
||||
|
||||
if !query.use_index {
|
||||
body["bypass_vector_index"] = serde_json::Value::Bool(true);
|
||||
}
|
||||
|
||||
let request = request.json(&body);
|
||||
|
||||
let (request_id, response) = self.client.send(request, true).await?;
|
||||
|
||||
let stream = self.read_arrow_stream(&request_id, response).await?;
|
||||
|
||||
Ok(Arc::new(OneShotExec::new(stream)))
|
||||
}
|
||||
|
||||
async fn plain_query(
|
||||
@@ -655,6 +698,7 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
|
||||
use arrow::{array::AsArray, compute::concat_batches, datatypes::Int32Type};
|
||||
use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use futures::{future::BoxFuture, StreamExt, TryFutureExt};
|
||||
@@ -1207,6 +1251,52 @@ mod tests {
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_query_multiple_vectors() {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/v1/table/my_table/query/");
|
||||
assert_eq!(
|
||||
request.headers().get("Content-Type").unwrap(),
|
||||
JSON_CONTENT_TYPE
|
||||
);
|
||||
let data = RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
|
||||
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
|
||||
)
|
||||
.unwrap();
|
||||
let response_body = write_ipc_file(&data);
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE)
|
||||
.body(response_body)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
let query = table
|
||||
.query()
|
||||
.nearest_to(vec![0.1, 0.2, 0.3])
|
||||
.unwrap()
|
||||
.add_query_vector(vec![0.4, 0.5, 0.6])
|
||||
.unwrap();
|
||||
let plan = query.explain_plan(true).await.unwrap();
|
||||
assert!(plan.contains("UnionExec"), "Plan: {}", plan);
|
||||
|
||||
let results = query
|
||||
.execute()
|
||||
.await
|
||||
.unwrap()
|
||||
.try_collect::<Vec<_>>()
|
||||
.await
|
||||
.unwrap();
|
||||
let results = concat_batches(&results[0].schema(), &results).unwrap();
|
||||
|
||||
let query_index = results["query_index"].as_primitive::<Int32Type>();
|
||||
// We don't guarantee order.
|
||||
assert!(query_index.values().contains(&0));
|
||||
assert!(query_index.values().contains(&1));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_index() {
|
||||
let cases = [
|
||||
|
||||
@@ -24,6 +24,9 @@ use arrow_array::{RecordBatchIterator, RecordBatchReader};
|
||||
use arrow_schema::{Field, Schema, SchemaRef};
|
||||
use async_trait::async_trait;
|
||||
use datafusion_physical_plan::display::DisplayableExecutionPlan;
|
||||
use datafusion_physical_plan::projection::ProjectionExec;
|
||||
use datafusion_physical_plan::repartition::RepartitionExec;
|
||||
use datafusion_physical_plan::union::UnionExec;
|
||||
use datafusion_physical_plan::ExecutionPlan;
|
||||
use futures::{StreamExt, TryStreamExt};
|
||||
use lance::dataset::builder::DatasetBuilder;
|
||||
@@ -972,6 +975,57 @@ impl Table {
|
||||
) -> Result<Option<IndexStatistics>> {
|
||||
self.inner.index_stats(index_name.as_ref()).await
|
||||
}
|
||||
|
||||
// Take many execution plans and map them into a single plan that adds
|
||||
// a query_index column and unions them.
|
||||
pub(crate) fn multi_vector_plan(
|
||||
plans: Vec<Arc<dyn ExecutionPlan>>,
|
||||
) -> Result<Arc<dyn ExecutionPlan>> {
|
||||
if plans.is_empty() {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "No plans provided".to_string(),
|
||||
});
|
||||
}
|
||||
// Projection to keeping all existing columns
|
||||
let first_plan = plans[0].clone();
|
||||
let project_all_columns = first_plan
|
||||
.schema()
|
||||
.fields()
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, field)| {
|
||||
let expr =
|
||||
datafusion_physical_plan::expressions::Column::new(field.name().as_str(), i);
|
||||
let expr = Arc::new(expr) as Arc<dyn datafusion_physical_plan::PhysicalExpr>;
|
||||
(expr, field.name().clone())
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let projected_plans = plans
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(plan_i, plan)| {
|
||||
let query_index = datafusion_common::ScalarValue::Int32(Some(plan_i as i32));
|
||||
let query_index_expr =
|
||||
datafusion_physical_plan::expressions::Literal::new(query_index);
|
||||
let query_index_expr =
|
||||
Arc::new(query_index_expr) as Arc<dyn datafusion_physical_plan::PhysicalExpr>;
|
||||
let mut projections = vec![(query_index_expr, "query_index".to_string())];
|
||||
projections.extend_from_slice(&project_all_columns);
|
||||
let projection = ProjectionExec::try_new(projections, plan).unwrap();
|
||||
Arc::new(projection) as Arc<dyn datafusion_physical_plan::ExecutionPlan>
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let unioned = Arc::new(UnionExec::new(projected_plans));
|
||||
// We require 1 partition in the final output
|
||||
let repartitioned = RepartitionExec::try_new(
|
||||
unioned,
|
||||
datafusion_physical_plan::Partitioning::RoundRobinBatch(1),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(Arc::new(repartitioned))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<NativeTable> for Table {
|
||||
@@ -1784,9 +1838,25 @@ impl TableInternal for NativeTable {
|
||||
) -> Result<Arc<dyn ExecutionPlan>> {
|
||||
let ds_ref = self.dataset.get().await?;
|
||||
|
||||
if query.query_vector.len() > 1 {
|
||||
// If there are multiple query vectors, create a plan for each of them and union them.
|
||||
let query_vecs = query.query_vector.clone();
|
||||
let plan_futures = query_vecs
|
||||
.into_iter()
|
||||
.map(|query_vector| {
|
||||
let mut sub_query = query.clone();
|
||||
sub_query.query_vector = vec![query_vector];
|
||||
let options_ref = options.clone();
|
||||
async move { self.create_plan(&sub_query, options_ref).await }
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let plans = futures::future::try_join_all(plan_futures).await?;
|
||||
return Table::multi_vector_plan(plans);
|
||||
}
|
||||
|
||||
let mut scanner: Scanner = ds_ref.scan();
|
||||
|
||||
if let Some(query_vector) = query.query_vector.as_ref() {
|
||||
if let Some(query_vector) = query.query_vector.first() {
|
||||
// If there is a vector query, default to limit=10 if unspecified
|
||||
let column = if let Some(col) = query.column.as_ref() {
|
||||
col.clone()
|
||||
@@ -1828,18 +1898,11 @@ impl TableInternal for NativeTable {
|
||||
query_vector,
|
||||
query.base.limit.unwrap_or(DEFAULT_TOP_K),
|
||||
)?;
|
||||
scanner.limit(
|
||||
query.base.limit.map(|limit| limit as i64),
|
||||
query.base.offset.map(|offset| offset as i64),
|
||||
)?;
|
||||
} else {
|
||||
// If there is no vector query, it's ok to not have a limit
|
||||
scanner.limit(
|
||||
query.base.limit.map(|limit| limit as i64),
|
||||
query.base.offset.map(|offset| offset as i64),
|
||||
)?;
|
||||
}
|
||||
|
||||
scanner.limit(
|
||||
query.base.limit.map(|limit| limit as i64),
|
||||
query.base.offset.map(|offset| offset as i64),
|
||||
)?;
|
||||
scanner.nprobs(query.nprobes);
|
||||
scanner.use_index(query.use_index);
|
||||
scanner.prefilter(query.base.prefilter);
|
||||
|
||||
Reference in New Issue
Block a user