diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 0906dfb3..be1b6aa4 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -37,11 +37,9 @@ jobs: run: | pip install -e .[tests] pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985 - pip install pytest pytest-mock black isort - - name: Black - run: black --check --diff --no-color --quiet . - - name: isort - run: isort --check --diff --quiet . + pip install pytest pytest-mock ruff + - name: Lint + run: ruff format --check . - name: Run tests run: pytest -m "not slow" -x -v --durations=30 tests - name: doctest @@ -67,8 +65,6 @@ jobs: pip install -e .[tests] pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985 pip install pytest pytest-mock black - - name: Black - run: black --check --diff --no-color --quiet . - name: Run tests run: pytest -m "not slow" -x -v --durations=30 tests pydantic1x: @@ -100,4 +96,4 @@ jobs: - name: Run tests run: pytest -m "not slow" -x -v --durations=30 tests - name: doctest - run: pytest --doctest-modules lancedb \ No newline at end of file + run: pytest --doctest-modules lancedb diff --git a/python/lancedb/__init__.py b/python/lancedb/__init__.py index c72390e6..6980b710 100644 --- a/python/lancedb/__init__.py +++ b/python/lancedb/__init__.py @@ -16,10 +16,11 @@ from typing import Optional __version__ = importlib.metadata.version("lancedb") -from .db import URI, DBConnection, LanceDBConnection +from .common import URI +from .db import DBConnection, LanceDBConnection from .remote.db import RemoteDBConnection -from .schema import vector -from .utils import sentry_log +from .schema import vector # noqa: F401 +from .utils import sentry_log # noqa: F401 def connect( diff --git a/python/lancedb/db.py b/python/lancedb/db.py index 4d5f4bb8..bcfc73b8 100644 --- a/python/lancedb/db.py +++ b/python/lancedb/db.py @@ -14,26 +14,39 @@ from __future__ import annotations import os -from abc import ABC, abstractmethod +from abc import abstractmethod from pathlib import Path -from typing import List, Optional, Union +from typing import TYPE_CHECKING, Iterable, List, Optional, Union import pyarrow as pa +from overrides import EnforceOverrides, override from pyarrow import fs -from .common import DATA, URI -from .embeddings import EmbeddingFunctionConfig -from .pydantic import LanceModel from .table import LanceTable, Table from .util import fs_from_uri, get_uri_location, get_uri_scheme +if TYPE_CHECKING: + from .common import DATA, URI + from .embeddings import EmbeddingFunctionConfig + from .pydantic import LanceModel -class DBConnection(ABC): + +class DBConnection(EnforceOverrides): """An active LanceDB connection interface.""" @abstractmethod - def table_names(self) -> list[str]: - """List all table names in the database.""" + def table_names( + self, page_token: Optional[str] = None, limit: int = 10 + ) -> Iterable[str]: + """List all table in this database + + Parameters + ---------- + page_token: str, optional + The token to use for pagination. If not present, start from the beginning. + limit: int, default 10 + The size of the page to return. + """ pass @abstractmethod @@ -45,6 +58,7 @@ class DBConnection(ABC): mode: str = "create", on_bad_vectors: str = "error", fill_value: float = 0.0, + embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None, ) -> Table: """Create a [Table][lancedb.table.Table] in the database. @@ -262,12 +276,15 @@ class LanceDBConnection(DBConnection): def uri(self) -> str: return self._uri - def table_names(self) -> list[str]: + @override + def table_names( + self, page_token: Optional[str] = None, limit: int = 10 + ) -> Iterable[str]: """Get the names of all tables in the database. The names are sorted. Returns ------- - list of str + Iterator of str. A list of table names. """ try: @@ -296,6 +313,7 @@ class LanceDBConnection(DBConnection): def __contains__(self, name: str) -> bool: return name in self.table_names() + @override def create_table( self, name: str, @@ -327,6 +345,7 @@ class LanceDBConnection(DBConnection): ) return tbl + @override def open_table(self, name: str) -> LanceTable: """Open a table in the database. @@ -341,6 +360,7 @@ class LanceDBConnection(DBConnection): """ return LanceTable.open(self, name) + @override def drop_table(self, name: str, ignore_missing: bool = False): """Drop a table from the database. @@ -359,6 +379,7 @@ class LanceDBConnection(DBConnection): if not ignore_missing: raise + @override def drop_database(self): filesystem, path = fs_from_uri(self.uri) filesystem.delete_dir(path) diff --git a/python/lancedb/embeddings/__init__.py b/python/lancedb/embeddings/__init__.py index 2977f0b4..99e6d314 100644 --- a/python/lancedb/embeddings/__init__.py +++ b/python/lancedb/embeddings/__init__.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +# ruff: noqa: F401 from .base import EmbeddingFunction, EmbeddingFunctionConfig, TextEmbeddingFunction from .cohere import CohereEmbeddingFunction from .open_clip import OpenClipEmbeddings diff --git a/python/lancedb/embeddings/cohere.py b/python/lancedb/embeddings/cohere.py index 07881f69..0084c857 100644 --- a/python/lancedb/embeddings/cohere.py +++ b/python/lancedb/embeddings/cohere.py @@ -31,7 +31,8 @@ class CohereEmbeddingFunction(TextEmbeddingFunction): Parameters ---------- name: str, default "embed-multilingual-v2.0" - The name of the model to use. See the Cohere documentation for a list of available models. + The name of the model to use. See the Cohere documentation for + a list of available models. Examples -------- @@ -39,7 +40,10 @@ class CohereEmbeddingFunction(TextEmbeddingFunction): from lancedb.pydantic import LanceModel, Vector from lancedb.embeddings import EmbeddingFunctionRegistry - cohere = EmbeddingFunctionRegistry.get_instance().get("cohere").create(name="embed-multilingual-v2.0") + cohere = EmbeddingFunctionRegistry + .get_instance() + .get("cohere") + .create(name="embed-multilingual-v2.0") class TextModel(LanceModel): text: str = cohere.SourceField() diff --git a/python/lancedb/query.py b/python/lancedb/query.py index 76e43728..c73e82f8 100644 --- a/python/lancedb/query.py +++ b/python/lancedb/query.py @@ -14,7 +14,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import List, Literal, Optional, Type, Union +from typing import TYPE_CHECKING, List, Literal, Optional, Type, Union import deprecation import numpy as np @@ -23,9 +23,11 @@ import pydantic from . import __version__ from .common import VECTOR_COLUMN_NAME -from .pydantic import LanceModel from .util import safe_import_pandas +if TYPE_CHECKING: + from .pydantic import LanceModel + pd = safe_import_pandas() diff --git a/python/lancedb/remote/client.py b/python/lancedb/remote/client.py index 24aa90df..03fec380 100644 --- a/python/lancedb/remote/client.py +++ b/python/lancedb/remote/client.py @@ -13,7 +13,7 @@ import functools -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Callable, Dict, Iterable, Optional, Union import aiohttp import attrs @@ -151,7 +151,9 @@ class RestfulLanceDBClient: return await deserialize(resp) @_check_not_closed - async def list_tables(self, limit: int, page_token: str): + async def list_tables( + self, limit: int, page_token: Optional[str] = None + ) -> Iterable[str]: """List all tables in the database.""" try: json = await self.get( diff --git a/python/lancedb/remote/db.py b/python/lancedb/remote/db.py index 6b018062..60387c33 100644 --- a/python/lancedb/remote/db.py +++ b/python/lancedb/remote/db.py @@ -12,14 +12,18 @@ # limitations under the License. import asyncio +import inspect import uuid -from typing import Iterator, Optional +from typing import Iterable, List, Optional, Union from urllib.parse import urlparse import pyarrow as pa +from overrides import override from ..common import DATA from ..db import DBConnection +from ..embeddings import EmbeddingFunctionConfig +from ..pydantic import LanceModel from ..table import Table, _sanitize_data from .arrow import to_ipc_binary from .client import ARROW_STREAM_CONTENT_TYPE, RestfulLanceDBClient @@ -52,8 +56,10 @@ class RemoteDBConnection(DBConnection): def __repr__(self) -> str: return f"RemoveConnect(name={self.db_name})" - def table_names(self, last_token: str, limit=10) -> Iterator[str]: + @override + def table_names(self, page_token: Optional[str] = None, limit=10) -> Iterable[str]: """List the names of all tables in the database. + Parameters ---------- last_token: str @@ -65,15 +71,16 @@ class RemoteDBConnection(DBConnection): """ while True: result = self._loop.run_until_complete( - self._client.list_tables(limit, last_token) + self._client.list_tables(limit, page_token) ) if len(result) > 0: - last_token = result[len(result) - 1] + page_token = result[len(result) - 1] else: break for item in result: - yield result + yield item + @override def open_table(self, name: str) -> Table: """Open a Lance Table in the database. @@ -92,16 +99,31 @@ class RemoteDBConnection(DBConnection): return RemoteTable(self, name) + @override def create_table( self, name: str, data: DATA = None, - schema: pa.Schema = None, + schema: Optional[Union[pa.Schema, LanceModel]] = None, on_bad_vectors: str = "error", fill_value: float = 0.0, + embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None, ) -> Table: if data is None and schema is None: raise ValueError("Either data or schema must be provided.") + if embedding_functions is not None: + raise NotImplementedError( + "embedding_functions is not supported for remote databases." + "Please vote https://github.com/lancedb/lancedb/issues/626 " + "for this feature." + ) + + if inspect.isclass(schema) and issubclass(schema, LanceModel): + # convert LanceModel to pyarrow schema + # note that it's possible this contains + # embedding function metadata already + schema = schema.to_arrow_schema() + if data is not None: data = _sanitize_data( data, @@ -130,6 +152,7 @@ class RemoteDBConnection(DBConnection): ) return RemoteTable(self, name) + @override def drop_table(self, name: str): """Drop a table from the database. diff --git a/python/lancedb/table.py b/python/lancedb/table.py index be92cd0b..548c11b5 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -16,16 +16,14 @@ from __future__ import annotations import inspect import os from abc import ABC, abstractmethod -from datetime import timedelta from functools import cached_property -from typing import Any, Iterable, List, Optional, Union +from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union import lance import numpy as np import pyarrow as pa import pyarrow.compute as pc from lance import LanceDataset -from lance.dataset import CleanupStats, ReaderLike from lance.vector import vec_to_table from .common import DATA, VEC, VECTOR_COLUMN_NAME @@ -35,6 +33,12 @@ from .query import LanceQueryBuilder, Query from .util import fs_from_uri, safe_import_pandas from .utils.events import register_event +if TYPE_CHECKING: + from datetime import timedelta + + from lance.dataset import CleanupStats, ReaderLike + + pd = safe_import_pandas() diff --git a/python/pyproject.toml b/python/pyproject.toml index 104d8ba2..13bc8e58 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -14,7 +14,8 @@ dependencies = [ "cachetools", "pyyaml>=6.0", "click>=8.1.7", - "requests>=2.31.0" + "requests>=2.31.0", + "overrides>=0.7" ] description = "lancedb" authors = [{ name = "LanceDB Devs", email = "dev@lancedb.com" }] @@ -64,6 +65,9 @@ build-backend = "setuptools.build_meta" [tool.isort] profile = "black" +[tool.ruff] +select = ["F", "E", "W", "I", "G", "TCH", "PERF"] + [tool.pytest.ini_options] addopts = "--strict-markers" markers = [ diff --git a/python/tests/test_db.py b/python/tests/test_db.py index a1f4f9c6..90719d85 100644 --- a/python/tests/test_db.py +++ b/python/tests/test_db.py @@ -129,7 +129,7 @@ def test_ingest_iterator(tmp_path): [ PydanticSchema(vector=[3.1, 4.1], item="foo", price=10.0), PydanticSchema(vector=[5.9, 26.5], item="bar", price=20.0), - ] + ], # TODO: test pydict separately. it is unique column number and names contraint ]