diff --git a/python/python/lancedb/__init__.py b/python/python/lancedb/__init__.py index efeed258f..be7a2b0fd 100644 --- a/python/python/lancedb/__init__.py +++ b/python/python/lancedb/__init__.py @@ -304,6 +304,15 @@ def deserialize_conn( manifest_enabled=parsed.get("manifest_enabled", False), namespace_client_properties=parsed.get("namespace_client_properties"), ) + elif connection_type == "remote": + return RemoteDBConnection( + parsed["db_url"], + parsed["api_key"], + parsed.get("region", "us-east-1"), + host_override=parsed.get("host_override"), + client_config=parsed.get("client_config"), + storage_options=storage_options, + ) else: raise ValueError(f"Unknown connection_type: {connection_type}") diff --git a/python/python/lancedb/permutation.py b/python/python/lancedb/permutation.py index fdcebc69e..ae7a56377 100644 --- a/python/python/lancedb/permutation.py +++ b/python/python/lancedb/permutation.py @@ -3,12 +3,13 @@ import copy import json +import os from deprecation import deprecated import pyarrow as pa from ._lancedb import async_permutation_builder, PermutationReader -from .table import LanceTable +from .table import LanceTable, Table from .background_loop import LOOP from .util import batch_to_tensor, batch_to_tensor_rows from typing import Any, Callable, Iterator, Literal, Optional, TYPE_CHECKING, Union @@ -354,6 +355,49 @@ class Transforms: DEFAULT_BATCH_SIZE = 100 +def _table_to_pickle_state(table: Table) -> dict[str, Any]: + from .remote.table import RemoteTable + + if isinstance(table, RemoteTable): + return { + "kind": "remote", + "table": table, + } + + if not isinstance(table, LanceTable): + raise ValueError(f"Cannot pickle table of type {type(table)!r}") + + base_uri = table._conn.uri + if base_uri.startswith("memory://"): + return { + "kind": "memory", + "name": table.name, + "data": table.to_arrow(), + } + + return { + "kind": "local", + "name": table.name, + "uri": base_uri, + "namespace": table._namespace_path, + "storage_options": table._conn.storage_options, + } + + +def _table_from_pickle_state(state: dict[str, Any]) -> Table: + from . import connect + + kind = state["kind"] + if kind == "remote": + return state["table"] + if kind == "memory": + return connect("memory://").create_table(state["name"], state["data"]) + if kind == "local": + db = connect(state["uri"], storage_options=state["storage_options"]) + return db.open_table(state["name"], namespace_path=state["namespace"] or None) + raise ValueError(f"Unknown table pickle state kind: {kind}") + + class Permutation: """ A Permutation is a view of a dataset that can be used as input to model training @@ -369,15 +413,15 @@ class Permutation: def __init__( self, - base_table: LanceTable, - permutation_table: Optional[LanceTable], + base_table: Table, + permutation_table: Optional[Table], split: int, selection: dict[str, str], batch_size: int, transform_fn: Callable[pa.RecordBatch, Any], offset: Optional[int] = None, limit: Optional[int] = None, - connection_factory: Optional[Callable[[str], LanceTable]] = None, + connection_factory: Optional[Callable[[str], Table]] = None, _reader: Optional[PermutationReader] = None, ): """ @@ -397,6 +441,7 @@ class Permutation: if _reader is None: _reader = LOOP.run(self._build_reader()) self.reader: PermutationReader = _reader + self._pid = os.getpid() async def _build_reader(self) -> PermutationReader: reader = await PermutationReader.from_tables( @@ -428,29 +473,25 @@ class Permutation: return new def with_connection_factory( - self, connection_factory: Callable[[str], LanceTable] + self, connection_factory: Callable[[str], Table] ) -> "Permutation": """ Creates a new permutation that will use ``connection_factory`` to reopen the base table when this permutation is unpickled in a worker process. - The factory is a callable that takes a single argument — the base table - name — and returns a [LanceTable]. It must be picklable; the worker + The factory is a callable that takes a single argument, the base table + name, and returns a LanceDB table. It must be picklable; the worker will pickle it via standard ``pickle`` and call it to recover the base table. Picklable callables in practice means top-level (module-level) functions, ``functools.partial`` of such functions, or instances of picklable classes implementing ``__call__``. Lambdas and closures over local variables don't pickle with the default protocol. - Setting a factory is necessary when the URI alone is not enough to - re-open the connection — most importantly for LanceDB Cloud (``db://``) - connections, where ``api_key`` and ``region`` aren't recoverable from - the connection object after construction. - - For local file or cloud-storage paths the factory is optional: if not - set, ``__getstate__`` falls back to capturing - ``(uri, storage_options, namespace_path)`` and re-opening via - ``lancedb.connect(uri, storage_options=...)``. + A factory is optional for normal local and remote LanceDB connections: + if not set, ``__getstate__`` captures the table's own picklable reopen + state. Use a factory when that default state is not enough, for example + when credentials should be loaded from the worker environment instead + of being embedded in the pickle. Examples -------- @@ -508,7 +549,7 @@ class Permutation: return new @classmethod - def identity(cls, table: LanceTable) -> "Permutation": + def identity(cls, table: Table) -> "Permutation": """ Creates an identity permutation for the given table. """ @@ -517,8 +558,8 @@ class Permutation: @classmethod def from_tables( cls, - base_table: LanceTable, - permutation_table: Optional[LanceTable] = None, + base_table: Table, + permutation_table: Optional[Table] = None, split: Optional[Union[str, int]] = None, ) -> "Permutation": """ @@ -594,19 +635,24 @@ class Permutation: The base table is captured either via a user-supplied ``connection_factory`` (see [with_connection_factory]) or, as a - fallback, by introspecting ``(uri, storage_options, namespace_path)`` - on the connection. The permutation table — always an in-memory - LanceDB table — is captured as a pyarrow Table (which pickles via - Arrow IPC natively). The reader is dropped from the wire format; - ``__setstate__`` rebuilds it from the restored tables. + fallback, by the table's own picklable reopen state. An in-memory + permutation table is captured as a pyarrow Table (which pickles via + Arrow IPC natively); otherwise, the permutation table uses its own + reopen state too. The reader is dropped from the wire format and + rebuilt lazily on first use. """ permutation_data: Optional[pa.Table] = None + permutation_table_state: Optional[dict[str, Any]] = None if self.permutation_table is not None: - permutation_data = self.permutation_table.to_arrow() + try: + permutation_data = self.permutation_table.to_arrow() + except NotImplementedError: + permutation_table_state = _table_to_pickle_state(self.permutation_table) common = { "base_table_name": self.base_table.name, "permutation_data": permutation_data, + "permutation_table_state": permutation_table_state, "split": self.split, "selection": self.selection, "batch_size": self.batch_size, @@ -622,39 +668,9 @@ class Permutation: # namespace from the existing connection. return common - # URI-introspection fallback: only viable for native (OSS) connections - # where (uri, storage_options) is enough to reopen. Remote / cloud - # connections don't expose recoverable api_key / region — those users - # must call with_connection_factory(). - try: - base_uri = self.base_table._conn.uri - storage_options = self.base_table._conn.storage_options - except AttributeError as e: - raise ValueError( - "Cannot pickle this Permutation: the base table's connection " - "does not expose a uri/storage_options, which usually means it " - "is a remote (LanceDB Cloud) connection. Call " - "Permutation.with_connection_factory(...) first to provide a " - "picklable callable that re-opens the base table from a worker " - "process." - ) from e - - if base_uri.startswith("memory://"): - # In-memory base tables don't exist in any worker process by - # default, so dump the entire base table into the pickle. This - # can be expensive for large datasets — users with large - # in-memory base tables should either persist them or set a - # connection_factory. - return { - **common, - "base_table_data": self.base_table.to_arrow(), - } - return { **common, - "base_table_uri": base_uri, - "base_table_namespace": self.base_table._namespace_path, - "base_table_storage_options": storage_options, + "base_table_state": _table_to_pickle_state(self.base_table), } def __setstate__(self, state: dict[str, Any]) -> None: @@ -663,6 +679,8 @@ class Permutation: connection_factory = state["connection_factory"] if connection_factory is not None: base_table = connection_factory(state["base_table_name"]) + elif "base_table_state" in state: + base_table = _table_from_pickle_state(state["base_table_state"]) elif "base_table_data" in state: # In-memory base table inlined into the pickle; rebuild the same # way we rebuild the in-memory permutation table. @@ -680,8 +698,12 @@ class Permutation: namespace_path=state["base_table_namespace"] or None, ) - permutation_table: Optional[LanceTable] = None - if state["permutation_data"] is not None: + permutation_table: Optional[Table] = None + if state.get("permutation_table_state") is not None: + permutation_table = _table_from_pickle_state( + state["permutation_table_state"] + ) + elif state["permutation_data"] is not None: mem_db = connect("memory://") permutation_table = mem_db.create_table( "permutation", state["permutation_data"] @@ -696,10 +718,26 @@ class Permutation: self.offset = state["offset"] self.limit = state["limit"] self.connection_factory = connection_factory + self.reader = None + self._pid = None + + def _ensure_open(self) -> None: + pid = os.getpid() + if self.reader is not None and getattr(self, "_pid", None) == pid: + return + if hasattr(self.base_table, "_ensure_open"): + self.base_table._ensure_open() + if self.permutation_table is not None and hasattr( + self.permutation_table, "_ensure_open" + ): + self.permutation_table._ensure_open() self.reader = LOOP.run(self._build_reader()) + self._pid = pid @property def schema(self) -> pa.Schema: + self._ensure_open() + async def do_output_schema(): return await self.reader.output_schema(self.selection) @@ -717,6 +755,7 @@ class Permutation: """ The number of rows in the permutation """ + self._ensure_open() return self.reader.count_rows() @property @@ -875,6 +914,7 @@ class Permutation: If skip_last_batch is True, the last batch will be skipped if it is not a multiple of batch_size. """ + self._ensure_open() async def get_iter(): return await self.reader.read(self.selection, batch_size=batch_size) @@ -976,6 +1016,7 @@ class Permutation: so `with_format` and `with_transform` affect this method in the same way they affect iteration. """ + self._ensure_open() async def do_take_offsets(): return await self.reader.take_offsets(offsets, selection=self.selection) @@ -1011,9 +1052,11 @@ class Permutation: """ Skip the first `skip` rows of the permutation """ + self._ensure_open() new = copy.copy(self) new.offset = skip new.reader = LOOP.run(new._build_reader()) + new._pid = os.getpid() return new @deprecated(details="Use with_take instead") @@ -1032,9 +1075,11 @@ class Permutation: """ Limit the permutation to `limit` rows (following any `skip`) """ + self._ensure_open() new = copy.copy(self) new.limit = limit new.reader = LOOP.run(new._build_reader()) + new._pid = os.getpid() return new @deprecated(details="Use with_repeat instead") diff --git a/python/python/lancedb/remote/db.py b/python/python/lancedb/remote/db.py index e110cdac1..8f1aeda66 100644 --- a/python/python/lancedb/remote/db.py +++ b/python/python/lancedb/remote/db.py @@ -3,6 +3,7 @@ from datetime import timedelta +import json import logging from concurrent.futures import ThreadPoolExecutor import sys @@ -17,7 +18,7 @@ else: # Remove this import to fix circular dependency # from lancedb import connect_async -from lancedb.remote import ClientConfig +from lancedb.remote import ClientConfig, RetryConfig, TimeoutConfig, TlsConfig import pyarrow as pa from ..common import DATA @@ -36,6 +37,64 @@ from ..table import Table from ..util import validate_table_name +def _duration_seconds(value: Optional[timedelta]) -> Optional[float]: + return value.total_seconds() if value is not None else None + + +def _timeout_config_to_dict( + config: Optional[TimeoutConfig], +) -> Optional[dict[str, Any]]: + if config is None: + return None + return { + "timeout": _duration_seconds(config.timeout), + "connect_timeout": _duration_seconds(config.connect_timeout), + "read_timeout": _duration_seconds(config.read_timeout), + "pool_idle_timeout": _duration_seconds(config.pool_idle_timeout), + } + + +def _retry_config_to_dict(config: RetryConfig) -> dict[str, Any]: + return { + "retries": config.retries, + "connect_retries": config.connect_retries, + "read_retries": config.read_retries, + "backoff_factor": config.backoff_factor, + "backoff_jitter": config.backoff_jitter, + "statuses": config.statuses, + } + + +def _tls_config_to_dict(config: Optional[TlsConfig]) -> Optional[dict[str, Any]]: + if config is None: + return None + return { + "cert_file": config.cert_file, + "key_file": config.key_file, + "ssl_ca_cert": config.ssl_ca_cert, + "assert_hostname": config.assert_hostname, + } + + +def _client_config_to_dict(config: ClientConfig) -> dict[str, Any]: + if config.header_provider is not None: + raise ValueError( + "Cannot serialize a remote connection with a header_provider. " + "Use static api_key/extra_headers or provide a worker-side " + "connection factory instead." + ) + return { + "user_agent": config.user_agent, + "retry_config": _retry_config_to_dict(config.retry_config), + "timeout_config": _timeout_config_to_dict(config.timeout_config), + "extra_headers": config.extra_headers, + "id_delimiter": config.id_delimiter, + "tls_config": _tls_config_to_dict(config.tls_config), + "header_provider": None, + "user_id": config.user_id, + } + + class RemoteDBConnection(DBConnection): """A connection to a remote LanceDB database.""" @@ -88,6 +147,11 @@ class RemoteDBConnection(DBConnection): parsed = urlparse(db_url) if parsed.scheme != "db": raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://") + self.db_url = db_url + self.api_key = api_key + self.region = region + self.host_override = host_override + self.storage_options = storage_options self.db_name = parsed.netloc self.client_config = client_config @@ -109,6 +173,20 @@ class RemoteDBConnection(DBConnection): def __repr__(self) -> str: return f"RemoteConnect(name={self.db_name})" + @override + def serialize(self) -> str: + return json.dumps( + { + "connection_type": "remote", + "db_url": self.db_url, + "api_key": self.api_key, + "region": self.region, + "host_override": self.host_override, + "client_config": _client_config_to_dict(self.client_config), + "storage_options": self.storage_options, + } + ) + @override def list_namespaces( self, @@ -329,7 +407,12 @@ class RemoteDBConnection(DBConnection): ) table = LOOP.run(self._conn.open_table(name, namespace_path=namespace_path)) - return RemoteTable(table, self.db_name) + return RemoteTable( + table, + self.db_name, + connection_state=self.serialize, + namespace_path=namespace_path, + ) def clone_table( self, @@ -378,7 +461,12 @@ class RemoteDBConnection(DBConnection): is_shallow=is_shallow, ) ) - return RemoteTable(table, self.db_name) + return RemoteTable( + table, + self.db_name, + connection_state=self.serialize, + namespace_path=target_namespace_path, + ) @override def create_table( @@ -523,7 +611,12 @@ class RemoteDBConnection(DBConnection): fill_value=fill_value, ) ) - return RemoteTable(table, self.db_name) + return RemoteTable( + table, + self.db_name, + connection_state=self.serialize, + namespace_path=namespace_path, + ) @override def drop_table(self, name: str, namespace_path: Optional[List[str]] = None): diff --git a/python/python/lancedb/remote/table.py b/python/python/lancedb/remote/table.py index c2fdcfae9..bba252f1c 100644 --- a/python/python/lancedb/remote/table.py +++ b/python/python/lancedb/remote/table.py @@ -4,6 +4,7 @@ from datetime import timedelta import logging from functools import cached_property +import os from typing import Any, Callable, Dict, Iterable, List, Optional, Union, Literal import warnings @@ -49,14 +50,78 @@ class RemoteTable(Table): self, table: AsyncTable, db_name: str, + *, + connection_state: Optional[Union[str, Callable[[], str]]] = None, + namespace_path: Optional[List[str]] = None, ): - self._table = table + self._table_handle = table + self._name = table.name self.db_name = db_name + self._connection_state = connection_state + self._namespace_path = list(namespace_path or []) + self._checkout_version: Optional[int] = None + self._pid = os.getpid() + + def _serialized_connection_state(self) -> str: + if self._connection_state is None: + raise RuntimeError( + "Cannot reopen this remote table because it does not carry " + "serialized connection state" + ) + if callable(self._connection_state): + self._connection_state = self._connection_state() + return self._connection_state + + @property + def _table(self) -> AsyncTable: + self._ensure_open() + assert self._table_handle is not None + return self._table_handle + + @_table.setter + def _table(self, table: AsyncTable) -> None: + self._table_handle = table + self._name = table.name + self._pid = os.getpid() + + def _ensure_open(self) -> None: + pid = os.getpid() + if self._table_handle is not None and self._pid == pid: + return + + from lancedb import deserialize_conn + + db = deserialize_conn(self._serialized_connection_state(), for_worker=True) + table = db.open_table(self._name, namespace_path=self._namespace_path) + if self._checkout_version is not None: + table.checkout(self._checkout_version) + + self._table_handle = table._table + self.db_name = table.db_name + self._pid = pid + + def __getstate__(self) -> dict: + return { + "connection_state": self._serialized_connection_state(), + "db_name": self.db_name, + "name": self.name, + "namespace_path": self._namespace_path, + "checkout_version": self._checkout_version, + } + + def __setstate__(self, state: dict) -> None: + self._table_handle = None + self._name = state["name"] + self.db_name = state["db_name"] + self._connection_state = state["connection_state"] + self._namespace_path = state["namespace_path"] + self._checkout_version = state["checkout_version"] + self._pid = None @property def name(self) -> str: """The name of the table""" - return self._table.name + return self._name def __repr__(self) -> str: return f"RemoteTable({self.db_name}.{self.name})" @@ -106,13 +171,19 @@ class RemoteTable(Table): raise NotImplementedError("to_pandas() is not yet supported on LanceDB cloud.") def checkout(self, version: Union[int, str]): - return LOOP.run(self._table.checkout(version)) + result = LOOP.run(self._table.checkout(version)) + self._checkout_version = self.version + return result def checkout_latest(self): - return LOOP.run(self._table.checkout_latest()) + result = LOOP.run(self._table.checkout_latest()) + self._checkout_version = None + return result def restore(self, version: Optional[Union[int, str]] = None): - return LOOP.run(self._table.restore(version)) + result = LOOP.run(self._table.restore(version)) + self._checkout_version = None + return result def list_indices(self) -> Iterable[IndexConfig]: """List all the indices on the table""" diff --git a/python/python/tests/test_remote_db.py b/python/python/tests/test_remote_db.py index c50cf29f9..279f93658 100644 --- a/python/python/tests/test_remote_db.py +++ b/python/python/tests/test_remote_db.py @@ -1,12 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright The LanceDB Authors -import re from concurrent.futures import ThreadPoolExecutor import contextlib from datetime import timedelta import http.server import json import multiprocessing as mp +import pickle +import re import sys import threading import time @@ -171,6 +172,155 @@ def test_table_len_sync(): assert len(table) == 1 +def test_remote_connection_serializes(): + def handler(request): + request.send_response(200) + request.send_header("Content-Type", "application/json") + request.end_headers() + request.wfile.write(b'{"tables": []}') + + with mock_lancedb_connection(handler) as db: + serialized = json.loads(db.serialize()) + assert isinstance(serialized["client_config"], dict) + restored = lancedb.deserialize_conn(db.serialize()) + assert restored.table_names() == [] + + +def test_remote_table_is_picklable(): + def handler(request): + request.close_connection = True + if request.path == "/v1/table/test/describe/": + request.send_response(200) + request.send_header("Content-Type", "application/json") + request.end_headers() + payload = json.dumps( + { + "version": 1, + "schema": { + "fields": [ + {"name": "id", "type": {"type": "int64"}, "nullable": False} + ] + }, + } + ) + request.wfile.write(payload.encode()) + elif request.path == "/v1/table/test/count_rows/": + request.send_response(200) + request.send_header("Content-Type", "application/json") + request.end_headers() + request.wfile.write(b"3") + else: + request.send_response(404) + request.end_headers() + + with mock_lancedb_connection(handler) as db: + table = db.open_table("test") + restored = pickle.loads(pickle.dumps(table)) + assert restored.count_rows() == 3 + + +def test_remote_table_open_does_not_require_picklable_client_config(): + from lancedb.remote import HeaderProvider + + class LocalHeaderProvider(HeaderProvider): + def get_headers(self): + return {"X-Test-Header": "present"} + + def handler(request): + request.close_connection = True + assert request.headers.get("X-Test-Header") == "present" + if request.path == "/v1/table/test/describe/": + request.send_response(200) + request.send_header("Content-Type", "application/json") + request.end_headers() + request.wfile.write(b'{"version": 1, "schema": {"fields": []}}') + elif request.path == "/v1/table/test/count_rows/": + request.send_response(200) + request.send_header("Content-Type", "application/json") + request.end_headers() + request.wfile.write(b"3") + else: + request.send_response(404) + request.end_headers() + + with http.server.HTTPServer( + ("localhost", 0), make_mock_http_handler(handler) + ) as server: + port = server.server_address[1] + handle = threading.Thread(target=server.serve_forever) + handle.start() + try: + db = lancedb.connect( + "db://dev", + api_key="fake", + host_override=f"http://localhost:{port}", + client_config={ + "retry_config": {"retries": 0}, + "timeout_config": {"connect_timeout": 2, "read_timeout": 2}, + "header_provider": LocalHeaderProvider(), + }, + ) + table = db.open_table("test") + assert table.count_rows() == 3 + with pytest.raises(ValueError, match="header_provider"): + pickle.dumps(table) + finally: + server.shutdown() + handle.join() + + +def test_remote_permutation_is_picklable(): + from lancedb.permutation import Permutation + + rows = list(range(10)) + + def handler(request): + request.close_connection = True + if request.path == "/v1/table/test/describe/": + request.send_response(200) + request.send_header("Content-Type", "application/json") + request.end_headers() + payload = json.dumps( + { + "version": 1, + "schema": { + "fields": [ + {"name": "a", "type": {"type": "int64"}, "nullable": False} + ] + }, + } + ) + request.wfile.write(payload.encode()) + elif request.path == "/v1/table/test/count_rows/": + request.send_response(200) + request.send_header("Content-Type", "application/json") + request.end_headers() + request.wfile.write(str(len(rows)).encode()) + elif request.path == "/v1/table/test/query/": + content_len = int(request.headers.get("Content-Length")) + body = json.loads(request.rfile.read(content_len)) + if "filter" in body: + match = re.search(r"_rowoffset in \((.*?)\)", body["filter"]) + offsets = [int(offset.strip()) for offset in match.group(1).split(",")] + else: + offsets = rows + table = pa.table({"a": [rows[offset] for offset in offsets]}) + + request.send_response(200) + request.send_header("Content-Type", "application/vnd.apache.arrow.file") + request.end_headers() + with pa.ipc.new_file(request.wfile, schema=table.schema) as writer: + writer.write_table(table) + else: + request.send_response(404) + request.end_headers() + + with mock_lancedb_connection(handler) as db: + permutation = Permutation.identity(db.open_table("test")) + restored = pickle.loads(pickle.dumps(permutation)) + assert restored.__getitems__([0, 2, 4]) == [{"a": 0}, {"a": 2}, {"a": 4}] + + def test_create_table_exist_ok(): def handler(request): if request.path == "/v1/table/test/create/?mode=exist_ok": @@ -1305,6 +1455,10 @@ def _remote_fork_child(port: int, queue) -> None: queue.put(db.table_names()) +def _remote_table_fork_child(table, queue) -> None: + queue.put(table.count_rows()) + + @pytest.mark.skipif( sys.platform != "linux", reason=( @@ -1367,3 +1521,65 @@ def test_remote_connection_after_fork(): finally: server.shutdown() server_thread.join() + + +@pytest.mark.skipif( + sys.platform != "linux", + reason=( + "fork() is unavailable on Windows and unsafe on macOS " + "(Apple frameworks/TLS are not fork-safe)" + ), +) +def test_inherited_remote_table_reopens_after_fork(): + def handler(request): + if request.path == "/v1/table/test/describe/": + request.send_response(200) + request.send_header("Content-Type", "application/json") + request.end_headers() + request.wfile.write(b'{"version": 1, "schema": {"fields": []}}') + elif request.path == "/v1/table/test/count_rows/": + request.send_response(200) + request.send_header("Content-Type", "application/json") + request.end_headers() + request.wfile.write(b"7") + else: + request.send_response(404) + request.end_headers() + + server = http.server.HTTPServer(("localhost", 0), make_mock_http_handler(handler)) + port = server.server_address[1] + server_thread = threading.Thread(target=server.serve_forever) + server_thread.start() + try: + db = lancedb.connect( + "db://dev", + api_key="fake", + host_override=f"http://localhost:{port}", + client_config={ + "retry_config": {"retries": 0}, + "timeout_config": {"connect_timeout": 2, "read_timeout": 2}, + }, + ) + table = db.open_table("test") + assert table.count_rows() == 7 + + ctx = mp.get_context("fork") + queue = ctx.Queue() + proc = ctx.Process(target=_remote_table_fork_child, args=(table, queue)) + proc.start() + proc.join(timeout=15) + + if proc.is_alive(): + proc.terminate() + proc.join(timeout=5) + if proc.is_alive(): + proc.kill() + proc.join() + pytest.fail("Remote table hung after fork") + + assert proc.exitcode == 0, f"child exited with code {proc.exitcode}" + assert not queue.empty(), "child produced no result" + assert queue.get() == 7 + finally: + server.shutdown() + server_thread.join() diff --git a/python/python/tests/test_torch.py b/python/python/tests/test_torch.py index d17e60bbd..8337568a3 100644 --- a/python/python/tests/test_torch.py +++ b/python/python/tests/test_torch.py @@ -1,10 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright The LanceDB Authors +import contextlib import functools +import http.server +import json import multiprocessing as mp import pickle +import re import sys +import threading import lancedb import pyarrow as pa @@ -15,6 +20,107 @@ from lancedb.util import tbl_to_tensor torch = pytest.importorskip("torch") +REMOTE_ROWS = list(range(100)) + + +def _make_mock_http_handler(handler): + class MockLanceDBHandler(http.server.BaseHTTPRequestHandler): + def do_GET(self): + handler(self) + + def do_POST(self): + handler(self) + + return MockLanceDBHandler + + +def _remote_schema_payload(): + return { + "version": 1, + "schema": { + "fields": [ + {"name": "a", "type": {"type": "int64"}, "nullable": False}, + ] + }, + } + + +def _offsets_from_filter(filter_sql: str | None) -> list[int]: + if filter_sql is None: + return REMOTE_ROWS + match = re.search(r"_rowoffset in \((.*?)\)", filter_sql) + if match is None: + return REMOTE_ROWS + raw_offsets = match.group(1).strip() + if raw_offsets == "": + return [] + return [int(offset.strip()) for offset in raw_offsets.split(",")] + + +def _remote_dataset_handler(request): + request.close_connection = True + if request.path == "/v1/table/test/describe/": + request.send_response(200) + request.send_header("Content-Type", "application/json") + request.end_headers() + request.wfile.write(json.dumps(_remote_schema_payload()).encode()) + elif request.path == "/v1/table/test/count_rows/": + request.send_response(200) + request.send_header("Content-Type", "application/json") + request.end_headers() + request.wfile.write(str(len(REMOTE_ROWS)).encode()) + elif request.path == "/v1/table/test/query/": + content_len = int(request.headers.get("Content-Length")) + body = json.loads(request.rfile.read(content_len)) + offsets = _offsets_from_filter(body.get("filter")) + requested_columns = body.get("columns") or ["a"] + if isinstance(requested_columns, dict): + requested_columns = list(requested_columns) + + data = {} + for column in requested_columns: + if column == "a": + data[column] = [REMOTE_ROWS[offset] for offset in offsets] + elif column == "_rowoffset": + data[column] = offsets + elif column == "_rowid": + data[column] = offsets + + table = pa.table(data) + request.send_response(200) + request.send_header("Content-Type", "application/vnd.apache.arrow.file") + request.end_headers() + with pa.ipc.new_file(request.wfile, schema=table.schema) as writer: + writer.write_table(table) + else: + request.send_response(404) + request.end_headers() + + +@contextlib.contextmanager +def _remote_dataset_table(): + with http.server.ThreadingHTTPServer( + ("localhost", 0), _make_mock_http_handler(_remote_dataset_handler) + ) as server: + port = server.server_address[1] + handle = threading.Thread(target=server.serve_forever) + handle.start() + try: + db = lancedb.connect( + "db://dev", + api_key="fake", + host_override=f"http://localhost:{port}", + client_config={ + "retry_config": {"retries": 0}, + "timeout_config": {"connect_timeout": 2, "read_timeout": 2}, + }, + ) + yield db.open_table("test") + finally: + server.shutdown() + handle.join() + + def _open_native_table(uri: str, table_name: str): """Top-level connection factory used by the explicit-factory pickle test. @@ -107,6 +213,39 @@ def test_permutation_dataloader_multiprocessing(tmp_db): assert seen == 1000 +def test_remote_table_dataloader_multiprocessing(): + with _remote_dataset_table() as table: + dataloader = torch.utils.data.DataLoader( + table, + collate_fn=tbl_to_tensor, + batch_size=10, + num_workers=2, + multiprocessing_context="spawn", + ) + seen = 0 + for batch in dataloader: + assert batch.size(0) == 1 + assert batch.size(1) == 10 + seen += batch.size(1) + assert seen == len(REMOTE_ROWS) + + +def test_remote_permutation_dataloader_multiprocessing(): + with _remote_dataset_table() as table: + permutation = Permutation.identity(table) + dataloader = torch.utils.data.DataLoader( + permutation, + batch_size=10, + num_workers=2, + multiprocessing_context="spawn", + ) + seen = 0 + for batch in dataloader: + assert batch["a"].size(0) == 10 + seen += batch["a"].size(0) + assert seen == len(REMOTE_ROWS) + + def test_permutation_pickle_with_connection_factory(tmp_path): """When the user provides a connection_factory, pickling should round-trip through that factory rather than introspecting the connection URI. Useful @@ -171,6 +310,35 @@ def _multiworker_dataloader_target(db_uri: str, result_queue): result_queue.put(count) +def _remote_multiworker_dataloader_target(port: int, result_queue): + import lancedb + from lancedb.permutation import Permutation + + db = lancedb.connect( + "db://dev", + api_key="fake", + host_override=f"http://localhost:{port}", + client_config={ + "retry_config": {"retries": 0}, + "timeout_config": {"connect_timeout": 2, "read_timeout": 2}, + }, + ) + table = db.open_table("test") + permutation = Permutation.identity(table) + + dataloader = torch.utils.data.DataLoader( + permutation, + batch_size=10, + num_workers=2, + multiprocessing_context="fork", + ) + count = 0 + for batch in dataloader: + assert batch["a"].size(0) == 10 + count += 1 + result_queue.put(count) + + @pytest.mark.skipif( sys.platform != "linux", reason=( @@ -208,3 +376,46 @@ def test_permutation_dataloader_fork_workers(tmp_path): assert proc.exitcode == 0, f"child exited with code {proc.exitcode}" assert not queue.empty(), "child produced no batches" assert queue.get() == 100 + + +@pytest.mark.skipif( + sys.platform != "linux", + reason=( + "fork() is unavailable on Windows and unsafe on macOS " + "(Apple frameworks/TLS are not fork-safe)" + ), +) +def test_remote_permutation_dataloader_fork_workers(): + with http.server.ThreadingHTTPServer( + ("localhost", 0), _make_mock_http_handler(_remote_dataset_handler) + ) as server: + port = server.server_address[1] + handle = threading.Thread(target=server.serve_forever) + handle.start() + try: + ctx = mp.get_context("spawn") + queue = ctx.Queue() + proc = ctx.Process( + target=_remote_multiworker_dataloader_target, + args=(port, queue), + ) + proc.start() + proc.join(timeout=30) + + if proc.is_alive(): + proc.terminate() + proc.join(timeout=5) + if proc.is_alive(): + proc.kill() + proc.join() + pytest.fail( + "Remote permutation hung when iterated in a fork-based " + "DataLoader worker" + ) + + assert proc.exitcode == 0, f"child exited with code {proc.exitcode}" + assert not queue.empty(), "child produced no batches" + assert queue.get() == 10 + finally: + server.shutdown() + handle.join()