fix(python): run remote SDK futures in background thread (#1856)

Users who call the remote SDK from code that uses futures (either
`ThreadPoolExecutor` or `asyncio`) can get odd errors like:

```
Traceback (most recent call last):
  File "/usr/lib/python3.12/asyncio/events.py", line 88, in _run
    self._context.run(self._callback, *self._args)
RuntimeError: cannot enter context: <_contextvars.Context object at 0x7cfe94cdc900> is already entered
```

This PR fixes that by executing all LanceDB futures in a dedicated
thread pool running on a background thread. That way, it doesn't
interact with their threadpool.
This commit is contained in:
Will Jones
2024-11-25 13:12:47 -08:00
committed by GitHub
parent 3e9321fc40
commit 6826039575
5 changed files with 92 additions and 47 deletions

View File

@@ -3,7 +3,6 @@ name = "lancedb"
# version in Cargo.toml
dependencies = [
"deprecation",
"nest-asyncio~=1.0",
"pylance==0.20.0b2",
"tqdm>=4.27.0",
"pydantic>=1.10",

View File

@@ -0,0 +1,25 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import asyncio
import threading
class BackgroundEventLoop:
"""
A background event loop that can run futures.
Used to bridge sync and async code, without messing with users event loops.
"""
def __init__(self):
self.loop = asyncio.new_event_loop()
self.thread = threading.Thread(
target=self.loop.run_forever,
name="LanceDBBackgroundEventLoop",
daemon=True,
)
self.thread.start()
def run(self, future):
return asyncio.run_coroutine_threadsafe(future, self.loop).result()

View File

@@ -11,7 +11,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from datetime import timedelta
import logging
from concurrent.futures import ThreadPoolExecutor
@@ -21,6 +20,7 @@ import warnings
from lancedb import connect_async
from lancedb.remote import ClientConfig
from lancedb.remote.background_loop import BackgroundEventLoop
import pyarrow as pa
from overrides import override
@@ -31,6 +31,8 @@ from ..pydantic import LanceModel
from ..table import Table
from ..util import validate_table_name
LOOP = BackgroundEventLoop()
class RemoteDBConnection(DBConnection):
"""A connection to a remote LanceDB database."""
@@ -86,18 +88,9 @@ class RemoteDBConnection(DBConnection):
raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://")
self.db_name = parsed.netloc
import nest_asyncio
nest_asyncio.apply()
try:
self._loop = asyncio.get_running_loop()
except RuntimeError:
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self.client_config = client_config
self._conn = self._loop.run_until_complete(
self._conn = LOOP.run(
connect_async(
db_url,
api_key=api_key,
@@ -127,9 +120,7 @@ class RemoteDBConnection(DBConnection):
-------
An iterator of table names.
"""
return self._loop.run_until_complete(
self._conn.table_names(start_after=page_token, limit=limit)
)
return LOOP.run(self._conn.table_names(start_after=page_token, limit=limit))
@override
def open_table(self, name: str, *, index_cache_size: Optional[int] = None) -> Table:
@@ -152,8 +143,8 @@ class RemoteDBConnection(DBConnection):
" (there is no local cache to configure)"
)
table = self._loop.run_until_complete(self._conn.open_table(name))
return RemoteTable(table, self.db_name, self._loop)
table = LOOP.run(self._conn.open_table(name))
return RemoteTable(table, self.db_name)
@override
def create_table(
@@ -268,7 +259,7 @@ class RemoteDBConnection(DBConnection):
from .table import RemoteTable
table = self._loop.run_until_complete(
table = LOOP.run(
self._conn.create_table(
name,
data,
@@ -278,7 +269,7 @@ class RemoteDBConnection(DBConnection):
fill_value=fill_value,
)
)
return RemoteTable(table, self.db_name, self._loop)
return RemoteTable(table, self.db_name)
@override
def drop_table(self, name: str):
@@ -289,7 +280,7 @@ class RemoteDBConnection(DBConnection):
name: str
The name of the table.
"""
self._loop.run_until_complete(self._conn.drop_table(name))
LOOP.run(self._conn.drop_table(name))
@override
def rename_table(self, cur_name: str, new_name: str):
@@ -302,7 +293,7 @@ class RemoteDBConnection(DBConnection):
new_name: str
The new name of the table.
"""
self._loop.run_until_complete(self._conn.rename_table(cur_name, new_name))
LOOP.run(self._conn.rename_table(cur_name, new_name))
async def close(self):
"""Close the connection to the database."""

View File

@@ -12,12 +12,12 @@
# limitations under the License.
from datetime import timedelta
import asyncio
import logging
from functools import cached_property
from typing import Dict, Iterable, List, Optional, Union, Literal
from lancedb.index import FTS, BTree, Bitmap, HnswPq, HnswSq, IvfPq, LabelList
from lancedb.remote.db import LOOP
import pyarrow as pa
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
@@ -33,9 +33,7 @@ class RemoteTable(Table):
self,
table: AsyncTable,
db_name: str,
loop: Optional[asyncio.AbstractEventLoop] = None,
):
self._loop = loop
self._table = table
self.db_name = db_name
@@ -56,12 +54,12 @@ class RemoteTable(Table):
of this Table
"""
return self._loop.run_until_complete(self._table.schema())
return LOOP.run(self._table.schema())
@property
def version(self) -> int:
"""Get the current version of the table"""
return self._loop.run_until_complete(self._table.version())
return LOOP.run(self._table.version())
@cached_property
def embedding_functions(self) -> dict:
@@ -98,11 +96,11 @@ class RemoteTable(Table):
def list_indices(self):
"""List all the indices on the table"""
return self._loop.run_until_complete(self._table.list_indices())
return LOOP.run(self._table.list_indices())
def index_stats(self, index_uuid: str):
"""List all the stats of a specified index"""
return self._loop.run_until_complete(self._table.index_stats(index_uuid))
return LOOP.run(self._table.index_stats(index_uuid))
def create_scalar_index(
self,
@@ -132,9 +130,7 @@ class RemoteTable(Table):
else:
raise ValueError(f"Unknown index type: {index_type}")
self._loop.run_until_complete(
self._table.create_index(column, config=config, replace=replace)
)
LOOP.run(self._table.create_index(column, config=config, replace=replace))
def create_fts_index(
self,
@@ -144,9 +140,7 @@ class RemoteTable(Table):
with_position: bool = True,
):
config = FTS(with_position=with_position)
self._loop.run_until_complete(
self._table.create_index(column, config=config, replace=replace)
)
LOOP.run(self._table.create_index(column, config=config, replace=replace))
def create_index(
self,
@@ -227,9 +221,7 @@ class RemoteTable(Table):
" 'IVF_PQ', 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'"
)
self._loop.run_until_complete(
self._table.create_index(vector_column_name, config=config)
)
LOOP.run(self._table.create_index(vector_column_name, config=config))
def add(
self,
@@ -261,7 +253,7 @@ class RemoteTable(Table):
The value to use when filling vectors. Only used if on_bad_vectors="fill".
"""
self._loop.run_until_complete(
LOOP.run(
self._table.add(
data, mode=mode, on_bad_vectors=on_bad_vectors, fill_value=fill_value
)
@@ -349,9 +341,7 @@ class RemoteTable(Table):
def _execute_query(
self, query: Query, batch_size: Optional[int] = None
) -> pa.RecordBatchReader:
return self._loop.run_until_complete(
self._table._execute_query(query, batch_size=batch_size)
)
return LOOP.run(self._table._execute_query(query, batch_size=batch_size))
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
"""Returns a [`LanceMergeInsertBuilder`][lancedb.merge.LanceMergeInsertBuilder]
@@ -368,9 +358,7 @@ class RemoteTable(Table):
on_bad_vectors: str,
fill_value: float,
):
self._loop.run_until_complete(
self._table._do_merge(merge, new_data, on_bad_vectors, fill_value)
)
LOOP.run(self._table._do_merge(merge, new_data, on_bad_vectors, fill_value))
def delete(self, predicate: str):
"""Delete rows from the table.
@@ -419,7 +407,7 @@ class RemoteTable(Table):
x vector _distance # doctest: +SKIP
0 2 [3.0, 4.0] 85.0 # doctest: +SKIP
"""
self._loop.run_until_complete(self._table.delete(predicate))
LOOP.run(self._table.delete(predicate))
def update(
self,
@@ -469,7 +457,7 @@ class RemoteTable(Table):
2 2 [10.0, 10.0] # doctest: +SKIP
"""
self._loop.run_until_complete(
LOOP.run(
self._table.update(where=where, updates=values, updates_sql=values_sql)
)
@@ -499,7 +487,7 @@ class RemoteTable(Table):
)
def count_rows(self, filter: Optional[str] = None) -> int:
return self._loop.run_until_complete(self._table.count_rows(filter))
return LOOP.run(self._table.count_rows(filter))
def add_columns(self, transforms: Dict[str, str]):
raise NotImplementedError(

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
from concurrent.futures import ThreadPoolExecutor
import contextlib
from datetime import timedelta
import http.server
@@ -187,6 +188,47 @@ async def test_retry_error():
assert cause.status_code == 429
def test_table_add_in_threadpool():
def handler(request):
if request.path == "/v1/table/test/insert/":
request.send_response(200)
request.end_headers()
elif request.path == "/v1/table/test/create/":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
request.wfile.write(b"{}")
elif request.path == "/v1/table/test/describe/":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
payload = json.dumps(
dict(
version=1,
schema=dict(
fields=[
dict(name="id", type={"type": "int64"}, nullable=False),
]
),
)
)
request.wfile.write(payload.encode())
else:
request.send_response(404)
request.end_headers()
with mock_lancedb_connection(handler) as db:
table = db.create_table("test", [{"id": 1}])
with ThreadPoolExecutor(3) as executor:
futures = []
for _ in range(10):
future = executor.submit(table.add, [{"id": 1}])
futures.append(future)
for future in futures:
future.result()
@contextlib.contextmanager
def query_test_table(query_handler):
def handler(request):