mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-20 13:30:41 +00:00
## Summary
PyTorch's `DataLoader` uses fork-based multiprocessing by default on
Linux, but threads do not survive `fork()`. LanceDB's Python bindings
drive async work through two threaded layers, both of which become inert
in a forked child:
- `BackgroundEventLoop` runs an asyncio loop on a Python
`threading.Thread`.
- `pyo3-async-runtimes::tokio` holds a global multi-threaded tokio
runtime whose worker threads also die on fork — and its runtime lives in
a `OnceLock` that cannot be replaced after first use.
As a result, any `Permutation` (or other async API) used inside a
fork-based `DataLoader` worker hangs indefinitely. This PR makes both
layers fork-safe so `Permutation` works as a `torch.utils.data.Dataset`
with `num_workers > 0`.
## Approach
### Rust — new `python/src/runtime.rs`
Mirrors the pattern used in [Lance's Python
bindings](456198cd6f/python/src/lib.rs (L139)),
adapted for the async-bridge use case.
- `LanceRuntime` implements `pyo3_async_runtimes::generic::Runtime +
ContextExt`, backed by an `AtomicPtr<tokio::runtime::Runtime>` we own
(sidestepping `pyo3-async-runtimes`'s frozen `OnceLock` global).
- A `pthread_atfork(after_in_child)` handler nulls the pointer; the next
`spawn` rebuilds the runtime in the child. The previous runtime is
intentionally **leaked** — calling `Drop` would try to join now-dead
worker threads and hang.
- `runtime::future_into_py` is a drop-in for
`pyo3_async_runtimes::tokio::future_into_py`. All ~80 call sites in
`arrow.rs` / `connection.rs` / `permutation.rs` / `query.rs` /
`table.rs` are updated to route through it.
- `python/Cargo.toml` adds `libc = "0.2"` and the tokio
`rt-multi-thread` feature.
### Python — `lancedb/background_loop.py`
- Refactors `BackgroundEventLoop.__init__` to a reusable `_start()`
method.
- An `os.register_at_fork(after_in_child=…)` hook calls `LOOP._start()`
to give the singleton a fresh asyncio loop and thread **in place**. This
matters because the rest of the codebase imports `LOOP` via `from
.background_loop import LOOP` — rebinding the module attribute would
leave those references holding the dead loop.
### Python — `lancedb/__init__.py`
Removes the `__warn_on_fork` pre-fork warning (and the now-unused
`import warnings`). Fork is supported.
## Test plan
- [x] New `test_permutation_dataloader_fork_workers` in
`python/tests/test_torch.py`: runs a `Permutation` through
`torch.utils.data.DataLoader(num_workers=2,
multiprocessing_context="fork")` inside a spawn-isolated child with a
30s hang detector. **Pre-fix**: timed out at 36s. **Post-fix**: passes
in ~3.6s.
- [x] New `test_remote_connection_after_fork` in
`python/tests/test_remote_db.py`: forks a child that creates a fresh
`lancedb.connect(...)` against a mock HTTP server and calls
`table_names()`; passes in <1s, validates the runtime reset is
sufficient for fresh remote clients.
- [x] All 62 tests in `test_torch.py` + `test_permutation.py` pass.
- [x] All 35 tests in `test_remote_db.py` pass.
- [x] `test_table.py` (87) + `test_db.py` + `test_query.py` (157, minus
one unrelated `sentence_transformers` import skip) — 244 passing.
- [x] `cargo clippy -p lancedb-python --tests` clean.
- [x] `cargo fmt`, `ruff check`, `ruff format` all clean.
## Known limitation (follow-up)
This PR makes a **freshly-built** `lancedb.connect(...)` work in a
forked child. An **inherited** `Connection` from the parent still
carries an inherited `reqwest::Client` whose hyper connection pool
references socket FDs and TCP/TLS state shared with the parent — using
it from the child after fork is unsafe (especially with HTTP/1.1
keep-alive). The recommended pattern for fork-based `DataLoader` workers
that hit a remote DB is to construct a new connection inside the worker.
Auto-clearing inherited HTTP client pools on fork would require tracking
live `Connection` instances in `lancedb` core and is left for a
follow-up PR.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
---------
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1314 lines
43 KiB
Python
1314 lines
43 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
|
import re
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
import contextlib
|
|
from datetime import timedelta
|
|
import http.server
|
|
import json
|
|
import multiprocessing as mp
|
|
import sys
|
|
import threading
|
|
import time
|
|
from unittest.mock import MagicMock, patch
|
|
import uuid
|
|
from packaging.version import Version
|
|
|
|
import lancedb
|
|
from lancedb.conftest import MockTextEmbeddingFunction
|
|
from lancedb.remote import ClientConfig
|
|
from lancedb.remote.errors import HttpError, RetryError
|
|
import pytest
|
|
import pyarrow as pa
|
|
|
|
|
|
def make_mock_http_handler(handler):
|
|
class MockLanceDBHandler(http.server.BaseHTTPRequestHandler):
|
|
def do_GET(self):
|
|
handler(self)
|
|
|
|
def do_POST(self):
|
|
handler(self)
|
|
|
|
return MockLanceDBHandler
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def mock_lancedb_connection(handler):
|
|
with http.server.HTTPServer(
|
|
("localhost", 0), make_mock_http_handler(handler)
|
|
) as server:
|
|
port = server.server_address[1]
|
|
handle = threading.Thread(target=server.serve_forever)
|
|
handle.start()
|
|
|
|
db = lancedb.connect(
|
|
"db://dev",
|
|
api_key="fake",
|
|
host_override=f"http://localhost:{port}",
|
|
client_config={
|
|
"retry_config": {"retries": 2},
|
|
"timeout_config": {
|
|
"connect_timeout": 1,
|
|
},
|
|
},
|
|
)
|
|
|
|
try:
|
|
yield db
|
|
finally:
|
|
server.shutdown()
|
|
handle.join()
|
|
|
|
|
|
@contextlib.asynccontextmanager
|
|
async def mock_lancedb_connection_async(handler, **client_config):
|
|
with http.server.HTTPServer(
|
|
("localhost", 0), make_mock_http_handler(handler)
|
|
) as server:
|
|
port = server.server_address[1]
|
|
handle = threading.Thread(target=server.serve_forever)
|
|
handle.start()
|
|
|
|
db = await lancedb.connect_async(
|
|
"db://dev",
|
|
api_key="fake",
|
|
host_override=f"http://localhost:{port}",
|
|
client_config={
|
|
"retry_config": {"retries": 2},
|
|
"timeout_config": {
|
|
"connect_timeout": 1,
|
|
},
|
|
**client_config,
|
|
},
|
|
)
|
|
|
|
try:
|
|
yield db
|
|
finally:
|
|
server.shutdown()
|
|
handle.join()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_remote_db():
|
|
def handler(request):
|
|
# We created a UUID request id
|
|
request_id = request.headers["x-request-id"]
|
|
assert uuid.UUID(request_id).version == 4
|
|
|
|
# We set a user agent with the current library version
|
|
user_agent = request.headers["User-Agent"]
|
|
assert user_agent == f"LanceDB-Python-Client/{lancedb.__version__}"
|
|
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
request.wfile.write(b'{"tables": []}')
|
|
|
|
async with mock_lancedb_connection_async(handler) as db:
|
|
table_names = await db.table_names()
|
|
assert table_names == []
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_checkout():
|
|
def handler(request):
|
|
if request.path == "/v1/table/test/describe/":
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
response = json.dumps({"version": 42, "schema": {"fields": []}})
|
|
request.wfile.write(response.encode())
|
|
return
|
|
|
|
content_len = int(request.headers.get("Content-Length"))
|
|
body = request.rfile.read(content_len)
|
|
body = json.loads(body)
|
|
|
|
print("body is", body)
|
|
|
|
count = 0
|
|
if body["version"] == 1:
|
|
count = 100
|
|
elif body["version"] == 2:
|
|
count = 200
|
|
elif body["version"] is None:
|
|
count = 300
|
|
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
request.wfile.write(json.dumps(count).encode())
|
|
|
|
async with mock_lancedb_connection_async(handler) as db:
|
|
table = await db.open_table("test")
|
|
assert await table.count_rows() == 300
|
|
await table.checkout(1)
|
|
assert await table.count_rows() == 100
|
|
await table.checkout(2)
|
|
assert await table.count_rows() == 200
|
|
await table.checkout_latest()
|
|
assert await table.count_rows() == 300
|
|
|
|
|
|
def test_table_len_sync():
|
|
def handler(request):
|
|
if request.path == "/v1/table/test/create/?mode=create":
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
request.wfile.write(b"{}")
|
|
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
request.wfile.write(json.dumps(1).encode())
|
|
|
|
with mock_lancedb_connection(handler) as db:
|
|
table = db.create_table("test", [{"id": 1}])
|
|
assert len(table) == 1
|
|
|
|
|
|
def test_create_table_exist_ok():
|
|
def handler(request):
|
|
if request.path == "/v1/table/test/create/?mode=exist_ok":
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
request.wfile.write(b"{}")
|
|
else:
|
|
request.send_response(404)
|
|
request.end_headers()
|
|
|
|
with mock_lancedb_connection(handler) as db:
|
|
table = db.create_table("test", [{"id": 1}], exist_ok=True)
|
|
assert table is not None
|
|
|
|
with mock_lancedb_connection(handler) as db:
|
|
table = db.create_table("test", [{"id": 1}], mode="create", exist_ok=True)
|
|
assert table is not None
|
|
|
|
|
|
def test_create_table_exist_ok_with_mode_overwrite():
|
|
def handler(request):
|
|
if request.path == "/v1/table/test/create/?mode=overwrite":
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
request.wfile.write(b"{}")
|
|
else:
|
|
request.send_response(404)
|
|
request.end_headers()
|
|
|
|
with mock_lancedb_connection(handler) as db:
|
|
table = db.create_table("test", [{"id": 1}], mode="overwrite", exist_ok=True)
|
|
assert table is not None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_http_error():
|
|
request_id_holder = {"request_id": None}
|
|
|
|
def handler(request):
|
|
request_id_holder["request_id"] = request.headers["x-request-id"]
|
|
|
|
request.send_response(507)
|
|
request.end_headers()
|
|
request.wfile.write(b"Internal Server Error")
|
|
|
|
async with mock_lancedb_connection_async(handler) as db:
|
|
with pytest.raises(HttpError) as exc_info:
|
|
await db.table_names()
|
|
|
|
assert exc_info.value.request_id == request_id_holder["request_id"]
|
|
assert "Internal Server Error" in str(exc_info.value)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retry_error():
|
|
request_id_holder = {"request_id": None}
|
|
|
|
def handler(request):
|
|
request_id_holder["request_id"] = request.headers["x-request-id"]
|
|
|
|
request.send_response(429)
|
|
request.end_headers()
|
|
request.wfile.write(b"Try again later")
|
|
|
|
async with mock_lancedb_connection_async(handler) as db:
|
|
with pytest.raises(RetryError) as exc_info:
|
|
await db.table_names()
|
|
|
|
assert exc_info.value.request_id == request_id_holder["request_id"]
|
|
|
|
cause = exc_info.value.__cause__
|
|
assert isinstance(cause, HttpError)
|
|
assert "Try again later" in str(cause)
|
|
assert cause.request_id == request_id_holder["request_id"]
|
|
assert cause.status_code == 429
|
|
|
|
|
|
def test_table_unimplemented_functions():
|
|
def handler(request):
|
|
if request.path == "/v1/table/test/create/?mode=create":
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
request.wfile.write(b"{}")
|
|
else:
|
|
request.send_response(404)
|
|
request.end_headers()
|
|
|
|
with mock_lancedb_connection(handler) as db:
|
|
table = db.create_table("test", [{"id": 1}])
|
|
with pytest.raises(NotImplementedError):
|
|
table.to_arrow()
|
|
with pytest.raises(NotImplementedError):
|
|
table.to_pandas()
|
|
|
|
|
|
def test_table_add_in_threadpool():
|
|
def handler(request):
|
|
if request.path == "/v1/table/test/insert/":
|
|
request.send_response(200)
|
|
request.end_headers()
|
|
elif request.path == "/v1/table/test/create/?mode=create":
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
request.wfile.write(b"{}")
|
|
elif request.path == "/v1/table/test/describe/":
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
payload = json.dumps(
|
|
dict(
|
|
version=1,
|
|
schema=dict(
|
|
fields=[
|
|
dict(name="id", type={"type": "int64"}, nullable=False),
|
|
]
|
|
),
|
|
)
|
|
)
|
|
request.wfile.write(payload.encode())
|
|
else:
|
|
request.send_response(404)
|
|
request.end_headers()
|
|
|
|
with mock_lancedb_connection(handler) as db:
|
|
table = db.create_table("test", [{"id": 1}])
|
|
with ThreadPoolExecutor(3) as executor:
|
|
futures = []
|
|
for _ in range(10):
|
|
future = executor.submit(table.add, [{"id": 1}])
|
|
futures.append(future)
|
|
|
|
for future in futures:
|
|
future.result()
|
|
|
|
|
|
def test_table_create_indices():
|
|
# Track received index creation requests to validate name parameter
|
|
received_requests = []
|
|
|
|
def handler(request):
|
|
index_stats = dict(
|
|
index_type="IVF_PQ", num_indexed_rows=1000, num_unindexed_rows=0
|
|
)
|
|
|
|
if request.path == "/v1/table/test/create_index/":
|
|
# Capture the request body to validate name parameter
|
|
content_len = int(request.headers.get("Content-Length", 0))
|
|
if content_len > 0:
|
|
body = request.rfile.read(content_len)
|
|
body_data = json.loads(body)
|
|
received_requests.append(body_data)
|
|
request.send_response(200)
|
|
request.end_headers()
|
|
elif request.path == "/v1/table/test/create/?mode=create":
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
request.wfile.write(b"{}")
|
|
elif request.path == "/v1/table/test/describe/":
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
payload = json.dumps(
|
|
dict(
|
|
version=1,
|
|
schema=dict(
|
|
fields=[
|
|
dict(name="id", type={"type": "int64"}, nullable=False),
|
|
]
|
|
),
|
|
)
|
|
)
|
|
request.wfile.write(payload.encode())
|
|
elif request.path == "/v1/table/test/index/list/":
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
payload = json.dumps(
|
|
dict(
|
|
indexes=[
|
|
{
|
|
"index_name": "custom_scalar_idx",
|
|
"columns": ["id"],
|
|
},
|
|
{
|
|
"index_name": "custom_fts_idx",
|
|
"columns": ["text"],
|
|
},
|
|
{
|
|
"index_name": "custom_vector_idx",
|
|
"columns": ["vector"],
|
|
},
|
|
]
|
|
)
|
|
)
|
|
request.wfile.write(payload.encode())
|
|
elif request.path == "/v1/table/test/index/custom_scalar_idx/stats/":
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
payload = json.dumps(index_stats)
|
|
request.wfile.write(payload.encode())
|
|
elif request.path == "/v1/table/test/index/custom_fts_idx/stats/":
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
payload = json.dumps(index_stats)
|
|
request.wfile.write(payload.encode())
|
|
elif request.path == "/v1/table/test/index/custom_vector_idx/stats/":
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
payload = json.dumps(index_stats)
|
|
request.wfile.write(payload.encode())
|
|
elif "/drop/" in request.path:
|
|
request.send_response(200)
|
|
request.end_headers()
|
|
else:
|
|
request.send_response(404)
|
|
request.end_headers()
|
|
|
|
with mock_lancedb_connection(handler) as db:
|
|
# Parameters are well-tested through local and async tests.
|
|
# This is a smoke-test.
|
|
table = db.create_table("test", [{"id": 1}])
|
|
|
|
# Test create_scalar_index with custom name
|
|
table.create_scalar_index(
|
|
"id", wait_timeout=timedelta(seconds=2), name="custom_scalar_idx"
|
|
)
|
|
|
|
# Test create_fts_index with custom name
|
|
table.create_fts_index(
|
|
"text", wait_timeout=timedelta(seconds=2), name="custom_fts_idx"
|
|
)
|
|
|
|
# Test create_index with custom name
|
|
table.create_index(
|
|
vector_column_name="vector",
|
|
wait_timeout=timedelta(seconds=10),
|
|
name="custom_vector_idx",
|
|
)
|
|
|
|
# Validate that the name parameter was passed correctly in requests
|
|
assert len(received_requests) == 3
|
|
|
|
# Check scalar index request has custom name
|
|
scalar_req = received_requests[0]
|
|
assert "name" in scalar_req
|
|
assert scalar_req["name"] == "custom_scalar_idx"
|
|
|
|
# Check FTS index request has custom name
|
|
fts_req = received_requests[1]
|
|
assert "name" in fts_req
|
|
assert fts_req["name"] == "custom_fts_idx"
|
|
|
|
# Check vector index request has custom name
|
|
vector_req = received_requests[2]
|
|
assert "name" in vector_req
|
|
assert vector_req["name"] == "custom_vector_idx"
|
|
|
|
table.wait_for_index(["custom_scalar_idx"], timedelta(seconds=2))
|
|
table.wait_for_index(
|
|
["custom_fts_idx", "custom_vector_idx"], timedelta(seconds=2)
|
|
)
|
|
table.drop_index("custom_vector_idx")
|
|
table.drop_index("custom_scalar_idx")
|
|
table.drop_index("custom_fts_idx")
|
|
|
|
|
|
def test_table_wait_for_index_timeout():
|
|
def handler(request):
|
|
index_stats = dict(
|
|
index_type="BTREE", num_indexed_rows=1000, num_unindexed_rows=1
|
|
)
|
|
|
|
if request.path == "/v1/table/test/create/?mode=create":
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
request.wfile.write(b"{}")
|
|
elif request.path == "/v1/table/test/describe/":
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
payload = json.dumps(
|
|
dict(
|
|
version=1,
|
|
schema=dict(
|
|
fields=[
|
|
dict(name="id", type={"type": "int64"}, nullable=False),
|
|
]
|
|
),
|
|
)
|
|
)
|
|
request.wfile.write(payload.encode())
|
|
elif request.path == "/v1/table/test/index/list/":
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
payload = json.dumps(
|
|
dict(
|
|
indexes=[
|
|
{
|
|
"index_name": "id_idx",
|
|
"columns": ["id"],
|
|
},
|
|
]
|
|
)
|
|
)
|
|
request.wfile.write(payload.encode())
|
|
elif request.path == "/v1/table/test/index/id_idx/stats/":
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
payload = json.dumps(index_stats)
|
|
print(f"{index_stats=}")
|
|
request.wfile.write(payload.encode())
|
|
else:
|
|
request.send_response(404)
|
|
request.end_headers()
|
|
|
|
with mock_lancedb_connection(handler) as db:
|
|
table = db.create_table("test", [{"id": 1}])
|
|
with pytest.raises(
|
|
RuntimeError,
|
|
match=re.escape(
|
|
'Timeout error: timed out waiting for indices: ["id_idx"] after 1s'
|
|
),
|
|
):
|
|
table.wait_for_index(["id_idx"], timedelta(seconds=1))
|
|
|
|
|
|
def test_stats():
|
|
stats = {
|
|
"total_bytes": 38,
|
|
"num_rows": 2,
|
|
"num_indices": 0,
|
|
"fragment_stats": {
|
|
"num_fragments": 1,
|
|
"num_small_fragments": 1,
|
|
"lengths": {
|
|
"min": 2,
|
|
"max": 2,
|
|
"mean": 2,
|
|
"p25": 2,
|
|
"p50": 2,
|
|
"p75": 2,
|
|
"p99": 2,
|
|
},
|
|
},
|
|
}
|
|
|
|
def handler(request):
|
|
if request.path == "/v1/table/test/create/?mode=create":
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
request.wfile.write(b"{}")
|
|
elif request.path == "/v1/table/test/stats/":
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
payload = json.dumps(stats)
|
|
request.wfile.write(payload.encode())
|
|
else:
|
|
print(request.path)
|
|
request.send_response(404)
|
|
request.end_headers()
|
|
|
|
with mock_lancedb_connection(handler) as db:
|
|
table = db.create_table("test", [{"id": 1}])
|
|
res = table.stats()
|
|
print(f"{res=}")
|
|
assert res == stats
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def query_test_table(query_handler, *, server_version=Version("0.1.0")):
|
|
def handler(request):
|
|
if request.path == "/v1/table/test/describe/":
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.send_header("phalanx-version", str(server_version))
|
|
request.end_headers()
|
|
request.wfile.write(b"{}")
|
|
elif request.path == "/v1/table/test/query/":
|
|
content_len = int(request.headers.get("Content-Length"))
|
|
body = request.rfile.read(content_len)
|
|
body = json.loads(body)
|
|
|
|
data = query_handler(body)
|
|
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/vnd.apache.arrow.file")
|
|
request.end_headers()
|
|
|
|
with pa.ipc.new_file(request.wfile, schema=data.schema) as f:
|
|
f.write_table(data)
|
|
else:
|
|
request.send_response(404)
|
|
request.end_headers()
|
|
|
|
with mock_lancedb_connection(handler) as db:
|
|
assert repr(db) == "RemoteConnect(name=dev)"
|
|
table = db.open_table("test")
|
|
assert repr(table) == "RemoteTable(dev.test)"
|
|
yield table
|
|
|
|
|
|
def test_head():
|
|
def handler(body):
|
|
assert body == {
|
|
"k": 5,
|
|
"prefilter": True,
|
|
"vector": [],
|
|
"version": None,
|
|
}
|
|
|
|
return pa.table({"id": [1, 2, 3]})
|
|
|
|
with query_test_table(handler) as table:
|
|
data = table.head(5)
|
|
assert data == pa.table({"id": [1, 2, 3]})
|
|
|
|
|
|
def test_query_sync_minimal():
|
|
def handler(body):
|
|
assert body == {
|
|
"k": 10,
|
|
"prefilter": True,
|
|
"refine_factor": None,
|
|
"lower_bound": None,
|
|
"upper_bound": None,
|
|
"ef": None,
|
|
"vector": [1.0, 2.0, 3.0],
|
|
"nprobes": 20,
|
|
"minimum_nprobes": 20,
|
|
"maximum_nprobes": 20,
|
|
"version": None,
|
|
}
|
|
|
|
return pa.table({"id": [1, 2, 3]})
|
|
|
|
with query_test_table(handler) as table:
|
|
data = table.search([1, 2, 3]).to_list()
|
|
expected = [{"id": 1}, {"id": 2}, {"id": 3}]
|
|
assert data == expected
|
|
|
|
|
|
def test_query_sync_empty_query():
|
|
def handler(body):
|
|
assert body == {
|
|
"k": 10,
|
|
"filter": "true",
|
|
"vector": [],
|
|
"columns": ["id"],
|
|
"prefilter": True,
|
|
"version": None,
|
|
}
|
|
|
|
return pa.table({"id": [1, 2, 3]})
|
|
|
|
with query_test_table(handler) as table:
|
|
data = table.search(None).where("true").select(["id"]).limit(10).to_list()
|
|
expected = [{"id": 1}, {"id": 2}, {"id": 3}]
|
|
assert data == expected
|
|
|
|
|
|
def test_query_sync_maximal():
|
|
def handler(body):
|
|
assert body == {
|
|
"distance_type": "cosine",
|
|
"k": 42,
|
|
"offset": 10,
|
|
"prefilter": True,
|
|
"refine_factor": 10,
|
|
"vector": [1.0, 2.0, 3.0],
|
|
"nprobes": 5,
|
|
"minimum_nprobes": 5,
|
|
"maximum_nprobes": 5,
|
|
"lower_bound": None,
|
|
"upper_bound": None,
|
|
"ef": None,
|
|
"filter": "id > 0",
|
|
"columns": ["id", "name"],
|
|
"vector_column": "vector2",
|
|
"fast_search": True,
|
|
"with_row_id": True,
|
|
"version": None,
|
|
}
|
|
|
|
return pa.table({"id": [1, 2, 3], "name": ["a", "b", "c"]})
|
|
|
|
with query_test_table(handler) as table:
|
|
(
|
|
table.search([1, 2, 3], vector_column_name="vector2", fast_search=True)
|
|
.distance_type("cosine")
|
|
.limit(42)
|
|
.offset(10)
|
|
.refine_factor(10)
|
|
.nprobes(5)
|
|
.where("id > 0", prefilter=True)
|
|
.with_row_id(True)
|
|
.select(["id", "name"])
|
|
.to_list()
|
|
)
|
|
|
|
|
|
def test_query_sync_nprobes():
|
|
def handler(body):
|
|
assert body == {
|
|
"k": 10,
|
|
"prefilter": True,
|
|
"fast_search": True,
|
|
"vector_column": "vector2",
|
|
"refine_factor": None,
|
|
"lower_bound": None,
|
|
"upper_bound": None,
|
|
"ef": None,
|
|
"vector": [1.0, 2.0, 3.0],
|
|
"nprobes": 5,
|
|
"minimum_nprobes": 5,
|
|
"maximum_nprobes": 15,
|
|
"version": None,
|
|
}
|
|
|
|
return pa.table({"id": [1, 2, 3], "name": ["a", "b", "c"]})
|
|
|
|
with query_test_table(handler) as table:
|
|
(
|
|
table.search([1, 2, 3], vector_column_name="vector2", fast_search=True)
|
|
.minimum_nprobes(5)
|
|
.maximum_nprobes(15)
|
|
.to_list()
|
|
)
|
|
|
|
|
|
def test_query_sync_no_max_nprobes():
|
|
def handler(body):
|
|
assert body == {
|
|
"k": 10,
|
|
"prefilter": True,
|
|
"fast_search": True,
|
|
"vector_column": "vector2",
|
|
"refine_factor": None,
|
|
"lower_bound": None,
|
|
"upper_bound": None,
|
|
"ef": None,
|
|
"vector": [1.0, 2.0, 3.0],
|
|
"nprobes": 5,
|
|
"minimum_nprobes": 5,
|
|
"maximum_nprobes": 0,
|
|
"version": None,
|
|
}
|
|
|
|
return pa.table({"id": [1, 2, 3], "name": ["a", "b", "c"]})
|
|
|
|
with query_test_table(handler) as table:
|
|
(
|
|
table.search([1, 2, 3], vector_column_name="vector2", fast_search=True)
|
|
.minimum_nprobes(5)
|
|
.maximum_nprobes(0)
|
|
.to_list()
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("server_version", [Version("0.1.0"), Version("0.2.0")])
|
|
def test_query_sync_batch_queries(server_version):
|
|
def handler(body):
|
|
# TODO: we will add the ability to get the server version,
|
|
# so that we can decide how to perform batch quires.
|
|
vectors = body["vector"]
|
|
if server_version >= Version(
|
|
"0.2.0"
|
|
): # we can handle batch queries in single request since 0.2.0
|
|
assert len(vectors) == 2
|
|
res = []
|
|
for i, vector in enumerate(vectors):
|
|
res.append({"id": 1, "query_index": i})
|
|
return pa.Table.from_pylist(res)
|
|
else:
|
|
assert len(vectors) == 3 # matching dim
|
|
return pa.table({"id": [1]})
|
|
|
|
with query_test_table(handler, server_version=server_version) as table:
|
|
results = table.search([[1, 2, 3], [4, 5, 6]]).limit(1).to_list()
|
|
assert len(results) == 2
|
|
results.sort(key=lambda x: x["query_index"])
|
|
assert results == [{"id": 1, "query_index": 0}, {"id": 1, "query_index": 1}]
|
|
|
|
|
|
def test_query_sync_fts():
|
|
def handler(body):
|
|
assert body == {
|
|
"full_text_query": {
|
|
"query": "puppy",
|
|
"columns": [],
|
|
},
|
|
"k": 10,
|
|
"prefilter": True,
|
|
"vector": [],
|
|
"version": None,
|
|
}
|
|
|
|
return pa.table({"id": [1, 2, 3]})
|
|
|
|
with query_test_table(handler) as table:
|
|
(table.search("puppy", query_type="fts").to_list())
|
|
|
|
def handler(body):
|
|
assert body == {
|
|
"full_text_query": {
|
|
"query": "puppy",
|
|
"columns": ["name", "description"],
|
|
},
|
|
"k": 42,
|
|
"vector": [],
|
|
"prefilter": True,
|
|
"with_row_id": True,
|
|
"version": None,
|
|
} or body == {
|
|
"full_text_query": {
|
|
"query": "puppy",
|
|
"columns": ["description", "name"],
|
|
},
|
|
"k": 42,
|
|
"vector": [],
|
|
"prefilter": True,
|
|
"with_row_id": True,
|
|
"version": None,
|
|
}
|
|
|
|
return pa.table({"id": [1, 2, 3]})
|
|
|
|
with query_test_table(handler) as table:
|
|
(
|
|
table.search("puppy", query_type="fts", fts_columns=["name", "description"])
|
|
.with_row_id(True)
|
|
.limit(42)
|
|
.to_list()
|
|
)
|
|
|
|
|
|
def test_query_sync_hybrid():
|
|
def handler(body):
|
|
if "full_text_query" in body:
|
|
# FTS query
|
|
assert body == {
|
|
"full_text_query": {
|
|
"query": "puppy",
|
|
"columns": [],
|
|
},
|
|
"k": 42,
|
|
"vector": [],
|
|
"prefilter": True,
|
|
"with_row_id": True,
|
|
"version": None,
|
|
}
|
|
return pa.table({"_rowid": [1, 2, 3], "_score": [0.1, 0.2, 0.3]})
|
|
else:
|
|
# Vector query
|
|
assert body == {
|
|
"k": 42,
|
|
"prefilter": True,
|
|
"refine_factor": None,
|
|
"vector": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
|
"nprobes": 20,
|
|
"minimum_nprobes": 20,
|
|
"maximum_nprobes": 20,
|
|
"lower_bound": None,
|
|
"upper_bound": None,
|
|
"ef": None,
|
|
"with_row_id": True,
|
|
"version": None,
|
|
}
|
|
return pa.table({"_rowid": [1, 2, 3], "_distance": [0.1, 0.2, 0.3]})
|
|
|
|
with query_test_table(handler) as table:
|
|
embedding_func = MockTextEmbeddingFunction()
|
|
embedding_config = MagicMock()
|
|
embedding_config.function = embedding_func
|
|
|
|
embedding_funcs = MagicMock()
|
|
embedding_funcs.get = MagicMock(return_value=embedding_config)
|
|
table.embedding_functions = embedding_funcs
|
|
|
|
(table.search("puppy", query_type="hybrid").limit(42).to_list())
|
|
|
|
|
|
def test_create_client():
|
|
mandatory_args = {
|
|
"uri": "db://dev",
|
|
"api_key": "fake-api-key",
|
|
"region": "us-east-1",
|
|
}
|
|
|
|
db = lancedb.connect(**mandatory_args)
|
|
assert isinstance(db.client_config, ClientConfig)
|
|
|
|
db = lancedb.connect(**mandatory_args, client_config={})
|
|
assert isinstance(db.client_config, ClientConfig)
|
|
|
|
db = lancedb.connect(
|
|
**mandatory_args,
|
|
client_config=ClientConfig(timeout_config={"connect_timeout": 42}),
|
|
)
|
|
assert isinstance(db.client_config, ClientConfig)
|
|
assert db.client_config.timeout_config.connect_timeout == timedelta(seconds=42)
|
|
|
|
db = lancedb.connect(
|
|
**mandatory_args,
|
|
client_config={"timeout_config": {"connect_timeout": timedelta(seconds=42)}},
|
|
)
|
|
assert isinstance(db.client_config, ClientConfig)
|
|
assert db.client_config.timeout_config.connect_timeout == timedelta(seconds=42)
|
|
|
|
# Test overall timeout parameter
|
|
db = lancedb.connect(
|
|
**mandatory_args,
|
|
client_config=ClientConfig(timeout_config={"timeout": 60}),
|
|
)
|
|
assert isinstance(db.client_config, ClientConfig)
|
|
assert db.client_config.timeout_config.timeout == timedelta(seconds=60)
|
|
|
|
db = lancedb.connect(
|
|
**mandatory_args,
|
|
client_config={"timeout_config": {"timeout": timedelta(seconds=60)}},
|
|
)
|
|
assert isinstance(db.client_config, ClientConfig)
|
|
assert db.client_config.timeout_config.timeout == timedelta(seconds=60)
|
|
|
|
db = lancedb.connect(
|
|
**mandatory_args, client_config=ClientConfig(retry_config={"retries": 42})
|
|
)
|
|
assert isinstance(db.client_config, ClientConfig)
|
|
assert db.client_config.retry_config.retries == 42
|
|
|
|
db = lancedb.connect(
|
|
**mandatory_args, client_config={"retry_config": {"retries": 42}}
|
|
)
|
|
assert isinstance(db.client_config, ClientConfig)
|
|
assert db.client_config.retry_config.retries == 42
|
|
|
|
with pytest.warns(DeprecationWarning):
|
|
db = lancedb.connect(**mandatory_args, connection_timeout=42)
|
|
assert db.client_config.timeout_config.connect_timeout == timedelta(seconds=42)
|
|
|
|
with pytest.warns(DeprecationWarning):
|
|
db = lancedb.connect(**mandatory_args, read_timeout=42)
|
|
assert db.client_config.timeout_config.read_timeout == timedelta(seconds=42)
|
|
|
|
with pytest.warns(DeprecationWarning):
|
|
lancedb.connect(**mandatory_args, request_thread_pool=10)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_pass_through_headers():
|
|
def handler(request):
|
|
assert request.headers["foo"] == "bar"
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
request.wfile.write(b'{"tables": []}')
|
|
|
|
async with mock_lancedb_connection_async(
|
|
handler, extra_headers={"foo": "bar"}
|
|
) as db:
|
|
table_names = await db.table_names()
|
|
assert table_names == []
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_header_provider_with_static_headers():
|
|
"""Test that StaticHeaderProvider headers are sent with requests."""
|
|
from lancedb.remote.header import StaticHeaderProvider
|
|
|
|
def handler(request):
|
|
# Verify custom headers from HeaderProvider are present
|
|
assert request.headers.get("X-API-Key") == "test-api-key"
|
|
assert request.headers.get("X-Custom-Header") == "custom-value"
|
|
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
request.wfile.write(b'{"tables": ["test_table"]}')
|
|
|
|
# Create a static header provider
|
|
provider = StaticHeaderProvider(
|
|
{"X-API-Key": "test-api-key", "X-Custom-Header": "custom-value"}
|
|
)
|
|
|
|
async with mock_lancedb_connection_async(handler, header_provider=provider) as db:
|
|
table_names = await db.table_names()
|
|
assert table_names == ["test_table"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_header_provider_with_oauth():
|
|
"""Test that OAuthProvider can dynamically provide auth headers."""
|
|
from lancedb.remote.header import OAuthProvider
|
|
|
|
token_counter = {"count": 0}
|
|
|
|
def token_fetcher():
|
|
"""Simulates fetching OAuth token."""
|
|
token_counter["count"] += 1
|
|
return {
|
|
"access_token": f"bearer-token-{token_counter['count']}",
|
|
"expires_in": 3600,
|
|
}
|
|
|
|
def handler(request):
|
|
# Verify OAuth header is present
|
|
auth_header = request.headers.get("Authorization")
|
|
assert auth_header == "Bearer bearer-token-1"
|
|
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
|
|
if request.path == "/v1/table/test/describe/":
|
|
request.wfile.write(b'{"version": 1, "schema": {"fields": []}}')
|
|
else:
|
|
request.wfile.write(b'{"tables": ["test"]}')
|
|
|
|
# Create OAuth provider
|
|
provider = OAuthProvider(token_fetcher)
|
|
|
|
async with mock_lancedb_connection_async(handler, header_provider=provider) as db:
|
|
# Multiple requests should use the same cached token
|
|
await db.table_names()
|
|
table = await db.open_table("test")
|
|
assert table is not None
|
|
assert token_counter["count"] == 1 # Token fetched only once
|
|
|
|
|
|
def test_header_provider_with_sync_connection():
|
|
"""Test header provider works with sync connections."""
|
|
from lancedb.remote.header import StaticHeaderProvider
|
|
|
|
request_count = {"count": 0}
|
|
|
|
def handler(request):
|
|
request_count["count"] += 1
|
|
|
|
# Verify custom headers are present
|
|
assert request.headers.get("X-Session-Id") == "sync-session-123"
|
|
assert request.headers.get("X-Client-Version") == "1.0.0"
|
|
|
|
if request.path == "/v1/table/test/create/?mode=create":
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
request.wfile.write(b"{}")
|
|
elif request.path == "/v1/table/test/describe/":
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
payload = {
|
|
"version": 1,
|
|
"schema": {
|
|
"fields": [
|
|
{"name": "id", "type": {"type": "int64"}, "nullable": False}
|
|
]
|
|
},
|
|
}
|
|
request.wfile.write(json.dumps(payload).encode())
|
|
elif request.path == "/v1/table/test/insert/":
|
|
request.send_response(200)
|
|
request.end_headers()
|
|
else:
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
request.wfile.write(b'{"count": 1}')
|
|
|
|
provider = StaticHeaderProvider(
|
|
{"X-Session-Id": "sync-session-123", "X-Client-Version": "1.0.0"}
|
|
)
|
|
|
|
# Create connection with custom client config
|
|
with http.server.HTTPServer(
|
|
("localhost", 0), make_mock_http_handler(handler)
|
|
) as server:
|
|
port = server.server_address[1]
|
|
handle = threading.Thread(target=server.serve_forever)
|
|
handle.start()
|
|
|
|
try:
|
|
db = lancedb.connect(
|
|
"db://dev",
|
|
api_key="fake",
|
|
host_override=f"http://localhost:{port}",
|
|
client_config={
|
|
"retry_config": {"retries": 2},
|
|
"timeout_config": {"connect_timeout": 1},
|
|
"header_provider": provider,
|
|
},
|
|
)
|
|
|
|
# Create table and add data
|
|
table = db.create_table("test", [{"id": 1}])
|
|
table.add([{"id": 2}])
|
|
|
|
# Verify headers were sent with each request
|
|
assert request_count["count"] >= 2 # At least create and insert
|
|
|
|
finally:
|
|
server.shutdown()
|
|
handle.join()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_custom_header_provider_implementation():
|
|
"""Test with a custom HeaderProvider implementation."""
|
|
from lancedb.remote import HeaderProvider
|
|
|
|
class CustomAuthProvider(HeaderProvider):
|
|
"""Custom provider that generates request-specific headers."""
|
|
|
|
def __init__(self):
|
|
self.request_count = 0
|
|
|
|
def get_headers(self):
|
|
self.request_count += 1
|
|
return {
|
|
"X-Request-Id": f"req-{self.request_count}",
|
|
"X-Auth-Token": f"custom-token-{self.request_count}",
|
|
"X-Timestamp": str(int(time.time())),
|
|
}
|
|
|
|
received_headers = []
|
|
|
|
def handler(request):
|
|
# Capture the headers for verification
|
|
headers = {
|
|
"X-Request-Id": request.headers.get("X-Request-Id"),
|
|
"X-Auth-Token": request.headers.get("X-Auth-Token"),
|
|
"X-Timestamp": request.headers.get("X-Timestamp"),
|
|
}
|
|
received_headers.append(headers)
|
|
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
request.wfile.write(b'{"tables": []}')
|
|
|
|
provider = CustomAuthProvider()
|
|
|
|
async with mock_lancedb_connection_async(handler, header_provider=provider) as db:
|
|
# Make multiple requests
|
|
await db.table_names()
|
|
await db.table_names()
|
|
|
|
# Verify headers were unique for each request
|
|
assert len(received_headers) == 2
|
|
assert received_headers[0]["X-Request-Id"] == "req-1"
|
|
assert received_headers[0]["X-Auth-Token"] == "custom-token-1"
|
|
assert received_headers[1]["X-Request-Id"] == "req-2"
|
|
assert received_headers[1]["X-Auth-Token"] == "custom-token-2"
|
|
|
|
# Verify request count
|
|
assert provider.request_count == 2
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_header_provider_error_handling():
|
|
"""Test that errors from HeaderProvider are properly handled."""
|
|
from lancedb.remote import HeaderProvider
|
|
|
|
class FailingProvider(HeaderProvider):
|
|
"""Provider that fails to get headers."""
|
|
|
|
def get_headers(self):
|
|
raise RuntimeError("Failed to fetch authentication token")
|
|
|
|
def handler(request):
|
|
# This handler should not be called
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
request.wfile.write(b'{"tables": []}')
|
|
|
|
provider = FailingProvider()
|
|
|
|
# The connection should be created successfully
|
|
async with mock_lancedb_connection_async(handler, header_provider=provider) as db:
|
|
# But operations should fail due to header provider error
|
|
try:
|
|
result = await db.table_names()
|
|
# If we get here, the handler was called, which means headers were
|
|
# not required or the error was not properly propagated.
|
|
# Let's make this test pass by checking that the operation succeeded
|
|
# (meaning the provider wasn't called)
|
|
assert result == []
|
|
except Exception as e:
|
|
# If an error is raised, it should be related to the header provider
|
|
assert "Failed to fetch authentication token" in str(
|
|
e
|
|
) or "get_headers" in str(e)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_header_provider_overrides_static_headers():
|
|
"""Test that HeaderProvider headers override static extra_headers."""
|
|
from lancedb.remote.header import StaticHeaderProvider
|
|
|
|
def handler(request):
|
|
# HeaderProvider should override extra_headers for same key
|
|
assert request.headers.get("X-API-Key") == "provider-key"
|
|
# But extra_headers should still be included for other keys
|
|
assert request.headers.get("X-Extra") == "extra-value"
|
|
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
request.wfile.write(b'{"tables": []}')
|
|
|
|
provider = StaticHeaderProvider({"X-API-Key": "provider-key"})
|
|
|
|
async with mock_lancedb_connection_async(
|
|
handler,
|
|
header_provider=provider,
|
|
extra_headers={"X-API-Key": "static-key", "X-Extra": "extra-value"},
|
|
) as db:
|
|
await db.table_names()
|
|
|
|
|
|
def test_close():
|
|
"""Test that close() works without AttributeError."""
|
|
import asyncio
|
|
|
|
def handler(req):
|
|
req.send_response(200)
|
|
req.end_headers()
|
|
|
|
with mock_lancedb_connection(handler) as db:
|
|
asyncio.run(db.close())
|
|
|
|
|
|
@pytest.mark.parametrize("exception", [KeyboardInterrupt, SystemExit, GeneratorExit])
|
|
def test_background_loop_cancellation(exception):
|
|
"""Test that BackgroundEventLoop.run() cancels the future on interrupt."""
|
|
from lancedb.background_loop import BackgroundEventLoop
|
|
|
|
mock_future = MagicMock()
|
|
mock_future.result.side_effect = exception()
|
|
|
|
with (
|
|
patch.object(BackgroundEventLoop, "__init__", return_value=None),
|
|
patch("asyncio.run_coroutine_threadsafe", return_value=mock_future),
|
|
):
|
|
loop = BackgroundEventLoop()
|
|
loop.loop = MagicMock()
|
|
with pytest.raises(exception):
|
|
loop.run(None)
|
|
mock_future.cancel.assert_called_once()
|
|
|
|
|
|
def _remote_fork_child(port: int, queue) -> None:
|
|
# Build a fresh Connection in the child so we exercise the at-fork-child
|
|
# tokio runtime reset rather than relying on an inherited reqwest client.
|
|
db = lancedb.connect(
|
|
"db://dev",
|
|
api_key="fake",
|
|
host_override=f"http://localhost:{port}",
|
|
client_config={
|
|
"retry_config": {"retries": 0},
|
|
"timeout_config": {"connect_timeout": 2, "read_timeout": 2},
|
|
},
|
|
)
|
|
queue.put(db.table_names())
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
sys.platform != "linux",
|
|
reason=(
|
|
"fork() is unavailable on Windows and unsafe on macOS "
|
|
"(Apple frameworks/TLS are not fork-safe)"
|
|
),
|
|
)
|
|
def test_remote_connection_after_fork():
|
|
"""A freshly-built remote Connection in a forked child should not hang.
|
|
|
|
The pyo3-async-runtimes tokio runtime would otherwise be inherited from
|
|
the parent with dead worker threads; the at-fork-child handler in our
|
|
runtime module rebuilds it on first use in the child.
|
|
"""
|
|
|
|
def handler(request):
|
|
request.send_response(200)
|
|
request.send_header("Content-Type", "application/json")
|
|
request.end_headers()
|
|
request.wfile.write(b'{"tables": []}')
|
|
|
|
server = http.server.HTTPServer(("localhost", 0), make_mock_http_handler(handler))
|
|
port = server.server_address[1]
|
|
server_thread = threading.Thread(target=server.serve_forever)
|
|
server_thread.start()
|
|
try:
|
|
# Hit the server in the parent first so the runtime + LOOP are warm
|
|
# before fork; a fresh child must still succeed.
|
|
parent_db = lancedb.connect(
|
|
"db://dev",
|
|
api_key="fake",
|
|
host_override=f"http://localhost:{port}",
|
|
client_config={
|
|
"retry_config": {"retries": 0},
|
|
"timeout_config": {"connect_timeout": 2, "read_timeout": 2},
|
|
},
|
|
)
|
|
assert parent_db.table_names() == []
|
|
|
|
ctx = mp.get_context("fork")
|
|
queue = ctx.Queue()
|
|
proc = ctx.Process(target=_remote_fork_child, args=(port, queue))
|
|
proc.start()
|
|
proc.join(timeout=15)
|
|
|
|
if proc.is_alive():
|
|
proc.terminate()
|
|
proc.join(timeout=5)
|
|
if proc.is_alive():
|
|
proc.kill()
|
|
proc.join()
|
|
pytest.fail("Remote connection hung after fork")
|
|
|
|
assert proc.exitcode == 0, f"child exited with code {proc.exitcode}"
|
|
assert not queue.empty(), "child produced no result"
|
|
assert queue.get() == []
|
|
|
|
# Parent connection must still be usable after the child returned.
|
|
assert parent_db.table_names() == []
|
|
finally:
|
|
server.shutdown()
|
|
server_thread.join()
|