From 68260395751c4789ce38f1f93bea46598f6f32c0 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Mon, 25 Nov 2024 13:12:47 -0800 Subject: [PATCH] fix(python): run remote SDK futures in background thread (#1856) Users who call the remote SDK from code that uses futures (either `ThreadPoolExecutor` or `asyncio`) can get odd errors like: ``` Traceback (most recent call last): File "/usr/lib/python3.12/asyncio/events.py", line 88, in _run self._context.run(self._callback, *self._args) RuntimeError: cannot enter context: <_contextvars.Context object at 0x7cfe94cdc900> is already entered ``` This PR fixes that by executing all LanceDB futures in a dedicated thread pool running on a background thread. That way, it doesn't interact with their threadpool. --- python/pyproject.toml | 1 - .../python/lancedb/remote/background_loop.py | 25 +++++++++++ python/python/lancedb/remote/db.py | 31 +++++--------- python/python/lancedb/remote/table.py | 40 +++++++----------- python/python/tests/test_remote_db.py | 42 +++++++++++++++++++ 5 files changed, 92 insertions(+), 47 deletions(-) create mode 100644 python/python/lancedb/remote/background_loop.py diff --git a/python/pyproject.toml b/python/pyproject.toml index a60f5baa..b96047d0 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -3,7 +3,6 @@ name = "lancedb" # version in Cargo.toml dependencies = [ "deprecation", - "nest-asyncio~=1.0", "pylance==0.20.0b2", "tqdm>=4.27.0", "pydantic>=1.10", diff --git a/python/python/lancedb/remote/background_loop.py b/python/python/lancedb/remote/background_loop.py new file mode 100644 index 00000000..4b7b8632 --- /dev/null +++ b/python/python/lancedb/remote/background_loop.py @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The LanceDB Authors + +import asyncio +import threading + + +class BackgroundEventLoop: + """ + A background event loop that can run futures. + + Used to bridge sync and async code, without messing with users event loops. + """ + + def __init__(self): + self.loop = asyncio.new_event_loop() + self.thread = threading.Thread( + target=self.loop.run_forever, + name="LanceDBBackgroundEventLoop", + daemon=True, + ) + self.thread.start() + + def run(self, future): + return asyncio.run_coroutine_threadsafe(future, self.loop).result() diff --git a/python/python/lancedb/remote/db.py b/python/python/lancedb/remote/db.py index 51ef389e..a1281739 100644 --- a/python/python/lancedb/remote/db.py +++ b/python/python/lancedb/remote/db.py @@ -11,7 +11,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio from datetime import timedelta import logging from concurrent.futures import ThreadPoolExecutor @@ -21,6 +20,7 @@ import warnings from lancedb import connect_async from lancedb.remote import ClientConfig +from lancedb.remote.background_loop import BackgroundEventLoop import pyarrow as pa from overrides import override @@ -31,6 +31,8 @@ from ..pydantic import LanceModel from ..table import Table from ..util import validate_table_name +LOOP = BackgroundEventLoop() + class RemoteDBConnection(DBConnection): """A connection to a remote LanceDB database.""" @@ -86,18 +88,9 @@ class RemoteDBConnection(DBConnection): raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://") self.db_name = parsed.netloc - import nest_asyncio - - nest_asyncio.apply() - try: - self._loop = asyncio.get_running_loop() - except RuntimeError: - self._loop = asyncio.new_event_loop() - asyncio.set_event_loop(self._loop) - self.client_config = client_config - self._conn = self._loop.run_until_complete( + self._conn = LOOP.run( connect_async( db_url, api_key=api_key, @@ -127,9 +120,7 @@ class RemoteDBConnection(DBConnection): ------- An iterator of table names. """ - return self._loop.run_until_complete( - self._conn.table_names(start_after=page_token, limit=limit) - ) + return LOOP.run(self._conn.table_names(start_after=page_token, limit=limit)) @override def open_table(self, name: str, *, index_cache_size: Optional[int] = None) -> Table: @@ -152,8 +143,8 @@ class RemoteDBConnection(DBConnection): " (there is no local cache to configure)" ) - table = self._loop.run_until_complete(self._conn.open_table(name)) - return RemoteTable(table, self.db_name, self._loop) + table = LOOP.run(self._conn.open_table(name)) + return RemoteTable(table, self.db_name) @override def create_table( @@ -268,7 +259,7 @@ class RemoteDBConnection(DBConnection): from .table import RemoteTable - table = self._loop.run_until_complete( + table = LOOP.run( self._conn.create_table( name, data, @@ -278,7 +269,7 @@ class RemoteDBConnection(DBConnection): fill_value=fill_value, ) ) - return RemoteTable(table, self.db_name, self._loop) + return RemoteTable(table, self.db_name) @override def drop_table(self, name: str): @@ -289,7 +280,7 @@ class RemoteDBConnection(DBConnection): name: str The name of the table. """ - self._loop.run_until_complete(self._conn.drop_table(name)) + LOOP.run(self._conn.drop_table(name)) @override def rename_table(self, cur_name: str, new_name: str): @@ -302,7 +293,7 @@ class RemoteDBConnection(DBConnection): new_name: str The new name of the table. """ - self._loop.run_until_complete(self._conn.rename_table(cur_name, new_name)) + LOOP.run(self._conn.rename_table(cur_name, new_name)) async def close(self): """Close the connection to the database.""" diff --git a/python/python/lancedb/remote/table.py b/python/python/lancedb/remote/table.py index c1eea6e1..1dc0fbaa 100644 --- a/python/python/lancedb/remote/table.py +++ b/python/python/lancedb/remote/table.py @@ -12,12 +12,12 @@ # limitations under the License. from datetime import timedelta -import asyncio import logging from functools import cached_property from typing import Dict, Iterable, List, Optional, Union, Literal from lancedb.index import FTS, BTree, Bitmap, HnswPq, HnswSq, IvfPq, LabelList +from lancedb.remote.db import LOOP import pyarrow as pa from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME @@ -33,9 +33,7 @@ class RemoteTable(Table): self, table: AsyncTable, db_name: str, - loop: Optional[asyncio.AbstractEventLoop] = None, ): - self._loop = loop self._table = table self.db_name = db_name @@ -56,12 +54,12 @@ class RemoteTable(Table): of this Table """ - return self._loop.run_until_complete(self._table.schema()) + return LOOP.run(self._table.schema()) @property def version(self) -> int: """Get the current version of the table""" - return self._loop.run_until_complete(self._table.version()) + return LOOP.run(self._table.version()) @cached_property def embedding_functions(self) -> dict: @@ -98,11 +96,11 @@ class RemoteTable(Table): def list_indices(self): """List all the indices on the table""" - return self._loop.run_until_complete(self._table.list_indices()) + return LOOP.run(self._table.list_indices()) def index_stats(self, index_uuid: str): """List all the stats of a specified index""" - return self._loop.run_until_complete(self._table.index_stats(index_uuid)) + return LOOP.run(self._table.index_stats(index_uuid)) def create_scalar_index( self, @@ -132,9 +130,7 @@ class RemoteTable(Table): else: raise ValueError(f"Unknown index type: {index_type}") - self._loop.run_until_complete( - self._table.create_index(column, config=config, replace=replace) - ) + LOOP.run(self._table.create_index(column, config=config, replace=replace)) def create_fts_index( self, @@ -144,9 +140,7 @@ class RemoteTable(Table): with_position: bool = True, ): config = FTS(with_position=with_position) - self._loop.run_until_complete( - self._table.create_index(column, config=config, replace=replace) - ) + LOOP.run(self._table.create_index(column, config=config, replace=replace)) def create_index( self, @@ -227,9 +221,7 @@ class RemoteTable(Table): " 'IVF_PQ', 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'" ) - self._loop.run_until_complete( - self._table.create_index(vector_column_name, config=config) - ) + LOOP.run(self._table.create_index(vector_column_name, config=config)) def add( self, @@ -261,7 +253,7 @@ class RemoteTable(Table): The value to use when filling vectors. Only used if on_bad_vectors="fill". """ - self._loop.run_until_complete( + LOOP.run( self._table.add( data, mode=mode, on_bad_vectors=on_bad_vectors, fill_value=fill_value ) @@ -349,9 +341,7 @@ class RemoteTable(Table): def _execute_query( self, query: Query, batch_size: Optional[int] = None ) -> pa.RecordBatchReader: - return self._loop.run_until_complete( - self._table._execute_query(query, batch_size=batch_size) - ) + return LOOP.run(self._table._execute_query(query, batch_size=batch_size)) def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder: """Returns a [`LanceMergeInsertBuilder`][lancedb.merge.LanceMergeInsertBuilder] @@ -368,9 +358,7 @@ class RemoteTable(Table): on_bad_vectors: str, fill_value: float, ): - self._loop.run_until_complete( - self._table._do_merge(merge, new_data, on_bad_vectors, fill_value) - ) + LOOP.run(self._table._do_merge(merge, new_data, on_bad_vectors, fill_value)) def delete(self, predicate: str): """Delete rows from the table. @@ -419,7 +407,7 @@ class RemoteTable(Table): x vector _distance # doctest: +SKIP 0 2 [3.0, 4.0] 85.0 # doctest: +SKIP """ - self._loop.run_until_complete(self._table.delete(predicate)) + LOOP.run(self._table.delete(predicate)) def update( self, @@ -469,7 +457,7 @@ class RemoteTable(Table): 2 2 [10.0, 10.0] # doctest: +SKIP """ - self._loop.run_until_complete( + LOOP.run( self._table.update(where=where, updates=values, updates_sql=values_sql) ) @@ -499,7 +487,7 @@ class RemoteTable(Table): ) def count_rows(self, filter: Optional[str] = None) -> int: - return self._loop.run_until_complete(self._table.count_rows(filter)) + return LOOP.run(self._table.count_rows(filter)) def add_columns(self, transforms: Dict[str, str]): raise NotImplementedError( diff --git a/python/python/tests/test_remote_db.py b/python/python/tests/test_remote_db.py index fbf432b1..8da9bda4 100644 --- a/python/python/tests/test_remote_db.py +++ b/python/python/tests/test_remote_db.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright The LanceDB Authors +from concurrent.futures import ThreadPoolExecutor import contextlib from datetime import timedelta import http.server @@ -187,6 +188,47 @@ async def test_retry_error(): assert cause.status_code == 429 +def test_table_add_in_threadpool(): + def handler(request): + if request.path == "/v1/table/test/insert/": + request.send_response(200) + request.end_headers() + elif request.path == "/v1/table/test/create/": + request.send_response(200) + request.send_header("Content-Type", "application/json") + request.end_headers() + request.wfile.write(b"{}") + elif request.path == "/v1/table/test/describe/": + request.send_response(200) + request.send_header("Content-Type", "application/json") + request.end_headers() + payload = json.dumps( + dict( + version=1, + schema=dict( + fields=[ + dict(name="id", type={"type": "int64"}, nullable=False), + ] + ), + ) + ) + request.wfile.write(payload.encode()) + else: + request.send_response(404) + request.end_headers() + + with mock_lancedb_connection(handler) as db: + table = db.create_table("test", [{"id": 1}]) + with ThreadPoolExecutor(3) as executor: + futures = [] + for _ in range(10): + future = executor.submit(table.add, [{"id": 1}]) + futures.append(future) + + for future in futures: + future.result() + + @contextlib.contextmanager def query_test_table(query_handler): def handler(request):