mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-24 13:59:58 +00:00
Compare commits
5 Commits
python-v0.
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a27c5cf12b | ||
|
|
f4dea72cc5 | ||
|
|
f76c4a5ce1 | ||
|
|
164ce397c2 | ||
|
|
445a312667 |
4
.github/workflows/python.yml
vendored
4
.github/workflows/python.yml
vendored
@@ -30,10 +30,10 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
python-version: "3.12"
|
||||
- name: Install ruff
|
||||
run: |
|
||||
pip install ruff==0.5.4
|
||||
pip install ruff==0.8.4
|
||||
- name: Format check
|
||||
run: ruff format --check .
|
||||
- name: Lint
|
||||
|
||||
16
Cargo.toml
16
Cargo.toml
@@ -23,14 +23,14 @@ rust-version = "1.78.0"
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=0.21.1", "features" = [
|
||||
"dynamodb",
|
||||
], git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
|
||||
lance-io = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
|
||||
lance-index = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
|
||||
lance-linalg = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
|
||||
lance-table = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
|
||||
lance-testing = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
|
||||
lance-datafusion = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
|
||||
lance-encoding = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.1" }
|
||||
], git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
|
||||
lance-io = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
|
||||
lance-index = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
|
||||
lance-linalg = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
|
||||
lance-table = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
|
||||
lance-testing = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
|
||||
lance-datafusion = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
|
||||
lance-encoding = { version = "=0.21.1", git = "https://github.com/lancedb/lance.git", tag = "v0.21.1-beta.2" }
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "53.2", optional = false }
|
||||
arrow-array = "53.2"
|
||||
|
||||
@@ -50,7 +50,7 @@ Consider that we have a LanceDB table named `my_table`, whose string column `tex
|
||||
});
|
||||
|
||||
await tbl
|
||||
.search("puppy", queryType="fts")
|
||||
.search("puppy", "fts")
|
||||
.select(["text"])
|
||||
.limit(10)
|
||||
.toArray();
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.17.2-beta.1"
|
||||
current_version = "0.17.2-beta.2"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb-python"
|
||||
version = "0.17.2-beta.1"
|
||||
version = "0.17.2-beta.2"
|
||||
edition.workspace = true
|
||||
description = "Python bindings for LanceDB"
|
||||
license.workspace = true
|
||||
|
||||
@@ -53,8 +53,9 @@ tests = [
|
||||
"pytz",
|
||||
"polars>=0.19, <=1.3.0",
|
||||
"tantivy",
|
||||
"pyarrow-stubs"
|
||||
]
|
||||
dev = ["ruff", "pre-commit"]
|
||||
dev = ["ruff", "pre-commit", "pyright"]
|
||||
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
||||
clip = ["torch", "pillow", "open-clip"]
|
||||
embeddings = [
|
||||
@@ -94,3 +95,7 @@ markers = [
|
||||
"asyncio",
|
||||
"s3_test",
|
||||
]
|
||||
|
||||
[tool.pyright]
|
||||
include = ["python/lancedb/table.py"]
|
||||
pythonVersion = "3.12"
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple, Any, Union, Literal
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from .index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq, FTS
|
||||
|
||||
class Connection(object):
|
||||
uri: str
|
||||
async def table_names(
|
||||
@@ -31,16 +33,35 @@ class Connection(object):
|
||||
class Table:
|
||||
def name(self) -> str: ...
|
||||
def __repr__(self) -> str: ...
|
||||
def is_open(self) -> bool: ...
|
||||
def close(self) -> None: ...
|
||||
async def schema(self) -> pa.Schema: ...
|
||||
async def add(self, data: pa.RecordBatchReader, mode: str) -> None: ...
|
||||
async def add(
|
||||
self, data: pa.RecordBatchReader, mode: Literal["append", "overwrite"]
|
||||
) -> None: ...
|
||||
async def update(self, updates: Dict[str, str], where: Optional[str]) -> None: ...
|
||||
async def count_rows(self, filter: Optional[str]) -> int: ...
|
||||
async def create_index(self, column: str, config, replace: Optional[bool]): ...
|
||||
async def create_index(
|
||||
self,
|
||||
column: str,
|
||||
index: Union[IvfFlat, IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS],
|
||||
replace: Optional[bool],
|
||||
): ...
|
||||
async def list_versions(self) -> List[Dict[str, Any]]: ...
|
||||
async def version(self) -> int: ...
|
||||
async def checkout(self, version: int): ...
|
||||
async def checkout_latest(self): ...
|
||||
async def restore(self): ...
|
||||
async def list_indices(self) -> List[IndexConfig]: ...
|
||||
async def list_indices(self) -> list[IndexConfig]: ...
|
||||
async def delete(self, filter: str): ...
|
||||
async def add_columns(self, columns: list[tuple[str, str]]) -> None: ...
|
||||
async def alter_columns(self, columns: list[dict[str, Any]]) -> None: ...
|
||||
async def optimize(
|
||||
self,
|
||||
*,
|
||||
cleanup_since_ms: Optional[int] = None,
|
||||
delete_unverified: Optional[bool] = None,
|
||||
) -> OptimizeStats: ...
|
||||
def query(self) -> Query: ...
|
||||
def vector_search(self) -> VectorQuery: ...
|
||||
|
||||
|
||||
@@ -603,7 +603,7 @@ class AsyncConnection(object):
|
||||
fill_value: Optional[float] = None,
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
*,
|
||||
embedding_functions: List[EmbeddingFunctionConfig] = None,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||
data_storage_version: Optional[str] = None,
|
||||
use_legacy_format: Optional[bool] = None,
|
||||
enable_v2_manifest_paths: Optional[bool] = None,
|
||||
|
||||
@@ -1,20 +1,10 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
"""Full text search index using tantivy-py"""
|
||||
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
@@ -31,7 +21,7 @@ from .table import LanceTable
|
||||
def create_index(
|
||||
index_path: str,
|
||||
text_fields: List[str],
|
||||
ordering_fields: List[str] = None,
|
||||
ordering_fields: Optional[List[str]] = None,
|
||||
tokenizer_name: str = "default",
|
||||
) -> tantivy.Index:
|
||||
"""
|
||||
@@ -75,8 +65,8 @@ def populate_index(
|
||||
index: tantivy.Index,
|
||||
table: LanceTable,
|
||||
fields: List[str],
|
||||
writer_heap_size: int = 1024 * 1024 * 1024,
|
||||
ordering_fields: List[str] = None,
|
||||
writer_heap_size: Optional[int] = None,
|
||||
ordering_fields: Optional[List[str]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Populate an index with data from a LanceTable
|
||||
@@ -99,6 +89,7 @@ def populate_index(
|
||||
"""
|
||||
if ordering_fields is None:
|
||||
ordering_fields = []
|
||||
writer_heap_size = writer_heap_size or 1024 * 1024 * 1024
|
||||
# first check the fields exist and are string or large string type
|
||||
nested = []
|
||||
|
||||
|
||||
@@ -115,6 +115,9 @@ class Query(pydantic.BaseModel):
|
||||
# e.g. `{"nprobes": "10", "refine_factor": "10"}`
|
||||
nprobes: int = 10
|
||||
|
||||
lower_bound: Optional[float] = None
|
||||
upper_bound: Optional[float] = None
|
||||
|
||||
# Refine factor.
|
||||
refine_factor: Optional[int] = None
|
||||
|
||||
@@ -604,6 +607,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
self._query = query
|
||||
self._metric = "L2"
|
||||
self._nprobes = 20
|
||||
self._lower_bound = None
|
||||
self._upper_bound = None
|
||||
self._refine_factor = None
|
||||
self._vector_column = vector_column
|
||||
self._prefilter = False
|
||||
@@ -649,6 +654,30 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
self._nprobes = nprobes
|
||||
return self
|
||||
|
||||
def distance_range(
|
||||
self, lower_bound: Optional[float] = None, upper_bound: Optional[float] = None
|
||||
) -> LanceVectorQueryBuilder:
|
||||
"""Set the distance range to use.
|
||||
|
||||
Only rows with distances within range [lower_bound, upper_bound)
|
||||
will be returned.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
lower: Optional[float]
|
||||
The lower bound of the distance range.
|
||||
upper_bound: Optional[float]
|
||||
The upper bound of the distance range.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceVectorQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._lower_bound = lower_bound
|
||||
self._upper_bound = upper_bound
|
||||
return self
|
||||
|
||||
def ef(self, ef: int) -> LanceVectorQueryBuilder:
|
||||
"""Set the number of candidates to consider during search.
|
||||
|
||||
@@ -728,6 +757,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
metric=self._metric,
|
||||
columns=self._columns,
|
||||
nprobes=self._nprobes,
|
||||
lower_bound=self._lower_bound,
|
||||
upper_bound=self._upper_bound,
|
||||
refine_factor=self._refine_factor,
|
||||
vector_column=self._vector_column,
|
||||
with_row_id=self._with_row_id,
|
||||
@@ -1284,6 +1315,31 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
self._nprobes = nprobes
|
||||
return self
|
||||
|
||||
def distance_range(
|
||||
self, lower_bound: Optional[float] = None, upper_bound: Optional[float] = None
|
||||
) -> LanceHybridQueryBuilder:
|
||||
"""
|
||||
Set the distance range to use.
|
||||
|
||||
Only rows with distances within range [lower_bound, upper_bound)
|
||||
will be returned.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
lower: Optional[float]
|
||||
The lower bound of the distance range.
|
||||
upper_bound: Optional[float]
|
||||
The upper bound of the distance range.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceHybridQueryBuilder
|
||||
The LanceHybridQueryBuilder object.
|
||||
"""
|
||||
self._lower_bound = lower_bound
|
||||
self._upper_bound = upper_bound
|
||||
return self
|
||||
|
||||
def ef(self, ef: int) -> LanceHybridQueryBuilder:
|
||||
"""
|
||||
Set the number of candidates to consider during search.
|
||||
@@ -1855,6 +1911,29 @@ class AsyncVectorQuery(AsyncQueryBase):
|
||||
self._inner.nprobes(nprobes)
|
||||
return self
|
||||
|
||||
def distance_range(
|
||||
self, lower_bound: Optional[float] = None, upper_bound: Optional[float] = None
|
||||
) -> AsyncVectorQuery:
|
||||
"""Set the distance range to use.
|
||||
|
||||
Only rows with distances within range [lower_bound, upper_bound)
|
||||
will be returned.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
lower: Optional[float]
|
||||
The lower bound of the distance range.
|
||||
upper_bound: Optional[float]
|
||||
The upper bound of the distance range.
|
||||
|
||||
Returns
|
||||
-------
|
||||
AsyncVectorQuery
|
||||
The AsyncVectorQuery object.
|
||||
"""
|
||||
self._inner.distance_range(lower_bound, upper_bound)
|
||||
return self
|
||||
|
||||
def ef(self, ef: int) -> AsyncVectorQuery:
|
||||
"""
|
||||
Set the number of candidates to consider during search
|
||||
|
||||
@@ -61,11 +61,12 @@ from .index import lang_mapping
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import PIL
|
||||
from lance.dataset import CleanupStats, ReaderLike
|
||||
from ._lancedb import Table as LanceDBTable, OptimizeStats
|
||||
from ._lancedb import Table as LanceDBTable, OptimizeStats, CompactionStats
|
||||
from .db import LanceDBConnection
|
||||
from .index import IndexConfig
|
||||
from lance.dataset import CleanupStats, ReaderLike
|
||||
import pandas
|
||||
import PIL
|
||||
|
||||
pd = safe_import_pandas()
|
||||
pl = safe_import_polars()
|
||||
@@ -84,7 +85,6 @@ def _pd_schema_without_embedding_funcs(
|
||||
)
|
||||
if not embedding_functions:
|
||||
return schema
|
||||
columns = set(columns)
|
||||
return pa.schema([field for field in schema if field.name in columns])
|
||||
|
||||
|
||||
@@ -119,7 +119,7 @@ def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table:
|
||||
return pa.Table.from_batches(data, schema=schema)
|
||||
else:
|
||||
return pa.Table.from_pylist(data, schema=schema)
|
||||
elif _check_for_pandas(data) and isinstance(data, pd.DataFrame):
|
||||
elif _check_for_pandas(data) and isinstance(data, pd.DataFrame): # type: ignore
|
||||
raw_schema = _pd_schema_without_embedding_funcs(schema, data.columns.to_list())
|
||||
table = pa.Table.from_pandas(data, preserve_index=False, schema=raw_schema)
|
||||
# Do not serialize Pandas metadata
|
||||
@@ -160,7 +160,7 @@ def _sanitize_data(
|
||||
metadata: Optional[dict] = None, # embedding metadata
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
):
|
||||
) -> Tuple[pa.Table, pa.Schema]:
|
||||
data = _coerce_to_table(data, schema)
|
||||
|
||||
if metadata:
|
||||
@@ -178,13 +178,17 @@ def _sanitize_data(
|
||||
|
||||
|
||||
def sanitize_create_table(
|
||||
data, schema, metadata=None, on_bad_vectors="error", fill_value=0.0
|
||||
data,
|
||||
schema: Union[pa.Schema, LanceModel],
|
||||
metadata=None,
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
):
|
||||
if inspect.isclass(schema) and issubclass(schema, LanceModel):
|
||||
# convert LanceModel to pyarrow schema
|
||||
# note that it's possible this contains
|
||||
# embedding function metadata already
|
||||
schema = schema.to_arrow_schema()
|
||||
schema: pa.Schema = schema.to_arrow_schema()
|
||||
|
||||
if data is not None:
|
||||
if metadata is None and schema is not None:
|
||||
@@ -272,41 +276,6 @@ def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schem
|
||||
return data
|
||||
|
||||
|
||||
def _generator_to_data_and_schema(
|
||||
data: Iterable,
|
||||
) -> Tuple[Iterable[pa.RecordBatch], pa.Schema]:
|
||||
def _with_first_generator(first, data):
|
||||
yield first
|
||||
yield from data
|
||||
|
||||
first = next(data, None)
|
||||
schema = None
|
||||
if isinstance(first, pa.RecordBatch):
|
||||
schema = first.schema
|
||||
data = _with_first_generator(first, data)
|
||||
elif isinstance(first, pa.Table):
|
||||
schema = first.schema
|
||||
data = _with_first_generator(first.to_batches(), data)
|
||||
return data, schema
|
||||
|
||||
|
||||
def _to_record_batch_generator(
|
||||
data: Iterable,
|
||||
schema,
|
||||
metadata,
|
||||
on_bad_vectors,
|
||||
fill_value,
|
||||
):
|
||||
for batch in data:
|
||||
# always convert to table because we need to sanitize the data
|
||||
# and do things like add the vector column etc
|
||||
if isinstance(batch, pa.RecordBatch):
|
||||
batch = pa.Table.from_batches([batch])
|
||||
batch, _ = _sanitize_data(batch, schema, metadata, on_bad_vectors, fill_value)
|
||||
for b in batch.to_batches():
|
||||
yield b
|
||||
|
||||
|
||||
def _table_path(base: str, table_name: str) -> str:
|
||||
"""
|
||||
Get a table path that can be used in PyArrow FS.
|
||||
@@ -404,7 +373,7 @@ class Table(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def to_pandas(self) -> "pd.DataFrame":
|
||||
def to_pandas(self) -> "pandas.DataFrame":
|
||||
"""Return the table as a pandas DataFrame.
|
||||
|
||||
Returns
|
||||
@@ -537,8 +506,8 @@ class Table(ABC):
|
||||
def create_fts_index(
|
||||
self,
|
||||
field_names: Union[str, List[str]],
|
||||
ordering_field_names: Union[str, List[str]] = None,
|
||||
*,
|
||||
ordering_field_names: Optional[Union[str, List[str]]] = None,
|
||||
replace: bool = False,
|
||||
writer_heap_size: Optional[int] = 1024 * 1024 * 1024,
|
||||
use_tantivy: bool = True,
|
||||
@@ -790,8 +759,7 @@ class Table(ABC):
|
||||
@abstractmethod
|
||||
def _execute_query(
|
||||
self, query: Query, batch_size: Optional[int] = None
|
||||
) -> pa.RecordBatchReader:
|
||||
pass
|
||||
) -> pa.RecordBatchReader: ...
|
||||
|
||||
@abstractmethod
|
||||
def _do_merge(
|
||||
@@ -800,8 +768,7 @@ class Table(ABC):
|
||||
new_data: DATA,
|
||||
on_bad_vectors: str,
|
||||
fill_value: float,
|
||||
):
|
||||
pass
|
||||
): ...
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, where: str):
|
||||
@@ -1121,7 +1088,7 @@ class Table(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def list_versions(self):
|
||||
def list_versions(self) -> List[Dict[str, Any]]:
|
||||
"""List all versions of the table"""
|
||||
|
||||
@cached_property
|
||||
@@ -1244,7 +1211,7 @@ class LanceTable(Table):
|
||||
A PyArrow schema object."""
|
||||
return LOOP.run(self._table.schema())
|
||||
|
||||
def list_versions(self):
|
||||
def list_versions(self) -> List[Dict[str, Any]]:
|
||||
"""List all versions of the table"""
|
||||
return LOOP.run(self._table.list_versions())
|
||||
|
||||
@@ -1297,7 +1264,7 @@ class LanceTable(Table):
|
||||
"""
|
||||
LOOP.run(self._table.checkout_latest())
|
||||
|
||||
def restore(self, version: int = None):
|
||||
def restore(self, version: Optional[int] = None):
|
||||
"""Restore a version of the table. This is an in-place operation.
|
||||
|
||||
This creates a new version where the data is equivalent to the
|
||||
@@ -1338,7 +1305,7 @@ class LanceTable(Table):
|
||||
def count_rows(self, filter: Optional[str] = None) -> int:
|
||||
return LOOP.run(self._table.count_rows(filter))
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return self.count_rows()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@@ -1506,8 +1473,8 @@ class LanceTable(Table):
|
||||
def create_fts_index(
|
||||
self,
|
||||
field_names: Union[str, List[str]],
|
||||
ordering_field_names: Union[str, List[str]] = None,
|
||||
*,
|
||||
ordering_field_names: Optional[Union[str, List[str]]] = None,
|
||||
replace: bool = False,
|
||||
writer_heap_size: Optional[int] = 1024 * 1024 * 1024,
|
||||
use_tantivy: bool = True,
|
||||
@@ -1594,6 +1561,7 @@ class LanceTable(Table):
|
||||
writer_heap_size=writer_heap_size,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def infer_tokenizer_configs(tokenizer_name: str) -> dict:
|
||||
if tokenizer_name == "default":
|
||||
return {
|
||||
@@ -1759,7 +1727,7 @@ class LanceTable(Table):
|
||||
)
|
||||
|
||||
@overload
|
||||
def search(
|
||||
def search( # type: ignore
|
||||
self,
|
||||
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
||||
vector_column_name: Optional[str] = None,
|
||||
@@ -1895,11 +1863,11 @@ class LanceTable(Table):
|
||||
name: str,
|
||||
data: Optional[DATA] = None,
|
||||
schema: Optional[pa.Schema] = None,
|
||||
mode: Literal["create", "overwrite", "append"] = "create",
|
||||
mode: Literal["create", "overwrite"] = "create",
|
||||
exist_ok: bool = False,
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
embedding_functions: List[EmbeddingFunctionConfig] = None,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||
*,
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
data_storage_version: Optional[str] = None,
|
||||
@@ -2065,7 +2033,7 @@ class LanceTable(Table):
|
||||
older_than, delete_unverified=delete_unverified
|
||||
)
|
||||
|
||||
def compact_files(self, *args, **kwargs):
|
||||
def compact_files(self, *args, **kwargs) -> CompactionStats:
|
||||
"""
|
||||
Run the compaction process on the table.
|
||||
|
||||
@@ -2450,7 +2418,7 @@ def _process_iterator(data: Iterable, schema: Optional[pa.Schema] = None) -> pa.
|
||||
if batch_table.schema != schema:
|
||||
try:
|
||||
batch_table = batch_table.cast(schema)
|
||||
except pa.lib.ArrowInvalid:
|
||||
except pa.lib.ArrowInvalid: # type: ignore
|
||||
raise ValueError(
|
||||
f"Input iterator yielded a batch with schema that "
|
||||
f"does not match the expected schema.\nExpected:\n{schema}\n"
|
||||
@@ -2710,16 +2678,17 @@ class AsyncTable:
|
||||
on_bad_vectors = "error"
|
||||
if fill_value is None:
|
||||
fill_value = 0.0
|
||||
data, _ = _sanitize_data(
|
||||
table_and_schema: Tuple[pa.Table, pa.Schema] = _sanitize_data(
|
||||
data,
|
||||
schema,
|
||||
metadata=schema.metadata,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
if isinstance(data, pa.Table):
|
||||
data = pa.RecordBatchReader.from_batches(data.schema, data.to_batches())
|
||||
await self._inner.add(data, mode)
|
||||
tbl, schema = table_and_schema
|
||||
if isinstance(tbl, pa.Table):
|
||||
data = pa.RecordBatchReader.from_batches(schema, tbl.to_batches())
|
||||
await self._inner.add(data, mode or "append")
|
||||
|
||||
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
|
||||
"""
|
||||
@@ -2817,6 +2786,7 @@ class AsyncTable:
|
||||
async_query.nearest_to(query.vector)
|
||||
.distance_type(query.metric)
|
||||
.nprobes(query.nprobes)
|
||||
.distance_range(query.lower_bound, query.upper_bound)
|
||||
)
|
||||
if query.refine_factor:
|
||||
async_query = async_query.refine_factor(query.refine_factor)
|
||||
@@ -2977,7 +2947,7 @@ class AsyncTable:
|
||||
|
||||
return await self._inner.update(updates_sql, where)
|
||||
|
||||
async def add_columns(self, transforms: Dict[str, str]):
|
||||
async def add_columns(self, transforms: dict[str, str]):
|
||||
"""
|
||||
Add new columns with defined values.
|
||||
|
||||
@@ -2990,7 +2960,7 @@ class AsyncTable:
|
||||
"""
|
||||
await self._inner.add_columns(list(transforms.items()))
|
||||
|
||||
async def alter_columns(self, *alterations: Iterable[Dict[str, str]]):
|
||||
async def alter_columns(self, *alterations: Iterable[dict[str, Any]]):
|
||||
"""
|
||||
Alter column names and nullability.
|
||||
|
||||
@@ -3148,9 +3118,12 @@ class AsyncTable:
|
||||
you have added or modified 100,000 or more records or run more than 20 data
|
||||
modification operations.
|
||||
"""
|
||||
cleanup_since_ms: Optional[int] = None
|
||||
if cleanup_older_than is not None:
|
||||
cleanup_older_than = round(cleanup_older_than.total_seconds() * 1000)
|
||||
return await self._inner.optimize(cleanup_older_than, delete_unverified)
|
||||
cleanup_since_ms = round(cleanup_older_than.total_seconds() * 1000)
|
||||
return await self._inner.optimize(
|
||||
cleanup_since_ms=cleanup_since_ms, delete_unverified=delete_unverified
|
||||
)
|
||||
|
||||
async def list_indices(self) -> Iterable[IndexConfig]:
|
||||
"""
|
||||
|
||||
@@ -167,8 +167,24 @@ def test_search_index(tmp_path, table):
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_search_fts(table, use_tantivy):
|
||||
table.create_fts_index("text", use_tantivy=use_tantivy)
|
||||
results = table.search("puppy").limit(5).to_list()
|
||||
results = table.search("puppy").select(["id", "text"]).limit(5).to_list()
|
||||
assert len(results) == 5
|
||||
assert len(results[0]) == 3 # id, text, _score
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fts_select_async(async_table):
|
||||
tbl = await async_table
|
||||
await tbl.create_index("text", config=FTS())
|
||||
results = (
|
||||
await tbl.query()
|
||||
.nearest_to_text("puppy")
|
||||
.select(["id", "text"])
|
||||
.limit(5)
|
||||
.to_list()
|
||||
)
|
||||
assert len(results) == 5
|
||||
assert len(results[0]) == 3 # id, text, _score
|
||||
|
||||
|
||||
def test_search_fts_phrase_query(table):
|
||||
|
||||
@@ -94,6 +94,73 @@ def test_with_row_id(table: lancedb.table.Table):
|
||||
assert rs["_rowid"].to_pylist() == [0, 1]
|
||||
|
||||
|
||||
def test_distance_range(table: lancedb.table.Table):
|
||||
q = [0, 0]
|
||||
rs = table.search(q).to_arrow()
|
||||
dists = rs["_distance"].to_pylist()
|
||||
min_dist = dists[0]
|
||||
max_dist = dists[-1]
|
||||
|
||||
res = table.search(q).distance_range(upper_bound=min_dist).to_arrow()
|
||||
assert len(res) == 0
|
||||
|
||||
res = table.search(q).distance_range(lower_bound=max_dist).to_arrow()
|
||||
assert len(res) == 1
|
||||
assert res["_distance"].to_pylist() == [max_dist]
|
||||
|
||||
res = table.search(q).distance_range(upper_bound=max_dist).to_arrow()
|
||||
assert len(res) == 1
|
||||
assert res["_distance"].to_pylist() == [min_dist]
|
||||
|
||||
res = table.search(q).distance_range(lower_bound=min_dist).to_arrow()
|
||||
assert len(res) == 2
|
||||
assert res["_distance"].to_pylist() == [min_dist, max_dist]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_distance_range_async(table_async: AsyncTable):
|
||||
q = [0, 0]
|
||||
rs = await table_async.query().nearest_to(q).to_arrow()
|
||||
dists = rs["_distance"].to_pylist()
|
||||
min_dist = dists[0]
|
||||
max_dist = dists[-1]
|
||||
|
||||
res = (
|
||||
await table_async.query()
|
||||
.nearest_to(q)
|
||||
.distance_range(upper_bound=min_dist)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(res) == 0
|
||||
|
||||
res = (
|
||||
await table_async.query()
|
||||
.nearest_to(q)
|
||||
.distance_range(lower_bound=max_dist)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(res) == 1
|
||||
assert res["_distance"].to_pylist() == [max_dist]
|
||||
|
||||
res = (
|
||||
await table_async.query()
|
||||
.nearest_to(q)
|
||||
.distance_range(upper_bound=max_dist)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(res) == 1
|
||||
assert res["_distance"].to_pylist() == [min_dist]
|
||||
|
||||
res = (
|
||||
await table_async.query()
|
||||
.nearest_to(q)
|
||||
.distance_range(lower_bound=min_dist)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(res) == 2
|
||||
assert res["_distance"].to_pylist() == [min_dist, max_dist]
|
||||
|
||||
|
||||
def test_vector_query_with_no_limit(table):
|
||||
with pytest.raises(ValueError):
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector").limit(0).select(
|
||||
|
||||
@@ -306,6 +306,8 @@ def test_query_sync_minimal():
|
||||
"k": 10,
|
||||
"prefilter": False,
|
||||
"refine_factor": None,
|
||||
"lower_bound": None,
|
||||
"upper_bound": None,
|
||||
"ef": None,
|
||||
"vector": [1.0, 2.0, 3.0],
|
||||
"nprobes": 20,
|
||||
@@ -348,6 +350,8 @@ def test_query_sync_maximal():
|
||||
"refine_factor": 10,
|
||||
"vector": [1.0, 2.0, 3.0],
|
||||
"nprobes": 5,
|
||||
"lower_bound": None,
|
||||
"upper_bound": None,
|
||||
"ef": None,
|
||||
"filter": "id > 0",
|
||||
"columns": ["id", "name"],
|
||||
@@ -449,6 +453,8 @@ def test_query_sync_hybrid():
|
||||
"refine_factor": None,
|
||||
"vector": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
"nprobes": 20,
|
||||
"lower_bound": None,
|
||||
"upper_bound": None,
|
||||
"ef": None,
|
||||
"with_row_id": True,
|
||||
"version": None,
|
||||
|
||||
@@ -152,6 +152,10 @@ impl FTSQuery {
|
||||
self.inner = self.inner.clone().select(Select::dynamic(&columns));
|
||||
}
|
||||
|
||||
pub fn select_columns(&mut self, columns: Vec<String>) {
|
||||
self.inner = self.inner.clone().select(Select::columns(&columns));
|
||||
}
|
||||
|
||||
pub fn limit(&mut self, limit: u32) {
|
||||
self.inner = self.inner.clone().limit(limit as usize);
|
||||
}
|
||||
@@ -280,6 +284,11 @@ impl VectorQuery {
|
||||
self.inner = self.inner.clone().nprobes(nprobe as usize);
|
||||
}
|
||||
|
||||
#[pyo3(signature = (lower_bound=None, upper_bound=None))]
|
||||
pub fn distance_range(&mut self, lower_bound: Option<f32>, upper_bound: Option<f32>) {
|
||||
self.inner = self.inner.clone().distance_range(lower_bound, upper_bound);
|
||||
}
|
||||
|
||||
pub fn ef(&mut self, ef: u32) {
|
||||
self.inner = self.inner.clone().ef(ef as usize);
|
||||
}
|
||||
@@ -341,6 +350,11 @@ impl HybridQuery {
|
||||
self.inner_fts.select(columns);
|
||||
}
|
||||
|
||||
pub fn select_columns(&mut self, columns: Vec<String>) {
|
||||
self.inner_vec.select_columns(columns.clone());
|
||||
self.inner_fts.select_columns(columns);
|
||||
}
|
||||
|
||||
pub fn limit(&mut self, limit: u32) {
|
||||
self.inner_vec.limit(limit);
|
||||
self.inner_fts.limit(limit);
|
||||
|
||||
@@ -97,10 +97,12 @@ impl Table {
|
||||
self.name.clone()
|
||||
}
|
||||
|
||||
/// Returns True if the table is open, False if it is closed.
|
||||
pub fn is_open(&self) -> bool {
|
||||
self.inner.is_some()
|
||||
}
|
||||
|
||||
/// Closes the table, releasing any resources associated with it.
|
||||
pub fn close(&mut self) {
|
||||
self.inner.take();
|
||||
}
|
||||
@@ -301,6 +303,7 @@ impl Table {
|
||||
Query::new(self.inner_ref().unwrap().query())
|
||||
}
|
||||
|
||||
/// Optimize the on-disk data by compacting and pruning old data, for better performance.
|
||||
#[pyo3(signature = (cleanup_since_ms=None, delete_unverified=None))]
|
||||
pub fn optimize(
|
||||
self_: PyRef<'_, Self>,
|
||||
|
||||
@@ -755,6 +755,10 @@ pub struct VectorQuery {
|
||||
// IVF PQ - ANN search.
|
||||
pub(crate) query_vector: Vec<Arc<dyn Array>>,
|
||||
pub(crate) nprobes: usize,
|
||||
// The lower bound (inclusive) of the distance to search for.
|
||||
pub(crate) lower_bound: Option<f32>,
|
||||
// The upper bound (exclusive) of the distance to search for.
|
||||
pub(crate) upper_bound: Option<f32>,
|
||||
// The number of candidates to return during the refine step for HNSW,
|
||||
// defaults to 1.5 * limit.
|
||||
pub(crate) ef: Option<usize>,
|
||||
@@ -771,6 +775,8 @@ impl VectorQuery {
|
||||
column: None,
|
||||
query_vector: Vec::new(),
|
||||
nprobes: 20,
|
||||
lower_bound: None,
|
||||
upper_bound: None,
|
||||
ef: None,
|
||||
refine_factor: None,
|
||||
distance_type: None,
|
||||
@@ -831,6 +837,14 @@ impl VectorQuery {
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the distance range for vector search,
|
||||
/// only rows with distances in the range [lower_bound, upper_bound) will be returned
|
||||
pub fn distance_range(mut self, lower_bound: Option<f32>, upper_bound: Option<f32>) -> Self {
|
||||
self.lower_bound = lower_bound;
|
||||
self.upper_bound = upper_bound;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the number of candidates to return during the refine step for HNSW
|
||||
///
|
||||
/// This argument is only used when the vector column has an HNSW index.
|
||||
@@ -1350,6 +1364,30 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_distance_range() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let table = make_test_table(&tmp_dir).await;
|
||||
let results = table
|
||||
.vector_search(&[0.1, 0.2, 0.3, 0.4])
|
||||
.unwrap()
|
||||
.distance_range(Some(0.0), Some(1.0))
|
||||
.limit(10)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap()
|
||||
.try_collect::<Vec<_>>()
|
||||
.await
|
||||
.unwrap();
|
||||
for batch in results {
|
||||
let distances = batch["_distance"].as_primitive::<Float32Type>();
|
||||
assert!(distances.iter().all(|d| {
|
||||
let d = d.unwrap();
|
||||
(0.0..1.0).contains(&d)
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multiple_query_vectors() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
|
||||
@@ -210,6 +210,8 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
body["prefilter"] = query.base.prefilter.into();
|
||||
body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default());
|
||||
body["nprobes"] = query.nprobes.into();
|
||||
body["lower_bound"] = query.lower_bound.into();
|
||||
body["upper_bound"] = query.upper_bound.into();
|
||||
body["ef"] = query.ef.into();
|
||||
body["refine_factor"] = query.refine_factor.into();
|
||||
if let Some(vector_column) = query.column.as_ref() {
|
||||
@@ -1304,6 +1306,8 @@ mod tests {
|
||||
"prefilter": true,
|
||||
"distance_type": "l2",
|
||||
"nprobes": 20,
|
||||
"lower_bound": Option::<f32>::None,
|
||||
"upper_bound": Option::<f32>::None,
|
||||
"k": 10,
|
||||
"ef": Option::<usize>::None,
|
||||
"refine_factor": null,
|
||||
@@ -1353,6 +1357,8 @@ mod tests {
|
||||
"bypass_vector_index": true,
|
||||
"columns": ["a", "b"],
|
||||
"nprobes": 12,
|
||||
"lower_bound": Option::<f32>::None,
|
||||
"upper_bound": Option::<f32>::None,
|
||||
"ef": Option::<usize>::None,
|
||||
"refine_factor": 2,
|
||||
"version": null,
|
||||
|
||||
@@ -1944,6 +1944,7 @@ impl TableInternal for NativeTable {
|
||||
if let Some(ef) = query.ef {
|
||||
scanner.ef(ef);
|
||||
}
|
||||
scanner.distance_range(query.lower_bound, query.upper_bound);
|
||||
scanner.use_index(query.use_index);
|
||||
scanner.prefilter(query.base.prefilter);
|
||||
match query.base.select {
|
||||
|
||||
Reference in New Issue
Block a user