mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-05 19:32:56 +00:00
feat: support vector search with distance thresholds (#1993)
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
This commit is contained in:
16
Cargo.toml
16
Cargo.toml
@@ -23,14 +23,14 @@ rust-version = "1.78.0"
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=0.21.1", "features" = [
|
||||
"dynamodb",
|
||||
], git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
|
||||
lance-io = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
|
||||
lance-index = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
|
||||
lance-linalg = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
|
||||
lance-table = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
|
||||
lance-testing = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
|
||||
lance-datafusion = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
|
||||
lance-encoding = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
|
||||
], git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
|
||||
lance-io = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
|
||||
lance-index = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
|
||||
lance-linalg = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
|
||||
lance-table = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
|
||||
lance-testing = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
|
||||
lance-datafusion = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
|
||||
lance-encoding = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "53.2", optional = false }
|
||||
arrow-array = "53.2"
|
||||
|
||||
@@ -115,6 +115,9 @@ class Query(pydantic.BaseModel):
|
||||
# e.g. `{"nprobes": "10", "refine_factor": "10"}`
|
||||
nprobes: int = 10
|
||||
|
||||
lower_bound: Optional[float] = None
|
||||
upper_bound: Optional[float] = None
|
||||
|
||||
# Refine factor.
|
||||
refine_factor: Optional[int] = None
|
||||
|
||||
@@ -604,6 +607,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
self._query = query
|
||||
self._metric = "L2"
|
||||
self._nprobes = 20
|
||||
self._lower_bound = None
|
||||
self._upper_bound = None
|
||||
self._refine_factor = None
|
||||
self._vector_column = vector_column
|
||||
self._prefilter = False
|
||||
@@ -649,6 +654,30 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
self._nprobes = nprobes
|
||||
return self
|
||||
|
||||
def distance_range(
|
||||
self, lower_bound: Optional[float] = None, upper_bound: Optional[float] = None
|
||||
) -> LanceVectorQueryBuilder:
|
||||
"""Set the distance range to use.
|
||||
|
||||
Only rows with distances within range [lower_bound, upper_bound)
|
||||
will be returned.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
lower: Optional[float]
|
||||
The lower bound of the distance range.
|
||||
upper_bound: Optional[float]
|
||||
The upper bound of the distance range.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceVectorQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._lower_bound = lower_bound
|
||||
self._upper_bound = upper_bound
|
||||
return self
|
||||
|
||||
def ef(self, ef: int) -> LanceVectorQueryBuilder:
|
||||
"""Set the number of candidates to consider during search.
|
||||
|
||||
@@ -728,6 +757,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
metric=self._metric,
|
||||
columns=self._columns,
|
||||
nprobes=self._nprobes,
|
||||
lower_bound=self._lower_bound,
|
||||
upper_bound=self._upper_bound,
|
||||
refine_factor=self._refine_factor,
|
||||
vector_column=self._vector_column,
|
||||
with_row_id=self._with_row_id,
|
||||
@@ -1284,6 +1315,31 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
self._nprobes = nprobes
|
||||
return self
|
||||
|
||||
def distance_range(
|
||||
self, lower_bound: Optional[float] = None, upper_bound: Optional[float] = None
|
||||
) -> LanceHybridQueryBuilder:
|
||||
"""
|
||||
Set the distance range to use.
|
||||
|
||||
Only rows with distances within range [lower_bound, upper_bound)
|
||||
will be returned.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
lower: Optional[float]
|
||||
The lower bound of the distance range.
|
||||
upper_bound: Optional[float]
|
||||
The upper bound of the distance range.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceHybridQueryBuilder
|
||||
The LanceHybridQueryBuilder object.
|
||||
"""
|
||||
self._lower_bound = lower_bound
|
||||
self._upper_bound = upper_bound
|
||||
return self
|
||||
|
||||
def ef(self, ef: int) -> LanceHybridQueryBuilder:
|
||||
"""
|
||||
Set the number of candidates to consider during search.
|
||||
@@ -1855,6 +1911,29 @@ class AsyncVectorQuery(AsyncQueryBase):
|
||||
self._inner.nprobes(nprobes)
|
||||
return self
|
||||
|
||||
def distance_range(
|
||||
self, lower_bound: Optional[float] = None, upper_bound: Optional[float] = None
|
||||
) -> AsyncVectorQuery:
|
||||
"""Set the distance range to use.
|
||||
|
||||
Only rows with distances within range [lower_bound, upper_bound)
|
||||
will be returned.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
lower: Optional[float]
|
||||
The lower bound of the distance range.
|
||||
upper_bound: Optional[float]
|
||||
The upper bound of the distance range.
|
||||
|
||||
Returns
|
||||
-------
|
||||
AsyncVectorQuery
|
||||
The AsyncVectorQuery object.
|
||||
"""
|
||||
self._inner.distance_range(lower_bound, upper_bound)
|
||||
return self
|
||||
|
||||
def ef(self, ef: int) -> AsyncVectorQuery:
|
||||
"""
|
||||
Set the number of candidates to consider during search
|
||||
|
||||
@@ -2786,6 +2786,7 @@ class AsyncTable:
|
||||
async_query.nearest_to(query.vector)
|
||||
.distance_type(query.metric)
|
||||
.nprobes(query.nprobes)
|
||||
.distance_range(query.lower_bound, query.upper_bound)
|
||||
)
|
||||
if query.refine_factor:
|
||||
async_query = async_query.refine_factor(query.refine_factor)
|
||||
|
||||
@@ -94,6 +94,73 @@ def test_with_row_id(table: lancedb.table.Table):
|
||||
assert rs["_rowid"].to_pylist() == [0, 1]
|
||||
|
||||
|
||||
def test_distance_range(table: lancedb.table.Table):
|
||||
q = [0, 0]
|
||||
rs = table.search(q).to_arrow()
|
||||
dists = rs["_distance"].to_pylist()
|
||||
min_dist = dists[0]
|
||||
max_dist = dists[-1]
|
||||
|
||||
res = table.search(q).distance_range(upper_bound=min_dist).to_arrow()
|
||||
assert len(res) == 0
|
||||
|
||||
res = table.search(q).distance_range(lower_bound=max_dist).to_arrow()
|
||||
assert len(res) == 1
|
||||
assert res["_distance"].to_pylist() == [max_dist]
|
||||
|
||||
res = table.search(q).distance_range(upper_bound=max_dist).to_arrow()
|
||||
assert len(res) == 1
|
||||
assert res["_distance"].to_pylist() == [min_dist]
|
||||
|
||||
res = table.search(q).distance_range(lower_bound=min_dist).to_arrow()
|
||||
assert len(res) == 2
|
||||
assert res["_distance"].to_pylist() == [min_dist, max_dist]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_distance_range_async(table_async: AsyncTable):
|
||||
q = [0, 0]
|
||||
rs = await table_async.query().nearest_to(q).to_arrow()
|
||||
dists = rs["_distance"].to_pylist()
|
||||
min_dist = dists[0]
|
||||
max_dist = dists[-1]
|
||||
|
||||
res = (
|
||||
await table_async.query()
|
||||
.nearest_to(q)
|
||||
.distance_range(upper_bound=min_dist)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(res) == 0
|
||||
|
||||
res = (
|
||||
await table_async.query()
|
||||
.nearest_to(q)
|
||||
.distance_range(lower_bound=max_dist)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(res) == 1
|
||||
assert res["_distance"].to_pylist() == [max_dist]
|
||||
|
||||
res = (
|
||||
await table_async.query()
|
||||
.nearest_to(q)
|
||||
.distance_range(upper_bound=max_dist)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(res) == 1
|
||||
assert res["_distance"].to_pylist() == [min_dist]
|
||||
|
||||
res = (
|
||||
await table_async.query()
|
||||
.nearest_to(q)
|
||||
.distance_range(lower_bound=min_dist)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(res) == 2
|
||||
assert res["_distance"].to_pylist() == [min_dist, max_dist]
|
||||
|
||||
|
||||
def test_vector_query_with_no_limit(table):
|
||||
with pytest.raises(ValueError):
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector").limit(0).select(
|
||||
|
||||
@@ -306,6 +306,8 @@ def test_query_sync_minimal():
|
||||
"k": 10,
|
||||
"prefilter": False,
|
||||
"refine_factor": None,
|
||||
"lower_bound": None,
|
||||
"upper_bound": None,
|
||||
"ef": None,
|
||||
"vector": [1.0, 2.0, 3.0],
|
||||
"nprobes": 20,
|
||||
@@ -348,6 +350,8 @@ def test_query_sync_maximal():
|
||||
"refine_factor": 10,
|
||||
"vector": [1.0, 2.0, 3.0],
|
||||
"nprobes": 5,
|
||||
"lower_bound": None,
|
||||
"upper_bound": None,
|
||||
"ef": None,
|
||||
"filter": "id > 0",
|
||||
"columns": ["id", "name"],
|
||||
@@ -449,6 +453,8 @@ def test_query_sync_hybrid():
|
||||
"refine_factor": None,
|
||||
"vector": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
"nprobes": 20,
|
||||
"lower_bound": None,
|
||||
"upper_bound": None,
|
||||
"ef": None,
|
||||
"with_row_id": True,
|
||||
"version": None,
|
||||
|
||||
@@ -284,6 +284,11 @@ impl VectorQuery {
|
||||
self.inner = self.inner.clone().nprobes(nprobe as usize);
|
||||
}
|
||||
|
||||
#[pyo3(signature = (lower_bound=None, upper_bound=None))]
|
||||
pub fn distance_range(&mut self, lower_bound: Option<f32>, upper_bound: Option<f32>) {
|
||||
self.inner = self.inner.clone().distance_range(lower_bound, upper_bound);
|
||||
}
|
||||
|
||||
pub fn ef(&mut self, ef: u32) {
|
||||
self.inner = self.inner.clone().ef(ef as usize);
|
||||
}
|
||||
|
||||
@@ -755,6 +755,10 @@ pub struct VectorQuery {
|
||||
// IVF PQ - ANN search.
|
||||
pub(crate) query_vector: Vec<Arc<dyn Array>>,
|
||||
pub(crate) nprobes: usize,
|
||||
// The lower bound (inclusive) of the distance to search for.
|
||||
pub(crate) lower_bound: Option<f32>,
|
||||
// The upper bound (exclusive) of the distance to search for.
|
||||
pub(crate) upper_bound: Option<f32>,
|
||||
// The number of candidates to return during the refine step for HNSW,
|
||||
// defaults to 1.5 * limit.
|
||||
pub(crate) ef: Option<usize>,
|
||||
@@ -771,6 +775,8 @@ impl VectorQuery {
|
||||
column: None,
|
||||
query_vector: Vec::new(),
|
||||
nprobes: 20,
|
||||
lower_bound: None,
|
||||
upper_bound: None,
|
||||
ef: None,
|
||||
refine_factor: None,
|
||||
distance_type: None,
|
||||
@@ -831,6 +837,14 @@ impl VectorQuery {
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the distance range for vector search,
|
||||
/// only rows with distances in the range [lower_bound, upper_bound) will be returned
|
||||
pub fn distance_range(mut self, lower_bound: Option<f32>, upper_bound: Option<f32>) -> Self {
|
||||
self.lower_bound = lower_bound;
|
||||
self.upper_bound = upper_bound;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the number of candidates to return during the refine step for HNSW
|
||||
///
|
||||
/// This argument is only used when the vector column has an HNSW index.
|
||||
@@ -1350,6 +1364,30 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_distance_range() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let table = make_test_table(&tmp_dir).await;
|
||||
let results = table
|
||||
.vector_search(&[0.1, 0.2, 0.3, 0.4])
|
||||
.unwrap()
|
||||
.distance_range(Some(0.0), Some(1.0))
|
||||
.limit(10)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap()
|
||||
.try_collect::<Vec<_>>()
|
||||
.await
|
||||
.unwrap();
|
||||
for batch in results {
|
||||
let distances = batch["_distance"].as_primitive::<Float32Type>();
|
||||
assert!(distances.iter().all(|d| {
|
||||
let d = d.unwrap();
|
||||
(0.0..1.0).contains(&d)
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multiple_query_vectors() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
|
||||
@@ -210,6 +210,8 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
body["prefilter"] = query.base.prefilter.into();
|
||||
body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default());
|
||||
body["nprobes"] = query.nprobes.into();
|
||||
body["lower_bound"] = query.lower_bound.into();
|
||||
body["upper_bound"] = query.upper_bound.into();
|
||||
body["ef"] = query.ef.into();
|
||||
body["refine_factor"] = query.refine_factor.into();
|
||||
if let Some(vector_column) = query.column.as_ref() {
|
||||
@@ -1304,6 +1306,8 @@ mod tests {
|
||||
"prefilter": true,
|
||||
"distance_type": "l2",
|
||||
"nprobes": 20,
|
||||
"lower_bound": Option::<f32>::None,
|
||||
"upper_bound": Option::<f32>::None,
|
||||
"k": 10,
|
||||
"ef": Option::<usize>::None,
|
||||
"refine_factor": null,
|
||||
@@ -1353,6 +1357,8 @@ mod tests {
|
||||
"bypass_vector_index": true,
|
||||
"columns": ["a", "b"],
|
||||
"nprobes": 12,
|
||||
"lower_bound": Option::<f32>::None,
|
||||
"upper_bound": Option::<f32>::None,
|
||||
"ef": Option::<usize>::None,
|
||||
"refine_factor": 2,
|
||||
"version": null,
|
||||
|
||||
@@ -1944,6 +1944,7 @@ impl TableInternal for NativeTable {
|
||||
if let Some(ef) = query.ef {
|
||||
scanner.ef(ef);
|
||||
}
|
||||
scanner.distance_range(query.lower_bound, query.upper_bound);
|
||||
scanner.use_index(query.use_index);
|
||||
scanner.prefilter(query.base.prefilter);
|
||||
match query.base.select {
|
||||
|
||||
Reference in New Issue
Block a user