Compare commits

...

1 Commits

Author SHA1 Message Date
rmeng
cc7a503faa feat: connection pool for sync client 2024-11-25 14:20:37 -05:00
2 changed files with 112 additions and 5 deletions

View File

@@ -11,6 +11,8 @@ from datetime import date, datetime
from functools import singledispatch from functools import singledispatch
from typing import Tuple, Union, Optional, Any from typing import Tuple, Union, Optional, Any
from urllib.parse import urlparse from urllib.parse import urlparse
from threading import Lock
from contextlib import contextmanager
import numpy as np import numpy as np
import pyarrow as pa import pyarrow as pa
@@ -314,3 +316,27 @@ def deprecated(func):
def validate_table_name(name: str): def validate_table_name(name: str):
"""Verify the table name is valid.""" """Verify the table name is valid."""
native_validate_table_name(name) native_validate_table_name(name)
class ConnectionPool:
def __init__(self, connection_factory, *, max_size: Optional[int] = None):
self.max_size = max_size
self._connection_factory = connection_factory
self._pool = []
self._lock = Lock()
@contextmanager
def connection(self):
with self._lock:
if self._pool:
conn = self._pool.pop()
else:
conn = self._connection_factory()
# release the lock before yielding
try:
yield conn
finally:
with self._lock:
if self.max_size is None or len(self._pool) < self.max_size:
self._pool.append(conn)

View File

@@ -6,13 +6,16 @@ from datetime import timedelta
import http.server import http.server
import json import json
import threading import threading
from concurrent.futures import ThreadPoolExecutor
from unittest.mock import MagicMock from unittest.mock import MagicMock
import uuid import uuid
import lancedb import lancedb
from lancedb.conftest import MockTextEmbeddingFunction from lancedb.conftest import MockTextEmbeddingFunction
from lancedb.remote import ClientConfig from lancedb.remote import ClientConfig
from lancedb.util import ConnectionPool
from lancedb.remote.errors import HttpError, RetryError from lancedb.remote.errors import HttpError, RetryError
import lancedb.util
import pytest import pytest
import pyarrow as pa import pyarrow as pa
@@ -55,6 +58,34 @@ def mock_lancedb_connection(handler):
handle.join() handle.join()
@contextlib.contextmanager
def mock_lancedb_connection_pool(handler):
with http.server.HTTPServer(
("localhost", 8080), make_mock_http_handler(handler)
) as server:
handle = threading.Thread(target=server.serve_forever)
handle.start()
def conn_factory():
lancedb.connect(
"db://dev",
api_key="fake",
host_override="http://localhost:8080",
client_config={
"retry_config": {"retries": 2},
"timeout_config": {
"connect_timeout": 1,
},
},
)
try:
yield ConnectionPool(conn_factory)
finally:
server.shutdown()
handle.join()
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def mock_lancedb_connection_async(handler): async def mock_lancedb_connection_async(handler):
with http.server.HTTPServer( with http.server.HTTPServer(
@@ -187,8 +218,7 @@ async def test_retry_error():
assert cause.status_code == 429 assert cause.status_code == 429
@contextlib.contextmanager def http_handler(query_handler):
def query_test_table(query_handler):
def handler(request): def handler(request):
if request.path == "/v1/table/test/describe/": if request.path == "/v1/table/test/describe/":
request.send_response(200) request.send_response(200)
@@ -212,7 +242,12 @@ def query_test_table(query_handler):
request.send_response(404) request.send_response(404)
request.end_headers() request.end_headers()
with mock_lancedb_connection(handler) as db: return handler
@contextlib.contextmanager
def query_test_table(connection_ctx_mgr):
with connection_ctx_mgr as db:
assert repr(db) == "RemoteConnect(name=dev)" assert repr(db) == "RemoteConnect(name=dev)"
table = db.open_table("test") table = db.open_table("test")
assert repr(table) == "RemoteTable(dev.test)" assert repr(table) == "RemoteTable(dev.test)"
@@ -220,6 +255,7 @@ def query_test_table(query_handler):
def test_query_sync_minimal(): def test_query_sync_minimal():
@http_handler
def handler(body): def handler(body):
assert body == { assert body == {
"distance_type": "l2", "distance_type": "l2",
@@ -234,13 +270,53 @@ def test_query_sync_minimal():
return pa.table({"id": [1, 2, 3]}) return pa.table({"id": [1, 2, 3]})
with query_test_table(handler) as table: with query_test_table(mock_lancedb_connection(handler)) as table:
data = table.search([1, 2, 3]).to_list()
expected = [{"id": 1}, {"id": 2}, {"id": 3}]
assert data == expected
with query_test_table(mock_lancedb_connection_pool(handler).connection()) as table:
data = table.search([1, 2, 3]).to_list() data = table.search([1, 2, 3]).to_list()
expected = [{"id": 1}, {"id": 2}, {"id": 3}] expected = [{"id": 1}, {"id": 2}, {"id": 3}]
assert data == expected assert data == expected
def test_query_sync_minimal_threaded():
num_query = 0
@http_handler
def handler(body):
assert body == {
"distance_type": "l2",
"k": 10,
"prefilter": False,
"refine_factor": None,
"ef": None,
"vector": [1.0, 2.0, 3.0],
"nprobes": 20,
"version": None,
}
nonlocal num_query
num_query += 1
return pa.table({"id": [1, 2, 3]})
pool = mock_lancedb_connection_pool(handler)
def _query(i):
with query_test_table(pool.connection()) as table:
data = table.search([1, 2, 3]).to_list()
expected = [{"id": 1}, {"id": 2}, {"id": 3}]
assert data == expected
with ThreadPoolExecutor as exec:
exec.map(_query, range(1000))
assert num_query == 1000
def test_query_sync_empty_query(): def test_query_sync_empty_query():
@http_handler
def handler(body): def handler(body):
assert body == { assert body == {
"k": 10, "k": 10,
@@ -252,7 +328,12 @@ def test_query_sync_empty_query():
return pa.table({"id": [1, 2, 3]}) return pa.table({"id": [1, 2, 3]})
with query_test_table(handler) as table: with query_test_table(mock_lancedb_connection(handler)) as table:
data = table.search(None).where("true").select(["id"]).limit(10).to_list()
expected = [{"id": 1}, {"id": 2}, {"id": 3}]
assert data == expected
with query_test_table(mock_lancedb_connection_pool(handler).connection()) as table:
data = table.search(None).where("true").select(["id"]).limit(10).to_list() data = table.search(None).where("true").select(["id"]).limit(10).to_list()
expected = [{"id": 1}, {"id": 2}, {"id": 3}] expected = [{"id": 1}, {"id": 2}, {"id": 3}]
assert data == expected assert data == expected