chore: add pyright static type checking and fix some of the table interface (#1996)

* Enable `pyright` in the project
* Fixed some pyright typing errors in `table.py`
This commit is contained in:
Lei Xu
2025-01-04 15:24:58 -08:00
committed by GitHub
parent 164ce397c2
commit f76c4a5ce1
7 changed files with 83 additions and 91 deletions

View File

@@ -1,7 +1,9 @@
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Any, Union, Literal
import pyarrow as pa
from .index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq, FTS
class Connection(object):
uri: str
async def table_names(
@@ -31,16 +33,35 @@ class Connection(object):
class Table:
def name(self) -> str: ...
def __repr__(self) -> str: ...
def is_open(self) -> bool: ...
def close(self) -> None: ...
async def schema(self) -> pa.Schema: ...
async def add(self, data: pa.RecordBatchReader, mode: str) -> None: ...
async def add(
self, data: pa.RecordBatchReader, mode: Literal["append", "overwrite"]
) -> None: ...
async def update(self, updates: Dict[str, str], where: Optional[str]) -> None: ...
async def count_rows(self, filter: Optional[str]) -> int: ...
async def create_index(self, column: str, config, replace: Optional[bool]): ...
async def create_index(
self,
column: str,
index: Union[IvfFlat, IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS],
replace: Optional[bool],
): ...
async def list_versions(self) -> List[Dict[str, Any]]: ...
async def version(self) -> int: ...
async def checkout(self, version: int): ...
async def checkout_latest(self): ...
async def restore(self): ...
async def list_indices(self) -> List[IndexConfig]: ...
async def list_indices(self) -> list[IndexConfig]: ...
async def delete(self, filter: str): ...
async def add_columns(self, columns: list[tuple[str, str]]) -> None: ...
async def alter_columns(self, columns: list[dict[str, Any]]) -> None: ...
async def optimize(
self,
*,
cleanup_since_ms: Optional[int] = None,
delete_unverified: Optional[bool] = None,
) -> OptimizeStats: ...
def query(self) -> Query: ...
def vector_search(self) -> VectorQuery: ...

View File

@@ -603,7 +603,7 @@ class AsyncConnection(object):
fill_value: Optional[float] = None,
storage_options: Optional[Dict[str, str]] = None,
*,
embedding_functions: List[EmbeddingFunctionConfig] = None,
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
data_storage_version: Optional[str] = None,
use_legacy_format: Optional[bool] = None,
enable_v2_manifest_paths: Optional[bool] = None,

View File

@@ -1,20 +1,10 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
"""Full text search index using tantivy-py"""
import os
from typing import List, Tuple
from typing import List, Tuple, Optional
import pyarrow as pa
@@ -31,7 +21,7 @@ from .table import LanceTable
def create_index(
index_path: str,
text_fields: List[str],
ordering_fields: List[str] = None,
ordering_fields: Optional[List[str]] = None,
tokenizer_name: str = "default",
) -> tantivy.Index:
"""
@@ -75,8 +65,8 @@ def populate_index(
index: tantivy.Index,
table: LanceTable,
fields: List[str],
writer_heap_size: int = 1024 * 1024 * 1024,
ordering_fields: List[str] = None,
writer_heap_size: Optional[int] = None,
ordering_fields: Optional[List[str]] = None,
) -> int:
"""
Populate an index with data from a LanceTable
@@ -99,6 +89,7 @@ def populate_index(
"""
if ordering_fields is None:
ordering_fields = []
writer_heap_size = writer_heap_size or 1024 * 1024 * 1024
# first check the fields exist and are string or large string type
nested = []

View File

@@ -61,11 +61,12 @@ from .index import lang_mapping
if TYPE_CHECKING:
import PIL
from lance.dataset import CleanupStats, ReaderLike
from ._lancedb import Table as LanceDBTable, OptimizeStats
from ._lancedb import Table as LanceDBTable, OptimizeStats, CompactionStats
from .db import LanceDBConnection
from .index import IndexConfig
from lance.dataset import CleanupStats, ReaderLike
import pandas
import PIL
pd = safe_import_pandas()
pl = safe_import_polars()
@@ -84,7 +85,6 @@ def _pd_schema_without_embedding_funcs(
)
if not embedding_functions:
return schema
columns = set(columns)
return pa.schema([field for field in schema if field.name in columns])
@@ -119,7 +119,7 @@ def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table:
return pa.Table.from_batches(data, schema=schema)
else:
return pa.Table.from_pylist(data, schema=schema)
elif _check_for_pandas(data) and isinstance(data, pd.DataFrame):
elif _check_for_pandas(data) and isinstance(data, pd.DataFrame): # type: ignore
raw_schema = _pd_schema_without_embedding_funcs(schema, data.columns.to_list())
table = pa.Table.from_pandas(data, preserve_index=False, schema=raw_schema)
# Do not serialize Pandas metadata
@@ -160,7 +160,7 @@ def _sanitize_data(
metadata: Optional[dict] = None, # embedding metadata
on_bad_vectors: str = "error",
fill_value: float = 0.0,
):
) -> Tuple[pa.Table, pa.Schema]:
data = _coerce_to_table(data, schema)
if metadata:
@@ -178,13 +178,17 @@ def _sanitize_data(
def sanitize_create_table(
data, schema, metadata=None, on_bad_vectors="error", fill_value=0.0
data,
schema: Union[pa.Schema, LanceModel],
metadata=None,
on_bad_vectors: str = "error",
fill_value: float = 0.0,
):
if inspect.isclass(schema) and issubclass(schema, LanceModel):
# convert LanceModel to pyarrow schema
# note that it's possible this contains
# embedding function metadata already
schema = schema.to_arrow_schema()
schema: pa.Schema = schema.to_arrow_schema()
if data is not None:
if metadata is None and schema is not None:
@@ -272,41 +276,6 @@ def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schem
return data
def _generator_to_data_and_schema(
data: Iterable,
) -> Tuple[Iterable[pa.RecordBatch], pa.Schema]:
def _with_first_generator(first, data):
yield first
yield from data
first = next(data, None)
schema = None
if isinstance(first, pa.RecordBatch):
schema = first.schema
data = _with_first_generator(first, data)
elif isinstance(first, pa.Table):
schema = first.schema
data = _with_first_generator(first.to_batches(), data)
return data, schema
def _to_record_batch_generator(
data: Iterable,
schema,
metadata,
on_bad_vectors,
fill_value,
):
for batch in data:
# always convert to table because we need to sanitize the data
# and do things like add the vector column etc
if isinstance(batch, pa.RecordBatch):
batch = pa.Table.from_batches([batch])
batch, _ = _sanitize_data(batch, schema, metadata, on_bad_vectors, fill_value)
for b in batch.to_batches():
yield b
def _table_path(base: str, table_name: str) -> str:
"""
Get a table path that can be used in PyArrow FS.
@@ -404,7 +373,7 @@ class Table(ABC):
"""
raise NotImplementedError
def to_pandas(self) -> "pd.DataFrame":
def to_pandas(self) -> "pandas.DataFrame":
"""Return the table as a pandas DataFrame.
Returns
@@ -537,8 +506,8 @@ class Table(ABC):
def create_fts_index(
self,
field_names: Union[str, List[str]],
ordering_field_names: Union[str, List[str]] = None,
*,
ordering_field_names: Optional[Union[str, List[str]]] = None,
replace: bool = False,
writer_heap_size: Optional[int] = 1024 * 1024 * 1024,
use_tantivy: bool = True,
@@ -790,8 +759,7 @@ class Table(ABC):
@abstractmethod
def _execute_query(
self, query: Query, batch_size: Optional[int] = None
) -> pa.RecordBatchReader:
pass
) -> pa.RecordBatchReader: ...
@abstractmethod
def _do_merge(
@@ -800,8 +768,7 @@ class Table(ABC):
new_data: DATA,
on_bad_vectors: str,
fill_value: float,
):
pass
): ...
@abstractmethod
def delete(self, where: str):
@@ -1121,7 +1088,7 @@ class Table(ABC):
"""
@abstractmethod
def list_versions(self):
def list_versions(self) -> List[Dict[str, Any]]:
"""List all versions of the table"""
@cached_property
@@ -1244,7 +1211,7 @@ class LanceTable(Table):
A PyArrow schema object."""
return LOOP.run(self._table.schema())
def list_versions(self):
def list_versions(self) -> List[Dict[str, Any]]:
"""List all versions of the table"""
return LOOP.run(self._table.list_versions())
@@ -1297,7 +1264,7 @@ class LanceTable(Table):
"""
LOOP.run(self._table.checkout_latest())
def restore(self, version: int = None):
def restore(self, version: Optional[int] = None):
"""Restore a version of the table. This is an in-place operation.
This creates a new version where the data is equivalent to the
@@ -1338,7 +1305,7 @@ class LanceTable(Table):
def count_rows(self, filter: Optional[str] = None) -> int:
return LOOP.run(self._table.count_rows(filter))
def __len__(self):
def __len__(self) -> int:
return self.count_rows()
def __repr__(self) -> str:
@@ -1506,8 +1473,8 @@ class LanceTable(Table):
def create_fts_index(
self,
field_names: Union[str, List[str]],
ordering_field_names: Union[str, List[str]] = None,
*,
ordering_field_names: Optional[Union[str, List[str]]] = None,
replace: bool = False,
writer_heap_size: Optional[int] = 1024 * 1024 * 1024,
use_tantivy: bool = True,
@@ -1594,6 +1561,7 @@ class LanceTable(Table):
writer_heap_size=writer_heap_size,
)
@staticmethod
def infer_tokenizer_configs(tokenizer_name: str) -> dict:
if tokenizer_name == "default":
return {
@@ -1759,7 +1727,7 @@ class LanceTable(Table):
)
@overload
def search(
def search( # type: ignore
self,
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
vector_column_name: Optional[str] = None,
@@ -1895,11 +1863,11 @@ class LanceTable(Table):
name: str,
data: Optional[DATA] = None,
schema: Optional[pa.Schema] = None,
mode: Literal["create", "overwrite", "append"] = "create",
mode: Literal["create", "overwrite"] = "create",
exist_ok: bool = False,
on_bad_vectors: str = "error",
fill_value: float = 0.0,
embedding_functions: List[EmbeddingFunctionConfig] = None,
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
*,
storage_options: Optional[Dict[str, str]] = None,
data_storage_version: Optional[str] = None,
@@ -2065,7 +2033,7 @@ class LanceTable(Table):
older_than, delete_unverified=delete_unverified
)
def compact_files(self, *args, **kwargs):
def compact_files(self, *args, **kwargs) -> CompactionStats:
"""
Run the compaction process on the table.
@@ -2450,7 +2418,7 @@ def _process_iterator(data: Iterable, schema: Optional[pa.Schema] = None) -> pa.
if batch_table.schema != schema:
try:
batch_table = batch_table.cast(schema)
except pa.lib.ArrowInvalid:
except pa.lib.ArrowInvalid: # type: ignore
raise ValueError(
f"Input iterator yielded a batch with schema that "
f"does not match the expected schema.\nExpected:\n{schema}\n"
@@ -2710,16 +2678,17 @@ class AsyncTable:
on_bad_vectors = "error"
if fill_value is None:
fill_value = 0.0
data, _ = _sanitize_data(
table_and_schema: Tuple[pa.Table, pa.Schema] = _sanitize_data(
data,
schema,
metadata=schema.metadata,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
if isinstance(data, pa.Table):
data = pa.RecordBatchReader.from_batches(data.schema, data.to_batches())
await self._inner.add(data, mode)
tbl, schema = table_and_schema
if isinstance(tbl, pa.Table):
data = pa.RecordBatchReader.from_batches(schema, tbl.to_batches())
await self._inner.add(data, mode or "append")
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
"""
@@ -2977,7 +2946,7 @@ class AsyncTable:
return await self._inner.update(updates_sql, where)
async def add_columns(self, transforms: Dict[str, str]):
async def add_columns(self, transforms: dict[str, str]):
"""
Add new columns with defined values.
@@ -2990,7 +2959,7 @@ class AsyncTable:
"""
await self._inner.add_columns(list(transforms.items()))
async def alter_columns(self, *alterations: Iterable[Dict[str, str]]):
async def alter_columns(self, *alterations: Iterable[dict[str, Any]]):
"""
Alter column names and nullability.
@@ -3148,9 +3117,12 @@ class AsyncTable:
you have added or modified 100,000 or more records or run more than 20 data
modification operations.
"""
cleanup_since_ms: Optional[int] = None
if cleanup_older_than is not None:
cleanup_older_than = round(cleanup_older_than.total_seconds() * 1000)
return await self._inner.optimize(cleanup_older_than, delete_unverified)
cleanup_since_ms = round(cleanup_older_than.total_seconds() * 1000)
return await self._inner.optimize(
cleanup_since_ms=cleanup_since_ms, delete_unverified=delete_unverified
)
async def list_indices(self) -> Iterable[IndexConfig]:
"""