From 980aa70e2d1c8e1df04e3ec0ff12daabf4d563ef Mon Sep 17 00:00:00 2001 From: Will Jones Date: Fri, 13 Dec 2024 12:56:44 -0800 Subject: [PATCH] feat(python): async-sync feature parity on Table (#1914) ### Changes to sync API * Updated `LanceTable` and `LanceDBConnection` reprs * Add `storage_options`, `data_storage_version`, and `enable_v2_manifest_paths` to sync create table API. * Add `storage_options` to `open_table` in sync API. * Add `list_indices()` and `index_stats()` to sync API * `create_table()` will now create only 1 version when data is passed. Previously it would always create two versions: 1 to create an empty table and 1 to add data to it. ### Changes to async API * Add `embedding_functions` to async `create_table()` API. * Added `head()` to async API ### Refactors * Refactor index parameters into dataclasses so they are easier to use from Python * Moved most tests to use an in-memory DB so we don't need to create so many temp directories Closes #1792 Closes #1932 --------- Co-authored-by: Weston Pace --- docs/src/basic.md | 8 - docs/src/migration.md | 79 +-- docs/src/python/python.md | 2 + python/python/lancedb/__init__.py | 10 +- python/python/lancedb/_lancedb.pyi | 17 +- python/python/lancedb/background_loop.py | 3 + python/python/lancedb/db.py | 96 ++- python/python/lancedb/index.py | 342 +++++----- python/python/lancedb/remote/db.py | 8 +- python/python/lancedb/remote/table.py | 20 +- python/python/lancedb/table.py | 787 ++++++++++++----------- python/python/lancedb/util.py | 12 + python/python/tests/conftest.py | 32 + python/python/tests/test_db.py | 214 +++--- python/python/tests/test_fts.py | 43 ++ python/python/tests/test_table.py | 566 ++++++++-------- python/python/tests/utils.py | 11 + python/src/connection.rs | 5 + python/src/index.rs | 335 ++++------ python/src/lib.rs | 3 +- python/src/table.rs | 10 +- rust/lancedb/src/connection.rs | 15 +- rust/lancedb/src/index/scalar.rs | 2 +- 23 files changed, 1296 insertions(+), 1324 deletions(-) create mode 100644 python/python/tests/conftest.py create mode 100644 python/python/tests/utils.py diff --git a/docs/src/basic.md b/docs/src/basic.md index 8e4a7402..1faf206f 100644 --- a/docs/src/basic.md +++ b/docs/src/basic.md @@ -141,14 +141,6 @@ recommend switching to stable releases. --8<-- "python/python/tests/docs/test_basic.py:connect_async" ``` - !!! note "Asynchronous Python API" - - The asynchronous Python API is new and has some slight differences compared - to the synchronous API. Feel free to start using the asynchronous version. - Once all features have migrated we will start to move the synchronous API to - use the same syntax as the asynchronous API. To help with this migration we - have created a [migration guide](migration.md) detailing the differences. - === "Typescript[^1]" === "@lancedb/lancedb" diff --git a/docs/src/migration.md b/docs/src/migration.md index b44748da..de77d705 100644 --- a/docs/src/migration.md +++ b/docs/src/migration.md @@ -1,81 +1,14 @@ # Rust-backed Client Migration Guide -In an effort to ensure all clients have the same set of capabilities we have begun migrating the -python and node clients onto a common Rust base library. In python, this new client is part of -the same lancedb package, exposed as an asynchronous client. Once the asynchronous client has -reached full functionality we will begin migrating the synchronous library to be a thin wrapper -around the asynchronous client. +In an effort to ensure all clients have the same set of capabilities we have +migrated the Python and Node clients onto a common Rust base library. In Python, +both the synchronous and asynchronous clients are based on this implementation. +In Node, the new client is available as `@lancedb/lancedb`, which replaces +the existing `vectordb` package. -This guide describes the differences between the two APIs and will hopefully assist users +This guide describes the differences between the two Node APIs and will hopefully assist users that would like to migrate to the new API. -## Python -### Closeable Connections - -The Connection now has a `close` method. You can call this when -you are done with the connection to eagerly free resources. Currently -this is limited to freeing/closing the HTTP connection for remote -connections. In the future we may add caching or other resources to -native connections so this is probably a good practice even if you -aren't using remote connections. - -In addition, the connection can be used as a context manager which may -be a more convenient way to ensure the connection is closed. - -```python -import lancedb - -async def my_async_fn(): - with await lancedb.connect_async("my_uri") as db: - print(await db.table_names()) -``` - -It is not mandatory to call the `close` method. If you do not call it -then the connection will be closed when the object is garbage collected. - -### Closeable Table - -The Table now also has a `close` method, similar to the connection. This -can be used to eagerly free the cache used by a Table object. Similar to -the connection, it can be used as a context manager and it is not mandatory -to call the `close` method. - -#### Changes to Table APIs - -- Previously `Table.schema` was a property. Now it is an async method. -- The method `Table.__len__` was removed and `len(table)` will no longer - work. Use `Table.count_rows` instead. - -#### Creating Indices - -The `Table.create_index` method is now used for creating both vector indices -and scalar indices. It currently requires a column name to be specified (the -column to index). Vector index defaults are now smarter and scale better with -the size of the data. - -To specify index configuration details you will need to specify which kind of -index you are using. - -#### Querying - -The `Table.search` method has been renamed to `AsyncTable.vector_search` for -clarity. - -### Features not yet supported - -The following features are not yet supported by the asynchronous API. However, -we plan to support them soon. - -- You cannot specify an embedding function when creating or opening a table. - You must calculate embeddings yourself if using the asynchronous API -- The merge insert operation is not supported in the asynchronous API -- Cleanup / compact / optimize indices are not supported in the asynchronous API -- add / alter columns is not supported in the asynchronous API -- The asynchronous API does not yet support any full text search or reranking - search -- Remote connections to LanceDb Cloud are not yet supported. -- The method Table.head is not yet supported. - ## TypeScript/JavaScript For JS/TS users, we offer a brand new SDK [@lancedb/lancedb](https://www.npmjs.com/package/@lancedb/lancedb) diff --git a/docs/src/python/python.md b/docs/src/python/python.md index eba6e1fd..c250c5f3 100644 --- a/docs/src/python/python.md +++ b/docs/src/python/python.md @@ -47,6 +47,8 @@ is also an [asynchronous API client](#connections-asynchronous). ::: lancedb.embeddings.registry.EmbeddingFunctionRegistry +::: lancedb.embeddings.base.EmbeddingFunctionConfig + ::: lancedb.embeddings.base.EmbeddingFunction ::: lancedb.embeddings.base.TextEmbeddingFunction diff --git a/python/python/lancedb/__init__.py b/python/python/lancedb/__init__.py index 49149bf1..93365073 100644 --- a/python/python/lancedb/__init__.py +++ b/python/python/lancedb/__init__.py @@ -70,7 +70,7 @@ def connect( default configuration is used. storage_options: dict, optional Additional options for the storage backend. See available options at - https://lancedb.github.io/lancedb/guides/storage/ + Examples -------- @@ -82,11 +82,13 @@ def connect( For object storage, use a URI prefix: - >>> db = lancedb.connect("s3://my-bucket/lancedb") + >>> db = lancedb.connect("s3://my-bucket/lancedb", + ... storage_options={"aws_access_key_id": "***"}) Connect to LanceDB cloud: - >>> db = lancedb.connect("db://my_database", api_key="ldb_...") + >>> db = lancedb.connect("db://my_database", api_key="ldb_...", + ... client_config={"retry_config": {"retries": 5}}) Returns ------- @@ -164,7 +166,7 @@ async def connect_async( default configuration is used. storage_options: dict, optional Additional options for the storage backend. See available options at - https://lancedb.github.io/lancedb/guides/storage/ + Examples -------- diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index a6f29dc5..3d87d0cf 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -2,19 +2,8 @@ from typing import Dict, List, Optional, Tuple import pyarrow as pa -class Index: - @staticmethod - def ivf_pq( - distance_type: Optional[str], - num_partitions: Optional[int], - num_sub_vectors: Optional[int], - max_iterations: Optional[int], - sample_rate: Optional[int], - ) -> Index: ... - @staticmethod - def btree() -> Index: ... - class Connection(object): + uri: str async def table_names( self, start_after: Optional[str], limit: Optional[int] ) -> list[str]: ... @@ -46,9 +35,7 @@ class Table: async def add(self, data: pa.RecordBatchReader, mode: str) -> None: ... async def update(self, updates: Dict[str, str], where: Optional[str]) -> None: ... async def count_rows(self, filter: Optional[str]) -> int: ... - async def create_index( - self, column: str, config: Optional[Index], replace: Optional[bool] - ): ... + async def create_index(self, column: str, config, replace: Optional[bool]): ... async def version(self) -> int: ... async def checkout(self, version): ... async def checkout_latest(self): ... diff --git a/python/python/lancedb/background_loop.py b/python/python/lancedb/background_loop.py index 4b7b8632..8dd2d00f 100644 --- a/python/python/lancedb/background_loop.py +++ b/python/python/lancedb/background_loop.py @@ -23,3 +23,6 @@ class BackgroundEventLoop: def run(self, future): return asyncio.run_coroutine_threadsafe(future, self.loop).result() + + +LOOP = BackgroundEventLoop() diff --git a/python/python/lancedb/db.py b/python/python/lancedb/db.py index ab75ec1f..18d4076c 100644 --- a/python/python/lancedb/db.py +++ b/python/python/lancedb/db.py @@ -17,10 +17,11 @@ from abc import abstractmethod from pathlib import Path from typing import TYPE_CHECKING, Dict, Iterable, List, Literal, Optional, Union +from lancedb.embeddings.registry import EmbeddingFunctionRegistry from overrides import EnforceOverrides, override from lancedb.common import data_to_reader, sanitize_uri, validate_schema -from lancedb.background_loop import BackgroundEventLoop +from lancedb.background_loop import LOOP from ._lancedb import connect as lancedb_connect from .table import ( @@ -43,8 +44,6 @@ if TYPE_CHECKING: from .common import DATA, URI from .embeddings import EmbeddingFunctionConfig -LOOP = BackgroundEventLoop() - class DBConnection(EnforceOverrides): """An active LanceDB connection interface.""" @@ -82,6 +81,10 @@ class DBConnection(EnforceOverrides): on_bad_vectors: str = "error", fill_value: float = 0.0, embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None, + *, + storage_options: Optional[Dict[str, str]] = None, + data_storage_version: Optional[str] = None, + enable_v2_manifest_paths: Optional[bool] = None, ) -> Table: """Create a [Table][lancedb.table.Table] in the database. @@ -119,6 +122,24 @@ class DBConnection(EnforceOverrides): One of "error", "drop", "fill". fill_value: float The value to use when filling vectors. Only used if on_bad_vectors="fill". + storage_options: dict, optional + Additional options for the storage backend. Options already set on the + connection will be inherited by the table, but can be overridden here. + See available options at + + data_storage_version: optional, str, default "stable" + The version of the data storage format to use. Newer versions are more + efficient but require newer versions of lance to read. The default is + "stable" which will use the legacy v2 version. See the user guide + for more details. + enable_v2_manifest_paths: bool, optional, default False + Use the new V2 manifest paths. These paths provide more efficient + opening of datasets with many versions on object stores. WARNING: + turning this on will make the dataset unreadable for older versions + of LanceDB (prior to 0.13.0). To migrate an existing dataset, instead + use the + [Table.migrate_manifest_paths_v2][lancedb.table.Table.migrate_v2_manifest_paths] + method. Returns ------- @@ -140,7 +161,7 @@ class DBConnection(EnforceOverrides): >>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7}, ... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}] >>> db.create_table("my_table", data) - LanceTable(connection=..., name="my_table") + LanceTable(name='my_table', version=1, ...) >>> db["my_table"].head() pyarrow.Table vector: fixed_size_list[2] @@ -161,7 +182,7 @@ class DBConnection(EnforceOverrides): ... "long": [-122.7, -74.1] ... }) >>> db.create_table("table2", data) - LanceTable(connection=..., name="table2") + LanceTable(name='table2', version=1, ...) >>> db["table2"].head() pyarrow.Table vector: fixed_size_list[2] @@ -184,7 +205,7 @@ class DBConnection(EnforceOverrides): ... pa.field("long", pa.float32()) ... ]) >>> db.create_table("table3", data, schema = custom_schema) - LanceTable(connection=..., name="table3") + LanceTable(name='table3', version=1, ...) >>> db["table3"].head() pyarrow.Table vector: fixed_size_list[2] @@ -218,7 +239,7 @@ class DBConnection(EnforceOverrides): ... pa.field("price", pa.float32()), ... ]) >>> db.create_table("table4", make_batches(), schema=schema) - LanceTable(connection=..., name="table4") + LanceTable(name='table4', version=1, ...) """ raise NotImplementedError @@ -226,7 +247,13 @@ class DBConnection(EnforceOverrides): def __getitem__(self, name: str) -> LanceTable: return self.open_table(name) - def open_table(self, name: str, *, index_cache_size: Optional[int] = None) -> Table: + def open_table( + self, + name: str, + *, + storage_options: Optional[Dict[str, str]] = None, + index_cache_size: Optional[int] = None, + ) -> Table: """Open a Lance Table in the database. Parameters @@ -243,6 +270,11 @@ class DBConnection(EnforceOverrides): This cache applies to the entire opened table, across all indices. Setting this value higher will increase performance on larger datasets at the expense of more RAM + storage_options: dict, optional + Additional options for the storage backend. Options already set on the + connection will be inherited by the table, but can be overridden here. + See available options at + Returns ------- @@ -309,15 +341,15 @@ class LanceDBConnection(DBConnection): >>> db = lancedb.connect("./.lancedb") >>> db.create_table("my_table", data=[{"vector": [1.1, 1.2], "b": 2}, ... {"vector": [0.5, 1.3], "b": 4}]) - LanceTable(connection=..., name="my_table") + LanceTable(name='my_table', version=1, ...) >>> db.create_table("another_table", data=[{"vector": [0.4, 0.4], "b": 6}]) - LanceTable(connection=..., name="another_table") + LanceTable(name='another_table', version=1, ...) >>> sorted(db.table_names()) ['another_table', 'my_table'] >>> len(db) 2 >>> db["my_table"] - LanceTable(connection=..., name="my_table") + LanceTable(name='my_table', version=1, ...) >>> "my_table" in db True >>> db.drop_table("my_table") @@ -363,7 +395,7 @@ class LanceDBConnection(DBConnection): self._conn = AsyncConnection(LOOP.run(do_connect())) def __repr__(self) -> str: - val = f"{self.__class__.__name__}({self._uri}" + val = f"{self.__class__.__name__}(uri={self._uri!r}" if self.read_consistency_interval is not None: val += f", read_consistency_interval={repr(self.read_consistency_interval)}" val += ")" @@ -403,6 +435,10 @@ class LanceDBConnection(DBConnection): on_bad_vectors: str = "error", fill_value: float = 0.0, embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None, + *, + storage_options: Optional[Dict[str, str]] = None, + data_storage_version: Optional[str] = None, + enable_v2_manifest_paths: Optional[bool] = None, ) -> LanceTable: """Create a table in the database. @@ -424,12 +460,19 @@ class LanceDBConnection(DBConnection): on_bad_vectors=on_bad_vectors, fill_value=fill_value, embedding_functions=embedding_functions, + storage_options=storage_options, + data_storage_version=data_storage_version, + enable_v2_manifest_paths=enable_v2_manifest_paths, ) return tbl @override def open_table( - self, name: str, *, index_cache_size: Optional[int] = None + self, + name: str, + *, + storage_options: Optional[Dict[str, str]] = None, + index_cache_size: Optional[int] = None, ) -> LanceTable: """Open a table in the database. @@ -442,7 +485,12 @@ class LanceDBConnection(DBConnection): ------- A LanceTable object representing the table. """ - return LanceTable.open(self, name, index_cache_size=index_cache_size) + return LanceTable.open( + self, + name, + storage_options=storage_options, + index_cache_size=index_cache_size, + ) @override def drop_table(self, name: str, ignore_missing: bool = False): @@ -524,6 +572,10 @@ class AsyncConnection(object): Any attempt to use the connection after it is closed will result in an error.""" self._inner.close() + @property + def uri(self) -> str: + return self._inner.uri + async def table_names( self, *, start_after: Optional[str] = None, limit: Optional[int] = None ) -> Iterable[str]: @@ -557,6 +609,7 @@ class AsyncConnection(object): fill_value: Optional[float] = None, storage_options: Optional[Dict[str, str]] = None, *, + embedding_functions: List[EmbeddingFunctionConfig] = None, data_storage_version: Optional[str] = None, use_legacy_format: Optional[bool] = None, enable_v2_manifest_paths: Optional[bool] = None, @@ -601,7 +654,7 @@ class AsyncConnection(object): Additional options for the storage backend. Options already set on the connection will be inherited by the table, but can be overridden here. See available options at - https://lancedb.github.io/lancedb/guides/storage/ + data_storage_version: optional, str, default "stable" The version of the data storage format to use. Newer versions are more efficient but require newer versions of lance to read. The default is @@ -730,6 +783,17 @@ class AsyncConnection(object): """ metadata = None + if embedding_functions is not None: + # If we passed in embedding functions explicitly + # then we'll override any schema metadata that + # may was implicitly specified by the LanceModel schema + registry = EmbeddingFunctionRegistry.get_instance() + metadata = registry.get_table_metadata(embedding_functions) + + data, schema = sanitize_create_table( + data, schema, metadata, on_bad_vectors, fill_value + ) + # Defining defaults here and not in function prototype. In the future # these defaults will move into rust so better to keep them as None. if on_bad_vectors is None: @@ -791,7 +855,7 @@ class AsyncConnection(object): Additional options for the storage backend. Options already set on the connection will be inherited by the table, but can be overridden here. See available options at - https://lancedb.github.io/lancedb/guides/storage/ + index_cache_size: int, default 256 Set the size of the index cache, specified as a number of entries diff --git a/python/python/lancedb/index.py b/python/python/lancedb/index.py index 55fa0e82..c34d6ad8 100644 --- a/python/python/lancedb/index.py +++ b/python/python/lancedb/index.py @@ -1,8 +1,6 @@ -from typing import Optional +from dataclasses import dataclass +from typing import Literal, Optional -from ._lancedb import ( - Index as LanceDbIndex, -) from ._lancedb import ( IndexConfig, ) @@ -29,6 +27,7 @@ lang_mapping = { } +@dataclass class BTree: """Describes a btree index configuration @@ -50,10 +49,10 @@ class BTree: the block size may be added in the future. """ - def __init__(self): - self._inner = LanceDbIndex.btree() + pass +@dataclass class Bitmap: """Describe a Bitmap index configuration. @@ -73,10 +72,10 @@ class Bitmap: requires 128 / 8 * 1Bi bytes on disk. """ - def __init__(self): - self._inner = LanceDbIndex.bitmap() + pass +@dataclass class LabelList: """Describe a LabelList index configuration. @@ -87,41 +86,57 @@ class LabelList: For example, it works with `tags`, `categories`, `keywords`, etc. """ - def __init__(self): - self._inner = LanceDbIndex.label_list() + pass +@dataclass class FTS: """Describe a FTS index configuration. `FTS` is a full-text search index that can be used on `String` columns For example, it works with `title`, `description`, `content`, etc. + + Attributes + ---------- + with_position : bool, default True + Whether to store the position of the token in the document. Setting this + to False can reduce the size of the index and improve indexing speed, + but it will disable support for phrase queries. + base_tokenizer : str, default "simple" + The base tokenizer to use for tokenization. Options are: + - "simple": Splits text by whitespace and punctuation. + - "whitespace": Split text by whitespace, but not punctuation. + - "raw": No tokenization. The entire text is treated as a single token. + language : str, default "English" + The language to use for tokenization. + max_token_length : int, default 40 + The maximum token length to index. Tokens longer than this length will be + ignored. + lower_case : bool, default True + Whether to convert the token to lower case. This makes queries case-insensitive. + stem : bool, default False + Whether to stem the token. Stemming reduces words to their root form. + For example, in English "running" and "runs" would both be reduced to "run". + remove_stop_words : bool, default False + Whether to remove stop words. Stop words are common words that are often + removed from text before indexing. For example, in English "the" and "and". + ascii_folding : bool, default False + Whether to fold ASCII characters. This converts accented characters to + their ASCII equivalent. For example, "café" would be converted to "cafe". """ - def __init__( - self, - with_position: bool = True, - base_tokenizer: str = "simple", - language: str = "English", - max_token_length: Optional[int] = 40, - lower_case: bool = True, - stem: bool = False, - remove_stop_words: bool = False, - ascii_folding: bool = False, - ): - self._inner = LanceDbIndex.fts( - with_position=with_position, - base_tokenizer=base_tokenizer, - language=language, - max_token_length=max_token_length, - lower_case=lower_case, - stem=stem, - remove_stop_words=remove_stop_words, - ascii_folding=ascii_folding, - ) + with_position: bool = True + base_tokenizer: Literal["simple", "raw", "whitespace"] = "simple" + language: str = "English" + max_token_length: Optional[int] = 40 + lower_case: bool = True + stem: bool = False + remove_stop_words: bool = False + ascii_folding: bool = False +@dataclass class HnswPq: """Describe a HNSW-PQ index configuration. @@ -232,30 +247,17 @@ class HnswPq: search phase. """ - def __init__( - self, - *, - distance_type: Optional[str] = None, - num_partitions: Optional[int] = None, - num_sub_vectors: Optional[int] = None, - num_bits: Optional[int] = None, - max_iterations: Optional[int] = None, - sample_rate: Optional[int] = None, - m: Optional[int] = None, - ef_construction: Optional[int] = None, - ): - self._inner = LanceDbIndex.hnsw_pq( - distance_type=distance_type, - num_partitions=num_partitions, - num_sub_vectors=num_sub_vectors, - num_bits=num_bits, - max_iterations=max_iterations, - sample_rate=sample_rate, - m=m, - ef_construction=ef_construction, - ) + distance_type: Literal["l2", "cosine", "dot"] = "l2" + num_partitions: Optional[int] = None + num_sub_vectors: Optional[int] = None + num_bits: int = 8 + max_iterations: int = 50 + sample_rate: int = 256 + m: int = 20 + ef_construction: int = 300 +@dataclass class HnswSq: """Describe a HNSW-SQ index configuration. @@ -345,26 +347,15 @@ class HnswSq: """ - def __init__( - self, - *, - distance_type: Optional[str] = None, - num_partitions: Optional[int] = None, - max_iterations: Optional[int] = None, - sample_rate: Optional[int] = None, - m: Optional[int] = None, - ef_construction: Optional[int] = None, - ): - self._inner = LanceDbIndex.hnsw_sq( - distance_type=distance_type, - num_partitions=num_partitions, - max_iterations=max_iterations, - sample_rate=sample_rate, - m=m, - ef_construction=ef_construction, - ) + distance_type: Literal["l2", "cosine", "dot"] = "l2" + num_partitions: Optional[int] = None + max_iterations: int = 50 + sample_rate: int = 256 + m: int = 20 + ef_construction: int = 300 +@dataclass class IvfPq: """Describes an IVF PQ Index @@ -387,120 +378,103 @@ class IvfPq: Note that training an IVF PQ index on a large dataset is a slow operation and currently is also a memory intensive operation. + + Attributes + ---------- + distance_type: str, default "L2" + The distance metric used to train the index + + This is used when training the index to calculate the IVF partitions + (vectors are grouped in partitions with similar vectors according to this + distance type) and to calculate a subvector's code during quantization. + + The distance type used to train an index MUST match the distance type used + to search the index. Failure to do so will yield inaccurate results. + + The following distance types are available: + + "l2" - Euclidean distance. This is a very common distance metric that + accounts for both magnitude and direction when determining the distance + between vectors. L2 distance has a range of [0, ∞). + + "cosine" - Cosine distance. Cosine distance is a distance metric + calculated from the cosine similarity between two vectors. Cosine + similarity is a measure of similarity between two non-zero vectors of an + inner product space. It is defined to equal the cosine of the angle + between them. Unlike L2, the cosine distance is not affected by the + magnitude of the vectors. Cosine distance has a range of [0, 2]. + + Note: the cosine distance is undefined when one (or both) of the vectors + are all zeros (there is no direction). These vectors are invalid and may + never be returned from a vector search. + + "dot" - Dot product. Dot distance is the dot product of two vectors. Dot + distance has a range of (-∞, ∞). If the vectors are normalized (i.e. their + L2 norm is 1), then dot distance is equivalent to the cosine distance. + num_partitions: int, default sqrt(num_rows) + The number of IVF partitions to create. + + This value should generally scale with the number of rows in the dataset. + By default the number of partitions is the square root of the number of + rows. + + If this value is too large then the first part of the search (picking the + right partition) will be slow. If this value is too small then the second + part of the search (searching within a partition) will be slow. + num_sub_vectors: int, default is vector dimension / 16 + Number of sub-vectors of PQ. + + This value controls how much the vector is compressed during the + quantization step. The more sub vectors there are the less the vector is + compressed. The default is the dimension of the vector divided by 16. If + the dimension is not evenly divisible by 16 we use the dimension divded by + 8. + + The above two cases are highly preferred. Having 8 or 16 values per + subvector allows us to use efficient SIMD instructions. + + If the dimension is not visible by 8 then we use 1 subvector. This is not + ideal and will likely result in poor performance. + num_bits: int, default 8 + Number of bits to encode each sub-vector. + + This value controls how much the sub-vectors are compressed. The more bits + the more accurate the index but the slower search. The default is 8 + bits. Only 4 and 8 are supported. + max_iterations: int, default 50 + Max iteration to train kmeans. + + When training an IVF PQ index we use kmeans to calculate the partitions. + This parameter controls how many iterations of kmeans to run. + + Increasing this might improve the quality of the index but in most cases + these extra iterations have diminishing returns. + + The default value is 50. + sample_rate: int, default 256 + The rate used to calculate the number of training vectors for kmeans. + + When an IVF PQ index is trained, we need to calculate partitions. These + are groups of vectors that are similar to each other. To do this we use an + algorithm called kmeans. + + Running kmeans on a large dataset can be slow. To speed this up we run + kmeans on a random sample of the data. This parameter controls the size of + the sample. The total number of vectors used to train the index is + `sample_rate * num_partitions`. + + Increasing this value might improve the quality of the index but in most + cases the default should be sufficient. + + The default value is 256. """ - def __init__( - self, - *, - distance_type: Optional[str] = None, - num_partitions: Optional[int] = None, - num_sub_vectors: Optional[int] = None, - num_bits: Optional[int] = None, - max_iterations: Optional[int] = None, - sample_rate: Optional[int] = None, - ): - """ - Create an IVF PQ index config - - Parameters - ---------- - distance_type: str, default "L2" - The distance metric used to train the index - - This is used when training the index to calculate the IVF partitions - (vectors are grouped in partitions with similar vectors according to this - distance type) and to calculate a subvector's code during quantization. - - The distance type used to train an index MUST match the distance type used - to search the index. Failure to do so will yield inaccurate results. - - The following distance types are available: - - "l2" - Euclidean distance. This is a very common distance metric that - accounts for both magnitude and direction when determining the distance - between vectors. L2 distance has a range of [0, ∞). - - "cosine" - Cosine distance. Cosine distance is a distance metric - calculated from the cosine similarity between two vectors. Cosine - similarity is a measure of similarity between two non-zero vectors of an - inner product space. It is defined to equal the cosine of the angle - between them. Unlike L2, the cosine distance is not affected by the - magnitude of the vectors. Cosine distance has a range of [0, 2]. - - Note: the cosine distance is undefined when one (or both) of the vectors - are all zeros (there is no direction). These vectors are invalid and may - never be returned from a vector search. - - "dot" - Dot product. Dot distance is the dot product of two vectors. Dot - distance has a range of (-∞, ∞). If the vectors are normalized (i.e. their - L2 norm is 1), then dot distance is equivalent to the cosine distance. - num_partitions: int, default sqrt(num_rows) - The number of IVF partitions to create. - - This value should generally scale with the number of rows in the dataset. - By default the number of partitions is the square root of the number of - rows. - - If this value is too large then the first part of the search (picking the - right partition) will be slow. If this value is too small then the second - part of the search (searching within a partition) will be slow. - num_sub_vectors: int, default is vector dimension / 16 - Number of sub-vectors of PQ. - - This value controls how much the vector is compressed during the - quantization step. The more sub vectors there are the less the vector is - compressed. The default is the dimension of the vector divided by 16. If - the dimension is not evenly divisible by 16 we use the dimension divded by - 8. - - The above two cases are highly preferred. Having 8 or 16 values per - subvector allows us to use efficient SIMD instructions. - - If the dimension is not visible by 8 then we use 1 subvector. This is not - ideal and will likely result in poor performance. - num_bits: int, default 8 - Number of bits to encode each sub-vector. - - This value controls how much the sub-vectors are compressed. The more bits - the more accurate the index but the slower search. The default is 8 - bits. Only 4 and 8 are supported. - max_iterations: int, default 50 - Max iteration to train kmeans. - - When training an IVF PQ index we use kmeans to calculate the partitions. - This parameter controls how many iterations of kmeans to run. - - Increasing this might improve the quality of the index but in most cases - these extra iterations have diminishing returns. - - The default value is 50. - sample_rate: int, default 256 - The rate used to calculate the number of training vectors for kmeans. - - When an IVF PQ index is trained, we need to calculate partitions. These - are groups of vectors that are similar to each other. To do this we use an - algorithm called kmeans. - - Running kmeans on a large dataset can be slow. To speed this up we run - kmeans on a random sample of the data. This parameter controls the size of - the sample. The total number of vectors used to train the index is - `sample_rate * num_partitions`. - - Increasing this value might improve the quality of the index but in most - cases the default should be sufficient. - - The default value is 256. - """ - if distance_type is not None: - distance_type = distance_type.lower() - self._inner = LanceDbIndex.ivf_pq( - distance_type=distance_type, - num_partitions=num_partitions, - num_sub_vectors=num_sub_vectors, - num_bits=num_bits, - max_iterations=max_iterations, - sample_rate=sample_rate, - ) + distance_type: Literal["l2", "cosine", "dot"] = "l2" + num_partitions: Optional[int] = None + num_sub_vectors: Optional[int] = None + num_bits: int = 8 + max_iterations: int = 50 + sample_rate: int = 256 -__all__ = ["BTree", "IvfPq", "IndexConfig"] +__all__ = ["BTree", "IvfPq", "HnswPq", "HnswSq", "IndexConfig"] diff --git a/python/python/lancedb/remote/db.py b/python/python/lancedb/remote/db.py index 1dfa1c9a..f53ab10a 100644 --- a/python/python/lancedb/remote/db.py +++ b/python/python/lancedb/remote/db.py @@ -121,7 +121,13 @@ class RemoteDBConnection(DBConnection): return LOOP.run(self._conn.table_names(start_after=page_token, limit=limit)) @override - def open_table(self, name: str, *, index_cache_size: Optional[int] = None) -> Table: + def open_table( + self, + name: str, + *, + storage_options: Optional[Dict[str, str]] = None, + index_cache_size: Optional[int] = None, + ) -> Table: """Open a Lance Table in the database. Parameters diff --git a/python/python/lancedb/remote/table.py b/python/python/lancedb/remote/table.py index 6f5399f6..5801ba52 100644 --- a/python/python/lancedb/remote/table.py +++ b/python/python/lancedb/remote/table.py @@ -16,6 +16,8 @@ import logging from functools import cached_property from typing import Dict, Iterable, List, Optional, Union, Literal +from lancedb._lancedb import IndexConfig +from lancedb.embeddings.base import EmbeddingFunctionConfig from lancedb.index import FTS, BTree, Bitmap, HnswPq, HnswSq, IvfPq, LabelList from lancedb.remote.db import LOOP import pyarrow as pa @@ -25,7 +27,7 @@ from lancedb.merge import LanceMergeInsertBuilder from lancedb.embeddings import EmbeddingFunctionRegistry from ..query import LanceVectorQueryBuilder, LanceQueryBuilder -from ..table import AsyncTable, Query, Table +from ..table import AsyncTable, IndexStatistics, Query, Table class RemoteTable(Table): @@ -62,7 +64,7 @@ class RemoteTable(Table): return LOOP.run(self._table.version()) @cached_property - def embedding_functions(self) -> dict: + def embedding_functions(self) -> Dict[str, EmbeddingFunctionConfig]: """ Get the embedding functions for the table @@ -94,11 +96,11 @@ class RemoteTable(Table): def checkout_latest(self): return LOOP.run(self._table.checkout_latest()) - def list_indices(self): + def list_indices(self) -> Iterable[IndexConfig]: """List all the indices on the table""" return LOOP.run(self._table.list_indices()) - def index_stats(self, index_uuid: str): + def index_stats(self, index_uuid: str) -> Optional[IndexStatistics]: """List all the stats of a specified index""" return LOOP.run(self._table.index_stats(index_uuid)) @@ -515,6 +517,16 @@ class RemoteTable(Table): def drop_columns(self, columns: Iterable[str]): return LOOP.run(self._table.drop_columns(columns)) + def uses_v2_manifest_paths(self) -> bool: + raise NotImplementedError( + "uses_v2_manifest_paths() is not supported on the LanceDB Cloud" + ) + + def migrate_v2_manifest_paths(self): + raise NotImplementedError( + "migrate_v2_manifest_paths() is not supported on the LanceDB Cloud" + ) + def add_index(tbl: pa.Table, i: int) -> pa.Table: return tbl.add_column( diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 3a82f9fa..4ba73405 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -3,9 +3,7 @@ from __future__ import annotations -import asyncio import inspect -import time from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import datetime, timedelta @@ -25,6 +23,7 @@ from typing import ( from urllib.parse import urlparse import lance +from lancedb.background_loop import LOOP from .dependencies import _check_for_pandas import numpy as np import pyarrow as pa @@ -33,8 +32,9 @@ import pyarrow.fs as pa_fs from lance import LanceDataset from lance.dependencies import _check_for_hugging_face -from .common import DATA, VEC, VECTOR_COLUMN_NAME, sanitize_uri +from .common import DATA, VEC, VECTOR_COLUMN_NAME from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry +from .index import BTree, IvfPq, Bitmap, LabelList, HnswPq, HnswSq, FTS from .merge import LanceMergeInsertBuilder from .pydantic import LanceModel, model_to_dict from .query import ( @@ -48,6 +48,7 @@ from .query import ( Query, ) from .util import ( + add_note, fs_from_uri, get_uri_scheme, infer_vector_column_name, @@ -58,14 +59,13 @@ from .util import ( ) from .index import lang_mapping -from ._lancedb import connect as lancedb_connect if TYPE_CHECKING: import PIL from lance.dataset import CleanupStats, ReaderLike from ._lancedb import Table as LanceDBTable, OptimizeStats from .db import LanceDBConnection - from .index import BTree, IndexConfig, IvfPq, Bitmap, LabelList, FTS, HnswPq, HnswSq + from .index import IndexConfig pd = safe_import_pandas() pl = safe_import_polars() @@ -364,6 +364,18 @@ class Table(ABC): [Table.create_index][lancedb.table.Table.create_index]. """ + @property + @abstractmethod + def name(self) -> str: + """The name of this Table""" + raise NotImplementedError + + @property + @abstractmethod + def version(self) -> int: + """The version of this Table""" + raise NotImplementedError + @property @abstractmethod def schema(self) -> pa.Schema: @@ -373,6 +385,13 @@ class Table(ABC): """ raise NotImplementedError + @property + @abstractmethod + def embedding_functions(self) -> Dict[str, EmbeddingFunctionConfig]: + """ + Get a mapping from vector column name to it's configured embedding function. + """ + @abstractmethod def count_rows(self, filter: Optional[str] = None) -> int: """ @@ -414,7 +433,12 @@ class Table(ABC): accelerator: Optional[str] = None, index_cache_size: Optional[int] = None, *, + index_type: Literal["IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"] = "IVF_PQ", num_bits: int = 8, + max_iterations: int = 50, + sample_rate: int = 256, + m: int = 20, + ef_construction: int = 300, ): """Create an index on the table. @@ -457,28 +481,39 @@ class Table(ABC): ): """Create a scalar index on a column. + Parameters + ---------- + column : str + The column to be indexed. Must be a boolean, integer, float, + or string column. + replace : bool, default True + Replace the existing index if it exists. + index_type: Literal["BTREE", "BITMAP", "LABEL_LIST"], default "BTREE" + The type of index to create. + + Examples + -------- + Scalar indices, like vector indices, can be used to speed up scans. A scalar index can speed up scans that contain filter expressions on the indexed column. For example, the following scan will be faster if the column ``my_col`` has a scalar index: - - import lancedb - - db = lancedb.connect("/data/lance") - img_table = db.open_table("images") - my_df = img_table.search().where("my_col = 7", prefilter=True).to_pandas() + >>> import lancedb # doctest: +SKIP + >>> db = lancedb.connect("/data/lance") # doctest: +SKIP + >>> img_table = db.open_table("images") # doctest: +SKIP + >>> my_df = img_table.search().where("my_col = 7", # doctest: +SKIP + ... prefilter=True).to_pandas() Scalar indices can also speed up scans containing a vector search and a prefilter: - import lancedb - - db = lancedb.connect("/data/lance") - img_table = db.open_table("images") - img_table.search([1, 2, 3, 4], vector_column_name="vector") - .where("my_col != 7", prefilter=True) - .to_pandas() + >>> import lancedb # doctest: +SKIP + >>> db = lancedb.connect("/data/lance") # doctest: +SKIP + >>> img_table = db.open_table("images") # doctest: +SKIP + >>> img_table.search([1, 2, 3, 4], vector_column_name="vector") # doctest: +SKIP + ... .where("my_col != 7", prefilter=True) + ... .to_pandas() Scalar indices can only speed up scans for basic filters using equality, comparison, range (e.g. ``my_col BETWEEN 0 AND 100``), and set @@ -493,27 +528,6 @@ class Table(ABC): if the column ``not_indexed`` does not have a scalar index then the filter ``my_col = 0 OR not_indexed = 1`` will not be able to use any scalar index on ``my_col``. - - **Experimental API** - - Parameters - ---------- - column : str - The column to be indexed. Must be a boolean, integer, float, - or string column. - replace : bool, default True - Replace the existing index if it exists. - index_type: Literal["BTREE", "BITMAP", "LABEL_LIST"], default "BTREE" - The type of index to create. - - Examples - -------- - - - import lance - - dataset = lance.dataset("./images.lance") - dataset.create_scalar_index("category") """ raise NotImplementedError @@ -528,7 +542,7 @@ class Table(ABC): tokenizer_name: Optional[str] = None, with_position: bool = True, # tokenizer configs: - base_tokenizer: str = "simple", + base_tokenizer: Literal["simple", "raw", "whitespace"] = "simple", language: str = "English", max_token_length: Optional[int] = 40, lower_case: bool = True, @@ -568,7 +582,28 @@ class Table(ABC): If False, do not store the positions of the terms in the text. This can reduce the size of the index and improve indexing speed. But it will raise an exception for phrase queries. - + base_tokenizer : str, default "simple" + The base tokenizer to use for tokenization. Options are: + - "simple": Splits text by whitespace and punctuation. + - "whitespace": Split text by whitespace, but not punctuation. + - "raw": No tokenization. The entire text is treated as a single token. + language : str, default "English" + The language to use for tokenization. + max_token_length : int, default 40 + The maximum token length to index. Tokens longer than this length will be + ignored. + lower_case : bool, default True + Whether to convert the token to lower case. This makes queries + case-insensitive. + stem : bool, default False + Whether to stem the token. Stemming reduces words to their root form. + For example, in English "running" and "runs" would both be reduced to "run". + remove_stop_words : bool, default False + Whether to remove stop words. Stop words are common words that are often + removed from text before indexing. For example, in English "the" and "and". + ascii_folding : bool, default False + Whether to fold ASCII characters. This converts accented characters to + their ASCII equivalent. For example, "café" would be converted to "cafe". """ raise NotImplementedError @@ -967,6 +1002,29 @@ class Table(ABC): modification operations. """ + @abstractmethod + def list_indices(self) -> Iterable[IndexConfig]: + """ + List all indices that have been created with + [Table.create_index][lancedb.table.Table.create_index] + """ + + @abstractmethod + def index_stats(self, index_name: str) -> Optional[IndexStatistics]: + """ + Retrieve statistics about an index + + Parameters + ---------- + index_name: str + The name of the index to retrieve statistics for + + Returns + ------- + IndexStatistics or None + The statistics about the index. Returns None if the index does not exist. + """ + @abstractmethod def add_columns(self, transforms: Dict[str, str]): """ @@ -985,6 +1043,8 @@ class Table(ABC): """ Alter column names and nullability. + Parameters + ---------- alterations : Iterable[Dict[str, Any]] A sequence of dictionaries, each with the following keys: - "path": str @@ -1061,93 +1121,36 @@ class Table(ABC): index_exists = fs.get_file_info(path).type != pa_fs.FileType.NotFound return (path, fs, index_exists) - -class _LanceDatasetRef(ABC): - @property @abstractmethod - def dataset(self) -> LanceDataset: - pass + def uses_v2_manifest_paths(self) -> bool: + """ + Check if the table is using the new v2 manifest paths. + + Returns + ------- + bool + True if the table is using the new v2 manifest paths, False otherwise. + """ - @property @abstractmethod - def dataset_mut(self) -> LanceDataset: - pass + def migrate_v2_manifest_paths(self): + """ + Migrate the manifest paths to the new format. + This will update the manifest to use the new v2 format for paths. -@dataclass -class _LanceLatestDatasetRef(_LanceDatasetRef): - """Reference to the latest version of a LanceDataset.""" + This function is idempotent, and can be run multiple times without + changing the state of the object store. - uri: str - index_cache_size: Optional[int] = None - read_consistency_interval: Optional[timedelta] = None - last_consistency_check: Optional[float] = None - storage_options: Optional[Dict[str, str]] = None - _dataset: Optional[LanceDataset] = None + !!! danger - @property - def dataset(self) -> LanceDataset: - if not self._dataset: - self._dataset = lance.dataset( - self.uri, - index_cache_size=self.index_cache_size, - storage_options=self.storage_options, - ) - self.last_consistency_check = time.monotonic() - elif self.read_consistency_interval is not None: - now = time.monotonic() - diff = timedelta(seconds=now - self.last_consistency_check) - if ( - self.last_consistency_check is None - or diff > self.read_consistency_interval - ): - self._dataset = self._dataset.checkout_version( - self._dataset.latest_version - ) - self.last_consistency_check = time.monotonic() - return self._dataset + This should not be run while other concurrent operations are happening. + And it should also run until completion before resuming other operations. - @dataset.setter - def dataset(self, value: LanceDataset): - self._dataset = value - self.last_consistency_check = time.monotonic() - - @property - def dataset_mut(self) -> LanceDataset: - return self.dataset - - -@dataclass -class _LanceTimeTravelRef(_LanceDatasetRef): - uri: str - version: int - index_cache_size: Optional[int] = None - storage_options: Optional[Dict[str, str]] = None - _dataset: Optional[LanceDataset] = None - - @property - def dataset(self) -> LanceDataset: - if not self._dataset: - self._dataset = lance.dataset( - self.uri, - version=self.version, - index_cache_size=self.index_cache_size, - storage_options=self.storage_options, - ) - return self._dataset - - @dataset.setter - def dataset(self, value: LanceDataset): - self._dataset = value - self.version = value.version - - @property - def dataset_mut(self) -> LanceDataset: - raise ValueError( - "Cannot mutate table reference fixed at version " - f"{self.version}. Call checkout_latest() to get a mutable " - "table reference." - ) + You can use + [Table.uses_v2_manifest_paths][lancedb.table.Table.uses_v2_manifest_paths] + to check if the table is already using the new path style. + """ class LanceTable(Table): @@ -1169,27 +1172,22 @@ class LanceTable(Table): self, connection: "LanceDBConnection", name: str, - version: Optional[int] = None, *, + storage_options: Optional[Dict[str, str]] = None, index_cache_size: Optional[int] = None, ): self._conn = connection - self.name = name + self._table = LOOP.run( + connection._conn.open_table( + name, + storage_options=storage_options, + index_cache_size=index_cache_size, + ) + ) - if version is not None: - self._ref = _LanceTimeTravelRef( - uri=self._dataset_uri, - version=version, - index_cache_size=index_cache_size, - storage_options=connection.storage_options, - ) - else: - self._ref = _LanceLatestDatasetRef( - uri=self._dataset_uri, - read_consistency_interval=connection.read_consistency_interval, - index_cache_size=index_cache_size, - storage_options=connection.storage_options, - ) + @property + def name(self) -> str: + return self._table.name @classmethod def open(cls, db, name, **kwargs): @@ -1210,17 +1208,14 @@ class LanceTable(Table): # Cacheable since it's deterministic return _table_path(self._conn.uri, self.name) - @property - def _dataset(self) -> LanceDataset: - return self._ref.dataset - - @property - def _dataset_mut(self) -> LanceDataset: - return self._ref.dataset_mut - - def to_lance(self) -> LanceDataset: + def to_lance(self, **kwargs) -> LanceDataset: """Return the LanceDataset backing this table.""" - return self._dataset + return lance.dataset( + self._dataset_path, + version=self.version, + storage_options=self._conn.storage_options, + **kwargs, + ) @property def schema(self) -> pa.Schema: @@ -1230,16 +1225,16 @@ class LanceTable(Table): ------- pa.Schema A PyArrow schema object.""" - return self._dataset.schema + return LOOP.run(self._table.schema()) def list_versions(self): """List all versions of the table""" - return self._dataset.versions() + return LOOP.run(self._table.list_versions()) @property def version(self) -> int: """Get the current version of the table""" - return self._dataset.version + return LOOP.run(self._table.version()) def checkout(self, version: int): """Checkout a version of the table. This is an in-place operation. @@ -1263,38 +1258,19 @@ class LanceTable(Table): >>> table = db.create_table("my_table", ... [{"vector": [1.1, 0.9], "type": "vector"}]) >>> table.version - 2 + 1 >>> table.to_pandas() vector type 0 [1.1, 0.9] vector >>> table.add([{"vector": [0.5, 0.2], "type": "vector"}]) >>> table.version - 3 - >>> table.checkout(2) + 2 + >>> table.checkout(1) >>> table.to_pandas() vector type 0 [1.1, 0.9] vector """ - max_ver = self._dataset.latest_version - if version < 1 or version > max_ver: - raise ValueError(f"Invalid version {version}") - - try: - ds = self._dataset.checkout_version(version) - except IOError as e: - if "not found" in str(e): - raise ValueError( - f"Version {version} no longer exists. Was it cleaned up?" - ) - else: - raise e - - self._ref = _LanceTimeTravelRef( - uri=self._dataset_uri, - version=version, - ) - # We've already loaded the version so we can populate it directly. - self._ref.dataset = ds + LOOP.run(self._table.checkout(version)) def checkout_latest(self): """Checkout the latest version of the table. This is an in-place operation. @@ -1302,13 +1278,7 @@ class LanceTable(Table): The table will be set back into standard mode, and will track the latest version of the table. """ - self.checkout(self._dataset.latest_version) - ds = self._ref.dataset - self._ref = _LanceLatestDatasetRef( - uri=self._dataset_uri, - read_consistency_interval=self._conn.read_consistency_interval, - ) - self._ref.dataset = ds + LOOP.run(self._table.checkout_latest()) def restore(self, version: int = None): """Restore a version of the table. This is an in-place operation. @@ -1330,51 +1300,37 @@ class LanceTable(Table): >>> table = db.create_table("my_table", [ ... {"vector": [1.1, 0.9], "type": "vector"}]) >>> table.version - 2 + 1 >>> table.to_pandas() vector type 0 [1.1, 0.9] vector >>> table.add([{"vector": [0.5, 0.2], "type": "vector"}]) >>> table.version - 3 - >>> table.restore(2) + 2 + >>> table.restore(1) >>> table.to_pandas() vector type 0 [1.1, 0.9] vector >>> len(table.list_versions()) - 4 + 3 """ - max_ver = self._dataset.latest_version - if version is None: - version = self.version - elif version < 1 or version > max_ver: - raise ValueError(f"Invalid version {version}") - else: - self.checkout(version) - - ds = self._dataset - - # no-op if restoring the latest version - if version != max_ver: - ds.restore() - - self._ref = _LanceLatestDatasetRef( - uri=self._dataset_uri, - read_consistency_interval=self._conn.read_consistency_interval, - ) - self._ref.dataset = ds + if version is not None: + LOOP.run(self._table.checkout(version)) + LOOP.run(self._table.restore()) def count_rows(self, filter: Optional[str] = None) -> int: - return self._dataset.count_rows(filter) + return LOOP.run(self._table.count_rows(filter)) def __len__(self): return self.count_rows() def __repr__(self) -> str: - val = f'{self.__class__.__name__}(connection={self._conn!r}, name="{self.name}"' - if isinstance(self._ref, _LanceTimeTravelRef): - val += f", version={self._ref.version}" - val += ")" + val = f"{self.__class__.__name__}(name={self.name!r}, version={self.version}" + if self._conn.read_consistency_interval is not None: + val += ", read_consistency_interval={!r}".format( + self._conn.read_consistency_interval + ) + val += f", _conn={self._conn!r})" return val def __str__(self) -> str: @@ -1382,7 +1338,7 @@ class LanceTable(Table): def head(self, n=5) -> pa.Table: """Return the first n rows of the table.""" - return self._dataset.head(n) + return LOOP.run(self._table.head(n)) def to_pandas(self) -> "pd.DataFrame": """Return the table as a pandas DataFrame. @@ -1399,7 +1355,7 @@ class LanceTable(Table): Returns ------- pa.Table""" - return self._dataset.to_table() + return LOOP.run(self._table.to_arrow()) def to_polars(self, batch_size=None) -> "pl.LazyFrame": """Return the table as a polars LazyFrame. @@ -1421,34 +1377,85 @@ class LanceTable(Table): ------- pl.LazyFrame """ + from lancedb.integrations.pyarrow import PyarrowDatasetAdapter + + dataset = PyarrowDatasetAdapter(self) return pl.scan_pyarrow_dataset( - self.to_lance(), allow_pyarrow_filter=False, batch_size=batch_size + dataset, allow_pyarrow_filter=False, batch_size=batch_size ) def create_index( self, metric="L2", - num_partitions=256, - num_sub_vectors=96, + num_partitions=None, + num_sub_vectors=None, vector_column_name=VECTOR_COLUMN_NAME, replace: bool = True, accelerator: Optional[str] = None, index_cache_size: Optional[int] = None, - index_type="IVF_PQ", - *, num_bits: int = 8, + index_type: Literal["IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"] = "IVF_PQ", + max_iterations: int = 50, + sample_rate: int = 256, + m: int = 20, + ef_construction: int = 300, ): """Create an index on the table.""" - self._dataset_mut.create_index( - column=vector_column_name, - index_type=index_type, - metric=metric, - num_partitions=num_partitions, - num_sub_vectors=num_sub_vectors, - replace=replace, - accelerator=accelerator, - index_cache_size=index_cache_size, - num_bits=num_bits, + if accelerator is not None: + # accelerator is only supported through pylance. + self.to_lance().create_index( + column=vector_column_name, + index_type=index_type, + metric=metric, + num_partitions=num_partitions, + num_sub_vectors=num_sub_vectors, + replace=replace, + accelerator=accelerator, + index_cache_size=index_cache_size, + num_bits=num_bits, + m=m, + ef_construction=ef_construction, + ) + self.checkout_latest() + return + elif index_type == "IVF_PQ": + config = IvfPq( + distance_type=metric, + num_partitions=num_partitions, + num_sub_vectors=num_sub_vectors, + num_bits=num_bits, + max_iterations=max_iterations, + sample_rate=sample_rate, + ) + elif index_type == "IVF_HNSW_PQ": + config = HnswPq( + distance_type=metric, + num_partitions=num_partitions, + num_sub_vectors=num_sub_vectors, + num_bits=num_bits, + max_iterations=max_iterations, + sample_rate=sample_rate, + m=m, + ef_construction=ef_construction, + ) + elif index_type == "IVF_HNSW_SQ": + config = HnswSq( + distance_type=metric, + num_partitions=num_partitions, + max_iterations=max_iterations, + sample_rate=sample_rate, + m=m, + ef_construction=ef_construction, + ) + else: + raise ValueError(f"Unknown index type {index_type}") + + return LOOP.run( + self._table.create_index( + vector_column_name, + replace=replace, + config=config, + ) ) def create_scalar_index( @@ -1458,8 +1465,16 @@ class LanceTable(Table): replace: bool = True, index_type: Literal["BTREE", "BITMAP", "LABEL_LIST"] = "BTREE", ): - self._dataset_mut.create_scalar_index( - column, index_type=index_type, replace=replace + if index_type == "BTREE": + config = BTree() + elif index_type == "BITMAP": + config = Bitmap() + elif index_type == "LABEL_LIST": + config = LabelList() + else: + raise ValueError(f"Unknown index type {index_type}") + return LOOP.run( + self._table.create_index(column, replace=replace, config=config) ) def create_fts_index( @@ -1484,28 +1499,37 @@ class LanceTable(Table): if not use_tantivy: if not isinstance(field_names, str): raise ValueError("field_names must be a string when use_tantivy=False") - tokenizer_configs = { - "base_tokenizer": base_tokenizer, - "language": language, - "max_token_length": max_token_length, - "lower_case": lower_case, - "stem": stem, - "remove_stop_words": remove_stop_words, - "ascii_folding": ascii_folding, - } - if tokenizer_name is not None: + + if tokenizer_name is None: + tokenizer_configs = { + "base_tokenizer": base_tokenizer, + "language": language, + "max_token_length": max_token_length, + "lower_case": lower_case, + "stem": stem, + "remove_stop_words": remove_stop_words, + "ascii_folding": ascii_folding, + } + else: tokenizer_configs = self.infer_tokenizer_configs(tokenizer_name) + + config = FTS( + with_position=with_position, + **tokenizer_configs, + ) + # delete the existing legacy index if it exists if replace: path, fs, exist = self._get_fts_index_path() if exist: fs.delete_dir(path) - self._dataset_mut.create_scalar_index( - field_names, - index_type="INVERTED", - replace=replace, - with_position=with_position, - **tokenizer_configs, + + LOOP.run( + self._table.create_index( + field_names, + replace=replace, + config=config, + ) ) return @@ -1624,15 +1648,11 @@ class LanceTable(Table): int The number of vectors in the table. """ - # TODO: manage table listing and metadata separately - data, _ = _sanitize_data( - data, - self.schema, - metadata=self.schema.metadata, - on_bad_vectors=on_bad_vectors, - fill_value=fill_value, + LOOP.run( + self._table.add( + data, mode=mode, on_bad_vectors=on_bad_vectors, fill_value=fill_value + ) ) - self._ref.dataset_mut.insert(data, mode=mode, schema=self.schema) def merge( self, @@ -1692,18 +1712,19 @@ class LanceTable(Table): other_table = other_table.to_lance() if isinstance(other_table, LanceDataset): other_table = other_table.to_table() - self._ref.dataset = self._dataset_mut.merge( + self.to_lance().merge( other_table, left_on=left_on, right_on=right_on, schema=schema ) + self.checkout_latest() @cached_property - def embedding_functions(self) -> dict: + def embedding_functions(self) -> Dict[str, EmbeddingFunctionConfig]: """ Get the embedding functions for the table Returns ------- - funcs: dict + funcs: Dict[str, EmbeddingFunctionConfig] A mapping of the vector column to the embedding function or empty dict if not configured. """ @@ -1844,15 +1865,19 @@ class LanceTable(Table): @classmethod def create( cls, - db, - name, - data=None, - schema=None, - mode="create", - exist_ok=False, + db: LanceDBConnection, + name: str, + data: Optional[DATA] = None, + schema: Optional[pa.Schema] = None, + mode: Literal["create", "overwrite", "append"] = "create", + exist_ok: bool = False, on_bad_vectors: str = "error", fill_value: float = 0.0, embedding_functions: List[EmbeddingFunctionConfig] = None, + *, + storage_options: Optional[Dict[str, str]] = None, + data_storage_version: Optional[str] = None, + enable_v2_manifest_paths: Optional[bool] = None, ): """ Create a new table. @@ -1901,46 +1926,28 @@ class LanceTable(Table): embedding_functions: list of EmbeddingFunctionModel, default None The embedding functions to use when creating the table. """ - tbl = LanceTable(db, name) - metadata = None - if embedding_functions is not None: - # If we passed in embedding functions explicitly - # then we'll override any schema metadata that - # may was implicitly specified by the LanceModel schema - registry = EmbeddingFunctionRegistry.get_instance() - metadata = registry.get_table_metadata(embedding_functions) + self = cls.__new__(cls) + self._conn = db - data, schema = sanitize_create_table( - data, schema, metadata, on_bad_vectors, fill_value - ) - - empty = pa.Table.from_batches([], schema=schema) - try: - lance.write_dataset( - empty, - tbl._dataset_uri, + self._table = LOOP.run( + self._conn._conn.create_table( + name, + data, schema=schema, mode=mode, - storage_options=db.storage_options, + exist_ok=exist_ok, + on_bad_vectors=on_bad_vectors, + fill_value=fill_value, + embedding_functions=embedding_functions, + storage_options=storage_options, + data_storage_version=data_storage_version, + enable_v2_manifest_paths=enable_v2_manifest_paths, ) - except OSError as err: - if "Dataset already exists" in str(err) and exist_ok: - if tbl.schema != schema: - raise ValueError( - f"Table {name} already exists with a different schema" - ) - return tbl - raise - - new_table = LanceTable(db, name) - - if data is not None: - new_table.add(data) - - return new_table + ) + return self def delete(self, where: str): - self._dataset_mut.delete(where) + LOOP.run(self._table.delete(where)) def update( self, @@ -1986,42 +1993,12 @@ class LanceTable(Table): 2 2 [10.0, 10.0] """ - if values is not None and values_sql is not None: - raise ValueError("Only one of values or values_sql can be provided") - if values is None and values_sql is None: - raise ValueError("Either values or values_sql must be provided") - - if values is not None: - values_sql = {k: value_to_sql(v) for k, v in values.items()} - - self._dataset_mut.update(values_sql, where) + LOOP.run(self._table.update(values, where=where, updates_sql=values_sql)) def _execute_query( self, query: Query, batch_size: Optional[int] = None ) -> pa.RecordBatchReader: - ds = self.to_lance() - nearest = None - if len(query.vector) > 0: - nearest = { - "column": query.vector_column, - "q": query.vector, - "k": query.k, - "metric": query.metric, - "nprobes": query.nprobes, - "refine_factor": query.refine_factor, - "ef": query.ef, - } - return ds.scanner( - columns=query.columns, - limit=query.k, - filter=query.filter, - prefilter=query.prefilter, - nearest=nearest, - full_text_query=query.full_text_query, - with_row_id=query.with_row_id, - batch_size=batch_size, - offset=query.offset, - ).to_reader() + return LOOP.run(self._table._execute_query(query, batch_size)) def _do_merge( self, @@ -2030,23 +2007,7 @@ class LanceTable(Table): on_bad_vectors: str, fill_value: float, ): - new_data, _ = _sanitize_data( - new_data, - self.schema, - metadata=self.schema.metadata, - on_bad_vectors=on_bad_vectors, - fill_value=fill_value, - ) - ds = self.to_lance() - builder = ds.merge_insert(merge._on) - if merge._when_matched_update_all: - builder.when_matched_update_all(merge._when_matched_update_all_condition) - if merge._when_not_matched_insert_all: - builder.when_not_matched_insert_all() - if merge._when_not_matched_by_source_delete: - cond = merge._when_not_matched_by_source_condition - builder.when_not_matched_by_source_delete(cond) - builder.execute(new_data) + LOOP.run(self._table._do_merge(merge, new_data, on_bad_vectors, fill_value)) def cleanup_old_versions( self, @@ -2089,7 +2050,9 @@ class LanceTable(Table): (see Lance documentation for more details) For most cases, the default should be fine. """ - return self.to_lance().optimize.compact_files(*args, **kwargs) + stats = self.to_lance().optimize.compact_files(*args, **kwargs) + self.checkout_latest() + return stats def optimize( self, @@ -2138,51 +2101,74 @@ class LanceTable(Table): you have added or modified 100,000 or more records or run more than 20 data modification operations. """ - try: - asyncio.get_running_loop() - raise AssertionError( - "Synchronous method called in asynchronous context. " - "If you are writing an asynchronous application " - "then please use the asynchronous APIs" + LOOP.run( + self._table.optimize( + cleanup_older_than=cleanup_older_than, + delete_unverified=delete_unverified, ) - - except RuntimeError: - asyncio.run( - self._async_optimize( - cleanup_older_than=cleanup_older_than, - delete_unverified=delete_unverified, - ) - ) - self.checkout_latest() - - async def _async_optimize( - self, - cleanup_older_than: Optional[timedelta] = None, - delete_unverified: bool = False, - ): - conn = await lancedb_connect( - sanitize_uri(self._conn.uri), - ) - table = AsyncTable(await conn.open_table(self.name)) - return await table.optimize( - cleanup_older_than=cleanup_older_than, delete_unverified=delete_unverified ) + def list_indices(self) -> Iterable[IndexConfig]: + """ + List all indices that have been created with Self::create_index + """ + return LOOP.run(self._table.list_indices()) + + def index_stats(self, index_name: str) -> Optional[IndexStatistics]: + """ + Retrieve statistics about an index + + Parameters + ---------- + index_name: str + The name of the index to retrieve statistics for + + Returns + ------- + IndexStatistics or None + The statistics about the index. Returns None if the index does not exist. + """ + return LOOP.run(self._table.index_stats(index_name)) + def add_columns(self, transforms: Dict[str, str]): - self._dataset_mut.add_columns(transforms) + LOOP.run(self._table.add_columns(transforms)) def alter_columns(self, *alterations: Iterable[Dict[str, str]]): - modified = [] - # I called this name in pylance, but I think I regret that now. So we - # allow both name and rename. - for alter in alterations: - if "rename" in alter: - alter["name"] = alter.pop("rename") - modified.append(alter) - self._dataset_mut.alter_columns(*modified) + LOOP.run(self._table.alter_columns(*alterations)) def drop_columns(self, columns: Iterable[str]): - self._dataset_mut.drop_columns(columns) + LOOP.run(self._table.drop_columns(columns)) + + def uses_v2_manifest_paths(self) -> bool: + """ + Check if the table is using the new v2 manifest paths. + + Returns + ------- + bool + True if the table is using the new v2 manifest paths, False otherwise. + """ + return LOOP.run(self._table.uses_v2_manifest_paths()) + + def migrate_v2_manifest_paths(self): + """ + Migrate the manifest paths to the new format. + + This will update the manifest to use the new v2 format for paths. + + This function is idempotent, and can be run multiple times without + changing the state of the object store. + + !!! danger + + This should not be run while other concurrent operations are happening. + And it should also run until completion before resuming other operations. + + You can use + [LanceTable.uses_v2_manifest_paths][lancedb.table.LanceTable.uses_v2_manifest_paths] + to check if the table is already using the new path style. + """ + LOOP.run(self._table.migrate_v2_manifest_paths()) def _sanitize_schema( @@ -2573,6 +2559,17 @@ class AsyncTable: """ return await self._inner.count_rows(filter) + async def head(self, n=5) -> pa.Table: + """ + Return the first `n` rows of the table. + + Parameters + ---------- + n: int, default 5 + The number of rows to return. + """ + return await self.query().limit(n).to_arrow() + def query(self) -> AsyncQuery: """ Returns an [AsyncQuery][lancedb.query.AsyncQuery] that can be used @@ -2630,15 +2627,27 @@ class AsyncTable: that index is out of date. The default is True - config: Union[IvfPq, BTree], default None + config: default None For advanced configuration you can specify the type of index you would like to create. You can also specify index-specific parameters when creating an index object. """ - index = None if config is not None: - index = config._inner - await self._inner.create_index(column, index=index, replace=replace) + if not isinstance( + config, (IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS) + ): + raise TypeError( + "config must be an instance of IvfPq, HnswPq, HnswSq, BTree," + " Bitmap, LabelList, or FTS" + ) + try: + await self._inner.create_index(column, index=config, replace=replace) + except ValueError as e: + if "not support the requested language" in str(e): + supported_langs = ", ".join(lang_mapping.values()) + help_msg = f"Supported languages: {supported_langs}" + add_note(e, help_msg) + raise e async def add( self, @@ -3029,7 +3038,15 @@ class AsyncTable: To return the table to a normal state use `[Self::checkout_latest]` """ - await self._inner.checkout(version) + try: + await self._inner.checkout(version) + except RuntimeError as e: + if "not found" in str(e): + raise ValueError( + f"Version {version} no longer exists. Was it cleaned up?" + ) + else: + raise async def checkout_latest(self): """ diff --git a/python/python/lancedb/util.py b/python/python/lancedb/util.py index 3e12efd3..eda7ce06 100644 --- a/python/python/lancedb/util.py +++ b/python/python/lancedb/util.py @@ -314,3 +314,15 @@ def deprecated(func): def validate_table_name(name: str): """Verify the table name is valid.""" native_validate_table_name(name) + + +def add_note(base_exception: BaseException, note: str): + if hasattr(base_exception, "add_note"): + base_exception.add_note(note) + elif isinstance(base_exception.args[0], str): + base_exception.args = ( + base_exception.args[0] + "\n" + note, + *base_exception.args[1:], + ) + else: + raise ValueError("Cannot add note to exception") diff --git a/python/python/tests/conftest.py b/python/python/tests/conftest.py new file mode 100644 index 00000000..e3205216 --- /dev/null +++ b/python/python/tests/conftest.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The LanceDB Authors + +from datetime import timedelta +from lancedb.db import AsyncConnection, DBConnection +import lancedb +import pytest +import pytest_asyncio + + +# Use an in-memory database for most tests. +@pytest.fixture +def mem_db() -> DBConnection: + return lancedb.connect("memory://") + + +# Use a temporary directory when we need to inspect the database files. +@pytest.fixture +def tmp_db(tmp_path) -> DBConnection: + return lancedb.connect(tmp_path) + + +@pytest_asyncio.fixture +async def mem_db_async() -> AsyncConnection: + return await lancedb.connect_async("memory://") + + +@pytest_asyncio.fixture +async def tmp_db_async(tmp_path) -> AsyncConnection: + return await lancedb.connect_async( + tmp_path, read_consistency_interval=timedelta(seconds=0) + ) diff --git a/python/python/tests/test_db.py b/python/python/tests/test_db.py index 93cd2aa8..893ae467 100644 --- a/python/python/tests/test_db.py +++ b/python/python/tests/test_db.py @@ -98,7 +98,7 @@ def test_ingest_pd(tmp_path): assert db.open_table("test").name == db["test"].name -def test_ingest_iterator(tmp_path): +def test_ingest_iterator(mem_db: lancedb.DBConnection): class PydanticSchema(LanceModel): vector: Vector(2) item: str @@ -156,8 +156,7 @@ def test_ingest_iterator(tmp_path): ] def run_tests(schema): - db = lancedb.connect(tmp_path) - tbl = db.create_table("table2", make_batches(), schema=schema, mode="overwrite") + tbl = mem_db.create_table("table2", make_batches(), schema=schema) tbl.to_pandas() assert tbl.search([3.1, 4.1]).limit(1).to_pandas()["_distance"][0] == 0.0 assert tbl.search([5.9, 26.5]).limit(1).to_pandas()["_distance"][0] == 0.0 @@ -165,15 +164,14 @@ def test_ingest_iterator(tmp_path): tbl.add(make_batches()) assert tbl_len == 50 assert len(tbl) == tbl_len * 2 - assert len(tbl.list_versions()) == 3 - db.drop_database() + assert len(tbl.list_versions()) == 2 + mem_db.drop_database() run_tests(arrow_schema) run_tests(PydanticSchema) -def test_table_names(tmp_path): - db = lancedb.connect(tmp_path) +def test_table_names(tmp_db: lancedb.DBConnection): data = pd.DataFrame( { "vector": [[3.1, 4.1], [5.9, 26.5]], @@ -181,10 +179,10 @@ def test_table_names(tmp_path): "price": [10.0, 20.0], } ) - db.create_table("test2", data=data) - db.create_table("test1", data=data) - db.create_table("test3", data=data) - assert db.table_names() == ["test1", "test2", "test3"] + tmp_db.create_table("test2", data=data) + tmp_db.create_table("test1", data=data) + tmp_db.create_table("test3", data=data) + assert tmp_db.table_names() == ["test1", "test2", "test3"] @pytest.mark.asyncio @@ -209,8 +207,7 @@ async def test_table_names_async(tmp_path): assert await db.table_names(start_after="test1") == ["test2", "test3"] -def test_create_mode(tmp_path): - db = lancedb.connect(tmp_path) +def test_create_mode(tmp_db: lancedb.DBConnection): data = pd.DataFrame( { "vector": [[3.1, 4.1], [5.9, 26.5]], @@ -218,10 +215,10 @@ def test_create_mode(tmp_path): "price": [10.0, 20.0], } ) - db.create_table("test", data=data) + tmp_db.create_table("test", data=data) with pytest.raises(Exception): - db.create_table("test", data=data) + tmp_db.create_table("test", data=data) new_data = pd.DataFrame( { @@ -230,13 +227,11 @@ def test_create_mode(tmp_path): "price": [10.0, 20.0], } ) - tbl = db.create_table("test", data=new_data, mode="overwrite") + tbl = tmp_db.create_table("test", data=new_data, mode="overwrite") assert tbl.to_pandas().item.tolist() == ["fizz", "buzz"] -def test_create_table_from_iterator(tmp_path): - db = lancedb.connect(tmp_path) - +def test_create_table_from_iterator(mem_db: lancedb.DBConnection): def gen_data(): for _ in range(10): yield pa.RecordBatch.from_arrays( @@ -248,14 +243,12 @@ def test_create_table_from_iterator(tmp_path): ["vector", "item", "price"], ) - table = db.create_table("test", data=gen_data()) + table = mem_db.create_table("test", data=gen_data()) assert table.count_rows() == 10 @pytest.mark.asyncio -async def test_create_table_from_iterator_async(tmp_path): - db = await lancedb.connect_async(tmp_path) - +async def test_create_table_from_iterator_async(mem_db_async: lancedb.AsyncConnection): def gen_data(): for _ in range(10): yield pa.RecordBatch.from_arrays( @@ -267,12 +260,11 @@ async def test_create_table_from_iterator_async(tmp_path): ["vector", "item", "price"], ) - table = await db.create_table("test", data=gen_data()) + table = await mem_db_async.create_table("test", data=gen_data()) assert await table.count_rows() == 10 -def test_create_exist_ok(tmp_path): - db = lancedb.connect(tmp_path) +def test_create_exist_ok(tmp_db: lancedb.DBConnection): data = pd.DataFrame( { "vector": [[3.1, 4.1], [5.9, 26.5]], @@ -280,13 +272,13 @@ def test_create_exist_ok(tmp_path): "price": [10.0, 20.0], } ) - tbl = db.create_table("test", data=data) + tbl = tmp_db.create_table("test", data=data) - with pytest.raises(OSError): - db.create_table("test", data=data) + with pytest.raises(ValueError): + tmp_db.create_table("test", data=data) # open the table but don't add more rows - tbl2 = db.create_table("test", data=data, exist_ok=True) + tbl2 = tmp_db.create_table("test", data=data, exist_ok=True) assert tbl.name == tbl2.name assert tbl.schema == tbl2.schema assert len(tbl) == len(tbl2) @@ -298,7 +290,7 @@ def test_create_exist_ok(tmp_path): pa.field("price", pa.float64()), ] ) - tbl3 = db.create_table("test", schema=schema, exist_ok=True) + tbl3 = tmp_db.create_table("test", schema=schema, exist_ok=True) assert tbl3.schema == schema bad_schema = pa.schema( @@ -310,7 +302,7 @@ def test_create_exist_ok(tmp_path): ] ) with pytest.raises(ValueError): - db.create_table("test", schema=bad_schema, exist_ok=True) + tmp_db.create_table("test", schema=bad_schema, exist_ok=True) @pytest.mark.asyncio @@ -325,26 +317,24 @@ async def test_connect(tmp_path): @pytest.mark.asyncio -async def test_close(tmp_path): - db = await lancedb.connect_async(tmp_path) - assert db.is_open() - db.close() - assert not db.is_open() +async def test_close(mem_db_async: lancedb.AsyncConnection): + assert mem_db_async.is_open() + mem_db_async.close() + assert not mem_db_async.is_open() with pytest.raises(RuntimeError, match="is closed"): - await db.table_names() + await mem_db_async.table_names() @pytest.mark.asyncio -async def test_context_manager(tmp_path): - with await lancedb.connect_async(tmp_path) as db: +async def test_context_manager(): + with await lancedb.connect_async("memory://") as db: assert db.is_open() assert not db.is_open() @pytest.mark.asyncio -async def test_create_mode_async(tmp_path): - db = await lancedb.connect_async(tmp_path) +async def test_create_mode_async(tmp_db_async: lancedb.AsyncConnection): data = pd.DataFrame( { "vector": [[3.1, 4.1], [5.9, 26.5]], @@ -352,10 +342,10 @@ async def test_create_mode_async(tmp_path): "price": [10.0, 20.0], } ) - await db.create_table("test", data=data) + await tmp_db_async.create_table("test", data=data) with pytest.raises(ValueError, match="already exists"): - await db.create_table("test", data=data) + await tmp_db_async.create_table("test", data=data) new_data = pd.DataFrame( { @@ -364,15 +354,14 @@ async def test_create_mode_async(tmp_path): "price": [10.0, 20.0], } ) - _tbl = await db.create_table("test", data=new_data, mode="overwrite") + _tbl = await tmp_db_async.create_table("test", data=new_data, mode="overwrite") # MIGRATION: to_pandas() is not available in async # assert tbl.to_pandas().item.tolist() == ["fizz", "buzz"] @pytest.mark.asyncio -async def test_create_exist_ok_async(tmp_path): - db = await lancedb.connect_async(tmp_path) +async def test_create_exist_ok_async(tmp_db_async: lancedb.AsyncConnection): data = pd.DataFrame( { "vector": [[3.1, 4.1], [5.9, 26.5]], @@ -380,13 +369,13 @@ async def test_create_exist_ok_async(tmp_path): "price": [10.0, 20.0], } ) - tbl = await db.create_table("test", data=data) + tbl = await tmp_db_async.create_table("test", data=data) with pytest.raises(ValueError, match="already exists"): - await db.create_table("test", data=data) + await tmp_db_async.create_table("test", data=data) # open the table but don't add more rows - tbl2 = await db.create_table("test", data=data, exist_ok=True) + tbl2 = await tmp_db_async.create_table("test", data=data, exist_ok=True) assert tbl.name == tbl2.name assert await tbl.schema() == await tbl2.schema() @@ -397,7 +386,7 @@ async def test_create_exist_ok_async(tmp_path): pa.field("price", pa.float64()), ] ) - tbl3 = await db.create_table("test", schema=schema, exist_ok=True) + tbl3 = await tmp_db_async.create_table("test", schema=schema, exist_ok=True) assert await tbl3.schema() == schema # Migration: When creating a table, but the table already exists, but @@ -448,13 +437,12 @@ async def test_create_table_v2_manifest_paths_async(tmp_path): assert re.match(r"\d{20}\.manifest", manifest) -def test_open_table_sync(tmp_path): - db = lancedb.connect(tmp_path) - db.create_table("test", data=[{"id": 0}]) - assert db.open_table("test").count_rows() == 1 - assert db.open_table("test", index_cache_size=0).count_rows() == 1 - with pytest.raises(FileNotFoundError, match="does not exist"): - db.open_table("does_not_exist") +def test_open_table_sync(tmp_db: lancedb.DBConnection): + tmp_db.create_table("test", data=[{"id": 0}]) + assert tmp_db.open_table("test").count_rows() == 1 + assert tmp_db.open_table("test", index_cache_size=0).count_rows() == 1 + with pytest.raises(ValueError, match="Table 'does_not_exist' was not found"): + tmp_db.open_table("does_not_exist") @pytest.mark.asyncio @@ -494,8 +482,7 @@ async def test_open_table(tmp_path): await db.open_table("does_not_exist") -def test_delete_table(tmp_path): - db = lancedb.connect(tmp_path) +def test_delete_table(tmp_db: lancedb.DBConnection): data = pd.DataFrame( { "vector": [[3.1, 4.1], [5.9, 26.5]], @@ -503,26 +490,25 @@ def test_delete_table(tmp_path): "price": [10.0, 20.0], } ) - db.create_table("test", data=data) + tmp_db.create_table("test", data=data) with pytest.raises(Exception): - db.create_table("test", data=data) + tmp_db.create_table("test", data=data) - assert db.table_names() == ["test"] + assert tmp_db.table_names() == ["test"] - db.drop_table("test") - assert db.table_names() == [] + tmp_db.drop_table("test") + assert tmp_db.table_names() == [] - db.create_table("test", data=data) - assert db.table_names() == ["test"] + tmp_db.create_table("test", data=data) + assert tmp_db.table_names() == ["test"] # dropping a table that does not exist should pass # if ignore_missing=True - db.drop_table("does_not_exist", ignore_missing=True) + tmp_db.drop_table("does_not_exist", ignore_missing=True) -def test_drop_database(tmp_path): - db = lancedb.connect(tmp_path) +def test_drop_database(tmp_db: lancedb.DBConnection): data = pd.DataFrame( { "vector": [[3.1, 4.1], [5.9, 26.5]], @@ -537,51 +523,50 @@ def test_drop_database(tmp_path): "price": [12.0, 17.0], } ) - db.create_table("test", data=data) + tmp_db.create_table("test", data=data) with pytest.raises(Exception): - db.create_table("test", data=data) + tmp_db.create_table("test", data=data) - assert db.table_names() == ["test"] + assert tmp_db.table_names() == ["test"] - db.create_table("new_test", data=new_data) - db.drop_database() - assert db.table_names() == [] + tmp_db.create_table("new_test", data=new_data) + tmp_db.drop_database() + assert tmp_db.table_names() == [] # it should pass when no tables are present - db.create_table("test", data=new_data) - db.drop_table("test") - assert db.table_names() == [] - db.drop_database() - assert db.table_names() == [] + tmp_db.create_table("test", data=new_data) + tmp_db.drop_table("test") + assert tmp_db.table_names() == [] + tmp_db.drop_database() + assert tmp_db.table_names() == [] # creating an empty database with schema schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), list_size=2))]) - db.create_table("empty_table", schema=schema) + tmp_db.create_table("empty_table", schema=schema) # dropping a empty database should pass - db.drop_database() - assert db.table_names() == [] + tmp_db.drop_database() + assert tmp_db.table_names() == [] -def test_empty_or_nonexistent_table(tmp_path): - db = lancedb.connect(tmp_path) +def test_empty_or_nonexistent_table(mem_db: lancedb.DBConnection): with pytest.raises(Exception): - db.create_table("test_with_no_data") + mem_db.create_table("test_with_no_data") with pytest.raises(Exception): - db.open_table("does_not_exist") + mem_db.open_table("does_not_exist") schema = pa.schema([pa.field("a", pa.int64(), nullable=False)]) - test = db.create_table("test", schema=schema) + test = mem_db.create_table("test", schema=schema) class TestModel(LanceModel): a: int - test2 = db.create_table("test2", schema=TestModel) + test2 = mem_db.create_table("test2", schema=TestModel) assert test.schema == test2.schema @pytest.mark.asyncio -async def test_create_in_v2_mode(tmp_path): +async def test_create_in_v2_mode(mem_db_async: lancedb.AsyncConnection): def make_data(): for i in range(10): yield pa.record_batch([pa.array([x for x in range(1024)])], names=["x"]) @@ -591,10 +576,8 @@ async def test_create_in_v2_mode(tmp_path): schema = pa.schema([pa.field("x", pa.int64())]) - db = await lancedb.connect_async(tmp_path) - # Create table in v1 mode - tbl = await db.create_table( + tbl = await mem_db_async.create_table( "test", data=make_data(), schema=schema, data_storage_version="legacy" ) @@ -610,7 +593,7 @@ async def test_create_in_v2_mode(tmp_path): assert not await is_in_v2_mode(tbl) # Create table in v2 mode - tbl = await db.create_table( + tbl = await mem_db_async.create_table( "test_v2", data=make_data(), schema=schema, use_legacy_format=False ) @@ -622,7 +605,7 @@ async def test_create_in_v2_mode(tmp_path): assert await is_in_v2_mode(tbl) # Create empty table in v2 mode and add data - tbl = await db.create_table( + tbl = await mem_db_async.create_table( "test_empty_v2", data=None, schema=schema, use_legacy_format=False ) await tbl.add(make_table()) @@ -630,7 +613,7 @@ async def test_create_in_v2_mode(tmp_path): assert await is_in_v2_mode(tbl) # Create empty table uses v1 mode by default - tbl = await db.create_table( + tbl = await mem_db_async.create_table( "test_empty_v2_default", data=None, schema=schema, data_storage_version="legacy" ) await tbl.add(make_table()) @@ -638,18 +621,17 @@ async def test_create_in_v2_mode(tmp_path): assert not await is_in_v2_mode(tbl) -def test_replace_index(tmp_path): - db = lancedb.connect(uri=tmp_path) - table = db.create_table( +def test_replace_index(mem_db: lancedb.DBConnection): + table = mem_db.create_table( "test", [ - {"vector": np.random.rand(128), "item": "foo", "price": float(i)} - for i in range(1000) + {"vector": np.random.rand(32), "item": "foo", "price": float(i)} + for i in range(512) ], ) table.create_index( num_partitions=2, - num_sub_vectors=4, + num_sub_vectors=2, ) with pytest.raises(Exception): @@ -660,27 +642,26 @@ def test_replace_index(tmp_path): ) table.create_index( - num_partitions=2, - num_sub_vectors=4, + num_partitions=1, + num_sub_vectors=2, replace=True, index_cache_size=10, ) -def test_prefilter_with_index(tmp_path): - db = lancedb.connect(uri=tmp_path) +def test_prefilter_with_index(mem_db: lancedb.DBConnection): data = [ - {"vector": np.random.rand(128), "item": "foo", "price": float(i)} - for i in range(1000) + {"vector": np.random.rand(32), "item": "foo", "price": float(i)} + for i in range(512) ] sample_key = data[100]["vector"] - table = db.create_table( + table = mem_db.create_table( "test", data, ) table.create_index( num_partitions=2, - num_sub_vectors=4, + num_sub_vectors=2, ) table = ( table.search(sample_key) @@ -691,13 +672,12 @@ def test_prefilter_with_index(tmp_path): assert table.num_rows == 1 -def test_create_table_with_invalid_names(tmp_path): - db = lancedb.connect(uri=tmp_path) +def test_create_table_with_invalid_names(tmp_db: lancedb.DBConnection): data = [{"vector": np.random.rand(128), "item": "foo"} for i in range(10)] with pytest.raises(ValueError): - db.create_table("foo/bar", data) + tmp_db.create_table("foo/bar", data) with pytest.raises(ValueError): - db.create_table("foo bar", data) + tmp_db.create_table("foo bar", data) with pytest.raises(ValueError): - db.create_table("foo$$bar", data) - db.create_table("foo.bar", data) + tmp_db.create_table("foo$$bar", data) + tmp_db.create_table("foo.bar", data) diff --git a/python/python/tests/test_fts.py b/python/python/tests/test_fts.py index 594552a0..162a2f11 100644 --- a/python/python/tests/test_fts.py +++ b/python/python/tests/test_fts.py @@ -15,10 +15,12 @@ import random from unittest import mock import lancedb as ldb +from lancedb.db import DBConnection from lancedb.index import FTS import numpy as np import pandas as pd import pytest +from utils import exception_output pytest.importorskip("lancedb.fts") tantivy = pytest.importorskip("tantivy") @@ -458,3 +460,44 @@ def test_syntax(table): table.search('the cats OR dogs were not really "pets" at all').phrase_query().limit( 10 ).to_list() + + +def test_language(mem_db: DBConnection): + sentences = [ + "Il n'y a que trois routes qui traversent la ville.", + "Je veux prendre la route vers l'est.", + "Je te retrouve au café au bout de la route.", + ] + data = [{"text": s} for s in sentences] + table = mem_db.create_table("test", data=data) + + with pytest.raises(ValueError) as e: + table.create_fts_index("text", use_tantivy=False, language="klingon") + + assert exception_output(e) == ( + "ValueError: LanceDB does not support the requested language: 'klingon'\n" + "Supported languages: Arabic, Danish, Dutch, English, Finnish, French, " + "German, Greek, Hungarian, Italian, Norwegian, Portuguese, Romanian, " + "Russian, Spanish, Swedish, Tamil, Turkish" + ) + + table.create_fts_index( + "text", + use_tantivy=False, + language="French", + stem=True, + ascii_folding=True, + remove_stop_words=True, + ) + + # Can get "routes" and "route" from the same root + results = table.search("route", query_type="fts").limit(5).to_list() + assert len(results) == 3 + + # Can find "café", without needing to provide accent + results = table.search("cafe", query_type="fts").limit(5).to_list() + assert len(results) == 1 + + # Stop words -> no results + results = table.search("la", query_type="fts").limit(5).to_list() + assert len(results) == 0 diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index 7a8bb552..1809a5c5 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -1,77 +1,57 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright The Lance Authors -import functools import os -from copy import copy from datetime import date, datetime, timedelta -from pathlib import Path from time import sleep from typing import List -from unittest.mock import PropertyMock, patch +from unittest.mock import patch import lance import lancedb +from lancedb.index import HnswPq, HnswSq, IvfPq import numpy as np import pandas as pd import polars as pl import pyarrow as pa import pytest -import pytest_asyncio from lancedb.conftest import MockTextEmbeddingFunction -from lancedb.db import AsyncConnection, LanceDBConnection +from lancedb.db import AsyncConnection, DBConnection from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry from lancedb.pydantic import LanceModel, Vector from lancedb.table import LanceTable from pydantic import BaseModel -class MockDB: - def __init__(self, uri: Path): - self.uri = str(uri) - self.read_consistency_interval = None - self.storage_options = None +def test_basic(mem_db: DBConnection): + data = [ + {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, + {"vector": [5.9, 26.5], "item": "bar", "price": 20.0}, + ] + table = mem_db.create_table("test", data=data) - @functools.cached_property - def is_managed_remote(self) -> bool: - return False - - -@pytest.fixture -def db(tmp_path) -> MockDB: - return MockDB(tmp_path) - - -@pytest_asyncio.fixture -async def db_async(tmp_path) -> AsyncConnection: - return await lancedb.connect_async( - tmp_path, read_consistency_interval=timedelta(seconds=0) - ) - - -def test_basic(db): - ds = LanceTable.create( - db, - "test", - data=[ - {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, - {"vector": [5.9, 26.5], "item": "bar", "price": 20.0}, - ], - ).to_lance() - - table = LanceTable(db, "test") assert table.name == "test" - assert table.schema == ds.schema - assert table.to_lance().to_table() == ds.to_table() + assert "LanceTable(name='test', version=1, _conn=LanceDBConnection(" in repr(table) + expected_schema = pa.schema( + { + "vector": pa.list_(pa.float32(), 2), + "item": pa.string(), + "price": pa.float64(), + } + ) + assert table.schema == expected_schema + + expected_data = pa.Table.from_pylist(data, schema=expected_schema) + assert table.to_arrow() == expected_data -def test_input_data_type(db, tmp_path): +def test_input_data_type(mem_db: DBConnection, tmp_path): schema = pa.schema( - [ - pa.field("id", pa.int64()), - pa.field("name", pa.string()), - pa.field("age", pa.int32()), - ] + { + "id": pa.int64(), + "name": pa.string(), + "age": pa.int32(), + } ) data = { @@ -100,23 +80,17 @@ def test_input_data_type(db, tmp_path): ] for input_type, input_data in input_types: table_name = f"test_{input_type.lower()}" - ds = LanceTable.create(db, table_name, data=input_data).to_lance() - assert ds.schema == schema - assert ds.count_rows() == 5 + table = mem_db.create_table(table_name, data=input_data) + assert table.schema == schema + assert table.count_rows() == 5 - assert ds.schema.field("id").type == pa.int64() - assert ds.schema.field("name").type == pa.string() - assert ds.schema.field("age").type == pa.int32() - - result_table = ds.to_table() - assert result_table.column("id").to_pylist() == data["id"] - assert result_table.column("name").to_pylist() == data["name"] - assert result_table.column("age").to_pylist() == data["age"] + assert table.schema == schema + assert table.to_arrow() == pa_table @pytest.mark.asyncio -async def test_close(db_async: AsyncConnection): - table = await db_async.create_table("some_table", data=[{"id": 0}]) +async def test_close(mem_db_async: AsyncConnection): + table = await mem_db_async.create_table("some_table", data=[{"id": 0}]) assert table.is_open() table.close() assert not table.is_open() @@ -127,8 +101,8 @@ async def test_close(db_async: AsyncConnection): @pytest.mark.asyncio -async def test_update_async(db_async: AsyncConnection): - table = await db_async.create_table("some_table", data=[{"id": 0}]) +async def test_update_async(mem_db_async: AsyncConnection): + table = await mem_db_async.create_table("some_table", data=[{"id": 0}]) assert await table.count_rows("id == 0") == 1 assert await table.count_rows("id == 7") == 0 await table.update({"id": 7}) @@ -143,42 +117,40 @@ async def test_update_async(db_async: AsyncConnection): assert await table.count_rows("id == 10") == 1 -def test_create_table(db): +def test_create_table(mem_db: DBConnection): schema = pa.schema( - [ - pa.field("vector", pa.list_(pa.float32(), 2)), - pa.field("item", pa.string()), - pa.field("price", pa.float32()), - ] + { + "vector": pa.list_(pa.float32(), 2), + "item": pa.string(), + "price": pa.float64(), + } ) - expected = pa.Table.from_arrays( - [ - pa.FixedSizeListArray.from_arrays(pa.array([3.1, 4.1, 5.9, 26.5]), 2), - pa.array(["foo", "bar"]), - pa.array([10.0, 20.0]), - ], + expected = pa.table( + { + "vector": [[3.1, 4.1], [5.9, 26.5]], + "item": ["foo", "bar"], + "price": [10.0, 20.0], + }, schema=schema, ) - data = [ - [ - {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, - {"vector": [5.9, 26.5], "item": "bar", "price": 20.0}, - ] + rows = [ + {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, + {"vector": [5.9, 26.5], "item": "bar", "price": 20.0}, + ] + df = pd.DataFrame(rows) + pa_table = pa.Table.from_pandas(df, schema=schema) + data = [ + ("Rows", rows), + ("pd_DataFrame", df), + ("pa_Table", pa_table), ] - df = pd.DataFrame(data[0]) - data.append(df) - data.append(pa.Table.from_pandas(df, schema=schema)) - for i, d in enumerate(data): - tbl = ( - LanceTable.create(db, f"test_{i}", data=d, schema=schema) - .to_lance() - .to_table() - ) + for name, d in data: + tbl = mem_db.create_table(name, data=d, schema=schema).to_arrow() assert expected == tbl -def test_empty_table(db): +def test_empty_table(mem_db: DBConnection): schema = pa.schema( [ pa.field("vector", pa.list_(pa.float32(), 2)), @@ -186,7 +158,7 @@ def test_empty_table(db): pa.field("price", pa.float32()), ] ) - tbl = LanceTable.create(db, "test", schema=schema) + tbl = mem_db.create_table("test", schema=schema) data = [ {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, {"vector": [5.9, 26.5], "item": "bar", "price": 20.0}, @@ -194,7 +166,7 @@ def test_empty_table(db): tbl.add(data=data) -def test_add_dictionary(db): +def test_add_dictionary(mem_db: DBConnection): schema = pa.schema( [ pa.field("vector", pa.list_(pa.float32(), 2)), @@ -202,7 +174,7 @@ def test_add_dictionary(db): pa.field("price", pa.float32()), ] ) - tbl = LanceTable.create(db, "test", schema=schema) + tbl = mem_db.create_table("test", schema=schema) data = {"vector": [3.1, 4.1], "item": "foo", "price": 10.0} with pytest.raises(ValueError) as excep_info: tbl.add(data=data) @@ -212,7 +184,7 @@ def test_add_dictionary(db): ) -def test_add(db): +def test_add(mem_db: DBConnection): schema = pa.schema( [ pa.field("vector", pa.list_(pa.float32(), 2)), @@ -221,8 +193,24 @@ def test_add(db): ] ) - table = LanceTable.create( - db, + def _add(table, schema): + assert len(table) == 2 + + table.add([{"vector": [6.3, 100.5], "item": "new", "price": 30.0}]) + assert len(table) == 3 + + expected = pa.table( + { + "vector": [[3.1, 4.1], [5.9, 26.5], [6.3, 100.5]], + "item": ["foo", "bar", "new"], + "price": [10.0, 20.0, 30.0], + }, + schema=schema, + ) + assert expected == table.to_arrow() + + # Append to table created with data + table = mem_db.create_table( "test", data=[ {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, @@ -231,7 +219,8 @@ def test_add(db): ) _add(table, schema) - table = LanceTable.create(db, "test2", schema=schema) + # Append to table created empty with schema + table = mem_db.create_table("test2", schema=schema) table.add( data=[ {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, @@ -241,8 +230,7 @@ def test_add(db): _add(table, schema) -def test_add_subschema(tmp_path): - db = lancedb.connect(tmp_path) +def test_add_subschema(mem_db: DBConnection): schema = pa.schema( [ pa.field("vector", pa.list_(pa.float32(), 2), nullable=True), @@ -250,7 +238,7 @@ def test_add_subschema(tmp_path): pa.field("price", pa.float64(), nullable=False), ] ) - table = db.create_table("test", schema=schema) + table = mem_db.create_table("test", schema=schema) data = {"price": 10.0, "item": "foo"} table.add([data]) @@ -271,7 +259,7 @@ def test_add_subschema(tmp_path): data = {"item": "foo"} # We can't omit a column if it's not nullable - with pytest.raises(OSError, match="Invalid user input"): + with pytest.raises(RuntimeError, match="Invalid user input"): table.add([data]) # We can add it if we make the column nullable @@ -296,15 +284,14 @@ def test_add_subschema(tmp_path): assert table.to_arrow() == expected -def test_add_nullability(tmp_path): - db = lancedb.connect(tmp_path) +def test_add_nullability(mem_db: DBConnection): schema = pa.schema( [ pa.field("vector", pa.list_(pa.float32(), 2), nullable=False), pa.field("id", pa.string(), nullable=False), ] ) - table = db.create_table("test", schema=schema) + table = mem_db.create_table("test", schema=schema) nullable_schema = pa.schema( [ @@ -356,7 +343,7 @@ def test_add_nullability(tmp_path): assert table.to_arrow() == expected -def test_add_pydantic_model(db): +def test_add_pydantic_model(mem_db: DBConnection): # https://github.com/lancedb/lancedb/issues/562 class Metadata(BaseModel): @@ -373,7 +360,7 @@ def test_add_pydantic_model(db): li: List[int] payload: Document - tbl = LanceTable.create(db, "mytable", schema=LanceSchema, mode="overwrite") + tbl = mem_db.create_table("mytable", schema=LanceSchema, mode="overwrite") assert tbl.schema == LanceSchema.to_arrow_schema() # add works @@ -398,8 +385,8 @@ def test_add_pydantic_model(db): @pytest.mark.asyncio -async def test_add_async(db_async: AsyncConnection): - table = await db_async.create_table( +async def test_add_async(mem_db_async: AsyncConnection): + table = await mem_db_async.create_table( "test", data=[ {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, @@ -412,18 +399,17 @@ async def test_add_async(db_async: AsyncConnection): {"vector": [10.0, 11.0], "item": "baz", "price": 30.0}, ], ) - table = await db_async.open_table("test") assert await table.count_rows() == 3 -def test_polars(db): +def test_polars(mem_db: DBConnection): data = { "vector": [[3.1, 4.1], [5.9, 26.5]], "item": ["foo", "bar"], "price": [10.0, 20.0], } # Ingest polars dataframe - table = LanceTable.create(db, "test", data=pl.DataFrame(data)) + table = mem_db.create_table("test", data=pl.DataFrame(data)) assert len(table) == 2 result = table.to_pandas() @@ -456,28 +442,8 @@ def test_polars(db): assert len(filtered_result) == 2 -def _add(table, schema): - assert len(table) == 2 - - table.add([{"vector": [6.3, 100.5], "item": "new", "price": 30.0}]) - assert len(table) == 3 - - expected = pa.Table.from_arrays( - [ - pa.FixedSizeListArray.from_arrays( - pa.array([3.1, 4.1, 5.9, 26.5, 6.3, 100.5]), 2 - ), - pa.array(["foo", "bar", "new"]), - pa.array([10.0, 20.0, 30.0]), - ], - schema=schema, - ) - assert expected == table.to_arrow() - - -def test_versioning(db): - table = LanceTable.create( - db, +def test_versioning(mem_db: DBConnection): + table = mem_db.create_table( "test", data=[ {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, @@ -485,56 +451,74 @@ def test_versioning(db): ], ) - assert len(table.list_versions()) == 2 - assert table.version == 2 + assert len(table.list_versions()) == 1 + assert table.version == 1 table.add([{"vector": [6.3, 100.5], "item": "new", "price": 30.0}]) - assert len(table.list_versions()) == 3 - assert table.version == 3 + assert len(table.list_versions()) == 2 + assert table.version == 2 assert len(table) == 3 - table.checkout(2) - assert table.version == 2 + table.checkout(1) + assert table.version == 1 assert len(table) == 2 -def test_create_index_method(): - with patch.object( - LanceTable, "_dataset_mut", new_callable=PropertyMock - ) as mock_dataset: - # Setup mock responses - mock_dataset.return_value.create_index.return_value = None +@patch("lancedb.table.AsyncTable.create_index") +def test_create_index_method(mock_create_index, mem_db: DBConnection): + table = mem_db.create_table( + "test", + data=[ + {"vector": [3.1, 4.1]}, + {"vector": [5.9, 26.5]}, + ], + ) - # Create a LanceTable object - connection = LanceDBConnection(uri="mock.uri") - table = LanceTable(connection, "test_table") + table.create_index( + metric="L2", + num_partitions=256, + num_sub_vectors=96, + vector_column_name="vector", + replace=True, + index_cache_size=256, + num_bits=4, + ) + expected_config = IvfPq( + distance_type="L2", + num_partitions=256, + num_sub_vectors=96, + num_bits=4, + ) + mock_create_index.assert_called_with("vector", replace=True, config=expected_config) - # Call the create_index method - table.create_index( - metric="L2", - num_partitions=256, - num_sub_vectors=96, - vector_column_name="vector", - replace=True, - index_cache_size=256, - ) + table.create_index( + vector_column_name="my_vector", + metric="dot", + index_type="IVF_HNSW_PQ", + replace=False, + ) + expected_config = HnswPq(distance_type="dot") + mock_create_index.assert_called_with( + "my_vector", replace=False, config=expected_config + ) - # Check that the _dataset.create_index method was called - # with the right parameters - mock_dataset.return_value.create_index.assert_called_once_with( - column="vector", - index_type="IVF_PQ", - metric="L2", - num_partitions=256, - num_sub_vectors=96, - replace=True, - accelerator=None, - index_cache_size=256, - num_bits=8, - ) + table.create_index( + vector_column_name="my_vector", + metric="cosine", + index_type="IVF_HNSW_SQ", + sample_rate=0.1, + m=29, + ef_construction=10, + ) + expected_config = HnswSq( + distance_type="cosine", sample_rate=0.1, m=29, ef_construction=10 + ) + mock_create_index.assert_called_with( + "my_vector", replace=True, config=expected_config + ) -def test_add_with_nans(db): +def test_add_with_nans(mem_db: DBConnection): # by default we raise an error on bad input vectors bad_data = [ {"vector": [np.nan], "item": "bar", "price": 20.0}, @@ -544,14 +528,12 @@ def test_add_with_nans(db): ] for row in bad_data: with pytest.raises(ValueError): - LanceTable.create( - db, + mem_db.create_table( "error_test", data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, row], ) - table = LanceTable.create( - db, + table = mem_db.create_table( "drop_test", data=[ {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, @@ -564,8 +546,7 @@ def test_add_with_nans(db): assert len(table) == 1 # We can fill bad input with some value - table = LanceTable.create( - db, + table = mem_db.create_table( "fill_test", data=[ {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, @@ -576,29 +557,28 @@ def test_add_with_nans(db): fill_value=0.0, ) assert len(table) == 3 - arrow_tbl = table.to_lance().to_table(filter="item == 'bar'") + arrow_tbl = table.search().where("item == 'bar'").to_arrow() v = arrow_tbl["vector"].to_pylist()[0] assert np.allclose(v, np.array([0.0, 0.0])) -def test_restore(db): - table = LanceTable.create( - db, +def test_restore(mem_db: DBConnection): + table = mem_db.create_table( "my_table", data=[{"vector": [1.1, 0.9], "type": "vector"}], ) table.add([{"vector": [0.5, 0.2], "type": "vector"}]) - table.restore(2) - assert len(table.list_versions()) == 4 + table.restore(1) + assert len(table.list_versions()) == 3 assert len(table) == 1 expected = table.to_arrow() - table.checkout(2) + table.checkout(1) table.restore() - assert len(table.list_versions()) == 5 + assert len(table.list_versions()) == 4 assert table.to_arrow() == expected - table.restore(5) # latest version should be no-op + table.restore(4) # latest version should be no-op assert len(table.list_versions()) == 5 with pytest.raises(ValueError): @@ -608,12 +588,17 @@ def test_restore(db): table.restore(0) -def test_merge(db, tmp_path): - table = LanceTable.create( - db, +def test_merge(tmp_db: DBConnection, tmp_path): + table = tmp_db.create_table( "my_table", - data=[{"vector": [1.1, 0.9], "id": 0}, {"vector": [1.2, 1.9], "id": 1}], + schema=pa.schema( + { + "vector": pa.list_(pa.float32(), 2), + "id": pa.int64(), + } + ), ) + table.add([{"vector": [1.1, 0.9], "id": 0}, {"vector": [1.2, 1.9], "id": 1}]) other_table = pa.table({"document": ["foo", "bar"], "id": [0, 1]}) table.merge(other_table, left_on="id") assert len(table.list_versions()) == 3 @@ -628,41 +613,38 @@ def test_merge(db, tmp_path): table.merge(other_dataset, left_on="id") -def test_delete(db): - table = LanceTable.create( - db, +def test_delete(mem_db: DBConnection): + table = mem_db.create_table( "my_table", data=[{"vector": [1.1, 0.9], "id": 0}, {"vector": [1.2, 1.9], "id": 1}], ) assert len(table) == 2 - assert len(table.list_versions()) == 2 + assert len(table.list_versions()) == 1 table.delete("id=0") - assert len(table.list_versions()) == 3 - assert table.version == 3 + assert len(table.list_versions()) == 2 + assert table.version == 2 assert len(table) == 1 assert table.to_pandas()["id"].tolist() == [1] -def test_update(db): - table = LanceTable.create( - db, +def test_update(mem_db: DBConnection): + table = mem_db.create_table( "my_table", data=[{"vector": [1.1, 0.9], "id": 0}, {"vector": [1.2, 1.9], "id": 1}], ) assert len(table) == 2 - assert len(table.list_versions()) == 2 + assert len(table.list_versions()) == 1 table.update(where="id=0", values={"vector": [1.1, 1.1]}) - assert len(table.list_versions()) == 3 - assert table.version == 3 + assert len(table.list_versions()) == 2 + assert table.version == 2 assert len(table) == 2 v = table.to_arrow()["vector"].combine_chunks() v = v.values.to_numpy().reshape(2, 2) assert np.allclose(v, np.array([[1.2, 1.9], [1.1, 1.1]])) -def test_update_types(db): - table = LanceTable.create( - db, +def test_update_types(mem_db: DBConnection): + table = mem_db.create_table( "my_table", data=[ { @@ -730,9 +712,8 @@ def test_update_types(db): assert actual == expected -def test_merge_insert(db): - table = LanceTable.create( - db, +def test_merge_insert(mem_db: DBConnection): + table = mem_db.create_table( "my_table", data=pa.table({"a": [1, 2, 3], "b": ["a", "b", "c"]}), ) @@ -796,9 +777,9 @@ def test_merge_insert(db): @pytest.mark.asyncio -async def test_merge_insert_async(db_async: AsyncConnection): +async def test_merge_insert_async(mem_db_async: AsyncConnection): data = pa.table({"a": [1, 2, 3], "b": ["a", "b", "c"]}) - table = await db_async.create_table("some_table", data=data) + table = await mem_db_async.create_table("some_table", data=data) assert await table.count_rows() == 3 version = await table.version() @@ -864,7 +845,7 @@ async def test_merge_insert_async(db_async: AsyncConnection): assert (await table.to_arrow()).sort_by("a") == expected -def test_create_with_embedding_function(db): +def test_create_with_embedding_function(mem_db: DBConnection): class MyTable(LanceModel): text: str vector: Vector(10) @@ -876,8 +857,7 @@ def test_create_with_embedding_function(db): conf = EmbeddingFunctionConfig( source_column="text", vector_column="vector", function=func ) - table = LanceTable.create( - db, + table = mem_db.create_table( "my_table", schema=MyTable, embedding_functions=[conf], @@ -892,24 +872,23 @@ def test_create_with_embedding_function(db): assert actual == expected -def test_create_f16_table(db): +def test_create_f16_table(mem_db: DBConnection): class MyTable(LanceModel): text: str - vector: Vector(128, value_type=pa.float16()) + vector: Vector(32, value_type=pa.float16()) df = pd.DataFrame( { - "text": [f"s-{i}" for i in range(10000)], - "vector": [np.random.randn(128).astype(np.float16) for _ in range(10000)], + "text": [f"s-{i}" for i in range(512)], + "vector": [np.random.randn(32).astype(np.float16) for _ in range(512)], } ) - table = LanceTable.create( - db, + table = mem_db.create_table( "f16_tbl", schema=MyTable, ) table.add(df) - table.create_index(num_partitions=2, num_sub_vectors=8) + table.create_index(num_partitions=2, num_sub_vectors=2) query = df.vector.iloc[2] expected = table.search(query).limit(2).to_arrow() @@ -917,14 +896,14 @@ def test_create_f16_table(db): assert "s-2" in expected["text"].to_pylist() -def test_add_with_embedding_function(db): +def test_add_with_embedding_function(mem_db: DBConnection): emb = EmbeddingFunctionRegistry.get_instance().get("test")() class MyTable(LanceModel): text: str = emb.SourceField() vector: Vector(emb.ndims()) = emb.VectorField() - table = LanceTable.create(db, "my_table", schema=MyTable) + table = mem_db.create_table("my_table", schema=MyTable) texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"] df = pd.DataFrame({"text": texts}) @@ -941,14 +920,13 @@ def test_add_with_embedding_function(db): assert actual == expected -def test_multiple_vector_columns(db): +def test_multiple_vector_columns(mem_db: DBConnection): class MyTable(LanceModel): text: str vector1: Vector(10) vector2: Vector(10) - table = LanceTable.create( - db, + table = mem_db.create_table( "my_table", schema=MyTable, ) @@ -969,23 +947,22 @@ def test_multiple_vector_columns(db): assert result1["text"].iloc[0] != result2["text"].iloc[0] -def test_create_scalar_index(db): +def test_create_scalar_index(mem_db: DBConnection): vec_array = pa.array( [[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]], pa.list_(pa.float32(), 2) ) test_data = pa.Table.from_pydict( {"x": ["c", "b", "a", "e", "b"], "y": [1, 2, 3, 4, 5], "vector": vec_array} ) - table = LanceTable.create( - db, + table = mem_db.create_table( "my_table", data=test_data, ) table.create_scalar_index("x") - indices = table.to_lance().list_indices() + indices = table.list_indices() assert len(indices) == 1 scalar_index = indices[0] - assert scalar_index["type"] == "BTree" + assert scalar_index.index_type == "BTree" # Confirm that prefiltering still works with the scalar index column results = table.search().where("x = 'c'").to_arrow() @@ -996,9 +973,8 @@ def test_create_scalar_index(db): assert results["_distance"][0].as_py() > 0 -def test_empty_query(db): - table = LanceTable.create( - db, +def test_empty_query(mem_db: DBConnection): + table = mem_db.create_table( "my_table", data=[{"text": "foo", "id": 0}, {"text": "bar", "id": 1}], ) @@ -1006,7 +982,7 @@ def test_empty_query(db): val = df.id.iloc[0] assert val == 1 - table = LanceTable.create(db, "my_table2", data=[{"id": i} for i in range(100)]) + table = mem_db.create_table("my_table2", data=[{"id": i} for i in range(100)]) df = table.search().select(["id"]).to_pandas() assert len(df) == 10 # None is the same as default @@ -1020,13 +996,12 @@ def test_empty_query(db): assert len(df) == 42 -def test_search_with_schema_inf_single_vector(db): +def test_search_with_schema_inf_single_vector(mem_db: DBConnection): class MyTable(LanceModel): text: str vector_col: Vector(10) - table = LanceTable.create( - db, + table = mem_db.create_table( "my_table", schema=MyTable, ) @@ -1047,14 +1022,13 @@ def test_search_with_schema_inf_single_vector(db): assert result1["text"].iloc[0] == result2["text"].iloc[0] -def test_search_with_schema_inf_multiple_vector(db): +def test_search_with_schema_inf_multiple_vector(mem_db: DBConnection): class MyTable(LanceModel): text: str vector1: Vector(10) vector2: Vector(10) - table = LanceTable.create( - db, + table = mem_db.create_table( "my_table", schema=MyTable, ) @@ -1073,21 +1047,20 @@ def test_search_with_schema_inf_multiple_vector(db): table.search(q).limit(1).to_pandas() -def test_compact_cleanup(db): - table = LanceTable.create( - db, +def test_compact_cleanup(tmp_db: DBConnection): + table = tmp_db.create_table( "my_table", data=[{"text": "foo", "id": 0}, {"text": "bar", "id": 1}], ) table.add([{"text": "baz", "id": 2}]) assert len(table) == 3 - assert table.version == 3 + assert table.version == 2 stats = table.compact_files() assert len(table) == 3 # Compact_files bump 2 versions. - assert table.version == 5 + assert table.version == 4 assert stats.fragments_removed > 0 assert stats.fragments_added == 1 @@ -1096,15 +1069,14 @@ def test_compact_cleanup(db): stats = table.cleanup_old_versions(older_than=timedelta(0), delete_unverified=True) assert stats.bytes_removed > 0 - assert table.version == 5 + assert table.version == 4 with pytest.raises(Exception, match="Version 3 no longer exists"): table.checkout(3) -def test_count_rows(db): - table = LanceTable.create( - db, +def test_count_rows(mem_db: DBConnection): + table = mem_db.create_table( "my_table", data=[{"text": "foo", "id": 0}, {"text": "bar", "id": 1}], ) @@ -1113,8 +1085,7 @@ def test_count_rows(db): assert table.count_rows(filter="text='bar'") == 1 -def setup_hybrid_search_table(tmp_path, embedding_func): - db = MockDB(str(tmp_path)) +def setup_hybrid_search_table(db: DBConnection, embedding_func): # Create a LanceDB table schema with a vector and a text column emb = EmbeddingFunctionRegistry.get_instance().get(embedding_func)() @@ -1123,8 +1094,7 @@ def setup_hybrid_search_table(tmp_path, embedding_func): vector: Vector(emb.ndims()) = emb.VectorField() # Initialize the table using the schema - table = LanceTable.create( - db, + table = db.create_table( "my_table", schema=MyTable, ) @@ -1152,11 +1122,11 @@ def setup_hybrid_search_table(tmp_path, embedding_func): return table, MyTable, emb -def test_hybrid_search(tmp_path): +def test_hybrid_search(tmp_db: DBConnection): # This test uses an FTS index pytest.importorskip("lancedb.fts") - table, MyTable, emb = setup_hybrid_search_table(tmp_path, "test") + table, MyTable, emb = setup_hybrid_search_table(tmp_db, "test") result1 = ( table.search("Our father who art in heaven", query_type="hybrid") @@ -1222,13 +1192,13 @@ def test_hybrid_search(tmp_path): table.search(query_type="hybrid").text("Arrrrggghhhhhhh").to_list() -def test_hybrid_search_metric_type(db, tmp_path): +def test_hybrid_search_metric_type(tmp_db: DBConnection): # This test uses an FTS index pytest.importorskip("lancedb.fts") # Need to use nonnorm as the embedding function so L2 and dot results # are different - table, _, _ = setup_hybrid_search_table(tmp_path, "nonnorm") + table, _, _ = setup_hybrid_search_table(tmp_db, "nonnorm") # with custom metric result_dot = ( @@ -1245,10 +1215,13 @@ def test_hybrid_search_metric_type(db, tmp_path): ) def test_consistency(tmp_path, consistency_interval): db = lancedb.connect(tmp_path) - table = LanceTable.create(db, "my_table", data=[{"id": 0}]) + table = db.create_table("my_table", data=[{"id": 0}]) db2 = lancedb.connect(tmp_path, read_consistency_interval=consistency_interval) table2 = db2.open_table("my_table") + if consistency_interval is not None: + assert "read_consistency_interval=datetime.timedelta(" in repr(db2) + assert "read_consistency_interval=datetime.timedelta(" in repr(table2) assert table2.version == table.version table.add([{"id": 1}]) @@ -1268,28 +1241,26 @@ def test_consistency(tmp_path, consistency_interval): def test_restore_consistency(tmp_path): db = lancedb.connect(tmp_path) - table = LanceTable.create(db, "my_table", data=[{"id": 0}]) + table = db.create_table("my_table", data=[{"id": 0}]) + assert table.version == 1 db2 = lancedb.connect(tmp_path, read_consistency_interval=timedelta(seconds=0)) table2 = db2.open_table("my_table") assert table2.version == table.version # If we call checkout, it should lose consistency - table_fixed = copy(table2) - table_fixed.checkout(table.version) - # But if we call checkout_latest, it should be consistent again - table_ref_latest = copy(table_fixed) - table_ref_latest.checkout_latest() + table2.checkout(table.version) table.add([{"id": 2}]) - assert table_fixed.version == table.version - 1 - assert table_ref_latest.version == table.version + assert table2.version == 1 + # But if we call checkout_latest, it should be consistent again + table2.checkout_latest() + assert table2.version == table.version # Schema evolution -def test_add_columns(tmp_path): - db = lancedb.connect(tmp_path) +def test_add_columns(mem_db: DBConnection): data = pa.table({"id": [0, 1]}) - table = LanceTable.create(db, "my_table", data=data) + table = LanceTable.create(mem_db, "my_table", data=data) table.add_columns({"new_col": "id + 2"}) assert table.to_arrow().column_names == ["id", "new_col"] assert table.to_arrow()["new_col"].to_pylist() == [2, 3] @@ -1299,27 +1270,26 @@ def test_add_columns(tmp_path): @pytest.mark.asyncio -async def test_add_columns_async(db_async: AsyncConnection): +async def test_add_columns_async(mem_db_async: AsyncConnection): data = pa.table({"id": [0, 1]}) - table = await db_async.create_table("my_table", data=data) + table = await mem_db_async.create_table("my_table", data=data) await table.add_columns({"new_col": "id + 2"}) data = await table.to_arrow() assert data.column_names == ["id", "new_col"] assert data["new_col"].to_pylist() == [2, 3] -def test_alter_columns(tmp_path): - db = lancedb.connect(tmp_path) +def test_alter_columns(mem_db: DBConnection): data = pa.table({"id": [0, 1]}) - table = LanceTable.create(db, "my_table", data=data) + table = mem_db.create_table("my_table", data=data) table.alter_columns({"path": "id", "rename": "new_id"}) assert table.to_arrow().column_names == ["new_id"] @pytest.mark.asyncio -async def test_alter_columns_async(db_async: AsyncConnection): +async def test_alter_columns_async(mem_db_async: AsyncConnection): data = pa.table({"id": [0, 1]}) - table = await db_async.create_table("my_table", data=data) + table = await mem_db_async.create_table("my_table", data=data) await table.alter_columns({"path": "id", "rename": "new_id"}) assert (await table.to_arrow()).column_names == ["new_id"] await table.alter_columns(dict(path="new_id", data_type=pa.int16(), nullable=True)) @@ -1328,26 +1298,25 @@ async def test_alter_columns_async(db_async: AsyncConnection): assert data.schema.field(0).nullable -def test_drop_columns(tmp_path): - db = lancedb.connect(tmp_path) +def test_drop_columns(mem_db: DBConnection): data = pa.table({"id": [0, 1], "category": ["a", "b"]}) - table = LanceTable.create(db, "my_table", data=data) + table = mem_db.create_table("my_table", data=data) table.drop_columns(["category"]) assert table.to_arrow().column_names == ["id"] @pytest.mark.asyncio -async def test_drop_columns_async(db_async: AsyncConnection): +async def test_drop_columns_async(mem_db_async: AsyncConnection): data = pa.table({"id": [0, 1], "category": ["a", "b"]}) - table = await db_async.create_table("my_table", data=data) + table = await mem_db_async.create_table("my_table", data=data) await table.drop_columns(["category"]) assert (await table.to_arrow()).column_names == ["id"] @pytest.mark.asyncio -async def test_time_travel(db_async: AsyncConnection): +async def test_time_travel(mem_db_async: AsyncConnection): # Setup - table = await db_async.create_table("some_table", data=[{"id": 0}]) + table = await mem_db_async.create_table("some_table", data=[{"id": 0}]) version = await table.version() await table.add([{"id": 1}]) assert await table.count_rows() == 2 @@ -1378,9 +1347,8 @@ async def test_time_travel(db_async: AsyncConnection): await table.restore() -def test_sync_optimize(db): - table = LanceTable.create( - db, +def test_sync_optimize(mem_db: DBConnection): + table = mem_db.create_table( "test", data=[ {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, @@ -1389,20 +1357,19 @@ def test_sync_optimize(db): ) table.create_scalar_index("price", index_type="BTREE") - stats = table.to_lance().stats.index_stats("price_idx") + stats = table.index_stats("price_idx") assert stats["num_indexed_rows"] == 2 table.add([{"vector": [2.0, 2.0], "item": "baz", "price": 30.0}]) assert table.count_rows() == 3 table.optimize() - stats = table.to_lance().stats.index_stats("price_idx") + stats = table.index_stats("price_idx") assert stats["num_indexed_rows"] == 3 @pytest.mark.asyncio -async def test_sync_optimize_in_async(db): - table = LanceTable.create( - db, +async def test_sync_optimize_in_async(mem_db: DBConnection): + table = mem_db.create_table( "test", data=[ {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, @@ -1411,24 +1378,17 @@ async def test_sync_optimize_in_async(db): ) table.create_scalar_index("price", index_type="BTREE") - stats = table.to_lance().stats.index_stats("price_idx") + stats = table.index_stats("price_idx") assert stats["num_indexed_rows"] == 2 table.add([{"vector": [2.0, 2.0], "item": "baz", "price": 30.0}]) assert table.count_rows() == 3 - try: - table.optimize() - except Exception as e: - assert ( - "Synchronous method called in asynchronous context. " - "If you are writing an asynchronous application " - "then please use the asynchronous APIs" in str(e) - ) + table.optimize() @pytest.mark.asyncio -async def test_optimize(db_async: AsyncConnection): - table = await db_async.create_table( +async def test_optimize(mem_db_async: AsyncConnection): + table = await mem_db_async.create_table( "test", data=[{"x": [1]}], ) @@ -1459,8 +1419,8 @@ async def test_optimize(db_async: AsyncConnection): @pytest.mark.asyncio -async def test_optimize_delete_unverified(db_async: AsyncConnection, tmp_path): - table = await db_async.create_table( +async def test_optimize_delete_unverified(tmp_db_async: AsyncConnection, tmp_path): + table = await tmp_db_async.create_table( "test", data=[{"x": [1]}], ) diff --git a/python/python/tests/utils.py b/python/python/tests/utils.py new file mode 100644 index 00000000..62ec7449 --- /dev/null +++ b/python/python/tests/utils.py @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The LanceDB Authors +import pytest + + +def exception_output(e_info: pytest.ExceptionInfo): + import traceback + + # skip traceback part, since it's not worth checking in tests + lines = traceback.format_exception_only(e_info.type, e_info.value) + return "".join(lines).strip() diff --git a/python/src/connection.rs b/python/src/connection.rs index 5648dfd9..d1263f24 100644 --- a/python/src/connection.rs +++ b/python/src/connection.rs @@ -58,6 +58,11 @@ impl Connection { self.inner.take(); } + #[getter] + pub fn uri(&self) -> PyResult { + self.get_inner().map(|inner| inner.uri().to_string()) + } + #[pyo3(signature = (start_after=None, limit=None))] pub fn table_names( self_: PyRef<'_, Self>, diff --git a/python/src/index.rs b/python/src/index.rs index 1e9ff260..be6c2269 100644 --- a/python/src/index.rs +++ b/python/src/index.rs @@ -12,224 +12,153 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Mutex; - -use lancedb::index::scalar::FtsIndexBuilder; -use lancedb::{ - index::{ - scalar::BTreeIndexBuilder, - vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder}, - Index as LanceDbIndex, - }, - DistanceType, +use lancedb::index::{ + scalar::{BTreeIndexBuilder, FtsIndexBuilder, TokenizerConfig}, + vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder}, + Index as LanceDbIndex, }; use pyo3::{ - exceptions::{PyKeyError, PyRuntimeError, PyValueError}, - pyclass, pymethods, IntoPy, PyObject, PyResult, Python, + exceptions::{PyKeyError, PyValueError}, + intern, pyclass, pymethods, + types::PyAnyMethods, + Bound, FromPyObject, IntoPy, PyAny, PyObject, PyResult, Python, }; use crate::util::parse_distance_type; -#[pyclass] -pub struct Index { - inner: Mutex>, -} - -impl Index { - pub fn consume(&self) -> PyResult { - self.inner - .lock() - .unwrap() - .take() - .ok_or_else(|| PyRuntimeError::new_err("cannot use an Index more than once")) +pub fn class_name<'a>(ob: &'a Bound<'_, PyAny>) -> PyResult<&'a str> { + let full_name: &str = ob + .getattr(intern!(ob.py(), "__class__"))? + .getattr(intern!(ob.py(), "__name__"))? + .extract()?; + match full_name.rsplit_once('.') { + Some((_, name)) => Ok(name), + None => Ok(full_name), } } -#[pymethods] -impl Index { - #[pyo3(signature = (distance_type=None, num_partitions=None, num_sub_vectors=None,num_bits=None, max_iterations=None, sample_rate=None))] - #[staticmethod] - pub fn ivf_pq( - distance_type: Option, - num_partitions: Option, - num_sub_vectors: Option, - num_bits: Option, - max_iterations: Option, - sample_rate: Option, - ) -> PyResult { - let mut ivf_pq_builder = IvfPqIndexBuilder::default(); - if let Some(distance_type) = distance_type { - let distance_type = match distance_type.as_str() { - "l2" => Ok(DistanceType::L2), - "cosine" => Ok(DistanceType::Cosine), - "dot" => Ok(DistanceType::Dot), - _ => Err(PyValueError::new_err(format!( - "Invalid distance type '{}'. Must be one of l2, cosine, or dot", - distance_type - ))), - }?; - ivf_pq_builder = ivf_pq_builder.distance_type(distance_type); +pub fn extract_index_params(source: &Option>) -> PyResult { + if let Some(source) = source { + match class_name(source)? { + "BTree" => Ok(LanceDbIndex::BTree(BTreeIndexBuilder::default())), + "Bitmap" => Ok(LanceDbIndex::Bitmap(Default::default())), + "LabelList" => Ok(LanceDbIndex::LabelList(Default::default())), + "FTS" => { + let params = source.extract::()?; + let inner_opts = TokenizerConfig::default() + .base_tokenizer(params.base_tokenizer) + .language(¶ms.language) + .map_err(|_| PyValueError::new_err(format!("LanceDB does not support the requested language: '{}'", params.language)))? + .lower_case(params.lower_case) + .max_token_length(params.max_token_length) + .remove_stop_words(params.remove_stop_words) + .stem(params.stem) + .ascii_folding(params.ascii_folding); + let mut opts = FtsIndexBuilder::default() + .with_position(params.with_position); + opts.tokenizer_configs = inner_opts; + Ok(LanceDbIndex::FTS(opts)) + }, + "IvfPq" => { + let params = source.extract::()?; + let distance_type = parse_distance_type(params.distance_type)?; + let mut ivf_pq_builder = IvfPqIndexBuilder::default() + .distance_type(distance_type) + .max_iterations(params.max_iterations) + .sample_rate(params.sample_rate) + .num_bits(params.num_bits); + if let Some(num_partitions) = params.num_partitions { + ivf_pq_builder = ivf_pq_builder.num_partitions(num_partitions); + } + if let Some(num_sub_vectors) = params.num_sub_vectors { + ivf_pq_builder = ivf_pq_builder.num_sub_vectors(num_sub_vectors); + } + Ok(LanceDbIndex::IvfPq(ivf_pq_builder)) + }, + "HnswPq" => { + let params = source.extract::()?; + let distance_type = parse_distance_type(params.distance_type)?; + let mut hnsw_pq_builder = IvfHnswPqIndexBuilder::default() + .distance_type(distance_type) + .max_iterations(params.max_iterations) + .sample_rate(params.sample_rate) + .num_edges(params.m) + .ef_construction(params.ef_construction) + .num_bits(params.num_bits); + if let Some(num_partitions) = params.num_partitions { + hnsw_pq_builder = hnsw_pq_builder.num_partitions(num_partitions); + } + if let Some(num_sub_vectors) = params.num_sub_vectors { + hnsw_pq_builder = hnsw_pq_builder.num_sub_vectors(num_sub_vectors); + } + Ok(LanceDbIndex::IvfHnswPq(hnsw_pq_builder)) + }, + "HnswSq" => { + let params = source.extract::()?; + let distance_type = parse_distance_type(params.distance_type)?; + let mut hnsw_sq_builder = IvfHnswSqIndexBuilder::default() + .distance_type(distance_type) + .max_iterations(params.max_iterations) + .sample_rate(params.sample_rate) + .num_edges(params.m) + .ef_construction(params.ef_construction); + if let Some(num_partitions) = params.num_partitions { + hnsw_sq_builder = hnsw_sq_builder.num_partitions(num_partitions); + } + Ok(LanceDbIndex::IvfHnswSq(hnsw_sq_builder)) + }, + not_supported => Err(PyValueError::new_err(format!( + "Invalid index type '{}'. Must be one of BTree, Bitmap, LabelList, FTS, IvfPq, IvfHnswPq, or IvfHnswSq", + not_supported + ))), } - if let Some(num_partitions) = num_partitions { - ivf_pq_builder = ivf_pq_builder.num_partitions(num_partitions); - } - if let Some(num_sub_vectors) = num_sub_vectors { - ivf_pq_builder = ivf_pq_builder.num_sub_vectors(num_sub_vectors); - } - if let Some(num_bits) = num_bits { - ivf_pq_builder = ivf_pq_builder.num_bits(num_bits); - } - if let Some(max_iterations) = max_iterations { - ivf_pq_builder = ivf_pq_builder.max_iterations(max_iterations); - } - if let Some(sample_rate) = sample_rate { - ivf_pq_builder = ivf_pq_builder.sample_rate(sample_rate); - } - Ok(Self { - inner: Mutex::new(Some(LanceDbIndex::IvfPq(ivf_pq_builder))), - }) + } else { + Ok(LanceDbIndex::Auto) } +} - #[staticmethod] - pub fn btree() -> PyResult { - Ok(Self { - inner: Mutex::new(Some(LanceDbIndex::BTree(BTreeIndexBuilder::default()))), - }) - } +#[derive(FromPyObject)] +struct FtsParams { + with_position: bool, + base_tokenizer: String, + language: String, + max_token_length: Option, + lower_case: bool, + stem: bool, + remove_stop_words: bool, + ascii_folding: bool, +} - #[staticmethod] - pub fn bitmap() -> PyResult { - Ok(Self { - inner: Mutex::new(Some(LanceDbIndex::Bitmap(Default::default()))), - }) - } +#[derive(FromPyObject)] +struct IvfPqParams { + distance_type: String, + num_partitions: Option, + num_sub_vectors: Option, + num_bits: u32, + max_iterations: u32, + sample_rate: u32, +} - #[staticmethod] - pub fn label_list() -> PyResult { - Ok(Self { - inner: Mutex::new(Some(LanceDbIndex::LabelList(Default::default()))), - }) - } +#[derive(FromPyObject)] +struct IvfHnswPqParams { + distance_type: String, + num_partitions: Option, + num_sub_vectors: Option, + num_bits: u32, + max_iterations: u32, + sample_rate: u32, + m: u32, + ef_construction: u32, +} - #[pyo3(signature = (with_position=None, base_tokenizer=None, language=None, max_token_length=None, lower_case=None, stem=None, remove_stop_words=None, ascii_folding=None))] - #[allow(clippy::too_many_arguments)] - #[staticmethod] - pub fn fts( - with_position: Option, - base_tokenizer: Option, - language: Option, - max_token_length: Option, - lower_case: Option, - stem: Option, - remove_stop_words: Option, - ascii_folding: Option, - ) -> Self { - let mut opts = FtsIndexBuilder::default(); - if let Some(with_position) = with_position { - opts = opts.with_position(with_position); - } - if let Some(base_tokenizer) = base_tokenizer { - opts.tokenizer_configs = opts.tokenizer_configs.base_tokenizer(base_tokenizer); - } - if let Some(language) = language { - opts.tokenizer_configs = opts.tokenizer_configs.language(&language).unwrap(); - } - opts.tokenizer_configs = opts.tokenizer_configs.max_token_length(max_token_length); - if let Some(lower_case) = lower_case { - opts.tokenizer_configs = opts.tokenizer_configs.lower_case(lower_case); - } - if let Some(stem) = stem { - opts.tokenizer_configs = opts.tokenizer_configs.stem(stem); - } - if let Some(remove_stop_words) = remove_stop_words { - opts.tokenizer_configs = opts.tokenizer_configs.remove_stop_words(remove_stop_words); - } - if let Some(ascii_folding) = ascii_folding { - opts.tokenizer_configs = opts.tokenizer_configs.ascii_folding(ascii_folding); - } - Self { - inner: Mutex::new(Some(LanceDbIndex::FTS(opts))), - } - } - - #[pyo3(signature = (distance_type=None, num_partitions=None, num_sub_vectors=None,num_bits=None, max_iterations=None, sample_rate=None, m=None, ef_construction=None))] - #[staticmethod] - #[allow(clippy::too_many_arguments)] - pub fn hnsw_pq( - distance_type: Option, - num_partitions: Option, - num_sub_vectors: Option, - num_bits: Option, - max_iterations: Option, - sample_rate: Option, - m: Option, - ef_construction: Option, - ) -> PyResult { - let mut hnsw_pq_builder = IvfHnswPqIndexBuilder::default(); - if let Some(distance_type) = distance_type { - let distance_type = parse_distance_type(distance_type)?; - hnsw_pq_builder = hnsw_pq_builder.distance_type(distance_type); - } - if let Some(num_partitions) = num_partitions { - hnsw_pq_builder = hnsw_pq_builder.num_partitions(num_partitions); - } - if let Some(num_sub_vectors) = num_sub_vectors { - hnsw_pq_builder = hnsw_pq_builder.num_sub_vectors(num_sub_vectors); - } - if let Some(num_bits) = num_bits { - hnsw_pq_builder = hnsw_pq_builder.num_bits(num_bits); - } - if let Some(max_iterations) = max_iterations { - hnsw_pq_builder = hnsw_pq_builder.max_iterations(max_iterations); - } - if let Some(sample_rate) = sample_rate { - hnsw_pq_builder = hnsw_pq_builder.sample_rate(sample_rate); - } - if let Some(m) = m { - hnsw_pq_builder = hnsw_pq_builder.num_edges(m); - } - if let Some(ef_construction) = ef_construction { - hnsw_pq_builder = hnsw_pq_builder.ef_construction(ef_construction); - } - Ok(Self { - inner: Mutex::new(Some(LanceDbIndex::IvfHnswPq(hnsw_pq_builder))), - }) - } - - #[pyo3(signature = (distance_type=None, num_partitions=None, max_iterations=None, sample_rate=None, m=None, ef_construction=None))] - #[staticmethod] - pub fn hnsw_sq( - distance_type: Option, - num_partitions: Option, - max_iterations: Option, - sample_rate: Option, - m: Option, - ef_construction: Option, - ) -> PyResult { - let mut hnsw_sq_builder = IvfHnswSqIndexBuilder::default(); - if let Some(distance_type) = distance_type { - let distance_type = parse_distance_type(distance_type)?; - hnsw_sq_builder = hnsw_sq_builder.distance_type(distance_type); - } - if let Some(num_partitions) = num_partitions { - hnsw_sq_builder = hnsw_sq_builder.num_partitions(num_partitions); - } - if let Some(max_iterations) = max_iterations { - hnsw_sq_builder = hnsw_sq_builder.max_iterations(max_iterations); - } - if let Some(sample_rate) = sample_rate { - hnsw_sq_builder = hnsw_sq_builder.sample_rate(sample_rate); - } - if let Some(m) = m { - hnsw_sq_builder = hnsw_sq_builder.num_edges(m); - } - if let Some(ef_construction) = ef_construction { - hnsw_sq_builder = hnsw_sq_builder.ef_construction(ef_construction); - } - Ok(Self { - inner: Mutex::new(Some(LanceDbIndex::IvfHnswSq(hnsw_sq_builder))), - }) - } +#[derive(FromPyObject)] +struct IvfHnswSqParams { + distance_type: String, + num_partitions: Option, + max_iterations: u32, + sample_rate: u32, + m: u32, + ef_construction: u32, } #[pyclass(get_all)] diff --git a/python/src/lib.rs b/python/src/lib.rs index 01e39cae..a68e7711 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -15,7 +15,7 @@ use arrow::RecordBatchStream; use connection::{connect, Connection}; use env_logger::Env; -use index::{Index, IndexConfig}; +use index::IndexConfig; use pyo3::{ pymodule, types::{PyModule, PyModuleMethods}, @@ -40,7 +40,6 @@ pub fn _lancedb(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { env_logger::init_from_env(env); m.add_class::()?; m.add_class::()?; - m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/python/src/table.rs b/python/src/table.rs index a5f446ec..c52f2a9f 100644 --- a/python/src/table.rs +++ b/python/src/table.rs @@ -19,7 +19,7 @@ use pyo3_async_runtimes::tokio::future_into_py; use crate::{ error::PythonErrorExt, - index::{Index, IndexConfig}, + index::{extract_index_params, IndexConfig}, query::Query, }; @@ -177,14 +177,10 @@ impl Table { pub fn create_index<'a>( self_: PyRef<'a, Self>, column: String, - index: Option<&Index>, + index: Option>, replace: Option, ) -> PyResult> { - let index = if let Some(index) = index { - index.consume()? - } else { - lancedb::index::Index::Auto - }; + let index = extract_index_params(&index)?; let mut op = self_.inner_ref()?.create_index(&[column], index); if let Some(replace) = replace { op = op.replace(replace); diff --git a/rust/lancedb/src/connection.rs b/rust/lancedb/src/connection.rs index 990693b9..ebe29910 100644 --- a/rust/lancedb/src/connection.rs +++ b/rust/lancedb/src/connection.rs @@ -1050,6 +1050,8 @@ impl ConnectionInternal for Database { write_params.enable_v2_manifest_paths = options.enable_v2_manifest_paths.unwrap_or_default(); + let data_schema = data.schema(); + match NativeTable::create( &table_uri, &options.name, @@ -1069,7 +1071,18 @@ impl ConnectionInternal for Database { CreateTableMode::ExistOk(callback) => { let builder = OpenTableBuilder::new(options.parent, options.name); let builder = (callback)(builder); - builder.execute().await + let table = builder.execute().await?; + + let table_schema = table.schema().await?; + + if table_schema != data_schema { + return Err(Error::Schema { + message: "Provided schema does not match existing table schema" + .to_string(), + }); + } + + Ok(table) } CreateTableMode::Overwrite => unreachable!(), }, diff --git a/rust/lancedb/src/index/scalar.rs b/rust/lancedb/src/index/scalar.rs index 8003688a..b2f27c7d 100644 --- a/rust/lancedb/src/index/scalar.rs +++ b/rust/lancedb/src/index/scalar.rs @@ -77,5 +77,5 @@ impl FtsIndexBuilder { } } -use lance_index::scalar::inverted::TokenizerConfig; +pub use lance_index::scalar::inverted::TokenizerConfig; pub use lance_index::scalar::FullTextSearchQuery;