mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-03 02:12:56 +00:00
chore: add pyright static type checking and fix some of the table interface (#1996)
* Enable `pyright` in the project * Fixed some pyright typing errors in `table.py`
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple, Any, Union, Literal
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from .index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq, FTS
|
||||
|
||||
class Connection(object):
|
||||
uri: str
|
||||
async def table_names(
|
||||
@@ -31,16 +33,35 @@ class Connection(object):
|
||||
class Table:
|
||||
def name(self) -> str: ...
|
||||
def __repr__(self) -> str: ...
|
||||
def is_open(self) -> bool: ...
|
||||
def close(self) -> None: ...
|
||||
async def schema(self) -> pa.Schema: ...
|
||||
async def add(self, data: pa.RecordBatchReader, mode: str) -> None: ...
|
||||
async def add(
|
||||
self, data: pa.RecordBatchReader, mode: Literal["append", "overwrite"]
|
||||
) -> None: ...
|
||||
async def update(self, updates: Dict[str, str], where: Optional[str]) -> None: ...
|
||||
async def count_rows(self, filter: Optional[str]) -> int: ...
|
||||
async def create_index(self, column: str, config, replace: Optional[bool]): ...
|
||||
async def create_index(
|
||||
self,
|
||||
column: str,
|
||||
index: Union[IvfFlat, IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS],
|
||||
replace: Optional[bool],
|
||||
): ...
|
||||
async def list_versions(self) -> List[Dict[str, Any]]: ...
|
||||
async def version(self) -> int: ...
|
||||
async def checkout(self, version: int): ...
|
||||
async def checkout_latest(self): ...
|
||||
async def restore(self): ...
|
||||
async def list_indices(self) -> List[IndexConfig]: ...
|
||||
async def list_indices(self) -> list[IndexConfig]: ...
|
||||
async def delete(self, filter: str): ...
|
||||
async def add_columns(self, columns: list[tuple[str, str]]) -> None: ...
|
||||
async def alter_columns(self, columns: list[dict[str, Any]]) -> None: ...
|
||||
async def optimize(
|
||||
self,
|
||||
*,
|
||||
cleanup_since_ms: Optional[int] = None,
|
||||
delete_unverified: Optional[bool] = None,
|
||||
) -> OptimizeStats: ...
|
||||
def query(self) -> Query: ...
|
||||
def vector_search(self) -> VectorQuery: ...
|
||||
|
||||
|
||||
@@ -603,7 +603,7 @@ class AsyncConnection(object):
|
||||
fill_value: Optional[float] = None,
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
*,
|
||||
embedding_functions: List[EmbeddingFunctionConfig] = None,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||
data_storage_version: Optional[str] = None,
|
||||
use_legacy_format: Optional[bool] = None,
|
||||
enable_v2_manifest_paths: Optional[bool] = None,
|
||||
|
||||
@@ -1,20 +1,10 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
"""Full text search index using tantivy-py"""
|
||||
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
@@ -31,7 +21,7 @@ from .table import LanceTable
|
||||
def create_index(
|
||||
index_path: str,
|
||||
text_fields: List[str],
|
||||
ordering_fields: List[str] = None,
|
||||
ordering_fields: Optional[List[str]] = None,
|
||||
tokenizer_name: str = "default",
|
||||
) -> tantivy.Index:
|
||||
"""
|
||||
@@ -75,8 +65,8 @@ def populate_index(
|
||||
index: tantivy.Index,
|
||||
table: LanceTable,
|
||||
fields: List[str],
|
||||
writer_heap_size: int = 1024 * 1024 * 1024,
|
||||
ordering_fields: List[str] = None,
|
||||
writer_heap_size: Optional[int] = None,
|
||||
ordering_fields: Optional[List[str]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Populate an index with data from a LanceTable
|
||||
@@ -99,6 +89,7 @@ def populate_index(
|
||||
"""
|
||||
if ordering_fields is None:
|
||||
ordering_fields = []
|
||||
writer_heap_size = writer_heap_size or 1024 * 1024 * 1024
|
||||
# first check the fields exist and are string or large string type
|
||||
nested = []
|
||||
|
||||
|
||||
@@ -61,11 +61,12 @@ from .index import lang_mapping
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import PIL
|
||||
from lance.dataset import CleanupStats, ReaderLike
|
||||
from ._lancedb import Table as LanceDBTable, OptimizeStats
|
||||
from ._lancedb import Table as LanceDBTable, OptimizeStats, CompactionStats
|
||||
from .db import LanceDBConnection
|
||||
from .index import IndexConfig
|
||||
from lance.dataset import CleanupStats, ReaderLike
|
||||
import pandas
|
||||
import PIL
|
||||
|
||||
pd = safe_import_pandas()
|
||||
pl = safe_import_polars()
|
||||
@@ -84,7 +85,6 @@ def _pd_schema_without_embedding_funcs(
|
||||
)
|
||||
if not embedding_functions:
|
||||
return schema
|
||||
columns = set(columns)
|
||||
return pa.schema([field for field in schema if field.name in columns])
|
||||
|
||||
|
||||
@@ -119,7 +119,7 @@ def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table:
|
||||
return pa.Table.from_batches(data, schema=schema)
|
||||
else:
|
||||
return pa.Table.from_pylist(data, schema=schema)
|
||||
elif _check_for_pandas(data) and isinstance(data, pd.DataFrame):
|
||||
elif _check_for_pandas(data) and isinstance(data, pd.DataFrame): # type: ignore
|
||||
raw_schema = _pd_schema_without_embedding_funcs(schema, data.columns.to_list())
|
||||
table = pa.Table.from_pandas(data, preserve_index=False, schema=raw_schema)
|
||||
# Do not serialize Pandas metadata
|
||||
@@ -160,7 +160,7 @@ def _sanitize_data(
|
||||
metadata: Optional[dict] = None, # embedding metadata
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
):
|
||||
) -> Tuple[pa.Table, pa.Schema]:
|
||||
data = _coerce_to_table(data, schema)
|
||||
|
||||
if metadata:
|
||||
@@ -178,13 +178,17 @@ def _sanitize_data(
|
||||
|
||||
|
||||
def sanitize_create_table(
|
||||
data, schema, metadata=None, on_bad_vectors="error", fill_value=0.0
|
||||
data,
|
||||
schema: Union[pa.Schema, LanceModel],
|
||||
metadata=None,
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
):
|
||||
if inspect.isclass(schema) and issubclass(schema, LanceModel):
|
||||
# convert LanceModel to pyarrow schema
|
||||
# note that it's possible this contains
|
||||
# embedding function metadata already
|
||||
schema = schema.to_arrow_schema()
|
||||
schema: pa.Schema = schema.to_arrow_schema()
|
||||
|
||||
if data is not None:
|
||||
if metadata is None and schema is not None:
|
||||
@@ -272,41 +276,6 @@ def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schem
|
||||
return data
|
||||
|
||||
|
||||
def _generator_to_data_and_schema(
|
||||
data: Iterable,
|
||||
) -> Tuple[Iterable[pa.RecordBatch], pa.Schema]:
|
||||
def _with_first_generator(first, data):
|
||||
yield first
|
||||
yield from data
|
||||
|
||||
first = next(data, None)
|
||||
schema = None
|
||||
if isinstance(first, pa.RecordBatch):
|
||||
schema = first.schema
|
||||
data = _with_first_generator(first, data)
|
||||
elif isinstance(first, pa.Table):
|
||||
schema = first.schema
|
||||
data = _with_first_generator(first.to_batches(), data)
|
||||
return data, schema
|
||||
|
||||
|
||||
def _to_record_batch_generator(
|
||||
data: Iterable,
|
||||
schema,
|
||||
metadata,
|
||||
on_bad_vectors,
|
||||
fill_value,
|
||||
):
|
||||
for batch in data:
|
||||
# always convert to table because we need to sanitize the data
|
||||
# and do things like add the vector column etc
|
||||
if isinstance(batch, pa.RecordBatch):
|
||||
batch = pa.Table.from_batches([batch])
|
||||
batch, _ = _sanitize_data(batch, schema, metadata, on_bad_vectors, fill_value)
|
||||
for b in batch.to_batches():
|
||||
yield b
|
||||
|
||||
|
||||
def _table_path(base: str, table_name: str) -> str:
|
||||
"""
|
||||
Get a table path that can be used in PyArrow FS.
|
||||
@@ -404,7 +373,7 @@ class Table(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def to_pandas(self) -> "pd.DataFrame":
|
||||
def to_pandas(self) -> "pandas.DataFrame":
|
||||
"""Return the table as a pandas DataFrame.
|
||||
|
||||
Returns
|
||||
@@ -537,8 +506,8 @@ class Table(ABC):
|
||||
def create_fts_index(
|
||||
self,
|
||||
field_names: Union[str, List[str]],
|
||||
ordering_field_names: Union[str, List[str]] = None,
|
||||
*,
|
||||
ordering_field_names: Optional[Union[str, List[str]]] = None,
|
||||
replace: bool = False,
|
||||
writer_heap_size: Optional[int] = 1024 * 1024 * 1024,
|
||||
use_tantivy: bool = True,
|
||||
@@ -790,8 +759,7 @@ class Table(ABC):
|
||||
@abstractmethod
|
||||
def _execute_query(
|
||||
self, query: Query, batch_size: Optional[int] = None
|
||||
) -> pa.RecordBatchReader:
|
||||
pass
|
||||
) -> pa.RecordBatchReader: ...
|
||||
|
||||
@abstractmethod
|
||||
def _do_merge(
|
||||
@@ -800,8 +768,7 @@ class Table(ABC):
|
||||
new_data: DATA,
|
||||
on_bad_vectors: str,
|
||||
fill_value: float,
|
||||
):
|
||||
pass
|
||||
): ...
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, where: str):
|
||||
@@ -1121,7 +1088,7 @@ class Table(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def list_versions(self):
|
||||
def list_versions(self) -> List[Dict[str, Any]]:
|
||||
"""List all versions of the table"""
|
||||
|
||||
@cached_property
|
||||
@@ -1244,7 +1211,7 @@ class LanceTable(Table):
|
||||
A PyArrow schema object."""
|
||||
return LOOP.run(self._table.schema())
|
||||
|
||||
def list_versions(self):
|
||||
def list_versions(self) -> List[Dict[str, Any]]:
|
||||
"""List all versions of the table"""
|
||||
return LOOP.run(self._table.list_versions())
|
||||
|
||||
@@ -1297,7 +1264,7 @@ class LanceTable(Table):
|
||||
"""
|
||||
LOOP.run(self._table.checkout_latest())
|
||||
|
||||
def restore(self, version: int = None):
|
||||
def restore(self, version: Optional[int] = None):
|
||||
"""Restore a version of the table. This is an in-place operation.
|
||||
|
||||
This creates a new version where the data is equivalent to the
|
||||
@@ -1338,7 +1305,7 @@ class LanceTable(Table):
|
||||
def count_rows(self, filter: Optional[str] = None) -> int:
|
||||
return LOOP.run(self._table.count_rows(filter))
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return self.count_rows()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@@ -1506,8 +1473,8 @@ class LanceTable(Table):
|
||||
def create_fts_index(
|
||||
self,
|
||||
field_names: Union[str, List[str]],
|
||||
ordering_field_names: Union[str, List[str]] = None,
|
||||
*,
|
||||
ordering_field_names: Optional[Union[str, List[str]]] = None,
|
||||
replace: bool = False,
|
||||
writer_heap_size: Optional[int] = 1024 * 1024 * 1024,
|
||||
use_tantivy: bool = True,
|
||||
@@ -1594,6 +1561,7 @@ class LanceTable(Table):
|
||||
writer_heap_size=writer_heap_size,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def infer_tokenizer_configs(tokenizer_name: str) -> dict:
|
||||
if tokenizer_name == "default":
|
||||
return {
|
||||
@@ -1759,7 +1727,7 @@ class LanceTable(Table):
|
||||
)
|
||||
|
||||
@overload
|
||||
def search(
|
||||
def search( # type: ignore
|
||||
self,
|
||||
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
||||
vector_column_name: Optional[str] = None,
|
||||
@@ -1895,11 +1863,11 @@ class LanceTable(Table):
|
||||
name: str,
|
||||
data: Optional[DATA] = None,
|
||||
schema: Optional[pa.Schema] = None,
|
||||
mode: Literal["create", "overwrite", "append"] = "create",
|
||||
mode: Literal["create", "overwrite"] = "create",
|
||||
exist_ok: bool = False,
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
embedding_functions: List[EmbeddingFunctionConfig] = None,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||
*,
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
data_storage_version: Optional[str] = None,
|
||||
@@ -2065,7 +2033,7 @@ class LanceTable(Table):
|
||||
older_than, delete_unverified=delete_unverified
|
||||
)
|
||||
|
||||
def compact_files(self, *args, **kwargs):
|
||||
def compact_files(self, *args, **kwargs) -> CompactionStats:
|
||||
"""
|
||||
Run the compaction process on the table.
|
||||
|
||||
@@ -2450,7 +2418,7 @@ def _process_iterator(data: Iterable, schema: Optional[pa.Schema] = None) -> pa.
|
||||
if batch_table.schema != schema:
|
||||
try:
|
||||
batch_table = batch_table.cast(schema)
|
||||
except pa.lib.ArrowInvalid:
|
||||
except pa.lib.ArrowInvalid: # type: ignore
|
||||
raise ValueError(
|
||||
f"Input iterator yielded a batch with schema that "
|
||||
f"does not match the expected schema.\nExpected:\n{schema}\n"
|
||||
@@ -2710,16 +2678,17 @@ class AsyncTable:
|
||||
on_bad_vectors = "error"
|
||||
if fill_value is None:
|
||||
fill_value = 0.0
|
||||
data, _ = _sanitize_data(
|
||||
table_and_schema: Tuple[pa.Table, pa.Schema] = _sanitize_data(
|
||||
data,
|
||||
schema,
|
||||
metadata=schema.metadata,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
if isinstance(data, pa.Table):
|
||||
data = pa.RecordBatchReader.from_batches(data.schema, data.to_batches())
|
||||
await self._inner.add(data, mode)
|
||||
tbl, schema = table_and_schema
|
||||
if isinstance(tbl, pa.Table):
|
||||
data = pa.RecordBatchReader.from_batches(schema, tbl.to_batches())
|
||||
await self._inner.add(data, mode or "append")
|
||||
|
||||
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
|
||||
"""
|
||||
@@ -2977,7 +2946,7 @@ class AsyncTable:
|
||||
|
||||
return await self._inner.update(updates_sql, where)
|
||||
|
||||
async def add_columns(self, transforms: Dict[str, str]):
|
||||
async def add_columns(self, transforms: dict[str, str]):
|
||||
"""
|
||||
Add new columns with defined values.
|
||||
|
||||
@@ -2990,7 +2959,7 @@ class AsyncTable:
|
||||
"""
|
||||
await self._inner.add_columns(list(transforms.items()))
|
||||
|
||||
async def alter_columns(self, *alterations: Iterable[Dict[str, str]]):
|
||||
async def alter_columns(self, *alterations: Iterable[dict[str, Any]]):
|
||||
"""
|
||||
Alter column names and nullability.
|
||||
|
||||
@@ -3148,9 +3117,12 @@ class AsyncTable:
|
||||
you have added or modified 100,000 or more records or run more than 20 data
|
||||
modification operations.
|
||||
"""
|
||||
cleanup_since_ms: Optional[int] = None
|
||||
if cleanup_older_than is not None:
|
||||
cleanup_older_than = round(cleanup_older_than.total_seconds() * 1000)
|
||||
return await self._inner.optimize(cleanup_older_than, delete_unverified)
|
||||
cleanup_since_ms = round(cleanup_older_than.total_seconds() * 1000)
|
||||
return await self._inner.optimize(
|
||||
cleanup_since_ms=cleanup_since_ms, delete_unverified=delete_unverified
|
||||
)
|
||||
|
||||
async def list_indices(self) -> Iterable[IndexConfig]:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user