fix(python): move table pickle state onto tables

This commit is contained in:
Xuanwo
2026-06-02 18:15:21 +08:00
parent a327044e2f
commit 4d10d22e92
7 changed files with 203 additions and 23 deletions

View File

@@ -217,6 +217,7 @@ class Table:
async def uri(self) -> str: ...
async def initial_storage_options(self) -> Optional[Dict[str, str]]: ...
async def latest_storage_options(self) -> Optional[Dict[str, str]]: ...
async def _table_reopen_state(self) -> Dict[str, Any]: ...
async def set_unenforced_primary_key(self, columns: List[str]) -> None: ...
async def set_lsm_write_spec(self, spec: LsmWriteSpec) -> None: ...
async def unset_lsm_write_spec(self) -> None: ...

View File

@@ -358,36 +358,28 @@ DEFAULT_BATCH_SIZE = 100
def _table_to_pickle_state(table: Table) -> dict[str, Any]:
from .remote.table import RemoteTable
if isinstance(table, RemoteTable):
return {
"kind": "remote",
"table": table,
}
if not isinstance(table, LanceTable):
raise ValueError(f"Cannot pickle table of type {type(table)!r}")
base_uri = table._conn.uri
if base_uri.startswith("memory://"):
if isinstance(table, LanceTable) and table._conn.uri.startswith("memory://"):
return {
"kind": "memory",
"name": table.name,
"data": table.to_arrow(),
}
return {
"kind": "local",
"name": table.name,
"uri": base_uri,
"namespace": table._namespace_path,
"storage_options": table._conn.storage_options,
}
if isinstance(table, (LanceTable, RemoteTable)):
return {
"kind": "table",
"table": table,
}
raise ValueError(f"Cannot pickle table of type {type(table)!r}")
def _table_from_pickle_state(state: dict[str, Any]) -> Table:
from . import connect
kind = state["kind"]
if kind == "table":
return state["table"]
if kind == "remote":
return state["table"]
if kind == "memory":

View File

@@ -74,6 +74,7 @@ class RemoteTable(Table):
self._connection_state = connection_state
self._namespace_path = list(namespace_path or [])
self._checkout_version: Optional[int] = None
self._table_state: Optional[dict[str, Any]] = None
self._pid = os.getpid()
def _serialized_connection_state(self) -> str:
@@ -86,6 +87,16 @@ class RemoteTable(Table):
self._connection_state = self._connection_state()
return self._connection_state
def _reopen_state(self) -> dict[str, Any]:
if self._table_state is not None:
return self._table_state
self._table_state = {
"name": self._name,
"namespace_path": self._namespace_path,
"storage_options": None,
}
return self._table_state
@property
def _table(self) -> AsyncTable:
self._ensure_open()
@@ -96,6 +107,7 @@ class RemoteTable(Table):
def _table(self, table: AsyncTable) -> None:
self._table_handle = table
self._name = table.name
self._table_state = None
self._pid = os.getpid()
def _ensure_open(self) -> None:
@@ -108,7 +120,11 @@ class RemoteTable(Table):
from lancedb import deserialize_conn
db = deserialize_conn(self._serialized_connection_state(), for_worker=True)
table = db.open_table(self._name, namespace_path=self._namespace_path)
table_state = self._reopen_state()
table = db.open_table(
table_state["name"],
namespace_path=table_state["namespace_path"] or None,
)
if self._checkout_version is not None:
table.checkout(self._checkout_version)
@@ -120,17 +136,24 @@ class RemoteTable(Table):
return {
"connection_state": self._serialized_connection_state(),
"db_name": self.db_name,
"name": self.name,
"namespace_path": self._namespace_path,
"table_state": self._reopen_state(),
"checkout_version": self._checkout_version,
}
def __setstate__(self, state: dict) -> None:
self._table_handle = None
self._name = state["name"]
table_state = state.get("table_state")
if table_state is None:
table_state = {
"name": state["name"],
"namespace_path": state["namespace_path"],
"storage_options": None,
}
self._table_state = table_state
self._name = table_state["name"]
self.db_name = state["db_name"]
self._connection_state = state["connection_state"]
self._namespace_path = state["namespace_path"]
self._namespace_path = table_state["namespace_path"]
self._checkout_version = state["checkout_version"]
self._pid = None

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
import asyncio
import inspect
import os
import deprecation
import warnings
from abc import ABC, abstractmethod
@@ -757,8 +758,12 @@ class Table(ABC):
"""
raise NotImplementedError
def _ensure_open(self) -> None:
pass
def __len__(self) -> int:
"""The number of rows in this Table"""
self._ensure_open()
return self.count_rows(None)
@property
@@ -1404,6 +1409,7 @@ class Table(ABC):
pa.RecordBatch
A record batch containing the rows at the given offsets.
"""
self._ensure_open()
# We don't know the order of the results at all. So we calculate a permutation
# for ordering the given offsets. Then we load the data with the _rowoffset
# column. Then we sort by _rowoffset and apply the inverse of the permutation
@@ -1962,6 +1968,7 @@ class LanceTable(Table):
self._location = location # Store location for use in _dataset_path
self._namespace_client = namespace_client
self._pushdown_operations = pushdown_operations or set()
self._init_reopen_tracking()
if _async is not None:
self._table = _async
else:
@@ -1977,6 +1984,66 @@ class LanceTable(Table):
)
)
def _init_reopen_tracking(self) -> None:
self._checkout_version: Optional[int] = None
self._table_state: Optional[dict[str, Any]] = None
self._pid = os.getpid()
def _reopen_state(self) -> dict[str, Any]:
state = LOOP.run(self._table._table_reopen_state())
if get_uri_scheme(self._conn.uri) == "memory":
raise ValueError(
"Cannot pickle an in-memory LanceTable. Use a persisted table "
"or provide a worker-side connection factory."
)
return state
def _copy_reopened_table(self, table: "LanceTable") -> None:
self._conn = table._conn
self._namespace_path = table._namespace_path
self._location = table._location
self._namespace_client = table._namespace_client
self._pushdown_operations = table._pushdown_operations
self._table = table._table
self._pid = os.getpid()
def _ensure_open(self) -> None:
pid = os.getpid()
if getattr(self, "_table", None) is not None and self._pid == pid:
return
if self._table_state is None:
self._table_state = self._reopen_state()
table = self._conn.open_table(
self._table_state["name"],
namespace_path=self._table_state["namespace_path"] or None,
storage_options=self._table_state["storage_options"],
)
if self._checkout_version is not None:
table.checkout(self._checkout_version)
self._copy_reopened_table(table)
def __getstate__(self) -> dict[str, Any]:
return {
"connection_state": self._conn.serialize(),
"table_state": self._reopen_state(),
"checkout_version": self._checkout_version,
}
def __setstate__(self, state: dict[str, Any]) -> None:
from . import deserialize_conn
self._conn = deserialize_conn(state["connection_state"], for_worker=True)
self._namespace_path = list(state["table_state"]["namespace_path"] or [])
self._location = None
self._namespace_client = None
self._pushdown_operations = set()
self._checkout_version = state["checkout_version"]
self._table_state = state["table_state"]
self._table = None
self._pid = None
self._ensure_open()
@property
def name(self) -> str:
return self._table.name
@@ -2180,6 +2247,7 @@ class LanceTable(Table):
0 [1.1, 0.9] vector
"""
LOOP.run(self._table.checkout(version))
self._checkout_version = self.version
def checkout_latest(self):
"""Checkout the latest version of the table. This is an in-place operation.
@@ -2188,6 +2256,7 @@ class LanceTable(Table):
version of the table.
"""
LOOP.run(self._table.checkout_latest())
self._checkout_version = None
def restore(self, version: Optional[Union[int, str]] = None):
"""Restore a version of the table. This is an in-place operation.
@@ -2236,6 +2305,7 @@ class LanceTable(Table):
if version is not None:
LOOP.run(self._table.checkout(version))
LOOP.run(self._table.restore())
self._checkout_version = None
def count_rows(self, filter: Optional[str] = None) -> int:
return LOOP.run(self._table.count_rows(filter))
@@ -3294,6 +3364,7 @@ class LanceTable(Table):
self._location = location
self._namespace_client = namespace_client
self._pushdown_operations = pushdown_operations or set()
self._init_reopen_tracking()
if data_storage_version is not None:
warnings.warn(
@@ -4551,6 +4622,10 @@ class AsyncTable:
"""
return await self._inner.latest_storage_options()
async def _table_reopen_state(self) -> dict[str, Any]:
"""Get the Rust-side table state needed to reopen this table."""
return await self._inner._table_reopen_state()
async def add(
self,
data: DATA,

View File

@@ -215,10 +215,51 @@ def test_remote_table_is_picklable():
with mock_lancedb_connection(handler) as db:
table = db.open_table("test")
state = table.__getstate__()
assert state["table_state"] == {
"name": "test",
"namespace_path": [],
"storage_options": None,
}
restored = pickle.loads(pickle.dumps(table))
assert restored.count_rows() == 3
def test_remote_table_reopens_when_pid_changes_without_cached_state():
def handler(request):
request.close_connection = True
if request.path == "/v1/table/test/describe/":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
payload = json.dumps(
{
"version": 1,
"schema": {
"fields": [
{"name": "id", "type": {"type": "int64"}, "nullable": False}
]
},
}
)
request.wfile.write(payload.encode())
elif request.path == "/v1/table/test/count_rows/":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
request.wfile.write(b"3")
else:
request.send_response(404)
request.end_headers()
with mock_lancedb_connection(handler) as db:
table = db.open_table("test")
table._pid = -1
table._table_state = None
assert table.count_rows() == 3
def test_remote_table_open_does_not_require_picklable_client_config():
from lancedb.remote import HeaderProvider

View File

@@ -3,6 +3,7 @@
import os
import pickle
import sys
import warnings
from datetime import date, datetime, timedelta
@@ -48,6 +49,36 @@ def test_basic(mem_db: DBConnection):
assert table.to_arrow() == expected_data
def test_lance_table_is_picklable(tmp_db: DBConnection):
table = tmp_db.create_table("pickle_table", pa.table({"id": [1, 2, 3]}))
restored = pickle.loads(pickle.dumps(table))
assert restored.name == "pickle_table"
assert restored.count_rows() == 3
assert restored.to_arrow().column("id").to_pylist() == [1, 2, 3]
def test_lance_table_pickle_preserves_checkout(tmp_db: DBConnection):
table = tmp_db.create_table("pickle_checkout", pa.table({"id": [1]}))
table.add(pa.table({"id": [2]}))
table.checkout(1)
restored = pickle.loads(pickle.dumps(table))
assert restored.count_rows() == 1
assert restored.to_arrow().column("id").to_pylist() == [1]
restored.checkout_latest()
assert restored.count_rows() == 2
def test_memory_lance_table_pickle_is_unsupported(mem_db: DBConnection):
table = mem_db.create_table("memory_pickle", pa.table({"id": [1]}))
with pytest.raises(ValueError, match="in-memory LanceTable"):
pickle.dumps(table)
def test_table_to_pandas_default_matches_arrow(tmp_db: DBConnection):
pd = pytest.importorskip("pandas")
data = pa.table({"id": [1, 2], "text": ["one", "two"]})

View File

@@ -755,6 +755,23 @@ impl Table {
})
}
pub fn _table_reopen_state(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move {
let name = inner.name().to_string();
let namespace_path = inner.namespace().to_vec();
let storage_options = inner.initial_storage_options().await;
Python::attach(|py| {
let dict = PyDict::new(py);
dict.set_item("name", name)?;
dict.set_item("namespace_path", namespace_path)?;
dict.set_item("storage_options", storage_options)?;
Ok(dict.unbind())
})
})
}
pub fn __repr__(&self) -> String {
match &self.inner {
None => format!("ClosedTable({})", self.name),