feat(python): support remote tables in PyTorch dataloaders (#3432)

This PR makes remote LanceDB tables usable from PyTorch multiprocessing
workers. Remote tables now carry enough safe JSON connection state to
reopen themselves after pickle/spawn or fork, and permutations lazily
rebuild their reader from restored tables instead of trying to reuse
process-local handles.

This addresses the remote-table gap in the PyTorch dataset path while
preserving the explicit connection factory escape hatch for custom
worker-side credential loading or non-serializable header providers.

Validated with targeted remote table, permutation, and PyTorch
DataLoader tests.
This commit is contained in:
Xuanwo
2026-06-02 15:38:28 +08:00
committed by GitHub
parent f20ec99dec
commit a327044e2f
6 changed files with 705 additions and 66 deletions

View File

@@ -1,12 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import re
from concurrent.futures import ThreadPoolExecutor
import contextlib
from datetime import timedelta
import http.server
import json
import multiprocessing as mp
import pickle
import re
import sys
import threading
import time
@@ -171,6 +172,155 @@ def test_table_len_sync():
assert len(table) == 1
def test_remote_connection_serializes():
def handler(request):
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
request.wfile.write(b'{"tables": []}')
with mock_lancedb_connection(handler) as db:
serialized = json.loads(db.serialize())
assert isinstance(serialized["client_config"], dict)
restored = lancedb.deserialize_conn(db.serialize())
assert restored.table_names() == []
def test_remote_table_is_picklable():
def handler(request):
request.close_connection = True
if request.path == "/v1/table/test/describe/":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
payload = json.dumps(
{
"version": 1,
"schema": {
"fields": [
{"name": "id", "type": {"type": "int64"}, "nullable": False}
]
},
}
)
request.wfile.write(payload.encode())
elif request.path == "/v1/table/test/count_rows/":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
request.wfile.write(b"3")
else:
request.send_response(404)
request.end_headers()
with mock_lancedb_connection(handler) as db:
table = db.open_table("test")
restored = pickle.loads(pickle.dumps(table))
assert restored.count_rows() == 3
def test_remote_table_open_does_not_require_picklable_client_config():
from lancedb.remote import HeaderProvider
class LocalHeaderProvider(HeaderProvider):
def get_headers(self):
return {"X-Test-Header": "present"}
def handler(request):
request.close_connection = True
assert request.headers.get("X-Test-Header") == "present"
if request.path == "/v1/table/test/describe/":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
request.wfile.write(b'{"version": 1, "schema": {"fields": []}}')
elif request.path == "/v1/table/test/count_rows/":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
request.wfile.write(b"3")
else:
request.send_response(404)
request.end_headers()
with http.server.HTTPServer(
("localhost", 0), make_mock_http_handler(handler)
) as server:
port = server.server_address[1]
handle = threading.Thread(target=server.serve_forever)
handle.start()
try:
db = lancedb.connect(
"db://dev",
api_key="fake",
host_override=f"http://localhost:{port}",
client_config={
"retry_config": {"retries": 0},
"timeout_config": {"connect_timeout": 2, "read_timeout": 2},
"header_provider": LocalHeaderProvider(),
},
)
table = db.open_table("test")
assert table.count_rows() == 3
with pytest.raises(ValueError, match="header_provider"):
pickle.dumps(table)
finally:
server.shutdown()
handle.join()
def test_remote_permutation_is_picklable():
from lancedb.permutation import Permutation
rows = list(range(10))
def handler(request):
request.close_connection = True
if request.path == "/v1/table/test/describe/":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
payload = json.dumps(
{
"version": 1,
"schema": {
"fields": [
{"name": "a", "type": {"type": "int64"}, "nullable": False}
]
},
}
)
request.wfile.write(payload.encode())
elif request.path == "/v1/table/test/count_rows/":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
request.wfile.write(str(len(rows)).encode())
elif request.path == "/v1/table/test/query/":
content_len = int(request.headers.get("Content-Length"))
body = json.loads(request.rfile.read(content_len))
if "filter" in body:
match = re.search(r"_rowoffset in \((.*?)\)", body["filter"])
offsets = [int(offset.strip()) for offset in match.group(1).split(",")]
else:
offsets = rows
table = pa.table({"a": [rows[offset] for offset in offsets]})
request.send_response(200)
request.send_header("Content-Type", "application/vnd.apache.arrow.file")
request.end_headers()
with pa.ipc.new_file(request.wfile, schema=table.schema) as writer:
writer.write_table(table)
else:
request.send_response(404)
request.end_headers()
with mock_lancedb_connection(handler) as db:
permutation = Permutation.identity(db.open_table("test"))
restored = pickle.loads(pickle.dumps(permutation))
assert restored.__getitems__([0, 2, 4]) == [{"a": 0}, {"a": 2}, {"a": 4}]
def test_create_table_exist_ok():
def handler(request):
if request.path == "/v1/table/test/create/?mode=exist_ok":
@@ -1400,6 +1550,10 @@ def _remote_fork_child(port: int, queue) -> None:
queue.put(db.table_names())
def _remote_table_fork_child(table, queue) -> None:
queue.put(table.count_rows())
@pytest.mark.skipif(
sys.platform != "linux",
reason=(
@@ -1462,3 +1616,65 @@ def test_remote_connection_after_fork():
finally:
server.shutdown()
server_thread.join()
@pytest.mark.skipif(
sys.platform != "linux",
reason=(
"fork() is unavailable on Windows and unsafe on macOS "
"(Apple frameworks/TLS are not fork-safe)"
),
)
def test_inherited_remote_table_reopens_after_fork():
def handler(request):
if request.path == "/v1/table/test/describe/":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
request.wfile.write(b'{"version": 1, "schema": {"fields": []}}')
elif request.path == "/v1/table/test/count_rows/":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
request.wfile.write(b"7")
else:
request.send_response(404)
request.end_headers()
server = http.server.HTTPServer(("localhost", 0), make_mock_http_handler(handler))
port = server.server_address[1]
server_thread = threading.Thread(target=server.serve_forever)
server_thread.start()
try:
db = lancedb.connect(
"db://dev",
api_key="fake",
host_override=f"http://localhost:{port}",
client_config={
"retry_config": {"retries": 0},
"timeout_config": {"connect_timeout": 2, "read_timeout": 2},
},
)
table = db.open_table("test")
assert table.count_rows() == 7
ctx = mp.get_context("fork")
queue = ctx.Queue()
proc = ctx.Process(target=_remote_table_fork_child, args=(table, queue))
proc.start()
proc.join(timeout=15)
if proc.is_alive():
proc.terminate()
proc.join(timeout=5)
if proc.is_alive():
proc.kill()
proc.join()
pytest.fail("Remote table hung after fork")
assert proc.exitcode == 0, f"child exited with code {proc.exitcode}"
assert not queue.empty(), "child produced no result"
assert queue.get() == 7
finally:
server.shutdown()
server_thread.join()