diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index c0edc098..4323faf7 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -30,10 +30,10 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: "3.11" + python-version: "3.12" - name: Install ruff run: | - pip install ruff==0.5.4 + pip install ruff==0.8.4 - name: Format check run: ruff format --check . - name: Lint diff --git a/python/pyproject.toml b/python/pyproject.toml index 62197bcd..d556024e 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -53,8 +53,9 @@ tests = [ "pytz", "polars>=0.19, <=1.3.0", "tantivy", + "pyarrow-stubs" ] -dev = ["ruff", "pre-commit"] +dev = ["ruff", "pre-commit", "pyright"] docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"] clip = ["torch", "pillow", "open-clip"] embeddings = [ @@ -94,3 +95,7 @@ markers = [ "asyncio", "s3_test", ] + +[tool.pyright] +include = ["python/lancedb/table.py"] +pythonVersion = "3.12" diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index 617bdddf..b431e3be 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -1,7 +1,9 @@ -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Any, Union, Literal import pyarrow as pa +from .index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq, FTS + class Connection(object): uri: str async def table_names( @@ -31,16 +33,35 @@ class Connection(object): class Table: def name(self) -> str: ... def __repr__(self) -> str: ... + def is_open(self) -> bool: ... + def close(self) -> None: ... async def schema(self) -> pa.Schema: ... - async def add(self, data: pa.RecordBatchReader, mode: str) -> None: ... + async def add( + self, data: pa.RecordBatchReader, mode: Literal["append", "overwrite"] + ) -> None: ... async def update(self, updates: Dict[str, str], where: Optional[str]) -> None: ... async def count_rows(self, filter: Optional[str]) -> int: ... - async def create_index(self, column: str, config, replace: Optional[bool]): ... + async def create_index( + self, + column: str, + index: Union[IvfFlat, IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS], + replace: Optional[bool], + ): ... + async def list_versions(self) -> List[Dict[str, Any]]: ... async def version(self) -> int: ... async def checkout(self, version: int): ... async def checkout_latest(self): ... async def restore(self): ... - async def list_indices(self) -> List[IndexConfig]: ... + async def list_indices(self) -> list[IndexConfig]: ... + async def delete(self, filter: str): ... + async def add_columns(self, columns: list[tuple[str, str]]) -> None: ... + async def alter_columns(self, columns: list[dict[str, Any]]) -> None: ... + async def optimize( + self, + *, + cleanup_since_ms: Optional[int] = None, + delete_unverified: Optional[bool] = None, + ) -> OptimizeStats: ... def query(self) -> Query: ... def vector_search(self) -> VectorQuery: ... diff --git a/python/python/lancedb/db.py b/python/python/lancedb/db.py index 278ed0b2..80ea876e 100644 --- a/python/python/lancedb/db.py +++ b/python/python/lancedb/db.py @@ -603,7 +603,7 @@ class AsyncConnection(object): fill_value: Optional[float] = None, storage_options: Optional[Dict[str, str]] = None, *, - embedding_functions: List[EmbeddingFunctionConfig] = None, + embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None, data_storage_version: Optional[str] = None, use_legacy_format: Optional[bool] = None, enable_v2_manifest_paths: Optional[bool] = None, diff --git a/python/python/lancedb/fts.py b/python/python/lancedb/fts.py index 1b7adde5..ab954116 100644 --- a/python/python/lancedb/fts.py +++ b/python/python/lancedb/fts.py @@ -1,20 +1,10 @@ -# Copyright 2023 LanceDB Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The LanceDB Authors """Full text search index using tantivy-py""" import os -from typing import List, Tuple +from typing import List, Tuple, Optional import pyarrow as pa @@ -31,7 +21,7 @@ from .table import LanceTable def create_index( index_path: str, text_fields: List[str], - ordering_fields: List[str] = None, + ordering_fields: Optional[List[str]] = None, tokenizer_name: str = "default", ) -> tantivy.Index: """ @@ -75,8 +65,8 @@ def populate_index( index: tantivy.Index, table: LanceTable, fields: List[str], - writer_heap_size: int = 1024 * 1024 * 1024, - ordering_fields: List[str] = None, + writer_heap_size: Optional[int] = None, + ordering_fields: Optional[List[str]] = None, ) -> int: """ Populate an index with data from a LanceTable @@ -99,6 +89,7 @@ def populate_index( """ if ordering_fields is None: ordering_fields = [] + writer_heap_size = writer_heap_size or 1024 * 1024 * 1024 # first check the fields exist and are string or large string type nested = [] diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 0ef77d54..59f10107 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -61,11 +61,12 @@ from .index import lang_mapping if TYPE_CHECKING: - import PIL - from lance.dataset import CleanupStats, ReaderLike - from ._lancedb import Table as LanceDBTable, OptimizeStats + from ._lancedb import Table as LanceDBTable, OptimizeStats, CompactionStats from .db import LanceDBConnection from .index import IndexConfig + from lance.dataset import CleanupStats, ReaderLike + import pandas + import PIL pd = safe_import_pandas() pl = safe_import_polars() @@ -84,7 +85,6 @@ def _pd_schema_without_embedding_funcs( ) if not embedding_functions: return schema - columns = set(columns) return pa.schema([field for field in schema if field.name in columns]) @@ -119,7 +119,7 @@ def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table: return pa.Table.from_batches(data, schema=schema) else: return pa.Table.from_pylist(data, schema=schema) - elif _check_for_pandas(data) and isinstance(data, pd.DataFrame): + elif _check_for_pandas(data) and isinstance(data, pd.DataFrame): # type: ignore raw_schema = _pd_schema_without_embedding_funcs(schema, data.columns.to_list()) table = pa.Table.from_pandas(data, preserve_index=False, schema=raw_schema) # Do not serialize Pandas metadata @@ -160,7 +160,7 @@ def _sanitize_data( metadata: Optional[dict] = None, # embedding metadata on_bad_vectors: str = "error", fill_value: float = 0.0, -): +) -> Tuple[pa.Table, pa.Schema]: data = _coerce_to_table(data, schema) if metadata: @@ -178,13 +178,17 @@ def _sanitize_data( def sanitize_create_table( - data, schema, metadata=None, on_bad_vectors="error", fill_value=0.0 + data, + schema: Union[pa.Schema, LanceModel], + metadata=None, + on_bad_vectors: str = "error", + fill_value: float = 0.0, ): if inspect.isclass(schema) and issubclass(schema, LanceModel): # convert LanceModel to pyarrow schema # note that it's possible this contains # embedding function metadata already - schema = schema.to_arrow_schema() + schema: pa.Schema = schema.to_arrow_schema() if data is not None: if metadata is None and schema is not None: @@ -272,41 +276,6 @@ def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schem return data -def _generator_to_data_and_schema( - data: Iterable, -) -> Tuple[Iterable[pa.RecordBatch], pa.Schema]: - def _with_first_generator(first, data): - yield first - yield from data - - first = next(data, None) - schema = None - if isinstance(first, pa.RecordBatch): - schema = first.schema - data = _with_first_generator(first, data) - elif isinstance(first, pa.Table): - schema = first.schema - data = _with_first_generator(first.to_batches(), data) - return data, schema - - -def _to_record_batch_generator( - data: Iterable, - schema, - metadata, - on_bad_vectors, - fill_value, -): - for batch in data: - # always convert to table because we need to sanitize the data - # and do things like add the vector column etc - if isinstance(batch, pa.RecordBatch): - batch = pa.Table.from_batches([batch]) - batch, _ = _sanitize_data(batch, schema, metadata, on_bad_vectors, fill_value) - for b in batch.to_batches(): - yield b - - def _table_path(base: str, table_name: str) -> str: """ Get a table path that can be used in PyArrow FS. @@ -404,7 +373,7 @@ class Table(ABC): """ raise NotImplementedError - def to_pandas(self) -> "pd.DataFrame": + def to_pandas(self) -> "pandas.DataFrame": """Return the table as a pandas DataFrame. Returns @@ -537,8 +506,8 @@ class Table(ABC): def create_fts_index( self, field_names: Union[str, List[str]], - ordering_field_names: Union[str, List[str]] = None, *, + ordering_field_names: Optional[Union[str, List[str]]] = None, replace: bool = False, writer_heap_size: Optional[int] = 1024 * 1024 * 1024, use_tantivy: bool = True, @@ -790,8 +759,7 @@ class Table(ABC): @abstractmethod def _execute_query( self, query: Query, batch_size: Optional[int] = None - ) -> pa.RecordBatchReader: - pass + ) -> pa.RecordBatchReader: ... @abstractmethod def _do_merge( @@ -800,8 +768,7 @@ class Table(ABC): new_data: DATA, on_bad_vectors: str, fill_value: float, - ): - pass + ): ... @abstractmethod def delete(self, where: str): @@ -1121,7 +1088,7 @@ class Table(ABC): """ @abstractmethod - def list_versions(self): + def list_versions(self) -> List[Dict[str, Any]]: """List all versions of the table""" @cached_property @@ -1244,7 +1211,7 @@ class LanceTable(Table): A PyArrow schema object.""" return LOOP.run(self._table.schema()) - def list_versions(self): + def list_versions(self) -> List[Dict[str, Any]]: """List all versions of the table""" return LOOP.run(self._table.list_versions()) @@ -1297,7 +1264,7 @@ class LanceTable(Table): """ LOOP.run(self._table.checkout_latest()) - def restore(self, version: int = None): + def restore(self, version: Optional[int] = None): """Restore a version of the table. This is an in-place operation. This creates a new version where the data is equivalent to the @@ -1338,7 +1305,7 @@ class LanceTable(Table): def count_rows(self, filter: Optional[str] = None) -> int: return LOOP.run(self._table.count_rows(filter)) - def __len__(self): + def __len__(self) -> int: return self.count_rows() def __repr__(self) -> str: @@ -1506,8 +1473,8 @@ class LanceTable(Table): def create_fts_index( self, field_names: Union[str, List[str]], - ordering_field_names: Union[str, List[str]] = None, *, + ordering_field_names: Optional[Union[str, List[str]]] = None, replace: bool = False, writer_heap_size: Optional[int] = 1024 * 1024 * 1024, use_tantivy: bool = True, @@ -1594,6 +1561,7 @@ class LanceTable(Table): writer_heap_size=writer_heap_size, ) + @staticmethod def infer_tokenizer_configs(tokenizer_name: str) -> dict: if tokenizer_name == "default": return { @@ -1759,7 +1727,7 @@ class LanceTable(Table): ) @overload - def search( + def search( # type: ignore self, query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None, vector_column_name: Optional[str] = None, @@ -1895,11 +1863,11 @@ class LanceTable(Table): name: str, data: Optional[DATA] = None, schema: Optional[pa.Schema] = None, - mode: Literal["create", "overwrite", "append"] = "create", + mode: Literal["create", "overwrite"] = "create", exist_ok: bool = False, on_bad_vectors: str = "error", fill_value: float = 0.0, - embedding_functions: List[EmbeddingFunctionConfig] = None, + embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None, *, storage_options: Optional[Dict[str, str]] = None, data_storage_version: Optional[str] = None, @@ -2065,7 +2033,7 @@ class LanceTable(Table): older_than, delete_unverified=delete_unverified ) - def compact_files(self, *args, **kwargs): + def compact_files(self, *args, **kwargs) -> CompactionStats: """ Run the compaction process on the table. @@ -2450,7 +2418,7 @@ def _process_iterator(data: Iterable, schema: Optional[pa.Schema] = None) -> pa. if batch_table.schema != schema: try: batch_table = batch_table.cast(schema) - except pa.lib.ArrowInvalid: + except pa.lib.ArrowInvalid: # type: ignore raise ValueError( f"Input iterator yielded a batch with schema that " f"does not match the expected schema.\nExpected:\n{schema}\n" @@ -2710,16 +2678,17 @@ class AsyncTable: on_bad_vectors = "error" if fill_value is None: fill_value = 0.0 - data, _ = _sanitize_data( + table_and_schema: Tuple[pa.Table, pa.Schema] = _sanitize_data( data, schema, metadata=schema.metadata, on_bad_vectors=on_bad_vectors, fill_value=fill_value, ) - if isinstance(data, pa.Table): - data = pa.RecordBatchReader.from_batches(data.schema, data.to_batches()) - await self._inner.add(data, mode) + tbl, schema = table_and_schema + if isinstance(tbl, pa.Table): + data = pa.RecordBatchReader.from_batches(schema, tbl.to_batches()) + await self._inner.add(data, mode or "append") def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder: """ @@ -2977,7 +2946,7 @@ class AsyncTable: return await self._inner.update(updates_sql, where) - async def add_columns(self, transforms: Dict[str, str]): + async def add_columns(self, transforms: dict[str, str]): """ Add new columns with defined values. @@ -2990,7 +2959,7 @@ class AsyncTable: """ await self._inner.add_columns(list(transforms.items())) - async def alter_columns(self, *alterations: Iterable[Dict[str, str]]): + async def alter_columns(self, *alterations: Iterable[dict[str, Any]]): """ Alter column names and nullability. @@ -3148,9 +3117,12 @@ class AsyncTable: you have added or modified 100,000 or more records or run more than 20 data modification operations. """ + cleanup_since_ms: Optional[int] = None if cleanup_older_than is not None: - cleanup_older_than = round(cleanup_older_than.total_seconds() * 1000) - return await self._inner.optimize(cleanup_older_than, delete_unverified) + cleanup_since_ms = round(cleanup_older_than.total_seconds() * 1000) + return await self._inner.optimize( + cleanup_since_ms=cleanup_since_ms, delete_unverified=delete_unverified + ) async def list_indices(self) -> Iterable[IndexConfig]: """ diff --git a/python/src/table.rs b/python/src/table.rs index c52f2a9f..59ac29ce 100644 --- a/python/src/table.rs +++ b/python/src/table.rs @@ -97,10 +97,12 @@ impl Table { self.name.clone() } + /// Returns True if the table is open, False if it is closed. pub fn is_open(&self) -> bool { self.inner.is_some() } + /// Closes the table, releasing any resources associated with it. pub fn close(&mut self) { self.inner.take(); } @@ -301,6 +303,7 @@ impl Table { Query::new(self.inner_ref().unwrap().query()) } + /// Optimize the on-disk data by compacting and pruning old data, for better performance. #[pyo3(signature = (cleanup_since_ms=None, delete_unverified=None))] pub fn optimize( self_: PyRef<'_, Self>,