diff --git a/python/python/lancedb/background_loop.py b/python/python/lancedb/background_loop.py index b39da229d..9e1fdfbd8 100644 --- a/python/python/lancedb/background_loop.py +++ b/python/python/lancedb/background_loop.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright The LanceDB Authors import asyncio +import concurrent.futures import os import threading import warnings @@ -37,6 +38,24 @@ class BackgroundEventLoop: LOOP = BackgroundEventLoop() + +def _new_embedding_executor() -> concurrent.futures.ThreadPoolExecutor: + return concurrent.futures.ThreadPoolExecutor(thread_name_prefix="lancedb-embedding") + + +# Embedding functions can block for a long time -- a heavy local model or an +# HTTP request to a remote embeddings API. Running them on asyncio's default +# executor lets them starve the unrelated blocking I/O that shares that pool, +# so they get a dedicated one. See +# https://github.com/lancedb/lancedb/issues/3310. +_EMBEDDING_EXECUTOR = _new_embedding_executor() + + +def embedding_executor() -> concurrent.futures.ThreadPoolExecutor: + """Return the executor dedicated to running blocking embedding calls.""" + return _EMBEDDING_EXECUTOR + + _FORK_WARNED = False @@ -47,6 +66,12 @@ def _reset_after_fork(): # the new state. The Rust-side tokio runtime is reset analogously by a # pthread_atfork hook installed in the _lancedb extension. LOOP._start() + # The embedding executor's worker threads are dead in the child as well. + # Replace it with a fresh pool (threads are spawned lazily, so this is + # cheap); we don't shut down the old one, since joining its dead workers + # could hang. + global _EMBEDDING_EXECUTOR + _EMBEDDING_EXECUTOR = _new_embedding_executor() global _FORK_WARNED if not _FORK_WARNED: _FORK_WARNED = True diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index eac48206b..893023060 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -30,7 +30,7 @@ from lancedb.scannable import _register_optional_converters, to_scannable from . import __version__ from lancedb.arrow import peek_reader -from lancedb.background_loop import LOOP +from lancedb.background_loop import LOOP, embedding_executor from .dependencies import ( _check_for_hugging_face, _check_for_lance, @@ -4908,10 +4908,13 @@ class AsyncTable: if embedding is not None: loop = asyncio.get_running_loop() # This function is likely to block, since it either calls an expensive - # function or makes an HTTP request to an embeddings REST API. + # function or makes an HTTP request to an embeddings REST API. Run it + # on a dedicated executor so it can't starve the default executor that + # other blocking I/O shares. See + # https://github.com/lancedb/lancedb/issues/3310. return ( await loop.run_in_executor( - None, + embedding_executor(), embedding.function.compute_query_embeddings_with_retry, query, ) diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index ade086a35..58c085f4e 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -4,6 +4,7 @@ import os import sys +import threading import warnings from datetime import date, datetime, timedelta from time import sleep @@ -2837,3 +2838,38 @@ def test_sanitize_data_metadata_not_stripped(): assert result_schema.metadata is not None assert result_schema.metadata[b"existing_key"] == b"existing_value" assert result_schema.metadata[b"new_key"] == b"new_value" + + +@pytest.mark.asyncio +async def test_async_search_runs_embedding_on_dedicated_executor( + mem_db_async: AsyncConnection, +): + # Regression test for #3310: AsyncTable.search() must run the (potentially + # blocking) query-embedding call on the dedicated embedding executor, not + # asyncio's default executor -- which is shared with other blocking I/O and + # can be starved by a slow embedding call under concurrent load. + func = MockTextEmbeddingFunction.create() + + class Schema(LanceModel): + text: str = func.SourceField() + vector: Vector(func.ndims()) = func.VectorField() + + table = await mem_db_async.create_table("embed_executor", schema=Schema) + await table.add([{"text": "hello world"}]) + + captured_threads: List[str] = [] + original = MockTextEmbeddingFunction.generate_embeddings + + def record_thread(self, texts): + captured_threads.append(threading.current_thread().name) + return original(self, texts) + + # Patch only around the search so we capture the query-embedding call, not + # the add-time source-embedding call. + with patch.object(MockTextEmbeddingFunction, "generate_embeddings", record_thread): + await (await table.search("a query string")).limit(1).to_list() + + assert captured_threads, "search did not invoke the embedding function" + assert all(name.startswith("lancedb-embedding") for name in captured_threads), ( + f"embedding ran off the dedicated executor: {captured_threads}" + )