feat: support vector search with distance thresholds (#1993)

Signed-off-by: BubbleCal <bubble-cal@outlook.com>
This commit is contained in:
BubbleCal
2025-01-06 13:23:39 +08:00
committed by GitHub
parent f76c4a5ce1
commit f4dea72cc5
9 changed files with 211 additions and 8 deletions

View File

@@ -23,14 +23,14 @@ rust-version = "1.78.0"
[workspace.dependencies]
lance = { "version" = "=0.21.1", "features" = [
"dynamodb",
], git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
lance-io = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
lance-index = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
lance-linalg = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
lance-table = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
lance-testing = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
lance-datafusion = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
lance-encoding = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
], git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
lance-io = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
lance-index = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
lance-linalg = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
lance-table = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
lance-testing = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
lance-datafusion = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
lance-encoding = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
# Note that this one does not include pyarrow
arrow = { version = "53.2", optional = false }
arrow-array = "53.2"

View File

@@ -115,6 +115,9 @@ class Query(pydantic.BaseModel):
# e.g. `{"nprobes": "10", "refine_factor": "10"}`
nprobes: int = 10
lower_bound: Optional[float] = None
upper_bound: Optional[float] = None
# Refine factor.
refine_factor: Optional[int] = None
@@ -604,6 +607,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
self._query = query
self._metric = "L2"
self._nprobes = 20
self._lower_bound = None
self._upper_bound = None
self._refine_factor = None
self._vector_column = vector_column
self._prefilter = False
@@ -649,6 +654,30 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
self._nprobes = nprobes
return self
def distance_range(
self, lower_bound: Optional[float] = None, upper_bound: Optional[float] = None
) -> LanceVectorQueryBuilder:
"""Set the distance range to use.
Only rows with distances within range [lower_bound, upper_bound)
will be returned.
Parameters
----------
lower: Optional[float]
The lower bound of the distance range.
upper_bound: Optional[float]
The upper bound of the distance range.
Returns
-------
LanceVectorQueryBuilder
The LanceQueryBuilder object.
"""
self._lower_bound = lower_bound
self._upper_bound = upper_bound
return self
def ef(self, ef: int) -> LanceVectorQueryBuilder:
"""Set the number of candidates to consider during search.
@@ -728,6 +757,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
metric=self._metric,
columns=self._columns,
nprobes=self._nprobes,
lower_bound=self._lower_bound,
upper_bound=self._upper_bound,
refine_factor=self._refine_factor,
vector_column=self._vector_column,
with_row_id=self._with_row_id,
@@ -1284,6 +1315,31 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
self._nprobes = nprobes
return self
def distance_range(
self, lower_bound: Optional[float] = None, upper_bound: Optional[float] = None
) -> LanceHybridQueryBuilder:
"""
Set the distance range to use.
Only rows with distances within range [lower_bound, upper_bound)
will be returned.
Parameters
----------
lower: Optional[float]
The lower bound of the distance range.
upper_bound: Optional[float]
The upper bound of the distance range.
Returns
-------
LanceHybridQueryBuilder
The LanceHybridQueryBuilder object.
"""
self._lower_bound = lower_bound
self._upper_bound = upper_bound
return self
def ef(self, ef: int) -> LanceHybridQueryBuilder:
"""
Set the number of candidates to consider during search.
@@ -1855,6 +1911,29 @@ class AsyncVectorQuery(AsyncQueryBase):
self._inner.nprobes(nprobes)
return self
def distance_range(
self, lower_bound: Optional[float] = None, upper_bound: Optional[float] = None
) -> AsyncVectorQuery:
"""Set the distance range to use.
Only rows with distances within range [lower_bound, upper_bound)
will be returned.
Parameters
----------
lower: Optional[float]
The lower bound of the distance range.
upper_bound: Optional[float]
The upper bound of the distance range.
Returns
-------
AsyncVectorQuery
The AsyncVectorQuery object.
"""
self._inner.distance_range(lower_bound, upper_bound)
return self
def ef(self, ef: int) -> AsyncVectorQuery:
"""
Set the number of candidates to consider during search

View File

@@ -2786,6 +2786,7 @@ class AsyncTable:
async_query.nearest_to(query.vector)
.distance_type(query.metric)
.nprobes(query.nprobes)
.distance_range(query.lower_bound, query.upper_bound)
)
if query.refine_factor:
async_query = async_query.refine_factor(query.refine_factor)

View File

@@ -94,6 +94,73 @@ def test_with_row_id(table: lancedb.table.Table):
assert rs["_rowid"].to_pylist() == [0, 1]
def test_distance_range(table: lancedb.table.Table):
q = [0, 0]
rs = table.search(q).to_arrow()
dists = rs["_distance"].to_pylist()
min_dist = dists[0]
max_dist = dists[-1]
res = table.search(q).distance_range(upper_bound=min_dist).to_arrow()
assert len(res) == 0
res = table.search(q).distance_range(lower_bound=max_dist).to_arrow()
assert len(res) == 1
assert res["_distance"].to_pylist() == [max_dist]
res = table.search(q).distance_range(upper_bound=max_dist).to_arrow()
assert len(res) == 1
assert res["_distance"].to_pylist() == [min_dist]
res = table.search(q).distance_range(lower_bound=min_dist).to_arrow()
assert len(res) == 2
assert res["_distance"].to_pylist() == [min_dist, max_dist]
@pytest.mark.asyncio
async def test_distance_range_async(table_async: AsyncTable):
q = [0, 0]
rs = await table_async.query().nearest_to(q).to_arrow()
dists = rs["_distance"].to_pylist()
min_dist = dists[0]
max_dist = dists[-1]
res = (
await table_async.query()
.nearest_to(q)
.distance_range(upper_bound=min_dist)
.to_arrow()
)
assert len(res) == 0
res = (
await table_async.query()
.nearest_to(q)
.distance_range(lower_bound=max_dist)
.to_arrow()
)
assert len(res) == 1
assert res["_distance"].to_pylist() == [max_dist]
res = (
await table_async.query()
.nearest_to(q)
.distance_range(upper_bound=max_dist)
.to_arrow()
)
assert len(res) == 1
assert res["_distance"].to_pylist() == [min_dist]
res = (
await table_async.query()
.nearest_to(q)
.distance_range(lower_bound=min_dist)
.to_arrow()
)
assert len(res) == 2
assert res["_distance"].to_pylist() == [min_dist, max_dist]
def test_vector_query_with_no_limit(table):
with pytest.raises(ValueError):
LanceVectorQueryBuilder(table, [0, 0], "vector").limit(0).select(

View File

@@ -306,6 +306,8 @@ def test_query_sync_minimal():
"k": 10,
"prefilter": False,
"refine_factor": None,
"lower_bound": None,
"upper_bound": None,
"ef": None,
"vector": [1.0, 2.0, 3.0],
"nprobes": 20,
@@ -348,6 +350,8 @@ def test_query_sync_maximal():
"refine_factor": 10,
"vector": [1.0, 2.0, 3.0],
"nprobes": 5,
"lower_bound": None,
"upper_bound": None,
"ef": None,
"filter": "id > 0",
"columns": ["id", "name"],
@@ -449,6 +453,8 @@ def test_query_sync_hybrid():
"refine_factor": None,
"vector": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
"nprobes": 20,
"lower_bound": None,
"upper_bound": None,
"ef": None,
"with_row_id": True,
"version": None,

View File

@@ -284,6 +284,11 @@ impl VectorQuery {
self.inner = self.inner.clone().nprobes(nprobe as usize);
}
#[pyo3(signature = (lower_bound=None, upper_bound=None))]
pub fn distance_range(&mut self, lower_bound: Option<f32>, upper_bound: Option<f32>) {
self.inner = self.inner.clone().distance_range(lower_bound, upper_bound);
}
pub fn ef(&mut self, ef: u32) {
self.inner = self.inner.clone().ef(ef as usize);
}

View File

@@ -755,6 +755,10 @@ pub struct VectorQuery {
// IVF PQ - ANN search.
pub(crate) query_vector: Vec<Arc<dyn Array>>,
pub(crate) nprobes: usize,
// The lower bound (inclusive) of the distance to search for.
pub(crate) lower_bound: Option<f32>,
// The upper bound (exclusive) of the distance to search for.
pub(crate) upper_bound: Option<f32>,
// The number of candidates to return during the refine step for HNSW,
// defaults to 1.5 * limit.
pub(crate) ef: Option<usize>,
@@ -771,6 +775,8 @@ impl VectorQuery {
column: None,
query_vector: Vec::new(),
nprobes: 20,
lower_bound: None,
upper_bound: None,
ef: None,
refine_factor: None,
distance_type: None,
@@ -831,6 +837,14 @@ impl VectorQuery {
self
}
/// Set the distance range for vector search,
/// only rows with distances in the range [lower_bound, upper_bound) will be returned
pub fn distance_range(mut self, lower_bound: Option<f32>, upper_bound: Option<f32>) -> Self {
self.lower_bound = lower_bound;
self.upper_bound = upper_bound;
self
}
/// Set the number of candidates to return during the refine step for HNSW
///
/// This argument is only used when the vector column has an HNSW index.
@@ -1350,6 +1364,30 @@ mod tests {
}
}
#[tokio::test]
async fn test_distance_range() {
let tmp_dir = tempdir().unwrap();
let table = make_test_table(&tmp_dir).await;
let results = table
.vector_search(&[0.1, 0.2, 0.3, 0.4])
.unwrap()
.distance_range(Some(0.0), Some(1.0))
.limit(10)
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
for batch in results {
let distances = batch["_distance"].as_primitive::<Float32Type>();
assert!(distances.iter().all(|d| {
let d = d.unwrap();
(0.0..1.0).contains(&d)
}));
}
}
#[tokio::test]
async fn test_multiple_query_vectors() {
let tmp_dir = tempdir().unwrap();

View File

@@ -210,6 +210,8 @@ impl<S: HttpSend> RemoteTable<S> {
body["prefilter"] = query.base.prefilter.into();
body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default());
body["nprobes"] = query.nprobes.into();
body["lower_bound"] = query.lower_bound.into();
body["upper_bound"] = query.upper_bound.into();
body["ef"] = query.ef.into();
body["refine_factor"] = query.refine_factor.into();
if let Some(vector_column) = query.column.as_ref() {
@@ -1304,6 +1306,8 @@ mod tests {
"prefilter": true,
"distance_type": "l2",
"nprobes": 20,
"lower_bound": Option::<f32>::None,
"upper_bound": Option::<f32>::None,
"k": 10,
"ef": Option::<usize>::None,
"refine_factor": null,
@@ -1353,6 +1357,8 @@ mod tests {
"bypass_vector_index": true,
"columns": ["a", "b"],
"nprobes": 12,
"lower_bound": Option::<f32>::None,
"upper_bound": Option::<f32>::None,
"ef": Option::<usize>::None,
"refine_factor": 2,
"version": null,

View File

@@ -1944,6 +1944,7 @@ impl TableInternal for NativeTable {
if let Some(ef) = query.ef {
scanner.ef(ef);
}
scanner.distance_range(query.lower_bound, query.upper_bound);
scanner.use_index(query.use_index);
scanner.prefilter(query.base.prefilter);
match query.base.select {