mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-22 21:09:58 +00:00
feat: connection pool for sync client
This commit is contained in:
@@ -11,6 +11,8 @@ from datetime import date, datetime
|
||||
from functools import singledispatch
|
||||
from typing import Tuple, Union, Optional, Any
|
||||
from urllib.parse import urlparse
|
||||
from threading import Lock
|
||||
from contextlib import contextmanager
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
@@ -314,3 +316,27 @@ def deprecated(func):
|
||||
def validate_table_name(name: str):
|
||||
"""Verify the table name is valid."""
|
||||
native_validate_table_name(name)
|
||||
|
||||
|
||||
class ConnectionPool:
|
||||
def __init__(self, connection_factory, *, max_size: Optional[int] = None):
|
||||
self.max_size = max_size
|
||||
self._connection_factory = connection_factory
|
||||
self._pool = []
|
||||
self._lock = Lock()
|
||||
|
||||
@contextmanager
|
||||
def connection(self):
|
||||
with self._lock:
|
||||
if self._pool:
|
||||
conn = self._pool.pop()
|
||||
else:
|
||||
conn = self._connection_factory()
|
||||
|
||||
# release the lock before yielding
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
with self._lock:
|
||||
if self.max_size is None or len(self._pool) < self.max_size:
|
||||
self._pool.append(conn)
|
||||
|
||||
@@ -6,13 +6,16 @@ from datetime import timedelta
|
||||
import http.server
|
||||
import json
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from unittest.mock import MagicMock
|
||||
import uuid
|
||||
|
||||
import lancedb
|
||||
from lancedb.conftest import MockTextEmbeddingFunction
|
||||
from lancedb.remote import ClientConfig
|
||||
from lancedb.util import ConnectionPool
|
||||
from lancedb.remote.errors import HttpError, RetryError
|
||||
import lancedb.util
|
||||
import pytest
|
||||
import pyarrow as pa
|
||||
|
||||
@@ -55,6 +58,34 @@ def mock_lancedb_connection(handler):
|
||||
handle.join()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def mock_lancedb_connection_pool(handler):
|
||||
with http.server.HTTPServer(
|
||||
("localhost", 8080), make_mock_http_handler(handler)
|
||||
) as server:
|
||||
handle = threading.Thread(target=server.serve_forever)
|
||||
handle.start()
|
||||
|
||||
def conn_factory():
|
||||
lancedb.connect(
|
||||
"db://dev",
|
||||
api_key="fake",
|
||||
host_override="http://localhost:8080",
|
||||
client_config={
|
||||
"retry_config": {"retries": 2},
|
||||
"timeout_config": {
|
||||
"connect_timeout": 1,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
yield ConnectionPool(conn_factory)
|
||||
finally:
|
||||
server.shutdown()
|
||||
handle.join()
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def mock_lancedb_connection_async(handler):
|
||||
with http.server.HTTPServer(
|
||||
@@ -187,8 +218,7 @@ async def test_retry_error():
|
||||
assert cause.status_code == 429
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def query_test_table(query_handler):
|
||||
def http_handler(query_handler):
|
||||
def handler(request):
|
||||
if request.path == "/v1/table/test/describe/":
|
||||
request.send_response(200)
|
||||
@@ -212,7 +242,12 @@ def query_test_table(query_handler):
|
||||
request.send_response(404)
|
||||
request.end_headers()
|
||||
|
||||
with mock_lancedb_connection(handler) as db:
|
||||
return handler
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def query_test_table(connection_ctx_mgr):
|
||||
with connection_ctx_mgr as db:
|
||||
assert repr(db) == "RemoteConnect(name=dev)"
|
||||
table = db.open_table("test")
|
||||
assert repr(table) == "RemoteTable(dev.test)"
|
||||
@@ -220,6 +255,7 @@ def query_test_table(query_handler):
|
||||
|
||||
|
||||
def test_query_sync_minimal():
|
||||
@http_handler
|
||||
def handler(body):
|
||||
assert body == {
|
||||
"distance_type": "l2",
|
||||
@@ -234,13 +270,53 @@ def test_query_sync_minimal():
|
||||
|
||||
return pa.table({"id": [1, 2, 3]})
|
||||
|
||||
with query_test_table(handler) as table:
|
||||
with query_test_table(mock_lancedb_connection(handler)) as table:
|
||||
data = table.search([1, 2, 3]).to_list()
|
||||
expected = [{"id": 1}, {"id": 2}, {"id": 3}]
|
||||
assert data == expected
|
||||
|
||||
with query_test_table(mock_lancedb_connection_pool(handler).connection()) as table:
|
||||
data = table.search([1, 2, 3]).to_list()
|
||||
expected = [{"id": 1}, {"id": 2}, {"id": 3}]
|
||||
assert data == expected
|
||||
|
||||
|
||||
def test_query_sync_minimal_threaded():
|
||||
num_query = 0
|
||||
|
||||
@http_handler
|
||||
def handler(body):
|
||||
assert body == {
|
||||
"distance_type": "l2",
|
||||
"k": 10,
|
||||
"prefilter": False,
|
||||
"refine_factor": None,
|
||||
"ef": None,
|
||||
"vector": [1.0, 2.0, 3.0],
|
||||
"nprobes": 20,
|
||||
"version": None,
|
||||
}
|
||||
nonlocal num_query
|
||||
num_query += 1
|
||||
|
||||
return pa.table({"id": [1, 2, 3]})
|
||||
|
||||
pool = mock_lancedb_connection_pool(handler)
|
||||
|
||||
def _query(i):
|
||||
with query_test_table(pool.connection()) as table:
|
||||
data = table.search([1, 2, 3]).to_list()
|
||||
expected = [{"id": 1}, {"id": 2}, {"id": 3}]
|
||||
assert data == expected
|
||||
|
||||
with ThreadPoolExecutor as exec:
|
||||
exec.map(_query, range(1000))
|
||||
|
||||
assert num_query == 1000
|
||||
|
||||
|
||||
def test_query_sync_empty_query():
|
||||
@http_handler
|
||||
def handler(body):
|
||||
assert body == {
|
||||
"k": 10,
|
||||
@@ -252,7 +328,12 @@ def test_query_sync_empty_query():
|
||||
|
||||
return pa.table({"id": [1, 2, 3]})
|
||||
|
||||
with query_test_table(handler) as table:
|
||||
with query_test_table(mock_lancedb_connection(handler)) as table:
|
||||
data = table.search(None).where("true").select(["id"]).limit(10).to_list()
|
||||
expected = [{"id": 1}, {"id": 2}, {"id": 3}]
|
||||
assert data == expected
|
||||
|
||||
with query_test_table(mock_lancedb_connection_pool(handler).connection()) as table:
|
||||
data = table.search(None).where("true").select(["id"]).limit(10).to_list()
|
||||
expected = [{"id": 1}, {"id": 2}, {"id": 3}]
|
||||
assert data == expected
|
||||
|
||||
Reference in New Issue
Block a user