mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 23:12:58 +00:00
feat: add support for add to async python API (#1037)
In order to add support for `add` we needed to migrate the rust `Table`
trait to a `Table` struct and `TableInternal` trait (similar to the way
the connection is designed).
While doing this we also cleaned up some inconsistencies between the
SDKs:
* Python and Node are garbage collected languages and it can be
difficult to trigger something to be freed. The convention for these
languages is to have some kind of close method. I added a close method
to both the table and connection which will drop the underlying rust
object.
* We made significant improvements to table creation in
cc5f2136a6
for the `node` SDK. I copied these changes to the `nodejs` SDK.
* The nodejs tables were using fs to create tmp directories and these
were not getting cleaned up. This is mostly harmless but annoying and so
I changed it up a bit to ensure we cleanup tmp directories.
* ~~countRows in the node SDK was returning `bigint`. I changed it to
return `number`~~ (this actually happened in a previous PR)
* Tables and connections now implement `std::fmt::Display` which is
hooked into python's `__repr__`. Node has no concept of a regular "to
string" function and so I added a `display` method.
* Python method signatures are changing so that optional parameters are
always `Optional[foo] = None` instead of something like `foo = False`.
This is because we want those defaults to be in rust whenever possible
(though we still need to mention the default in documentation).
* I changed the python `AsyncConnection/AsyncTable` classes from
abstract classes with a single implementation to just classes because we
no longer have the remote implementation in python.
Note: this does NOT add the `add` function to the remote table. This PR
was already large enough, and the remote implementation is unique
enough, that I am going to do all the remote stuff at a later date (we
should have the structure in place and correct so there shouldn't be any
refactor concerns)
---------
Co-authored-by: Will Jones <willjones127@gmail.com>
This commit is contained in:
@@ -21,7 +21,7 @@ __version__ = importlib.metadata.version("lancedb")
|
||||
|
||||
from ._lancedb import connect as lancedb_connect
|
||||
from .common import URI, sanitize_uri
|
||||
from .db import AsyncConnection, AsyncLanceDBConnection, DBConnection, LanceDBConnection
|
||||
from .db import AsyncConnection, DBConnection, LanceDBConnection
|
||||
from .remote.db import RemoteDBConnection
|
||||
from .schema import vector # noqa: F401
|
||||
|
||||
@@ -167,8 +167,17 @@ async def connect_async(
|
||||
conn : DBConnection
|
||||
A connection to a LanceDB database.
|
||||
"""
|
||||
return AsyncLanceDBConnection(
|
||||
if read_consistency_interval is not None:
|
||||
read_consistency_interval_secs = read_consistency_interval.total_seconds()
|
||||
else:
|
||||
read_consistency_interval_secs = None
|
||||
|
||||
return AsyncConnection(
|
||||
await lancedb_connect(
|
||||
sanitize_uri(uri), api_key, region, host_override, read_consistency_interval
|
||||
sanitize_uri(uri),
|
||||
api_key,
|
||||
region,
|
||||
host_override,
|
||||
read_consistency_interval_secs,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -13,6 +13,7 @@ class Connection(object):
|
||||
|
||||
class Table(object):
|
||||
def name(self) -> str: ...
|
||||
def __repr__(self) -> str: ...
|
||||
async def schema(self) -> pa.Schema: ...
|
||||
|
||||
async def connect(
|
||||
|
||||
@@ -17,7 +17,7 @@ import inspect
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Iterable, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Iterable, List, Literal, Optional, Union
|
||||
|
||||
import pyarrow as pa
|
||||
from overrides import EnforceOverrides, override
|
||||
@@ -28,7 +28,7 @@ from lancedb.embeddings.registry import EmbeddingFunctionRegistry
|
||||
from lancedb.utils.events import register_event
|
||||
|
||||
from .pydantic import LanceModel
|
||||
from .table import AsyncLanceTable, LanceTable, Table, _sanitize_data
|
||||
from .table import AsyncTable, LanceTable, Table, _sanitize_data
|
||||
from .util import fs_from_uri, get_uri_location, get_uri_scheme, join_uri
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -427,12 +427,64 @@ class LanceDBConnection(DBConnection):
|
||||
filesystem.delete_dir(path)
|
||||
|
||||
|
||||
class AsyncConnection(EnforceOverrides):
|
||||
"""An active LanceDB connection interface."""
|
||||
class AsyncConnection(object):
|
||||
"""An active LanceDB connection
|
||||
|
||||
To obtain a connection you can use the [connect] function.
|
||||
|
||||
This could be a native connection (using lance) or a remote connection (e.g. for
|
||||
connecting to LanceDb Cloud)
|
||||
|
||||
Local connections do not currently hold any open resources but they may do so in the
|
||||
future (for example, for shared cache or connections to catalog services) Remote
|
||||
connections represent an open connection to the remote server. The [close] method
|
||||
can be used to release any underlying resources eagerly. The connection can also
|
||||
be used as a context manager:
|
||||
|
||||
Connections can be shared on multiple threads and are expected to be long lived.
|
||||
Connections can also be used as a context manager, however, in many cases a single
|
||||
connection can be used for the lifetime of the application and so this is often
|
||||
not needed. Closing a connection is optional. If it is not closed then it will
|
||||
be automatically closed when the connection object is deleted.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
>>> import asyncio
|
||||
>>> import lancedb
|
||||
>>> async def my_connect():
|
||||
... with await lancedb.connect("/tmp/my_dataset") as conn:
|
||||
... # do something with the connection
|
||||
... pass
|
||||
... # conn is closed here
|
||||
"""
|
||||
|
||||
def __init__(self, connection: LanceDbConnection):
|
||||
self._inner = connection
|
||||
|
||||
def __repr__(self):
|
||||
return self._inner.__repr__()
|
||||
|
||||
def __enter__(self):
|
||||
self
|
||||
|
||||
def __exit__(self, *_):
|
||||
self.close()
|
||||
|
||||
def is_open(self):
|
||||
"""Return True if the connection is open."""
|
||||
return self._inner.is_open()
|
||||
|
||||
def close(self):
|
||||
"""Close the connection, releasing any underlying resources.
|
||||
|
||||
It is safe to call this method multiple times.
|
||||
|
||||
Any attempt to use the connection after it is closed will result in an error."""
|
||||
self._inner.close()
|
||||
|
||||
@abstractmethod
|
||||
async def table_names(
|
||||
self, *, page_token: Optional[str] = None, limit: int = 10
|
||||
self, *, page_token: Optional[str] = None, limit: Optional[int] = None
|
||||
) -> Iterable[str]:
|
||||
"""List all tables in this database, in sorted order
|
||||
|
||||
@@ -450,18 +502,18 @@ class AsyncConnection(EnforceOverrides):
|
||||
-------
|
||||
Iterable of str
|
||||
"""
|
||||
pass
|
||||
# TODO: hook in page_token and limit
|
||||
return await self._inner.table_names()
|
||||
|
||||
@abstractmethod
|
||||
async def create_table(
|
||||
self,
|
||||
name: str,
|
||||
data: Optional[DATA] = None,
|
||||
schema: Optional[Union[pa.Schema, LanceModel]] = None,
|
||||
mode: str = "create",
|
||||
exist_ok: bool = False,
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
mode: Optional[Literal["create", "overwrite"]] = None,
|
||||
exist_ok: Optional[bool] = None,
|
||||
on_bad_vectors: Optional[str] = None,
|
||||
fill_value: Optional[float] = None,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||
) -> Table:
|
||||
"""Create a [Table][lancedb.table.Table] in the database.
|
||||
@@ -485,7 +537,7 @@ class AsyncConnection(EnforceOverrides):
|
||||
- pyarrow.Schema
|
||||
|
||||
- [LanceModel][lancedb.pydantic.LanceModel]
|
||||
mode: str; default "create"
|
||||
mode: Literal["create", "overwrite"]; default "create"
|
||||
The mode to use when creating the table.
|
||||
Can be either "create" or "overwrite".
|
||||
By default, if the table already exists, an exception is raised.
|
||||
@@ -601,72 +653,6 @@ class AsyncConnection(EnforceOverrides):
|
||||
LanceTable(connection=..., name="table4")
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def open_table(self, name: str) -> Table:
|
||||
"""Open a Lance Table in the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A LanceTable object representing the table.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def drop_table(self, name: str):
|
||||
"""Drop a table from the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def drop_database(self):
|
||||
"""
|
||||
Drop database
|
||||
This is the same thing as dropping all the tables
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AsyncLanceDBConnection(AsyncConnection):
|
||||
def __init__(self, connection: LanceDbConnection):
|
||||
self._inner = connection
|
||||
|
||||
async def __repr__(self) -> str:
|
||||
pass
|
||||
|
||||
@override
|
||||
async def table_names(
|
||||
self,
|
||||
*,
|
||||
page_token=None,
|
||||
limit=None,
|
||||
) -> Iterable[str]:
|
||||
# TODO: hook in page_token and limit
|
||||
return await self._inner.table_names()
|
||||
|
||||
@override
|
||||
async def create_table(
|
||||
self,
|
||||
name: str,
|
||||
data: Optional[DATA] = None,
|
||||
schema: Optional[Union[pa.Schema, LanceModel]] = None,
|
||||
mode: str = "create",
|
||||
exist_ok: bool = False,
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||
) -> Table:
|
||||
if mode.lower() not in ["create", "overwrite"]:
|
||||
raise ValueError("mode must be either 'create' or 'overwrite'")
|
||||
|
||||
if inspect.isclass(schema) and issubclass(schema, LanceModel):
|
||||
# convert LanceModel to pyarrow schema
|
||||
# note that it's possible this contains
|
||||
@@ -681,6 +667,14 @@ class AsyncLanceDBConnection(AsyncConnection):
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
metadata = registry.get_table_metadata(embedding_functions)
|
||||
|
||||
# Defining defaults here and not in function prototype. In the future
|
||||
# these defaults will move into rust so better to keep them as None.
|
||||
if on_bad_vectors is None:
|
||||
on_bad_vectors = "error"
|
||||
|
||||
if fill_value is None:
|
||||
fill_value = 0.0
|
||||
|
||||
if data is not None:
|
||||
data = _sanitize_data(
|
||||
data,
|
||||
@@ -708,6 +702,10 @@ class AsyncLanceDBConnection(AsyncConnection):
|
||||
schema = schema.with_metadata(metadata)
|
||||
validate_schema(schema)
|
||||
|
||||
if exist_ok is None:
|
||||
exist_ok = False
|
||||
if mode is None:
|
||||
mode = "create"
|
||||
if mode == "create" and exist_ok:
|
||||
mode = "exist_ok"
|
||||
|
||||
@@ -722,16 +720,37 @@ class AsyncLanceDBConnection(AsyncConnection):
|
||||
)
|
||||
|
||||
register_event("create_table")
|
||||
return AsyncLanceTable(new_table)
|
||||
return AsyncTable(new_table)
|
||||
|
||||
@override
|
||||
async def open_table(self, name: str) -> LanceTable:
|
||||
async def open_table(self, name: str) -> Table:
|
||||
"""Open a Lance Table in the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A LanceTable object representing the table.
|
||||
"""
|
||||
table = await self._inner.open_table(name)
|
||||
register_event("open_table")
|
||||
return AsyncTable(table)
|
||||
|
||||
async def drop_table(self, name: str):
|
||||
"""Drop a table from the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def drop_table(self, name: str, ignore_missing: bool = False):
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def drop_database(self):
|
||||
"""
|
||||
Drop database
|
||||
This is the same thing as dropping all the tables
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -19,7 +19,17 @@ from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import lance
|
||||
import numpy as np
|
||||
@@ -28,7 +38,6 @@ import pyarrow.compute as pc
|
||||
import pyarrow.fs as pa_fs
|
||||
from lance import LanceDataset
|
||||
from lance.vector import vec_to_table
|
||||
from overrides import override
|
||||
|
||||
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||
@@ -1776,9 +1785,23 @@ def _sanitize_nans(data, fill_value, on_bad_vectors, vec_arr, vector_column_name
|
||||
return data
|
||||
|
||||
|
||||
class AsyncTable(ABC):
|
||||
class AsyncTable:
|
||||
"""
|
||||
A Table is a collection of Records in a LanceDB Database.
|
||||
An AsyncTable is a collection of Records in a LanceDB Database.
|
||||
|
||||
An AsyncTable can be obtained from the
|
||||
[AsyncConnection.create_table][lancedb.AsyncConnection.create_table] and
|
||||
[AsyncConnection.open_table][lancedb.AsyncConnection.open_table] methods.
|
||||
|
||||
An AsyncTable object is expected to be long lived and reused for multiple
|
||||
operations. AsyncTable objects will cache a certain amount of index data in memory.
|
||||
This cache will be freed when the Table is garbage collected. To eagerly free the
|
||||
cache you can call the [close][AsyncTable.close] method. Once the AsyncTable is
|
||||
closed, it cannot be used for any further operations.
|
||||
|
||||
An AsyncTable can also be used as a context manager, and will automatically close
|
||||
when the context is exited. Closing a table is optional. If you do not close the
|
||||
table, it will be closed when the AsyncTable object is garbage collected.
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -1813,21 +1836,49 @@ class AsyncTable(ABC):
|
||||
[Table.create_index][lancedb.table.Table.create_index].
|
||||
"""
|
||||
|
||||
def __init__(self, table: LanceDBTable):
|
||||
"""Create a new Table object.
|
||||
|
||||
You should not create Table objects directly.
|
||||
|
||||
Use [AsyncConnection.create_table][lancedb.AsyncConnection.create_table] and
|
||||
[AsyncConnection.open_table][lancedb.AsyncConnection.open_table] to obtain
|
||||
Table objects."""
|
||||
self._inner = table
|
||||
|
||||
def __repr__(self):
|
||||
return self._inner.__repr__()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *_):
|
||||
self.close()
|
||||
|
||||
def is_open(self) -> bool:
|
||||
"""Return True if the table is closed."""
|
||||
return self._inner.is_open()
|
||||
|
||||
def close(self):
|
||||
"""Close the table and free any resources associated with it.
|
||||
|
||||
It is safe to call this method multiple times.
|
||||
|
||||
Any attempt to use the table after it has been closed will raise an error."""
|
||||
return self._inner.close()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""The name of the table."""
|
||||
raise NotImplementedError
|
||||
return self._inner.name()
|
||||
|
||||
@abstractmethod
|
||||
async def schema(self) -> pa.Schema:
|
||||
"""The [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#)
|
||||
of this Table
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
return await self._inner.schema()
|
||||
|
||||
@abstractmethod
|
||||
async def count_rows(self, filter: Optional[str] = None) -> int:
|
||||
"""
|
||||
Count the number of rows in the table.
|
||||
@@ -1837,7 +1888,7 @@ class AsyncTable(ABC):
|
||||
filter: str, optional
|
||||
A SQL where clause to filter the rows to count.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
return await self._inner.count_rows(filter)
|
||||
|
||||
async def to_pandas(self) -> "pd.DataFrame":
|
||||
"""Return the table as a pandas DataFrame.
|
||||
@@ -1848,7 +1899,6 @@ class AsyncTable(ABC):
|
||||
"""
|
||||
return self.to_arrow().to_pandas()
|
||||
|
||||
@abstractmethod
|
||||
async def to_arrow(self) -> pa.Table:
|
||||
"""Return the table as a pyarrow Table.
|
||||
|
||||
@@ -1896,7 +1946,6 @@ class AsyncTable(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def create_scalar_index(
|
||||
self,
|
||||
column: str,
|
||||
@@ -1967,13 +2016,13 @@ class AsyncTable(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def add(
|
||||
self,
|
||||
data: DATA,
|
||||
mode: str = "append",
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
*,
|
||||
mode: Optional[Literal["append", "overwrite"]] = "append",
|
||||
on_bad_vectors: Optional[str] = None,
|
||||
fill_value: Optional[float] = None,
|
||||
):
|
||||
"""Add more data to the [Table](Table).
|
||||
|
||||
@@ -1997,7 +2046,20 @@ class AsyncTable(ABC):
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
schema = await self.schema()
|
||||
if on_bad_vectors is None:
|
||||
on_bad_vectors = "error"
|
||||
if fill_value is None:
|
||||
fill_value = 0.0
|
||||
data = _sanitize_data(
|
||||
data,
|
||||
schema,
|
||||
metadata=schema.metadata,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
await self._inner.add(data, mode)
|
||||
register_event("add")
|
||||
|
||||
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
|
||||
"""
|
||||
@@ -2059,7 +2121,6 @@ class AsyncTable(ABC):
|
||||
|
||||
return LanceMergeInsertBuilder(self, on)
|
||||
|
||||
@abstractmethod
|
||||
async def search(
|
||||
self,
|
||||
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
||||
@@ -2142,11 +2203,9 @@ class AsyncTable(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def _execute_query(self, query: Query) -> pa.Table:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def _do_merge(
|
||||
self,
|
||||
merge: LanceMergeInsertBuilder,
|
||||
@@ -2156,7 +2215,6 @@ class AsyncTable(ABC):
|
||||
):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, where: str):
|
||||
"""Delete rows from the table.
|
||||
|
||||
@@ -2207,7 +2265,6 @@ class AsyncTable(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def update(
|
||||
self,
|
||||
where: Optional[str] = None,
|
||||
@@ -2263,7 +2320,6 @@ class AsyncTable(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def cleanup_old_versions(
|
||||
self,
|
||||
older_than: Optional[timedelta] = None,
|
||||
@@ -2295,7 +2351,6 @@ class AsyncTable(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def compact_files(self, *args, **kwargs):
|
||||
"""
|
||||
Run the compaction process on the table.
|
||||
@@ -2311,7 +2366,6 @@ class AsyncTable(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def add_columns(self, transforms: Dict[str, str]):
|
||||
"""
|
||||
Add new columns with defined values.
|
||||
@@ -2327,7 +2381,6 @@ class AsyncTable(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def alter_columns(self, alterations: Iterable[Dict[str, str]]):
|
||||
"""
|
||||
Alter column names and nullability.
|
||||
@@ -2350,7 +2403,6 @@ class AsyncTable(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def drop_columns(self, columns: Iterable[str]):
|
||||
"""
|
||||
Drop columns from the table.
|
||||
@@ -2363,126 +2415,3 @@ class AsyncTable(ABC):
|
||||
The names of the columns to drop.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AsyncLanceTable(AsyncTable):
|
||||
def __init__(self, table: LanceDBTable):
|
||||
self._inner = table
|
||||
|
||||
@property
|
||||
@override
|
||||
def name(self) -> str:
|
||||
return self._inner.name()
|
||||
|
||||
@override
|
||||
async def schema(self) -> pa.Schema:
|
||||
return await self._inner.schema()
|
||||
|
||||
@override
|
||||
async def count_rows(self, filter: Optional[str] = None) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
async def to_pandas(self) -> "pd.DataFrame":
|
||||
return self.to_arrow().to_pandas()
|
||||
|
||||
@override
|
||||
async def to_arrow(self) -> pa.Table:
|
||||
raise NotImplementedError
|
||||
|
||||
async def create_index(
|
||||
self,
|
||||
metric="L2",
|
||||
num_partitions=256,
|
||||
num_sub_vectors=96,
|
||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||
replace: bool = True,
|
||||
accelerator: Optional[str] = None,
|
||||
index_cache_size: Optional[int] = None,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def create_scalar_index(
|
||||
self,
|
||||
column: str,
|
||||
*,
|
||||
replace: bool = True,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def add(
|
||||
self,
|
||||
data: DATA,
|
||||
mode: str = "append",
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
|
||||
on = [on] if isinstance(on, str) else list(on.iter())
|
||||
|
||||
return LanceMergeInsertBuilder(self, on)
|
||||
|
||||
@override
|
||||
async def search(
|
||||
self,
|
||||
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
||||
vector_column_name: Optional[str] = None,
|
||||
query_type: str = "auto",
|
||||
) -> LanceQueryBuilder:
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def _execute_query(self, query: Query) -> pa.Table:
|
||||
pass
|
||||
|
||||
@override
|
||||
async def _do_merge(
|
||||
self,
|
||||
merge: LanceMergeInsertBuilder,
|
||||
new_data: DATA,
|
||||
on_bad_vectors: str,
|
||||
fill_value: float,
|
||||
):
|
||||
pass
|
||||
|
||||
@override
|
||||
async def delete(self, where: str):
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def update(
|
||||
self,
|
||||
where: Optional[str] = None,
|
||||
values: Optional[dict] = None,
|
||||
*,
|
||||
values_sql: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def cleanup_old_versions(
|
||||
self,
|
||||
older_than: Optional[timedelta] = None,
|
||||
*,
|
||||
delete_unverified: bool = False,
|
||||
) -> CleanupStats:
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def compact_files(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def add_columns(self, transforms: Dict[str, str]):
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def alter_columns(self, alterations: Iterable[Dict[str, str]]):
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def drop_columns(self, columns: Iterable[str]):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -11,6 +11,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from datetime import timedelta
|
||||
|
||||
import lancedb
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -250,6 +253,28 @@ def test_create_exist_ok(tmp_path):
|
||||
db.create_table("test", schema=bad_schema, exist_ok=True)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect(tmp_path):
|
||||
db = await lancedb.connect_async(tmp_path)
|
||||
assert str(db) == f"NativeDatabase(uri={tmp_path}, read_consistency_interval=None)"
|
||||
|
||||
db = await lancedb.connect_async(
|
||||
tmp_path, read_consistency_interval=timedelta(seconds=5)
|
||||
)
|
||||
assert str(db) == f"NativeDatabase(uri={tmp_path}, read_consistency_interval=5s)"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close(tmp_path):
|
||||
db = await lancedb.connect_async(tmp_path)
|
||||
assert db.is_open()
|
||||
db.close()
|
||||
assert not db.is_open()
|
||||
|
||||
with pytest.raises(RuntimeError, match="is closed"):
|
||||
await db.table_names()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_mode_async(tmp_path):
|
||||
db = await lancedb.connect_async(tmp_path)
|
||||
@@ -322,6 +347,39 @@ async def test_create_exist_ok_async(tmp_path):
|
||||
# await db.create_table("test", schema=bad_schema, exist_ok=True)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_open_table(tmp_path):
|
||||
db = await lancedb.connect_async(tmp_path)
|
||||
data = pd.DataFrame(
|
||||
{
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
"item": ["foo", "bar"],
|
||||
"price": [10.0, 20.0],
|
||||
}
|
||||
)
|
||||
await db.create_table("test", data=data)
|
||||
|
||||
tbl = await db.open_table("test")
|
||||
assert tbl.name == "test"
|
||||
assert (
|
||||
re.search(
|
||||
r"NativeTable\(test, uri=.*test\.lance, read_consistency_interval=None\)",
|
||||
str(tbl),
|
||||
)
|
||||
is not None
|
||||
)
|
||||
assert await tbl.schema() == pa.schema(
|
||||
{
|
||||
"vector": pa.list_(pa.float32(), list_size=2),
|
||||
"item": pa.utf8(),
|
||||
"price": pa.float64(),
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="was not found"):
|
||||
await db.open_table("does_not_exist")
|
||||
|
||||
|
||||
def test_delete_table(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
data = pd.DataFrame(
|
||||
|
||||
@@ -26,8 +26,9 @@ import pandas as pd
|
||||
import polars as pl
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from lancedb.conftest import MockTextEmbeddingFunction
|
||||
from lancedb.db import LanceDBConnection
|
||||
from lancedb.db import AsyncConnection, LanceDBConnection
|
||||
from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.table import LanceTable
|
||||
@@ -49,6 +50,13 @@ def db(tmp_path) -> MockDB:
|
||||
return MockDB(tmp_path)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_async(tmp_path) -> AsyncConnection:
|
||||
return await lancedb.connect_async(
|
||||
tmp_path, read_consistency_interval=timedelta(seconds=0)
|
||||
)
|
||||
|
||||
|
||||
def test_basic(db):
|
||||
ds = LanceTable.create(
|
||||
db,
|
||||
@@ -65,6 +73,18 @@ def test_basic(db):
|
||||
assert table.to_lance().to_table() == ds.to_table()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close(db_async: AsyncConnection):
|
||||
table = await db_async.create_table("some_table", data=[{"id": 0}])
|
||||
assert table.is_open()
|
||||
table.close()
|
||||
assert not table.is_open()
|
||||
|
||||
with pytest.raises(Exception, match="Table some_table is closed"):
|
||||
await table.count_rows()
|
||||
assert str(table) == "ClosedTable(some_table)"
|
||||
|
||||
|
||||
def test_create_table(db):
|
||||
schema = pa.schema(
|
||||
[
|
||||
@@ -186,6 +206,25 @@ def test_add_pydantic_model(db):
|
||||
assert len(really_flattened.columns) == 7
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_async(db_async: AsyncConnection):
|
||||
table = await db_async.create_table(
|
||||
"test",
|
||||
data=[
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
||||
],
|
||||
)
|
||||
assert await table.count_rows() == 2
|
||||
await table.add(
|
||||
data=[
|
||||
{"vector": [10.0, 11.0], "item": "baz", "price": 30.0},
|
||||
],
|
||||
)
|
||||
table = await db_async.open_table("test")
|
||||
assert await table.count_rows() == 3
|
||||
|
||||
|
||||
def test_polars(db):
|
||||
data = {
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
|
||||
Reference in New Issue
Block a user