diff --git a/python/python/lancedb/background_loop.py b/python/python/lancedb/background_loop.py index 8dd2d00f2..d132dd82d 100644 --- a/python/python/lancedb/background_loop.py +++ b/python/python/lancedb/background_loop.py @@ -22,7 +22,12 @@ class BackgroundEventLoop: self.thread.start() def run(self, future): - return asyncio.run_coroutine_threadsafe(future, self.loop).result() + concurrent_future = asyncio.run_coroutine_threadsafe(future, self.loop) + try: + return concurrent_future.result() + except BaseException: + concurrent_future.cancel() + raise LOOP = BackgroundEventLoop() diff --git a/python/python/tests/test_remote_db.py b/python/python/tests/test_remote_db.py index 566e1fba4..1447bc1fa 100644 --- a/python/python/tests/test_remote_db.py +++ b/python/python/tests/test_remote_db.py @@ -8,7 +8,7 @@ import http.server import json import threading import time -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import uuid from packaging.version import Version @@ -1203,3 +1203,22 @@ async def test_header_provider_overrides_static_headers(): extra_headers={"X-API-Key": "static-key", "X-Extra": "extra-value"}, ) as db: await db.table_names() + + +@pytest.mark.parametrize("exception", [KeyboardInterrupt, SystemExit, GeneratorExit]) +def test_background_loop_cancellation(exception): + """Test that BackgroundEventLoop.run() cancels the future on interrupt.""" + from lancedb.background_loop import BackgroundEventLoop + + mock_future = MagicMock() + mock_future.result.side_effect = exception() + + with ( + patch.object(BackgroundEventLoop, "__init__", return_value=None), + patch("asyncio.run_coroutine_threadsafe", return_value=mock_future), + ): + loop = BackgroundEventLoop() + loop.loop = MagicMock() + with pytest.raises(exception): + loop.run(None) + mock_future.cancel.assert_called_once()