mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 07:09:57 +00:00
chore: improve create_table API consistency between local and remote SDK (#627)
This commit is contained in:
@@ -16,10 +16,11 @@ from typing import Optional
|
||||
|
||||
__version__ = importlib.metadata.version("lancedb")
|
||||
|
||||
from .db import URI, DBConnection, LanceDBConnection
|
||||
from .common import URI
|
||||
from .db import DBConnection, LanceDBConnection
|
||||
from .remote.db import RemoteDBConnection
|
||||
from .schema import vector
|
||||
from .utils import sentry_log
|
||||
from .schema import vector # noqa: F401
|
||||
from .utils import sentry_log # noqa: F401
|
||||
|
||||
|
||||
def connect(
|
||||
|
||||
@@ -14,26 +14,39 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Iterable, List, Optional, Union
|
||||
|
||||
import pyarrow as pa
|
||||
from overrides import EnforceOverrides, override
|
||||
from pyarrow import fs
|
||||
|
||||
from .common import DATA, URI
|
||||
from .embeddings import EmbeddingFunctionConfig
|
||||
from .pydantic import LanceModel
|
||||
from .table import LanceTable, Table
|
||||
from .util import fs_from_uri, get_uri_location, get_uri_scheme
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .common import DATA, URI
|
||||
from .embeddings import EmbeddingFunctionConfig
|
||||
from .pydantic import LanceModel
|
||||
|
||||
class DBConnection(ABC):
|
||||
|
||||
class DBConnection(EnforceOverrides):
|
||||
"""An active LanceDB connection interface."""
|
||||
|
||||
@abstractmethod
|
||||
def table_names(self) -> list[str]:
|
||||
"""List all table names in the database."""
|
||||
def table_names(
|
||||
self, page_token: Optional[str] = None, limit: int = 10
|
||||
) -> Iterable[str]:
|
||||
"""List all table in this database
|
||||
|
||||
Parameters
|
||||
----------
|
||||
page_token: str, optional
|
||||
The token to use for pagination. If not present, start from the beginning.
|
||||
limit: int, default 10
|
||||
The size of the page to return.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -45,6 +58,7 @@ class DBConnection(ABC):
|
||||
mode: str = "create",
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||
) -> Table:
|
||||
"""Create a [Table][lancedb.table.Table] in the database.
|
||||
|
||||
@@ -262,12 +276,15 @@ class LanceDBConnection(DBConnection):
|
||||
def uri(self) -> str:
|
||||
return self._uri
|
||||
|
||||
def table_names(self) -> list[str]:
|
||||
@override
|
||||
def table_names(
|
||||
self, page_token: Optional[str] = None, limit: int = 10
|
||||
) -> Iterable[str]:
|
||||
"""Get the names of all tables in the database. The names are sorted.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list of str
|
||||
Iterator of str.
|
||||
A list of table names.
|
||||
"""
|
||||
try:
|
||||
@@ -296,6 +313,7 @@ class LanceDBConnection(DBConnection):
|
||||
def __contains__(self, name: str) -> bool:
|
||||
return name in self.table_names()
|
||||
|
||||
@override
|
||||
def create_table(
|
||||
self,
|
||||
name: str,
|
||||
@@ -327,6 +345,7 @@ class LanceDBConnection(DBConnection):
|
||||
)
|
||||
return tbl
|
||||
|
||||
@override
|
||||
def open_table(self, name: str) -> LanceTable:
|
||||
"""Open a table in the database.
|
||||
|
||||
@@ -341,6 +360,7 @@ class LanceDBConnection(DBConnection):
|
||||
"""
|
||||
return LanceTable.open(self, name)
|
||||
|
||||
@override
|
||||
def drop_table(self, name: str, ignore_missing: bool = False):
|
||||
"""Drop a table from the database.
|
||||
|
||||
@@ -359,6 +379,7 @@ class LanceDBConnection(DBConnection):
|
||||
if not ignore_missing:
|
||||
raise
|
||||
|
||||
@override
|
||||
def drop_database(self):
|
||||
filesystem, path = fs_from_uri(self.uri)
|
||||
filesystem.delete_dir(path)
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa: F401
|
||||
from .base import EmbeddingFunction, EmbeddingFunctionConfig, TextEmbeddingFunction
|
||||
from .cohere import CohereEmbeddingFunction
|
||||
from .open_clip import OpenClipEmbeddings
|
||||
|
||||
@@ -31,7 +31,8 @@ class CohereEmbeddingFunction(TextEmbeddingFunction):
|
||||
Parameters
|
||||
----------
|
||||
name: str, default "embed-multilingual-v2.0"
|
||||
The name of the model to use. See the Cohere documentation for a list of available models.
|
||||
The name of the model to use. See the Cohere documentation for
|
||||
a list of available models.
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -39,7 +40,10 @@ class CohereEmbeddingFunction(TextEmbeddingFunction):
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.embeddings import EmbeddingFunctionRegistry
|
||||
|
||||
cohere = EmbeddingFunctionRegistry.get_instance().get("cohere").create(name="embed-multilingual-v2.0")
|
||||
cohere = EmbeddingFunctionRegistry
|
||||
.get_instance()
|
||||
.get("cohere")
|
||||
.create(name="embed-multilingual-v2.0")
|
||||
|
||||
class TextModel(LanceModel):
|
||||
text: str = cohere.SourceField()
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Literal, Optional, Type, Union
|
||||
from typing import TYPE_CHECKING, List, Literal, Optional, Type, Union
|
||||
|
||||
import deprecation
|
||||
import numpy as np
|
||||
@@ -23,9 +23,11 @@ import pydantic
|
||||
|
||||
from . import __version__
|
||||
from .common import VECTOR_COLUMN_NAME
|
||||
from .pydantic import LanceModel
|
||||
from .util import safe_import_pandas
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .pydantic import LanceModel
|
||||
|
||||
pd = safe_import_pandas()
|
||||
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
|
||||
|
||||
import functools
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
from typing import Any, Callable, Dict, Iterable, Optional, Union
|
||||
|
||||
import aiohttp
|
||||
import attrs
|
||||
@@ -151,7 +151,9 @@ class RestfulLanceDBClient:
|
||||
return await deserialize(resp)
|
||||
|
||||
@_check_not_closed
|
||||
async def list_tables(self, limit: int, page_token: str):
|
||||
async def list_tables(
|
||||
self, limit: int, page_token: Optional[str] = None
|
||||
) -> Iterable[str]:
|
||||
"""List all tables in the database."""
|
||||
try:
|
||||
json = await self.get(
|
||||
|
||||
@@ -12,14 +12,18 @@
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import uuid
|
||||
from typing import Iterator, Optional
|
||||
from typing import Iterable, List, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pyarrow as pa
|
||||
from overrides import override
|
||||
|
||||
from ..common import DATA
|
||||
from ..db import DBConnection
|
||||
from ..embeddings import EmbeddingFunctionConfig
|
||||
from ..pydantic import LanceModel
|
||||
from ..table import Table, _sanitize_data
|
||||
from .arrow import to_ipc_binary
|
||||
from .client import ARROW_STREAM_CONTENT_TYPE, RestfulLanceDBClient
|
||||
@@ -52,8 +56,10 @@ class RemoteDBConnection(DBConnection):
|
||||
def __repr__(self) -> str:
|
||||
return f"RemoveConnect(name={self.db_name})"
|
||||
|
||||
def table_names(self, last_token: str, limit=10) -> Iterator[str]:
|
||||
@override
|
||||
def table_names(self, page_token: Optional[str] = None, limit=10) -> Iterable[str]:
|
||||
"""List the names of all tables in the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
last_token: str
|
||||
@@ -65,15 +71,16 @@ class RemoteDBConnection(DBConnection):
|
||||
"""
|
||||
while True:
|
||||
result = self._loop.run_until_complete(
|
||||
self._client.list_tables(limit, last_token)
|
||||
self._client.list_tables(limit, page_token)
|
||||
)
|
||||
if len(result) > 0:
|
||||
last_token = result[len(result) - 1]
|
||||
page_token = result[len(result) - 1]
|
||||
else:
|
||||
break
|
||||
for item in result:
|
||||
yield result
|
||||
yield item
|
||||
|
||||
@override
|
||||
def open_table(self, name: str) -> Table:
|
||||
"""Open a Lance Table in the database.
|
||||
|
||||
@@ -92,16 +99,31 @@ class RemoteDBConnection(DBConnection):
|
||||
|
||||
return RemoteTable(self, name)
|
||||
|
||||
@override
|
||||
def create_table(
|
||||
self,
|
||||
name: str,
|
||||
data: DATA = None,
|
||||
schema: pa.Schema = None,
|
||||
schema: Optional[Union[pa.Schema, LanceModel]] = None,
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||
) -> Table:
|
||||
if data is None and schema is None:
|
||||
raise ValueError("Either data or schema must be provided.")
|
||||
if embedding_functions is not None:
|
||||
raise NotImplementedError(
|
||||
"embedding_functions is not supported for remote databases."
|
||||
"Please vote https://github.com/lancedb/lancedb/issues/626 "
|
||||
"for this feature."
|
||||
)
|
||||
|
||||
if inspect.isclass(schema) and issubclass(schema, LanceModel):
|
||||
# convert LanceModel to pyarrow schema
|
||||
# note that it's possible this contains
|
||||
# embedding function metadata already
|
||||
schema = schema.to_arrow_schema()
|
||||
|
||||
if data is not None:
|
||||
data = _sanitize_data(
|
||||
data,
|
||||
@@ -130,6 +152,7 @@ class RemoteDBConnection(DBConnection):
|
||||
)
|
||||
return RemoteTable(self, name)
|
||||
|
||||
@override
|
||||
def drop_table(self, name: str):
|
||||
"""Drop a table from the database.
|
||||
|
||||
|
||||
@@ -16,16 +16,14 @@ from __future__ import annotations
|
||||
import inspect
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import timedelta
|
||||
from functools import cached_property
|
||||
from typing import Any, Iterable, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union
|
||||
|
||||
import lance
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
from lance import LanceDataset
|
||||
from lance.dataset import CleanupStats, ReaderLike
|
||||
from lance.vector import vec_to_table
|
||||
|
||||
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
@@ -35,6 +33,12 @@ from .query import LanceQueryBuilder, Query
|
||||
from .util import fs_from_uri, safe_import_pandas
|
||||
from .utils.events import register_event
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import timedelta
|
||||
|
||||
from lance.dataset import CleanupStats, ReaderLike
|
||||
|
||||
|
||||
pd = safe_import_pandas()
|
||||
|
||||
|
||||
|
||||
@@ -14,7 +14,8 @@ dependencies = [
|
||||
"cachetools",
|
||||
"pyyaml>=6.0",
|
||||
"click>=8.1.7",
|
||||
"requests>=2.31.0"
|
||||
"requests>=2.31.0",
|
||||
"overrides>=0.7"
|
||||
]
|
||||
description = "lancedb"
|
||||
authors = [{ name = "LanceDB Devs", email = "dev@lancedb.com" }]
|
||||
@@ -64,6 +65,9 @@ build-backend = "setuptools.build_meta"
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
|
||||
[tool.ruff]
|
||||
select = ["F", "E", "W", "I", "G", "TCH", "PERF"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "--strict-markers"
|
||||
markers = [
|
||||
|
||||
@@ -129,7 +129,7 @@ def test_ingest_iterator(tmp_path):
|
||||
[
|
||||
PydanticSchema(vector=[3.1, 4.1], item="foo", price=10.0),
|
||||
PydanticSchema(vector=[5.9, 26.5], item="bar", price=20.0),
|
||||
]
|
||||
],
|
||||
# TODO: test pydict separately. it is unique column number and names contraint
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user