feat(python): async-sync feature parity on Table (#1914)

### Changes to sync API
* Updated `LanceTable` and `LanceDBConnection` reprs
* Add `storage_options`, `data_storage_version`, and
`enable_v2_manifest_paths` to sync create table API.
* Add `storage_options` to `open_table` in sync API.
* Add `list_indices()` and `index_stats()` to sync API
* `create_table()` will now create only 1 version when data is passed.
Previously it would always create two versions: 1 to create an empty
table and 1 to add data to it.

### Changes to async API
* Add `embedding_functions` to async `create_table()` API.
* Added `head()` to async API

### Refactors
* Refactor index parameters into dataclasses so they are easier to use
from Python
* Moved most tests to use an in-memory DB so we don't need to create so
many temp directories

Closes #1792
Closes #1932

---------

Co-authored-by: Weston Pace <weston.pace@gmail.com>
This commit is contained in:
Will Jones
2024-12-13 12:56:44 -08:00
committed by GitHub
parent d83e5a0208
commit 980aa70e2d
23 changed files with 1296 additions and 1324 deletions

View File

@@ -70,7 +70,7 @@ def connect(
default configuration is used.
storage_options: dict, optional
Additional options for the storage backend. See available options at
https://lancedb.github.io/lancedb/guides/storage/
<https://lancedb.github.io/lancedb/guides/storage/>
Examples
--------
@@ -82,11 +82,13 @@ def connect(
For object storage, use a URI prefix:
>>> db = lancedb.connect("s3://my-bucket/lancedb")
>>> db = lancedb.connect("s3://my-bucket/lancedb",
... storage_options={"aws_access_key_id": "***"})
Connect to LanceDB cloud:
>>> db = lancedb.connect("db://my_database", api_key="ldb_...")
>>> db = lancedb.connect("db://my_database", api_key="ldb_...",
... client_config={"retry_config": {"retries": 5}})
Returns
-------
@@ -164,7 +166,7 @@ async def connect_async(
default configuration is used.
storage_options: dict, optional
Additional options for the storage backend. See available options at
https://lancedb.github.io/lancedb/guides/storage/
<https://lancedb.github.io/lancedb/guides/storage/>
Examples
--------

View File

@@ -2,19 +2,8 @@ from typing import Dict, List, Optional, Tuple
import pyarrow as pa
class Index:
@staticmethod
def ivf_pq(
distance_type: Optional[str],
num_partitions: Optional[int],
num_sub_vectors: Optional[int],
max_iterations: Optional[int],
sample_rate: Optional[int],
) -> Index: ...
@staticmethod
def btree() -> Index: ...
class Connection(object):
uri: str
async def table_names(
self, start_after: Optional[str], limit: Optional[int]
) -> list[str]: ...
@@ -46,9 +35,7 @@ class Table:
async def add(self, data: pa.RecordBatchReader, mode: str) -> None: ...
async def update(self, updates: Dict[str, str], where: Optional[str]) -> None: ...
async def count_rows(self, filter: Optional[str]) -> int: ...
async def create_index(
self, column: str, config: Optional[Index], replace: Optional[bool]
): ...
async def create_index(self, column: str, config, replace: Optional[bool]): ...
async def version(self) -> int: ...
async def checkout(self, version): ...
async def checkout_latest(self): ...

View File

@@ -23,3 +23,6 @@ class BackgroundEventLoop:
def run(self, future):
return asyncio.run_coroutine_threadsafe(future, self.loop).result()
LOOP = BackgroundEventLoop()

View File

@@ -17,10 +17,11 @@ from abc import abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING, Dict, Iterable, List, Literal, Optional, Union
from lancedb.embeddings.registry import EmbeddingFunctionRegistry
from overrides import EnforceOverrides, override
from lancedb.common import data_to_reader, sanitize_uri, validate_schema
from lancedb.background_loop import BackgroundEventLoop
from lancedb.background_loop import LOOP
from ._lancedb import connect as lancedb_connect
from .table import (
@@ -43,8 +44,6 @@ if TYPE_CHECKING:
from .common import DATA, URI
from .embeddings import EmbeddingFunctionConfig
LOOP = BackgroundEventLoop()
class DBConnection(EnforceOverrides):
"""An active LanceDB connection interface."""
@@ -82,6 +81,10 @@ class DBConnection(EnforceOverrides):
on_bad_vectors: str = "error",
fill_value: float = 0.0,
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
*,
storage_options: Optional[Dict[str, str]] = None,
data_storage_version: Optional[str] = None,
enable_v2_manifest_paths: Optional[bool] = None,
) -> Table:
"""Create a [Table][lancedb.table.Table] in the database.
@@ -119,6 +122,24 @@ class DBConnection(EnforceOverrides):
One of "error", "drop", "fill".
fill_value: float
The value to use when filling vectors. Only used if on_bad_vectors="fill".
storage_options: dict, optional
Additional options for the storage backend. Options already set on the
connection will be inherited by the table, but can be overridden here.
See available options at
<https://lancedb.github.io/lancedb/guides/storage/>
data_storage_version: optional, str, default "stable"
The version of the data storage format to use. Newer versions are more
efficient but require newer versions of lance to read. The default is
"stable" which will use the legacy v2 version. See the user guide
for more details.
enable_v2_manifest_paths: bool, optional, default False
Use the new V2 manifest paths. These paths provide more efficient
opening of datasets with many versions on object stores. WARNING:
turning this on will make the dataset unreadable for older versions
of LanceDB (prior to 0.13.0). To migrate an existing dataset, instead
use the
[Table.migrate_manifest_paths_v2][lancedb.table.Table.migrate_v2_manifest_paths]
method.
Returns
-------
@@ -140,7 +161,7 @@ class DBConnection(EnforceOverrides):
>>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7},
... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}]
>>> db.create_table("my_table", data)
LanceTable(connection=..., name="my_table")
LanceTable(name='my_table', version=1, ...)
>>> db["my_table"].head()
pyarrow.Table
vector: fixed_size_list<item: float>[2]
@@ -161,7 +182,7 @@ class DBConnection(EnforceOverrides):
... "long": [-122.7, -74.1]
... })
>>> db.create_table("table2", data)
LanceTable(connection=..., name="table2")
LanceTable(name='table2', version=1, ...)
>>> db["table2"].head()
pyarrow.Table
vector: fixed_size_list<item: float>[2]
@@ -184,7 +205,7 @@ class DBConnection(EnforceOverrides):
... pa.field("long", pa.float32())
... ])
>>> db.create_table("table3", data, schema = custom_schema)
LanceTable(connection=..., name="table3")
LanceTable(name='table3', version=1, ...)
>>> db["table3"].head()
pyarrow.Table
vector: fixed_size_list<item: float>[2]
@@ -218,7 +239,7 @@ class DBConnection(EnforceOverrides):
... pa.field("price", pa.float32()),
... ])
>>> db.create_table("table4", make_batches(), schema=schema)
LanceTable(connection=..., name="table4")
LanceTable(name='table4', version=1, ...)
"""
raise NotImplementedError
@@ -226,7 +247,13 @@ class DBConnection(EnforceOverrides):
def __getitem__(self, name: str) -> LanceTable:
return self.open_table(name)
def open_table(self, name: str, *, index_cache_size: Optional[int] = None) -> Table:
def open_table(
self,
name: str,
*,
storage_options: Optional[Dict[str, str]] = None,
index_cache_size: Optional[int] = None,
) -> Table:
"""Open a Lance Table in the database.
Parameters
@@ -243,6 +270,11 @@ class DBConnection(EnforceOverrides):
This cache applies to the entire opened table, across all indices.
Setting this value higher will increase performance on larger datasets
at the expense of more RAM
storage_options: dict, optional
Additional options for the storage backend. Options already set on the
connection will be inherited by the table, but can be overridden here.
See available options at
<https://lancedb.github.io/lancedb/guides/storage/>
Returns
-------
@@ -309,15 +341,15 @@ class LanceDBConnection(DBConnection):
>>> db = lancedb.connect("./.lancedb")
>>> db.create_table("my_table", data=[{"vector": [1.1, 1.2], "b": 2},
... {"vector": [0.5, 1.3], "b": 4}])
LanceTable(connection=..., name="my_table")
LanceTable(name='my_table', version=1, ...)
>>> db.create_table("another_table", data=[{"vector": [0.4, 0.4], "b": 6}])
LanceTable(connection=..., name="another_table")
LanceTable(name='another_table', version=1, ...)
>>> sorted(db.table_names())
['another_table', 'my_table']
>>> len(db)
2
>>> db["my_table"]
LanceTable(connection=..., name="my_table")
LanceTable(name='my_table', version=1, ...)
>>> "my_table" in db
True
>>> db.drop_table("my_table")
@@ -363,7 +395,7 @@ class LanceDBConnection(DBConnection):
self._conn = AsyncConnection(LOOP.run(do_connect()))
def __repr__(self) -> str:
val = f"{self.__class__.__name__}({self._uri}"
val = f"{self.__class__.__name__}(uri={self._uri!r}"
if self.read_consistency_interval is not None:
val += f", read_consistency_interval={repr(self.read_consistency_interval)}"
val += ")"
@@ -403,6 +435,10 @@ class LanceDBConnection(DBConnection):
on_bad_vectors: str = "error",
fill_value: float = 0.0,
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
*,
storage_options: Optional[Dict[str, str]] = None,
data_storage_version: Optional[str] = None,
enable_v2_manifest_paths: Optional[bool] = None,
) -> LanceTable:
"""Create a table in the database.
@@ -424,12 +460,19 @@ class LanceDBConnection(DBConnection):
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
embedding_functions=embedding_functions,
storage_options=storage_options,
data_storage_version=data_storage_version,
enable_v2_manifest_paths=enable_v2_manifest_paths,
)
return tbl
@override
def open_table(
self, name: str, *, index_cache_size: Optional[int] = None
self,
name: str,
*,
storage_options: Optional[Dict[str, str]] = None,
index_cache_size: Optional[int] = None,
) -> LanceTable:
"""Open a table in the database.
@@ -442,7 +485,12 @@ class LanceDBConnection(DBConnection):
-------
A LanceTable object representing the table.
"""
return LanceTable.open(self, name, index_cache_size=index_cache_size)
return LanceTable.open(
self,
name,
storage_options=storage_options,
index_cache_size=index_cache_size,
)
@override
def drop_table(self, name: str, ignore_missing: bool = False):
@@ -524,6 +572,10 @@ class AsyncConnection(object):
Any attempt to use the connection after it is closed will result in an error."""
self._inner.close()
@property
def uri(self) -> str:
return self._inner.uri
async def table_names(
self, *, start_after: Optional[str] = None, limit: Optional[int] = None
) -> Iterable[str]:
@@ -557,6 +609,7 @@ class AsyncConnection(object):
fill_value: Optional[float] = None,
storage_options: Optional[Dict[str, str]] = None,
*,
embedding_functions: List[EmbeddingFunctionConfig] = None,
data_storage_version: Optional[str] = None,
use_legacy_format: Optional[bool] = None,
enable_v2_manifest_paths: Optional[bool] = None,
@@ -601,7 +654,7 @@ class AsyncConnection(object):
Additional options for the storage backend. Options already set on the
connection will be inherited by the table, but can be overridden here.
See available options at
https://lancedb.github.io/lancedb/guides/storage/
<https://lancedb.github.io/lancedb/guides/storage/>
data_storage_version: optional, str, default "stable"
The version of the data storage format to use. Newer versions are more
efficient but require newer versions of lance to read. The default is
@@ -730,6 +783,17 @@ class AsyncConnection(object):
"""
metadata = None
if embedding_functions is not None:
# If we passed in embedding functions explicitly
# then we'll override any schema metadata that
# may was implicitly specified by the LanceModel schema
registry = EmbeddingFunctionRegistry.get_instance()
metadata = registry.get_table_metadata(embedding_functions)
data, schema = sanitize_create_table(
data, schema, metadata, on_bad_vectors, fill_value
)
# Defining defaults here and not in function prototype. In the future
# these defaults will move into rust so better to keep them as None.
if on_bad_vectors is None:
@@ -791,7 +855,7 @@ class AsyncConnection(object):
Additional options for the storage backend. Options already set on the
connection will be inherited by the table, but can be overridden here.
See available options at
https://lancedb.github.io/lancedb/guides/storage/
<https://lancedb.github.io/lancedb/guides/storage/>
index_cache_size: int, default 256
Set the size of the index cache, specified as a number of entries

View File

@@ -1,8 +1,6 @@
from typing import Optional
from dataclasses import dataclass
from typing import Literal, Optional
from ._lancedb import (
Index as LanceDbIndex,
)
from ._lancedb import (
IndexConfig,
)
@@ -29,6 +27,7 @@ lang_mapping = {
}
@dataclass
class BTree:
"""Describes a btree index configuration
@@ -50,10 +49,10 @@ class BTree:
the block size may be added in the future.
"""
def __init__(self):
self._inner = LanceDbIndex.btree()
pass
@dataclass
class Bitmap:
"""Describe a Bitmap index configuration.
@@ -73,10 +72,10 @@ class Bitmap:
requires 128 / 8 * 1Bi bytes on disk.
"""
def __init__(self):
self._inner = LanceDbIndex.bitmap()
pass
@dataclass
class LabelList:
"""Describe a LabelList index configuration.
@@ -87,41 +86,57 @@ class LabelList:
For example, it works with `tags`, `categories`, `keywords`, etc.
"""
def __init__(self):
self._inner = LanceDbIndex.label_list()
pass
@dataclass
class FTS:
"""Describe a FTS index configuration.
`FTS` is a full-text search index that can be used on `String` columns
For example, it works with `title`, `description`, `content`, etc.
Attributes
----------
with_position : bool, default True
Whether to store the position of the token in the document. Setting this
to False can reduce the size of the index and improve indexing speed,
but it will disable support for phrase queries.
base_tokenizer : str, default "simple"
The base tokenizer to use for tokenization. Options are:
- "simple": Splits text by whitespace and punctuation.
- "whitespace": Split text by whitespace, but not punctuation.
- "raw": No tokenization. The entire text is treated as a single token.
language : str, default "English"
The language to use for tokenization.
max_token_length : int, default 40
The maximum token length to index. Tokens longer than this length will be
ignored.
lower_case : bool, default True
Whether to convert the token to lower case. This makes queries case-insensitive.
stem : bool, default False
Whether to stem the token. Stemming reduces words to their root form.
For example, in English "running" and "runs" would both be reduced to "run".
remove_stop_words : bool, default False
Whether to remove stop words. Stop words are common words that are often
removed from text before indexing. For example, in English "the" and "and".
ascii_folding : bool, default False
Whether to fold ASCII characters. This converts accented characters to
their ASCII equivalent. For example, "café" would be converted to "cafe".
"""
def __init__(
self,
with_position: bool = True,
base_tokenizer: str = "simple",
language: str = "English",
max_token_length: Optional[int] = 40,
lower_case: bool = True,
stem: bool = False,
remove_stop_words: bool = False,
ascii_folding: bool = False,
):
self._inner = LanceDbIndex.fts(
with_position=with_position,
base_tokenizer=base_tokenizer,
language=language,
max_token_length=max_token_length,
lower_case=lower_case,
stem=stem,
remove_stop_words=remove_stop_words,
ascii_folding=ascii_folding,
)
with_position: bool = True
base_tokenizer: Literal["simple", "raw", "whitespace"] = "simple"
language: str = "English"
max_token_length: Optional[int] = 40
lower_case: bool = True
stem: bool = False
remove_stop_words: bool = False
ascii_folding: bool = False
@dataclass
class HnswPq:
"""Describe a HNSW-PQ index configuration.
@@ -232,30 +247,17 @@ class HnswPq:
search phase.
"""
def __init__(
self,
*,
distance_type: Optional[str] = None,
num_partitions: Optional[int] = None,
num_sub_vectors: Optional[int] = None,
num_bits: Optional[int] = None,
max_iterations: Optional[int] = None,
sample_rate: Optional[int] = None,
m: Optional[int] = None,
ef_construction: Optional[int] = None,
):
self._inner = LanceDbIndex.hnsw_pq(
distance_type=distance_type,
num_partitions=num_partitions,
num_sub_vectors=num_sub_vectors,
num_bits=num_bits,
max_iterations=max_iterations,
sample_rate=sample_rate,
m=m,
ef_construction=ef_construction,
)
distance_type: Literal["l2", "cosine", "dot"] = "l2"
num_partitions: Optional[int] = None
num_sub_vectors: Optional[int] = None
num_bits: int = 8
max_iterations: int = 50
sample_rate: int = 256
m: int = 20
ef_construction: int = 300
@dataclass
class HnswSq:
"""Describe a HNSW-SQ index configuration.
@@ -345,26 +347,15 @@ class HnswSq:
"""
def __init__(
self,
*,
distance_type: Optional[str] = None,
num_partitions: Optional[int] = None,
max_iterations: Optional[int] = None,
sample_rate: Optional[int] = None,
m: Optional[int] = None,
ef_construction: Optional[int] = None,
):
self._inner = LanceDbIndex.hnsw_sq(
distance_type=distance_type,
num_partitions=num_partitions,
max_iterations=max_iterations,
sample_rate=sample_rate,
m=m,
ef_construction=ef_construction,
)
distance_type: Literal["l2", "cosine", "dot"] = "l2"
num_partitions: Optional[int] = None
max_iterations: int = 50
sample_rate: int = 256
m: int = 20
ef_construction: int = 300
@dataclass
class IvfPq:
"""Describes an IVF PQ Index
@@ -387,120 +378,103 @@ class IvfPq:
Note that training an IVF PQ index on a large dataset is a slow operation and
currently is also a memory intensive operation.
Attributes
----------
distance_type: str, default "L2"
The distance metric used to train the index
This is used when training the index to calculate the IVF partitions
(vectors are grouped in partitions with similar vectors according to this
distance type) and to calculate a subvector's code during quantization.
The distance type used to train an index MUST match the distance type used
to search the index. Failure to do so will yield inaccurate results.
The following distance types are available:
"l2" - Euclidean distance. This is a very common distance metric that
accounts for both magnitude and direction when determining the distance
between vectors. L2 distance has a range of [0, ∞).
"cosine" - Cosine distance. Cosine distance is a distance metric
calculated from the cosine similarity between two vectors. Cosine
similarity is a measure of similarity between two non-zero vectors of an
inner product space. It is defined to equal the cosine of the angle
between them. Unlike L2, the cosine distance is not affected by the
magnitude of the vectors. Cosine distance has a range of [0, 2].
Note: the cosine distance is undefined when one (or both) of the vectors
are all zeros (there is no direction). These vectors are invalid and may
never be returned from a vector search.
"dot" - Dot product. Dot distance is the dot product of two vectors. Dot
distance has a range of (-∞, ∞). If the vectors are normalized (i.e. their
L2 norm is 1), then dot distance is equivalent to the cosine distance.
num_partitions: int, default sqrt(num_rows)
The number of IVF partitions to create.
This value should generally scale with the number of rows in the dataset.
By default the number of partitions is the square root of the number of
rows.
If this value is too large then the first part of the search (picking the
right partition) will be slow. If this value is too small then the second
part of the search (searching within a partition) will be slow.
num_sub_vectors: int, default is vector dimension / 16
Number of sub-vectors of PQ.
This value controls how much the vector is compressed during the
quantization step. The more sub vectors there are the less the vector is
compressed. The default is the dimension of the vector divided by 16. If
the dimension is not evenly divisible by 16 we use the dimension divded by
8.
The above two cases are highly preferred. Having 8 or 16 values per
subvector allows us to use efficient SIMD instructions.
If the dimension is not visible by 8 then we use 1 subvector. This is not
ideal and will likely result in poor performance.
num_bits: int, default 8
Number of bits to encode each sub-vector.
This value controls how much the sub-vectors are compressed. The more bits
the more accurate the index but the slower search. The default is 8
bits. Only 4 and 8 are supported.
max_iterations: int, default 50
Max iteration to train kmeans.
When training an IVF PQ index we use kmeans to calculate the partitions.
This parameter controls how many iterations of kmeans to run.
Increasing this might improve the quality of the index but in most cases
these extra iterations have diminishing returns.
The default value is 50.
sample_rate: int, default 256
The rate used to calculate the number of training vectors for kmeans.
When an IVF PQ index is trained, we need to calculate partitions. These
are groups of vectors that are similar to each other. To do this we use an
algorithm called kmeans.
Running kmeans on a large dataset can be slow. To speed this up we run
kmeans on a random sample of the data. This parameter controls the size of
the sample. The total number of vectors used to train the index is
`sample_rate * num_partitions`.
Increasing this value might improve the quality of the index but in most
cases the default should be sufficient.
The default value is 256.
"""
def __init__(
self,
*,
distance_type: Optional[str] = None,
num_partitions: Optional[int] = None,
num_sub_vectors: Optional[int] = None,
num_bits: Optional[int] = None,
max_iterations: Optional[int] = None,
sample_rate: Optional[int] = None,
):
"""
Create an IVF PQ index config
Parameters
----------
distance_type: str, default "L2"
The distance metric used to train the index
This is used when training the index to calculate the IVF partitions
(vectors are grouped in partitions with similar vectors according to this
distance type) and to calculate a subvector's code during quantization.
The distance type used to train an index MUST match the distance type used
to search the index. Failure to do so will yield inaccurate results.
The following distance types are available:
"l2" - Euclidean distance. This is a very common distance metric that
accounts for both magnitude and direction when determining the distance
between vectors. L2 distance has a range of [0, ∞).
"cosine" - Cosine distance. Cosine distance is a distance metric
calculated from the cosine similarity between two vectors. Cosine
similarity is a measure of similarity between two non-zero vectors of an
inner product space. It is defined to equal the cosine of the angle
between them. Unlike L2, the cosine distance is not affected by the
magnitude of the vectors. Cosine distance has a range of [0, 2].
Note: the cosine distance is undefined when one (or both) of the vectors
are all zeros (there is no direction). These vectors are invalid and may
never be returned from a vector search.
"dot" - Dot product. Dot distance is the dot product of two vectors. Dot
distance has a range of (-∞, ∞). If the vectors are normalized (i.e. their
L2 norm is 1), then dot distance is equivalent to the cosine distance.
num_partitions: int, default sqrt(num_rows)
The number of IVF partitions to create.
This value should generally scale with the number of rows in the dataset.
By default the number of partitions is the square root of the number of
rows.
If this value is too large then the first part of the search (picking the
right partition) will be slow. If this value is too small then the second
part of the search (searching within a partition) will be slow.
num_sub_vectors: int, default is vector dimension / 16
Number of sub-vectors of PQ.
This value controls how much the vector is compressed during the
quantization step. The more sub vectors there are the less the vector is
compressed. The default is the dimension of the vector divided by 16. If
the dimension is not evenly divisible by 16 we use the dimension divded by
8.
The above two cases are highly preferred. Having 8 or 16 values per
subvector allows us to use efficient SIMD instructions.
If the dimension is not visible by 8 then we use 1 subvector. This is not
ideal and will likely result in poor performance.
num_bits: int, default 8
Number of bits to encode each sub-vector.
This value controls how much the sub-vectors are compressed. The more bits
the more accurate the index but the slower search. The default is 8
bits. Only 4 and 8 are supported.
max_iterations: int, default 50
Max iteration to train kmeans.
When training an IVF PQ index we use kmeans to calculate the partitions.
This parameter controls how many iterations of kmeans to run.
Increasing this might improve the quality of the index but in most cases
these extra iterations have diminishing returns.
The default value is 50.
sample_rate: int, default 256
The rate used to calculate the number of training vectors for kmeans.
When an IVF PQ index is trained, we need to calculate partitions. These
are groups of vectors that are similar to each other. To do this we use an
algorithm called kmeans.
Running kmeans on a large dataset can be slow. To speed this up we run
kmeans on a random sample of the data. This parameter controls the size of
the sample. The total number of vectors used to train the index is
`sample_rate * num_partitions`.
Increasing this value might improve the quality of the index but in most
cases the default should be sufficient.
The default value is 256.
"""
if distance_type is not None:
distance_type = distance_type.lower()
self._inner = LanceDbIndex.ivf_pq(
distance_type=distance_type,
num_partitions=num_partitions,
num_sub_vectors=num_sub_vectors,
num_bits=num_bits,
max_iterations=max_iterations,
sample_rate=sample_rate,
)
distance_type: Literal["l2", "cosine", "dot"] = "l2"
num_partitions: Optional[int] = None
num_sub_vectors: Optional[int] = None
num_bits: int = 8
max_iterations: int = 50
sample_rate: int = 256
__all__ = ["BTree", "IvfPq", "IndexConfig"]
__all__ = ["BTree", "IvfPq", "HnswPq", "HnswSq", "IndexConfig"]

View File

@@ -121,7 +121,13 @@ class RemoteDBConnection(DBConnection):
return LOOP.run(self._conn.table_names(start_after=page_token, limit=limit))
@override
def open_table(self, name: str, *, index_cache_size: Optional[int] = None) -> Table:
def open_table(
self,
name: str,
*,
storage_options: Optional[Dict[str, str]] = None,
index_cache_size: Optional[int] = None,
) -> Table:
"""Open a Lance Table in the database.
Parameters

View File

@@ -16,6 +16,8 @@ import logging
from functools import cached_property
from typing import Dict, Iterable, List, Optional, Union, Literal
from lancedb._lancedb import IndexConfig
from lancedb.embeddings.base import EmbeddingFunctionConfig
from lancedb.index import FTS, BTree, Bitmap, HnswPq, HnswSq, IvfPq, LabelList
from lancedb.remote.db import LOOP
import pyarrow as pa
@@ -25,7 +27,7 @@ from lancedb.merge import LanceMergeInsertBuilder
from lancedb.embeddings import EmbeddingFunctionRegistry
from ..query import LanceVectorQueryBuilder, LanceQueryBuilder
from ..table import AsyncTable, Query, Table
from ..table import AsyncTable, IndexStatistics, Query, Table
class RemoteTable(Table):
@@ -62,7 +64,7 @@ class RemoteTable(Table):
return LOOP.run(self._table.version())
@cached_property
def embedding_functions(self) -> dict:
def embedding_functions(self) -> Dict[str, EmbeddingFunctionConfig]:
"""
Get the embedding functions for the table
@@ -94,11 +96,11 @@ class RemoteTable(Table):
def checkout_latest(self):
return LOOP.run(self._table.checkout_latest())
def list_indices(self):
def list_indices(self) -> Iterable[IndexConfig]:
"""List all the indices on the table"""
return LOOP.run(self._table.list_indices())
def index_stats(self, index_uuid: str):
def index_stats(self, index_uuid: str) -> Optional[IndexStatistics]:
"""List all the stats of a specified index"""
return LOOP.run(self._table.index_stats(index_uuid))
@@ -515,6 +517,16 @@ class RemoteTable(Table):
def drop_columns(self, columns: Iterable[str]):
return LOOP.run(self._table.drop_columns(columns))
def uses_v2_manifest_paths(self) -> bool:
raise NotImplementedError(
"uses_v2_manifest_paths() is not supported on the LanceDB Cloud"
)
def migrate_v2_manifest_paths(self):
raise NotImplementedError(
"migrate_v2_manifest_paths() is not supported on the LanceDB Cloud"
)
def add_index(tbl: pa.Table, i: int) -> pa.Table:
return tbl.add_column(

File diff suppressed because it is too large Load Diff

View File

@@ -314,3 +314,15 @@ def deprecated(func):
def validate_table_name(name: str):
"""Verify the table name is valid."""
native_validate_table_name(name)
def add_note(base_exception: BaseException, note: str):
if hasattr(base_exception, "add_note"):
base_exception.add_note(note)
elif isinstance(base_exception.args[0], str):
base_exception.args = (
base_exception.args[0] + "\n" + note,
*base_exception.args[1:],
)
else:
raise ValueError("Cannot add note to exception")

View File

@@ -0,0 +1,32 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
from datetime import timedelta
from lancedb.db import AsyncConnection, DBConnection
import lancedb
import pytest
import pytest_asyncio
# Use an in-memory database for most tests.
@pytest.fixture
def mem_db() -> DBConnection:
return lancedb.connect("memory://")
# Use a temporary directory when we need to inspect the database files.
@pytest.fixture
def tmp_db(tmp_path) -> DBConnection:
return lancedb.connect(tmp_path)
@pytest_asyncio.fixture
async def mem_db_async() -> AsyncConnection:
return await lancedb.connect_async("memory://")
@pytest_asyncio.fixture
async def tmp_db_async(tmp_path) -> AsyncConnection:
return await lancedb.connect_async(
tmp_path, read_consistency_interval=timedelta(seconds=0)
)

View File

@@ -98,7 +98,7 @@ def test_ingest_pd(tmp_path):
assert db.open_table("test").name == db["test"].name
def test_ingest_iterator(tmp_path):
def test_ingest_iterator(mem_db: lancedb.DBConnection):
class PydanticSchema(LanceModel):
vector: Vector(2)
item: str
@@ -156,8 +156,7 @@ def test_ingest_iterator(tmp_path):
]
def run_tests(schema):
db = lancedb.connect(tmp_path)
tbl = db.create_table("table2", make_batches(), schema=schema, mode="overwrite")
tbl = mem_db.create_table("table2", make_batches(), schema=schema)
tbl.to_pandas()
assert tbl.search([3.1, 4.1]).limit(1).to_pandas()["_distance"][0] == 0.0
assert tbl.search([5.9, 26.5]).limit(1).to_pandas()["_distance"][0] == 0.0
@@ -165,15 +164,14 @@ def test_ingest_iterator(tmp_path):
tbl.add(make_batches())
assert tbl_len == 50
assert len(tbl) == tbl_len * 2
assert len(tbl.list_versions()) == 3
db.drop_database()
assert len(tbl.list_versions()) == 2
mem_db.drop_database()
run_tests(arrow_schema)
run_tests(PydanticSchema)
def test_table_names(tmp_path):
db = lancedb.connect(tmp_path)
def test_table_names(tmp_db: lancedb.DBConnection):
data = pd.DataFrame(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
@@ -181,10 +179,10 @@ def test_table_names(tmp_path):
"price": [10.0, 20.0],
}
)
db.create_table("test2", data=data)
db.create_table("test1", data=data)
db.create_table("test3", data=data)
assert db.table_names() == ["test1", "test2", "test3"]
tmp_db.create_table("test2", data=data)
tmp_db.create_table("test1", data=data)
tmp_db.create_table("test3", data=data)
assert tmp_db.table_names() == ["test1", "test2", "test3"]
@pytest.mark.asyncio
@@ -209,8 +207,7 @@ async def test_table_names_async(tmp_path):
assert await db.table_names(start_after="test1") == ["test2", "test3"]
def test_create_mode(tmp_path):
db = lancedb.connect(tmp_path)
def test_create_mode(tmp_db: lancedb.DBConnection):
data = pd.DataFrame(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
@@ -218,10 +215,10 @@ def test_create_mode(tmp_path):
"price": [10.0, 20.0],
}
)
db.create_table("test", data=data)
tmp_db.create_table("test", data=data)
with pytest.raises(Exception):
db.create_table("test", data=data)
tmp_db.create_table("test", data=data)
new_data = pd.DataFrame(
{
@@ -230,13 +227,11 @@ def test_create_mode(tmp_path):
"price": [10.0, 20.0],
}
)
tbl = db.create_table("test", data=new_data, mode="overwrite")
tbl = tmp_db.create_table("test", data=new_data, mode="overwrite")
assert tbl.to_pandas().item.tolist() == ["fizz", "buzz"]
def test_create_table_from_iterator(tmp_path):
db = lancedb.connect(tmp_path)
def test_create_table_from_iterator(mem_db: lancedb.DBConnection):
def gen_data():
for _ in range(10):
yield pa.RecordBatch.from_arrays(
@@ -248,14 +243,12 @@ def test_create_table_from_iterator(tmp_path):
["vector", "item", "price"],
)
table = db.create_table("test", data=gen_data())
table = mem_db.create_table("test", data=gen_data())
assert table.count_rows() == 10
@pytest.mark.asyncio
async def test_create_table_from_iterator_async(tmp_path):
db = await lancedb.connect_async(tmp_path)
async def test_create_table_from_iterator_async(mem_db_async: lancedb.AsyncConnection):
def gen_data():
for _ in range(10):
yield pa.RecordBatch.from_arrays(
@@ -267,12 +260,11 @@ async def test_create_table_from_iterator_async(tmp_path):
["vector", "item", "price"],
)
table = await db.create_table("test", data=gen_data())
table = await mem_db_async.create_table("test", data=gen_data())
assert await table.count_rows() == 10
def test_create_exist_ok(tmp_path):
db = lancedb.connect(tmp_path)
def test_create_exist_ok(tmp_db: lancedb.DBConnection):
data = pd.DataFrame(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
@@ -280,13 +272,13 @@ def test_create_exist_ok(tmp_path):
"price": [10.0, 20.0],
}
)
tbl = db.create_table("test", data=data)
tbl = tmp_db.create_table("test", data=data)
with pytest.raises(OSError):
db.create_table("test", data=data)
with pytest.raises(ValueError):
tmp_db.create_table("test", data=data)
# open the table but don't add more rows
tbl2 = db.create_table("test", data=data, exist_ok=True)
tbl2 = tmp_db.create_table("test", data=data, exist_ok=True)
assert tbl.name == tbl2.name
assert tbl.schema == tbl2.schema
assert len(tbl) == len(tbl2)
@@ -298,7 +290,7 @@ def test_create_exist_ok(tmp_path):
pa.field("price", pa.float64()),
]
)
tbl3 = db.create_table("test", schema=schema, exist_ok=True)
tbl3 = tmp_db.create_table("test", schema=schema, exist_ok=True)
assert tbl3.schema == schema
bad_schema = pa.schema(
@@ -310,7 +302,7 @@ def test_create_exist_ok(tmp_path):
]
)
with pytest.raises(ValueError):
db.create_table("test", schema=bad_schema, exist_ok=True)
tmp_db.create_table("test", schema=bad_schema, exist_ok=True)
@pytest.mark.asyncio
@@ -325,26 +317,24 @@ async def test_connect(tmp_path):
@pytest.mark.asyncio
async def test_close(tmp_path):
db = await lancedb.connect_async(tmp_path)
assert db.is_open()
db.close()
assert not db.is_open()
async def test_close(mem_db_async: lancedb.AsyncConnection):
assert mem_db_async.is_open()
mem_db_async.close()
assert not mem_db_async.is_open()
with pytest.raises(RuntimeError, match="is closed"):
await db.table_names()
await mem_db_async.table_names()
@pytest.mark.asyncio
async def test_context_manager(tmp_path):
with await lancedb.connect_async(tmp_path) as db:
async def test_context_manager():
with await lancedb.connect_async("memory://") as db:
assert db.is_open()
assert not db.is_open()
@pytest.mark.asyncio
async def test_create_mode_async(tmp_path):
db = await lancedb.connect_async(tmp_path)
async def test_create_mode_async(tmp_db_async: lancedb.AsyncConnection):
data = pd.DataFrame(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
@@ -352,10 +342,10 @@ async def test_create_mode_async(tmp_path):
"price": [10.0, 20.0],
}
)
await db.create_table("test", data=data)
await tmp_db_async.create_table("test", data=data)
with pytest.raises(ValueError, match="already exists"):
await db.create_table("test", data=data)
await tmp_db_async.create_table("test", data=data)
new_data = pd.DataFrame(
{
@@ -364,15 +354,14 @@ async def test_create_mode_async(tmp_path):
"price": [10.0, 20.0],
}
)
_tbl = await db.create_table("test", data=new_data, mode="overwrite")
_tbl = await tmp_db_async.create_table("test", data=new_data, mode="overwrite")
# MIGRATION: to_pandas() is not available in async
# assert tbl.to_pandas().item.tolist() == ["fizz", "buzz"]
@pytest.mark.asyncio
async def test_create_exist_ok_async(tmp_path):
db = await lancedb.connect_async(tmp_path)
async def test_create_exist_ok_async(tmp_db_async: lancedb.AsyncConnection):
data = pd.DataFrame(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
@@ -380,13 +369,13 @@ async def test_create_exist_ok_async(tmp_path):
"price": [10.0, 20.0],
}
)
tbl = await db.create_table("test", data=data)
tbl = await tmp_db_async.create_table("test", data=data)
with pytest.raises(ValueError, match="already exists"):
await db.create_table("test", data=data)
await tmp_db_async.create_table("test", data=data)
# open the table but don't add more rows
tbl2 = await db.create_table("test", data=data, exist_ok=True)
tbl2 = await tmp_db_async.create_table("test", data=data, exist_ok=True)
assert tbl.name == tbl2.name
assert await tbl.schema() == await tbl2.schema()
@@ -397,7 +386,7 @@ async def test_create_exist_ok_async(tmp_path):
pa.field("price", pa.float64()),
]
)
tbl3 = await db.create_table("test", schema=schema, exist_ok=True)
tbl3 = await tmp_db_async.create_table("test", schema=schema, exist_ok=True)
assert await tbl3.schema() == schema
# Migration: When creating a table, but the table already exists, but
@@ -448,13 +437,12 @@ async def test_create_table_v2_manifest_paths_async(tmp_path):
assert re.match(r"\d{20}\.manifest", manifest)
def test_open_table_sync(tmp_path):
db = lancedb.connect(tmp_path)
db.create_table("test", data=[{"id": 0}])
assert db.open_table("test").count_rows() == 1
assert db.open_table("test", index_cache_size=0).count_rows() == 1
with pytest.raises(FileNotFoundError, match="does not exist"):
db.open_table("does_not_exist")
def test_open_table_sync(tmp_db: lancedb.DBConnection):
tmp_db.create_table("test", data=[{"id": 0}])
assert tmp_db.open_table("test").count_rows() == 1
assert tmp_db.open_table("test", index_cache_size=0).count_rows() == 1
with pytest.raises(ValueError, match="Table 'does_not_exist' was not found"):
tmp_db.open_table("does_not_exist")
@pytest.mark.asyncio
@@ -494,8 +482,7 @@ async def test_open_table(tmp_path):
await db.open_table("does_not_exist")
def test_delete_table(tmp_path):
db = lancedb.connect(tmp_path)
def test_delete_table(tmp_db: lancedb.DBConnection):
data = pd.DataFrame(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
@@ -503,26 +490,25 @@ def test_delete_table(tmp_path):
"price": [10.0, 20.0],
}
)
db.create_table("test", data=data)
tmp_db.create_table("test", data=data)
with pytest.raises(Exception):
db.create_table("test", data=data)
tmp_db.create_table("test", data=data)
assert db.table_names() == ["test"]
assert tmp_db.table_names() == ["test"]
db.drop_table("test")
assert db.table_names() == []
tmp_db.drop_table("test")
assert tmp_db.table_names() == []
db.create_table("test", data=data)
assert db.table_names() == ["test"]
tmp_db.create_table("test", data=data)
assert tmp_db.table_names() == ["test"]
# dropping a table that does not exist should pass
# if ignore_missing=True
db.drop_table("does_not_exist", ignore_missing=True)
tmp_db.drop_table("does_not_exist", ignore_missing=True)
def test_drop_database(tmp_path):
db = lancedb.connect(tmp_path)
def test_drop_database(tmp_db: lancedb.DBConnection):
data = pd.DataFrame(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
@@ -537,51 +523,50 @@ def test_drop_database(tmp_path):
"price": [12.0, 17.0],
}
)
db.create_table("test", data=data)
tmp_db.create_table("test", data=data)
with pytest.raises(Exception):
db.create_table("test", data=data)
tmp_db.create_table("test", data=data)
assert db.table_names() == ["test"]
assert tmp_db.table_names() == ["test"]
db.create_table("new_test", data=new_data)
db.drop_database()
assert db.table_names() == []
tmp_db.create_table("new_test", data=new_data)
tmp_db.drop_database()
assert tmp_db.table_names() == []
# it should pass when no tables are present
db.create_table("test", data=new_data)
db.drop_table("test")
assert db.table_names() == []
db.drop_database()
assert db.table_names() == []
tmp_db.create_table("test", data=new_data)
tmp_db.drop_table("test")
assert tmp_db.table_names() == []
tmp_db.drop_database()
assert tmp_db.table_names() == []
# creating an empty database with schema
schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), list_size=2))])
db.create_table("empty_table", schema=schema)
tmp_db.create_table("empty_table", schema=schema)
# dropping a empty database should pass
db.drop_database()
assert db.table_names() == []
tmp_db.drop_database()
assert tmp_db.table_names() == []
def test_empty_or_nonexistent_table(tmp_path):
db = lancedb.connect(tmp_path)
def test_empty_or_nonexistent_table(mem_db: lancedb.DBConnection):
with pytest.raises(Exception):
db.create_table("test_with_no_data")
mem_db.create_table("test_with_no_data")
with pytest.raises(Exception):
db.open_table("does_not_exist")
mem_db.open_table("does_not_exist")
schema = pa.schema([pa.field("a", pa.int64(), nullable=False)])
test = db.create_table("test", schema=schema)
test = mem_db.create_table("test", schema=schema)
class TestModel(LanceModel):
a: int
test2 = db.create_table("test2", schema=TestModel)
test2 = mem_db.create_table("test2", schema=TestModel)
assert test.schema == test2.schema
@pytest.mark.asyncio
async def test_create_in_v2_mode(tmp_path):
async def test_create_in_v2_mode(mem_db_async: lancedb.AsyncConnection):
def make_data():
for i in range(10):
yield pa.record_batch([pa.array([x for x in range(1024)])], names=["x"])
@@ -591,10 +576,8 @@ async def test_create_in_v2_mode(tmp_path):
schema = pa.schema([pa.field("x", pa.int64())])
db = await lancedb.connect_async(tmp_path)
# Create table in v1 mode
tbl = await db.create_table(
tbl = await mem_db_async.create_table(
"test", data=make_data(), schema=schema, data_storage_version="legacy"
)
@@ -610,7 +593,7 @@ async def test_create_in_v2_mode(tmp_path):
assert not await is_in_v2_mode(tbl)
# Create table in v2 mode
tbl = await db.create_table(
tbl = await mem_db_async.create_table(
"test_v2", data=make_data(), schema=schema, use_legacy_format=False
)
@@ -622,7 +605,7 @@ async def test_create_in_v2_mode(tmp_path):
assert await is_in_v2_mode(tbl)
# Create empty table in v2 mode and add data
tbl = await db.create_table(
tbl = await mem_db_async.create_table(
"test_empty_v2", data=None, schema=schema, use_legacy_format=False
)
await tbl.add(make_table())
@@ -630,7 +613,7 @@ async def test_create_in_v2_mode(tmp_path):
assert await is_in_v2_mode(tbl)
# Create empty table uses v1 mode by default
tbl = await db.create_table(
tbl = await mem_db_async.create_table(
"test_empty_v2_default", data=None, schema=schema, data_storage_version="legacy"
)
await tbl.add(make_table())
@@ -638,18 +621,17 @@ async def test_create_in_v2_mode(tmp_path):
assert not await is_in_v2_mode(tbl)
def test_replace_index(tmp_path):
db = lancedb.connect(uri=tmp_path)
table = db.create_table(
def test_replace_index(mem_db: lancedb.DBConnection):
table = mem_db.create_table(
"test",
[
{"vector": np.random.rand(128), "item": "foo", "price": float(i)}
for i in range(1000)
{"vector": np.random.rand(32), "item": "foo", "price": float(i)}
for i in range(512)
],
)
table.create_index(
num_partitions=2,
num_sub_vectors=4,
num_sub_vectors=2,
)
with pytest.raises(Exception):
@@ -660,27 +642,26 @@ def test_replace_index(tmp_path):
)
table.create_index(
num_partitions=2,
num_sub_vectors=4,
num_partitions=1,
num_sub_vectors=2,
replace=True,
index_cache_size=10,
)
def test_prefilter_with_index(tmp_path):
db = lancedb.connect(uri=tmp_path)
def test_prefilter_with_index(mem_db: lancedb.DBConnection):
data = [
{"vector": np.random.rand(128), "item": "foo", "price": float(i)}
for i in range(1000)
{"vector": np.random.rand(32), "item": "foo", "price": float(i)}
for i in range(512)
]
sample_key = data[100]["vector"]
table = db.create_table(
table = mem_db.create_table(
"test",
data,
)
table.create_index(
num_partitions=2,
num_sub_vectors=4,
num_sub_vectors=2,
)
table = (
table.search(sample_key)
@@ -691,13 +672,12 @@ def test_prefilter_with_index(tmp_path):
assert table.num_rows == 1
def test_create_table_with_invalid_names(tmp_path):
db = lancedb.connect(uri=tmp_path)
def test_create_table_with_invalid_names(tmp_db: lancedb.DBConnection):
data = [{"vector": np.random.rand(128), "item": "foo"} for i in range(10)]
with pytest.raises(ValueError):
db.create_table("foo/bar", data)
tmp_db.create_table("foo/bar", data)
with pytest.raises(ValueError):
db.create_table("foo bar", data)
tmp_db.create_table("foo bar", data)
with pytest.raises(ValueError):
db.create_table("foo$$bar", data)
db.create_table("foo.bar", data)
tmp_db.create_table("foo$$bar", data)
tmp_db.create_table("foo.bar", data)

View File

@@ -15,10 +15,12 @@ import random
from unittest import mock
import lancedb as ldb
from lancedb.db import DBConnection
from lancedb.index import FTS
import numpy as np
import pandas as pd
import pytest
from utils import exception_output
pytest.importorskip("lancedb.fts")
tantivy = pytest.importorskip("tantivy")
@@ -458,3 +460,44 @@ def test_syntax(table):
table.search('the cats OR dogs were not really "pets" at all').phrase_query().limit(
10
).to_list()
def test_language(mem_db: DBConnection):
sentences = [
"Il n'y a que trois routes qui traversent la ville.",
"Je veux prendre la route vers l'est.",
"Je te retrouve au café au bout de la route.",
]
data = [{"text": s} for s in sentences]
table = mem_db.create_table("test", data=data)
with pytest.raises(ValueError) as e:
table.create_fts_index("text", use_tantivy=False, language="klingon")
assert exception_output(e) == (
"ValueError: LanceDB does not support the requested language: 'klingon'\n"
"Supported languages: Arabic, Danish, Dutch, English, Finnish, French, "
"German, Greek, Hungarian, Italian, Norwegian, Portuguese, Romanian, "
"Russian, Spanish, Swedish, Tamil, Turkish"
)
table.create_fts_index(
"text",
use_tantivy=False,
language="French",
stem=True,
ascii_folding=True,
remove_stop_words=True,
)
# Can get "routes" and "route" from the same root
results = table.search("route", query_type="fts").limit(5).to_list()
assert len(results) == 3
# Can find "café", without needing to provide accent
results = table.search("cafe", query_type="fts").limit(5).to_list()
assert len(results) == 1
# Stop words -> no results
results = table.search("la", query_type="fts").limit(5).to_list()
assert len(results) == 0

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import pytest
def exception_output(e_info: pytest.ExceptionInfo):
import traceback
# skip traceback part, since it's not worth checking in tests
lines = traceback.format_exception_only(e_info.type, e_info.value)
return "".join(lines).strip()

View File

@@ -58,6 +58,11 @@ impl Connection {
self.inner.take();
}
#[getter]
pub fn uri(&self) -> PyResult<String> {
self.get_inner().map(|inner| inner.uri().to_string())
}
#[pyo3(signature = (start_after=None, limit=None))]
pub fn table_names(
self_: PyRef<'_, Self>,

View File

@@ -12,224 +12,153 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Mutex;
use lancedb::index::scalar::FtsIndexBuilder;
use lancedb::{
index::{
scalar::BTreeIndexBuilder,
vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder},
Index as LanceDbIndex,
},
DistanceType,
use lancedb::index::{
scalar::{BTreeIndexBuilder, FtsIndexBuilder, TokenizerConfig},
vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder},
Index as LanceDbIndex,
};
use pyo3::{
exceptions::{PyKeyError, PyRuntimeError, PyValueError},
pyclass, pymethods, IntoPy, PyObject, PyResult, Python,
exceptions::{PyKeyError, PyValueError},
intern, pyclass, pymethods,
types::PyAnyMethods,
Bound, FromPyObject, IntoPy, PyAny, PyObject, PyResult, Python,
};
use crate::util::parse_distance_type;
#[pyclass]
pub struct Index {
inner: Mutex<Option<LanceDbIndex>>,
}
impl Index {
pub fn consume(&self) -> PyResult<LanceDbIndex> {
self.inner
.lock()
.unwrap()
.take()
.ok_or_else(|| PyRuntimeError::new_err("cannot use an Index more than once"))
pub fn class_name<'a>(ob: &'a Bound<'_, PyAny>) -> PyResult<&'a str> {
let full_name: &str = ob
.getattr(intern!(ob.py(), "__class__"))?
.getattr(intern!(ob.py(), "__name__"))?
.extract()?;
match full_name.rsplit_once('.') {
Some((_, name)) => Ok(name),
None => Ok(full_name),
}
}
#[pymethods]
impl Index {
#[pyo3(signature = (distance_type=None, num_partitions=None, num_sub_vectors=None,num_bits=None, max_iterations=None, sample_rate=None))]
#[staticmethod]
pub fn ivf_pq(
distance_type: Option<String>,
num_partitions: Option<u32>,
num_sub_vectors: Option<u32>,
num_bits: Option<u32>,
max_iterations: Option<u32>,
sample_rate: Option<u32>,
) -> PyResult<Self> {
let mut ivf_pq_builder = IvfPqIndexBuilder::default();
if let Some(distance_type) = distance_type {
let distance_type = match distance_type.as_str() {
"l2" => Ok(DistanceType::L2),
"cosine" => Ok(DistanceType::Cosine),
"dot" => Ok(DistanceType::Dot),
_ => Err(PyValueError::new_err(format!(
"Invalid distance type '{}'. Must be one of l2, cosine, or dot",
distance_type
))),
}?;
ivf_pq_builder = ivf_pq_builder.distance_type(distance_type);
pub fn extract_index_params(source: &Option<Bound<'_, PyAny>>) -> PyResult<LanceDbIndex> {
if let Some(source) = source {
match class_name(source)? {
"BTree" => Ok(LanceDbIndex::BTree(BTreeIndexBuilder::default())),
"Bitmap" => Ok(LanceDbIndex::Bitmap(Default::default())),
"LabelList" => Ok(LanceDbIndex::LabelList(Default::default())),
"FTS" => {
let params = source.extract::<FtsParams>()?;
let inner_opts = TokenizerConfig::default()
.base_tokenizer(params.base_tokenizer)
.language(&params.language)
.map_err(|_| PyValueError::new_err(format!("LanceDB does not support the requested language: '{}'", params.language)))?
.lower_case(params.lower_case)
.max_token_length(params.max_token_length)
.remove_stop_words(params.remove_stop_words)
.stem(params.stem)
.ascii_folding(params.ascii_folding);
let mut opts = FtsIndexBuilder::default()
.with_position(params.with_position);
opts.tokenizer_configs = inner_opts;
Ok(LanceDbIndex::FTS(opts))
},
"IvfPq" => {
let params = source.extract::<IvfPqParams>()?;
let distance_type = parse_distance_type(params.distance_type)?;
let mut ivf_pq_builder = IvfPqIndexBuilder::default()
.distance_type(distance_type)
.max_iterations(params.max_iterations)
.sample_rate(params.sample_rate)
.num_bits(params.num_bits);
if let Some(num_partitions) = params.num_partitions {
ivf_pq_builder = ivf_pq_builder.num_partitions(num_partitions);
}
if let Some(num_sub_vectors) = params.num_sub_vectors {
ivf_pq_builder = ivf_pq_builder.num_sub_vectors(num_sub_vectors);
}
Ok(LanceDbIndex::IvfPq(ivf_pq_builder))
},
"HnswPq" => {
let params = source.extract::<IvfHnswPqParams>()?;
let distance_type = parse_distance_type(params.distance_type)?;
let mut hnsw_pq_builder = IvfHnswPqIndexBuilder::default()
.distance_type(distance_type)
.max_iterations(params.max_iterations)
.sample_rate(params.sample_rate)
.num_edges(params.m)
.ef_construction(params.ef_construction)
.num_bits(params.num_bits);
if let Some(num_partitions) = params.num_partitions {
hnsw_pq_builder = hnsw_pq_builder.num_partitions(num_partitions);
}
if let Some(num_sub_vectors) = params.num_sub_vectors {
hnsw_pq_builder = hnsw_pq_builder.num_sub_vectors(num_sub_vectors);
}
Ok(LanceDbIndex::IvfHnswPq(hnsw_pq_builder))
},
"HnswSq" => {
let params = source.extract::<IvfHnswSqParams>()?;
let distance_type = parse_distance_type(params.distance_type)?;
let mut hnsw_sq_builder = IvfHnswSqIndexBuilder::default()
.distance_type(distance_type)
.max_iterations(params.max_iterations)
.sample_rate(params.sample_rate)
.num_edges(params.m)
.ef_construction(params.ef_construction);
if let Some(num_partitions) = params.num_partitions {
hnsw_sq_builder = hnsw_sq_builder.num_partitions(num_partitions);
}
Ok(LanceDbIndex::IvfHnswSq(hnsw_sq_builder))
},
not_supported => Err(PyValueError::new_err(format!(
"Invalid index type '{}'. Must be one of BTree, Bitmap, LabelList, FTS, IvfPq, IvfHnswPq, or IvfHnswSq",
not_supported
))),
}
if let Some(num_partitions) = num_partitions {
ivf_pq_builder = ivf_pq_builder.num_partitions(num_partitions);
}
if let Some(num_sub_vectors) = num_sub_vectors {
ivf_pq_builder = ivf_pq_builder.num_sub_vectors(num_sub_vectors);
}
if let Some(num_bits) = num_bits {
ivf_pq_builder = ivf_pq_builder.num_bits(num_bits);
}
if let Some(max_iterations) = max_iterations {
ivf_pq_builder = ivf_pq_builder.max_iterations(max_iterations);
}
if let Some(sample_rate) = sample_rate {
ivf_pq_builder = ivf_pq_builder.sample_rate(sample_rate);
}
Ok(Self {
inner: Mutex::new(Some(LanceDbIndex::IvfPq(ivf_pq_builder))),
})
} else {
Ok(LanceDbIndex::Auto)
}
}
#[staticmethod]
pub fn btree() -> PyResult<Self> {
Ok(Self {
inner: Mutex::new(Some(LanceDbIndex::BTree(BTreeIndexBuilder::default()))),
})
}
#[derive(FromPyObject)]
struct FtsParams {
with_position: bool,
base_tokenizer: String,
language: String,
max_token_length: Option<usize>,
lower_case: bool,
stem: bool,
remove_stop_words: bool,
ascii_folding: bool,
}
#[staticmethod]
pub fn bitmap() -> PyResult<Self> {
Ok(Self {
inner: Mutex::new(Some(LanceDbIndex::Bitmap(Default::default()))),
})
}
#[derive(FromPyObject)]
struct IvfPqParams {
distance_type: String,
num_partitions: Option<u32>,
num_sub_vectors: Option<u32>,
num_bits: u32,
max_iterations: u32,
sample_rate: u32,
}
#[staticmethod]
pub fn label_list() -> PyResult<Self> {
Ok(Self {
inner: Mutex::new(Some(LanceDbIndex::LabelList(Default::default()))),
})
}
#[derive(FromPyObject)]
struct IvfHnswPqParams {
distance_type: String,
num_partitions: Option<u32>,
num_sub_vectors: Option<u32>,
num_bits: u32,
max_iterations: u32,
sample_rate: u32,
m: u32,
ef_construction: u32,
}
#[pyo3(signature = (with_position=None, base_tokenizer=None, language=None, max_token_length=None, lower_case=None, stem=None, remove_stop_words=None, ascii_folding=None))]
#[allow(clippy::too_many_arguments)]
#[staticmethod]
pub fn fts(
with_position: Option<bool>,
base_tokenizer: Option<String>,
language: Option<String>,
max_token_length: Option<usize>,
lower_case: Option<bool>,
stem: Option<bool>,
remove_stop_words: Option<bool>,
ascii_folding: Option<bool>,
) -> Self {
let mut opts = FtsIndexBuilder::default();
if let Some(with_position) = with_position {
opts = opts.with_position(with_position);
}
if let Some(base_tokenizer) = base_tokenizer {
opts.tokenizer_configs = opts.tokenizer_configs.base_tokenizer(base_tokenizer);
}
if let Some(language) = language {
opts.tokenizer_configs = opts.tokenizer_configs.language(&language).unwrap();
}
opts.tokenizer_configs = opts.tokenizer_configs.max_token_length(max_token_length);
if let Some(lower_case) = lower_case {
opts.tokenizer_configs = opts.tokenizer_configs.lower_case(lower_case);
}
if let Some(stem) = stem {
opts.tokenizer_configs = opts.tokenizer_configs.stem(stem);
}
if let Some(remove_stop_words) = remove_stop_words {
opts.tokenizer_configs = opts.tokenizer_configs.remove_stop_words(remove_stop_words);
}
if let Some(ascii_folding) = ascii_folding {
opts.tokenizer_configs = opts.tokenizer_configs.ascii_folding(ascii_folding);
}
Self {
inner: Mutex::new(Some(LanceDbIndex::FTS(opts))),
}
}
#[pyo3(signature = (distance_type=None, num_partitions=None, num_sub_vectors=None,num_bits=None, max_iterations=None, sample_rate=None, m=None, ef_construction=None))]
#[staticmethod]
#[allow(clippy::too_many_arguments)]
pub fn hnsw_pq(
distance_type: Option<String>,
num_partitions: Option<u32>,
num_sub_vectors: Option<u32>,
num_bits: Option<u32>,
max_iterations: Option<u32>,
sample_rate: Option<u32>,
m: Option<u32>,
ef_construction: Option<u32>,
) -> PyResult<Self> {
let mut hnsw_pq_builder = IvfHnswPqIndexBuilder::default();
if let Some(distance_type) = distance_type {
let distance_type = parse_distance_type(distance_type)?;
hnsw_pq_builder = hnsw_pq_builder.distance_type(distance_type);
}
if let Some(num_partitions) = num_partitions {
hnsw_pq_builder = hnsw_pq_builder.num_partitions(num_partitions);
}
if let Some(num_sub_vectors) = num_sub_vectors {
hnsw_pq_builder = hnsw_pq_builder.num_sub_vectors(num_sub_vectors);
}
if let Some(num_bits) = num_bits {
hnsw_pq_builder = hnsw_pq_builder.num_bits(num_bits);
}
if let Some(max_iterations) = max_iterations {
hnsw_pq_builder = hnsw_pq_builder.max_iterations(max_iterations);
}
if let Some(sample_rate) = sample_rate {
hnsw_pq_builder = hnsw_pq_builder.sample_rate(sample_rate);
}
if let Some(m) = m {
hnsw_pq_builder = hnsw_pq_builder.num_edges(m);
}
if let Some(ef_construction) = ef_construction {
hnsw_pq_builder = hnsw_pq_builder.ef_construction(ef_construction);
}
Ok(Self {
inner: Mutex::new(Some(LanceDbIndex::IvfHnswPq(hnsw_pq_builder))),
})
}
#[pyo3(signature = (distance_type=None, num_partitions=None, max_iterations=None, sample_rate=None, m=None, ef_construction=None))]
#[staticmethod]
pub fn hnsw_sq(
distance_type: Option<String>,
num_partitions: Option<u32>,
max_iterations: Option<u32>,
sample_rate: Option<u32>,
m: Option<u32>,
ef_construction: Option<u32>,
) -> PyResult<Self> {
let mut hnsw_sq_builder = IvfHnswSqIndexBuilder::default();
if let Some(distance_type) = distance_type {
let distance_type = parse_distance_type(distance_type)?;
hnsw_sq_builder = hnsw_sq_builder.distance_type(distance_type);
}
if let Some(num_partitions) = num_partitions {
hnsw_sq_builder = hnsw_sq_builder.num_partitions(num_partitions);
}
if let Some(max_iterations) = max_iterations {
hnsw_sq_builder = hnsw_sq_builder.max_iterations(max_iterations);
}
if let Some(sample_rate) = sample_rate {
hnsw_sq_builder = hnsw_sq_builder.sample_rate(sample_rate);
}
if let Some(m) = m {
hnsw_sq_builder = hnsw_sq_builder.num_edges(m);
}
if let Some(ef_construction) = ef_construction {
hnsw_sq_builder = hnsw_sq_builder.ef_construction(ef_construction);
}
Ok(Self {
inner: Mutex::new(Some(LanceDbIndex::IvfHnswSq(hnsw_sq_builder))),
})
}
#[derive(FromPyObject)]
struct IvfHnswSqParams {
distance_type: String,
num_partitions: Option<u32>,
max_iterations: u32,
sample_rate: u32,
m: u32,
ef_construction: u32,
}
#[pyclass(get_all)]

View File

@@ -15,7 +15,7 @@
use arrow::RecordBatchStream;
use connection::{connect, Connection};
use env_logger::Env;
use index::{Index, IndexConfig};
use index::IndexConfig;
use pyo3::{
pymodule,
types::{PyModule, PyModuleMethods},
@@ -40,7 +40,6 @@ pub fn _lancedb(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
env_logger::init_from_env(env);
m.add_class::<Connection>()?;
m.add_class::<Table>()?;
m.add_class::<Index>()?;
m.add_class::<IndexConfig>()?;
m.add_class::<Query>()?;
m.add_class::<VectorQuery>()?;

View File

@@ -19,7 +19,7 @@ use pyo3_async_runtimes::tokio::future_into_py;
use crate::{
error::PythonErrorExt,
index::{Index, IndexConfig},
index::{extract_index_params, IndexConfig},
query::Query,
};
@@ -177,14 +177,10 @@ impl Table {
pub fn create_index<'a>(
self_: PyRef<'a, Self>,
column: String,
index: Option<&Index>,
index: Option<Bound<'_, PyAny>>,
replace: Option<bool>,
) -> PyResult<Bound<'a, PyAny>> {
let index = if let Some(index) = index {
index.consume()?
} else {
lancedb::index::Index::Auto
};
let index = extract_index_params(&index)?;
let mut op = self_.inner_ref()?.create_index(&[column], index);
if let Some(replace) = replace {
op = op.replace(replace);