mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-26 22:59:57 +00:00
Compare commits
1 Commits
python-v0.
...
rmeng/pool
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cc7a503faa |
@@ -11,6 +11,8 @@ from datetime import date, datetime
|
|||||||
from functools import singledispatch
|
from functools import singledispatch
|
||||||
from typing import Tuple, Union, Optional, Any
|
from typing import Tuple, Union, Optional, Any
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
from threading import Lock
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
@@ -314,3 +316,27 @@ def deprecated(func):
|
|||||||
def validate_table_name(name: str):
|
def validate_table_name(name: str):
|
||||||
"""Verify the table name is valid."""
|
"""Verify the table name is valid."""
|
||||||
native_validate_table_name(name)
|
native_validate_table_name(name)
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionPool:
|
||||||
|
def __init__(self, connection_factory, *, max_size: Optional[int] = None):
|
||||||
|
self.max_size = max_size
|
||||||
|
self._connection_factory = connection_factory
|
||||||
|
self._pool = []
|
||||||
|
self._lock = Lock()
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def connection(self):
|
||||||
|
with self._lock:
|
||||||
|
if self._pool:
|
||||||
|
conn = self._pool.pop()
|
||||||
|
else:
|
||||||
|
conn = self._connection_factory()
|
||||||
|
|
||||||
|
# release the lock before yielding
|
||||||
|
try:
|
||||||
|
yield conn
|
||||||
|
finally:
|
||||||
|
with self._lock:
|
||||||
|
if self.max_size is None or len(self._pool) < self.max_size:
|
||||||
|
self._pool.append(conn)
|
||||||
|
|||||||
@@ -6,13 +6,16 @@ from datetime import timedelta
|
|||||||
import http.server
|
import http.server
|
||||||
import json
|
import json
|
||||||
import threading
|
import threading
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
import lancedb
|
import lancedb
|
||||||
from lancedb.conftest import MockTextEmbeddingFunction
|
from lancedb.conftest import MockTextEmbeddingFunction
|
||||||
from lancedb.remote import ClientConfig
|
from lancedb.remote import ClientConfig
|
||||||
|
from lancedb.util import ConnectionPool
|
||||||
from lancedb.remote.errors import HttpError, RetryError
|
from lancedb.remote.errors import HttpError, RetryError
|
||||||
|
import lancedb.util
|
||||||
import pytest
|
import pytest
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
|
|
||||||
@@ -55,6 +58,34 @@ def mock_lancedb_connection(handler):
|
|||||||
handle.join()
|
handle.join()
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def mock_lancedb_connection_pool(handler):
|
||||||
|
with http.server.HTTPServer(
|
||||||
|
("localhost", 8080), make_mock_http_handler(handler)
|
||||||
|
) as server:
|
||||||
|
handle = threading.Thread(target=server.serve_forever)
|
||||||
|
handle.start()
|
||||||
|
|
||||||
|
def conn_factory():
|
||||||
|
lancedb.connect(
|
||||||
|
"db://dev",
|
||||||
|
api_key="fake",
|
||||||
|
host_override="http://localhost:8080",
|
||||||
|
client_config={
|
||||||
|
"retry_config": {"retries": 2},
|
||||||
|
"timeout_config": {
|
||||||
|
"connect_timeout": 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield ConnectionPool(conn_factory)
|
||||||
|
finally:
|
||||||
|
server.shutdown()
|
||||||
|
handle.join()
|
||||||
|
|
||||||
|
|
||||||
@contextlib.asynccontextmanager
|
@contextlib.asynccontextmanager
|
||||||
async def mock_lancedb_connection_async(handler):
|
async def mock_lancedb_connection_async(handler):
|
||||||
with http.server.HTTPServer(
|
with http.server.HTTPServer(
|
||||||
@@ -187,8 +218,7 @@ async def test_retry_error():
|
|||||||
assert cause.status_code == 429
|
assert cause.status_code == 429
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
def http_handler(query_handler):
|
||||||
def query_test_table(query_handler):
|
|
||||||
def handler(request):
|
def handler(request):
|
||||||
if request.path == "/v1/table/test/describe/":
|
if request.path == "/v1/table/test/describe/":
|
||||||
request.send_response(200)
|
request.send_response(200)
|
||||||
@@ -212,7 +242,12 @@ def query_test_table(query_handler):
|
|||||||
request.send_response(404)
|
request.send_response(404)
|
||||||
request.end_headers()
|
request.end_headers()
|
||||||
|
|
||||||
with mock_lancedb_connection(handler) as db:
|
return handler
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def query_test_table(connection_ctx_mgr):
|
||||||
|
with connection_ctx_mgr as db:
|
||||||
assert repr(db) == "RemoteConnect(name=dev)"
|
assert repr(db) == "RemoteConnect(name=dev)"
|
||||||
table = db.open_table("test")
|
table = db.open_table("test")
|
||||||
assert repr(table) == "RemoteTable(dev.test)"
|
assert repr(table) == "RemoteTable(dev.test)"
|
||||||
@@ -220,6 +255,7 @@ def query_test_table(query_handler):
|
|||||||
|
|
||||||
|
|
||||||
def test_query_sync_minimal():
|
def test_query_sync_minimal():
|
||||||
|
@http_handler
|
||||||
def handler(body):
|
def handler(body):
|
||||||
assert body == {
|
assert body == {
|
||||||
"distance_type": "l2",
|
"distance_type": "l2",
|
||||||
@@ -234,13 +270,53 @@ def test_query_sync_minimal():
|
|||||||
|
|
||||||
return pa.table({"id": [1, 2, 3]})
|
return pa.table({"id": [1, 2, 3]})
|
||||||
|
|
||||||
with query_test_table(handler) as table:
|
with query_test_table(mock_lancedb_connection(handler)) as table:
|
||||||
|
data = table.search([1, 2, 3]).to_list()
|
||||||
|
expected = [{"id": 1}, {"id": 2}, {"id": 3}]
|
||||||
|
assert data == expected
|
||||||
|
|
||||||
|
with query_test_table(mock_lancedb_connection_pool(handler).connection()) as table:
|
||||||
data = table.search([1, 2, 3]).to_list()
|
data = table.search([1, 2, 3]).to_list()
|
||||||
expected = [{"id": 1}, {"id": 2}, {"id": 3}]
|
expected = [{"id": 1}, {"id": 2}, {"id": 3}]
|
||||||
assert data == expected
|
assert data == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_sync_minimal_threaded():
|
||||||
|
num_query = 0
|
||||||
|
|
||||||
|
@http_handler
|
||||||
|
def handler(body):
|
||||||
|
assert body == {
|
||||||
|
"distance_type": "l2",
|
||||||
|
"k": 10,
|
||||||
|
"prefilter": False,
|
||||||
|
"refine_factor": None,
|
||||||
|
"ef": None,
|
||||||
|
"vector": [1.0, 2.0, 3.0],
|
||||||
|
"nprobes": 20,
|
||||||
|
"version": None,
|
||||||
|
}
|
||||||
|
nonlocal num_query
|
||||||
|
num_query += 1
|
||||||
|
|
||||||
|
return pa.table({"id": [1, 2, 3]})
|
||||||
|
|
||||||
|
pool = mock_lancedb_connection_pool(handler)
|
||||||
|
|
||||||
|
def _query(i):
|
||||||
|
with query_test_table(pool.connection()) as table:
|
||||||
|
data = table.search([1, 2, 3]).to_list()
|
||||||
|
expected = [{"id": 1}, {"id": 2}, {"id": 3}]
|
||||||
|
assert data == expected
|
||||||
|
|
||||||
|
with ThreadPoolExecutor as exec:
|
||||||
|
exec.map(_query, range(1000))
|
||||||
|
|
||||||
|
assert num_query == 1000
|
||||||
|
|
||||||
|
|
||||||
def test_query_sync_empty_query():
|
def test_query_sync_empty_query():
|
||||||
|
@http_handler
|
||||||
def handler(body):
|
def handler(body):
|
||||||
assert body == {
|
assert body == {
|
||||||
"k": 10,
|
"k": 10,
|
||||||
@@ -252,7 +328,12 @@ def test_query_sync_empty_query():
|
|||||||
|
|
||||||
return pa.table({"id": [1, 2, 3]})
|
return pa.table({"id": [1, 2, 3]})
|
||||||
|
|
||||||
with query_test_table(handler) as table:
|
with query_test_table(mock_lancedb_connection(handler)) as table:
|
||||||
|
data = table.search(None).where("true").select(["id"]).limit(10).to_list()
|
||||||
|
expected = [{"id": 1}, {"id": 2}, {"id": 3}]
|
||||||
|
assert data == expected
|
||||||
|
|
||||||
|
with query_test_table(mock_lancedb_connection_pool(handler).connection()) as table:
|
||||||
data = table.search(None).where("true").select(["id"]).limit(10).to_list()
|
data = table.search(None).where("true").select(["id"]).limit(10).to_list()
|
||||||
expected = [{"id": 1}, {"id": 2}, {"id": 3}]
|
expected = [{"id": 1}, {"id": 2}, {"id": 3}]
|
||||||
assert data == expected
|
assert data == expected
|
||||||
|
|||||||
Reference in New Issue
Block a user