mirror of
https://github.com/lancedb/lancedb.git
synced 2026-06-02 20:00:46 +00:00
Merge remote-tracking branch 'origin/main' into xuanwo/remote-pytorch-multiprocessing
# Conflicts: # python/python/lancedb/remote/table.py
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.33.0-beta.0"
|
||||
current_version = "0.33.0-beta.1"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb-python"
|
||||
version = "0.33.0-beta.0"
|
||||
version = "0.33.0-beta.1"
|
||||
publish = false
|
||||
edition.workspace = true
|
||||
description = "Python bindings for LanceDB"
|
||||
|
||||
@@ -94,7 +94,6 @@ def connect(
|
||||
host_override: str, optional
|
||||
The override url for LanceDB Cloud.
|
||||
read_consistency_interval: timedelta, default None
|
||||
(For LanceDB OSS only)
|
||||
The interval at which to check for updates to the table from other
|
||||
processes. If None, then consistency is not checked. For performance
|
||||
reasons, this is the default. For strong consistency, set this to
|
||||
@@ -104,6 +103,10 @@ def connect(
|
||||
the last check, then the table will be checked for updates. Note: this
|
||||
consistency only applies to read operations. Write operations are
|
||||
always consistent.
|
||||
|
||||
Stronger consistency is not free. The smaller the interval, the more
|
||||
often each read pays the cost of checking for updates against object
|
||||
storage, raising per-read latency and cost.
|
||||
client_config: ClientConfig or dict, optional
|
||||
Configuration options for the LanceDB Cloud HTTP client. If a dict, then
|
||||
the keys are the attributes of the ClientConfig class. If None, then the
|
||||
@@ -147,6 +150,13 @@ def connect(
|
||||
>>> db = lancedb.connect("s3://my-bucket/lancedb",
|
||||
... storage_options={"aws_access_key_id": "***"})
|
||||
|
||||
For tests and temporary data, use an in-memory database:
|
||||
|
||||
>>> db = lancedb.connect("memory://")
|
||||
|
||||
In-memory databases are not persisted. Tables are dropped when the last
|
||||
connection or table handle referencing them is closed.
|
||||
|
||||
Connect to LanceDB cloud:
|
||||
|
||||
>>> db = lancedb.connect("db://my_database", api_key="ldb_...",
|
||||
@@ -210,6 +220,7 @@ def connect(
|
||||
request_thread_pool=request_thread_pool,
|
||||
client_config=client_config,
|
||||
storage_options=storage_options,
|
||||
read_consistency_interval=read_consistency_interval,
|
||||
**kwargs,
|
||||
)
|
||||
_check_s3_bucket_with_dots(str(uri), storage_options)
|
||||
@@ -345,7 +356,6 @@ async def connect_async(
|
||||
host_override: str, optional
|
||||
The override url for LanceDB Cloud.
|
||||
read_consistency_interval: timedelta, default None
|
||||
(For LanceDB OSS only)
|
||||
The interval at which to check for updates to the table from other
|
||||
processes. If None, then consistency is not checked. For performance
|
||||
reasons, this is the default. For strong consistency, set this to
|
||||
@@ -355,6 +365,10 @@ async def connect_async(
|
||||
the last check, then the table will be checked for updates. Note: this
|
||||
consistency only applies to read operations. Write operations are
|
||||
always consistent.
|
||||
|
||||
Stronger consistency is not free. The smaller the interval, the more
|
||||
often each read pays the cost of checking for updates against object
|
||||
storage, raising per-read latency and cost.
|
||||
client_config: ClientConfig or dict, optional
|
||||
Configuration options for the LanceDB Cloud HTTP client. If a dict, then
|
||||
the keys are the attributes of the ClientConfig class. If None, then the
|
||||
@@ -387,6 +401,8 @@ async def connect_async(
|
||||
... db = await lancedb.connect_async("s3://my-bucket/lancedb",
|
||||
... storage_options={
|
||||
... "aws_access_key_id": "***"})
|
||||
... # For tests and temporary data, use an in-memory database
|
||||
... db = await lancedb.connect_async("memory://")
|
||||
... # Connect to LanceDB cloud
|
||||
... db = await lancedb.connect_async("db://my_database", api_key="ldb_...",
|
||||
... client_config={
|
||||
|
||||
@@ -220,6 +220,7 @@ class Table:
|
||||
async def set_unenforced_primary_key(self, columns: List[str]) -> None: ...
|
||||
async def set_lsm_write_spec(self, spec: LsmWriteSpec) -> None: ...
|
||||
async def unset_lsm_write_spec(self) -> None: ...
|
||||
async def close_lsm_writers(self) -> None: ...
|
||||
@property
|
||||
def tags(self) -> Tags: ...
|
||||
def query(self) -> Query: ...
|
||||
@@ -420,6 +421,7 @@ class MergeResult:
|
||||
num_inserted_rows: int
|
||||
num_deleted_rows: int
|
||||
num_attempts: int
|
||||
num_rows: int
|
||||
|
||||
class LsmWriteSpec:
|
||||
"""Specification selecting Lance's MemWAL LSM-style write path for
|
||||
|
||||
@@ -281,6 +281,9 @@ class HnswPq:
|
||||
m: int = 20
|
||||
ef_construction: int = 300
|
||||
target_partition_size: Optional[int] = None
|
||||
# Name of the accelerator (e.g. "cuda") to use for IVF training. When set,
|
||||
# create_index() dispatches to pylance to build the index on the accelerator.
|
||||
accelerator: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -386,6 +389,9 @@ class HnswSq:
|
||||
m: int = 20
|
||||
ef_construction: int = 300
|
||||
target_partition_size: Optional[int] = None
|
||||
# Name of the accelerator (e.g. "cuda") to use for IVF training. When set,
|
||||
# create_index() dispatches to pylance to build the index on the accelerator.
|
||||
accelerator: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -579,6 +585,9 @@ class IvfFlat:
|
||||
max_iterations: int = 50
|
||||
sample_rate: int = 256
|
||||
target_partition_size: Optional[int] = None
|
||||
# Name of the accelerator (e.g. "cuda") to use for IVF training. When set,
|
||||
# create_index() dispatches to pylance to build the index on the accelerator.
|
||||
accelerator: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -609,6 +618,9 @@ class IvfSq:
|
||||
max_iterations: int = 50
|
||||
sample_rate: int = 256
|
||||
target_partition_size: Optional[int] = None
|
||||
# Name of the accelerator (e.g. "cuda") to use for IVF training. When set,
|
||||
# create_index() dispatches to pylance to build the index on the accelerator.
|
||||
accelerator: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -739,6 +751,9 @@ class IvfPq:
|
||||
max_iterations: int = 50
|
||||
sample_rate: int = 256
|
||||
target_partition_size: Optional[int] = None
|
||||
# Name of the accelerator (e.g. "cuda") to use for IVF training. When set,
|
||||
# create_index() dispatches to pylance to build the index on the accelerator.
|
||||
accelerator: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -792,6 +807,9 @@ class IvfRq:
|
||||
max_iterations: int = 50
|
||||
sample_rate: int = 256
|
||||
target_partition_size: Optional[int] = None
|
||||
# Name of the accelerator (e.g. "cuda") to use for IVF training. When set,
|
||||
# create_index() dispatches to pylance to build the index on the accelerator.
|
||||
accelerator: Optional[str] = None
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -34,6 +34,8 @@ class LanceMergeInsertBuilder(object):
|
||||
self._when_not_matched_by_source_condition = None
|
||||
self._timeout = None
|
||||
self._use_index = True
|
||||
self._use_lsm_write = None
|
||||
self._validate_single_shard = None
|
||||
|
||||
def when_matched_update_all(
|
||||
self, *, where: Optional[str] = None
|
||||
@@ -96,6 +98,46 @@ class LanceMergeInsertBuilder(object):
|
||||
self._use_index = use_index
|
||||
return self
|
||||
|
||||
def use_lsm_write(self, use_lsm_write: bool) -> LanceMergeInsertBuilder:
|
||||
"""
|
||||
Controls whether the merge uses the MemWAL LSM write path.
|
||||
|
||||
By default (unset), a `merge_insert` on a table with an LSM write spec
|
||||
is routed through Lance's MemWAL shard writer, and a table without one
|
||||
uses the standard path. Pass `False` to force the standard path even
|
||||
when a spec is set. Pass `True` to require a spec — `merge_insert`
|
||||
raises an error if none is installed.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
use_lsm_write: bool
|
||||
Whether to use the LSM write path.
|
||||
"""
|
||||
self._use_lsm_write = use_lsm_write
|
||||
return self
|
||||
|
||||
def validate_single_shard(
|
||||
self, validate_single_shard: bool
|
||||
) -> LanceMergeInsertBuilder:
|
||||
"""
|
||||
Controls how an LSM merge checks that its input targets a single shard.
|
||||
|
||||
When a table has an LSM write spec, every row in a `merge_insert` call
|
||||
must route to the same shard. When `True` (the default), every row is
|
||||
inspected to verify this. When `False`, only the first row is inspected
|
||||
and the shard it routes to is used for the whole input — a faster path
|
||||
for callers that have already pre-sharded their input.
|
||||
|
||||
Has no effect on tables without an LSM write spec.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
validate_single_shard: bool
|
||||
Whether to check every row routes to one shard. Defaults to `True`.
|
||||
"""
|
||||
self._validate_single_shard = validate_single_shard
|
||||
return self
|
||||
|
||||
def execute(
|
||||
self,
|
||||
new_data: DATA,
|
||||
|
||||
@@ -109,6 +109,7 @@ class RemoteDBConnection(DBConnection):
|
||||
connection_timeout: Optional[float] = None,
|
||||
read_timeout: Optional[float] = None,
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
read_consistency_interval: Optional[timedelta] = None,
|
||||
):
|
||||
"""Connect to a remote LanceDB database."""
|
||||
if isinstance(client_config, dict):
|
||||
@@ -167,6 +168,7 @@ class RemoteDBConnection(DBConnection):
|
||||
host_override=host_override,
|
||||
client_config=client_config,
|
||||
storage_options=storage_options,
|
||||
read_consistency_interval=read_consistency_interval,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -2,12 +2,25 @@
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
from datetime import timedelta
|
||||
import deprecation
|
||||
import logging
|
||||
from functools import cached_property
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Union, Literal
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Union,
|
||||
Literal,
|
||||
overload,
|
||||
)
|
||||
import warnings
|
||||
|
||||
from lancedb import __version__
|
||||
|
||||
from lancedb._lancedb import (
|
||||
AddColumnsResult,
|
||||
AddResult,
|
||||
@@ -33,6 +46,7 @@ from lancedb.index import (
|
||||
LabelList,
|
||||
)
|
||||
from lancedb.remote.db import LOOP
|
||||
from lancedb.table import IndexConfigType, KNOWN_METRICS
|
||||
import pyarrow as pa
|
||||
|
||||
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
@@ -195,6 +209,11 @@ class RemoteTable(Table):
|
||||
"""List all the stats of a specified index"""
|
||||
return LOOP.run(self._table.index_stats(index_uuid))
|
||||
|
||||
@deprecation.deprecated(
|
||||
deprecated_in="0.25.0",
|
||||
current_version=__version__,
|
||||
details="Use create_index() with config=BTree()/Bitmap()/LabelList() instead.",
|
||||
)
|
||||
def create_scalar_index(
|
||||
self,
|
||||
column: str,
|
||||
@@ -204,7 +223,12 @@ class RemoteTable(Table):
|
||||
wait_timeout: Optional[timedelta] = None,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
"""Creates a scalar index
|
||||
"""Creates a scalar index.
|
||||
|
||||
.. deprecated:: 0.25.0
|
||||
Use :meth:`create_index` with a BTree, Bitmap, or LabelList config instead.
|
||||
Example: ``table.create_index("column", config=BTree())``
|
||||
|
||||
Parameters
|
||||
----------
|
||||
column : str
|
||||
@@ -235,6 +259,11 @@ class RemoteTable(Table):
|
||||
)
|
||||
)
|
||||
|
||||
@deprecation.deprecated(
|
||||
deprecated_in="0.25.0",
|
||||
current_version=__version__,
|
||||
details="Use create_index() with config=FTS() instead.",
|
||||
)
|
||||
def create_fts_index(
|
||||
self,
|
||||
column: str,
|
||||
@@ -255,6 +284,12 @@ class RemoteTable(Table):
|
||||
prefix_only: bool = False,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
"""Create a full-text search index on a column.
|
||||
|
||||
.. deprecated:: 0.25.0
|
||||
Use :meth:`create_index` with an FTS config instead.
|
||||
Example: ``table.create_index("text_column", config=FTS())``
|
||||
"""
|
||||
config = FTS(
|
||||
with_position=with_position,
|
||||
base_tokenizer=base_tokenizer,
|
||||
@@ -278,9 +313,43 @@ class RemoteTable(Table):
|
||||
)
|
||||
)
|
||||
|
||||
# New unified API overload
|
||||
@overload
|
||||
def create_index(
|
||||
self,
|
||||
metric="l2",
|
||||
column: str,
|
||||
/,
|
||||
*,
|
||||
config: IndexConfigType,
|
||||
wait_timeout: Optional[timedelta] = ...,
|
||||
name: Optional[str] = ...,
|
||||
train: bool = ...,
|
||||
) -> None: ...
|
||||
|
||||
# Legacy API overload (deprecated)
|
||||
@overload
|
||||
def create_index(
|
||||
self,
|
||||
metric: Literal["l2", "cosine", "dot", "hamming"] = ...,
|
||||
vector_column_name: str = ...,
|
||||
index_cache_size: Optional[int] = ...,
|
||||
num_partitions: Optional[int] = ...,
|
||||
num_sub_vectors: Optional[int] = ...,
|
||||
replace: Optional[bool] = ...,
|
||||
accelerator: Optional[str] = ...,
|
||||
index_type: Literal[
|
||||
"VECTOR", "IVF_FLAT", "IVF_SQ", "IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"
|
||||
] = ...,
|
||||
wait_timeout: Optional[timedelta] = ...,
|
||||
*,
|
||||
num_bits: int = ...,
|
||||
name: Optional[str] = ...,
|
||||
train: bool = ...,
|
||||
) -> None: ...
|
||||
|
||||
def create_index(
|
||||
self,
|
||||
metric: str = "l2",
|
||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||
index_cache_size: Optional[int] = None,
|
||||
num_partitions: Optional[int] = None,
|
||||
@@ -291,89 +360,113 @@ class RemoteTable(Table):
|
||||
wait_timeout: Optional[timedelta] = None,
|
||||
*,
|
||||
num_bits: int = 8,
|
||||
config: Optional[IndexConfigType] = None,
|
||||
name: Optional[str] = None,
|
||||
train: bool = True,
|
||||
):
|
||||
"""Create an index on the table.
|
||||
"""Create an index on a column.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metric : str
|
||||
The metric to use for the index. Default is "l2".
|
||||
vector_column_name : str
|
||||
The name of the vector column. Default is "vector".
|
||||
This method supports both the new unified API and the legacy API
|
||||
for backwards compatibility. The new API takes the column name as the
|
||||
first positional argument and an index configuration object via
|
||||
``config``; the legacy API takes the distance metric as the first
|
||||
argument plus separate ``vector_column_name`` / ``num_partitions`` /
|
||||
etc. parameters, and emits a ``DeprecationWarning``.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> import uuid
|
||||
>>> from lancedb.schema import vector
|
||||
>>> db = lancedb.connect("db://...", api_key="...", # doctest: +SKIP
|
||||
... region="...") # doctest: +SKIP
|
||||
>>> table_name = uuid.uuid4().hex
|
||||
>>> schema = pa.schema(
|
||||
... [
|
||||
... pa.field("id", pa.uint32(), False),
|
||||
... pa.field("vector", vector(128), False),
|
||||
... pa.field("s", pa.string(), False),
|
||||
... ]
|
||||
New API (recommended):
|
||||
|
||||
>>> table.create_index( # doctest: +SKIP
|
||||
... "vector", config=IvfPq(distance_type="l2")
|
||||
... )
|
||||
>>> table = db.create_table( # doctest: +SKIP
|
||||
... table_name, # doctest: +SKIP
|
||||
... schema=schema, # doctest: +SKIP
|
||||
>>> table.create_index("category", config=BTree()) # doctest: +SKIP
|
||||
>>> table.create_index("content", config=FTS()) # doctest: +SKIP
|
||||
|
||||
Legacy API (deprecated):
|
||||
|
||||
>>> table.create_index( # doctest: +SKIP
|
||||
... "l2", vector_column_name="vector"
|
||||
... )
|
||||
>>> table.create_index("l2", "vector") # doctest: +SKIP
|
||||
"""
|
||||
# Detect whether this is a legacy API call
|
||||
is_legacy = self._is_legacy_create_index_call(
|
||||
metric,
|
||||
config,
|
||||
num_partitions,
|
||||
num_sub_vectors,
|
||||
vector_column_name,
|
||||
accelerator,
|
||||
index_cache_size,
|
||||
replace,
|
||||
)
|
||||
|
||||
if accelerator is not None:
|
||||
logging.warning(
|
||||
"GPU accelerator is not yet supported on LanceDB cloud."
|
||||
"If you have 100M+ vectors to index,"
|
||||
"please contact us at contact@lancedb.com"
|
||||
)
|
||||
if replace is not None:
|
||||
logging.warning(
|
||||
"replace is not supported on LanceDB cloud."
|
||||
"Existing indexes will always be replaced."
|
||||
if is_legacy:
|
||||
warnings.warn(
|
||||
"The create_index() API with metric/num_partitions parameters is "
|
||||
"deprecated and will be removed in a future version. "
|
||||
"Please migrate to the new unified API:\n"
|
||||
" # Old (deprecated):\n"
|
||||
" table.create_index('l2', vector_column_name='my_vector')\n"
|
||||
" # New (recommended):\n"
|
||||
" table.create_index('my_vector', config=IvfPq(distance_type='l2'))",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
index_type = index_type.upper()
|
||||
if index_type == "VECTOR" or index_type == "IVF_PQ":
|
||||
config = IvfPq(
|
||||
distance_type=metric,
|
||||
num_partitions=num_partitions,
|
||||
num_sub_vectors=num_sub_vectors,
|
||||
num_bits=num_bits,
|
||||
)
|
||||
elif index_type == "IVF_RQ":
|
||||
config = IvfRq(
|
||||
distance_type=metric,
|
||||
num_partitions=num_partitions,
|
||||
num_bits=num_bits,
|
||||
)
|
||||
elif index_type == "IVF_SQ":
|
||||
config = IvfSq(distance_type=metric, num_partitions=num_partitions)
|
||||
elif index_type == "IVF_HNSW_PQ":
|
||||
raise ValueError(
|
||||
"IVF_HNSW_PQ is not supported on LanceDB cloud."
|
||||
"Please use IVF_HNSW_SQ instead."
|
||||
)
|
||||
elif index_type == "IVF_HNSW_SQ":
|
||||
config = HnswSq(distance_type=metric, num_partitions=num_partitions)
|
||||
elif index_type == "IVF_HNSW_FLAT":
|
||||
config = HnswFlat(distance_type=metric, num_partitions=num_partitions)
|
||||
elif index_type == "IVF_FLAT":
|
||||
config = IvfFlat(distance_type=metric, num_partitions=num_partitions)
|
||||
column = vector_column_name
|
||||
|
||||
if accelerator is not None:
|
||||
logging.warning(
|
||||
"GPU accelerator is not yet supported on LanceDB cloud."
|
||||
"If you have 100M+ vectors to index,"
|
||||
"please contact us at contact@lancedb.com"
|
||||
)
|
||||
if replace is not None:
|
||||
logging.warning(
|
||||
"replace is not supported on LanceDB cloud."
|
||||
"Existing indexes will always be replaced."
|
||||
)
|
||||
|
||||
idx_type = index_type.upper()
|
||||
if idx_type == "VECTOR" or idx_type == "IVF_PQ":
|
||||
config = IvfPq(
|
||||
distance_type=metric,
|
||||
num_partitions=num_partitions,
|
||||
num_sub_vectors=num_sub_vectors,
|
||||
num_bits=num_bits,
|
||||
)
|
||||
elif idx_type == "IVF_RQ":
|
||||
config = IvfRq(
|
||||
distance_type=metric,
|
||||
num_partitions=num_partitions,
|
||||
num_bits=num_bits,
|
||||
)
|
||||
elif idx_type == "IVF_SQ":
|
||||
config = IvfSq(distance_type=metric, num_partitions=num_partitions)
|
||||
elif idx_type == "IVF_HNSW_PQ":
|
||||
raise ValueError(
|
||||
"IVF_HNSW_PQ is not supported on LanceDB cloud."
|
||||
"Please use IVF_HNSW_SQ instead."
|
||||
)
|
||||
elif idx_type == "IVF_HNSW_SQ":
|
||||
config = HnswSq(distance_type=metric, num_partitions=num_partitions)
|
||||
elif idx_type == "IVF_HNSW_FLAT":
|
||||
config = HnswFlat(distance_type=metric, num_partitions=num_partitions)
|
||||
elif idx_type == "IVF_FLAT":
|
||||
config = IvfFlat(distance_type=metric, num_partitions=num_partitions)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown vector index type: {idx_type}. Valid options are"
|
||||
" 'IVF_FLAT', 'IVF_PQ', 'IVF_RQ', 'IVF_SQ',"
|
||||
" 'IVF_HNSW_PQ', 'IVF_HNSW_SQ', 'IVF_HNSW_FLAT'"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown vector index type: {index_type}. Valid options are"
|
||||
" 'IVF_FLAT', 'IVF_PQ', 'IVF_RQ', 'IVF_SQ',"
|
||||
" 'IVF_HNSW_PQ', 'IVF_HNSW_SQ', 'IVF_HNSW_FLAT'"
|
||||
)
|
||||
column = metric
|
||||
|
||||
LOOP.run(
|
||||
self._table.create_index(
|
||||
vector_column_name,
|
||||
column,
|
||||
config=config,
|
||||
wait_timeout=wait_timeout,
|
||||
name=name,
|
||||
@@ -381,6 +474,37 @@ class RemoteTable(Table):
|
||||
)
|
||||
)
|
||||
|
||||
def _is_legacy_create_index_call(
|
||||
self,
|
||||
first_arg: str,
|
||||
config: Optional[IndexConfigType],
|
||||
num_partitions: Optional[int],
|
||||
num_sub_vectors: Optional[int],
|
||||
vector_column_name: str,
|
||||
accelerator: Optional[str],
|
||||
index_cache_size: Optional[int],
|
||||
replace: Optional[bool],
|
||||
) -> bool:
|
||||
"""Detect if this is a legacy create_index call."""
|
||||
if config is not None:
|
||||
return False
|
||||
if any(
|
||||
x is not None
|
||||
for x in (
|
||||
num_partitions,
|
||||
num_sub_vectors,
|
||||
accelerator,
|
||||
index_cache_size,
|
||||
replace,
|
||||
)
|
||||
):
|
||||
return True
|
||||
if vector_column_name != VECTOR_COLUMN_NAME:
|
||||
return True
|
||||
if first_arg.lower() in KNOWN_METRICS:
|
||||
return True
|
||||
return False
|
||||
|
||||
def add(
|
||||
self,
|
||||
data: DATA,
|
||||
@@ -741,6 +865,10 @@ class RemoteTable(Table):
|
||||
"""Not supported on LanceDB Cloud."""
|
||||
return LOOP.run(self._table.unset_lsm_write_spec())
|
||||
|
||||
def close_lsm_writers(self) -> None:
|
||||
"""No-op on LanceDB Cloud (no local shard writers)."""
|
||||
return LOOP.run(self._table.close_lsm_writers())
|
||||
|
||||
def drop_index(self, index_name: str):
|
||||
return LOOP.run(self._table.drop_index(index_name))
|
||||
|
||||
|
||||
@@ -102,8 +102,15 @@ class LinearCombinationReranker(Reranker):
|
||||
|
||||
combined_list = []
|
||||
for row_id, result in results.items():
|
||||
# Convert vector distance to a relevance score in [0, 1] where
|
||||
# higher is better. Missing vector entries are penalised with
|
||||
# `_invert_score(fill)` = 1 - fill (= 0.0 for the default fill=1).
|
||||
vector_score = self._invert_score(result.get("_distance", fill))
|
||||
fts_score = result.get("_score", fill)
|
||||
# FTS scores (BM25) are already in a "higher = more relevant" space.
|
||||
# Missing FTS entries are penalised symmetrically: we use
|
||||
# `1 - fill` so that the same `fill` value drives both missing-vector
|
||||
# and missing-FTS penalties in the same direction.
|
||||
fts_score = result.get("_score", 1 - fill)
|
||||
result["_relevance_score"] = self._combine_score(vector_score, fts_score)
|
||||
combined_list.append(result)
|
||||
|
||||
@@ -123,8 +130,12 @@ class LinearCombinationReranker(Reranker):
|
||||
return tbl
|
||||
|
||||
def _combine_score(self, vector_score, fts_score):
|
||||
# these scores represent distance
|
||||
return 1 - (self.weight * vector_score + (1 - self.weight) * fts_score)
|
||||
# Both vector_score (inverted distance) and fts_score are in a
|
||||
# "higher = more relevant" space. A straight weighted average gives
|
||||
# higher _relevance_score to better matches, as expected.
|
||||
# Previously this returned `1 - (...)` which inverted the final
|
||||
# ranking so that the *least* relevant document ranked first.
|
||||
return self.weight * vector_score + (1 - self.weight) * fts_score
|
||||
|
||||
def _invert_score(self, dist: float):
|
||||
# Invert the score between relevance and distance
|
||||
|
||||
@@ -174,6 +174,24 @@ if TYPE_CHECKING:
|
||||
DistanceType,
|
||||
)
|
||||
|
||||
# Type alias for index configuration objects
|
||||
IndexConfigType = Union[
|
||||
IvfFlat,
|
||||
IvfPq,
|
||||
IvfSq,
|
||||
IvfRq,
|
||||
HnswFlat,
|
||||
HnswPq,
|
||||
HnswSq,
|
||||
BTree,
|
||||
Bitmap,
|
||||
LabelList,
|
||||
FTS,
|
||||
]
|
||||
|
||||
# Known distance metrics for legacy API detection
|
||||
KNOWN_METRICS = {"l2", "cosine", "dot", "hamming"}
|
||||
|
||||
|
||||
def _into_pyarrow_reader(
|
||||
data, schema: Optional[pa.Schema] = None
|
||||
@@ -807,11 +825,49 @@ class Table(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
# New unified API overload
|
||||
@overload
|
||||
def create_index(
|
||||
self,
|
||||
metric="l2",
|
||||
num_partitions=256,
|
||||
num_sub_vectors=96,
|
||||
column: str,
|
||||
/,
|
||||
*,
|
||||
config: IndexConfigType,
|
||||
replace: bool = ...,
|
||||
wait_timeout: Optional[timedelta] = ...,
|
||||
name: Optional[str] = ...,
|
||||
train: bool = ...,
|
||||
) -> None: ...
|
||||
|
||||
# Legacy API overload (deprecated)
|
||||
@overload
|
||||
def create_index(
|
||||
self,
|
||||
metric: Literal["l2", "cosine", "dot", "hamming"] = ...,
|
||||
num_partitions: Optional[int] = ...,
|
||||
num_sub_vectors: Optional[int] = ...,
|
||||
vector_column_name: str = ...,
|
||||
replace: bool = ...,
|
||||
accelerator: Optional[str] = ...,
|
||||
index_cache_size: Optional[int] = ...,
|
||||
*,
|
||||
index_type: VectorIndexType = ...,
|
||||
wait_timeout: Optional[timedelta] = ...,
|
||||
num_bits: int = ...,
|
||||
max_iterations: int = ...,
|
||||
sample_rate: int = ...,
|
||||
m: int = ...,
|
||||
ef_construction: int = ...,
|
||||
name: Optional[str] = ...,
|
||||
train: bool = ...,
|
||||
target_partition_size: Optional[int] = ...,
|
||||
) -> None: ...
|
||||
|
||||
def create_index(
|
||||
self,
|
||||
metric: DistanceType = "l2",
|
||||
num_partitions: Optional[int] = None,
|
||||
num_sub_vectors: Optional[int] = None,
|
||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||
replace: bool = True,
|
||||
accelerator: Optional[str] = None,
|
||||
@@ -824,46 +880,53 @@ class Table(ABC):
|
||||
sample_rate: int = 256,
|
||||
m: int = 20,
|
||||
ef_construction: int = 300,
|
||||
config: Optional[IndexConfigType] = None,
|
||||
name: Optional[str] = None,
|
||||
train: bool = True,
|
||||
target_partition_size: Optional[int] = None,
|
||||
):
|
||||
"""Create an index on the table.
|
||||
"""Create an index on a column.
|
||||
|
||||
This method supports both the new unified API and the legacy API
|
||||
for backwards compatibility. The new API takes the column name as the
|
||||
first positional argument and an index configuration object via
|
||||
``config``; the legacy API takes the distance metric as the first
|
||||
argument plus separate ``vector_column_name`` / ``num_partitions`` /
|
||||
etc. parameters, and emits a ``DeprecationWarning``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metric: str, default "l2"
|
||||
The distance metric to use when creating the index.
|
||||
Valid values are "l2", "cosine", "dot", or "hamming".
|
||||
l2 is euclidean distance.
|
||||
Hamming is available only for binary vectors.
|
||||
num_partitions: int, default 256
|
||||
The number of IVF partitions to use when creating the index.
|
||||
Default is 256.
|
||||
num_sub_vectors: int, default 96
|
||||
The number of PQ sub-vectors to use when creating the index.
|
||||
Default is 96.
|
||||
vector_column_name: str, default "vector"
|
||||
The vector column name to create the index.
|
||||
replace: bool, default True
|
||||
- If True, replace the existing index if it exists.
|
||||
metric : str
|
||||
For new API: the column name to index.
|
||||
For legacy API: the distance metric ("l2", "cosine", "dot", "hamming").
|
||||
config : IndexConfigType, optional
|
||||
The index configuration object. If provided, uses the new unified API.
|
||||
Can be one of: IvfFlat, IvfPq, IvfSq, IvfRq, HnswPq, HnswSq,
|
||||
BTree, Bitmap, LabelList, FTS.
|
||||
replace : bool, default True
|
||||
Whether to replace an existing index on this column.
|
||||
wait_timeout : timedelta, optional
|
||||
Timeout to wait for async indexing to complete.
|
||||
name : str, optional
|
||||
Custom name for the index.
|
||||
train : bool, default True
|
||||
Whether to train the index with existing data.
|
||||
|
||||
- If False, raise an error if duplicate index exists.
|
||||
accelerator: str, default None
|
||||
If set, use the given accelerator to create the index.
|
||||
Only support "cuda" for now.
|
||||
index_cache_size : int, optional
|
||||
The size of the index cache in number of entries. Default value is 256.
|
||||
num_bits: int
|
||||
The number of bits to encode sub-vectors. Only used with the IVF_PQ index.
|
||||
Only 4 and 8 are supported.
|
||||
wait_timeout: timedelta, optional
|
||||
The timeout to wait if indexing is asynchronous.
|
||||
name: str, optional
|
||||
The name of the index. If not provided, a default name will be generated.
|
||||
train: bool, default True
|
||||
Whether to train the index with existing data. Vector indices always train
|
||||
with existing data.
|
||||
Examples
|
||||
--------
|
||||
New API (recommended):
|
||||
|
||||
>>> table.create_index( # doctest: +SKIP
|
||||
... "vector", config=IvfPq(distance_type="l2")
|
||||
... )
|
||||
>>> table.create_index("category", config=BTree()) # doctest: +SKIP
|
||||
>>> table.create_index("content", config=FTS()) # doctest: +SKIP
|
||||
|
||||
Legacy API (deprecated):
|
||||
|
||||
>>> table.create_index( # doctest: +SKIP
|
||||
... "l2", vector_column_name="vector"
|
||||
... )
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1188,7 +1251,7 @@ class Table(ABC):
|
||||
... .when_not_matched_insert_all() \\
|
||||
... .execute(new_data)
|
||||
>>> res
|
||||
MergeResult(version=2, num_updated_rows=2, num_inserted_rows=1, num_deleted_rows=0, num_attempts=1)
|
||||
MergeResult(version=2, num_updated_rows=2, num_inserted_rows=1, num_deleted_rows=0, num_attempts=1, num_rows=3)
|
||||
>>> # The order of new rows is non-deterministic since we use
|
||||
>>> # a hash-join as part of this operation and so we sort here
|
||||
>>> table.to_arrow().sort_by("a").to_pandas()
|
||||
@@ -2250,11 +2313,51 @@ class LanceTable(Table):
|
||||
dataset, allow_pyarrow_filter=False, batch_size=batch_size
|
||||
)
|
||||
|
||||
# New unified API overload
|
||||
@overload
|
||||
def create_index(
|
||||
self,
|
||||
metric: DistanceType = "l2",
|
||||
num_partitions=None,
|
||||
num_sub_vectors=None,
|
||||
column: str,
|
||||
/,
|
||||
*,
|
||||
config: IndexConfigType,
|
||||
replace: bool = ...,
|
||||
wait_timeout: Optional[timedelta] = ...,
|
||||
name: Optional[str] = ...,
|
||||
train: bool = ...,
|
||||
) -> None: ...
|
||||
|
||||
# Legacy API overload (deprecated)
|
||||
@overload
|
||||
def create_index(
|
||||
self,
|
||||
metric: Literal["l2", "cosine", "dot", "hamming"] = ...,
|
||||
num_partitions: Optional[int] = ...,
|
||||
num_sub_vectors: Optional[int] = ...,
|
||||
vector_column_name: str = ...,
|
||||
replace: bool = ...,
|
||||
accelerator: Optional[str] = ...,
|
||||
index_cache_size: Optional[int] = ...,
|
||||
num_bits: int = ...,
|
||||
index_type: Literal[
|
||||
"IVF_FLAT", "IVF_SQ", "IVF_PQ", "IVF_RQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"
|
||||
] = ...,
|
||||
max_iterations: int = ...,
|
||||
sample_rate: int = ...,
|
||||
m: int = ...,
|
||||
ef_construction: int = ...,
|
||||
*,
|
||||
wait_timeout: Optional[timedelta] = ...,
|
||||
name: Optional[str] = ...,
|
||||
train: bool = ...,
|
||||
target_partition_size: Optional[int] = ...,
|
||||
) -> None: ...
|
||||
|
||||
def create_index(
|
||||
self,
|
||||
metric: str = "l2",
|
||||
num_partitions: Optional[int] = None,
|
||||
num_sub_vectors: Optional[int] = None,
|
||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||
replace: bool = True,
|
||||
accelerator: Optional[str] = None,
|
||||
@@ -2274,47 +2377,232 @@ class LanceTable(Table):
|
||||
m: int = 20,
|
||||
ef_construction: int = 300,
|
||||
*,
|
||||
config: Optional[IndexConfigType] = None,
|
||||
wait_timeout: Optional[timedelta] = None,
|
||||
name: Optional[str] = None,
|
||||
train: bool = True,
|
||||
target_partition_size: Optional[int] = None,
|
||||
):
|
||||
"""Create an index on the table."""
|
||||
if accelerator is not None:
|
||||
# accelerator is only supported through pylance.
|
||||
self.to_lance().create_index(
|
||||
column=vector_column_name,
|
||||
index_type=index_type,
|
||||
"""Create an index on a column.
|
||||
|
||||
This method supports both the new unified API and the legacy API
|
||||
for backwards compatibility. The new API takes the column name as the
|
||||
first positional argument and an index configuration object via
|
||||
``config``; the legacy API takes the distance metric as the first
|
||||
argument plus separate ``vector_column_name`` / ``num_partitions`` /
|
||||
etc. parameters, and emits a ``DeprecationWarning``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metric : str
|
||||
For new API: the column name to index.
|
||||
For legacy API: the distance metric ("l2", "cosine", "dot", "hamming").
|
||||
config : IndexConfigType, optional
|
||||
The index configuration object. If provided, uses the new unified API.
|
||||
Can be one of: IvfFlat, IvfPq, IvfSq, IvfRq, HnswPq, HnswSq,
|
||||
BTree, Bitmap, LabelList, FTS.
|
||||
replace : bool, default True
|
||||
Whether to replace an existing index on this column.
|
||||
wait_timeout : timedelta, optional
|
||||
Timeout to wait for async indexing to complete.
|
||||
name : str, optional
|
||||
Custom name for the index.
|
||||
train : bool, default True
|
||||
Whether to train the index with existing data.
|
||||
|
||||
Examples
|
||||
--------
|
||||
New API (recommended):
|
||||
|
||||
>>> table.create_index( # doctest: +SKIP
|
||||
... "vector", config=IvfPq(distance_type="l2")
|
||||
... )
|
||||
>>> table.create_index("category", config=BTree()) # doctest: +SKIP
|
||||
>>> table.create_index("content", config=FTS()) # doctest: +SKIP
|
||||
|
||||
Legacy API (deprecated):
|
||||
|
||||
>>> table.create_index( # doctest: +SKIP
|
||||
... "l2", vector_column_name="vector"
|
||||
... )
|
||||
"""
|
||||
# Detect whether this is a legacy API call
|
||||
is_legacy = self._is_legacy_create_index_call(
|
||||
metric,
|
||||
config,
|
||||
num_partitions,
|
||||
num_sub_vectors,
|
||||
vector_column_name,
|
||||
accelerator,
|
||||
index_cache_size,
|
||||
)
|
||||
|
||||
if is_legacy:
|
||||
warnings.warn(
|
||||
"The create_index() API with metric/num_partitions parameters is "
|
||||
"deprecated and will be removed in a future version. "
|
||||
"Please migrate to the new unified API:\n"
|
||||
" # Old (deprecated):\n"
|
||||
" table.create_index('l2', vector_column_name='my_vector')\n"
|
||||
" # New (recommended):\n"
|
||||
" table.create_index('my_vector', config=IvfPq(distance_type='l2'))",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# Legacy API: first arg is the distance metric
|
||||
column = vector_column_name
|
||||
|
||||
# Build config from legacy parameters
|
||||
config = self._build_vector_config_from_legacy_params(
|
||||
metric=metric,
|
||||
index_type=index_type,
|
||||
num_partitions=num_partitions,
|
||||
num_sub_vectors=num_sub_vectors,
|
||||
replace=replace,
|
||||
accelerator=accelerator,
|
||||
index_cache_size=index_cache_size,
|
||||
num_bits=num_bits,
|
||||
max_iterations=max_iterations,
|
||||
sample_rate=sample_rate,
|
||||
m=m,
|
||||
ef_construction=ef_construction,
|
||||
target_partition_size=target_partition_size,
|
||||
accelerator=accelerator,
|
||||
)
|
||||
self.checkout_latest()
|
||||
return
|
||||
elif index_type == "IVF_FLAT":
|
||||
config = IvfFlat(
|
||||
|
||||
# Handle accelerator through pylance
|
||||
if accelerator is not None:
|
||||
self.to_lance().create_index(
|
||||
column=column,
|
||||
index_type=index_type,
|
||||
metric=metric,
|
||||
num_partitions=num_partitions,
|
||||
num_sub_vectors=num_sub_vectors,
|
||||
replace=replace,
|
||||
accelerator=accelerator,
|
||||
index_cache_size=index_cache_size,
|
||||
num_bits=num_bits,
|
||||
m=m,
|
||||
ef_construction=ef_construction,
|
||||
target_partition_size=target_partition_size,
|
||||
)
|
||||
self.checkout_latest()
|
||||
return
|
||||
else:
|
||||
# New API: metric is the column name
|
||||
column = metric
|
||||
|
||||
# Check if config has accelerator set and dispatch to pylance
|
||||
if config is not None and hasattr(config, "accelerator"):
|
||||
acc = getattr(config, "accelerator", None)
|
||||
if acc is not None:
|
||||
# Dispatch to pylance for GPU acceleration
|
||||
index_type_map = {
|
||||
"IvfFlat": "IVF_FLAT",
|
||||
"IvfSq": "IVF_SQ",
|
||||
"IvfPq": "IVF_PQ",
|
||||
"IvfRq": "IVF_RQ",
|
||||
"HnswPq": "IVF_HNSW_PQ",
|
||||
"HnswSq": "IVF_HNSW_SQ",
|
||||
}
|
||||
cfg_type = type(config).__name__
|
||||
lance_index_type = index_type_map.get(cfg_type, "IVF_PQ")
|
||||
|
||||
self.to_lance().create_index(
|
||||
column=column,
|
||||
index_type=lance_index_type,
|
||||
metric=getattr(config, "distance_type", "l2"),
|
||||
num_partitions=getattr(config, "num_partitions", None),
|
||||
num_sub_vectors=getattr(config, "num_sub_vectors", None),
|
||||
replace=replace,
|
||||
accelerator=acc,
|
||||
num_bits=getattr(config, "num_bits", 8),
|
||||
m=getattr(config, "m", 20),
|
||||
ef_construction=getattr(config, "ef_construction", 300),
|
||||
target_partition_size=getattr(
|
||||
config, "target_partition_size", None
|
||||
),
|
||||
)
|
||||
self.checkout_latest()
|
||||
return
|
||||
|
||||
return LOOP.run(
|
||||
self._table.create_index(
|
||||
column,
|
||||
replace=replace,
|
||||
config=config,
|
||||
wait_timeout=wait_timeout,
|
||||
name=name,
|
||||
train=train,
|
||||
)
|
||||
)
|
||||
|
||||
def _is_legacy_create_index_call(
|
||||
self,
|
||||
first_arg: str,
|
||||
config: Optional[IndexConfigType],
|
||||
num_partitions: Optional[int],
|
||||
num_sub_vectors: Optional[int],
|
||||
vector_column_name: str,
|
||||
accelerator: Optional[str],
|
||||
index_cache_size: Optional[int],
|
||||
) -> bool:
|
||||
"""Detect if this is a legacy create_index call."""
|
||||
# If config is provided, it's definitely the new API
|
||||
if config is not None:
|
||||
return False
|
||||
|
||||
# If old-style parameters were explicitly set, it's legacy
|
||||
if any(
|
||||
x is not None
|
||||
for x in (num_partitions, num_sub_vectors, accelerator, index_cache_size)
|
||||
):
|
||||
return True
|
||||
|
||||
# If vector_column_name differs from default, it's legacy
|
||||
if vector_column_name != VECTOR_COLUMN_NAME:
|
||||
return True
|
||||
|
||||
# If first arg is a known metric, assume legacy
|
||||
if first_arg.lower() in KNOWN_METRICS:
|
||||
return True
|
||||
|
||||
# Otherwise assume new API
|
||||
return False
|
||||
|
||||
def _build_vector_config_from_legacy_params(
|
||||
self,
|
||||
metric: str,
|
||||
index_type: str,
|
||||
num_partitions: Optional[int],
|
||||
num_sub_vectors: Optional[int],
|
||||
num_bits: int,
|
||||
max_iterations: int,
|
||||
sample_rate: int,
|
||||
m: int,
|
||||
ef_construction: int,
|
||||
target_partition_size: Optional[int],
|
||||
accelerator: Optional[str],
|
||||
) -> IndexConfigType:
|
||||
"""Build an index config object from legacy parameters."""
|
||||
if index_type == "IVF_FLAT":
|
||||
return IvfFlat(
|
||||
distance_type=metric,
|
||||
num_partitions=num_partitions,
|
||||
max_iterations=max_iterations,
|
||||
sample_rate=sample_rate,
|
||||
target_partition_size=target_partition_size,
|
||||
accelerator=accelerator,
|
||||
)
|
||||
elif index_type == "IVF_SQ":
|
||||
config = IvfSq(
|
||||
return IvfSq(
|
||||
distance_type=metric,
|
||||
num_partitions=num_partitions,
|
||||
max_iterations=max_iterations,
|
||||
sample_rate=sample_rate,
|
||||
target_partition_size=target_partition_size,
|
||||
accelerator=accelerator,
|
||||
)
|
||||
elif index_type == "IVF_PQ":
|
||||
config = IvfPq(
|
||||
return IvfPq(
|
||||
distance_type=metric,
|
||||
num_partitions=num_partitions,
|
||||
num_sub_vectors=num_sub_vectors,
|
||||
@@ -2322,18 +2610,20 @@ class LanceTable(Table):
|
||||
max_iterations=max_iterations,
|
||||
sample_rate=sample_rate,
|
||||
target_partition_size=target_partition_size,
|
||||
accelerator=accelerator,
|
||||
)
|
||||
elif index_type == "IVF_RQ":
|
||||
config = IvfRq(
|
||||
return IvfRq(
|
||||
distance_type=metric,
|
||||
num_partitions=num_partitions,
|
||||
num_bits=num_bits,
|
||||
max_iterations=max_iterations,
|
||||
sample_rate=sample_rate,
|
||||
target_partition_size=target_partition_size,
|
||||
accelerator=accelerator,
|
||||
)
|
||||
elif index_type == "IVF_HNSW_PQ":
|
||||
config = HnswPq(
|
||||
return HnswPq(
|
||||
distance_type=metric,
|
||||
num_partitions=num_partitions,
|
||||
num_sub_vectors=num_sub_vectors,
|
||||
@@ -2343,9 +2633,10 @@ class LanceTable(Table):
|
||||
m=m,
|
||||
ef_construction=ef_construction,
|
||||
target_partition_size=target_partition_size,
|
||||
accelerator=accelerator,
|
||||
)
|
||||
elif index_type == "IVF_HNSW_SQ":
|
||||
config = HnswSq(
|
||||
return HnswSq(
|
||||
distance_type=metric,
|
||||
num_partitions=num_partitions,
|
||||
max_iterations=max_iterations,
|
||||
@@ -2353,9 +2644,10 @@ class LanceTable(Table):
|
||||
m=m,
|
||||
ef_construction=ef_construction,
|
||||
target_partition_size=target_partition_size,
|
||||
accelerator=accelerator,
|
||||
)
|
||||
elif index_type == "IVF_HNSW_FLAT":
|
||||
config = HnswFlat(
|
||||
return HnswFlat(
|
||||
distance_type=metric,
|
||||
num_partitions=num_partitions,
|
||||
max_iterations=max_iterations,
|
||||
@@ -2367,16 +2659,6 @@ class LanceTable(Table):
|
||||
else:
|
||||
raise ValueError(f"Unknown index type {index_type}")
|
||||
|
||||
return LOOP.run(
|
||||
self._table.create_index(
|
||||
vector_column_name,
|
||||
replace=replace,
|
||||
config=config,
|
||||
name=name,
|
||||
train=train,
|
||||
)
|
||||
)
|
||||
|
||||
def drop_index(self, name: str) -> None:
|
||||
"""
|
||||
Drops an index from the table
|
||||
@@ -2476,6 +2758,11 @@ class LanceTable(Table):
|
||||
"""
|
||||
return LOOP.run(self._table.latest_storage_options())
|
||||
|
||||
@deprecation.deprecated(
|
||||
deprecated_in="0.25.0",
|
||||
current_version=__version__,
|
||||
details="Use create_index() with config=BTree()/Bitmap()/LabelList() instead.",
|
||||
)
|
||||
def create_scalar_index(
|
||||
self,
|
||||
column: str,
|
||||
@@ -2484,6 +2771,12 @@ class LanceTable(Table):
|
||||
index_type: ScalarIndexType = "BTREE",
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
"""Create a scalar index on a column.
|
||||
|
||||
.. deprecated:: 0.25.0
|
||||
Use :meth:`create_index` with a BTree, Bitmap, or LabelList config instead.
|
||||
Example: ``table.create_index("column", config=BTree())``
|
||||
"""
|
||||
if index_type == "BTREE":
|
||||
config = BTree()
|
||||
elif index_type == "BITMAP":
|
||||
@@ -2496,6 +2789,11 @@ class LanceTable(Table):
|
||||
self._table.create_index(column, replace=replace, config=config, name=name)
|
||||
)
|
||||
|
||||
@deprecation.deprecated(
|
||||
deprecated_in="0.25.0",
|
||||
current_version=__version__,
|
||||
details="Use create_index() with config=FTS() instead.",
|
||||
)
|
||||
def create_fts_index(
|
||||
self,
|
||||
field_names: Union[str, List[str]],
|
||||
@@ -2519,6 +2817,12 @@ class LanceTable(Table):
|
||||
prefix_only: bool = False,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
"""Create a full-text search index on a column.
|
||||
|
||||
.. deprecated:: 0.25.0
|
||||
Use :meth:`create_index` with an FTS config instead.
|
||||
Example: ``table.create_index("text_column", config=FTS())``
|
||||
"""
|
||||
self._ensure_no_legacy_fts_index()
|
||||
|
||||
if use_tantivy:
|
||||
@@ -3297,6 +3601,11 @@ class LanceTable(Table):
|
||||
[`AsyncTable.unset_lsm_write_spec`][lancedb.AsyncTable.unset_lsm_write_spec]."""
|
||||
return LOOP.run(self._table.unset_lsm_write_spec())
|
||||
|
||||
def close_lsm_writers(self) -> None:
|
||||
"""Close cached MemWAL shard writers. See
|
||||
[`AsyncTable.close_lsm_writers`][lancedb.AsyncTable.close_lsm_writers]."""
|
||||
return LOOP.run(self._table.close_lsm_writers())
|
||||
|
||||
def uses_v2_manifest_paths(self) -> bool:
|
||||
"""
|
||||
Check if the table is using the new v2 manifest paths.
|
||||
@@ -3905,6 +4214,16 @@ class AsyncTable:
|
||||
"""
|
||||
await self._inner.unset_lsm_write_spec()
|
||||
|
||||
async def close_lsm_writers(self) -> None:
|
||||
"""Drain and close any cached MemWAL shard writers for this table.
|
||||
|
||||
When an LSM write spec is installed, `merge_insert` opens MemWAL shard
|
||||
writers and caches them for reuse across calls. This closes them,
|
||||
flushing pending data; writers reopen lazily on the next
|
||||
`merge_insert`. It is a no-op when no writers are cached.
|
||||
"""
|
||||
await self._inner.close_lsm_writers()
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""The name of the table."""
|
||||
@@ -4355,7 +4674,7 @@ class AsyncTable:
|
||||
... .when_not_matched_insert_all() \\
|
||||
... .execute(new_data)
|
||||
>>> res
|
||||
MergeResult(version=2, num_updated_rows=2, num_inserted_rows=1, num_deleted_rows=0, num_attempts=1)
|
||||
MergeResult(version=2, num_updated_rows=2, num_inserted_rows=1, num_deleted_rows=0, num_attempts=1, num_rows=3)
|
||||
>>> # The order of new rows is non-deterministic since we use
|
||||
>>> # a hash-join as part of this operation and so we sort here
|
||||
>>> table.to_arrow().sort_by("a").to_pandas()
|
||||
@@ -4735,6 +5054,8 @@ class AsyncTable:
|
||||
when_not_matched_by_source_condition=merge._when_not_matched_by_source_condition,
|
||||
timeout=merge._timeout,
|
||||
use_index=merge._use_index,
|
||||
use_lsm_write=merge._use_lsm_write,
|
||||
validate_single_shard=merge._validate_single_shard,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ async def test_upsert_async(mem_db_async):
|
||||
await table.count_rows() # 3
|
||||
res
|
||||
# MergeResult(version=2, num_updated_rows=1,
|
||||
# num_inserted_rows=1, num_deleted_rows=0)
|
||||
# num_inserted_rows=1, num_deleted_rows=0, num_rows=2)
|
||||
# --8<-- [end:upsert_basic_async]
|
||||
assert await table.count_rows() == 3
|
||||
assert res.version == 2
|
||||
@@ -86,7 +86,7 @@ def test_insert_if_not_exists(mem_db):
|
||||
table.count_rows() # 3
|
||||
res
|
||||
# MergeResult(version=2, num_updated_rows=0,
|
||||
# num_inserted_rows=1, num_deleted_rows=0)
|
||||
# num_inserted_rows=1, num_deleted_rows=0, num_rows=1)
|
||||
# --8<-- [end:insert_if_not_exists]
|
||||
assert table.count_rows() == 3
|
||||
assert res.version == 2
|
||||
@@ -116,7 +116,7 @@ async def test_insert_if_not_exists_async(mem_db_async):
|
||||
await table.count_rows() # 3
|
||||
res
|
||||
# MergeResult(version=2, num_updated_rows=0,
|
||||
# num_inserted_rows=1, num_deleted_rows=0)
|
||||
# num_inserted_rows=1, num_deleted_rows=0, num_rows=1)
|
||||
# --8<-- [end:insert_if_not_exists]
|
||||
assert await table.count_rows() == 3
|
||||
assert res.version == 2
|
||||
@@ -150,7 +150,7 @@ def test_replace_range(mem_db):
|
||||
table.count_rows("doc_id = 1") # 1
|
||||
res
|
||||
# MergeResult(version=2, num_updated_rows=1,
|
||||
# num_inserted_rows=0, num_deleted_rows=1)
|
||||
# num_inserted_rows=0, num_deleted_rows=1, num_rows=1)
|
||||
# --8<-- [end:insert_if_not_exists]
|
||||
assert table.count_rows("doc_id = 1") == 1
|
||||
assert res.version == 2
|
||||
@@ -185,7 +185,7 @@ async def test_replace_range_async(mem_db_async):
|
||||
await table.count_rows("doc_id = 1") # 1
|
||||
res
|
||||
# MergeResult(version=2, num_updated_rows=1,
|
||||
# num_inserted_rows=0, num_deleted_rows=1)
|
||||
# num_inserted_rows=0, num_deleted_rows=1, num_rows=1)
|
||||
# --8<-- [end:insert_if_not_exists]
|
||||
assert await table.count_rows("doc_id = 1") == 1
|
||||
assert res.version == 2
|
||||
|
||||
@@ -466,7 +466,8 @@ async def test_create_table_v2_manifest_paths_async(tmp_path):
|
||||
assert await tbl.uses_v2_manifest_paths()
|
||||
manifests_dir = tmp_path / "test_v2_manifest_paths.lance" / "_versions"
|
||||
for manifest in os.listdir(manifests_dir):
|
||||
assert re.match(r"\d{20}\.manifest", manifest)
|
||||
if manifest.endswith(".manifest"):
|
||||
assert re.match(r"\d{20}\.manifest", manifest)
|
||||
|
||||
# Start a table in V1 mode then migrate
|
||||
tbl = await db_no_v2_paths.create_table(
|
||||
@@ -476,13 +477,15 @@ async def test_create_table_v2_manifest_paths_async(tmp_path):
|
||||
assert not await tbl.uses_v2_manifest_paths()
|
||||
manifests_dir = tmp_path / "test_v2_migration.lance" / "_versions"
|
||||
for manifest in os.listdir(manifests_dir):
|
||||
assert re.match(r"\d\.manifest", manifest)
|
||||
if manifest.endswith(".manifest"):
|
||||
assert re.match(r"\d\.manifest", manifest)
|
||||
|
||||
await tbl.migrate_manifest_paths_v2()
|
||||
assert await tbl.uses_v2_manifest_paths()
|
||||
|
||||
for manifest in os.listdir(manifests_dir):
|
||||
assert re.match(r"\d{20}\.manifest", manifest)
|
||||
if manifest.endswith(".manifest"):
|
||||
assert re.match(r"\d{20}\.manifest", manifest)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -215,11 +215,12 @@ def test_reject_legacy_tantivy_index(table):
|
||||
|
||||
@pytest.mark.parametrize("with_position", [True, False])
|
||||
def test_create_inverted_index(table, with_position):
|
||||
table.create_fts_index(
|
||||
"text",
|
||||
with_position=with_position,
|
||||
name="custom_fts_index",
|
||||
)
|
||||
with pytest.warns(DeprecationWarning, match="create_fts_index"):
|
||||
table.create_fts_index(
|
||||
"text",
|
||||
with_position=with_position,
|
||||
name="custom_fts_index",
|
||||
)
|
||||
indices = table.list_indices()
|
||||
fts_indices = [i for i in indices if i.index_type == "FTS"]
|
||||
assert any(i.name == "custom_fts_index" for i in fts_indices)
|
||||
|
||||
@@ -162,12 +162,13 @@ async def test_create_bitmap_index(some_table: AsyncTable):
|
||||
await some_table.create_index("data", config=Bitmap())
|
||||
indices = await some_table.list_indices()
|
||||
assert len(indices) == 3
|
||||
# list_indices returns indices in alphabetical order by name
|
||||
assert indices[0].index_type == "Bitmap"
|
||||
assert indices[0].columns == ["id"]
|
||||
assert indices[0].columns == ["data"]
|
||||
assert indices[1].index_type == "Bitmap"
|
||||
assert indices[1].columns == ["is_active"]
|
||||
assert indices[1].columns == ["id"]
|
||||
assert indices[2].index_type == "Bitmap"
|
||||
assert indices[2].columns == ["data"]
|
||||
assert indices[2].columns == ["is_active"]
|
||||
|
||||
index_name = indices[0].name
|
||||
stats = await some_table.index_stats(index_name)
|
||||
|
||||
@@ -40,16 +40,6 @@ def _make_table(tmp_path):
|
||||
def test_set_lsm_write_spec_validates(tmp_path):
|
||||
_db, table = _make_table(tmp_path)
|
||||
|
||||
# No PK set yet.
|
||||
with pytest.raises(Exception, match="primary key"):
|
||||
table.set_lsm_write_spec(LsmWriteSpec.bucket("id", 4))
|
||||
|
||||
table.set_unenforced_primary_key("id")
|
||||
|
||||
# Column mismatch.
|
||||
with pytest.raises(Exception, match="match"):
|
||||
table.set_lsm_write_spec(LsmWriteSpec.bucket("v", 4))
|
||||
|
||||
# Out-of-range num_buckets.
|
||||
with pytest.raises(Exception, match="num_buckets"):
|
||||
table.set_lsm_write_spec(LsmWriteSpec.bucket("id", 0))
|
||||
@@ -70,7 +60,6 @@ def test_unset_lsm_write_spec(tmp_path):
|
||||
table.unset_lsm_write_spec()
|
||||
|
||||
# Install a spec, then remove it; afterwards a fresh spec can be set.
|
||||
table.set_unenforced_primary_key("id")
|
||||
table.set_lsm_write_spec(LsmWriteSpec.bucket("id", 4))
|
||||
table.unset_lsm_write_spec()
|
||||
# A second unset errors — there is no spec left to remove.
|
||||
|
||||
196
python/python/tests/test_merge_insert_lsm.py
Normal file
196
python/python/tests/test_merge_insert_lsm.py
Normal file
@@ -0,0 +1,196 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
"""Tests for the MemWAL LSM ``merge_insert`` dispatch."""
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
import lancedb
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
from lancedb._lancedb import LsmWriteSpec
|
||||
|
||||
SCHEMA = pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int64(), nullable=False),
|
||||
pa.field("value", pa.int64(), nullable=False),
|
||||
]
|
||||
)
|
||||
|
||||
REGION_SCHEMA = pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int64(), nullable=False),
|
||||
pa.field("region", pa.utf8(), nullable=False),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _reader(ids):
|
||||
batch = pa.RecordBatch.from_arrays(
|
||||
[
|
||||
pa.array(ids, type=pa.int64()),
|
||||
pa.array(list(range(len(ids))), type=pa.int64()),
|
||||
],
|
||||
schema=SCHEMA,
|
||||
)
|
||||
return pa.RecordBatchReader.from_batches(SCHEMA, [batch])
|
||||
|
||||
|
||||
def _region_reader(rows):
|
||||
batch = pa.RecordBatch.from_arrays(
|
||||
[
|
||||
pa.array([row[0] for row in rows], type=pa.int64()),
|
||||
pa.array([row[1] for row in rows], type=pa.utf8()),
|
||||
],
|
||||
schema=REGION_SCHEMA,
|
||||
)
|
||||
return pa.RecordBatchReader.from_batches(REGION_SCHEMA, [batch])
|
||||
|
||||
|
||||
def _bucket_table(tmp_path):
|
||||
"""A table with ``id`` as the primary key and a single-bucket LSM spec."""
|
||||
db = lancedb.connect(tmp_path, read_consistency_interval=timedelta(seconds=0))
|
||||
table = db.create_table("t", _reader([1, 2, 3]))
|
||||
table.set_unenforced_primary_key("id")
|
||||
# num_buckets = 1: every row routes to the single bucket.
|
||||
table.set_lsm_write_spec(LsmWriteSpec.bucket("id", 1))
|
||||
return table
|
||||
|
||||
|
||||
def test_lsm_merge_insert_bucket(tmp_path):
|
||||
table = _bucket_table(tmp_path)
|
||||
# Empty `on` defaults to the primary key.
|
||||
result = (
|
||||
table.merge_insert([])
|
||||
.when_matched_update_all()
|
||||
.when_not_matched_insert_all()
|
||||
.execute(_reader([3, 4, 5]))
|
||||
)
|
||||
# LSM path: rows go to the MemWAL, so only num_rows is populated.
|
||||
assert result.num_rows == 3
|
||||
assert result.version == 0
|
||||
assert result.num_inserted_rows == 0
|
||||
assert result.num_updated_rows == 0
|
||||
|
||||
|
||||
def test_lsm_merge_insert_unsharded(tmp_path):
|
||||
db = lancedb.connect(tmp_path, read_consistency_interval=timedelta(seconds=0))
|
||||
table = db.create_table("t", _reader([1, 2, 3]))
|
||||
table.set_unenforced_primary_key("id")
|
||||
table.set_lsm_write_spec(LsmWriteSpec.unsharded())
|
||||
result = (
|
||||
table.merge_insert("id")
|
||||
.when_matched_update_all()
|
||||
.when_not_matched_insert_all()
|
||||
.execute(_reader([10, 11, 12, 13]))
|
||||
)
|
||||
assert result.num_rows == 4
|
||||
|
||||
|
||||
def test_lsm_merge_insert_identity(tmp_path):
|
||||
db = lancedb.connect(tmp_path, read_consistency_interval=timedelta(seconds=0))
|
||||
table = db.create_table("t", _region_reader([(1, "us"), (2, "us")]))
|
||||
table.set_unenforced_primary_key("id")
|
||||
table.set_lsm_write_spec(LsmWriteSpec.identity("region"))
|
||||
# All rows share one identity value, so they route to one shard.
|
||||
result = (
|
||||
table.merge_insert([])
|
||||
.when_matched_update_all()
|
||||
.when_not_matched_insert_all()
|
||||
.execute(_region_reader([(3, "us"), (4, "us")]))
|
||||
)
|
||||
assert result.num_rows == 2
|
||||
|
||||
|
||||
def test_lsm_merge_insert_use_lsm_write_false(tmp_path):
|
||||
table = _bucket_table(tmp_path) # rows id = 1, 2, 3
|
||||
# use_lsm_write(False) opts out: the standard path runs and commits.
|
||||
result = (
|
||||
table.merge_insert("id")
|
||||
.when_not_matched_insert_all()
|
||||
.use_lsm_write(False)
|
||||
.execute(_reader([3, 4, 5]))
|
||||
)
|
||||
assert result.num_inserted_rows == 2
|
||||
assert table.count_rows() == 5
|
||||
|
||||
|
||||
def test_lsm_merge_insert_validate_single_shard_off(tmp_path):
|
||||
table = _bucket_table(tmp_path)
|
||||
result = (
|
||||
table.merge_insert([])
|
||||
.when_matched_update_all()
|
||||
.when_not_matched_insert_all()
|
||||
.validate_single_shard(False)
|
||||
.execute(_reader([6, 7, 8]))
|
||||
)
|
||||
assert result.num_rows == 3
|
||||
|
||||
|
||||
def test_lsm_merge_insert_use_lsm_write_true_requires_spec(tmp_path):
|
||||
# A table with a primary key but no LSM write spec installed.
|
||||
db = lancedb.connect(tmp_path, read_consistency_interval=timedelta(seconds=0))
|
||||
table = db.create_table("t", _reader([1, 2, 3]))
|
||||
table.set_unenforced_primary_key("id")
|
||||
with pytest.raises(Exception, match="use_lsm_write"):
|
||||
(
|
||||
table.merge_insert("id")
|
||||
.when_matched_update_all()
|
||||
.when_not_matched_insert_all()
|
||||
.use_lsm_write(True)
|
||||
.execute(_reader([4]))
|
||||
)
|
||||
|
||||
|
||||
def test_lsm_merge_insert_rejects_on_not_primary_key(tmp_path):
|
||||
table = _bucket_table(tmp_path)
|
||||
with pytest.raises(Exception, match="primary key"):
|
||||
(
|
||||
table.merge_insert("value")
|
||||
.when_matched_update_all()
|
||||
.when_not_matched_insert_all()
|
||||
.execute(_reader([1]))
|
||||
)
|
||||
|
||||
|
||||
def test_lsm_merge_insert_rejects_non_upsert(tmp_path):
|
||||
table = _bucket_table(tmp_path)
|
||||
# Insert-only (no when_matched_update_all) is not the upsert shape.
|
||||
with pytest.raises(Exception, match="upsert"):
|
||||
table.merge_insert([]).when_not_matched_insert_all().execute(_reader([4]))
|
||||
|
||||
|
||||
def test_lsm_close_writers(tmp_path):
|
||||
table = _bucket_table(tmp_path)
|
||||
(
|
||||
table.merge_insert([])
|
||||
.when_matched_update_all()
|
||||
.when_not_matched_insert_all()
|
||||
.execute(_reader([7, 8]))
|
||||
)
|
||||
table.close_lsm_writers()
|
||||
# The writer reopens lazily on the next merge_insert.
|
||||
result = (
|
||||
table.merge_insert([])
|
||||
.when_matched_update_all()
|
||||
.when_not_matched_insert_all()
|
||||
.execute(_reader([9]))
|
||||
)
|
||||
assert result.num_rows == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_lsm_merge_insert(tmp_path):
|
||||
db = await lancedb.connect_async(
|
||||
tmp_path, read_consistency_interval=timedelta(seconds=0)
|
||||
)
|
||||
table = await db.create_table("t", _reader([1, 2, 3]))
|
||||
await table.set_unenforced_primary_key("id")
|
||||
await table.set_lsm_write_spec(LsmWriteSpec.bucket("id", 1))
|
||||
|
||||
builder = (
|
||||
table.merge_insert([]).when_matched_update_all().when_not_matched_insert_all()
|
||||
)
|
||||
result = await builder.execute(_reader([3, 4, 5]))
|
||||
assert result.num_rows == 3
|
||||
await table.close_lsm_writers()
|
||||
@@ -586,22 +586,25 @@ def test_table_create_indices():
|
||||
# This is a smoke-test.
|
||||
table = db.create_table("test", [{"id": 1}])
|
||||
|
||||
# Test create_scalar_index with custom name
|
||||
table.create_scalar_index(
|
||||
"id", wait_timeout=timedelta(seconds=2), name="custom_scalar_idx"
|
||||
)
|
||||
# Test create_scalar_index with custom name (legacy method)
|
||||
with pytest.warns(DeprecationWarning, match="create_scalar_index"):
|
||||
table.create_scalar_index(
|
||||
"id", wait_timeout=timedelta(seconds=2), name="custom_scalar_idx"
|
||||
)
|
||||
|
||||
# Test create_fts_index with custom name
|
||||
table.create_fts_index(
|
||||
"text", wait_timeout=timedelta(seconds=2), name="custom_fts_idx"
|
||||
)
|
||||
# Test create_fts_index with custom name (legacy method)
|
||||
with pytest.warns(DeprecationWarning, match="create_fts_index"):
|
||||
table.create_fts_index(
|
||||
"text", wait_timeout=timedelta(seconds=2), name="custom_fts_idx"
|
||||
)
|
||||
|
||||
# Test create_index with custom name
|
||||
table.create_index(
|
||||
vector_column_name="vector",
|
||||
wait_timeout=timedelta(seconds=10),
|
||||
name="custom_vector_idx",
|
||||
)
|
||||
# Test create_index with custom name (legacy form: vector_column_name kwarg)
|
||||
with pytest.warns(DeprecationWarning, match="create_index"):
|
||||
table.create_index(
|
||||
vector_column_name="vector",
|
||||
wait_timeout=timedelta(seconds=10),
|
||||
name="custom_vector_idx",
|
||||
)
|
||||
|
||||
# Validate that the name parameter was passed correctly in requests
|
||||
assert len(received_requests) == 3
|
||||
@@ -630,6 +633,98 @@ def test_table_create_indices():
|
||||
table.drop_index("custom_fts_idx")
|
||||
|
||||
|
||||
def test_remote_create_index_new_api():
|
||||
received_requests = []
|
||||
|
||||
def handler(request):
|
||||
if request.path == "/v1/table/test/create_index/":
|
||||
content_len = int(request.headers.get("Content-Length", 0))
|
||||
body = request.rfile.read(content_len) if content_len > 0 else b""
|
||||
received_requests.append(json.loads(body) if body else {})
|
||||
request.send_response(200)
|
||||
request.end_headers()
|
||||
elif request.path == "/v1/table/test/create/?mode=create":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(b"{}")
|
||||
elif request.path == "/v1/table/test/describe/":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(
|
||||
json.dumps(
|
||||
dict(
|
||||
version=1,
|
||||
schema=dict(
|
||||
fields=[
|
||||
dict(name="id", type={"type": "int64"}, nullable=False),
|
||||
dict(
|
||||
name="category",
|
||||
type={"type": "string"},
|
||||
nullable=False,
|
||||
),
|
||||
dict(
|
||||
name="text", type={"type": "string"}, nullable=False
|
||||
),
|
||||
dict(
|
||||
name="vector",
|
||||
type={
|
||||
"type": "fixed_size_list",
|
||||
"fields": [
|
||||
dict(
|
||||
name="item",
|
||||
type={"type": "float"},
|
||||
nullable=True,
|
||||
)
|
||||
],
|
||||
"length": 2,
|
||||
},
|
||||
nullable=False,
|
||||
),
|
||||
]
|
||||
),
|
||||
)
|
||||
).encode()
|
||||
)
|
||||
else:
|
||||
request.send_response(404)
|
||||
request.end_headers()
|
||||
|
||||
from lancedb.index import BTree, FTS, IvfPq, IvfRq
|
||||
|
||||
with mock_lancedb_connection(handler) as db:
|
||||
table = db.create_table("test", [{"id": 1}])
|
||||
|
||||
# New API: column-first, config= kwarg. Should NOT emit DeprecationWarning.
|
||||
import warnings as _warnings
|
||||
|
||||
with _warnings.catch_warnings():
|
||||
_warnings.simplefilter("error", DeprecationWarning)
|
||||
table.create_index("vector", config=IvfPq(distance_type="l2"))
|
||||
table.create_index("category", config=BTree())
|
||||
table.create_index("text", config=FTS())
|
||||
# IvfRq via new API
|
||||
table.create_index("vector", config=IvfRq(distance_type="l2"))
|
||||
|
||||
# Legacy index_type="IVF_RQ" routes to IvfRq config under the hood.
|
||||
with pytest.warns(DeprecationWarning, match="create_index"):
|
||||
table.create_index(
|
||||
vector_column_name="vector",
|
||||
index_type="IVF_RQ",
|
||||
num_partitions=8,
|
||||
)
|
||||
|
||||
assert len(received_requests) == 5
|
||||
assert [req["column"] for req in received_requests] == [
|
||||
"vector",
|
||||
"category",
|
||||
"text",
|
||||
"vector",
|
||||
"vector",
|
||||
]
|
||||
|
||||
|
||||
def test_table_wait_for_index_timeout():
|
||||
def handler(request):
|
||||
index_stats = dict(
|
||||
|
||||
@@ -603,3 +603,89 @@ def test_cross_encoder_reranker_return_all(tmp_path):
|
||||
assert "_relevance_score" in result.column_names
|
||||
assert "_score" in result.column_names
|
||||
assert "_distance" in result.column_names
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regression tests for LinearCombinationReranker scoring bugs (issue #3154)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_linear_combination_best_match_ranks_first():
|
||||
"""
|
||||
The document that is BOTH the closest vector match AND the only FTS match
|
||||
must rank first. Previously _combine_score subtracted from 1, inverting
|
||||
the ranking so the worst document ranked highest.
|
||||
"""
|
||||
reranker = LinearCombinationReranker(weight=0.7, return_score="all")
|
||||
|
||||
# rowid 0: perfect vector match, sole FTS match → should rank 1st
|
||||
# rowid 1: mediocre vector, no FTS match
|
||||
# rowid 2: bad vector, no FTS match
|
||||
vector_results = pa.Table.from_pydict(
|
||||
{
|
||||
"_rowid": [0, 1, 2],
|
||||
"_distance": [0.0, 0.5, 0.9],
|
||||
}
|
||||
)
|
||||
fts_results = pa.Table.from_pydict(
|
||||
{
|
||||
"_rowid": [0],
|
||||
"_score": [1.0],
|
||||
}
|
||||
)
|
||||
|
||||
combined = reranker.merge_results(vector_results, fts_results, fill=1.0)
|
||||
scores = dict(
|
||||
zip(
|
||||
combined["_rowid"].to_pylist(),
|
||||
combined["_relevance_score"].to_pylist(),
|
||||
)
|
||||
)
|
||||
|
||||
# rowid 0 must have the highest relevance score
|
||||
assert scores[0] > scores[1], (
|
||||
f"Best match (rowid 0, score={scores[0]:.4f}) should beat "
|
||||
f"mid match (rowid 1, score={scores[1]:.4f})"
|
||||
)
|
||||
assert scores[1] > scores[2], (
|
||||
f"Mid match (rowid 1, score={scores[1]:.4f}) should beat "
|
||||
f"bad match (rowid 2, score={scores[2]:.4f})"
|
||||
)
|
||||
|
||||
|
||||
def test_linear_combination_missing_fts_is_penalised():
|
||||
"""
|
||||
A document with no FTS match must score *lower* than a document that
|
||||
has a mediocre FTS match, everything else being equal. Previously
|
||||
missing-FTS entries used fill=1.0 directly, which gave them a reward
|
||||
(via the 1-(...) inversion) instead of a penalty.
|
||||
"""
|
||||
reranker = LinearCombinationReranker(weight=0.5, return_score="all")
|
||||
|
||||
vector_results = pa.Table.from_pydict(
|
||||
{
|
||||
"_rowid": [0, 1],
|
||||
"_distance": [0.2, 0.2], # identical vector scores
|
||||
}
|
||||
)
|
||||
fts_results = pa.Table.from_pydict(
|
||||
{
|
||||
"_rowid": [0], # rowid 1 has no FTS match
|
||||
"_score": [0.3], # small FTS score
|
||||
}
|
||||
)
|
||||
|
||||
combined = reranker.merge_results(vector_results, fts_results, fill=1.0)
|
||||
scores = dict(
|
||||
zip(
|
||||
combined["_rowid"].to_pylist(),
|
||||
combined["_relevance_score"].to_pylist(),
|
||||
)
|
||||
)
|
||||
|
||||
# rowid 0 has a small FTS score; rowid 1 has none.
|
||||
# Even a small FTS contribution should beat having none at all.
|
||||
assert scores[0] > scores[1], (
|
||||
f"Document with FTS score (rowid 0, {scores[0]:.4f}) should beat "
|
||||
f"document with no FTS match (rowid 1, {scores[1]:.4f})"
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from datetime import date, datetime, timedelta
|
||||
from time import sleep
|
||||
from typing import List
|
||||
@@ -11,7 +12,7 @@ from unittest.mock import patch
|
||||
|
||||
import lancedb
|
||||
from lancedb.dependencies import _PANDAS_AVAILABLE
|
||||
from lancedb.index import HnswFlat, HnswPq, HnswSq, IvfPq
|
||||
from lancedb.index import BTree, FTS, HnswFlat, HnswPq, HnswSq, IvfPq
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
import pyarrow as pa
|
||||
@@ -928,7 +929,12 @@ def test_create_index_method(mock_create_index, mem_db: DBConnection):
|
||||
num_bits=4,
|
||||
)
|
||||
mock_create_index.assert_called_with(
|
||||
"vector", replace=True, config=expected_config, name=None, train=True
|
||||
"vector",
|
||||
replace=True,
|
||||
config=expected_config,
|
||||
wait_timeout=None,
|
||||
name=None,
|
||||
train=True,
|
||||
)
|
||||
|
||||
# Test with target_partition_size
|
||||
@@ -948,7 +954,12 @@ def test_create_index_method(mock_create_index, mem_db: DBConnection):
|
||||
target_partition_size=8192,
|
||||
)
|
||||
mock_create_index.assert_called_with(
|
||||
"vector", replace=True, config=expected_config, name=None, train=True
|
||||
"vector",
|
||||
replace=True,
|
||||
config=expected_config,
|
||||
wait_timeout=None,
|
||||
name=None,
|
||||
train=True,
|
||||
)
|
||||
|
||||
# target_partition_size has a default value,
|
||||
@@ -967,7 +978,12 @@ def test_create_index_method(mock_create_index, mem_db: DBConnection):
|
||||
num_bits=4,
|
||||
)
|
||||
mock_create_index.assert_called_with(
|
||||
"vector", replace=True, config=expected_config, name=None, train=True
|
||||
"vector",
|
||||
replace=True,
|
||||
config=expected_config,
|
||||
wait_timeout=None,
|
||||
name=None,
|
||||
train=True,
|
||||
)
|
||||
|
||||
table.create_index(
|
||||
@@ -978,7 +994,12 @@ def test_create_index_method(mock_create_index, mem_db: DBConnection):
|
||||
)
|
||||
expected_config = HnswPq(distance_type="dot")
|
||||
mock_create_index.assert_called_with(
|
||||
"my_vector", replace=False, config=expected_config, name=None, train=True
|
||||
"my_vector",
|
||||
replace=False,
|
||||
config=expected_config,
|
||||
wait_timeout=None,
|
||||
name=None,
|
||||
train=True,
|
||||
)
|
||||
|
||||
table.create_index(
|
||||
@@ -993,7 +1014,12 @@ def test_create_index_method(mock_create_index, mem_db: DBConnection):
|
||||
distance_type="cosine", sample_rate=0.1, m=29, ef_construction=10
|
||||
)
|
||||
mock_create_index.assert_called_with(
|
||||
"my_vector", replace=True, config=expected_config, name=None, train=True
|
||||
"my_vector",
|
||||
replace=True,
|
||||
config=expected_config,
|
||||
wait_timeout=None,
|
||||
name=None,
|
||||
train=True,
|
||||
)
|
||||
|
||||
table.create_index(
|
||||
@@ -1008,7 +1034,12 @@ def test_create_index_method(mock_create_index, mem_db: DBConnection):
|
||||
distance_type="cosine", sample_rate=0.1, m=29, ef_construction=10
|
||||
)
|
||||
mock_create_index.assert_called_with(
|
||||
"my_vector", replace=True, config=expected_config, name=None, train=True
|
||||
"my_vector",
|
||||
replace=True,
|
||||
config=expected_config,
|
||||
wait_timeout=None,
|
||||
name=None,
|
||||
train=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -1032,6 +1063,7 @@ def test_create_index_name_and_train_parameters(
|
||||
"vector",
|
||||
replace=True,
|
||||
config=expected_config,
|
||||
wait_timeout=None,
|
||||
name="my_custom_index",
|
||||
train=True,
|
||||
)
|
||||
@@ -1039,13 +1071,82 @@ def test_create_index_name_and_train_parameters(
|
||||
# Test with train=False
|
||||
table.create_index(vector_column_name="vector", train=False)
|
||||
mock_create_index.assert_called_with(
|
||||
"vector", replace=True, config=expected_config, name=None, train=False
|
||||
"vector",
|
||||
replace=True,
|
||||
config=expected_config,
|
||||
wait_timeout=None,
|
||||
name=None,
|
||||
train=False,
|
||||
)
|
||||
|
||||
# Test with both name and train
|
||||
table.create_index(vector_column_name="vector", name="my_index_name", train=True)
|
||||
mock_create_index.assert_called_with(
|
||||
"vector", replace=True, config=expected_config, name="my_index_name", train=True
|
||||
"vector",
|
||||
replace=True,
|
||||
config=expected_config,
|
||||
wait_timeout=None,
|
||||
name="my_index_name",
|
||||
train=True,
|
||||
)
|
||||
|
||||
|
||||
@patch("lancedb.table.AsyncTable.create_index")
|
||||
def test_create_index_legacy_emits_deprecation_warning(
|
||||
mock_create_index, mem_db: DBConnection
|
||||
):
|
||||
table = mem_db.create_table(
|
||||
"test",
|
||||
data=[{"vector": [3.1, 4.1]}, {"vector": [5.9, 26.5]}],
|
||||
)
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="create_index"):
|
||||
table.create_index(metric="l2", num_partitions=8, vector_column_name="vector")
|
||||
|
||||
|
||||
@patch("lancedb.table.AsyncTable.create_index")
|
||||
def test_create_index_new_api(mock_create_index, mem_db: DBConnection):
|
||||
table = mem_db.create_table(
|
||||
"test",
|
||||
data=[
|
||||
{"vector": [3.1, 4.1], "category": "a", "text": "hello world"},
|
||||
{"vector": [5.9, 26.5], "category": "b", "text": "goodbye"},
|
||||
],
|
||||
)
|
||||
|
||||
# Vector index via new API should not warn
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error", DeprecationWarning)
|
||||
table.create_index("vector", config=IvfPq(distance_type="l2"))
|
||||
mock_create_index.assert_called_with(
|
||||
"vector",
|
||||
replace=True,
|
||||
config=IvfPq(distance_type="l2"),
|
||||
wait_timeout=None,
|
||||
name=None,
|
||||
train=True,
|
||||
)
|
||||
|
||||
# Scalar index via new API
|
||||
table.create_index("category", config=BTree())
|
||||
mock_create_index.assert_called_with(
|
||||
"category",
|
||||
replace=True,
|
||||
config=BTree(),
|
||||
wait_timeout=None,
|
||||
name=None,
|
||||
train=True,
|
||||
)
|
||||
|
||||
# FTS index via new API
|
||||
table.create_index("text", config=FTS(with_position=True))
|
||||
mock_create_index.assert_called_with(
|
||||
"text",
|
||||
replace=True,
|
||||
config=FTS(with_position=True),
|
||||
wait_timeout=None,
|
||||
name=None,
|
||||
train=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -1861,8 +1962,9 @@ def test_create_scalar_index(mem_db: DBConnection):
|
||||
"my_table",
|
||||
data=test_data,
|
||||
)
|
||||
# Test with default name
|
||||
table.create_scalar_index("x")
|
||||
# Test with default name; confirm DeprecationWarning fires
|
||||
with pytest.warns(DeprecationWarning, match="create_scalar_index"):
|
||||
table.create_scalar_index("x")
|
||||
indices = table.list_indices()
|
||||
assert len(indices) == 1
|
||||
scalar_index = indices[0]
|
||||
|
||||
@@ -143,18 +143,20 @@ pub struct MergeResult {
|
||||
pub num_inserted_rows: u64,
|
||||
pub num_deleted_rows: u64,
|
||||
pub num_attempts: u32,
|
||||
pub num_rows: u64,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl MergeResult {
|
||||
pub fn __repr__(&self) -> String {
|
||||
format!(
|
||||
"MergeResult(version={}, num_updated_rows={}, num_inserted_rows={}, num_deleted_rows={}, num_attempts={})",
|
||||
"MergeResult(version={}, num_updated_rows={}, num_inserted_rows={}, num_deleted_rows={}, num_attempts={}, num_rows={})",
|
||||
self.version,
|
||||
self.num_updated_rows,
|
||||
self.num_inserted_rows,
|
||||
self.num_deleted_rows,
|
||||
self.num_attempts
|
||||
self.num_attempts,
|
||||
self.num_rows
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -167,6 +169,7 @@ impl From<lancedb::table::MergeResult> for MergeResult {
|
||||
num_inserted_rows: result.num_inserted_rows,
|
||||
num_deleted_rows: result.num_deleted_rows,
|
||||
num_attempts: result.num_attempts,
|
||||
num_rows: result.num_rows,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -194,6 +197,12 @@ impl LsmWriteSpec {
|
||||
}
|
||||
|
||||
/// Identity sharding — shard by the raw value of `column`.
|
||||
///
|
||||
/// `column` must be a deterministic function of the unenforced primary
|
||||
/// key: every row with a given primary key must always produce the same
|
||||
/// `column` value, or upserts of that key can land in different shards
|
||||
/// and a stale version can win. Typically `column` is the primary key
|
||||
/// itself or a stable attribute of it.
|
||||
#[staticmethod]
|
||||
pub fn identity(column: String) -> Self {
|
||||
Self {
|
||||
@@ -933,6 +942,12 @@ impl Table {
|
||||
if let Some(use_index) = parameters.use_index {
|
||||
builder.use_index(use_index);
|
||||
}
|
||||
if let Some(use_lsm_write) = parameters.use_lsm_write {
|
||||
builder.use_lsm_write(use_lsm_write);
|
||||
}
|
||||
if let Some(validate_single_shard) = parameters.validate_single_shard {
|
||||
builder.validate_single_shard(validate_single_shard);
|
||||
}
|
||||
|
||||
future_into_py(self_.py(), async move {
|
||||
let res = builder.execute(Box::new(batches)).await.infer_error()?;
|
||||
@@ -971,6 +986,13 @@ impl Table {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn close_lsm_writers(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner_ref()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner.close_lsm_writers().await.infer_error()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn uses_v2_manifest_paths(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner_ref()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
@@ -1124,6 +1146,8 @@ pub struct MergeInsertParams {
|
||||
when_not_matched_by_source_condition: Option<String>,
|
||||
timeout: Option<std::time::Duration>,
|
||||
use_index: Option<bool>,
|
||||
use_lsm_write: Option<bool>,
|
||||
validate_single_shard: Option<bool>,
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
|
||||
Reference in New Issue
Block a user