chore: improve create_table API consistency between local and remote SDK (#627)

This commit is contained in:
Lei Xu
2023-11-03 13:15:11 -07:00
committed by GitHub
parent 567734dd6e
commit 554e068917
11 changed files with 96 additions and 38 deletions

View File

@@ -14,26 +14,39 @@
from __future__ import annotations
import os
from abc import ABC, abstractmethod
from abc import abstractmethod
from pathlib import Path
from typing import List, Optional, Union
from typing import TYPE_CHECKING, Iterable, List, Optional, Union
import pyarrow as pa
from overrides import EnforceOverrides, override
from pyarrow import fs
from .common import DATA, URI
from .embeddings import EmbeddingFunctionConfig
from .pydantic import LanceModel
from .table import LanceTable, Table
from .util import fs_from_uri, get_uri_location, get_uri_scheme
if TYPE_CHECKING:
from .common import DATA, URI
from .embeddings import EmbeddingFunctionConfig
from .pydantic import LanceModel
class DBConnection(ABC):
class DBConnection(EnforceOverrides):
"""An active LanceDB connection interface."""
@abstractmethod
def table_names(self) -> list[str]:
"""List all table names in the database."""
def table_names(
self, page_token: Optional[str] = None, limit: int = 10
) -> Iterable[str]:
"""List all table in this database
Parameters
----------
page_token: str, optional
The token to use for pagination. If not present, start from the beginning.
limit: int, default 10
The size of the page to return.
"""
pass
@abstractmethod
@@ -45,6 +58,7 @@ class DBConnection(ABC):
mode: str = "create",
on_bad_vectors: str = "error",
fill_value: float = 0.0,
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
) -> Table:
"""Create a [Table][lancedb.table.Table] in the database.
@@ -262,12 +276,15 @@ class LanceDBConnection(DBConnection):
def uri(self) -> str:
return self._uri
def table_names(self) -> list[str]:
@override
def table_names(
self, page_token: Optional[str] = None, limit: int = 10
) -> Iterable[str]:
"""Get the names of all tables in the database. The names are sorted.
Returns
-------
list of str
Iterator of str.
A list of table names.
"""
try:
@@ -296,6 +313,7 @@ class LanceDBConnection(DBConnection):
def __contains__(self, name: str) -> bool:
return name in self.table_names()
@override
def create_table(
self,
name: str,
@@ -327,6 +345,7 @@ class LanceDBConnection(DBConnection):
)
return tbl
@override
def open_table(self, name: str) -> LanceTable:
"""Open a table in the database.
@@ -341,6 +360,7 @@ class LanceDBConnection(DBConnection):
"""
return LanceTable.open(self, name)
@override
def drop_table(self, name: str, ignore_missing: bool = False):
"""Drop a table from the database.
@@ -359,6 +379,7 @@ class LanceDBConnection(DBConnection):
if not ignore_missing:
raise
@override
def drop_database(self):
filesystem, path = fs_from_uri(self.uri)
filesystem.delete_dir(path)