# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright The LanceDB Authors import re from concurrent.futures import ThreadPoolExecutor import contextlib from datetime import timedelta import http.server import json import threading from unittest.mock import MagicMock import uuid from packaging.version import Version import lancedb from lancedb.conftest import MockTextEmbeddingFunction from lancedb.remote import ClientConfig from lancedb.remote.errors import HttpError, RetryError import pytest import pyarrow as pa def make_mock_http_handler(handler): class MockLanceDBHandler(http.server.BaseHTTPRequestHandler): def do_GET(self): handler(self) def do_POST(self): handler(self) return MockLanceDBHandler @contextlib.contextmanager def mock_lancedb_connection(handler): with http.server.HTTPServer( ("localhost", 0), make_mock_http_handler(handler) ) as server: port = server.server_address[1] handle = threading.Thread(target=server.serve_forever) handle.start() db = lancedb.connect( "db://dev", api_key="fake", host_override=f"http://localhost:{port}", client_config={ "retry_config": {"retries": 2}, "timeout_config": { "connect_timeout": 1, }, }, ) try: yield db finally: server.shutdown() handle.join() @contextlib.asynccontextmanager async def mock_lancedb_connection_async(handler, **client_config): with http.server.HTTPServer( ("localhost", 0), make_mock_http_handler(handler) ) as server: port = server.server_address[1] handle = threading.Thread(target=server.serve_forever) handle.start() db = await lancedb.connect_async( "db://dev", api_key="fake", host_override=f"http://localhost:{port}", client_config={ "retry_config": {"retries": 2}, "timeout_config": { "connect_timeout": 1, }, **client_config, }, ) try: yield db finally: server.shutdown() handle.join() @pytest.mark.asyncio async def test_async_remote_db(): def handler(request): # We created a UUID request id request_id = request.headers["x-request-id"] assert uuid.UUID(request_id).version == 4 # We set a user agent with the current library version user_agent = request.headers["User-Agent"] assert user_agent == f"LanceDB-Python-Client/{lancedb.__version__}" request.send_response(200) request.send_header("Content-Type", "application/json") request.end_headers() request.wfile.write(b'{"tables": []}') async with mock_lancedb_connection_async(handler) as db: table_names = await db.table_names() assert table_names == [] @pytest.mark.asyncio async def test_async_checkout(): def handler(request): if request.path == "/v1/table/test/describe/": request.send_response(200) request.send_header("Content-Type", "application/json") request.end_headers() response = json.dumps({"version": 42, "schema": {"fields": []}}) request.wfile.write(response.encode()) return content_len = int(request.headers.get("Content-Length")) body = request.rfile.read(content_len) body = json.loads(body) print("body is", body) count = 0 if body["version"] == 1: count = 100 elif body["version"] == 2: count = 200 elif body["version"] is None: count = 300 request.send_response(200) request.send_header("Content-Type", "application/json") request.end_headers() request.wfile.write(json.dumps(count).encode()) async with mock_lancedb_connection_async(handler) as db: table = await db.open_table("test") assert await table.count_rows() == 300 await table.checkout(1) assert await table.count_rows() == 100 await table.checkout(2) assert await table.count_rows() == 200 await table.checkout_latest() assert await table.count_rows() == 300 def test_table_len_sync(): def handler(request): if request.path == "/v1/table/test/create/?mode=create": request.send_response(200) request.send_header("Content-Type", "application/json") request.end_headers() request.wfile.write(b"{}") request.send_response(200) request.send_header("Content-Type", "application/json") request.end_headers() request.wfile.write(json.dumps(1).encode()) with mock_lancedb_connection(handler) as db: table = db.create_table("test", [{"id": 1}]) assert len(table) == 1 @pytest.mark.asyncio async def test_http_error(): request_id_holder = {"request_id": None} def handler(request): request_id_holder["request_id"] = request.headers["x-request-id"] request.send_response(507) request.end_headers() request.wfile.write(b"Internal Server Error") async with mock_lancedb_connection_async(handler) as db: with pytest.raises(HttpError) as exc_info: await db.table_names() assert exc_info.value.request_id == request_id_holder["request_id"] assert "Internal Server Error" in str(exc_info.value) @pytest.mark.asyncio async def test_retry_error(): request_id_holder = {"request_id": None} def handler(request): request_id_holder["request_id"] = request.headers["x-request-id"] request.send_response(429) request.end_headers() request.wfile.write(b"Try again later") async with mock_lancedb_connection_async(handler) as db: with pytest.raises(RetryError) as exc_info: await db.table_names() assert exc_info.value.request_id == request_id_holder["request_id"] cause = exc_info.value.__cause__ assert isinstance(cause, HttpError) assert "Try again later" in str(cause) assert cause.request_id == request_id_holder["request_id"] assert cause.status_code == 429 def test_table_add_in_threadpool(): def handler(request): if request.path == "/v1/table/test/insert/": request.send_response(200) request.end_headers() elif request.path == "/v1/table/test/create/?mode=create": request.send_response(200) request.send_header("Content-Type", "application/json") request.end_headers() request.wfile.write(b"{}") elif request.path == "/v1/table/test/describe/": request.send_response(200) request.send_header("Content-Type", "application/json") request.end_headers() payload = json.dumps( dict( version=1, schema=dict( fields=[ dict(name="id", type={"type": "int64"}, nullable=False), ] ), ) ) request.wfile.write(payload.encode()) else: request.send_response(404) request.end_headers() with mock_lancedb_connection(handler) as db: table = db.create_table("test", [{"id": 1}]) with ThreadPoolExecutor(3) as executor: futures = [] for _ in range(10): future = executor.submit(table.add, [{"id": 1}]) futures.append(future) for future in futures: future.result() def test_table_create_indices(): def handler(request): index_stats = dict( index_type="IVF_PQ", num_indexed_rows=1000, num_unindexed_rows=0 ) if request.path == "/v1/table/test/create_index/": request.send_response(200) request.end_headers() elif request.path == "/v1/table/test/create/?mode=create": request.send_response(200) request.send_header("Content-Type", "application/json") request.end_headers() request.wfile.write(b"{}") elif request.path == "/v1/table/test/describe/": request.send_response(200) request.send_header("Content-Type", "application/json") request.end_headers() payload = json.dumps( dict( version=1, schema=dict( fields=[ dict(name="id", type={"type": "int64"}, nullable=False), ] ), ) ) request.wfile.write(payload.encode()) elif request.path == "/v1/table/test/index/list/": request.send_response(200) request.send_header("Content-Type", "application/json") request.end_headers() payload = json.dumps( dict( indexes=[ { "index_name": "id_idx", "columns": ["id"], }, { "index_name": "text_idx", "columns": ["text"], }, { "index_name": "vector_idx", "columns": ["vector"], }, ] ) ) request.wfile.write(payload.encode()) elif request.path == "/v1/table/test/index/id_idx/stats/": request.send_response(200) request.send_header("Content-Type", "application/json") request.end_headers() payload = json.dumps(index_stats) request.wfile.write(payload.encode()) elif request.path == "/v1/table/test/index/text_idx/stats/": request.send_response(200) request.send_header("Content-Type", "application/json") request.end_headers() payload = json.dumps(index_stats) request.wfile.write(payload.encode()) elif request.path == "/v1/table/test/index/vector_idx/stats/": request.send_response(200) request.send_header("Content-Type", "application/json") request.end_headers() payload = json.dumps(index_stats) request.wfile.write(payload.encode()) elif "/drop/" in request.path: request.send_response(200) request.end_headers() else: request.send_response(404) request.end_headers() with mock_lancedb_connection(handler) as db: # Parameters are well-tested through local and async tests. # This is a smoke-test. table = db.create_table("test", [{"id": 1}]) table.create_scalar_index("id", wait_timeout=timedelta(seconds=2)) table.create_fts_index("text", wait_timeout=timedelta(seconds=2)) table.create_index( vector_column_name="vector", wait_timeout=timedelta(seconds=10) ) table.wait_for_index(["id_idx"], timedelta(seconds=2)) table.wait_for_index(["text_idx", "vector_idx"], timedelta(seconds=2)) table.drop_index("vector_idx") table.drop_index("id_idx") table.drop_index("text_idx") def test_table_wait_for_index_timeout(): def handler(request): index_stats = dict( index_type="BTREE", num_indexed_rows=1000, num_unindexed_rows=1 ) if request.path == "/v1/table/test/create/?mode=create": request.send_response(200) request.send_header("Content-Type", "application/json") request.end_headers() request.wfile.write(b"{}") elif request.path == "/v1/table/test/describe/": request.send_response(200) request.send_header("Content-Type", "application/json") request.end_headers() payload = json.dumps( dict( version=1, schema=dict( fields=[ dict(name="id", type={"type": "int64"}, nullable=False), ] ), ) ) request.wfile.write(payload.encode()) elif request.path == "/v1/table/test/index/list/": request.send_response(200) request.send_header("Content-Type", "application/json") request.end_headers() payload = json.dumps( dict( indexes=[ { "index_name": "id_idx", "columns": ["id"], }, ] ) ) request.wfile.write(payload.encode()) elif request.path == "/v1/table/test/index/id_idx/stats/": request.send_response(200) request.send_header("Content-Type", "application/json") request.end_headers() payload = json.dumps(index_stats) print(f"{index_stats=}") request.wfile.write(payload.encode()) else: request.send_response(404) request.end_headers() with mock_lancedb_connection(handler) as db: table = db.create_table("test", [{"id": 1}]) with pytest.raises( RuntimeError, match=re.escape( 'Timeout error: timed out waiting for indices: ["id_idx"] after 1s' ), ): table.wait_for_index(["id_idx"], timedelta(seconds=1)) def test_stats(): stats = { "total_bytes": 38, "num_rows": 2, "num_indices": 0, "fragment_stats": { "num_fragments": 1, "num_small_fragments": 1, "lengths": { "min": 2, "max": 2, "mean": 2, "p25": 2, "p50": 2, "p75": 2, "p99": 2, }, }, } def handler(request): if request.path == "/v1/table/test/create/?mode=create": request.send_response(200) request.send_header("Content-Type", "application/json") request.end_headers() request.wfile.write(b"{}") elif request.path == "/v1/table/test/stats/": request.send_response(200) request.send_header("Content-Type", "application/json") request.end_headers() payload = json.dumps(stats) request.wfile.write(payload.encode()) else: print(request.path) request.send_response(404) request.end_headers() with mock_lancedb_connection(handler) as db: table = db.create_table("test", [{"id": 1}]) res = table.stats() print(f"{res=}") assert res == stats @contextlib.contextmanager def query_test_table(query_handler, *, server_version=Version("0.1.0")): def handler(request): if request.path == "/v1/table/test/describe/": request.send_response(200) request.send_header("Content-Type", "application/json") request.send_header("phalanx-version", str(server_version)) request.end_headers() request.wfile.write(b"{}") elif request.path == "/v1/table/test/query/": content_len = int(request.headers.get("Content-Length")) body = request.rfile.read(content_len) body = json.loads(body) data = query_handler(body) request.send_response(200) request.send_header("Content-Type", "application/vnd.apache.arrow.file") request.end_headers() with pa.ipc.new_file(request.wfile, schema=data.schema) as f: f.write_table(data) else: request.send_response(404) request.end_headers() with mock_lancedb_connection(handler) as db: assert repr(db) == "RemoteConnect(name=dev)" table = db.open_table("test") assert repr(table) == "RemoteTable(dev.test)" yield table def test_query_sync_minimal(): def handler(body): assert body == { "distance_type": "l2", "k": 10, "prefilter": True, "refine_factor": None, "lower_bound": None, "upper_bound": None, "ef": None, "vector": [1.0, 2.0, 3.0], "nprobes": 20, "version": None, } return pa.table({"id": [1, 2, 3]}) with query_test_table(handler) as table: data = table.search([1, 2, 3]).to_list() expected = [{"id": 1}, {"id": 2}, {"id": 3}] assert data == expected def test_query_sync_empty_query(): def handler(body): assert body == { "k": 10, "filter": "true", "vector": [], "columns": ["id"], "prefilter": True, "version": None, } return pa.table({"id": [1, 2, 3]}) with query_test_table(handler) as table: data = table.search(None).where("true").select(["id"]).limit(10).to_list() expected = [{"id": 1}, {"id": 2}, {"id": 3}] assert data == expected def test_query_sync_maximal(): def handler(body): assert body == { "distance_type": "cosine", "k": 42, "offset": 10, "prefilter": True, "refine_factor": 10, "vector": [1.0, 2.0, 3.0], "nprobes": 5, "lower_bound": None, "upper_bound": None, "ef": None, "filter": "id > 0", "columns": ["id", "name"], "vector_column": "vector2", "fast_search": True, "with_row_id": True, "version": None, } return pa.table({"id": [1, 2, 3], "name": ["a", "b", "c"]}) with query_test_table(handler) as table: ( table.search([1, 2, 3], vector_column_name="vector2", fast_search=True) .distance_type("cosine") .limit(42) .offset(10) .refine_factor(10) .nprobes(5) .where("id > 0", prefilter=True) .with_row_id(True) .select(["id", "name"]) .to_list() ) @pytest.mark.parametrize("server_version", [Version("0.1.0"), Version("0.2.0")]) def test_query_sync_batch_queries(server_version): def handler(body): # TODO: we will add the ability to get the server version, # so that we can decide how to perform batch quires. vectors = body["vector"] if server_version >= Version( "0.2.0" ): # we can handle batch queries in single request since 0.2.0 assert len(vectors) == 2 res = [] for i, vector in enumerate(vectors): res.append({"id": 1, "query_index": i}) return pa.Table.from_pylist(res) else: assert len(vectors) == 3 # matching dim return pa.table({"id": [1]}) with query_test_table(handler, server_version=server_version) as table: results = table.search([[1, 2, 3], [4, 5, 6]]).limit(1).to_list() assert len(results) == 2 results.sort(key=lambda x: x["query_index"]) assert results == [{"id": 1, "query_index": 0}, {"id": 1, "query_index": 1}] def test_query_sync_fts(): def handler(body): assert body == { "full_text_query": { "query": "puppy", "columns": [], }, "k": 10, "prefilter": True, "vector": [], "version": None, } return pa.table({"id": [1, 2, 3]}) with query_test_table(handler) as table: (table.search("puppy", query_type="fts").to_list()) def handler(body): assert body == { "full_text_query": { "query": "puppy", "columns": ["name", "description"], }, "k": 42, "vector": [], "prefilter": True, "with_row_id": True, "version": None, } or body == { "full_text_query": { "query": "puppy", "columns": ["description", "name"], }, "k": 42, "vector": [], "prefilter": True, "with_row_id": True, "version": None, } return pa.table({"id": [1, 2, 3]}) with query_test_table(handler) as table: ( table.search("puppy", query_type="fts", fts_columns=["name", "description"]) .with_row_id(True) .limit(42) .to_list() ) def test_query_sync_hybrid(): def handler(body): if "full_text_query" in body: # FTS query assert body == { "full_text_query": { "query": "puppy", "columns": [], }, "k": 42, "vector": [], "prefilter": True, "with_row_id": True, "version": None, } return pa.table({"_rowid": [1, 2, 3], "_score": [0.1, 0.2, 0.3]}) else: # Vector query assert body == { "distance_type": "l2", "k": 42, "prefilter": True, "refine_factor": None, "vector": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], "nprobes": 20, "lower_bound": None, "upper_bound": None, "ef": None, "with_row_id": True, "version": None, } return pa.table({"_rowid": [1, 2, 3], "_distance": [0.1, 0.2, 0.3]}) with query_test_table(handler) as table: embedding_func = MockTextEmbeddingFunction() embedding_config = MagicMock() embedding_config.function = embedding_func embedding_funcs = MagicMock() embedding_funcs.get = MagicMock(return_value=embedding_config) table.embedding_functions = embedding_funcs (table.search("puppy", query_type="hybrid").limit(42).to_list()) def test_create_client(): mandatory_args = { "uri": "db://dev", "api_key": "fake-api-key", "region": "us-east-1", } db = lancedb.connect(**mandatory_args) assert isinstance(db.client_config, ClientConfig) db = lancedb.connect(**mandatory_args, client_config={}) assert isinstance(db.client_config, ClientConfig) db = lancedb.connect( **mandatory_args, client_config=ClientConfig(timeout_config={"connect_timeout": 42}), ) assert isinstance(db.client_config, ClientConfig) assert db.client_config.timeout_config.connect_timeout == timedelta(seconds=42) db = lancedb.connect( **mandatory_args, client_config={"timeout_config": {"connect_timeout": timedelta(seconds=42)}}, ) assert isinstance(db.client_config, ClientConfig) assert db.client_config.timeout_config.connect_timeout == timedelta(seconds=42) db = lancedb.connect( **mandatory_args, client_config=ClientConfig(retry_config={"retries": 42}) ) assert isinstance(db.client_config, ClientConfig) assert db.client_config.retry_config.retries == 42 db = lancedb.connect( **mandatory_args, client_config={"retry_config": {"retries": 42}} ) assert isinstance(db.client_config, ClientConfig) assert db.client_config.retry_config.retries == 42 with pytest.warns(DeprecationWarning): db = lancedb.connect(**mandatory_args, connection_timeout=42) assert db.client_config.timeout_config.connect_timeout == timedelta(seconds=42) with pytest.warns(DeprecationWarning): db = lancedb.connect(**mandatory_args, read_timeout=42) assert db.client_config.timeout_config.read_timeout == timedelta(seconds=42) with pytest.warns(DeprecationWarning): lancedb.connect(**mandatory_args, request_thread_pool=10) @pytest.mark.asyncio async def test_pass_through_headers(): def handler(request): assert request.headers["foo"] == "bar" request.send_response(200) request.send_header("Content-Type", "application/json") request.end_headers() request.wfile.write(b'{"tables": []}') async with mock_lancedb_connection_async( handler, extra_headers={"foo": "bar"} ) as db: table_names = await db.table_names() assert table_names == []