diff --git a/python/python/lancedb/util.py b/python/python/lancedb/util.py index 3e12efd3..80dffcea 100644 --- a/python/python/lancedb/util.py +++ b/python/python/lancedb/util.py @@ -11,6 +11,8 @@ from datetime import date, datetime from functools import singledispatch from typing import Tuple, Union, Optional, Any from urllib.parse import urlparse +from threading import Lock +from contextlib import contextmanager import numpy as np import pyarrow as pa @@ -314,3 +316,27 @@ def deprecated(func): def validate_table_name(name: str): """Verify the table name is valid.""" native_validate_table_name(name) + + +class ConnectionPool: + def __init__(self, connection_factory, *, max_size: Optional[int] = None): + self.max_size = max_size + self._connection_factory = connection_factory + self._pool = [] + self._lock = Lock() + + @contextmanager + def connection(self): + with self._lock: + if self._pool: + conn = self._pool.pop() + else: + conn = self._connection_factory() + + # release the lock before yielding + try: + yield conn + finally: + with self._lock: + if self.max_size is None or len(self._pool) < self.max_size: + self._pool.append(conn) diff --git a/python/python/tests/test_remote_db.py b/python/python/tests/test_remote_db.py index fbf432b1..81798eb7 100644 --- a/python/python/tests/test_remote_db.py +++ b/python/python/tests/test_remote_db.py @@ -6,13 +6,16 @@ from datetime import timedelta import http.server import json import threading +from concurrent.futures import ThreadPoolExecutor from unittest.mock import MagicMock import uuid import lancedb from lancedb.conftest import MockTextEmbeddingFunction from lancedb.remote import ClientConfig +from lancedb.util import ConnectionPool from lancedb.remote.errors import HttpError, RetryError +import lancedb.util import pytest import pyarrow as pa @@ -55,6 +58,34 @@ def mock_lancedb_connection(handler): handle.join() +@contextlib.contextmanager +def mock_lancedb_connection_pool(handler): + with http.server.HTTPServer( + ("localhost", 8080), make_mock_http_handler(handler) + ) as server: + handle = threading.Thread(target=server.serve_forever) + handle.start() + + def conn_factory(): + lancedb.connect( + "db://dev", + api_key="fake", + host_override="http://localhost:8080", + client_config={ + "retry_config": {"retries": 2}, + "timeout_config": { + "connect_timeout": 1, + }, + }, + ) + + try: + yield ConnectionPool(conn_factory) + finally: + server.shutdown() + handle.join() + + @contextlib.asynccontextmanager async def mock_lancedb_connection_async(handler): with http.server.HTTPServer( @@ -187,8 +218,7 @@ async def test_retry_error(): assert cause.status_code == 429 -@contextlib.contextmanager -def query_test_table(query_handler): +def http_handler(query_handler): def handler(request): if request.path == "/v1/table/test/describe/": request.send_response(200) @@ -212,7 +242,12 @@ def query_test_table(query_handler): request.send_response(404) request.end_headers() - with mock_lancedb_connection(handler) as db: + return handler + + +@contextlib.contextmanager +def query_test_table(connection_ctx_mgr): + with connection_ctx_mgr as db: assert repr(db) == "RemoteConnect(name=dev)" table = db.open_table("test") assert repr(table) == "RemoteTable(dev.test)" @@ -220,6 +255,7 @@ def query_test_table(query_handler): def test_query_sync_minimal(): + @http_handler def handler(body): assert body == { "distance_type": "l2", @@ -234,13 +270,53 @@ def test_query_sync_minimal(): return pa.table({"id": [1, 2, 3]}) - with query_test_table(handler) as table: + with query_test_table(mock_lancedb_connection(handler)) as table: + data = table.search([1, 2, 3]).to_list() + expected = [{"id": 1}, {"id": 2}, {"id": 3}] + assert data == expected + + with query_test_table(mock_lancedb_connection_pool(handler).connection()) as table: data = table.search([1, 2, 3]).to_list() expected = [{"id": 1}, {"id": 2}, {"id": 3}] assert data == expected +def test_query_sync_minimal_threaded(): + num_query = 0 + + @http_handler + def handler(body): + assert body == { + "distance_type": "l2", + "k": 10, + "prefilter": False, + "refine_factor": None, + "ef": None, + "vector": [1.0, 2.0, 3.0], + "nprobes": 20, + "version": None, + } + nonlocal num_query + num_query += 1 + + return pa.table({"id": [1, 2, 3]}) + + pool = mock_lancedb_connection_pool(handler) + + def _query(i): + with query_test_table(pool.connection()) as table: + data = table.search([1, 2, 3]).to_list() + expected = [{"id": 1}, {"id": 2}, {"id": 3}] + assert data == expected + + with ThreadPoolExecutor as exec: + exec.map(_query, range(1000)) + + assert num_query == 1000 + + def test_query_sync_empty_query(): + @http_handler def handler(body): assert body == { "k": 10, @@ -252,7 +328,12 @@ def test_query_sync_empty_query(): return pa.table({"id": [1, 2, 3]}) - with query_test_table(handler) as table: + with query_test_table(mock_lancedb_connection(handler)) as table: + data = table.search(None).where("true").select(["id"]).limit(10).to_list() + expected = [{"id": 1}, {"id": 2}, {"id": 3}] + assert data == expected + + with query_test_table(mock_lancedb_connection_pool(handler).connection()) as table: data = table.search(None).where("true").select(["id"]).limit(10).to_list() expected = [{"id": 1}, {"id": 2}, {"id": 3}] assert data == expected