From 4d10d22e922abe45bf1574a24fed866938e5a870 Mon Sep 17 00:00:00 2001 From: Xuanwo Date: Tue, 2 Jun 2026 18:15:21 +0800 Subject: [PATCH] fix(python): move table pickle state onto tables --- python/python/lancedb/_lancedb.pyi | 1 + python/python/lancedb/permutation.py | 28 ++++------ python/python/lancedb/remote/table.py | 33 ++++++++++-- python/python/lancedb/table.py | 75 +++++++++++++++++++++++++++ python/python/tests/test_remote_db.py | 41 +++++++++++++++ python/python/tests/test_table.py | 31 +++++++++++ python/src/table.rs | 17 ++++++ 7 files changed, 203 insertions(+), 23 deletions(-) diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index 0148f6575..ad4f29ed6 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -217,6 +217,7 @@ class Table: async def uri(self) -> str: ... async def initial_storage_options(self) -> Optional[Dict[str, str]]: ... async def latest_storage_options(self) -> Optional[Dict[str, str]]: ... + async def _table_reopen_state(self) -> Dict[str, Any]: ... async def set_unenforced_primary_key(self, columns: List[str]) -> None: ... async def set_lsm_write_spec(self, spec: LsmWriteSpec) -> None: ... async def unset_lsm_write_spec(self) -> None: ... diff --git a/python/python/lancedb/permutation.py b/python/python/lancedb/permutation.py index c57b96630..b95a9acf0 100644 --- a/python/python/lancedb/permutation.py +++ b/python/python/lancedb/permutation.py @@ -358,36 +358,28 @@ DEFAULT_BATCH_SIZE = 100 def _table_to_pickle_state(table: Table) -> dict[str, Any]: from .remote.table import RemoteTable - if isinstance(table, RemoteTable): - return { - "kind": "remote", - "table": table, - } - - if not isinstance(table, LanceTable): - raise ValueError(f"Cannot pickle table of type {type(table)!r}") - - base_uri = table._conn.uri - if base_uri.startswith("memory://"): + if isinstance(table, LanceTable) and table._conn.uri.startswith("memory://"): return { "kind": "memory", "name": table.name, "data": table.to_arrow(), } - return { - "kind": "local", - "name": table.name, - "uri": base_uri, - "namespace": table._namespace_path, - "storage_options": table._conn.storage_options, - } + if isinstance(table, (LanceTable, RemoteTable)): + return { + "kind": "table", + "table": table, + } + + raise ValueError(f"Cannot pickle table of type {type(table)!r}") def _table_from_pickle_state(state: dict[str, Any]) -> Table: from . import connect kind = state["kind"] + if kind == "table": + return state["table"] if kind == "remote": return state["table"] if kind == "memory": diff --git a/python/python/lancedb/remote/table.py b/python/python/lancedb/remote/table.py index 019f91044..1a0309d93 100644 --- a/python/python/lancedb/remote/table.py +++ b/python/python/lancedb/remote/table.py @@ -74,6 +74,7 @@ class RemoteTable(Table): self._connection_state = connection_state self._namespace_path = list(namespace_path or []) self._checkout_version: Optional[int] = None + self._table_state: Optional[dict[str, Any]] = None self._pid = os.getpid() def _serialized_connection_state(self) -> str: @@ -86,6 +87,16 @@ class RemoteTable(Table): self._connection_state = self._connection_state() return self._connection_state + def _reopen_state(self) -> dict[str, Any]: + if self._table_state is not None: + return self._table_state + self._table_state = { + "name": self._name, + "namespace_path": self._namespace_path, + "storage_options": None, + } + return self._table_state + @property def _table(self) -> AsyncTable: self._ensure_open() @@ -96,6 +107,7 @@ class RemoteTable(Table): def _table(self, table: AsyncTable) -> None: self._table_handle = table self._name = table.name + self._table_state = None self._pid = os.getpid() def _ensure_open(self) -> None: @@ -108,7 +120,11 @@ class RemoteTable(Table): from lancedb import deserialize_conn db = deserialize_conn(self._serialized_connection_state(), for_worker=True) - table = db.open_table(self._name, namespace_path=self._namespace_path) + table_state = self._reopen_state() + table = db.open_table( + table_state["name"], + namespace_path=table_state["namespace_path"] or None, + ) if self._checkout_version is not None: table.checkout(self._checkout_version) @@ -120,17 +136,24 @@ class RemoteTable(Table): return { "connection_state": self._serialized_connection_state(), "db_name": self.db_name, - "name": self.name, - "namespace_path": self._namespace_path, + "table_state": self._reopen_state(), "checkout_version": self._checkout_version, } def __setstate__(self, state: dict) -> None: self._table_handle = None - self._name = state["name"] + table_state = state.get("table_state") + if table_state is None: + table_state = { + "name": state["name"], + "namespace_path": state["namespace_path"], + "storage_options": None, + } + self._table_state = table_state + self._name = table_state["name"] self.db_name = state["db_name"] self._connection_state = state["connection_state"] - self._namespace_path = state["namespace_path"] + self._namespace_path = table_state["namespace_path"] self._checkout_version = state["checkout_version"] self._pid = None diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 2de369419..d326fb99e 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -5,6 +5,7 @@ from __future__ import annotations import asyncio import inspect +import os import deprecation import warnings from abc import ABC, abstractmethod @@ -757,8 +758,12 @@ class Table(ABC): """ raise NotImplementedError + def _ensure_open(self) -> None: + pass + def __len__(self) -> int: """The number of rows in this Table""" + self._ensure_open() return self.count_rows(None) @property @@ -1404,6 +1409,7 @@ class Table(ABC): pa.RecordBatch A record batch containing the rows at the given offsets. """ + self._ensure_open() # We don't know the order of the results at all. So we calculate a permutation # for ordering the given offsets. Then we load the data with the _rowoffset # column. Then we sort by _rowoffset and apply the inverse of the permutation @@ -1962,6 +1968,7 @@ class LanceTable(Table): self._location = location # Store location for use in _dataset_path self._namespace_client = namespace_client self._pushdown_operations = pushdown_operations or set() + self._init_reopen_tracking() if _async is not None: self._table = _async else: @@ -1977,6 +1984,66 @@ class LanceTable(Table): ) ) + def _init_reopen_tracking(self) -> None: + self._checkout_version: Optional[int] = None + self._table_state: Optional[dict[str, Any]] = None + self._pid = os.getpid() + + def _reopen_state(self) -> dict[str, Any]: + state = LOOP.run(self._table._table_reopen_state()) + if get_uri_scheme(self._conn.uri) == "memory": + raise ValueError( + "Cannot pickle an in-memory LanceTable. Use a persisted table " + "or provide a worker-side connection factory." + ) + return state + + def _copy_reopened_table(self, table: "LanceTable") -> None: + self._conn = table._conn + self._namespace_path = table._namespace_path + self._location = table._location + self._namespace_client = table._namespace_client + self._pushdown_operations = table._pushdown_operations + self._table = table._table + self._pid = os.getpid() + + def _ensure_open(self) -> None: + pid = os.getpid() + if getattr(self, "_table", None) is not None and self._pid == pid: + return + if self._table_state is None: + self._table_state = self._reopen_state() + + table = self._conn.open_table( + self._table_state["name"], + namespace_path=self._table_state["namespace_path"] or None, + storage_options=self._table_state["storage_options"], + ) + if self._checkout_version is not None: + table.checkout(self._checkout_version) + self._copy_reopened_table(table) + + def __getstate__(self) -> dict[str, Any]: + return { + "connection_state": self._conn.serialize(), + "table_state": self._reopen_state(), + "checkout_version": self._checkout_version, + } + + def __setstate__(self, state: dict[str, Any]) -> None: + from . import deserialize_conn + + self._conn = deserialize_conn(state["connection_state"], for_worker=True) + self._namespace_path = list(state["table_state"]["namespace_path"] or []) + self._location = None + self._namespace_client = None + self._pushdown_operations = set() + self._checkout_version = state["checkout_version"] + self._table_state = state["table_state"] + self._table = None + self._pid = None + self._ensure_open() + @property def name(self) -> str: return self._table.name @@ -2180,6 +2247,7 @@ class LanceTable(Table): 0 [1.1, 0.9] vector """ LOOP.run(self._table.checkout(version)) + self._checkout_version = self.version def checkout_latest(self): """Checkout the latest version of the table. This is an in-place operation. @@ -2188,6 +2256,7 @@ class LanceTable(Table): version of the table. """ LOOP.run(self._table.checkout_latest()) + self._checkout_version = None def restore(self, version: Optional[Union[int, str]] = None): """Restore a version of the table. This is an in-place operation. @@ -2236,6 +2305,7 @@ class LanceTable(Table): if version is not None: LOOP.run(self._table.checkout(version)) LOOP.run(self._table.restore()) + self._checkout_version = None def count_rows(self, filter: Optional[str] = None) -> int: return LOOP.run(self._table.count_rows(filter)) @@ -3294,6 +3364,7 @@ class LanceTable(Table): self._location = location self._namespace_client = namespace_client self._pushdown_operations = pushdown_operations or set() + self._init_reopen_tracking() if data_storage_version is not None: warnings.warn( @@ -4551,6 +4622,10 @@ class AsyncTable: """ return await self._inner.latest_storage_options() + async def _table_reopen_state(self) -> dict[str, Any]: + """Get the Rust-side table state needed to reopen this table.""" + return await self._inner._table_reopen_state() + async def add( self, data: DATA, diff --git a/python/python/tests/test_remote_db.py b/python/python/tests/test_remote_db.py index 4cc184c77..11a1fdfc2 100644 --- a/python/python/tests/test_remote_db.py +++ b/python/python/tests/test_remote_db.py @@ -215,10 +215,51 @@ def test_remote_table_is_picklable(): with mock_lancedb_connection(handler) as db: table = db.open_table("test") + state = table.__getstate__() + assert state["table_state"] == { + "name": "test", + "namespace_path": [], + "storage_options": None, + } restored = pickle.loads(pickle.dumps(table)) assert restored.count_rows() == 3 +def test_remote_table_reopens_when_pid_changes_without_cached_state(): + def handler(request): + request.close_connection = True + if request.path == "/v1/table/test/describe/": + request.send_response(200) + request.send_header("Content-Type", "application/json") + request.end_headers() + payload = json.dumps( + { + "version": 1, + "schema": { + "fields": [ + {"name": "id", "type": {"type": "int64"}, "nullable": False} + ] + }, + } + ) + request.wfile.write(payload.encode()) + elif request.path == "/v1/table/test/count_rows/": + request.send_response(200) + request.send_header("Content-Type", "application/json") + request.end_headers() + request.wfile.write(b"3") + else: + request.send_response(404) + request.end_headers() + + with mock_lancedb_connection(handler) as db: + table = db.open_table("test") + table._pid = -1 + table._table_state = None + + assert table.count_rows() == 3 + + def test_remote_table_open_does_not_require_picklable_client_config(): from lancedb.remote import HeaderProvider diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index 2a07c2df6..63659f97b 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -3,6 +3,7 @@ import os +import pickle import sys import warnings from datetime import date, datetime, timedelta @@ -48,6 +49,36 @@ def test_basic(mem_db: DBConnection): assert table.to_arrow() == expected_data +def test_lance_table_is_picklable(tmp_db: DBConnection): + table = tmp_db.create_table("pickle_table", pa.table({"id": [1, 2, 3]})) + + restored = pickle.loads(pickle.dumps(table)) + + assert restored.name == "pickle_table" + assert restored.count_rows() == 3 + assert restored.to_arrow().column("id").to_pylist() == [1, 2, 3] + + +def test_lance_table_pickle_preserves_checkout(tmp_db: DBConnection): + table = tmp_db.create_table("pickle_checkout", pa.table({"id": [1]})) + table.add(pa.table({"id": [2]})) + table.checkout(1) + + restored = pickle.loads(pickle.dumps(table)) + + assert restored.count_rows() == 1 + assert restored.to_arrow().column("id").to_pylist() == [1] + restored.checkout_latest() + assert restored.count_rows() == 2 + + +def test_memory_lance_table_pickle_is_unsupported(mem_db: DBConnection): + table = mem_db.create_table("memory_pickle", pa.table({"id": [1]})) + + with pytest.raises(ValueError, match="in-memory LanceTable"): + pickle.dumps(table) + + def test_table_to_pandas_default_matches_arrow(tmp_db: DBConnection): pd = pytest.importorskip("pandas") data = pa.table({"id": [1, 2], "text": ["one", "two"]}) diff --git a/python/src/table.rs b/python/src/table.rs index 302c2bb46..7cbe17019 100644 --- a/python/src/table.rs +++ b/python/src/table.rs @@ -755,6 +755,23 @@ impl Table { }) } + pub fn _table_reopen_state(self_: PyRef<'_, Self>) -> PyResult> { + let inner = self_.inner_ref()?.clone(); + future_into_py(self_.py(), async move { + let name = inner.name().to_string(); + let namespace_path = inner.namespace().to_vec(); + let storage_options = inner.initial_storage_options().await; + + Python::attach(|py| { + let dict = PyDict::new(py); + dict.set_item("name", name)?; + dict.set_item("namespace_path", namespace_path)?; + dict.set_item("storage_options", storage_options)?; + Ok(dict.unbind()) + }) + }) + } + pub fn __repr__(&self) -> String { match &self.inner { None => format!("ClosedTable({})", self.name),