mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-25 06:19:57 +00:00
This moves the __len__ method from LanceTable and RemoteTable to Table so that child classes don't need to implement their own. In the process, it fixes the implementation of RemoteTable's length method, which was previously missing a return statement. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Refactor** - Centralized the table length functionality in the base table class, simplifying subclass behavior. - Removed redundant or non-functional length methods from specific table classes. - **Tests** - Added a new test to verify correct table length reporting for remote tables. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
754 lines
25 KiB
Python
754 lines
25 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
|
import re
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
import contextlib
|
|
from datetime import timedelta
|
|
import http.server
|
|
import json
|
|
import threading
|
|
from unittest.mock import MagicMock
|
|
import uuid
|
|
from packaging.version import Version
|
|
|
|
import lancedb
|
|
from lancedb.conftest import MockTextEmbeddingFunction
|
|
from lancedb.remote import ClientConfig
|
|
from lancedb.remote.errors import HttpError, RetryError
|
|
import pytest
|
|
import pyarrow as pa
|
|
|
|
|
|
def make_mock_http_handler(handler):
|
|
class MockLanceDBHandler(http.server.BaseHTTPRequestHandler):
|
|
def do_GET(self):
|
|
handler(self)
|
|
|
|
def do_POST(self):
|
|
handler(self)
|
|
|
|
return MockLanceDBHandler
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def mock_lancedb_connection(handler):
|
|
with http.server.HTTPServer(
|
|
("localhost", 0), make_mock_http_handler(handler)
|
|
) as server:
|
|
port = server.server_address[1]
|
|
handle = threading.Thread(target=server.serve_forever)
|
|
handle.start()
|
|
|
|
db = lancedb.connect(
|
|
"db://dev",
|
|
api_key="fake",
|
|
host_override=f"http://localhost:{port}",
|
|
client_config={
|
|
"retry_config": {"retries": 2},
|
|
"timeout_config": {
|
|
"connect_timeout": 1,
|
|
},
|
|
},
|
|
)
|
|
|
|
try:
|
|
yield db
|
|
finally:
|
|
server.shutdown()
|
|
handle.join()
|
|
|
|
|
|
@contextlib.asynccontextmanager
|
|
async def mock_lancedb_connection_async(handler, **client_config):
|
|
with http.server.HTTPServer(
|
|
("localhost", 0), make_mock_http_handler(handler)
|
|
) as server:
|
|
port = server.server_address[1]
|
|
handle = threading.Thread(target=server.serve_forever)
|
|
handle.start()
|
|
|
|
db = await lancedb.connect_async(
|
|
"db://dev",
|
|
api_key="fake",
|
|
host_override=f"http://localhost:{port}",
|
|
client_config={
|
|
"retry_config": {"retries": 2},
|
|
"timeout_config": {
|
|
"connect_timeout": 1,
|
|
},
|
|
**client_config,
|
|
},
|
|
)
|
|
|
|
try:
|
|
yield db
|
|
finally:
|
|
server.shutdown()
|
|
handle.join()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_remote_db():
|
|
def handler(request):
|
|
# We created a UUID request id
|
|
request_id = request.headers["x-request-id"]
|
|
assert uuid.UUID(request_id).version == 4
|
|
|
|
# We set a user agent with the current library version
|
|
user_agent = request.headers["User-Agent"]
|
|
assert user_agent == f"LanceDB-Python-Client/{lancedb.__version__}"
|
|
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
request.wfile.write(b'{"tables": []}')
|
|
|
|
async with mock_lancedb_connection_async(handler) as db:
|
|
table_names = await db.table_names()
|
|
assert table_names == []
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_checkout():
|
|
def handler(request):
|
|
if request.path == "/v1/table/test/describe/":
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
response = json.dumps({"version": 42, "schema": {"fields": []}})
|
|
request.wfile.write(response.encode())
|
|
return
|
|
|
|
content_len = int(request.headers.get("Content-Length"))
|
|
body = request.rfile.read(content_len)
|
|
body = json.loads(body)
|
|
|
|
print("body is", body)
|
|
|
|
count = 0
|
|
if body["version"] == 1:
|
|
count = 100
|
|
elif body["version"] == 2:
|
|
count = 200
|
|
elif body["version"] is None:
|
|
count = 300
|
|
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
request.wfile.write(json.dumps(count).encode())
|
|
|
|
async with mock_lancedb_connection_async(handler) as db:
|
|
table = await db.open_table("test")
|
|
assert await table.count_rows() == 300
|
|
await table.checkout(1)
|
|
assert await table.count_rows() == 100
|
|
await table.checkout(2)
|
|
assert await table.count_rows() == 200
|
|
await table.checkout_latest()
|
|
assert await table.count_rows() == 300
|
|
|
|
|
|
def test_table_len_sync():
|
|
def handler(request):
|
|
if request.path == "/v1/table/test/create/?mode=create":
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
request.wfile.write(b"{}")
|
|
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
request.wfile.write(json.dumps(1).encode())
|
|
|
|
with mock_lancedb_connection(handler) as db:
|
|
table = db.create_table("test", [{"id": 1}])
|
|
assert len(table) == 1
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_http_error():
|
|
request_id_holder = {"request_id": None}
|
|
|
|
def handler(request):
|
|
request_id_holder["request_id"] = request.headers["x-request-id"]
|
|
|
|
request.send_response(507)
|
|
request.end_headers()
|
|
request.wfile.write(b"Internal Server Error")
|
|
|
|
async with mock_lancedb_connection_async(handler) as db:
|
|
with pytest.raises(HttpError) as exc_info:
|
|
await db.table_names()
|
|
|
|
assert exc_info.value.request_id == request_id_holder["request_id"]
|
|
assert "Internal Server Error" in str(exc_info.value)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retry_error():
|
|
request_id_holder = {"request_id": None}
|
|
|
|
def handler(request):
|
|
request_id_holder["request_id"] = request.headers["x-request-id"]
|
|
|
|
request.send_response(429)
|
|
request.end_headers()
|
|
request.wfile.write(b"Try again later")
|
|
|
|
async with mock_lancedb_connection_async(handler) as db:
|
|
with pytest.raises(RetryError) as exc_info:
|
|
await db.table_names()
|
|
|
|
assert exc_info.value.request_id == request_id_holder["request_id"]
|
|
|
|
cause = exc_info.value.__cause__
|
|
assert isinstance(cause, HttpError)
|
|
assert "Try again later" in str(cause)
|
|
assert cause.request_id == request_id_holder["request_id"]
|
|
assert cause.status_code == 429
|
|
|
|
|
|
def test_table_add_in_threadpool():
|
|
def handler(request):
|
|
if request.path == "/v1/table/test/insert/":
|
|
request.send_response(200)
|
|
request.end_headers()
|
|
elif request.path == "/v1/table/test/create/?mode=create":
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
request.wfile.write(b"{}")
|
|
elif request.path == "/v1/table/test/describe/":
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
payload = json.dumps(
|
|
dict(
|
|
version=1,
|
|
schema=dict(
|
|
fields=[
|
|
dict(name="id", type={"type": "int64"}, nullable=False),
|
|
]
|
|
),
|
|
)
|
|
)
|
|
request.wfile.write(payload.encode())
|
|
else:
|
|
request.send_response(404)
|
|
request.end_headers()
|
|
|
|
with mock_lancedb_connection(handler) as db:
|
|
table = db.create_table("test", [{"id": 1}])
|
|
with ThreadPoolExecutor(3) as executor:
|
|
futures = []
|
|
for _ in range(10):
|
|
future = executor.submit(table.add, [{"id": 1}])
|
|
futures.append(future)
|
|
|
|
for future in futures:
|
|
future.result()
|
|
|
|
|
|
def test_table_create_indices():
|
|
def handler(request):
|
|
index_stats = dict(
|
|
index_type="IVF_PQ", num_indexed_rows=1000, num_unindexed_rows=0
|
|
)
|
|
|
|
if request.path == "/v1/table/test/create_index/":
|
|
request.send_response(200)
|
|
request.end_headers()
|
|
elif request.path == "/v1/table/test/create/?mode=create":
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
request.wfile.write(b"{}")
|
|
elif request.path == "/v1/table/test/describe/":
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
payload = json.dumps(
|
|
dict(
|
|
version=1,
|
|
schema=dict(
|
|
fields=[
|
|
dict(name="id", type={"type": "int64"}, nullable=False),
|
|
]
|
|
),
|
|
)
|
|
)
|
|
request.wfile.write(payload.encode())
|
|
elif request.path == "/v1/table/test/index/list/":
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
payload = json.dumps(
|
|
dict(
|
|
indexes=[
|
|
{
|
|
"index_name": "id_idx",
|
|
"columns": ["id"],
|
|
},
|
|
{
|
|
"index_name": "text_idx",
|
|
"columns": ["text"],
|
|
},
|
|
{
|
|
"index_name": "vector_idx",
|
|
"columns": ["vector"],
|
|
},
|
|
]
|
|
)
|
|
)
|
|
request.wfile.write(payload.encode())
|
|
elif request.path == "/v1/table/test/index/id_idx/stats/":
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
payload = json.dumps(index_stats)
|
|
request.wfile.write(payload.encode())
|
|
elif request.path == "/v1/table/test/index/text_idx/stats/":
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
payload = json.dumps(index_stats)
|
|
request.wfile.write(payload.encode())
|
|
elif request.path == "/v1/table/test/index/vector_idx/stats/":
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
payload = json.dumps(index_stats)
|
|
request.wfile.write(payload.encode())
|
|
elif "/drop/" in request.path:
|
|
request.send_response(200)
|
|
request.end_headers()
|
|
else:
|
|
request.send_response(404)
|
|
request.end_headers()
|
|
|
|
with mock_lancedb_connection(handler) as db:
|
|
# Parameters are well-tested through local and async tests.
|
|
# This is a smoke-test.
|
|
table = db.create_table("test", [{"id": 1}])
|
|
table.create_scalar_index("id", wait_timeout=timedelta(seconds=2))
|
|
table.create_fts_index("text", wait_timeout=timedelta(seconds=2))
|
|
table.create_index(
|
|
vector_column_name="vector", wait_timeout=timedelta(seconds=10)
|
|
)
|
|
table.wait_for_index(["id_idx"], timedelta(seconds=2))
|
|
table.wait_for_index(["text_idx", "vector_idx"], timedelta(seconds=2))
|
|
table.drop_index("vector_idx")
|
|
table.drop_index("id_idx")
|
|
table.drop_index("text_idx")
|
|
|
|
|
|
def test_table_wait_for_index_timeout():
|
|
def handler(request):
|
|
index_stats = dict(
|
|
index_type="BTREE", num_indexed_rows=1000, num_unindexed_rows=1
|
|
)
|
|
|
|
if request.path == "/v1/table/test/create/?mode=create":
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
request.wfile.write(b"{}")
|
|
elif request.path == "/v1/table/test/describe/":
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
payload = json.dumps(
|
|
dict(
|
|
version=1,
|
|
schema=dict(
|
|
fields=[
|
|
dict(name="id", type={"type": "int64"}, nullable=False),
|
|
]
|
|
),
|
|
)
|
|
)
|
|
request.wfile.write(payload.encode())
|
|
elif request.path == "/v1/table/test/index/list/":
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
payload = json.dumps(
|
|
dict(
|
|
indexes=[
|
|
{
|
|
"index_name": "id_idx",
|
|
"columns": ["id"],
|
|
},
|
|
]
|
|
)
|
|
)
|
|
request.wfile.write(payload.encode())
|
|
elif request.path == "/v1/table/test/index/id_idx/stats/":
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
payload = json.dumps(index_stats)
|
|
print(f"{index_stats=}")
|
|
request.wfile.write(payload.encode())
|
|
else:
|
|
request.send_response(404)
|
|
request.end_headers()
|
|
|
|
with mock_lancedb_connection(handler) as db:
|
|
table = db.create_table("test", [{"id": 1}])
|
|
with pytest.raises(
|
|
RuntimeError,
|
|
match=re.escape(
|
|
'Timeout error: timed out waiting for indices: ["id_idx"] after 1s'
|
|
),
|
|
):
|
|
table.wait_for_index(["id_idx"], timedelta(seconds=1))
|
|
|
|
|
|
def test_stats():
|
|
stats = {
|
|
"total_bytes": 38,
|
|
"num_rows": 2,
|
|
"num_indices": 0,
|
|
"fragment_stats": {
|
|
"num_fragments": 1,
|
|
"num_small_fragments": 1,
|
|
"lengths": {
|
|
"min": 2,
|
|
"max": 2,
|
|
"mean": 2,
|
|
"p25": 2,
|
|
"p50": 2,
|
|
"p75": 2,
|
|
"p99": 2,
|
|
},
|
|
},
|
|
}
|
|
|
|
def handler(request):
|
|
if request.path == "/v1/table/test/create/?mode=create":
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
request.wfile.write(b"{}")
|
|
elif request.path == "/v1/table/test/stats/":
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
payload = json.dumps(stats)
|
|
request.wfile.write(payload.encode())
|
|
else:
|
|
print(request.path)
|
|
request.send_response(404)
|
|
request.end_headers()
|
|
|
|
with mock_lancedb_connection(handler) as db:
|
|
table = db.create_table("test", [{"id": 1}])
|
|
res = table.stats()
|
|
print(f"{res=}")
|
|
assert res == stats
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def query_test_table(query_handler, *, server_version=Version("0.1.0")):
|
|
def handler(request):
|
|
if request.path == "/v1/table/test/describe/":
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.send_header("phalanx-version", str(server_version))
|
|
request.end_headers()
|
|
request.wfile.write(b"{}")
|
|
elif request.path == "/v1/table/test/query/":
|
|
content_len = int(request.headers.get("Content-Length"))
|
|
body = request.rfile.read(content_len)
|
|
body = json.loads(body)
|
|
|
|
data = query_handler(body)
|
|
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/vnd.apache.arrow.file")
|
|
request.end_headers()
|
|
|
|
with pa.ipc.new_file(request.wfile, schema=data.schema) as f:
|
|
f.write_table(data)
|
|
else:
|
|
request.send_response(404)
|
|
request.end_headers()
|
|
|
|
with mock_lancedb_connection(handler) as db:
|
|
assert repr(db) == "RemoteConnect(name=dev)"
|
|
table = db.open_table("test")
|
|
assert repr(table) == "RemoteTable(dev.test)"
|
|
yield table
|
|
|
|
|
|
def test_query_sync_minimal():
|
|
def handler(body):
|
|
assert body == {
|
|
"distance_type": "l2",
|
|
"k": 10,
|
|
"prefilter": True,
|
|
"refine_factor": None,
|
|
"lower_bound": None,
|
|
"upper_bound": None,
|
|
"ef": None,
|
|
"vector": [1.0, 2.0, 3.0],
|
|
"nprobes": 20,
|
|
"version": None,
|
|
}
|
|
|
|
return pa.table({"id": [1, 2, 3]})
|
|
|
|
with query_test_table(handler) as table:
|
|
data = table.search([1, 2, 3]).to_list()
|
|
expected = [{"id": 1}, {"id": 2}, {"id": 3}]
|
|
assert data == expected
|
|
|
|
|
|
def test_query_sync_empty_query():
|
|
def handler(body):
|
|
assert body == {
|
|
"k": 10,
|
|
"filter": "true",
|
|
"vector": [],
|
|
"columns": ["id"],
|
|
"prefilter": True,
|
|
"version": None,
|
|
}
|
|
|
|
return pa.table({"id": [1, 2, 3]})
|
|
|
|
with query_test_table(handler) as table:
|
|
data = table.search(None).where("true").select(["id"]).limit(10).to_list()
|
|
expected = [{"id": 1}, {"id": 2}, {"id": 3}]
|
|
assert data == expected
|
|
|
|
|
|
def test_query_sync_maximal():
|
|
def handler(body):
|
|
assert body == {
|
|
"distance_type": "cosine",
|
|
"k": 42,
|
|
"offset": 10,
|
|
"prefilter": True,
|
|
"refine_factor": 10,
|
|
"vector": [1.0, 2.0, 3.0],
|
|
"nprobes": 5,
|
|
"lower_bound": None,
|
|
"upper_bound": None,
|
|
"ef": None,
|
|
"filter": "id > 0",
|
|
"columns": ["id", "name"],
|
|
"vector_column": "vector2",
|
|
"fast_search": True,
|
|
"with_row_id": True,
|
|
"version": None,
|
|
}
|
|
|
|
return pa.table({"id": [1, 2, 3], "name": ["a", "b", "c"]})
|
|
|
|
with query_test_table(handler) as table:
|
|
(
|
|
table.search([1, 2, 3], vector_column_name="vector2", fast_search=True)
|
|
.distance_type("cosine")
|
|
.limit(42)
|
|
.offset(10)
|
|
.refine_factor(10)
|
|
.nprobes(5)
|
|
.where("id > 0", prefilter=True)
|
|
.with_row_id(True)
|
|
.select(["id", "name"])
|
|
.to_list()
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("server_version", [Version("0.1.0"), Version("0.2.0")])
|
|
def test_query_sync_batch_queries(server_version):
|
|
def handler(body):
|
|
# TODO: we will add the ability to get the server version,
|
|
# so that we can decide how to perform batch quires.
|
|
vectors = body["vector"]
|
|
if server_version >= Version(
|
|
"0.2.0"
|
|
): # we can handle batch queries in single request since 0.2.0
|
|
assert len(vectors) == 2
|
|
res = []
|
|
for i, vector in enumerate(vectors):
|
|
res.append({"id": 1, "query_index": i})
|
|
return pa.Table.from_pylist(res)
|
|
else:
|
|
assert len(vectors) == 3 # matching dim
|
|
return pa.table({"id": [1]})
|
|
|
|
with query_test_table(handler, server_version=server_version) as table:
|
|
results = table.search([[1, 2, 3], [4, 5, 6]]).limit(1).to_list()
|
|
assert len(results) == 2
|
|
results.sort(key=lambda x: x["query_index"])
|
|
assert results == [{"id": 1, "query_index": 0}, {"id": 1, "query_index": 1}]
|
|
|
|
|
|
def test_query_sync_fts():
|
|
def handler(body):
|
|
assert body == {
|
|
"full_text_query": {
|
|
"query": "puppy",
|
|
"columns": [],
|
|
},
|
|
"k": 10,
|
|
"prefilter": True,
|
|
"vector": [],
|
|
"version": None,
|
|
}
|
|
|
|
return pa.table({"id": [1, 2, 3]})
|
|
|
|
with query_test_table(handler) as table:
|
|
(table.search("puppy", query_type="fts").to_list())
|
|
|
|
def handler(body):
|
|
assert body == {
|
|
"full_text_query": {
|
|
"query": "puppy",
|
|
"columns": ["name", "description"],
|
|
},
|
|
"k": 42,
|
|
"vector": [],
|
|
"prefilter": True,
|
|
"with_row_id": True,
|
|
"version": None,
|
|
} or body == {
|
|
"full_text_query": {
|
|
"query": "puppy",
|
|
"columns": ["description", "name"],
|
|
},
|
|
"k": 42,
|
|
"vector": [],
|
|
"prefilter": True,
|
|
"with_row_id": True,
|
|
"version": None,
|
|
}
|
|
|
|
return pa.table({"id": [1, 2, 3]})
|
|
|
|
with query_test_table(handler) as table:
|
|
(
|
|
table.search("puppy", query_type="fts", fts_columns=["name", "description"])
|
|
.with_row_id(True)
|
|
.limit(42)
|
|
.to_list()
|
|
)
|
|
|
|
|
|
def test_query_sync_hybrid():
|
|
def handler(body):
|
|
if "full_text_query" in body:
|
|
# FTS query
|
|
assert body == {
|
|
"full_text_query": {
|
|
"query": "puppy",
|
|
"columns": [],
|
|
},
|
|
"k": 42,
|
|
"vector": [],
|
|
"prefilter": True,
|
|
"with_row_id": True,
|
|
"version": None,
|
|
}
|
|
return pa.table({"_rowid": [1, 2, 3], "_score": [0.1, 0.2, 0.3]})
|
|
else:
|
|
# Vector query
|
|
assert body == {
|
|
"distance_type": "l2",
|
|
"k": 42,
|
|
"prefilter": True,
|
|
"refine_factor": None,
|
|
"vector": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
|
"nprobes": 20,
|
|
"lower_bound": None,
|
|
"upper_bound": None,
|
|
"ef": None,
|
|
"with_row_id": True,
|
|
"version": None,
|
|
}
|
|
return pa.table({"_rowid": [1, 2, 3], "_distance": [0.1, 0.2, 0.3]})
|
|
|
|
with query_test_table(handler) as table:
|
|
embedding_func = MockTextEmbeddingFunction()
|
|
embedding_config = MagicMock()
|
|
embedding_config.function = embedding_func
|
|
|
|
embedding_funcs = MagicMock()
|
|
embedding_funcs.get = MagicMock(return_value=embedding_config)
|
|
table.embedding_functions = embedding_funcs
|
|
|
|
(table.search("puppy", query_type="hybrid").limit(42).to_list())
|
|
|
|
|
|
def test_create_client():
|
|
mandatory_args = {
|
|
"uri": "db://dev",
|
|
"api_key": "fake-api-key",
|
|
"region": "us-east-1",
|
|
}
|
|
|
|
db = lancedb.connect(**mandatory_args)
|
|
assert isinstance(db.client_config, ClientConfig)
|
|
|
|
db = lancedb.connect(**mandatory_args, client_config={})
|
|
assert isinstance(db.client_config, ClientConfig)
|
|
|
|
db = lancedb.connect(
|
|
**mandatory_args,
|
|
client_config=ClientConfig(timeout_config={"connect_timeout": 42}),
|
|
)
|
|
assert isinstance(db.client_config, ClientConfig)
|
|
assert db.client_config.timeout_config.connect_timeout == timedelta(seconds=42)
|
|
|
|
db = lancedb.connect(
|
|
**mandatory_args,
|
|
client_config={"timeout_config": {"connect_timeout": timedelta(seconds=42)}},
|
|
)
|
|
assert isinstance(db.client_config, ClientConfig)
|
|
assert db.client_config.timeout_config.connect_timeout == timedelta(seconds=42)
|
|
|
|
db = lancedb.connect(
|
|
**mandatory_args, client_config=ClientConfig(retry_config={"retries": 42})
|
|
)
|
|
assert isinstance(db.client_config, ClientConfig)
|
|
assert db.client_config.retry_config.retries == 42
|
|
|
|
db = lancedb.connect(
|
|
**mandatory_args, client_config={"retry_config": {"retries": 42}}
|
|
)
|
|
assert isinstance(db.client_config, ClientConfig)
|
|
assert db.client_config.retry_config.retries == 42
|
|
|
|
with pytest.warns(DeprecationWarning):
|
|
db = lancedb.connect(**mandatory_args, connection_timeout=42)
|
|
assert db.client_config.timeout_config.connect_timeout == timedelta(seconds=42)
|
|
|
|
with pytest.warns(DeprecationWarning):
|
|
db = lancedb.connect(**mandatory_args, read_timeout=42)
|
|
assert db.client_config.timeout_config.read_timeout == timedelta(seconds=42)
|
|
|
|
with pytest.warns(DeprecationWarning):
|
|
lancedb.connect(**mandatory_args, request_thread_pool=10)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_pass_through_headers():
|
|
def handler(request):
|
|
assert request.headers["foo"] == "bar"
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
request.wfile.write(b'{"tables": []}')
|
|
|
|
async with mock_lancedb_connection_async(
|
|
handler, extra_headers={"foo": "bar"}
|
|
) as db:
|
|
table_names = await db.table_names()
|
|
assert table_names == []
|