Merge remote-tracking branch 'origin/main' into xuanwo/remote-pytorch-multiprocessing

# Conflicts:
#	python/python/lancedb/remote/table.py
This commit is contained in:
Xuanwo
2026-06-01 17:54:45 +08:00
68 changed files with 4423 additions and 672 deletions

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.33.0-beta.0"
current_version = "0.33.0-beta.1"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb-python"
version = "0.33.0-beta.0"
version = "0.33.0-beta.1"
publish = false
edition.workspace = true
description = "Python bindings for LanceDB"

View File

@@ -94,7 +94,6 @@ def connect(
host_override: str, optional
The override url for LanceDB Cloud.
read_consistency_interval: timedelta, default None
(For LanceDB OSS only)
The interval at which to check for updates to the table from other
processes. If None, then consistency is not checked. For performance
reasons, this is the default. For strong consistency, set this to
@@ -104,6 +103,10 @@ def connect(
the last check, then the table will be checked for updates. Note: this
consistency only applies to read operations. Write operations are
always consistent.
Stronger consistency is not free. The smaller the interval, the more
often each read pays the cost of checking for updates against object
storage, raising per-read latency and cost.
client_config: ClientConfig or dict, optional
Configuration options for the LanceDB Cloud HTTP client. If a dict, then
the keys are the attributes of the ClientConfig class. If None, then the
@@ -147,6 +150,13 @@ def connect(
>>> db = lancedb.connect("s3://my-bucket/lancedb",
... storage_options={"aws_access_key_id": "***"})
For tests and temporary data, use an in-memory database:
>>> db = lancedb.connect("memory://")
In-memory databases are not persisted. Tables are dropped when the last
connection or table handle referencing them is closed.
Connect to LanceDB cloud:
>>> db = lancedb.connect("db://my_database", api_key="ldb_...",
@@ -210,6 +220,7 @@ def connect(
request_thread_pool=request_thread_pool,
client_config=client_config,
storage_options=storage_options,
read_consistency_interval=read_consistency_interval,
**kwargs,
)
_check_s3_bucket_with_dots(str(uri), storage_options)
@@ -345,7 +356,6 @@ async def connect_async(
host_override: str, optional
The override url for LanceDB Cloud.
read_consistency_interval: timedelta, default None
(For LanceDB OSS only)
The interval at which to check for updates to the table from other
processes. If None, then consistency is not checked. For performance
reasons, this is the default. For strong consistency, set this to
@@ -355,6 +365,10 @@ async def connect_async(
the last check, then the table will be checked for updates. Note: this
consistency only applies to read operations. Write operations are
always consistent.
Stronger consistency is not free. The smaller the interval, the more
often each read pays the cost of checking for updates against object
storage, raising per-read latency and cost.
client_config: ClientConfig or dict, optional
Configuration options for the LanceDB Cloud HTTP client. If a dict, then
the keys are the attributes of the ClientConfig class. If None, then the
@@ -387,6 +401,8 @@ async def connect_async(
... db = await lancedb.connect_async("s3://my-bucket/lancedb",
... storage_options={
... "aws_access_key_id": "***"})
... # For tests and temporary data, use an in-memory database
... db = await lancedb.connect_async("memory://")
... # Connect to LanceDB cloud
... db = await lancedb.connect_async("db://my_database", api_key="ldb_...",
... client_config={

View File

@@ -220,6 +220,7 @@ class Table:
async def set_unenforced_primary_key(self, columns: List[str]) -> None: ...
async def set_lsm_write_spec(self, spec: LsmWriteSpec) -> None: ...
async def unset_lsm_write_spec(self) -> None: ...
async def close_lsm_writers(self) -> None: ...
@property
def tags(self) -> Tags: ...
def query(self) -> Query: ...
@@ -420,6 +421,7 @@ class MergeResult:
num_inserted_rows: int
num_deleted_rows: int
num_attempts: int
num_rows: int
class LsmWriteSpec:
"""Specification selecting Lance's MemWAL LSM-style write path for

View File

@@ -281,6 +281,9 @@ class HnswPq:
m: int = 20
ef_construction: int = 300
target_partition_size: Optional[int] = None
# Name of the accelerator (e.g. "cuda") to use for IVF training. When set,
# create_index() dispatches to pylance to build the index on the accelerator.
accelerator: Optional[str] = None
@dataclass
@@ -386,6 +389,9 @@ class HnswSq:
m: int = 20
ef_construction: int = 300
target_partition_size: Optional[int] = None
# Name of the accelerator (e.g. "cuda") to use for IVF training. When set,
# create_index() dispatches to pylance to build the index on the accelerator.
accelerator: Optional[str] = None
@dataclass
@@ -579,6 +585,9 @@ class IvfFlat:
max_iterations: int = 50
sample_rate: int = 256
target_partition_size: Optional[int] = None
# Name of the accelerator (e.g. "cuda") to use for IVF training. When set,
# create_index() dispatches to pylance to build the index on the accelerator.
accelerator: Optional[str] = None
@dataclass
@@ -609,6 +618,9 @@ class IvfSq:
max_iterations: int = 50
sample_rate: int = 256
target_partition_size: Optional[int] = None
# Name of the accelerator (e.g. "cuda") to use for IVF training. When set,
# create_index() dispatches to pylance to build the index on the accelerator.
accelerator: Optional[str] = None
@dataclass
@@ -739,6 +751,9 @@ class IvfPq:
max_iterations: int = 50
sample_rate: int = 256
target_partition_size: Optional[int] = None
# Name of the accelerator (e.g. "cuda") to use for IVF training. When set,
# create_index() dispatches to pylance to build the index on the accelerator.
accelerator: Optional[str] = None
@dataclass
@@ -792,6 +807,9 @@ class IvfRq:
max_iterations: int = 50
sample_rate: int = 256
target_partition_size: Optional[int] = None
# Name of the accelerator (e.g. "cuda") to use for IVF training. When set,
# create_index() dispatches to pylance to build the index on the accelerator.
accelerator: Optional[str] = None
__all__ = [

View File

@@ -34,6 +34,8 @@ class LanceMergeInsertBuilder(object):
self._when_not_matched_by_source_condition = None
self._timeout = None
self._use_index = True
self._use_lsm_write = None
self._validate_single_shard = None
def when_matched_update_all(
self, *, where: Optional[str] = None
@@ -96,6 +98,46 @@ class LanceMergeInsertBuilder(object):
self._use_index = use_index
return self
def use_lsm_write(self, use_lsm_write: bool) -> LanceMergeInsertBuilder:
"""
Controls whether the merge uses the MemWAL LSM write path.
By default (unset), a `merge_insert` on a table with an LSM write spec
is routed through Lance's MemWAL shard writer, and a table without one
uses the standard path. Pass `False` to force the standard path even
when a spec is set. Pass `True` to require a spec — `merge_insert`
raises an error if none is installed.
Parameters
----------
use_lsm_write: bool
Whether to use the LSM write path.
"""
self._use_lsm_write = use_lsm_write
return self
def validate_single_shard(
self, validate_single_shard: bool
) -> LanceMergeInsertBuilder:
"""
Controls how an LSM merge checks that its input targets a single shard.
When a table has an LSM write spec, every row in a `merge_insert` call
must route to the same shard. When `True` (the default), every row is
inspected to verify this. When `False`, only the first row is inspected
and the shard it routes to is used for the whole input — a faster path
for callers that have already pre-sharded their input.
Has no effect on tables without an LSM write spec.
Parameters
----------
validate_single_shard: bool
Whether to check every row routes to one shard. Defaults to `True`.
"""
self._validate_single_shard = validate_single_shard
return self
def execute(
self,
new_data: DATA,

View File

@@ -109,6 +109,7 @@ class RemoteDBConnection(DBConnection):
connection_timeout: Optional[float] = None,
read_timeout: Optional[float] = None,
storage_options: Optional[Dict[str, str]] = None,
read_consistency_interval: Optional[timedelta] = None,
):
"""Connect to a remote LanceDB database."""
if isinstance(client_config, dict):
@@ -167,6 +168,7 @@ class RemoteDBConnection(DBConnection):
host_override=host_override,
client_config=client_config,
storage_options=storage_options,
read_consistency_interval=read_consistency_interval,
)
)

View File

@@ -2,12 +2,25 @@
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
from datetime import timedelta
import deprecation
import logging
from functools import cached_property
import os
from typing import Any, Callable, Dict, Iterable, List, Optional, Union, Literal
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Union,
Literal,
overload,
)
import warnings
from lancedb import __version__
from lancedb._lancedb import (
AddColumnsResult,
AddResult,
@@ -33,6 +46,7 @@ from lancedb.index import (
LabelList,
)
from lancedb.remote.db import LOOP
from lancedb.table import IndexConfigType, KNOWN_METRICS
import pyarrow as pa
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
@@ -195,6 +209,11 @@ class RemoteTable(Table):
"""List all the stats of a specified index"""
return LOOP.run(self._table.index_stats(index_uuid))
@deprecation.deprecated(
deprecated_in="0.25.0",
current_version=__version__,
details="Use create_index() with config=BTree()/Bitmap()/LabelList() instead.",
)
def create_scalar_index(
self,
column: str,
@@ -204,7 +223,12 @@ class RemoteTable(Table):
wait_timeout: Optional[timedelta] = None,
name: Optional[str] = None,
):
"""Creates a scalar index
"""Creates a scalar index.
.. deprecated:: 0.25.0
Use :meth:`create_index` with a BTree, Bitmap, or LabelList config instead.
Example: ``table.create_index("column", config=BTree())``
Parameters
----------
column : str
@@ -235,6 +259,11 @@ class RemoteTable(Table):
)
)
@deprecation.deprecated(
deprecated_in="0.25.0",
current_version=__version__,
details="Use create_index() with config=FTS() instead.",
)
def create_fts_index(
self,
column: str,
@@ -255,6 +284,12 @@ class RemoteTable(Table):
prefix_only: bool = False,
name: Optional[str] = None,
):
"""Create a full-text search index on a column.
.. deprecated:: 0.25.0
Use :meth:`create_index` with an FTS config instead.
Example: ``table.create_index("text_column", config=FTS())``
"""
config = FTS(
with_position=with_position,
base_tokenizer=base_tokenizer,
@@ -278,9 +313,43 @@ class RemoteTable(Table):
)
)
# New unified API overload
@overload
def create_index(
self,
metric="l2",
column: str,
/,
*,
config: IndexConfigType,
wait_timeout: Optional[timedelta] = ...,
name: Optional[str] = ...,
train: bool = ...,
) -> None: ...
# Legacy API overload (deprecated)
@overload
def create_index(
self,
metric: Literal["l2", "cosine", "dot", "hamming"] = ...,
vector_column_name: str = ...,
index_cache_size: Optional[int] = ...,
num_partitions: Optional[int] = ...,
num_sub_vectors: Optional[int] = ...,
replace: Optional[bool] = ...,
accelerator: Optional[str] = ...,
index_type: Literal[
"VECTOR", "IVF_FLAT", "IVF_SQ", "IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"
] = ...,
wait_timeout: Optional[timedelta] = ...,
*,
num_bits: int = ...,
name: Optional[str] = ...,
train: bool = ...,
) -> None: ...
def create_index(
self,
metric: str = "l2",
vector_column_name: str = VECTOR_COLUMN_NAME,
index_cache_size: Optional[int] = None,
num_partitions: Optional[int] = None,
@@ -291,89 +360,113 @@ class RemoteTable(Table):
wait_timeout: Optional[timedelta] = None,
*,
num_bits: int = 8,
config: Optional[IndexConfigType] = None,
name: Optional[str] = None,
train: bool = True,
):
"""Create an index on the table.
"""Create an index on a column.
Parameters
----------
metric : str
The metric to use for the index. Default is "l2".
vector_column_name : str
The name of the vector column. Default is "vector".
This method supports both the new unified API and the legacy API
for backwards compatibility. The new API takes the column name as the
first positional argument and an index configuration object via
``config``; the legacy API takes the distance metric as the first
argument plus separate ``vector_column_name`` / ``num_partitions`` /
etc. parameters, and emits a ``DeprecationWarning``.
Examples
--------
>>> import lancedb
>>> import uuid
>>> from lancedb.schema import vector
>>> db = lancedb.connect("db://...", api_key="...", # doctest: +SKIP
... region="...") # doctest: +SKIP
>>> table_name = uuid.uuid4().hex
>>> schema = pa.schema(
... [
... pa.field("id", pa.uint32(), False),
... pa.field("vector", vector(128), False),
... pa.field("s", pa.string(), False),
... ]
New API (recommended):
>>> table.create_index( # doctest: +SKIP
... "vector", config=IvfPq(distance_type="l2")
... )
>>> table = db.create_table( # doctest: +SKIP
... table_name, # doctest: +SKIP
... schema=schema, # doctest: +SKIP
>>> table.create_index("category", config=BTree()) # doctest: +SKIP
>>> table.create_index("content", config=FTS()) # doctest: +SKIP
Legacy API (deprecated):
>>> table.create_index( # doctest: +SKIP
... "l2", vector_column_name="vector"
... )
>>> table.create_index("l2", "vector") # doctest: +SKIP
"""
# Detect whether this is a legacy API call
is_legacy = self._is_legacy_create_index_call(
metric,
config,
num_partitions,
num_sub_vectors,
vector_column_name,
accelerator,
index_cache_size,
replace,
)
if accelerator is not None:
logging.warning(
"GPU accelerator is not yet supported on LanceDB cloud."
"If you have 100M+ vectors to index,"
"please contact us at contact@lancedb.com"
)
if replace is not None:
logging.warning(
"replace is not supported on LanceDB cloud."
"Existing indexes will always be replaced."
if is_legacy:
warnings.warn(
"The create_index() API with metric/num_partitions parameters is "
"deprecated and will be removed in a future version. "
"Please migrate to the new unified API:\n"
" # Old (deprecated):\n"
" table.create_index('l2', vector_column_name='my_vector')\n"
" # New (recommended):\n"
" table.create_index('my_vector', config=IvfPq(distance_type='l2'))",
DeprecationWarning,
stacklevel=2,
)
index_type = index_type.upper()
if index_type == "VECTOR" or index_type == "IVF_PQ":
config = IvfPq(
distance_type=metric,
num_partitions=num_partitions,
num_sub_vectors=num_sub_vectors,
num_bits=num_bits,
)
elif index_type == "IVF_RQ":
config = IvfRq(
distance_type=metric,
num_partitions=num_partitions,
num_bits=num_bits,
)
elif index_type == "IVF_SQ":
config = IvfSq(distance_type=metric, num_partitions=num_partitions)
elif index_type == "IVF_HNSW_PQ":
raise ValueError(
"IVF_HNSW_PQ is not supported on LanceDB cloud."
"Please use IVF_HNSW_SQ instead."
)
elif index_type == "IVF_HNSW_SQ":
config = HnswSq(distance_type=metric, num_partitions=num_partitions)
elif index_type == "IVF_HNSW_FLAT":
config = HnswFlat(distance_type=metric, num_partitions=num_partitions)
elif index_type == "IVF_FLAT":
config = IvfFlat(distance_type=metric, num_partitions=num_partitions)
column = vector_column_name
if accelerator is not None:
logging.warning(
"GPU accelerator is not yet supported on LanceDB cloud."
"If you have 100M+ vectors to index,"
"please contact us at contact@lancedb.com"
)
if replace is not None:
logging.warning(
"replace is not supported on LanceDB cloud."
"Existing indexes will always be replaced."
)
idx_type = index_type.upper()
if idx_type == "VECTOR" or idx_type == "IVF_PQ":
config = IvfPq(
distance_type=metric,
num_partitions=num_partitions,
num_sub_vectors=num_sub_vectors,
num_bits=num_bits,
)
elif idx_type == "IVF_RQ":
config = IvfRq(
distance_type=metric,
num_partitions=num_partitions,
num_bits=num_bits,
)
elif idx_type == "IVF_SQ":
config = IvfSq(distance_type=metric, num_partitions=num_partitions)
elif idx_type == "IVF_HNSW_PQ":
raise ValueError(
"IVF_HNSW_PQ is not supported on LanceDB cloud."
"Please use IVF_HNSW_SQ instead."
)
elif idx_type == "IVF_HNSW_SQ":
config = HnswSq(distance_type=metric, num_partitions=num_partitions)
elif idx_type == "IVF_HNSW_FLAT":
config = HnswFlat(distance_type=metric, num_partitions=num_partitions)
elif idx_type == "IVF_FLAT":
config = IvfFlat(distance_type=metric, num_partitions=num_partitions)
else:
raise ValueError(
f"Unknown vector index type: {idx_type}. Valid options are"
" 'IVF_FLAT', 'IVF_PQ', 'IVF_RQ', 'IVF_SQ',"
" 'IVF_HNSW_PQ', 'IVF_HNSW_SQ', 'IVF_HNSW_FLAT'"
)
else:
raise ValueError(
f"Unknown vector index type: {index_type}. Valid options are"
" 'IVF_FLAT', 'IVF_PQ', 'IVF_RQ', 'IVF_SQ',"
" 'IVF_HNSW_PQ', 'IVF_HNSW_SQ', 'IVF_HNSW_FLAT'"
)
column = metric
LOOP.run(
self._table.create_index(
vector_column_name,
column,
config=config,
wait_timeout=wait_timeout,
name=name,
@@ -381,6 +474,37 @@ class RemoteTable(Table):
)
)
def _is_legacy_create_index_call(
self,
first_arg: str,
config: Optional[IndexConfigType],
num_partitions: Optional[int],
num_sub_vectors: Optional[int],
vector_column_name: str,
accelerator: Optional[str],
index_cache_size: Optional[int],
replace: Optional[bool],
) -> bool:
"""Detect if this is a legacy create_index call."""
if config is not None:
return False
if any(
x is not None
for x in (
num_partitions,
num_sub_vectors,
accelerator,
index_cache_size,
replace,
)
):
return True
if vector_column_name != VECTOR_COLUMN_NAME:
return True
if first_arg.lower() in KNOWN_METRICS:
return True
return False
def add(
self,
data: DATA,
@@ -741,6 +865,10 @@ class RemoteTable(Table):
"""Not supported on LanceDB Cloud."""
return LOOP.run(self._table.unset_lsm_write_spec())
def close_lsm_writers(self) -> None:
"""No-op on LanceDB Cloud (no local shard writers)."""
return LOOP.run(self._table.close_lsm_writers())
def drop_index(self, index_name: str):
return LOOP.run(self._table.drop_index(index_name))

View File

@@ -102,8 +102,15 @@ class LinearCombinationReranker(Reranker):
combined_list = []
for row_id, result in results.items():
# Convert vector distance to a relevance score in [0, 1] where
# higher is better. Missing vector entries are penalised with
# `_invert_score(fill)` = 1 - fill (= 0.0 for the default fill=1).
vector_score = self._invert_score(result.get("_distance", fill))
fts_score = result.get("_score", fill)
# FTS scores (BM25) are already in a "higher = more relevant" space.
# Missing FTS entries are penalised symmetrically: we use
# `1 - fill` so that the same `fill` value drives both missing-vector
# and missing-FTS penalties in the same direction.
fts_score = result.get("_score", 1 - fill)
result["_relevance_score"] = self._combine_score(vector_score, fts_score)
combined_list.append(result)
@@ -123,8 +130,12 @@ class LinearCombinationReranker(Reranker):
return tbl
def _combine_score(self, vector_score, fts_score):
# these scores represent distance
return 1 - (self.weight * vector_score + (1 - self.weight) * fts_score)
# Both vector_score (inverted distance) and fts_score are in a
# "higher = more relevant" space. A straight weighted average gives
# higher _relevance_score to better matches, as expected.
# Previously this returned `1 - (...)` which inverted the final
# ranking so that the *least* relevant document ranked first.
return self.weight * vector_score + (1 - self.weight) * fts_score
def _invert_score(self, dist: float):
# Invert the score between relevance and distance

View File

@@ -174,6 +174,24 @@ if TYPE_CHECKING:
DistanceType,
)
# Type alias for index configuration objects
IndexConfigType = Union[
IvfFlat,
IvfPq,
IvfSq,
IvfRq,
HnswFlat,
HnswPq,
HnswSq,
BTree,
Bitmap,
LabelList,
FTS,
]
# Known distance metrics for legacy API detection
KNOWN_METRICS = {"l2", "cosine", "dot", "hamming"}
def _into_pyarrow_reader(
data, schema: Optional[pa.Schema] = None
@@ -807,11 +825,49 @@ class Table(ABC):
"""
raise NotImplementedError
# New unified API overload
@overload
def create_index(
self,
metric="l2",
num_partitions=256,
num_sub_vectors=96,
column: str,
/,
*,
config: IndexConfigType,
replace: bool = ...,
wait_timeout: Optional[timedelta] = ...,
name: Optional[str] = ...,
train: bool = ...,
) -> None: ...
# Legacy API overload (deprecated)
@overload
def create_index(
self,
metric: Literal["l2", "cosine", "dot", "hamming"] = ...,
num_partitions: Optional[int] = ...,
num_sub_vectors: Optional[int] = ...,
vector_column_name: str = ...,
replace: bool = ...,
accelerator: Optional[str] = ...,
index_cache_size: Optional[int] = ...,
*,
index_type: VectorIndexType = ...,
wait_timeout: Optional[timedelta] = ...,
num_bits: int = ...,
max_iterations: int = ...,
sample_rate: int = ...,
m: int = ...,
ef_construction: int = ...,
name: Optional[str] = ...,
train: bool = ...,
target_partition_size: Optional[int] = ...,
) -> None: ...
def create_index(
self,
metric: DistanceType = "l2",
num_partitions: Optional[int] = None,
num_sub_vectors: Optional[int] = None,
vector_column_name: str = VECTOR_COLUMN_NAME,
replace: bool = True,
accelerator: Optional[str] = None,
@@ -824,46 +880,53 @@ class Table(ABC):
sample_rate: int = 256,
m: int = 20,
ef_construction: int = 300,
config: Optional[IndexConfigType] = None,
name: Optional[str] = None,
train: bool = True,
target_partition_size: Optional[int] = None,
):
"""Create an index on the table.
"""Create an index on a column.
This method supports both the new unified API and the legacy API
for backwards compatibility. The new API takes the column name as the
first positional argument and an index configuration object via
``config``; the legacy API takes the distance metric as the first
argument plus separate ``vector_column_name`` / ``num_partitions`` /
etc. parameters, and emits a ``DeprecationWarning``.
Parameters
----------
metric: str, default "l2"
The distance metric to use when creating the index.
Valid values are "l2", "cosine", "dot", or "hamming".
l2 is euclidean distance.
Hamming is available only for binary vectors.
num_partitions: int, default 256
The number of IVF partitions to use when creating the index.
Default is 256.
num_sub_vectors: int, default 96
The number of PQ sub-vectors to use when creating the index.
Default is 96.
vector_column_name: str, default "vector"
The vector column name to create the index.
replace: bool, default True
- If True, replace the existing index if it exists.
metric : str
For new API: the column name to index.
For legacy API: the distance metric ("l2", "cosine", "dot", "hamming").
config : IndexConfigType, optional
The index configuration object. If provided, uses the new unified API.
Can be one of: IvfFlat, IvfPq, IvfSq, IvfRq, HnswPq, HnswSq,
BTree, Bitmap, LabelList, FTS.
replace : bool, default True
Whether to replace an existing index on this column.
wait_timeout : timedelta, optional
Timeout to wait for async indexing to complete.
name : str, optional
Custom name for the index.
train : bool, default True
Whether to train the index with existing data.
- If False, raise an error if duplicate index exists.
accelerator: str, default None
If set, use the given accelerator to create the index.
Only support "cuda" for now.
index_cache_size : int, optional
The size of the index cache in number of entries. Default value is 256.
num_bits: int
The number of bits to encode sub-vectors. Only used with the IVF_PQ index.
Only 4 and 8 are supported.
wait_timeout: timedelta, optional
The timeout to wait if indexing is asynchronous.
name: str, optional
The name of the index. If not provided, a default name will be generated.
train: bool, default True
Whether to train the index with existing data. Vector indices always train
with existing data.
Examples
--------
New API (recommended):
>>> table.create_index( # doctest: +SKIP
... "vector", config=IvfPq(distance_type="l2")
... )
>>> table.create_index("category", config=BTree()) # doctest: +SKIP
>>> table.create_index("content", config=FTS()) # doctest: +SKIP
Legacy API (deprecated):
>>> table.create_index( # doctest: +SKIP
... "l2", vector_column_name="vector"
... )
"""
raise NotImplementedError
@@ -1188,7 +1251,7 @@ class Table(ABC):
... .when_not_matched_insert_all() \\
... .execute(new_data)
>>> res
MergeResult(version=2, num_updated_rows=2, num_inserted_rows=1, num_deleted_rows=0, num_attempts=1)
MergeResult(version=2, num_updated_rows=2, num_inserted_rows=1, num_deleted_rows=0, num_attempts=1, num_rows=3)
>>> # The order of new rows is non-deterministic since we use
>>> # a hash-join as part of this operation and so we sort here
>>> table.to_arrow().sort_by("a").to_pandas()
@@ -2250,11 +2313,51 @@ class LanceTable(Table):
dataset, allow_pyarrow_filter=False, batch_size=batch_size
)
# New unified API overload
@overload
def create_index(
self,
metric: DistanceType = "l2",
num_partitions=None,
num_sub_vectors=None,
column: str,
/,
*,
config: IndexConfigType,
replace: bool = ...,
wait_timeout: Optional[timedelta] = ...,
name: Optional[str] = ...,
train: bool = ...,
) -> None: ...
# Legacy API overload (deprecated)
@overload
def create_index(
self,
metric: Literal["l2", "cosine", "dot", "hamming"] = ...,
num_partitions: Optional[int] = ...,
num_sub_vectors: Optional[int] = ...,
vector_column_name: str = ...,
replace: bool = ...,
accelerator: Optional[str] = ...,
index_cache_size: Optional[int] = ...,
num_bits: int = ...,
index_type: Literal[
"IVF_FLAT", "IVF_SQ", "IVF_PQ", "IVF_RQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"
] = ...,
max_iterations: int = ...,
sample_rate: int = ...,
m: int = ...,
ef_construction: int = ...,
*,
wait_timeout: Optional[timedelta] = ...,
name: Optional[str] = ...,
train: bool = ...,
target_partition_size: Optional[int] = ...,
) -> None: ...
def create_index(
self,
metric: str = "l2",
num_partitions: Optional[int] = None,
num_sub_vectors: Optional[int] = None,
vector_column_name: str = VECTOR_COLUMN_NAME,
replace: bool = True,
accelerator: Optional[str] = None,
@@ -2274,47 +2377,232 @@ class LanceTable(Table):
m: int = 20,
ef_construction: int = 300,
*,
config: Optional[IndexConfigType] = None,
wait_timeout: Optional[timedelta] = None,
name: Optional[str] = None,
train: bool = True,
target_partition_size: Optional[int] = None,
):
"""Create an index on the table."""
if accelerator is not None:
# accelerator is only supported through pylance.
self.to_lance().create_index(
column=vector_column_name,
index_type=index_type,
"""Create an index on a column.
This method supports both the new unified API and the legacy API
for backwards compatibility. The new API takes the column name as the
first positional argument and an index configuration object via
``config``; the legacy API takes the distance metric as the first
argument plus separate ``vector_column_name`` / ``num_partitions`` /
etc. parameters, and emits a ``DeprecationWarning``.
Parameters
----------
metric : str
For new API: the column name to index.
For legacy API: the distance metric ("l2", "cosine", "dot", "hamming").
config : IndexConfigType, optional
The index configuration object. If provided, uses the new unified API.
Can be one of: IvfFlat, IvfPq, IvfSq, IvfRq, HnswPq, HnswSq,
BTree, Bitmap, LabelList, FTS.
replace : bool, default True
Whether to replace an existing index on this column.
wait_timeout : timedelta, optional
Timeout to wait for async indexing to complete.
name : str, optional
Custom name for the index.
train : bool, default True
Whether to train the index with existing data.
Examples
--------
New API (recommended):
>>> table.create_index( # doctest: +SKIP
... "vector", config=IvfPq(distance_type="l2")
... )
>>> table.create_index("category", config=BTree()) # doctest: +SKIP
>>> table.create_index("content", config=FTS()) # doctest: +SKIP
Legacy API (deprecated):
>>> table.create_index( # doctest: +SKIP
... "l2", vector_column_name="vector"
... )
"""
# Detect whether this is a legacy API call
is_legacy = self._is_legacy_create_index_call(
metric,
config,
num_partitions,
num_sub_vectors,
vector_column_name,
accelerator,
index_cache_size,
)
if is_legacy:
warnings.warn(
"The create_index() API with metric/num_partitions parameters is "
"deprecated and will be removed in a future version. "
"Please migrate to the new unified API:\n"
" # Old (deprecated):\n"
" table.create_index('l2', vector_column_name='my_vector')\n"
" # New (recommended):\n"
" table.create_index('my_vector', config=IvfPq(distance_type='l2'))",
DeprecationWarning,
stacklevel=2,
)
# Legacy API: first arg is the distance metric
column = vector_column_name
# Build config from legacy parameters
config = self._build_vector_config_from_legacy_params(
metric=metric,
index_type=index_type,
num_partitions=num_partitions,
num_sub_vectors=num_sub_vectors,
replace=replace,
accelerator=accelerator,
index_cache_size=index_cache_size,
num_bits=num_bits,
max_iterations=max_iterations,
sample_rate=sample_rate,
m=m,
ef_construction=ef_construction,
target_partition_size=target_partition_size,
accelerator=accelerator,
)
self.checkout_latest()
return
elif index_type == "IVF_FLAT":
config = IvfFlat(
# Handle accelerator through pylance
if accelerator is not None:
self.to_lance().create_index(
column=column,
index_type=index_type,
metric=metric,
num_partitions=num_partitions,
num_sub_vectors=num_sub_vectors,
replace=replace,
accelerator=accelerator,
index_cache_size=index_cache_size,
num_bits=num_bits,
m=m,
ef_construction=ef_construction,
target_partition_size=target_partition_size,
)
self.checkout_latest()
return
else:
# New API: metric is the column name
column = metric
# Check if config has accelerator set and dispatch to pylance
if config is not None and hasattr(config, "accelerator"):
acc = getattr(config, "accelerator", None)
if acc is not None:
# Dispatch to pylance for GPU acceleration
index_type_map = {
"IvfFlat": "IVF_FLAT",
"IvfSq": "IVF_SQ",
"IvfPq": "IVF_PQ",
"IvfRq": "IVF_RQ",
"HnswPq": "IVF_HNSW_PQ",
"HnswSq": "IVF_HNSW_SQ",
}
cfg_type = type(config).__name__
lance_index_type = index_type_map.get(cfg_type, "IVF_PQ")
self.to_lance().create_index(
column=column,
index_type=lance_index_type,
metric=getattr(config, "distance_type", "l2"),
num_partitions=getattr(config, "num_partitions", None),
num_sub_vectors=getattr(config, "num_sub_vectors", None),
replace=replace,
accelerator=acc,
num_bits=getattr(config, "num_bits", 8),
m=getattr(config, "m", 20),
ef_construction=getattr(config, "ef_construction", 300),
target_partition_size=getattr(
config, "target_partition_size", None
),
)
self.checkout_latest()
return
return LOOP.run(
self._table.create_index(
column,
replace=replace,
config=config,
wait_timeout=wait_timeout,
name=name,
train=train,
)
)
def _is_legacy_create_index_call(
self,
first_arg: str,
config: Optional[IndexConfigType],
num_partitions: Optional[int],
num_sub_vectors: Optional[int],
vector_column_name: str,
accelerator: Optional[str],
index_cache_size: Optional[int],
) -> bool:
"""Detect if this is a legacy create_index call."""
# If config is provided, it's definitely the new API
if config is not None:
return False
# If old-style parameters were explicitly set, it's legacy
if any(
x is not None
for x in (num_partitions, num_sub_vectors, accelerator, index_cache_size)
):
return True
# If vector_column_name differs from default, it's legacy
if vector_column_name != VECTOR_COLUMN_NAME:
return True
# If first arg is a known metric, assume legacy
if first_arg.lower() in KNOWN_METRICS:
return True
# Otherwise assume new API
return False
def _build_vector_config_from_legacy_params(
self,
metric: str,
index_type: str,
num_partitions: Optional[int],
num_sub_vectors: Optional[int],
num_bits: int,
max_iterations: int,
sample_rate: int,
m: int,
ef_construction: int,
target_partition_size: Optional[int],
accelerator: Optional[str],
) -> IndexConfigType:
"""Build an index config object from legacy parameters."""
if index_type == "IVF_FLAT":
return IvfFlat(
distance_type=metric,
num_partitions=num_partitions,
max_iterations=max_iterations,
sample_rate=sample_rate,
target_partition_size=target_partition_size,
accelerator=accelerator,
)
elif index_type == "IVF_SQ":
config = IvfSq(
return IvfSq(
distance_type=metric,
num_partitions=num_partitions,
max_iterations=max_iterations,
sample_rate=sample_rate,
target_partition_size=target_partition_size,
accelerator=accelerator,
)
elif index_type == "IVF_PQ":
config = IvfPq(
return IvfPq(
distance_type=metric,
num_partitions=num_partitions,
num_sub_vectors=num_sub_vectors,
@@ -2322,18 +2610,20 @@ class LanceTable(Table):
max_iterations=max_iterations,
sample_rate=sample_rate,
target_partition_size=target_partition_size,
accelerator=accelerator,
)
elif index_type == "IVF_RQ":
config = IvfRq(
return IvfRq(
distance_type=metric,
num_partitions=num_partitions,
num_bits=num_bits,
max_iterations=max_iterations,
sample_rate=sample_rate,
target_partition_size=target_partition_size,
accelerator=accelerator,
)
elif index_type == "IVF_HNSW_PQ":
config = HnswPq(
return HnswPq(
distance_type=metric,
num_partitions=num_partitions,
num_sub_vectors=num_sub_vectors,
@@ -2343,9 +2633,10 @@ class LanceTable(Table):
m=m,
ef_construction=ef_construction,
target_partition_size=target_partition_size,
accelerator=accelerator,
)
elif index_type == "IVF_HNSW_SQ":
config = HnswSq(
return HnswSq(
distance_type=metric,
num_partitions=num_partitions,
max_iterations=max_iterations,
@@ -2353,9 +2644,10 @@ class LanceTable(Table):
m=m,
ef_construction=ef_construction,
target_partition_size=target_partition_size,
accelerator=accelerator,
)
elif index_type == "IVF_HNSW_FLAT":
config = HnswFlat(
return HnswFlat(
distance_type=metric,
num_partitions=num_partitions,
max_iterations=max_iterations,
@@ -2367,16 +2659,6 @@ class LanceTable(Table):
else:
raise ValueError(f"Unknown index type {index_type}")
return LOOP.run(
self._table.create_index(
vector_column_name,
replace=replace,
config=config,
name=name,
train=train,
)
)
def drop_index(self, name: str) -> None:
"""
Drops an index from the table
@@ -2476,6 +2758,11 @@ class LanceTable(Table):
"""
return LOOP.run(self._table.latest_storage_options())
@deprecation.deprecated(
deprecated_in="0.25.0",
current_version=__version__,
details="Use create_index() with config=BTree()/Bitmap()/LabelList() instead.",
)
def create_scalar_index(
self,
column: str,
@@ -2484,6 +2771,12 @@ class LanceTable(Table):
index_type: ScalarIndexType = "BTREE",
name: Optional[str] = None,
):
"""Create a scalar index on a column.
.. deprecated:: 0.25.0
Use :meth:`create_index` with a BTree, Bitmap, or LabelList config instead.
Example: ``table.create_index("column", config=BTree())``
"""
if index_type == "BTREE":
config = BTree()
elif index_type == "BITMAP":
@@ -2496,6 +2789,11 @@ class LanceTable(Table):
self._table.create_index(column, replace=replace, config=config, name=name)
)
@deprecation.deprecated(
deprecated_in="0.25.0",
current_version=__version__,
details="Use create_index() with config=FTS() instead.",
)
def create_fts_index(
self,
field_names: Union[str, List[str]],
@@ -2519,6 +2817,12 @@ class LanceTable(Table):
prefix_only: bool = False,
name: Optional[str] = None,
):
"""Create a full-text search index on a column.
.. deprecated:: 0.25.0
Use :meth:`create_index` with an FTS config instead.
Example: ``table.create_index("text_column", config=FTS())``
"""
self._ensure_no_legacy_fts_index()
if use_tantivy:
@@ -3297,6 +3601,11 @@ class LanceTable(Table):
[`AsyncTable.unset_lsm_write_spec`][lancedb.AsyncTable.unset_lsm_write_spec]."""
return LOOP.run(self._table.unset_lsm_write_spec())
def close_lsm_writers(self) -> None:
"""Close cached MemWAL shard writers. See
[`AsyncTable.close_lsm_writers`][lancedb.AsyncTable.close_lsm_writers]."""
return LOOP.run(self._table.close_lsm_writers())
def uses_v2_manifest_paths(self) -> bool:
"""
Check if the table is using the new v2 manifest paths.
@@ -3905,6 +4214,16 @@ class AsyncTable:
"""
await self._inner.unset_lsm_write_spec()
async def close_lsm_writers(self) -> None:
"""Drain and close any cached MemWAL shard writers for this table.
When an LSM write spec is installed, `merge_insert` opens MemWAL shard
writers and caches them for reuse across calls. This closes them,
flushing pending data; writers reopen lazily on the next
`merge_insert`. It is a no-op when no writers are cached.
"""
await self._inner.close_lsm_writers()
@property
def name(self) -> str:
"""The name of the table."""
@@ -4355,7 +4674,7 @@ class AsyncTable:
... .when_not_matched_insert_all() \\
... .execute(new_data)
>>> res
MergeResult(version=2, num_updated_rows=2, num_inserted_rows=1, num_deleted_rows=0, num_attempts=1)
MergeResult(version=2, num_updated_rows=2, num_inserted_rows=1, num_deleted_rows=0, num_attempts=1, num_rows=3)
>>> # The order of new rows is non-deterministic since we use
>>> # a hash-join as part of this operation and so we sort here
>>> table.to_arrow().sort_by("a").to_pandas()
@@ -4735,6 +5054,8 @@ class AsyncTable:
when_not_matched_by_source_condition=merge._when_not_matched_by_source_condition,
timeout=merge._timeout,
use_index=merge._use_index,
use_lsm_write=merge._use_lsm_write,
validate_single_shard=merge._validate_single_shard,
),
)

View File

@@ -57,7 +57,7 @@ async def test_upsert_async(mem_db_async):
await table.count_rows() # 3
res
# MergeResult(version=2, num_updated_rows=1,
# num_inserted_rows=1, num_deleted_rows=0)
# num_inserted_rows=1, num_deleted_rows=0, num_rows=2)
# --8<-- [end:upsert_basic_async]
assert await table.count_rows() == 3
assert res.version == 2
@@ -86,7 +86,7 @@ def test_insert_if_not_exists(mem_db):
table.count_rows() # 3
res
# MergeResult(version=2, num_updated_rows=0,
# num_inserted_rows=1, num_deleted_rows=0)
# num_inserted_rows=1, num_deleted_rows=0, num_rows=1)
# --8<-- [end:insert_if_not_exists]
assert table.count_rows() == 3
assert res.version == 2
@@ -116,7 +116,7 @@ async def test_insert_if_not_exists_async(mem_db_async):
await table.count_rows() # 3
res
# MergeResult(version=2, num_updated_rows=0,
# num_inserted_rows=1, num_deleted_rows=0)
# num_inserted_rows=1, num_deleted_rows=0, num_rows=1)
# --8<-- [end:insert_if_not_exists]
assert await table.count_rows() == 3
assert res.version == 2
@@ -150,7 +150,7 @@ def test_replace_range(mem_db):
table.count_rows("doc_id = 1") # 1
res
# MergeResult(version=2, num_updated_rows=1,
# num_inserted_rows=0, num_deleted_rows=1)
# num_inserted_rows=0, num_deleted_rows=1, num_rows=1)
# --8<-- [end:insert_if_not_exists]
assert table.count_rows("doc_id = 1") == 1
assert res.version == 2
@@ -185,7 +185,7 @@ async def test_replace_range_async(mem_db_async):
await table.count_rows("doc_id = 1") # 1
res
# MergeResult(version=2, num_updated_rows=1,
# num_inserted_rows=0, num_deleted_rows=1)
# num_inserted_rows=0, num_deleted_rows=1, num_rows=1)
# --8<-- [end:insert_if_not_exists]
assert await table.count_rows("doc_id = 1") == 1
assert res.version == 2

View File

@@ -466,7 +466,8 @@ async def test_create_table_v2_manifest_paths_async(tmp_path):
assert await tbl.uses_v2_manifest_paths()
manifests_dir = tmp_path / "test_v2_manifest_paths.lance" / "_versions"
for manifest in os.listdir(manifests_dir):
assert re.match(r"\d{20}\.manifest", manifest)
if manifest.endswith(".manifest"):
assert re.match(r"\d{20}\.manifest", manifest)
# Start a table in V1 mode then migrate
tbl = await db_no_v2_paths.create_table(
@@ -476,13 +477,15 @@ async def test_create_table_v2_manifest_paths_async(tmp_path):
assert not await tbl.uses_v2_manifest_paths()
manifests_dir = tmp_path / "test_v2_migration.lance" / "_versions"
for manifest in os.listdir(manifests_dir):
assert re.match(r"\d\.manifest", manifest)
if manifest.endswith(".manifest"):
assert re.match(r"\d\.manifest", manifest)
await tbl.migrate_manifest_paths_v2()
assert await tbl.uses_v2_manifest_paths()
for manifest in os.listdir(manifests_dir):
assert re.match(r"\d{20}\.manifest", manifest)
if manifest.endswith(".manifest"):
assert re.match(r"\d{20}\.manifest", manifest)
@pytest.mark.asyncio

View File

@@ -215,11 +215,12 @@ def test_reject_legacy_tantivy_index(table):
@pytest.mark.parametrize("with_position", [True, False])
def test_create_inverted_index(table, with_position):
table.create_fts_index(
"text",
with_position=with_position,
name="custom_fts_index",
)
with pytest.warns(DeprecationWarning, match="create_fts_index"):
table.create_fts_index(
"text",
with_position=with_position,
name="custom_fts_index",
)
indices = table.list_indices()
fts_indices = [i for i in indices if i.index_type == "FTS"]
assert any(i.name == "custom_fts_index" for i in fts_indices)

View File

@@ -162,12 +162,13 @@ async def test_create_bitmap_index(some_table: AsyncTable):
await some_table.create_index("data", config=Bitmap())
indices = await some_table.list_indices()
assert len(indices) == 3
# list_indices returns indices in alphabetical order by name
assert indices[0].index_type == "Bitmap"
assert indices[0].columns == ["id"]
assert indices[0].columns == ["data"]
assert indices[1].index_type == "Bitmap"
assert indices[1].columns == ["is_active"]
assert indices[1].columns == ["id"]
assert indices[2].index_type == "Bitmap"
assert indices[2].columns == ["data"]
assert indices[2].columns == ["is_active"]
index_name = indices[0].name
stats = await some_table.index_stats(index_name)

View File

@@ -40,16 +40,6 @@ def _make_table(tmp_path):
def test_set_lsm_write_spec_validates(tmp_path):
_db, table = _make_table(tmp_path)
# No PK set yet.
with pytest.raises(Exception, match="primary key"):
table.set_lsm_write_spec(LsmWriteSpec.bucket("id", 4))
table.set_unenforced_primary_key("id")
# Column mismatch.
with pytest.raises(Exception, match="match"):
table.set_lsm_write_spec(LsmWriteSpec.bucket("v", 4))
# Out-of-range num_buckets.
with pytest.raises(Exception, match="num_buckets"):
table.set_lsm_write_spec(LsmWriteSpec.bucket("id", 0))
@@ -70,7 +60,6 @@ def test_unset_lsm_write_spec(tmp_path):
table.unset_lsm_write_spec()
# Install a spec, then remove it; afterwards a fresh spec can be set.
table.set_unenforced_primary_key("id")
table.set_lsm_write_spec(LsmWriteSpec.bucket("id", 4))
table.unset_lsm_write_spec()
# A second unset errors — there is no spec left to remove.

View File

@@ -0,0 +1,196 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
"""Tests for the MemWAL LSM ``merge_insert`` dispatch."""
from datetime import timedelta
import lancedb
import pyarrow as pa
import pytest
from lancedb._lancedb import LsmWriteSpec
SCHEMA = pa.schema(
[
pa.field("id", pa.int64(), nullable=False),
pa.field("value", pa.int64(), nullable=False),
]
)
REGION_SCHEMA = pa.schema(
[
pa.field("id", pa.int64(), nullable=False),
pa.field("region", pa.utf8(), nullable=False),
]
)
def _reader(ids):
batch = pa.RecordBatch.from_arrays(
[
pa.array(ids, type=pa.int64()),
pa.array(list(range(len(ids))), type=pa.int64()),
],
schema=SCHEMA,
)
return pa.RecordBatchReader.from_batches(SCHEMA, [batch])
def _region_reader(rows):
batch = pa.RecordBatch.from_arrays(
[
pa.array([row[0] for row in rows], type=pa.int64()),
pa.array([row[1] for row in rows], type=pa.utf8()),
],
schema=REGION_SCHEMA,
)
return pa.RecordBatchReader.from_batches(REGION_SCHEMA, [batch])
def _bucket_table(tmp_path):
"""A table with ``id`` as the primary key and a single-bucket LSM spec."""
db = lancedb.connect(tmp_path, read_consistency_interval=timedelta(seconds=0))
table = db.create_table("t", _reader([1, 2, 3]))
table.set_unenforced_primary_key("id")
# num_buckets = 1: every row routes to the single bucket.
table.set_lsm_write_spec(LsmWriteSpec.bucket("id", 1))
return table
def test_lsm_merge_insert_bucket(tmp_path):
table = _bucket_table(tmp_path)
# Empty `on` defaults to the primary key.
result = (
table.merge_insert([])
.when_matched_update_all()
.when_not_matched_insert_all()
.execute(_reader([3, 4, 5]))
)
# LSM path: rows go to the MemWAL, so only num_rows is populated.
assert result.num_rows == 3
assert result.version == 0
assert result.num_inserted_rows == 0
assert result.num_updated_rows == 0
def test_lsm_merge_insert_unsharded(tmp_path):
db = lancedb.connect(tmp_path, read_consistency_interval=timedelta(seconds=0))
table = db.create_table("t", _reader([1, 2, 3]))
table.set_unenforced_primary_key("id")
table.set_lsm_write_spec(LsmWriteSpec.unsharded())
result = (
table.merge_insert("id")
.when_matched_update_all()
.when_not_matched_insert_all()
.execute(_reader([10, 11, 12, 13]))
)
assert result.num_rows == 4
def test_lsm_merge_insert_identity(tmp_path):
db = lancedb.connect(tmp_path, read_consistency_interval=timedelta(seconds=0))
table = db.create_table("t", _region_reader([(1, "us"), (2, "us")]))
table.set_unenforced_primary_key("id")
table.set_lsm_write_spec(LsmWriteSpec.identity("region"))
# All rows share one identity value, so they route to one shard.
result = (
table.merge_insert([])
.when_matched_update_all()
.when_not_matched_insert_all()
.execute(_region_reader([(3, "us"), (4, "us")]))
)
assert result.num_rows == 2
def test_lsm_merge_insert_use_lsm_write_false(tmp_path):
table = _bucket_table(tmp_path) # rows id = 1, 2, 3
# use_lsm_write(False) opts out: the standard path runs and commits.
result = (
table.merge_insert("id")
.when_not_matched_insert_all()
.use_lsm_write(False)
.execute(_reader([3, 4, 5]))
)
assert result.num_inserted_rows == 2
assert table.count_rows() == 5
def test_lsm_merge_insert_validate_single_shard_off(tmp_path):
table = _bucket_table(tmp_path)
result = (
table.merge_insert([])
.when_matched_update_all()
.when_not_matched_insert_all()
.validate_single_shard(False)
.execute(_reader([6, 7, 8]))
)
assert result.num_rows == 3
def test_lsm_merge_insert_use_lsm_write_true_requires_spec(tmp_path):
# A table with a primary key but no LSM write spec installed.
db = lancedb.connect(tmp_path, read_consistency_interval=timedelta(seconds=0))
table = db.create_table("t", _reader([1, 2, 3]))
table.set_unenforced_primary_key("id")
with pytest.raises(Exception, match="use_lsm_write"):
(
table.merge_insert("id")
.when_matched_update_all()
.when_not_matched_insert_all()
.use_lsm_write(True)
.execute(_reader([4]))
)
def test_lsm_merge_insert_rejects_on_not_primary_key(tmp_path):
table = _bucket_table(tmp_path)
with pytest.raises(Exception, match="primary key"):
(
table.merge_insert("value")
.when_matched_update_all()
.when_not_matched_insert_all()
.execute(_reader([1]))
)
def test_lsm_merge_insert_rejects_non_upsert(tmp_path):
table = _bucket_table(tmp_path)
# Insert-only (no when_matched_update_all) is not the upsert shape.
with pytest.raises(Exception, match="upsert"):
table.merge_insert([]).when_not_matched_insert_all().execute(_reader([4]))
def test_lsm_close_writers(tmp_path):
table = _bucket_table(tmp_path)
(
table.merge_insert([])
.when_matched_update_all()
.when_not_matched_insert_all()
.execute(_reader([7, 8]))
)
table.close_lsm_writers()
# The writer reopens lazily on the next merge_insert.
result = (
table.merge_insert([])
.when_matched_update_all()
.when_not_matched_insert_all()
.execute(_reader([9]))
)
assert result.num_rows == 1
@pytest.mark.asyncio
async def test_async_lsm_merge_insert(tmp_path):
db = await lancedb.connect_async(
tmp_path, read_consistency_interval=timedelta(seconds=0)
)
table = await db.create_table("t", _reader([1, 2, 3]))
await table.set_unenforced_primary_key("id")
await table.set_lsm_write_spec(LsmWriteSpec.bucket("id", 1))
builder = (
table.merge_insert([]).when_matched_update_all().when_not_matched_insert_all()
)
result = await builder.execute(_reader([3, 4, 5]))
assert result.num_rows == 3
await table.close_lsm_writers()

View File

@@ -586,22 +586,25 @@ def test_table_create_indices():
# This is a smoke-test.
table = db.create_table("test", [{"id": 1}])
# Test create_scalar_index with custom name
table.create_scalar_index(
"id", wait_timeout=timedelta(seconds=2), name="custom_scalar_idx"
)
# Test create_scalar_index with custom name (legacy method)
with pytest.warns(DeprecationWarning, match="create_scalar_index"):
table.create_scalar_index(
"id", wait_timeout=timedelta(seconds=2), name="custom_scalar_idx"
)
# Test create_fts_index with custom name
table.create_fts_index(
"text", wait_timeout=timedelta(seconds=2), name="custom_fts_idx"
)
# Test create_fts_index with custom name (legacy method)
with pytest.warns(DeprecationWarning, match="create_fts_index"):
table.create_fts_index(
"text", wait_timeout=timedelta(seconds=2), name="custom_fts_idx"
)
# Test create_index with custom name
table.create_index(
vector_column_name="vector",
wait_timeout=timedelta(seconds=10),
name="custom_vector_idx",
)
# Test create_index with custom name (legacy form: vector_column_name kwarg)
with pytest.warns(DeprecationWarning, match="create_index"):
table.create_index(
vector_column_name="vector",
wait_timeout=timedelta(seconds=10),
name="custom_vector_idx",
)
# Validate that the name parameter was passed correctly in requests
assert len(received_requests) == 3
@@ -630,6 +633,98 @@ def test_table_create_indices():
table.drop_index("custom_fts_idx")
def test_remote_create_index_new_api():
received_requests = []
def handler(request):
if request.path == "/v1/table/test/create_index/":
content_len = int(request.headers.get("Content-Length", 0))
body = request.rfile.read(content_len) if content_len > 0 else b""
received_requests.append(json.loads(body) if body else {})
request.send_response(200)
request.end_headers()
elif request.path == "/v1/table/test/create/?mode=create":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
request.wfile.write(b"{}")
elif request.path == "/v1/table/test/describe/":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
request.wfile.write(
json.dumps(
dict(
version=1,
schema=dict(
fields=[
dict(name="id", type={"type": "int64"}, nullable=False),
dict(
name="category",
type={"type": "string"},
nullable=False,
),
dict(
name="text", type={"type": "string"}, nullable=False
),
dict(
name="vector",
type={
"type": "fixed_size_list",
"fields": [
dict(
name="item",
type={"type": "float"},
nullable=True,
)
],
"length": 2,
},
nullable=False,
),
]
),
)
).encode()
)
else:
request.send_response(404)
request.end_headers()
from lancedb.index import BTree, FTS, IvfPq, IvfRq
with mock_lancedb_connection(handler) as db:
table = db.create_table("test", [{"id": 1}])
# New API: column-first, config= kwarg. Should NOT emit DeprecationWarning.
import warnings as _warnings
with _warnings.catch_warnings():
_warnings.simplefilter("error", DeprecationWarning)
table.create_index("vector", config=IvfPq(distance_type="l2"))
table.create_index("category", config=BTree())
table.create_index("text", config=FTS())
# IvfRq via new API
table.create_index("vector", config=IvfRq(distance_type="l2"))
# Legacy index_type="IVF_RQ" routes to IvfRq config under the hood.
with pytest.warns(DeprecationWarning, match="create_index"):
table.create_index(
vector_column_name="vector",
index_type="IVF_RQ",
num_partitions=8,
)
assert len(received_requests) == 5
assert [req["column"] for req in received_requests] == [
"vector",
"category",
"text",
"vector",
"vector",
]
def test_table_wait_for_index_timeout():
def handler(request):
index_stats = dict(

View File

@@ -603,3 +603,89 @@ def test_cross_encoder_reranker_return_all(tmp_path):
assert "_relevance_score" in result.column_names
assert "_score" in result.column_names
assert "_distance" in result.column_names
# ---------------------------------------------------------------------------
# Regression tests for LinearCombinationReranker scoring bugs (issue #3154)
# ---------------------------------------------------------------------------
def test_linear_combination_best_match_ranks_first():
"""
The document that is BOTH the closest vector match AND the only FTS match
must rank first. Previously _combine_score subtracted from 1, inverting
the ranking so the worst document ranked highest.
"""
reranker = LinearCombinationReranker(weight=0.7, return_score="all")
# rowid 0: perfect vector match, sole FTS match → should rank 1st
# rowid 1: mediocre vector, no FTS match
# rowid 2: bad vector, no FTS match
vector_results = pa.Table.from_pydict(
{
"_rowid": [0, 1, 2],
"_distance": [0.0, 0.5, 0.9],
}
)
fts_results = pa.Table.from_pydict(
{
"_rowid": [0],
"_score": [1.0],
}
)
combined = reranker.merge_results(vector_results, fts_results, fill=1.0)
scores = dict(
zip(
combined["_rowid"].to_pylist(),
combined["_relevance_score"].to_pylist(),
)
)
# rowid 0 must have the highest relevance score
assert scores[0] > scores[1], (
f"Best match (rowid 0, score={scores[0]:.4f}) should beat "
f"mid match (rowid 1, score={scores[1]:.4f})"
)
assert scores[1] > scores[2], (
f"Mid match (rowid 1, score={scores[1]:.4f}) should beat "
f"bad match (rowid 2, score={scores[2]:.4f})"
)
def test_linear_combination_missing_fts_is_penalised():
"""
A document with no FTS match must score *lower* than a document that
has a mediocre FTS match, everything else being equal. Previously
missing-FTS entries used fill=1.0 directly, which gave them a reward
(via the 1-(...) inversion) instead of a penalty.
"""
reranker = LinearCombinationReranker(weight=0.5, return_score="all")
vector_results = pa.Table.from_pydict(
{
"_rowid": [0, 1],
"_distance": [0.2, 0.2], # identical vector scores
}
)
fts_results = pa.Table.from_pydict(
{
"_rowid": [0], # rowid 1 has no FTS match
"_score": [0.3], # small FTS score
}
)
combined = reranker.merge_results(vector_results, fts_results, fill=1.0)
scores = dict(
zip(
combined["_rowid"].to_pylist(),
combined["_relevance_score"].to_pylist(),
)
)
# rowid 0 has a small FTS score; rowid 1 has none.
# Even a small FTS contribution should beat having none at all.
assert scores[0] > scores[1], (
f"Document with FTS score (rowid 0, {scores[0]:.4f}) should beat "
f"document with no FTS match (rowid 1, {scores[1]:.4f})"
)

View File

@@ -4,6 +4,7 @@
import os
import sys
import warnings
from datetime import date, datetime, timedelta
from time import sleep
from typing import List
@@ -11,7 +12,7 @@ from unittest.mock import patch
import lancedb
from lancedb.dependencies import _PANDAS_AVAILABLE
from lancedb.index import HnswFlat, HnswPq, HnswSq, IvfPq
from lancedb.index import BTree, FTS, HnswFlat, HnswPq, HnswSq, IvfPq
import numpy as np
import polars as pl
import pyarrow as pa
@@ -928,7 +929,12 @@ def test_create_index_method(mock_create_index, mem_db: DBConnection):
num_bits=4,
)
mock_create_index.assert_called_with(
"vector", replace=True, config=expected_config, name=None, train=True
"vector",
replace=True,
config=expected_config,
wait_timeout=None,
name=None,
train=True,
)
# Test with target_partition_size
@@ -948,7 +954,12 @@ def test_create_index_method(mock_create_index, mem_db: DBConnection):
target_partition_size=8192,
)
mock_create_index.assert_called_with(
"vector", replace=True, config=expected_config, name=None, train=True
"vector",
replace=True,
config=expected_config,
wait_timeout=None,
name=None,
train=True,
)
# target_partition_size has a default value,
@@ -967,7 +978,12 @@ def test_create_index_method(mock_create_index, mem_db: DBConnection):
num_bits=4,
)
mock_create_index.assert_called_with(
"vector", replace=True, config=expected_config, name=None, train=True
"vector",
replace=True,
config=expected_config,
wait_timeout=None,
name=None,
train=True,
)
table.create_index(
@@ -978,7 +994,12 @@ def test_create_index_method(mock_create_index, mem_db: DBConnection):
)
expected_config = HnswPq(distance_type="dot")
mock_create_index.assert_called_with(
"my_vector", replace=False, config=expected_config, name=None, train=True
"my_vector",
replace=False,
config=expected_config,
wait_timeout=None,
name=None,
train=True,
)
table.create_index(
@@ -993,7 +1014,12 @@ def test_create_index_method(mock_create_index, mem_db: DBConnection):
distance_type="cosine", sample_rate=0.1, m=29, ef_construction=10
)
mock_create_index.assert_called_with(
"my_vector", replace=True, config=expected_config, name=None, train=True
"my_vector",
replace=True,
config=expected_config,
wait_timeout=None,
name=None,
train=True,
)
table.create_index(
@@ -1008,7 +1034,12 @@ def test_create_index_method(mock_create_index, mem_db: DBConnection):
distance_type="cosine", sample_rate=0.1, m=29, ef_construction=10
)
mock_create_index.assert_called_with(
"my_vector", replace=True, config=expected_config, name=None, train=True
"my_vector",
replace=True,
config=expected_config,
wait_timeout=None,
name=None,
train=True,
)
@@ -1032,6 +1063,7 @@ def test_create_index_name_and_train_parameters(
"vector",
replace=True,
config=expected_config,
wait_timeout=None,
name="my_custom_index",
train=True,
)
@@ -1039,13 +1071,82 @@ def test_create_index_name_and_train_parameters(
# Test with train=False
table.create_index(vector_column_name="vector", train=False)
mock_create_index.assert_called_with(
"vector", replace=True, config=expected_config, name=None, train=False
"vector",
replace=True,
config=expected_config,
wait_timeout=None,
name=None,
train=False,
)
# Test with both name and train
table.create_index(vector_column_name="vector", name="my_index_name", train=True)
mock_create_index.assert_called_with(
"vector", replace=True, config=expected_config, name="my_index_name", train=True
"vector",
replace=True,
config=expected_config,
wait_timeout=None,
name="my_index_name",
train=True,
)
@patch("lancedb.table.AsyncTable.create_index")
def test_create_index_legacy_emits_deprecation_warning(
mock_create_index, mem_db: DBConnection
):
table = mem_db.create_table(
"test",
data=[{"vector": [3.1, 4.1]}, {"vector": [5.9, 26.5]}],
)
with pytest.warns(DeprecationWarning, match="create_index"):
table.create_index(metric="l2", num_partitions=8, vector_column_name="vector")
@patch("lancedb.table.AsyncTable.create_index")
def test_create_index_new_api(mock_create_index, mem_db: DBConnection):
table = mem_db.create_table(
"test",
data=[
{"vector": [3.1, 4.1], "category": "a", "text": "hello world"},
{"vector": [5.9, 26.5], "category": "b", "text": "goodbye"},
],
)
# Vector index via new API should not warn
with warnings.catch_warnings():
warnings.simplefilter("error", DeprecationWarning)
table.create_index("vector", config=IvfPq(distance_type="l2"))
mock_create_index.assert_called_with(
"vector",
replace=True,
config=IvfPq(distance_type="l2"),
wait_timeout=None,
name=None,
train=True,
)
# Scalar index via new API
table.create_index("category", config=BTree())
mock_create_index.assert_called_with(
"category",
replace=True,
config=BTree(),
wait_timeout=None,
name=None,
train=True,
)
# FTS index via new API
table.create_index("text", config=FTS(with_position=True))
mock_create_index.assert_called_with(
"text",
replace=True,
config=FTS(with_position=True),
wait_timeout=None,
name=None,
train=True,
)
@@ -1861,8 +1962,9 @@ def test_create_scalar_index(mem_db: DBConnection):
"my_table",
data=test_data,
)
# Test with default name
table.create_scalar_index("x")
# Test with default name; confirm DeprecationWarning fires
with pytest.warns(DeprecationWarning, match="create_scalar_index"):
table.create_scalar_index("x")
indices = table.list_indices()
assert len(indices) == 1
scalar_index = indices[0]

View File

@@ -143,18 +143,20 @@ pub struct MergeResult {
pub num_inserted_rows: u64,
pub num_deleted_rows: u64,
pub num_attempts: u32,
pub num_rows: u64,
}
#[pymethods]
impl MergeResult {
pub fn __repr__(&self) -> String {
format!(
"MergeResult(version={}, num_updated_rows={}, num_inserted_rows={}, num_deleted_rows={}, num_attempts={})",
"MergeResult(version={}, num_updated_rows={}, num_inserted_rows={}, num_deleted_rows={}, num_attempts={}, num_rows={})",
self.version,
self.num_updated_rows,
self.num_inserted_rows,
self.num_deleted_rows,
self.num_attempts
self.num_attempts,
self.num_rows
)
}
}
@@ -167,6 +169,7 @@ impl From<lancedb::table::MergeResult> for MergeResult {
num_inserted_rows: result.num_inserted_rows,
num_deleted_rows: result.num_deleted_rows,
num_attempts: result.num_attempts,
num_rows: result.num_rows,
}
}
}
@@ -194,6 +197,12 @@ impl LsmWriteSpec {
}
/// Identity sharding — shard by the raw value of `column`.
///
/// `column` must be a deterministic function of the unenforced primary
/// key: every row with a given primary key must always produce the same
/// `column` value, or upserts of that key can land in different shards
/// and a stale version can win. Typically `column` is the primary key
/// itself or a stable attribute of it.
#[staticmethod]
pub fn identity(column: String) -> Self {
Self {
@@ -933,6 +942,12 @@ impl Table {
if let Some(use_index) = parameters.use_index {
builder.use_index(use_index);
}
if let Some(use_lsm_write) = parameters.use_lsm_write {
builder.use_lsm_write(use_lsm_write);
}
if let Some(validate_single_shard) = parameters.validate_single_shard {
builder.validate_single_shard(validate_single_shard);
}
future_into_py(self_.py(), async move {
let res = builder.execute(Box::new(batches)).await.infer_error()?;
@@ -971,6 +986,13 @@ impl Table {
})
}
pub fn close_lsm_writers(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move {
inner.close_lsm_writers().await.infer_error()
})
}
pub fn uses_v2_manifest_paths(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move {
@@ -1124,6 +1146,8 @@ pub struct MergeInsertParams {
when_not_matched_by_source_condition: Option<String>,
timeout: Option<std::time::Duration>,
use_index: Option<bool>,
use_lsm_write: Option<bool>,
validate_single_shard: Option<bool>,
}
#[pyclass]