feat: connection pool for sync client

This commit is contained in:
rmeng
2024-11-25 14:18:00 -05:00
parent 2ded17452b
commit cc7a503faa
2 changed files with 112 additions and 5 deletions

View File

@@ -11,6 +11,8 @@ from datetime import date, datetime
from functools import singledispatch
from typing import Tuple, Union, Optional, Any
from urllib.parse import urlparse
from threading import Lock
from contextlib import contextmanager
import numpy as np
import pyarrow as pa
@@ -314,3 +316,27 @@ def deprecated(func):
def validate_table_name(name: str):
"""Verify the table name is valid."""
native_validate_table_name(name)
class ConnectionPool:
def __init__(self, connection_factory, *, max_size: Optional[int] = None):
self.max_size = max_size
self._connection_factory = connection_factory
self._pool = []
self._lock = Lock()
@contextmanager
def connection(self):
with self._lock:
if self._pool:
conn = self._pool.pop()
else:
conn = self._connection_factory()
# release the lock before yielding
try:
yield conn
finally:
with self._lock:
if self.max_size is None or len(self._pool) < self.max_size:
self._pool.append(conn)

View File

@@ -6,13 +6,16 @@ from datetime import timedelta
import http.server
import json
import threading
from concurrent.futures import ThreadPoolExecutor
from unittest.mock import MagicMock
import uuid
import lancedb
from lancedb.conftest import MockTextEmbeddingFunction
from lancedb.remote import ClientConfig
from lancedb.util import ConnectionPool
from lancedb.remote.errors import HttpError, RetryError
import lancedb.util
import pytest
import pyarrow as pa
@@ -55,6 +58,34 @@ def mock_lancedb_connection(handler):
handle.join()
@contextlib.contextmanager
def mock_lancedb_connection_pool(handler):
with http.server.HTTPServer(
("localhost", 8080), make_mock_http_handler(handler)
) as server:
handle = threading.Thread(target=server.serve_forever)
handle.start()
def conn_factory():
lancedb.connect(
"db://dev",
api_key="fake",
host_override="http://localhost:8080",
client_config={
"retry_config": {"retries": 2},
"timeout_config": {
"connect_timeout": 1,
},
},
)
try:
yield ConnectionPool(conn_factory)
finally:
server.shutdown()
handle.join()
@contextlib.asynccontextmanager
async def mock_lancedb_connection_async(handler):
with http.server.HTTPServer(
@@ -187,8 +218,7 @@ async def test_retry_error():
assert cause.status_code == 429
@contextlib.contextmanager
def query_test_table(query_handler):
def http_handler(query_handler):
def handler(request):
if request.path == "/v1/table/test/describe/":
request.send_response(200)
@@ -212,7 +242,12 @@ def query_test_table(query_handler):
request.send_response(404)
request.end_headers()
with mock_lancedb_connection(handler) as db:
return handler
@contextlib.contextmanager
def query_test_table(connection_ctx_mgr):
with connection_ctx_mgr as db:
assert repr(db) == "RemoteConnect(name=dev)"
table = db.open_table("test")
assert repr(table) == "RemoteTable(dev.test)"
@@ -220,6 +255,7 @@ def query_test_table(query_handler):
def test_query_sync_minimal():
@http_handler
def handler(body):
assert body == {
"distance_type": "l2",
@@ -234,13 +270,53 @@ def test_query_sync_minimal():
return pa.table({"id": [1, 2, 3]})
with query_test_table(handler) as table:
with query_test_table(mock_lancedb_connection(handler)) as table:
data = table.search([1, 2, 3]).to_list()
expected = [{"id": 1}, {"id": 2}, {"id": 3}]
assert data == expected
with query_test_table(mock_lancedb_connection_pool(handler).connection()) as table:
data = table.search([1, 2, 3]).to_list()
expected = [{"id": 1}, {"id": 2}, {"id": 3}]
assert data == expected
def test_query_sync_minimal_threaded():
num_query = 0
@http_handler
def handler(body):
assert body == {
"distance_type": "l2",
"k": 10,
"prefilter": False,
"refine_factor": None,
"ef": None,
"vector": [1.0, 2.0, 3.0],
"nprobes": 20,
"version": None,
}
nonlocal num_query
num_query += 1
return pa.table({"id": [1, 2, 3]})
pool = mock_lancedb_connection_pool(handler)
def _query(i):
with query_test_table(pool.connection()) as table:
data = table.search([1, 2, 3]).to_list()
expected = [{"id": 1}, {"id": 2}, {"id": 3}]
assert data == expected
with ThreadPoolExecutor as exec:
exec.map(_query, range(1000))
assert num_query == 1000
def test_query_sync_empty_query():
@http_handler
def handler(body):
assert body == {
"k": 10,
@@ -252,7 +328,12 @@ def test_query_sync_empty_query():
return pa.table({"id": [1, 2, 3]})
with query_test_table(handler) as table:
with query_test_table(mock_lancedb_connection(handler)) as table:
data = table.search(None).where("true").select(["id"]).limit(10).to_list()
expected = [{"id": 1}, {"id": 2}, {"id": 3}]
assert data == expected
with query_test_table(mock_lancedb_connection_pool(handler).connection()) as table:
data = table.search(None).where("true").select(["id"]).limit(10).to_list()
expected = [{"id": 1}, {"id": 2}, {"id": 3}]
assert data == expected