Compare commits

..

5 Commits

Author SHA1 Message Date
Lance Release
a27c5cf12b Bump version: 0.17.2-beta.1 → 0.17.2-beta.2 2025-01-06 05:34:27 +00:00
BubbleCal
f4dea72cc5 feat: support vector search with distance thresholds (#1993)
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
2025-01-06 13:23:39 +08:00
Lei Xu
f76c4a5ce1 chore: add pyright static type checking and fix some of the table interface (#1996)
* Enable `pyright` in the project
* Fixed some pyright typing errors in `table.py`
2025-01-04 15:24:58 -08:00
ahaapple
164ce397c2 docs: fix full-text search (Native FTS) TypeScript doc error (#1992)
Fix

```
Cannot find name 'queryType'.ts(2304)
any
```
2025-01-03 13:36:10 -05:00
BubbleCal
445a312667 fix: selecting columns failed on FTS and hybrid search (#1991)
it reports error `AttributeError: 'builtins.FTSQuery' object has no
attribute 'select_columns'`
because we missed `select_columns` method in rust

Signed-off-by: BubbleCal <bubble-cal@outlook.com>
2025-01-03 13:08:12 +08:00
19 changed files with 323 additions and 103 deletions

View File

@@ -30,10 +30,10 @@ jobs:
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v5 uses: actions/setup-python@v5
with: with:
python-version: "3.11" python-version: "3.12"
- name: Install ruff - name: Install ruff
run: | run: |
pip install ruff==0.5.4 pip install ruff==0.8.4
- name: Format check - name: Format check
run: ruff format --check . run: ruff format --check .
- name: Lint - name: Lint

View File

@@ -23,14 +23,14 @@ rust-version = "1.78.0"
[workspace.dependencies] [workspace.dependencies]
lance = { "version" = "=0.21.1", "features" = [ lance = { "version" = "=0.21.1", "features" = [
"dynamodb", "dynamodb",
], git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" } ], git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
lance-io = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" } lance-io = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
lance-index = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" } lance-index = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
lance-linalg = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" } lance-linalg = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
lance-table = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" } lance-table = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
lance-testing = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" } lance-testing = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
lance-datafusion = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" } lance-datafusion = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
lance-encoding = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" } lance-encoding = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
# Note that this one does not include pyarrow # Note that this one does not include pyarrow
arrow = { version = "53.2", optional = false } arrow = { version = "53.2", optional = false }
arrow-array = "53.2" arrow-array = "53.2"

View File

@@ -50,7 +50,7 @@ Consider that we have a LanceDB table named `my_table`, whose string column `tex
}); });
await tbl await tbl
.search("puppy", queryType="fts") .search("puppy", "fts")
.select(["text"]) .select(["text"])
.limit(10) .limit(10)
.toArray(); .toArray();

View File

@@ -1,5 +1,5 @@
[tool.bumpversion] [tool.bumpversion]
current_version = "0.17.2-beta.1" current_version = "0.17.2-beta.2"
parse = """(?x) parse = """(?x)
(?P<major>0|[1-9]\\d*)\\. (?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\. (?P<minor>0|[1-9]\\d*)\\.

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "lancedb-python" name = "lancedb-python"
version = "0.17.2-beta.1" version = "0.17.2-beta.2"
edition.workspace = true edition.workspace = true
description = "Python bindings for LanceDB" description = "Python bindings for LanceDB"
license.workspace = true license.workspace = true

View File

@@ -53,8 +53,9 @@ tests = [
"pytz", "pytz",
"polars>=0.19, <=1.3.0", "polars>=0.19, <=1.3.0",
"tantivy", "tantivy",
"pyarrow-stubs"
] ]
dev = ["ruff", "pre-commit"] dev = ["ruff", "pre-commit", "pyright"]
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"] docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
clip = ["torch", "pillow", "open-clip"] clip = ["torch", "pillow", "open-clip"]
embeddings = [ embeddings = [
@@ -94,3 +95,7 @@ markers = [
"asyncio", "asyncio",
"s3_test", "s3_test",
] ]
[tool.pyright]
include = ["python/lancedb/table.py"]
pythonVersion = "3.12"

View File

@@ -1,7 +1,9 @@
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple, Any, Union, Literal
import pyarrow as pa import pyarrow as pa
from .index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq, FTS
class Connection(object): class Connection(object):
uri: str uri: str
async def table_names( async def table_names(
@@ -31,16 +33,35 @@ class Connection(object):
class Table: class Table:
def name(self) -> str: ... def name(self) -> str: ...
def __repr__(self) -> str: ... def __repr__(self) -> str: ...
def is_open(self) -> bool: ...
def close(self) -> None: ...
async def schema(self) -> pa.Schema: ... async def schema(self) -> pa.Schema: ...
async def add(self, data: pa.RecordBatchReader, mode: str) -> None: ... async def add(
self, data: pa.RecordBatchReader, mode: Literal["append", "overwrite"]
) -> None: ...
async def update(self, updates: Dict[str, str], where: Optional[str]) -> None: ... async def update(self, updates: Dict[str, str], where: Optional[str]) -> None: ...
async def count_rows(self, filter: Optional[str]) -> int: ... async def count_rows(self, filter: Optional[str]) -> int: ...
async def create_index(self, column: str, config, replace: Optional[bool]): ... async def create_index(
self,
column: str,
index: Union[IvfFlat, IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS],
replace: Optional[bool],
): ...
async def list_versions(self) -> List[Dict[str, Any]]: ...
async def version(self) -> int: ... async def version(self) -> int: ...
async def checkout(self, version: int): ... async def checkout(self, version: int): ...
async def checkout_latest(self): ... async def checkout_latest(self): ...
async def restore(self): ... async def restore(self): ...
async def list_indices(self) -> List[IndexConfig]: ... async def list_indices(self) -> list[IndexConfig]: ...
async def delete(self, filter: str): ...
async def add_columns(self, columns: list[tuple[str, str]]) -> None: ...
async def alter_columns(self, columns: list[dict[str, Any]]) -> None: ...
async def optimize(
self,
*,
cleanup_since_ms: Optional[int] = None,
delete_unverified: Optional[bool] = None,
) -> OptimizeStats: ...
def query(self) -> Query: ... def query(self) -> Query: ...
def vector_search(self) -> VectorQuery: ... def vector_search(self) -> VectorQuery: ...

View File

@@ -603,7 +603,7 @@ class AsyncConnection(object):
fill_value: Optional[float] = None, fill_value: Optional[float] = None,
storage_options: Optional[Dict[str, str]] = None, storage_options: Optional[Dict[str, str]] = None,
*, *,
embedding_functions: List[EmbeddingFunctionConfig] = None, embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
data_storage_version: Optional[str] = None, data_storage_version: Optional[str] = None,
use_legacy_format: Optional[bool] = None, use_legacy_format: Optional[bool] = None,
enable_v2_manifest_paths: Optional[bool] = None, enable_v2_manifest_paths: Optional[bool] = None,

View File

@@ -1,20 +1,10 @@
# Copyright 2023 LanceDB Developers # SPDX-License-Identifier: Apache-2.0
# # SPDX-FileCopyrightText: Copyright The LanceDB Authors
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Full text search index using tantivy-py""" """Full text search index using tantivy-py"""
import os import os
from typing import List, Tuple from typing import List, Tuple, Optional
import pyarrow as pa import pyarrow as pa
@@ -31,7 +21,7 @@ from .table import LanceTable
def create_index( def create_index(
index_path: str, index_path: str,
text_fields: List[str], text_fields: List[str],
ordering_fields: List[str] = None, ordering_fields: Optional[List[str]] = None,
tokenizer_name: str = "default", tokenizer_name: str = "default",
) -> tantivy.Index: ) -> tantivy.Index:
""" """
@@ -75,8 +65,8 @@ def populate_index(
index: tantivy.Index, index: tantivy.Index,
table: LanceTable, table: LanceTable,
fields: List[str], fields: List[str],
writer_heap_size: int = 1024 * 1024 * 1024, writer_heap_size: Optional[int] = None,
ordering_fields: List[str] = None, ordering_fields: Optional[List[str]] = None,
) -> int: ) -> int:
""" """
Populate an index with data from a LanceTable Populate an index with data from a LanceTable
@@ -99,6 +89,7 @@ def populate_index(
""" """
if ordering_fields is None: if ordering_fields is None:
ordering_fields = [] ordering_fields = []
writer_heap_size = writer_heap_size or 1024 * 1024 * 1024
# first check the fields exist and are string or large string type # first check the fields exist and are string or large string type
nested = [] nested = []

View File

@@ -115,6 +115,9 @@ class Query(pydantic.BaseModel):
# e.g. `{"nprobes": "10", "refine_factor": "10"}` # e.g. `{"nprobes": "10", "refine_factor": "10"}`
nprobes: int = 10 nprobes: int = 10
lower_bound: Optional[float] = None
upper_bound: Optional[float] = None
# Refine factor. # Refine factor.
refine_factor: Optional[int] = None refine_factor: Optional[int] = None
@@ -604,6 +607,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
self._query = query self._query = query
self._metric = "L2" self._metric = "L2"
self._nprobes = 20 self._nprobes = 20
self._lower_bound = None
self._upper_bound = None
self._refine_factor = None self._refine_factor = None
self._vector_column = vector_column self._vector_column = vector_column
self._prefilter = False self._prefilter = False
@@ -649,6 +654,30 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
self._nprobes = nprobes self._nprobes = nprobes
return self return self
def distance_range(
self, lower_bound: Optional[float] = None, upper_bound: Optional[float] = None
) -> LanceVectorQueryBuilder:
"""Set the distance range to use.
Only rows with distances within range [lower_bound, upper_bound)
will be returned.
Parameters
----------
lower: Optional[float]
The lower bound of the distance range.
upper_bound: Optional[float]
The upper bound of the distance range.
Returns
-------
LanceVectorQueryBuilder
The LanceQueryBuilder object.
"""
self._lower_bound = lower_bound
self._upper_bound = upper_bound
return self
def ef(self, ef: int) -> LanceVectorQueryBuilder: def ef(self, ef: int) -> LanceVectorQueryBuilder:
"""Set the number of candidates to consider during search. """Set the number of candidates to consider during search.
@@ -728,6 +757,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
metric=self._metric, metric=self._metric,
columns=self._columns, columns=self._columns,
nprobes=self._nprobes, nprobes=self._nprobes,
lower_bound=self._lower_bound,
upper_bound=self._upper_bound,
refine_factor=self._refine_factor, refine_factor=self._refine_factor,
vector_column=self._vector_column, vector_column=self._vector_column,
with_row_id=self._with_row_id, with_row_id=self._with_row_id,
@@ -1284,6 +1315,31 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
self._nprobes = nprobes self._nprobes = nprobes
return self return self
def distance_range(
self, lower_bound: Optional[float] = None, upper_bound: Optional[float] = None
) -> LanceHybridQueryBuilder:
"""
Set the distance range to use.
Only rows with distances within range [lower_bound, upper_bound)
will be returned.
Parameters
----------
lower: Optional[float]
The lower bound of the distance range.
upper_bound: Optional[float]
The upper bound of the distance range.
Returns
-------
LanceHybridQueryBuilder
The LanceHybridQueryBuilder object.
"""
self._lower_bound = lower_bound
self._upper_bound = upper_bound
return self
def ef(self, ef: int) -> LanceHybridQueryBuilder: def ef(self, ef: int) -> LanceHybridQueryBuilder:
""" """
Set the number of candidates to consider during search. Set the number of candidates to consider during search.
@@ -1855,6 +1911,29 @@ class AsyncVectorQuery(AsyncQueryBase):
self._inner.nprobes(nprobes) self._inner.nprobes(nprobes)
return self return self
def distance_range(
self, lower_bound: Optional[float] = None, upper_bound: Optional[float] = None
) -> AsyncVectorQuery:
"""Set the distance range to use.
Only rows with distances within range [lower_bound, upper_bound)
will be returned.
Parameters
----------
lower: Optional[float]
The lower bound of the distance range.
upper_bound: Optional[float]
The upper bound of the distance range.
Returns
-------
AsyncVectorQuery
The AsyncVectorQuery object.
"""
self._inner.distance_range(lower_bound, upper_bound)
return self
def ef(self, ef: int) -> AsyncVectorQuery: def ef(self, ef: int) -> AsyncVectorQuery:
""" """
Set the number of candidates to consider during search Set the number of candidates to consider during search

View File

@@ -61,11 +61,12 @@ from .index import lang_mapping
if TYPE_CHECKING: if TYPE_CHECKING:
import PIL from ._lancedb import Table as LanceDBTable, OptimizeStats, CompactionStats
from lance.dataset import CleanupStats, ReaderLike
from ._lancedb import Table as LanceDBTable, OptimizeStats
from .db import LanceDBConnection from .db import LanceDBConnection
from .index import IndexConfig from .index import IndexConfig
from lance.dataset import CleanupStats, ReaderLike
import pandas
import PIL
pd = safe_import_pandas() pd = safe_import_pandas()
pl = safe_import_polars() pl = safe_import_polars()
@@ -84,7 +85,6 @@ def _pd_schema_without_embedding_funcs(
) )
if not embedding_functions: if not embedding_functions:
return schema return schema
columns = set(columns)
return pa.schema([field for field in schema if field.name in columns]) return pa.schema([field for field in schema if field.name in columns])
@@ -119,7 +119,7 @@ def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table:
return pa.Table.from_batches(data, schema=schema) return pa.Table.from_batches(data, schema=schema)
else: else:
return pa.Table.from_pylist(data, schema=schema) return pa.Table.from_pylist(data, schema=schema)
elif _check_for_pandas(data) and isinstance(data, pd.DataFrame): elif _check_for_pandas(data) and isinstance(data, pd.DataFrame): # type: ignore
raw_schema = _pd_schema_without_embedding_funcs(schema, data.columns.to_list()) raw_schema = _pd_schema_without_embedding_funcs(schema, data.columns.to_list())
table = pa.Table.from_pandas(data, preserve_index=False, schema=raw_schema) table = pa.Table.from_pandas(data, preserve_index=False, schema=raw_schema)
# Do not serialize Pandas metadata # Do not serialize Pandas metadata
@@ -160,7 +160,7 @@ def _sanitize_data(
metadata: Optional[dict] = None, # embedding metadata metadata: Optional[dict] = None, # embedding metadata
on_bad_vectors: str = "error", on_bad_vectors: str = "error",
fill_value: float = 0.0, fill_value: float = 0.0,
): ) -> Tuple[pa.Table, pa.Schema]:
data = _coerce_to_table(data, schema) data = _coerce_to_table(data, schema)
if metadata: if metadata:
@@ -178,13 +178,17 @@ def _sanitize_data(
def sanitize_create_table( def sanitize_create_table(
data, schema, metadata=None, on_bad_vectors="error", fill_value=0.0 data,
schema: Union[pa.Schema, LanceModel],
metadata=None,
on_bad_vectors: str = "error",
fill_value: float = 0.0,
): ):
if inspect.isclass(schema) and issubclass(schema, LanceModel): if inspect.isclass(schema) and issubclass(schema, LanceModel):
# convert LanceModel to pyarrow schema # convert LanceModel to pyarrow schema
# note that it's possible this contains # note that it's possible this contains
# embedding function metadata already # embedding function metadata already
schema = schema.to_arrow_schema() schema: pa.Schema = schema.to_arrow_schema()
if data is not None: if data is not None:
if metadata is None and schema is not None: if metadata is None and schema is not None:
@@ -272,41 +276,6 @@ def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schem
return data return data
def _generator_to_data_and_schema(
data: Iterable,
) -> Tuple[Iterable[pa.RecordBatch], pa.Schema]:
def _with_first_generator(first, data):
yield first
yield from data
first = next(data, None)
schema = None
if isinstance(first, pa.RecordBatch):
schema = first.schema
data = _with_first_generator(first, data)
elif isinstance(first, pa.Table):
schema = first.schema
data = _with_first_generator(first.to_batches(), data)
return data, schema
def _to_record_batch_generator(
data: Iterable,
schema,
metadata,
on_bad_vectors,
fill_value,
):
for batch in data:
# always convert to table because we need to sanitize the data
# and do things like add the vector column etc
if isinstance(batch, pa.RecordBatch):
batch = pa.Table.from_batches([batch])
batch, _ = _sanitize_data(batch, schema, metadata, on_bad_vectors, fill_value)
for b in batch.to_batches():
yield b
def _table_path(base: str, table_name: str) -> str: def _table_path(base: str, table_name: str) -> str:
""" """
Get a table path that can be used in PyArrow FS. Get a table path that can be used in PyArrow FS.
@@ -404,7 +373,7 @@ class Table(ABC):
""" """
raise NotImplementedError raise NotImplementedError
def to_pandas(self) -> "pd.DataFrame": def to_pandas(self) -> "pandas.DataFrame":
"""Return the table as a pandas DataFrame. """Return the table as a pandas DataFrame.
Returns Returns
@@ -537,8 +506,8 @@ class Table(ABC):
def create_fts_index( def create_fts_index(
self, self,
field_names: Union[str, List[str]], field_names: Union[str, List[str]],
ordering_field_names: Union[str, List[str]] = None,
*, *,
ordering_field_names: Optional[Union[str, List[str]]] = None,
replace: bool = False, replace: bool = False,
writer_heap_size: Optional[int] = 1024 * 1024 * 1024, writer_heap_size: Optional[int] = 1024 * 1024 * 1024,
use_tantivy: bool = True, use_tantivy: bool = True,
@@ -790,8 +759,7 @@ class Table(ABC):
@abstractmethod @abstractmethod
def _execute_query( def _execute_query(
self, query: Query, batch_size: Optional[int] = None self, query: Query, batch_size: Optional[int] = None
) -> pa.RecordBatchReader: ) -> pa.RecordBatchReader: ...
pass
@abstractmethod @abstractmethod
def _do_merge( def _do_merge(
@@ -800,8 +768,7 @@ class Table(ABC):
new_data: DATA, new_data: DATA,
on_bad_vectors: str, on_bad_vectors: str,
fill_value: float, fill_value: float,
): ): ...
pass
@abstractmethod @abstractmethod
def delete(self, where: str): def delete(self, where: str):
@@ -1121,7 +1088,7 @@ class Table(ABC):
""" """
@abstractmethod @abstractmethod
def list_versions(self): def list_versions(self) -> List[Dict[str, Any]]:
"""List all versions of the table""" """List all versions of the table"""
@cached_property @cached_property
@@ -1244,7 +1211,7 @@ class LanceTable(Table):
A PyArrow schema object.""" A PyArrow schema object."""
return LOOP.run(self._table.schema()) return LOOP.run(self._table.schema())
def list_versions(self): def list_versions(self) -> List[Dict[str, Any]]:
"""List all versions of the table""" """List all versions of the table"""
return LOOP.run(self._table.list_versions()) return LOOP.run(self._table.list_versions())
@@ -1297,7 +1264,7 @@ class LanceTable(Table):
""" """
LOOP.run(self._table.checkout_latest()) LOOP.run(self._table.checkout_latest())
def restore(self, version: int = None): def restore(self, version: Optional[int] = None):
"""Restore a version of the table. This is an in-place operation. """Restore a version of the table. This is an in-place operation.
This creates a new version where the data is equivalent to the This creates a new version where the data is equivalent to the
@@ -1338,7 +1305,7 @@ class LanceTable(Table):
def count_rows(self, filter: Optional[str] = None) -> int: def count_rows(self, filter: Optional[str] = None) -> int:
return LOOP.run(self._table.count_rows(filter)) return LOOP.run(self._table.count_rows(filter))
def __len__(self): def __len__(self) -> int:
return self.count_rows() return self.count_rows()
def __repr__(self) -> str: def __repr__(self) -> str:
@@ -1506,8 +1473,8 @@ class LanceTable(Table):
def create_fts_index( def create_fts_index(
self, self,
field_names: Union[str, List[str]], field_names: Union[str, List[str]],
ordering_field_names: Union[str, List[str]] = None,
*, *,
ordering_field_names: Optional[Union[str, List[str]]] = None,
replace: bool = False, replace: bool = False,
writer_heap_size: Optional[int] = 1024 * 1024 * 1024, writer_heap_size: Optional[int] = 1024 * 1024 * 1024,
use_tantivy: bool = True, use_tantivy: bool = True,
@@ -1594,6 +1561,7 @@ class LanceTable(Table):
writer_heap_size=writer_heap_size, writer_heap_size=writer_heap_size,
) )
@staticmethod
def infer_tokenizer_configs(tokenizer_name: str) -> dict: def infer_tokenizer_configs(tokenizer_name: str) -> dict:
if tokenizer_name == "default": if tokenizer_name == "default":
return { return {
@@ -1759,7 +1727,7 @@ class LanceTable(Table):
) )
@overload @overload
def search( def search( # type: ignore
self, self,
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None, query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
vector_column_name: Optional[str] = None, vector_column_name: Optional[str] = None,
@@ -1895,11 +1863,11 @@ class LanceTable(Table):
name: str, name: str,
data: Optional[DATA] = None, data: Optional[DATA] = None,
schema: Optional[pa.Schema] = None, schema: Optional[pa.Schema] = None,
mode: Literal["create", "overwrite", "append"] = "create", mode: Literal["create", "overwrite"] = "create",
exist_ok: bool = False, exist_ok: bool = False,
on_bad_vectors: str = "error", on_bad_vectors: str = "error",
fill_value: float = 0.0, fill_value: float = 0.0,
embedding_functions: List[EmbeddingFunctionConfig] = None, embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
*, *,
storage_options: Optional[Dict[str, str]] = None, storage_options: Optional[Dict[str, str]] = None,
data_storage_version: Optional[str] = None, data_storage_version: Optional[str] = None,
@@ -2065,7 +2033,7 @@ class LanceTable(Table):
older_than, delete_unverified=delete_unverified older_than, delete_unverified=delete_unverified
) )
def compact_files(self, *args, **kwargs): def compact_files(self, *args, **kwargs) -> CompactionStats:
""" """
Run the compaction process on the table. Run the compaction process on the table.
@@ -2450,7 +2418,7 @@ def _process_iterator(data: Iterable, schema: Optional[pa.Schema] = None) -> pa.
if batch_table.schema != schema: if batch_table.schema != schema:
try: try:
batch_table = batch_table.cast(schema) batch_table = batch_table.cast(schema)
except pa.lib.ArrowInvalid: except pa.lib.ArrowInvalid: # type: ignore
raise ValueError( raise ValueError(
f"Input iterator yielded a batch with schema that " f"Input iterator yielded a batch with schema that "
f"does not match the expected schema.\nExpected:\n{schema}\n" f"does not match the expected schema.\nExpected:\n{schema}\n"
@@ -2710,16 +2678,17 @@ class AsyncTable:
on_bad_vectors = "error" on_bad_vectors = "error"
if fill_value is None: if fill_value is None:
fill_value = 0.0 fill_value = 0.0
data, _ = _sanitize_data( table_and_schema: Tuple[pa.Table, pa.Schema] = _sanitize_data(
data, data,
schema, schema,
metadata=schema.metadata, metadata=schema.metadata,
on_bad_vectors=on_bad_vectors, on_bad_vectors=on_bad_vectors,
fill_value=fill_value, fill_value=fill_value,
) )
if isinstance(data, pa.Table): tbl, schema = table_and_schema
data = pa.RecordBatchReader.from_batches(data.schema, data.to_batches()) if isinstance(tbl, pa.Table):
await self._inner.add(data, mode) data = pa.RecordBatchReader.from_batches(schema, tbl.to_batches())
await self._inner.add(data, mode or "append")
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder: def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
""" """
@@ -2817,6 +2786,7 @@ class AsyncTable:
async_query.nearest_to(query.vector) async_query.nearest_to(query.vector)
.distance_type(query.metric) .distance_type(query.metric)
.nprobes(query.nprobes) .nprobes(query.nprobes)
.distance_range(query.lower_bound, query.upper_bound)
) )
if query.refine_factor: if query.refine_factor:
async_query = async_query.refine_factor(query.refine_factor) async_query = async_query.refine_factor(query.refine_factor)
@@ -2977,7 +2947,7 @@ class AsyncTable:
return await self._inner.update(updates_sql, where) return await self._inner.update(updates_sql, where)
async def add_columns(self, transforms: Dict[str, str]): async def add_columns(self, transforms: dict[str, str]):
""" """
Add new columns with defined values. Add new columns with defined values.
@@ -2990,7 +2960,7 @@ class AsyncTable:
""" """
await self._inner.add_columns(list(transforms.items())) await self._inner.add_columns(list(transforms.items()))
async def alter_columns(self, *alterations: Iterable[Dict[str, str]]): async def alter_columns(self, *alterations: Iterable[dict[str, Any]]):
""" """
Alter column names and nullability. Alter column names and nullability.
@@ -3148,9 +3118,12 @@ class AsyncTable:
you have added or modified 100,000 or more records or run more than 20 data you have added or modified 100,000 or more records or run more than 20 data
modification operations. modification operations.
""" """
cleanup_since_ms: Optional[int] = None
if cleanup_older_than is not None: if cleanup_older_than is not None:
cleanup_older_than = round(cleanup_older_than.total_seconds() * 1000) cleanup_since_ms = round(cleanup_older_than.total_seconds() * 1000)
return await self._inner.optimize(cleanup_older_than, delete_unverified) return await self._inner.optimize(
cleanup_since_ms=cleanup_since_ms, delete_unverified=delete_unverified
)
async def list_indices(self) -> Iterable[IndexConfig]: async def list_indices(self) -> Iterable[IndexConfig]:
""" """

View File

@@ -167,8 +167,24 @@ def test_search_index(tmp_path, table):
@pytest.mark.parametrize("use_tantivy", [True, False]) @pytest.mark.parametrize("use_tantivy", [True, False])
def test_search_fts(table, use_tantivy): def test_search_fts(table, use_tantivy):
table.create_fts_index("text", use_tantivy=use_tantivy) table.create_fts_index("text", use_tantivy=use_tantivy)
results = table.search("puppy").limit(5).to_list() results = table.search("puppy").select(["id", "text"]).limit(5).to_list()
assert len(results) == 5 assert len(results) == 5
assert len(results[0]) == 3 # id, text, _score
@pytest.mark.asyncio
async def test_fts_select_async(async_table):
tbl = await async_table
await tbl.create_index("text", config=FTS())
results = (
await tbl.query()
.nearest_to_text("puppy")
.select(["id", "text"])
.limit(5)
.to_list()
)
assert len(results) == 5
assert len(results[0]) == 3 # id, text, _score
def test_search_fts_phrase_query(table): def test_search_fts_phrase_query(table):

View File

@@ -94,6 +94,73 @@ def test_with_row_id(table: lancedb.table.Table):
assert rs["_rowid"].to_pylist() == [0, 1] assert rs["_rowid"].to_pylist() == [0, 1]
def test_distance_range(table: lancedb.table.Table):
q = [0, 0]
rs = table.search(q).to_arrow()
dists = rs["_distance"].to_pylist()
min_dist = dists[0]
max_dist = dists[-1]
res = table.search(q).distance_range(upper_bound=min_dist).to_arrow()
assert len(res) == 0
res = table.search(q).distance_range(lower_bound=max_dist).to_arrow()
assert len(res) == 1
assert res["_distance"].to_pylist() == [max_dist]
res = table.search(q).distance_range(upper_bound=max_dist).to_arrow()
assert len(res) == 1
assert res["_distance"].to_pylist() == [min_dist]
res = table.search(q).distance_range(lower_bound=min_dist).to_arrow()
assert len(res) == 2
assert res["_distance"].to_pylist() == [min_dist, max_dist]
@pytest.mark.asyncio
async def test_distance_range_async(table_async: AsyncTable):
q = [0, 0]
rs = await table_async.query().nearest_to(q).to_arrow()
dists = rs["_distance"].to_pylist()
min_dist = dists[0]
max_dist = dists[-1]
res = (
await table_async.query()
.nearest_to(q)
.distance_range(upper_bound=min_dist)
.to_arrow()
)
assert len(res) == 0
res = (
await table_async.query()
.nearest_to(q)
.distance_range(lower_bound=max_dist)
.to_arrow()
)
assert len(res) == 1
assert res["_distance"].to_pylist() == [max_dist]
res = (
await table_async.query()
.nearest_to(q)
.distance_range(upper_bound=max_dist)
.to_arrow()
)
assert len(res) == 1
assert res["_distance"].to_pylist() == [min_dist]
res = (
await table_async.query()
.nearest_to(q)
.distance_range(lower_bound=min_dist)
.to_arrow()
)
assert len(res) == 2
assert res["_distance"].to_pylist() == [min_dist, max_dist]
def test_vector_query_with_no_limit(table): def test_vector_query_with_no_limit(table):
with pytest.raises(ValueError): with pytest.raises(ValueError):
LanceVectorQueryBuilder(table, [0, 0], "vector").limit(0).select( LanceVectorQueryBuilder(table, [0, 0], "vector").limit(0).select(

View File

@@ -306,6 +306,8 @@ def test_query_sync_minimal():
"k": 10, "k": 10,
"prefilter": False, "prefilter": False,
"refine_factor": None, "refine_factor": None,
"lower_bound": None,
"upper_bound": None,
"ef": None, "ef": None,
"vector": [1.0, 2.0, 3.0], "vector": [1.0, 2.0, 3.0],
"nprobes": 20, "nprobes": 20,
@@ -348,6 +350,8 @@ def test_query_sync_maximal():
"refine_factor": 10, "refine_factor": 10,
"vector": [1.0, 2.0, 3.0], "vector": [1.0, 2.0, 3.0],
"nprobes": 5, "nprobes": 5,
"lower_bound": None,
"upper_bound": None,
"ef": None, "ef": None,
"filter": "id > 0", "filter": "id > 0",
"columns": ["id", "name"], "columns": ["id", "name"],
@@ -449,6 +453,8 @@ def test_query_sync_hybrid():
"refine_factor": None, "refine_factor": None,
"vector": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], "vector": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
"nprobes": 20, "nprobes": 20,
"lower_bound": None,
"upper_bound": None,
"ef": None, "ef": None,
"with_row_id": True, "with_row_id": True,
"version": None, "version": None,

View File

@@ -152,6 +152,10 @@ impl FTSQuery {
self.inner = self.inner.clone().select(Select::dynamic(&columns)); self.inner = self.inner.clone().select(Select::dynamic(&columns));
} }
pub fn select_columns(&mut self, columns: Vec<String>) {
self.inner = self.inner.clone().select(Select::columns(&columns));
}
pub fn limit(&mut self, limit: u32) { pub fn limit(&mut self, limit: u32) {
self.inner = self.inner.clone().limit(limit as usize); self.inner = self.inner.clone().limit(limit as usize);
} }
@@ -280,6 +284,11 @@ impl VectorQuery {
self.inner = self.inner.clone().nprobes(nprobe as usize); self.inner = self.inner.clone().nprobes(nprobe as usize);
} }
#[pyo3(signature = (lower_bound=None, upper_bound=None))]
pub fn distance_range(&mut self, lower_bound: Option<f32>, upper_bound: Option<f32>) {
self.inner = self.inner.clone().distance_range(lower_bound, upper_bound);
}
pub fn ef(&mut self, ef: u32) { pub fn ef(&mut self, ef: u32) {
self.inner = self.inner.clone().ef(ef as usize); self.inner = self.inner.clone().ef(ef as usize);
} }
@@ -341,6 +350,11 @@ impl HybridQuery {
self.inner_fts.select(columns); self.inner_fts.select(columns);
} }
pub fn select_columns(&mut self, columns: Vec<String>) {
self.inner_vec.select_columns(columns.clone());
self.inner_fts.select_columns(columns);
}
pub fn limit(&mut self, limit: u32) { pub fn limit(&mut self, limit: u32) {
self.inner_vec.limit(limit); self.inner_vec.limit(limit);
self.inner_fts.limit(limit); self.inner_fts.limit(limit);

View File

@@ -97,10 +97,12 @@ impl Table {
self.name.clone() self.name.clone()
} }
/// Returns True if the table is open, False if it is closed.
pub fn is_open(&self) -> bool { pub fn is_open(&self) -> bool {
self.inner.is_some() self.inner.is_some()
} }
/// Closes the table, releasing any resources associated with it.
pub fn close(&mut self) { pub fn close(&mut self) {
self.inner.take(); self.inner.take();
} }
@@ -301,6 +303,7 @@ impl Table {
Query::new(self.inner_ref().unwrap().query()) Query::new(self.inner_ref().unwrap().query())
} }
/// Optimize the on-disk data by compacting and pruning old data, for better performance.
#[pyo3(signature = (cleanup_since_ms=None, delete_unverified=None))] #[pyo3(signature = (cleanup_since_ms=None, delete_unverified=None))]
pub fn optimize( pub fn optimize(
self_: PyRef<'_, Self>, self_: PyRef<'_, Self>,

View File

@@ -755,6 +755,10 @@ pub struct VectorQuery {
// IVF PQ - ANN search. // IVF PQ - ANN search.
pub(crate) query_vector: Vec<Arc<dyn Array>>, pub(crate) query_vector: Vec<Arc<dyn Array>>,
pub(crate) nprobes: usize, pub(crate) nprobes: usize,
// The lower bound (inclusive) of the distance to search for.
pub(crate) lower_bound: Option<f32>,
// The upper bound (exclusive) of the distance to search for.
pub(crate) upper_bound: Option<f32>,
// The number of candidates to return during the refine step for HNSW, // The number of candidates to return during the refine step for HNSW,
// defaults to 1.5 * limit. // defaults to 1.5 * limit.
pub(crate) ef: Option<usize>, pub(crate) ef: Option<usize>,
@@ -771,6 +775,8 @@ impl VectorQuery {
column: None, column: None,
query_vector: Vec::new(), query_vector: Vec::new(),
nprobes: 20, nprobes: 20,
lower_bound: None,
upper_bound: None,
ef: None, ef: None,
refine_factor: None, refine_factor: None,
distance_type: None, distance_type: None,
@@ -831,6 +837,14 @@ impl VectorQuery {
self self
} }
/// Set the distance range for vector search,
/// only rows with distances in the range [lower_bound, upper_bound) will be returned
pub fn distance_range(mut self, lower_bound: Option<f32>, upper_bound: Option<f32>) -> Self {
self.lower_bound = lower_bound;
self.upper_bound = upper_bound;
self
}
/// Set the number of candidates to return during the refine step for HNSW /// Set the number of candidates to return during the refine step for HNSW
/// ///
/// This argument is only used when the vector column has an HNSW index. /// This argument is only used when the vector column has an HNSW index.
@@ -1350,6 +1364,30 @@ mod tests {
} }
} }
#[tokio::test]
async fn test_distance_range() {
let tmp_dir = tempdir().unwrap();
let table = make_test_table(&tmp_dir).await;
let results = table
.vector_search(&[0.1, 0.2, 0.3, 0.4])
.unwrap()
.distance_range(Some(0.0), Some(1.0))
.limit(10)
.execute()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
for batch in results {
let distances = batch["_distance"].as_primitive::<Float32Type>();
assert!(distances.iter().all(|d| {
let d = d.unwrap();
(0.0..1.0).contains(&d)
}));
}
}
#[tokio::test] #[tokio::test]
async fn test_multiple_query_vectors() { async fn test_multiple_query_vectors() {
let tmp_dir = tempdir().unwrap(); let tmp_dir = tempdir().unwrap();

View File

@@ -210,6 +210,8 @@ impl<S: HttpSend> RemoteTable<S> {
body["prefilter"] = query.base.prefilter.into(); body["prefilter"] = query.base.prefilter.into();
body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default()); body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default());
body["nprobes"] = query.nprobes.into(); body["nprobes"] = query.nprobes.into();
body["lower_bound"] = query.lower_bound.into();
body["upper_bound"] = query.upper_bound.into();
body["ef"] = query.ef.into(); body["ef"] = query.ef.into();
body["refine_factor"] = query.refine_factor.into(); body["refine_factor"] = query.refine_factor.into();
if let Some(vector_column) = query.column.as_ref() { if let Some(vector_column) = query.column.as_ref() {
@@ -1304,6 +1306,8 @@ mod tests {
"prefilter": true, "prefilter": true,
"distance_type": "l2", "distance_type": "l2",
"nprobes": 20, "nprobes": 20,
"lower_bound": Option::<f32>::None,
"upper_bound": Option::<f32>::None,
"k": 10, "k": 10,
"ef": Option::<usize>::None, "ef": Option::<usize>::None,
"refine_factor": null, "refine_factor": null,
@@ -1353,6 +1357,8 @@ mod tests {
"bypass_vector_index": true, "bypass_vector_index": true,
"columns": ["a", "b"], "columns": ["a", "b"],
"nprobes": 12, "nprobes": 12,
"lower_bound": Option::<f32>::None,
"upper_bound": Option::<f32>::None,
"ef": Option::<usize>::None, "ef": Option::<usize>::None,
"refine_factor": 2, "refine_factor": 2,
"version": null, "version": null,

View File

@@ -1944,6 +1944,7 @@ impl TableInternal for NativeTable {
if let Some(ef) = query.ef { if let Some(ef) = query.ef {
scanner.ef(ef); scanner.ef(ef);
} }
scanner.distance_range(query.lower_bound, query.upper_bound);
scanner.use_index(query.use_index); scanner.use_index(query.use_index);
scanner.prefilter(query.base.prefilter); scanner.prefilter(query.base.prefilter);
match query.base.select { match query.base.select {