Compare commits

...

5 Commits

Author SHA1 Message Date
Lance Release
a27c5cf12b Bump version: 0.17.2-beta.1 → 0.17.2-beta.2 2025-01-06 05:34:27 +00:00
BubbleCal
f4dea72cc5 feat: support vector search with distance thresholds (#1993)
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
2025-01-06 13:23:39 +08:00
Lei Xu
f76c4a5ce1 chore: add pyright static type checking and fix some of the table interface (#1996)
* Enable `pyright` in the project
* Fixed some pyright typing errors in `table.py`
2025-01-04 15:24:58 -08:00
ahaapple
164ce397c2 docs: fix full-text search (Native FTS) TypeScript doc error (#1992)
Fix

```
Cannot find name 'queryType'.ts(2304)
any
```
2025-01-03 13:36:10 -05:00
BubbleCal
445a312667 fix: selecting columns failed on FTS and hybrid search (#1991)
it reports error `AttributeError: 'builtins.FTSQuery' object has no
attribute 'select_columns'`
because we missed `select_columns` method in rust

Signed-off-by: BubbleCal <bubble-cal@outlook.com>
2025-01-03 13:08:12 +08:00
19 changed files with 323 additions and 103 deletions

View File

@@ -30,10 +30,10 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
python-version: "3.12"
- name: Install ruff
run: |
pip install ruff==0.5.4
pip install ruff==0.8.4
- name: Format check
run: ruff format --check .
- name: Lint

View File

@@ -23,14 +23,14 @@ rust-version = "1.78.0"
[workspace.dependencies]
lance = { "version" = "=0.21.1", "features" = [
"dynamodb",
], git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
lance-io = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
lance-index = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
lance-linalg = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
lance-table = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
lance-testing = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
lance-datafusion = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
lance-encoding = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
], git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
lance-io = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
lance-index = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
lance-linalg = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
lance-table = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
lance-testing = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
lance-datafusion = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
lance-encoding = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
# Note that this one does not include pyarrow
arrow = { version = "53.2", optional = false }
arrow-array = "53.2"

View File

@@ -50,7 +50,7 @@ Consider that we have a LanceDB table named `my_table`, whose string column `tex
});
await tbl
.search("puppy", queryType="fts")
.search("puppy", "fts")
.select(["text"])
.limit(10)
.toArray();

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.17.2-beta.1"
current_version = "0.17.2-beta.2"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb-python"
version = "0.17.2-beta.1"
version = "0.17.2-beta.2"
edition.workspace = true
description = "Python bindings for LanceDB"
license.workspace = true

View File

@@ -53,8 +53,9 @@ tests = [
"pytz",
"polars>=0.19, <=1.3.0",
"tantivy",
"pyarrow-stubs"
]
dev = ["ruff", "pre-commit"]
dev = ["ruff", "pre-commit", "pyright"]
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
clip = ["torch", "pillow", "open-clip"]
embeddings = [
@@ -94,3 +95,7 @@ markers = [
"asyncio",
"s3_test",
]
[tool.pyright]
include = ["python/lancedb/table.py"]
pythonVersion = "3.12"

View File

@@ -1,7 +1,9 @@
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Any, Union, Literal
import pyarrow as pa
from .index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq, FTS
class Connection(object):
uri: str
async def table_names(
@@ -31,16 +33,35 @@ class Connection(object):
class Table:
def name(self) -> str: ...
def __repr__(self) -> str: ...
def is_open(self) -> bool: ...
def close(self) -> None: ...
async def schema(self) -> pa.Schema: ...
async def add(self, data: pa.RecordBatchReader, mode: str) -> None: ...
async def add(
self, data: pa.RecordBatchReader, mode: Literal["append", "overwrite"]
) -> None: ...
async def update(self, updates: Dict[str, str], where: Optional[str]) -> None: ...
async def count_rows(self, filter: Optional[str]) -> int: ...
async def create_index(self, column: str, config, replace: Optional[bool]): ...
async def create_index(
self,
column: str,
index: Union[IvfFlat, IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS],
replace: Optional[bool],
): ...
async def list_versions(self) -> List[Dict[str, Any]]: ...
async def version(self) -> int: ...
async def checkout(self, version: int): ...
async def checkout_latest(self): ...
async def restore(self): ...
async def list_indices(self) -> List[IndexConfig]: ...
async def list_indices(self) -> list[IndexConfig]: ...
async def delete(self, filter: str): ...
async def add_columns(self, columns: list[tuple[str, str]]) -> None: ...
async def alter_columns(self, columns: list[dict[str, Any]]) -> None: ...
async def optimize(
self,
*,
cleanup_since_ms: Optional[int] = None,
delete_unverified: Optional[bool] = None,
) -> OptimizeStats: ...
def query(self) -> Query: ...
def vector_search(self) -> VectorQuery: ...

View File

@@ -603,7 +603,7 @@ class AsyncConnection(object):
fill_value: Optional[float] = None,
storage_options: Optional[Dict[str, str]] = None,
*,
embedding_functions: List[EmbeddingFunctionConfig] = None,
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
data_storage_version: Optional[str] = None,
use_legacy_format: Optional[bool] = None,
enable_v2_manifest_paths: Optional[bool] = None,

View File

@@ -1,20 +1,10 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
"""Full text search index using tantivy-py"""
import os
from typing import List, Tuple
from typing import List, Tuple, Optional
import pyarrow as pa
@@ -31,7 +21,7 @@ from .table import LanceTable
def create_index(
index_path: str,
text_fields: List[str],
ordering_fields: List[str] = None,
ordering_fields: Optional[List[str]] = None,
tokenizer_name: str = "default",
) -> tantivy.Index:
"""
@@ -75,8 +65,8 @@ def populate_index(
index: tantivy.Index,
table: LanceTable,
fields: List[str],
writer_heap_size: int = 1024 * 1024 * 1024,
ordering_fields: List[str] = None,
writer_heap_size: Optional[int] = None,
ordering_fields: Optional[List[str]] = None,
) -> int:
"""
Populate an index with data from a LanceTable
@@ -99,6 +89,7 @@ def populate_index(
"""
if ordering_fields is None:
ordering_fields = []
writer_heap_size = writer_heap_size or 1024 * 1024 * 1024
# first check the fields exist and are string or large string type
nested = []

View File

@@ -115,6 +115,9 @@ class Query(pydantic.BaseModel):
# e.g. `{"nprobes": "10", "refine_factor": "10"}`
nprobes: int = 10
lower_bound: Optional[float] = None
upper_bound: Optional[float] = None
# Refine factor.
refine_factor: Optional[int] = None
@@ -604,6 +607,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
self._query = query
self._metric = "L2"
self._nprobes = 20
self._lower_bound = None
self._upper_bound = None
self._refine_factor = None
self._vector_column = vector_column
self._prefilter = False
@@ -649,6 +654,30 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
self._nprobes = nprobes
return self
def distance_range(
self, lower_bound: Optional[float] = None, upper_bound: Optional[float] = None
) -> LanceVectorQueryBuilder:
"""Set the distance range to use.
Only rows with distances within range [lower_bound, upper_bound)
will be returned.
Parameters
----------
lower: Optional[float]
The lower bound of the distance range.
upper_bound: Optional[float]
The upper bound of the distance range.
Returns
-------
LanceVectorQueryBuilder
The LanceQueryBuilder object.
"""
self._lower_bound = lower_bound
self._upper_bound = upper_bound
return self
def ef(self, ef: int) -> LanceVectorQueryBuilder:
"""Set the number of candidates to consider during search.
@@ -728,6 +757,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
metric=self._metric,
columns=self._columns,
nprobes=self._nprobes,
lower_bound=self._lower_bound,
upper_bound=self._upper_bound,
refine_factor=self._refine_factor,
vector_column=self._vector_column,
with_row_id=self._with_row_id,
@@ -1284,6 +1315,31 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
self._nprobes = nprobes
return self
def distance_range(
self, lower_bound: Optional[float] = None, upper_bound: Optional[float] = None
) -> LanceHybridQueryBuilder:
"""
Set the distance range to use.
Only rows with distances within range [lower_bound, upper_bound)
will be returned.
Parameters
----------
lower: Optional[float]
The lower bound of the distance range.
upper_bound: Optional[float]
The upper bound of the distance range.
Returns
-------
LanceHybridQueryBuilder
The LanceHybridQueryBuilder object.
"""
self._lower_bound = lower_bound
self._upper_bound = upper_bound
return self
def ef(self, ef: int) -> LanceHybridQueryBuilder:
"""
Set the number of candidates to consider during search.
@@ -1855,6 +1911,29 @@ class AsyncVectorQuery(AsyncQueryBase):
self._inner.nprobes(nprobes)
return self
def distance_range(
self, lower_bound: Optional[float] = None, upper_bound: Optional[float] = None
) -> AsyncVectorQuery:
"""Set the distance range to use.
Only rows with distances within range [lower_bound, upper_bound)
will be returned.
Parameters
----------
lower: Optional[float]
The lower bound of the distance range.
upper_bound: Optional[float]
The upper bound of the distance range.
Returns
-------
AsyncVectorQuery
The AsyncVectorQuery object.
"""
self._inner.distance_range(lower_bound, upper_bound)
return self
def ef(self, ef: int) -> AsyncVectorQuery:
"""
Set the number of candidates to consider during search

View File

@@ -61,11 +61,12 @@ from .index import lang_mapping
if TYPE_CHECKING:
import PIL
from lance.dataset import CleanupStats, ReaderLike
from ._lancedb import Table as LanceDBTable, OptimizeStats
from ._lancedb import Table as LanceDBTable, OptimizeStats, CompactionStats
from .db import LanceDBConnection
from .index import IndexConfig
from lance.dataset import CleanupStats, ReaderLike
import pandas
import PIL
pd = safe_import_pandas()
pl = safe_import_polars()
@@ -84,7 +85,6 @@ def _pd_schema_without_embedding_funcs(
)
if not embedding_functions:
return schema
columns = set(columns)
return pa.schema([field for field in schema if field.name in columns])
@@ -119,7 +119,7 @@ def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table:
return pa.Table.from_batches(data, schema=schema)
else:
return pa.Table.from_pylist(data, schema=schema)
elif _check_for_pandas(data) and isinstance(data, pd.DataFrame):
elif _check_for_pandas(data) and isinstance(data, pd.DataFrame): # type: ignore
raw_schema = _pd_schema_without_embedding_funcs(schema, data.columns.to_list())
table = pa.Table.from_pandas(data, preserve_index=False, schema=raw_schema)
# Do not serialize Pandas metadata
@@ -160,7 +160,7 @@ def _sanitize_data(
metadata: Optional[dict] = None, # embedding metadata
on_bad_vectors: str = "error",
fill_value: float = 0.0,
):
) -> Tuple[pa.Table, pa.Schema]:
data = _coerce_to_table(data, schema)
if metadata:
@@ -178,13 +178,17 @@ def _sanitize_data(
def sanitize_create_table(
data, schema, metadata=None, on_bad_vectors="error", fill_value=0.0
data,
schema: Union[pa.Schema, LanceModel],
metadata=None,
on_bad_vectors: str = "error",
fill_value: float = 0.0,
):
if inspect.isclass(schema) and issubclass(schema, LanceModel):
# convert LanceModel to pyarrow schema
# note that it's possible this contains
# embedding function metadata already
schema = schema.to_arrow_schema()
schema: pa.Schema = schema.to_arrow_schema()
if data is not None:
if metadata is None and schema is not None:
@@ -272,41 +276,6 @@ def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schem
return data
def _generator_to_data_and_schema(
data: Iterable,
) -> Tuple[Iterable[pa.RecordBatch], pa.Schema]:
def _with_first_generator(first, data):
yield first
yield from data
first = next(data, None)
schema = None
if isinstance(first, pa.RecordBatch):
schema = first.schema
data = _with_first_generator(first, data)
elif isinstance(first, pa.Table):
schema = first.schema
data = _with_first_generator(first.to_batches(), data)
return data, schema
def _to_record_batch_generator(
data: Iterable,
schema,
metadata,
on_bad_vectors,
fill_value,
):
for batch in data:
# always convert to table because we need to sanitize the data
# and do things like add the vector column etc
if isinstance(batch, pa.RecordBatch):
batch = pa.Table.from_batches([batch])
batch, _ = _sanitize_data(batch, schema, metadata, on_bad_vectors, fill_value)
for b in batch.to_batches():
yield b
def _table_path(base: str, table_name: str) -> str:
"""
Get a table path that can be used in PyArrow FS.
@@ -404,7 +373,7 @@ class Table(ABC):
"""
raise NotImplementedError
def to_pandas(self) -> "pd.DataFrame":
def to_pandas(self) -> "pandas.DataFrame":
"""Return the table as a pandas DataFrame.
Returns
@@ -537,8 +506,8 @@ class Table(ABC):
def create_fts_index(
self,
field_names: Union[str, List[str]],
ordering_field_names: Union[str, List[str]] = None,
*,
ordering_field_names: Optional[Union[str, List[str]]] = None,
replace: bool = False,
writer_heap_size: Optional[int] = 1024 * 1024 * 1024,
use_tantivy: bool = True,
@@ -790,8 +759,7 @@ class Table(ABC):
@abstractmethod
def _execute_query(
self, query: Query, batch_size: Optional[int] = None
) -> pa.RecordBatchReader:
pass
) -> pa.RecordBatchReader: ...
@abstractmethod
def _do_merge(
@@ -800,8 +768,7 @@ class Table(ABC):
new_data: DATA,
on_bad_vectors: str,
fill_value: float,
):
pass
): ...
@abstractmethod
def delete(self, where: str):
@@ -1121,7 +1088,7 @@ class Table(ABC):
"""
@abstractmethod
def list_versions(self):
def list_versions(self) -> List[Dict[str, Any]]:
"""List all versions of the table"""
@cached_property
@@ -1244,7 +1211,7 @@ class LanceTable(Table):
A PyArrow schema object."""
return LOOP.run(self._table.schema())
def list_versions(self):
def list_versions(self) -> List[Dict[str, Any]]:
"""List all versions of the table"""
return LOOP.run(self._table.list_versions())
@@ -1297,7 +1264,7 @@ class LanceTable(Table):
"""
LOOP.run(self._table.checkout_latest())
def restore(self, version: int = None):
def restore(self, version: Optional[int] = None):
"""Restore a version of the table. This is an in-place operation.
This creates a new version where the data is equivalent to the
@@ -1338,7 +1305,7 @@ class LanceTable(Table):
def count_rows(self, filter: Optional[str] = None) -> int:
return LOOP.run(self._table.count_rows(filter))
def __len__(self):
def __len__(self) -> int:
return self.count_rows()
def __repr__(self) -> str:
@@ -1506,8 +1473,8 @@ class LanceTable(Table):
def create_fts_index(
self,
field_names: Union[str, List[str]],
ordering_field_names: Union[str, List[str]] = None,
*,
ordering_field_names: Optional[Union[str, List[str]]] = None,
replace: bool = False,
writer_heap_size: Optional[int] = 1024 * 1024 * 1024,
use_tantivy: bool = True,
@@ -1594,6 +1561,7 @@ class LanceTable(Table):
writer_heap_size=writer_heap_size,
)
@staticmethod
def infer_tokenizer_configs(tokenizer_name: str) -> dict:
if tokenizer_name == "default":
return {
@@ -1759,7 +1727,7 @@ class LanceTable(Table):
)
@overload
def search(
def search( # type: ignore
self,
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
vector_column_name: Optional[str] = None,
@@ -1895,11 +1863,11 @@ class LanceTable(Table):
name: str,
data: Optional[DATA] = None,
schema: Optional[pa.Schema] = None,
mode: Literal["create", "overwrite", "append"] = "create",
mode: Literal["create", "overwrite"] = "create",
exist_ok: bool = False,
on_bad_vectors: str = "error",
fill_value: float = 0.0,
embedding_functions: List[EmbeddingFunctionConfig] = None,
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
*,
storage_options: Optional[Dict[str, str]] = None,
data_storage_version: Optional[str] = None,
@@ -2065,7 +2033,7 @@ class LanceTable(Table):
older_than, delete_unverified=delete_unverified
)
def compact_files(self, *args, **kwargs):
def compact_files(self, *args, **kwargs) -> CompactionStats:
"""
Run the compaction process on the table.
@@ -2450,7 +2418,7 @@ def _process_iterator(data: Iterable, schema: Optional[pa.Schema] = None) -> pa.
if batch_table.schema != schema:
try:
batch_table = batch_table.cast(schema)
except pa.lib.ArrowInvalid:
except pa.lib.ArrowInvalid: # type: ignore
raise ValueError(
f"Input iterator yielded a batch with schema that "
f"does not match the expected schema.\nExpected:\n{schema}\n"
@@ -2710,16 +2678,17 @@ class AsyncTable:
on_bad_vectors = "error"
if fill_value is None:
fill_value = 0.0
data, _ = _sanitize_data(
table_and_schema: Tuple[pa.Table, pa.Schema] = _sanitize_data(
data,
schema,
metadata=schema.metadata,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
if isinstance(data, pa.Table):
data = pa.RecordBatchReader.from_batches(data.schema, data.to_batches())
await self._inner.add(data, mode)
tbl, schema = table_and_schema
if isinstance(tbl, pa.Table):
data = pa.RecordBatchReader.from_batches(schema, tbl.to_batches())
await self._inner.add(data, mode or "append")
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
"""
@@ -2817,6 +2786,7 @@ class AsyncTable:
async_query.nearest_to(query.vector)
.distance_type(query.metric)
.nprobes(query.nprobes)
.distance_range(query.lower_bound, query.upper_bound)
)
if query.refine_factor:
async_query = async_query.refine_factor(query.refine_factor)
@@ -2977,7 +2947,7 @@ class AsyncTable:
return await self._inner.update(updates_sql, where)
async def add_columns(self, transforms: Dict[str, str]):
async def add_columns(self, transforms: dict[str, str]):
"""
Add new columns with defined values.
@@ -2990,7 +2960,7 @@ class AsyncTable:
"""
await self._inner.add_columns(list(transforms.items()))
async def alter_columns(self, *alterations: Iterable[Dict[str, str]]):
async def alter_columns(self, *alterations: Iterable[dict[str, Any]]):
"""
Alter column names and nullability.
@@ -3148,9 +3118,12 @@ class AsyncTable:
you have added or modified 100,000 or more records or run more than 20 data
modification operations.
"""
cleanup_since_ms: Optional[int] = None
if cleanup_older_than is not None:
cleanup_older_than = round(cleanup_older_than.total_seconds() * 1000)
return await self._inner.optimize(cleanup_older_than, delete_unverified)
cleanup_since_ms = round(cleanup_older_than.total_seconds() * 1000)
return await self._inner.optimize(
cleanup_since_ms=cleanup_since_ms, delete_unverified=delete_unverified
)
async def list_indices(self) -> Iterable[IndexConfig]:
"""

View File

@@ -167,8 +167,24 @@ def test_search_index(tmp_path, table):
@pytest.mark.parametrize("use_tantivy", [True, False])
def test_search_fts(table, use_tantivy):
table.create_fts_index("text", use_tantivy=use_tantivy)
results = table.search("puppy").limit(5).to_list()
results = table.search("puppy").select(["id", "text"]).limit(5).to_list()
assert len(results) == 5
assert len(results[0]) == 3 # id, text, _score
@pytest.mark.asyncio
async def test_fts_select_async(async_table):
tbl = await async_table
await tbl.create_index("text", config=FTS())
results = (
await tbl.query()
.nearest_to_text("puppy")
.select(["id", "text"])
.limit(5)
.to_list()
)
assert len(results) == 5
assert len(results[0]) == 3 # id, text, _score
def test_search_fts_phrase_query(table):

View File

@@ -94,6 +94,73 @@ def test_with_row_id(table: lancedb.table.Table):
assert rs["_rowid"].to_pylist() == [0, 1]
def test_distance_range(table: lancedb.table.Table):
q = [0, 0]
rs = table.search(q).to_arrow()
dists = rs["_distance"].to_pylist()
min_dist = dists[0]
max_dist = dists[-1]
res = table.search(q).distance_range(upper_bound=min_dist).to_arrow()
assert len(res) == 0
res = table.search(q).distance_range(lower_bound=max_dist).to_arrow()
assert len(res) == 1
assert res["_distance"].to_pylist() == [max_dist]
res = table.search(q).distance_range(upper_bound=max_dist).to_arrow()
assert len(res) == 1
assert res["_distance"].to_pylist() == [min_dist]
res = table.search(q).distance_range(lower_bound=min_dist).to_arrow()
assert len(res) == 2
assert res["_distance"].to_pylist() == [min_dist, max_dist]
@pytest.mark.asyncio
async def test_distance_range_async(table_async: AsyncTable):
q = [0, 0]
rs = await table_async.query().nearest_to(q).to_arrow()
dists = rs["_distance"].to_pylist()
min_dist = dists[0]
max_dist = dists[-1]
res = (
await table_async.query()
.nearest_to(q)
.distance_range(upper_bound=min_dist)
.to_arrow()
)
assert len(res) == 0
res = (
await table_async.query()
.nearest_to(q)
.distance_range(lower_bound=max_dist)
.to_arrow()
)
assert len(res) == 1
assert res["_distance"].to_pylist() == [max_dist]
res = (
await table_async.query()
.nearest_to(q)
.distance_range(upper_bound=max_dist)
.to_arrow()
)
assert len(res) == 1
assert res["_distance"].to_pylist() == [min_dist]
res = (
await table_async.query()
.nearest_to(q)
.distance_range(lower_bound=min_dist)
.to_arrow()
)
assert len(res) == 2
assert res["_distance"].to_pylist() == [min_dist, max_dist]
def test_vector_query_with_no_limit(table):
with pytest.raises(ValueError):
LanceVectorQueryBuilder(table, [0, 0], "vector").limit(0).select(

View File

@@ -306,6 +306,8 @@ def test_query_sync_minimal():
"k": 10,
"prefilter": False,
"refine_factor": None,
"lower_bound": None,
"upper_bound": None,
"ef": None,
"vector": [1.0, 2.0, 3.0],
"nprobes": 20,
@@ -348,6 +350,8 @@ def test_query_sync_maximal():
"refine_factor": 10,
"vector": [1.0, 2.0, 3.0],
"nprobes": 5,
"lower_bound": None,
"upper_bound": None,
"ef": None,
"filter": "id > 0",
"columns": ["id", "name"],
@@ -449,6 +453,8 @@ def test_query_sync_hybrid():
"refine_factor": None,
"vector": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
"nprobes": 20,
"lower_bound": None,
"upper_bound": None,
"ef": None,
"with_row_id": True,
"version": None,

View File

@@ -152,6 +152,10 @@ impl FTSQuery {
self.inner = self.inner.clone().select(Select::dynamic(&columns));
}
pub fn select_columns(&mut self, columns: Vec<String>) {
self.inner = self.inner.clone().select(Select::columns(&columns));
}
pub fn limit(&mut self, limit: u32) {
self.inner = self.inner.clone().limit(limit as usize);
}
@@ -280,6 +284,11 @@ impl VectorQuery {
self.inner = self.inner.clone().nprobes(nprobe as usize);
}
#[pyo3(signature = (lower_bound=None, upper_bound=None))]
pub fn distance_range(&mut self, lower_bound: Option<f32>, upper_bound: Option<f32>) {
self.inner = self.inner.clone().distance_range(lower_bound, upper_bound);
}
pub fn ef(&mut self, ef: u32) {
self.inner = self.inner.clone().ef(ef as usize);
}
@@ -341,6 +350,11 @@ impl HybridQuery {
self.inner_fts.select(columns);
}
pub fn select_columns(&mut self, columns: Vec<String>) {
self.inner_vec.select_columns(columns.clone());
self.inner_fts.select_columns(columns);
}
pub fn limit(&mut self, limit: u32) {
self.inner_vec.limit(limit);
self.inner_fts.limit(limit);

View File

@@ -97,10 +97,12 @@ impl Table {
self.name.clone()
}
/// Returns True if the table is open, False if it is closed.
pub fn is_open(&self) -> bool {
self.inner.is_some()
}
/// Closes the table, releasing any resources associated with it.
pub fn close(&mut self) {
self.inner.take();
}
@@ -301,6 +303,7 @@ impl Table {
Query::new(self.inner_ref().unwrap().query())
}
/// Optimize the on-disk data by compacting and pruning old data, for better performance.
#[pyo3(signature = (cleanup_since_ms=None, delete_unverified=None))]
pub fn optimize(
self_: PyRef<'_, Self>,

View File

@@ -755,6 +755,10 @@ pub struct VectorQuery {
// IVF PQ - ANN search.
pub(crate) query_vector: Vec<Arc<dyn Array>>,
pub(crate) nprobes: usize,
// The lower bound (inclusive) of the distance to search for.
pub(crate) lower_bound: Option<f32>,
// The upper bound (exclusive) of the distance to search for.
pub(crate) upper_bound: Option<f32>,
// The number of candidates to return during the refine step for HNSW,
// defaults to 1.5 * limit.
pub(crate) ef: Option<usize>,
@@ -771,6 +775,8 @@ impl VectorQuery {
column: None,
query_vector: Vec::new(),
nprobes: 20,
lower_bound: None,
upper_bound: None,
ef: None,
refine_factor: None,
distance_type: None,
@@ -831,6 +837,14 @@ impl VectorQuery {
self
}
/// Set the distance range for vector search,
/// only rows with distances in the range [lower_bound, upper_bound) will be returned
pub fn distance_range(mut self, lower_bound: Option<f32>, upper_bound: Option<f32>) -> Self {
self.lower_bound = lower_bound;
self.upper_bound = upper_bound;
self
}
/// Set the number of candidates to return during the refine step for HNSW
///
/// This argument is only used when the vector column has an HNSW index.
@@ -1350,6 +1364,30 @@ mod tests {
}
}
#[tokio::test]
async fn test_distance_range() {
let tmp_dir = tempdir().unwrap();
let table = make_test_table(&tmp_dir).await;
let results = table
.vector_search(&[0.1, 0.2, 0.3, 0.4])
.unwrap()
.distance_range(Some(0.0), Some(1.0))
.limit(10)
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
for batch in results {
let distances = batch["_distance"].as_primitive::<Float32Type>();
assert!(distances.iter().all(|d| {
let d = d.unwrap();
(0.0..1.0).contains(&d)
}));
}
}
#[tokio::test]
async fn test_multiple_query_vectors() {
let tmp_dir = tempdir().unwrap();

View File

@@ -210,6 +210,8 @@ impl<S: HttpSend> RemoteTable<S> {
body["prefilter"] = query.base.prefilter.into();
body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default());
body["nprobes"] = query.nprobes.into();
body["lower_bound"] = query.lower_bound.into();
body["upper_bound"] = query.upper_bound.into();
body["ef"] = query.ef.into();
body["refine_factor"] = query.refine_factor.into();
if let Some(vector_column) = query.column.as_ref() {
@@ -1304,6 +1306,8 @@ mod tests {
"prefilter": true,
"distance_type": "l2",
"nprobes": 20,
"lower_bound": Option::<f32>::None,
"upper_bound": Option::<f32>::None,
"k": 10,
"ef": Option::<usize>::None,
"refine_factor": null,
@@ -1353,6 +1357,8 @@ mod tests {
"bypass_vector_index": true,
"columns": ["a", "b"],
"nprobes": 12,
"lower_bound": Option::<f32>::None,
"upper_bound": Option::<f32>::None,
"ef": Option::<usize>::None,
"refine_factor": 2,
"version": null,

View File

@@ -1944,6 +1944,7 @@ impl TableInternal for NativeTable {
if let Some(ef) = query.ef {
scanner.ef(ef);
}
scanner.distance_range(query.lower_bound, query.upper_bound);
scanner.use_index(query.use_index);
scanner.prefilter(query.base.prefilter);
match query.base.select {