feat: add support for add to async python API (#1037)

In order to add support for `add` we needed to migrate the rust `Table`
trait to a `Table` struct and `TableInternal` trait (similar to the way
the connection is designed).

While doing this we also cleaned up some inconsistencies between the
SDKs:

* Python and Node are garbage collected languages and it can be
difficult to trigger something to be freed. The convention for these
languages is to have some kind of close method. I added a close method
to both the table and connection which will drop the underlying rust
object.
* We made significant improvements to table creation in
cc5f2136a6
for the `node` SDK. I copied these changes to the `nodejs` SDK.
* The nodejs tables were using fs to create tmp directories and these
were not getting cleaned up. This is mostly harmless but annoying and so
I changed it up a bit to ensure we cleanup tmp directories.
* ~~countRows in the node SDK was returning `bigint`. I changed it to
return `number`~~ (this actually happened in a previous PR)
* Tables and connections now implement `std::fmt::Display` which is
hooked into python's `__repr__`. Node has no concept of a regular "to
string" function and so I added a `display` method.
* Python method signatures are changing so that optional parameters are
always `Optional[foo] = None` instead of something like `foo = False`.
This is because we want those defaults to be in rust whenever possible
(though we still need to mention the default in documentation).
* I changed the python `AsyncConnection/AsyncTable` classes from
abstract classes with a single implementation to just classes because we
no longer have the remote implementation in python.

Note: this does NOT add the `add` function to the remote table. This PR
was already large enough, and the remote implementation is unique
enough, that I am going to do all the remote stuff at a later date (we
should have the structure in place and correct so there shouldn't be any
refactor concerns)

---------

Co-authored-by: Will Jones <willjones127@gmail.com>
This commit is contained in:
Weston Pace
2024-03-04 09:27:41 -08:00
parent 3bbcaba65b
commit 8033a44d68
42 changed files with 2822 additions and 1122 deletions

View File

@@ -21,7 +21,7 @@ __version__ = importlib.metadata.version("lancedb")
from ._lancedb import connect as lancedb_connect
from .common import URI, sanitize_uri
from .db import AsyncConnection, AsyncLanceDBConnection, DBConnection, LanceDBConnection
from .db import AsyncConnection, DBConnection, LanceDBConnection
from .remote.db import RemoteDBConnection
from .schema import vector # noqa: F401
@@ -167,8 +167,17 @@ async def connect_async(
conn : DBConnection
A connection to a LanceDB database.
"""
return AsyncLanceDBConnection(
if read_consistency_interval is not None:
read_consistency_interval_secs = read_consistency_interval.total_seconds()
else:
read_consistency_interval_secs = None
return AsyncConnection(
await lancedb_connect(
sanitize_uri(uri), api_key, region, host_override, read_consistency_interval
sanitize_uri(uri),
api_key,
region,
host_override,
read_consistency_interval_secs,
)
)

View File

@@ -13,6 +13,7 @@ class Connection(object):
class Table(object):
def name(self) -> str: ...
def __repr__(self) -> str: ...
async def schema(self) -> pa.Schema: ...
async def connect(

View File

@@ -17,7 +17,7 @@ import inspect
import os
from abc import abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING, Iterable, List, Optional, Union
from typing import TYPE_CHECKING, Iterable, List, Literal, Optional, Union
import pyarrow as pa
from overrides import EnforceOverrides, override
@@ -28,7 +28,7 @@ from lancedb.embeddings.registry import EmbeddingFunctionRegistry
from lancedb.utils.events import register_event
from .pydantic import LanceModel
from .table import AsyncLanceTable, LanceTable, Table, _sanitize_data
from .table import AsyncTable, LanceTable, Table, _sanitize_data
from .util import fs_from_uri, get_uri_location, get_uri_scheme, join_uri
if TYPE_CHECKING:
@@ -427,12 +427,64 @@ class LanceDBConnection(DBConnection):
filesystem.delete_dir(path)
class AsyncConnection(EnforceOverrides):
"""An active LanceDB connection interface."""
class AsyncConnection(object):
"""An active LanceDB connection
To obtain a connection you can use the [connect] function.
This could be a native connection (using lance) or a remote connection (e.g. for
connecting to LanceDb Cloud)
Local connections do not currently hold any open resources but they may do so in the
future (for example, for shared cache or connections to catalog services) Remote
connections represent an open connection to the remote server. The [close] method
can be used to release any underlying resources eagerly. The connection can also
be used as a context manager:
Connections can be shared on multiple threads and are expected to be long lived.
Connections can also be used as a context manager, however, in many cases a single
connection can be used for the lifetime of the application and so this is often
not needed. Closing a connection is optional. If it is not closed then it will
be automatically closed when the connection object is deleted.
Examples
--------
>>> import asyncio
>>> import lancedb
>>> async def my_connect():
... with await lancedb.connect("/tmp/my_dataset") as conn:
... # do something with the connection
... pass
... # conn is closed here
"""
def __init__(self, connection: LanceDbConnection):
self._inner = connection
def __repr__(self):
return self._inner.__repr__()
def __enter__(self):
self
def __exit__(self, *_):
self.close()
def is_open(self):
"""Return True if the connection is open."""
return self._inner.is_open()
def close(self):
"""Close the connection, releasing any underlying resources.
It is safe to call this method multiple times.
Any attempt to use the connection after it is closed will result in an error."""
self._inner.close()
@abstractmethod
async def table_names(
self, *, page_token: Optional[str] = None, limit: int = 10
self, *, page_token: Optional[str] = None, limit: Optional[int] = None
) -> Iterable[str]:
"""List all tables in this database, in sorted order
@@ -450,18 +502,18 @@ class AsyncConnection(EnforceOverrides):
-------
Iterable of str
"""
pass
# TODO: hook in page_token and limit
return await self._inner.table_names()
@abstractmethod
async def create_table(
self,
name: str,
data: Optional[DATA] = None,
schema: Optional[Union[pa.Schema, LanceModel]] = None,
mode: str = "create",
exist_ok: bool = False,
on_bad_vectors: str = "error",
fill_value: float = 0.0,
mode: Optional[Literal["create", "overwrite"]] = None,
exist_ok: Optional[bool] = None,
on_bad_vectors: Optional[str] = None,
fill_value: Optional[float] = None,
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
) -> Table:
"""Create a [Table][lancedb.table.Table] in the database.
@@ -485,7 +537,7 @@ class AsyncConnection(EnforceOverrides):
- pyarrow.Schema
- [LanceModel][lancedb.pydantic.LanceModel]
mode: str; default "create"
mode: Literal["create", "overwrite"]; default "create"
The mode to use when creating the table.
Can be either "create" or "overwrite".
By default, if the table already exists, an exception is raised.
@@ -601,72 +653,6 @@ class AsyncConnection(EnforceOverrides):
LanceTable(connection=..., name="table4")
"""
raise NotImplementedError
async def open_table(self, name: str) -> Table:
"""Open a Lance Table in the database.
Parameters
----------
name: str
The name of the table.
Returns
-------
A LanceTable object representing the table.
"""
raise NotImplementedError
async def drop_table(self, name: str):
"""Drop a table from the database.
Parameters
----------
name: str
The name of the table.
"""
raise NotImplementedError
async def drop_database(self):
"""
Drop database
This is the same thing as dropping all the tables
"""
raise NotImplementedError
class AsyncLanceDBConnection(AsyncConnection):
def __init__(self, connection: LanceDbConnection):
self._inner = connection
async def __repr__(self) -> str:
pass
@override
async def table_names(
self,
*,
page_token=None,
limit=None,
) -> Iterable[str]:
# TODO: hook in page_token and limit
return await self._inner.table_names()
@override
async def create_table(
self,
name: str,
data: Optional[DATA] = None,
schema: Optional[Union[pa.Schema, LanceModel]] = None,
mode: str = "create",
exist_ok: bool = False,
on_bad_vectors: str = "error",
fill_value: float = 0.0,
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
) -> Table:
if mode.lower() not in ["create", "overwrite"]:
raise ValueError("mode must be either 'create' or 'overwrite'")
if inspect.isclass(schema) and issubclass(schema, LanceModel):
# convert LanceModel to pyarrow schema
# note that it's possible this contains
@@ -681,6 +667,14 @@ class AsyncLanceDBConnection(AsyncConnection):
registry = EmbeddingFunctionRegistry.get_instance()
metadata = registry.get_table_metadata(embedding_functions)
# Defining defaults here and not in function prototype. In the future
# these defaults will move into rust so better to keep them as None.
if on_bad_vectors is None:
on_bad_vectors = "error"
if fill_value is None:
fill_value = 0.0
if data is not None:
data = _sanitize_data(
data,
@@ -708,6 +702,10 @@ class AsyncLanceDBConnection(AsyncConnection):
schema = schema.with_metadata(metadata)
validate_schema(schema)
if exist_ok is None:
exist_ok = False
if mode is None:
mode = "create"
if mode == "create" and exist_ok:
mode = "exist_ok"
@@ -722,16 +720,37 @@ class AsyncLanceDBConnection(AsyncConnection):
)
register_event("create_table")
return AsyncLanceTable(new_table)
return AsyncTable(new_table)
@override
async def open_table(self, name: str) -> LanceTable:
async def open_table(self, name: str) -> Table:
"""Open a Lance Table in the database.
Parameters
----------
name: str
The name of the table.
Returns
-------
A LanceTable object representing the table.
"""
table = await self._inner.open_table(name)
register_event("open_table")
return AsyncTable(table)
async def drop_table(self, name: str):
"""Drop a table from the database.
Parameters
----------
name: str
The name of the table.
"""
raise NotImplementedError
@override
async def drop_table(self, name: str, ignore_missing: bool = False):
raise NotImplementedError
@override
async def drop_database(self):
"""
Drop database
This is the same thing as dropping all the tables
"""
raise NotImplementedError

View File

@@ -19,7 +19,17 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import timedelta
from functools import cached_property
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
List,
Literal,
Optional,
Tuple,
Union,
)
import lance
import numpy as np
@@ -28,7 +38,6 @@ import pyarrow.compute as pc
import pyarrow.fs as pa_fs
from lance import LanceDataset
from lance.vector import vec_to_table
from overrides import override
from .common import DATA, VEC, VECTOR_COLUMN_NAME
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
@@ -1776,9 +1785,23 @@ def _sanitize_nans(data, fill_value, on_bad_vectors, vec_arr, vector_column_name
return data
class AsyncTable(ABC):
class AsyncTable:
"""
A Table is a collection of Records in a LanceDB Database.
An AsyncTable is a collection of Records in a LanceDB Database.
An AsyncTable can be obtained from the
[AsyncConnection.create_table][lancedb.AsyncConnection.create_table] and
[AsyncConnection.open_table][lancedb.AsyncConnection.open_table] methods.
An AsyncTable object is expected to be long lived and reused for multiple
operations. AsyncTable objects will cache a certain amount of index data in memory.
This cache will be freed when the Table is garbage collected. To eagerly free the
cache you can call the [close][AsyncTable.close] method. Once the AsyncTable is
closed, it cannot be used for any further operations.
An AsyncTable can also be used as a context manager, and will automatically close
when the context is exited. Closing a table is optional. If you do not close the
table, it will be closed when the AsyncTable object is garbage collected.
Examples
--------
@@ -1813,21 +1836,49 @@ class AsyncTable(ABC):
[Table.create_index][lancedb.table.Table.create_index].
"""
def __init__(self, table: LanceDBTable):
"""Create a new Table object.
You should not create Table objects directly.
Use [AsyncConnection.create_table][lancedb.AsyncConnection.create_table] and
[AsyncConnection.open_table][lancedb.AsyncConnection.open_table] to obtain
Table objects."""
self._inner = table
def __repr__(self):
return self._inner.__repr__()
def __enter__(self):
return self
def __exit__(self, *_):
self.close()
def is_open(self) -> bool:
"""Return True if the table is closed."""
return self._inner.is_open()
def close(self):
"""Close the table and free any resources associated with it.
It is safe to call this method multiple times.
Any attempt to use the table after it has been closed will raise an error."""
return self._inner.close()
@property
@abstractmethod
def name(self) -> str:
"""The name of the table."""
raise NotImplementedError
return self._inner.name()
@abstractmethod
async def schema(self) -> pa.Schema:
"""The [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#)
of this Table
"""
raise NotImplementedError
return await self._inner.schema()
@abstractmethod
async def count_rows(self, filter: Optional[str] = None) -> int:
"""
Count the number of rows in the table.
@@ -1837,7 +1888,7 @@ class AsyncTable(ABC):
filter: str, optional
A SQL where clause to filter the rows to count.
"""
raise NotImplementedError
return await self._inner.count_rows(filter)
async def to_pandas(self) -> "pd.DataFrame":
"""Return the table as a pandas DataFrame.
@@ -1848,7 +1899,6 @@ class AsyncTable(ABC):
"""
return self.to_arrow().to_pandas()
@abstractmethod
async def to_arrow(self) -> pa.Table:
"""Return the table as a pyarrow Table.
@@ -1896,7 +1946,6 @@ class AsyncTable(ABC):
"""
raise NotImplementedError
@abstractmethod
async def create_scalar_index(
self,
column: str,
@@ -1967,13 +2016,13 @@ class AsyncTable(ABC):
"""
raise NotImplementedError
@abstractmethod
async def add(
self,
data: DATA,
mode: str = "append",
on_bad_vectors: str = "error",
fill_value: float = 0.0,
*,
mode: Optional[Literal["append", "overwrite"]] = "append",
on_bad_vectors: Optional[str] = None,
fill_value: Optional[float] = None,
):
"""Add more data to the [Table](Table).
@@ -1997,7 +2046,20 @@ class AsyncTable(ABC):
The value to use when filling vectors. Only used if on_bad_vectors="fill".
"""
raise NotImplementedError
schema = await self.schema()
if on_bad_vectors is None:
on_bad_vectors = "error"
if fill_value is None:
fill_value = 0.0
data = _sanitize_data(
data,
schema,
metadata=schema.metadata,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
await self._inner.add(data, mode)
register_event("add")
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
"""
@@ -2059,7 +2121,6 @@ class AsyncTable(ABC):
return LanceMergeInsertBuilder(self, on)
@abstractmethod
async def search(
self,
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
@@ -2142,11 +2203,9 @@ class AsyncTable(ABC):
"""
raise NotImplementedError
@abstractmethod
async def _execute_query(self, query: Query) -> pa.Table:
pass
@abstractmethod
async def _do_merge(
self,
merge: LanceMergeInsertBuilder,
@@ -2156,7 +2215,6 @@ class AsyncTable(ABC):
):
pass
@abstractmethod
async def delete(self, where: str):
"""Delete rows from the table.
@@ -2207,7 +2265,6 @@ class AsyncTable(ABC):
"""
raise NotImplementedError
@abstractmethod
async def update(
self,
where: Optional[str] = None,
@@ -2263,7 +2320,6 @@ class AsyncTable(ABC):
"""
raise NotImplementedError
@abstractmethod
async def cleanup_old_versions(
self,
older_than: Optional[timedelta] = None,
@@ -2295,7 +2351,6 @@ class AsyncTable(ABC):
"""
raise NotImplementedError
@abstractmethod
async def compact_files(self, *args, **kwargs):
"""
Run the compaction process on the table.
@@ -2311,7 +2366,6 @@ class AsyncTable(ABC):
"""
raise NotImplementedError
@abstractmethod
async def add_columns(self, transforms: Dict[str, str]):
"""
Add new columns with defined values.
@@ -2327,7 +2381,6 @@ class AsyncTable(ABC):
"""
raise NotImplementedError
@abstractmethod
async def alter_columns(self, alterations: Iterable[Dict[str, str]]):
"""
Alter column names and nullability.
@@ -2350,7 +2403,6 @@ class AsyncTable(ABC):
"""
raise NotImplementedError
@abstractmethod
async def drop_columns(self, columns: Iterable[str]):
"""
Drop columns from the table.
@@ -2363,126 +2415,3 @@ class AsyncTable(ABC):
The names of the columns to drop.
"""
raise NotImplementedError
class AsyncLanceTable(AsyncTable):
def __init__(self, table: LanceDBTable):
self._inner = table
@property
@override
def name(self) -> str:
return self._inner.name()
@override
async def schema(self) -> pa.Schema:
return await self._inner.schema()
@override
async def count_rows(self, filter: Optional[str] = None) -> int:
raise NotImplementedError
async def to_pandas(self) -> "pd.DataFrame":
return self.to_arrow().to_pandas()
@override
async def to_arrow(self) -> pa.Table:
raise NotImplementedError
async def create_index(
self,
metric="L2",
num_partitions=256,
num_sub_vectors=96,
vector_column_name: str = VECTOR_COLUMN_NAME,
replace: bool = True,
accelerator: Optional[str] = None,
index_cache_size: Optional[int] = None,
):
raise NotImplementedError
@override
async def create_scalar_index(
self,
column: str,
*,
replace: bool = True,
):
raise NotImplementedError
@override
async def add(
self,
data: DATA,
mode: str = "append",
on_bad_vectors: str = "error",
fill_value: float = 0.0,
):
raise NotImplementedError
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
on = [on] if isinstance(on, str) else list(on.iter())
return LanceMergeInsertBuilder(self, on)
@override
async def search(
self,
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
vector_column_name: Optional[str] = None,
query_type: str = "auto",
) -> LanceQueryBuilder:
raise NotImplementedError
@override
async def _execute_query(self, query: Query) -> pa.Table:
pass
@override
async def _do_merge(
self,
merge: LanceMergeInsertBuilder,
new_data: DATA,
on_bad_vectors: str,
fill_value: float,
):
pass
@override
async def delete(self, where: str):
raise NotImplementedError
@override
async def update(
self,
where: Optional[str] = None,
values: Optional[dict] = None,
*,
values_sql: Optional[Dict[str, str]] = None,
):
raise NotImplementedError
@override
async def cleanup_old_versions(
self,
older_than: Optional[timedelta] = None,
*,
delete_unverified: bool = False,
) -> CleanupStats:
raise NotImplementedError
@override
async def compact_files(self, *args, **kwargs):
raise NotImplementedError
@override
async def add_columns(self, transforms: Dict[str, str]):
raise NotImplementedError
@override
async def alter_columns(self, alterations: Iterable[Dict[str, str]]):
raise NotImplementedError
@override
async def drop_columns(self, columns: Iterable[str]):
raise NotImplementedError

View File

@@ -11,6 +11,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from datetime import timedelta
import lancedb
import numpy as np
import pandas as pd
@@ -250,6 +253,28 @@ def test_create_exist_ok(tmp_path):
db.create_table("test", schema=bad_schema, exist_ok=True)
@pytest.mark.asyncio
async def test_connect(tmp_path):
db = await lancedb.connect_async(tmp_path)
assert str(db) == f"NativeDatabase(uri={tmp_path}, read_consistency_interval=None)"
db = await lancedb.connect_async(
tmp_path, read_consistency_interval=timedelta(seconds=5)
)
assert str(db) == f"NativeDatabase(uri={tmp_path}, read_consistency_interval=5s)"
@pytest.mark.asyncio
async def test_close(tmp_path):
db = await lancedb.connect_async(tmp_path)
assert db.is_open()
db.close()
assert not db.is_open()
with pytest.raises(RuntimeError, match="is closed"):
await db.table_names()
@pytest.mark.asyncio
async def test_create_mode_async(tmp_path):
db = await lancedb.connect_async(tmp_path)
@@ -322,6 +347,39 @@ async def test_create_exist_ok_async(tmp_path):
# await db.create_table("test", schema=bad_schema, exist_ok=True)
@pytest.mark.asyncio
async def test_open_table(tmp_path):
db = await lancedb.connect_async(tmp_path)
data = pd.DataFrame(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
"item": ["foo", "bar"],
"price": [10.0, 20.0],
}
)
await db.create_table("test", data=data)
tbl = await db.open_table("test")
assert tbl.name == "test"
assert (
re.search(
r"NativeTable\(test, uri=.*test\.lance, read_consistency_interval=None\)",
str(tbl),
)
is not None
)
assert await tbl.schema() == pa.schema(
{
"vector": pa.list_(pa.float32(), list_size=2),
"item": pa.utf8(),
"price": pa.float64(),
}
)
with pytest.raises(ValueError, match="was not found"):
await db.open_table("does_not_exist")
def test_delete_table(tmp_path):
db = lancedb.connect(tmp_path)
data = pd.DataFrame(

View File

@@ -26,8 +26,9 @@ import pandas as pd
import polars as pl
import pyarrow as pa
import pytest
import pytest_asyncio
from lancedb.conftest import MockTextEmbeddingFunction
from lancedb.db import LanceDBConnection
from lancedb.db import AsyncConnection, LanceDBConnection
from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
from lancedb.pydantic import LanceModel, Vector
from lancedb.table import LanceTable
@@ -49,6 +50,13 @@ def db(tmp_path) -> MockDB:
return MockDB(tmp_path)
@pytest_asyncio.fixture
async def db_async(tmp_path) -> AsyncConnection:
return await lancedb.connect_async(
tmp_path, read_consistency_interval=timedelta(seconds=0)
)
def test_basic(db):
ds = LanceTable.create(
db,
@@ -65,6 +73,18 @@ def test_basic(db):
assert table.to_lance().to_table() == ds.to_table()
@pytest.mark.asyncio
async def test_close(db_async: AsyncConnection):
table = await db_async.create_table("some_table", data=[{"id": 0}])
assert table.is_open()
table.close()
assert not table.is_open()
with pytest.raises(Exception, match="Table some_table is closed"):
await table.count_rows()
assert str(table) == "ClosedTable(some_table)"
def test_create_table(db):
schema = pa.schema(
[
@@ -186,6 +206,25 @@ def test_add_pydantic_model(db):
assert len(really_flattened.columns) == 7
@pytest.mark.asyncio
async def test_add_async(db_async: AsyncConnection):
table = await db_async.create_table(
"test",
data=[
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
],
)
assert await table.count_rows() == 2
await table.add(
data=[
{"vector": [10.0, 11.0], "item": "baz", "price": 30.0},
],
)
table = await db_async.open_table("test")
assert await table.count_rows() == 3
def test_polars(db):
data = {
"vector": [[3.1, 4.1], [5.9, 26.5]],