feat: page_token / limit to native table_names function. Use async table_names function from sync table_names function (#1059)

The synchronous table_names function in python lancedb relies on arrow's
filesystem which behaves slightly differently than object_store. As a
result, the function would not work properly in GCS.

However, the async table_names function uses object_store directly and
thus is accurate. In most cases we can fallback to using the async
table_names function and so this PR does so. The one case we cannot is
if the user is already in an async context (we can't start a new async
event loop). Soon, we can just redirect those users to use the async API
instead of the sync API and so that case will eventually go away. For
now, we fallback to the old behavior.
This commit is contained in:
Weston Pace
2024-03-05 08:38:18 -08:00
committed by GitHub
parent 47dbb988bf
commit 9148cd6d47
21 changed files with 250 additions and 83 deletions

View File

@@ -3,7 +3,9 @@ from typing import Optional
import pyarrow as pa
class Connection(object):
async def table_names(self) -> list[str]: ...
async def table_names(
self, start_after: Optional[str], limit: Optional[int]
) -> list[str]: ...
async def create_table(
self, name: str, mode: str, data: pa.RecordBatchReader
) -> Table: ...

View File

@@ -13,6 +13,7 @@
from __future__ import annotations
import asyncio
import inspect
import os
from abc import abstractmethod
@@ -27,6 +28,7 @@ from lancedb.common import data_to_reader, validate_schema
from lancedb.embeddings.registry import EmbeddingFunctionRegistry
from lancedb.utils.events import register_event
from ._lancedb import connect as lancedb_connect
from .pydantic import LanceModel
from .table import AsyncTable, LanceTable, Table, _sanitize_data
from .util import fs_from_uri, get_uri_location, get_uri_scheme, join_uri
@@ -317,6 +319,10 @@ class LanceDBConnection(DBConnection):
def uri(self) -> str:
return self._uri
async def _async_get_table_names(self, start_after: Optional[str], limit: int):
conn = AsyncConnection(await lancedb_connect(self.uri))
return await conn.table_names(start_after=start_after, limit=limit)
@override
def table_names(
self, page_token: Optional[str] = None, limit: int = 10
@@ -329,23 +335,31 @@ class LanceDBConnection(DBConnection):
A list of table names.
"""
try:
filesystem = fs_from_uri(self.uri)[0]
except pa.ArrowInvalid:
raise NotImplementedError("Unsupported scheme: " + self.uri)
asyncio.get_running_loop()
# User application is async. Soon we will just tell them to use the
# async version. Until then fallback to the old sync implementation.
try:
filesystem = fs_from_uri(self.uri)[0]
except pa.ArrowInvalid:
raise NotImplementedError("Unsupported scheme: " + self.uri)
try:
loc = get_uri_location(self.uri)
paths = filesystem.get_file_info(fs.FileSelector(loc))
except FileNotFoundError:
# It is ok if the file does not exist since it will be created
paths = []
tables = [
os.path.splitext(file_info.base_name)[0]
for file_info in paths
if file_info.extension == "lance"
]
tables.sort()
return tables
try:
loc = get_uri_location(self.uri)
paths = filesystem.get_file_info(fs.FileSelector(loc))
except FileNotFoundError:
# It is ok if the file does not exist since it will be created
paths = []
tables = [
os.path.splitext(file_info.base_name)[0]
for file_info in paths
if file_info.extension == "lance"
]
tables.sort()
return tables
except RuntimeError:
# User application is sync. It is safe to use the async implementation
# under the hood.
return asyncio.run(self._async_get_table_names(page_token, limit))
def __len__(self) -> int:
return len(self.table_names())
@@ -484,26 +498,26 @@ class AsyncConnection(object):
self._inner.close()
async def table_names(
self, *, page_token: Optional[str] = None, limit: Optional[int] = None
self, *, start_after: Optional[str] = None, limit: Optional[int] = None
) -> Iterable[str]:
"""List all tables in this database, in sorted order
Parameters
----------
page_token: str, optional
The token to use for pagination. If not present, start from the beginning.
Typically, this token is last table name from the previous page.
Only supported by LanceDb Cloud.
start_after: str, optional
If present, only return names that come lexicographically after the supplied
value.
This can be combined with limit to implement pagination by setting this to
the last table name from the previous page.
limit: int, default 10
The size of the page to return.
Only supported by LanceDb Cloud.
The number of results to return.
Returns
-------
Iterable of str
"""
# TODO: hook in page_token and limit
return await self._inner.table_names()
return await self._inner.table_names(start_after=start_after, limit=limit)
async def create_table(
self,

View File

@@ -185,6 +185,10 @@ async def test_table_names_async(tmp_path):
db = await lancedb.connect_async(tmp_path)
assert await db.table_names() == ["test1", "test2", "test3"]
assert await db.table_names(limit=1) == ["test1"]
assert await db.table_names(start_after="test1", limit=1) == ["test2"]
assert await db.table_names(start_after="test1") == ["test2", "test3"]
def test_create_mode(tmp_path):
db = lancedb.connect(tmp_path)

View File

@@ -69,11 +69,20 @@ impl Connection {
self.inner.take();
}
pub fn table_names(self_: PyRef<'_, Self>) -> PyResult<&PyAny> {
pub fn table_names(
self_: PyRef<'_, Self>,
start_after: Option<String>,
limit: Option<u32>,
) -> PyResult<&PyAny> {
let inner = self_.get_inner()?.clone();
future_into_py(self_.py(), async move {
inner.table_names().await.infer_error()
})
let mut op = inner.table_names();
if let Some(start_after) = start_after {
op = op.start_after(start_after);
}
if let Some(limit) = limit {
op = op.limit(limit);
}
future_into_py(self_.py(), async move { op.execute().await.infer_error() })
}
pub fn create_table<'a>(