chore: improve create_table API consistency between local and remote SDK (#627)

This commit is contained in:
Lei Xu
2023-11-03 13:15:11 -07:00
committed by GitHub
parent 567734dd6e
commit 554e068917
11 changed files with 96 additions and 38 deletions

View File

@@ -16,10 +16,11 @@ from typing import Optional
__version__ = importlib.metadata.version("lancedb")
from .db import URI, DBConnection, LanceDBConnection
from .common import URI
from .db import DBConnection, LanceDBConnection
from .remote.db import RemoteDBConnection
from .schema import vector
from .utils import sentry_log
from .schema import vector # noqa: F401
from .utils import sentry_log # noqa: F401
def connect(

View File

@@ -14,26 +14,39 @@
from __future__ import annotations
import os
from abc import ABC, abstractmethod
from abc import abstractmethod
from pathlib import Path
from typing import List, Optional, Union
from typing import TYPE_CHECKING, Iterable, List, Optional, Union
import pyarrow as pa
from overrides import EnforceOverrides, override
from pyarrow import fs
from .common import DATA, URI
from .embeddings import EmbeddingFunctionConfig
from .pydantic import LanceModel
from .table import LanceTable, Table
from .util import fs_from_uri, get_uri_location, get_uri_scheme
if TYPE_CHECKING:
from .common import DATA, URI
from .embeddings import EmbeddingFunctionConfig
from .pydantic import LanceModel
class DBConnection(ABC):
class DBConnection(EnforceOverrides):
"""An active LanceDB connection interface."""
@abstractmethod
def table_names(self) -> list[str]:
"""List all table names in the database."""
def table_names(
self, page_token: Optional[str] = None, limit: int = 10
) -> Iterable[str]:
"""List all table in this database
Parameters
----------
page_token: str, optional
The token to use for pagination. If not present, start from the beginning.
limit: int, default 10
The size of the page to return.
"""
pass
@abstractmethod
@@ -45,6 +58,7 @@ class DBConnection(ABC):
mode: str = "create",
on_bad_vectors: str = "error",
fill_value: float = 0.0,
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
) -> Table:
"""Create a [Table][lancedb.table.Table] in the database.
@@ -262,12 +276,15 @@ class LanceDBConnection(DBConnection):
def uri(self) -> str:
return self._uri
def table_names(self) -> list[str]:
@override
def table_names(
self, page_token: Optional[str] = None, limit: int = 10
) -> Iterable[str]:
"""Get the names of all tables in the database. The names are sorted.
Returns
-------
list of str
Iterator of str.
A list of table names.
"""
try:
@@ -296,6 +313,7 @@ class LanceDBConnection(DBConnection):
def __contains__(self, name: str) -> bool:
return name in self.table_names()
@override
def create_table(
self,
name: str,
@@ -327,6 +345,7 @@ class LanceDBConnection(DBConnection):
)
return tbl
@override
def open_table(self, name: str) -> LanceTable:
"""Open a table in the database.
@@ -341,6 +360,7 @@ class LanceDBConnection(DBConnection):
"""
return LanceTable.open(self, name)
@override
def drop_table(self, name: str, ignore_missing: bool = False):
"""Drop a table from the database.
@@ -359,6 +379,7 @@ class LanceDBConnection(DBConnection):
if not ignore_missing:
raise
@override
def drop_database(self):
filesystem, path = fs_from_uri(self.uri)
filesystem.delete_dir(path)

View File

@@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa: F401
from .base import EmbeddingFunction, EmbeddingFunctionConfig, TextEmbeddingFunction
from .cohere import CohereEmbeddingFunction
from .open_clip import OpenClipEmbeddings

View File

@@ -31,7 +31,8 @@ class CohereEmbeddingFunction(TextEmbeddingFunction):
Parameters
----------
name: str, default "embed-multilingual-v2.0"
The name of the model to use. See the Cohere documentation for a list of available models.
The name of the model to use. See the Cohere documentation for
a list of available models.
Examples
--------
@@ -39,7 +40,10 @@ class CohereEmbeddingFunction(TextEmbeddingFunction):
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import EmbeddingFunctionRegistry
cohere = EmbeddingFunctionRegistry.get_instance().get("cohere").create(name="embed-multilingual-v2.0")
cohere = EmbeddingFunctionRegistry
.get_instance()
.get("cohere")
.create(name="embed-multilingual-v2.0")
class TextModel(LanceModel):
text: str = cohere.SourceField()

View File

@@ -14,7 +14,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List, Literal, Optional, Type, Union
from typing import TYPE_CHECKING, List, Literal, Optional, Type, Union
import deprecation
import numpy as np
@@ -23,9 +23,11 @@ import pydantic
from . import __version__
from .common import VECTOR_COLUMN_NAME
from .pydantic import LanceModel
from .util import safe_import_pandas
if TYPE_CHECKING:
from .pydantic import LanceModel
pd = safe_import_pandas()

View File

@@ -13,7 +13,7 @@
import functools
from typing import Any, Callable, Dict, Optional, Union
from typing import Any, Callable, Dict, Iterable, Optional, Union
import aiohttp
import attrs
@@ -151,7 +151,9 @@ class RestfulLanceDBClient:
return await deserialize(resp)
@_check_not_closed
async def list_tables(self, limit: int, page_token: str):
async def list_tables(
self, limit: int, page_token: Optional[str] = None
) -> Iterable[str]:
"""List all tables in the database."""
try:
json = await self.get(

View File

@@ -12,14 +12,18 @@
# limitations under the License.
import asyncio
import inspect
import uuid
from typing import Iterator, Optional
from typing import Iterable, List, Optional, Union
from urllib.parse import urlparse
import pyarrow as pa
from overrides import override
from ..common import DATA
from ..db import DBConnection
from ..embeddings import EmbeddingFunctionConfig
from ..pydantic import LanceModel
from ..table import Table, _sanitize_data
from .arrow import to_ipc_binary
from .client import ARROW_STREAM_CONTENT_TYPE, RestfulLanceDBClient
@@ -52,8 +56,10 @@ class RemoteDBConnection(DBConnection):
def __repr__(self) -> str:
return f"RemoveConnect(name={self.db_name})"
def table_names(self, last_token: str, limit=10) -> Iterator[str]:
@override
def table_names(self, page_token: Optional[str] = None, limit=10) -> Iterable[str]:
"""List the names of all tables in the database.
Parameters
----------
last_token: str
@@ -65,15 +71,16 @@ class RemoteDBConnection(DBConnection):
"""
while True:
result = self._loop.run_until_complete(
self._client.list_tables(limit, last_token)
self._client.list_tables(limit, page_token)
)
if len(result) > 0:
last_token = result[len(result) - 1]
page_token = result[len(result) - 1]
else:
break
for item in result:
yield result
yield item
@override
def open_table(self, name: str) -> Table:
"""Open a Lance Table in the database.
@@ -92,16 +99,31 @@ class RemoteDBConnection(DBConnection):
return RemoteTable(self, name)
@override
def create_table(
self,
name: str,
data: DATA = None,
schema: pa.Schema = None,
schema: Optional[Union[pa.Schema, LanceModel]] = None,
on_bad_vectors: str = "error",
fill_value: float = 0.0,
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
) -> Table:
if data is None and schema is None:
raise ValueError("Either data or schema must be provided.")
if embedding_functions is not None:
raise NotImplementedError(
"embedding_functions is not supported for remote databases."
"Please vote https://github.com/lancedb/lancedb/issues/626 "
"for this feature."
)
if inspect.isclass(schema) and issubclass(schema, LanceModel):
# convert LanceModel to pyarrow schema
# note that it's possible this contains
# embedding function metadata already
schema = schema.to_arrow_schema()
if data is not None:
data = _sanitize_data(
data,
@@ -130,6 +152,7 @@ class RemoteDBConnection(DBConnection):
)
return RemoteTable(self, name)
@override
def drop_table(self, name: str):
"""Drop a table from the database.

View File

@@ -16,16 +16,14 @@ from __future__ import annotations
import inspect
import os
from abc import ABC, abstractmethod
from datetime import timedelta
from functools import cached_property
from typing import Any, Iterable, List, Optional, Union
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union
import lance
import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
from lance import LanceDataset
from lance.dataset import CleanupStats, ReaderLike
from lance.vector import vec_to_table
from .common import DATA, VEC, VECTOR_COLUMN_NAME
@@ -35,6 +33,12 @@ from .query import LanceQueryBuilder, Query
from .util import fs_from_uri, safe_import_pandas
from .utils.events import register_event
if TYPE_CHECKING:
from datetime import timedelta
from lance.dataset import CleanupStats, ReaderLike
pd = safe_import_pandas()

View File

@@ -14,7 +14,8 @@ dependencies = [
"cachetools",
"pyyaml>=6.0",
"click>=8.1.7",
"requests>=2.31.0"
"requests>=2.31.0",
"overrides>=0.7"
]
description = "lancedb"
authors = [{ name = "LanceDB Devs", email = "dev@lancedb.com" }]
@@ -64,6 +65,9 @@ build-backend = "setuptools.build_meta"
[tool.isort]
profile = "black"
[tool.ruff]
select = ["F", "E", "W", "I", "G", "TCH", "PERF"]
[tool.pytest.ini_options]
addopts = "--strict-markers"
markers = [

View File

@@ -129,7 +129,7 @@ def test_ingest_iterator(tmp_path):
[
PydanticSchema(vector=[3.1, 4.1], item="foo", price=10.0),
PydanticSchema(vector=[5.9, 26.5], item="bar", price=20.0),
]
],
# TODO: test pydict separately. it is unique column number and names contraint
]