mirror of
https://github.com/lancedb/lancedb.git
synced 2026-06-02 20:00:46 +00:00
fix(python): move table pickle state onto tables
This commit is contained in:
@@ -217,6 +217,7 @@ class Table:
|
||||
async def uri(self) -> str: ...
|
||||
async def initial_storage_options(self) -> Optional[Dict[str, str]]: ...
|
||||
async def latest_storage_options(self) -> Optional[Dict[str, str]]: ...
|
||||
async def _table_reopen_state(self) -> Dict[str, Any]: ...
|
||||
async def set_unenforced_primary_key(self, columns: List[str]) -> None: ...
|
||||
async def set_lsm_write_spec(self, spec: LsmWriteSpec) -> None: ...
|
||||
async def unset_lsm_write_spec(self) -> None: ...
|
||||
|
||||
@@ -358,36 +358,28 @@ DEFAULT_BATCH_SIZE = 100
|
||||
def _table_to_pickle_state(table: Table) -> dict[str, Any]:
|
||||
from .remote.table import RemoteTable
|
||||
|
||||
if isinstance(table, RemoteTable):
|
||||
return {
|
||||
"kind": "remote",
|
||||
"table": table,
|
||||
}
|
||||
|
||||
if not isinstance(table, LanceTable):
|
||||
raise ValueError(f"Cannot pickle table of type {type(table)!r}")
|
||||
|
||||
base_uri = table._conn.uri
|
||||
if base_uri.startswith("memory://"):
|
||||
if isinstance(table, LanceTable) and table._conn.uri.startswith("memory://"):
|
||||
return {
|
||||
"kind": "memory",
|
||||
"name": table.name,
|
||||
"data": table.to_arrow(),
|
||||
}
|
||||
|
||||
return {
|
||||
"kind": "local",
|
||||
"name": table.name,
|
||||
"uri": base_uri,
|
||||
"namespace": table._namespace_path,
|
||||
"storage_options": table._conn.storage_options,
|
||||
}
|
||||
if isinstance(table, (LanceTable, RemoteTable)):
|
||||
return {
|
||||
"kind": "table",
|
||||
"table": table,
|
||||
}
|
||||
|
||||
raise ValueError(f"Cannot pickle table of type {type(table)!r}")
|
||||
|
||||
|
||||
def _table_from_pickle_state(state: dict[str, Any]) -> Table:
|
||||
from . import connect
|
||||
|
||||
kind = state["kind"]
|
||||
if kind == "table":
|
||||
return state["table"]
|
||||
if kind == "remote":
|
||||
return state["table"]
|
||||
if kind == "memory":
|
||||
|
||||
@@ -74,6 +74,7 @@ class RemoteTable(Table):
|
||||
self._connection_state = connection_state
|
||||
self._namespace_path = list(namespace_path or [])
|
||||
self._checkout_version: Optional[int] = None
|
||||
self._table_state: Optional[dict[str, Any]] = None
|
||||
self._pid = os.getpid()
|
||||
|
||||
def _serialized_connection_state(self) -> str:
|
||||
@@ -86,6 +87,16 @@ class RemoteTable(Table):
|
||||
self._connection_state = self._connection_state()
|
||||
return self._connection_state
|
||||
|
||||
def _reopen_state(self) -> dict[str, Any]:
|
||||
if self._table_state is not None:
|
||||
return self._table_state
|
||||
self._table_state = {
|
||||
"name": self._name,
|
||||
"namespace_path": self._namespace_path,
|
||||
"storage_options": None,
|
||||
}
|
||||
return self._table_state
|
||||
|
||||
@property
|
||||
def _table(self) -> AsyncTable:
|
||||
self._ensure_open()
|
||||
@@ -96,6 +107,7 @@ class RemoteTable(Table):
|
||||
def _table(self, table: AsyncTable) -> None:
|
||||
self._table_handle = table
|
||||
self._name = table.name
|
||||
self._table_state = None
|
||||
self._pid = os.getpid()
|
||||
|
||||
def _ensure_open(self) -> None:
|
||||
@@ -108,7 +120,11 @@ class RemoteTable(Table):
|
||||
from lancedb import deserialize_conn
|
||||
|
||||
db = deserialize_conn(self._serialized_connection_state(), for_worker=True)
|
||||
table = db.open_table(self._name, namespace_path=self._namespace_path)
|
||||
table_state = self._reopen_state()
|
||||
table = db.open_table(
|
||||
table_state["name"],
|
||||
namespace_path=table_state["namespace_path"] or None,
|
||||
)
|
||||
if self._checkout_version is not None:
|
||||
table.checkout(self._checkout_version)
|
||||
|
||||
@@ -120,17 +136,24 @@ class RemoteTable(Table):
|
||||
return {
|
||||
"connection_state": self._serialized_connection_state(),
|
||||
"db_name": self.db_name,
|
||||
"name": self.name,
|
||||
"namespace_path": self._namespace_path,
|
||||
"table_state": self._reopen_state(),
|
||||
"checkout_version": self._checkout_version,
|
||||
}
|
||||
|
||||
def __setstate__(self, state: dict) -> None:
|
||||
self._table_handle = None
|
||||
self._name = state["name"]
|
||||
table_state = state.get("table_state")
|
||||
if table_state is None:
|
||||
table_state = {
|
||||
"name": state["name"],
|
||||
"namespace_path": state["namespace_path"],
|
||||
"storage_options": None,
|
||||
}
|
||||
self._table_state = table_state
|
||||
self._name = table_state["name"]
|
||||
self.db_name = state["db_name"]
|
||||
self._connection_state = state["connection_state"]
|
||||
self._namespace_path = state["namespace_path"]
|
||||
self._namespace_path = table_state["namespace_path"]
|
||||
self._checkout_version = state["checkout_version"]
|
||||
self._pid = None
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import os
|
||||
import deprecation
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
@@ -757,8 +758,12 @@ class Table(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _ensure_open(self) -> None:
|
||||
pass
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""The number of rows in this Table"""
|
||||
self._ensure_open()
|
||||
return self.count_rows(None)
|
||||
|
||||
@property
|
||||
@@ -1404,6 +1409,7 @@ class Table(ABC):
|
||||
pa.RecordBatch
|
||||
A record batch containing the rows at the given offsets.
|
||||
"""
|
||||
self._ensure_open()
|
||||
# We don't know the order of the results at all. So we calculate a permutation
|
||||
# for ordering the given offsets. Then we load the data with the _rowoffset
|
||||
# column. Then we sort by _rowoffset and apply the inverse of the permutation
|
||||
@@ -1962,6 +1968,7 @@ class LanceTable(Table):
|
||||
self._location = location # Store location for use in _dataset_path
|
||||
self._namespace_client = namespace_client
|
||||
self._pushdown_operations = pushdown_operations or set()
|
||||
self._init_reopen_tracking()
|
||||
if _async is not None:
|
||||
self._table = _async
|
||||
else:
|
||||
@@ -1977,6 +1984,66 @@ class LanceTable(Table):
|
||||
)
|
||||
)
|
||||
|
||||
def _init_reopen_tracking(self) -> None:
|
||||
self._checkout_version: Optional[int] = None
|
||||
self._table_state: Optional[dict[str, Any]] = None
|
||||
self._pid = os.getpid()
|
||||
|
||||
def _reopen_state(self) -> dict[str, Any]:
|
||||
state = LOOP.run(self._table._table_reopen_state())
|
||||
if get_uri_scheme(self._conn.uri) == "memory":
|
||||
raise ValueError(
|
||||
"Cannot pickle an in-memory LanceTable. Use a persisted table "
|
||||
"or provide a worker-side connection factory."
|
||||
)
|
||||
return state
|
||||
|
||||
def _copy_reopened_table(self, table: "LanceTable") -> None:
|
||||
self._conn = table._conn
|
||||
self._namespace_path = table._namespace_path
|
||||
self._location = table._location
|
||||
self._namespace_client = table._namespace_client
|
||||
self._pushdown_operations = table._pushdown_operations
|
||||
self._table = table._table
|
||||
self._pid = os.getpid()
|
||||
|
||||
def _ensure_open(self) -> None:
|
||||
pid = os.getpid()
|
||||
if getattr(self, "_table", None) is not None and self._pid == pid:
|
||||
return
|
||||
if self._table_state is None:
|
||||
self._table_state = self._reopen_state()
|
||||
|
||||
table = self._conn.open_table(
|
||||
self._table_state["name"],
|
||||
namespace_path=self._table_state["namespace_path"] or None,
|
||||
storage_options=self._table_state["storage_options"],
|
||||
)
|
||||
if self._checkout_version is not None:
|
||||
table.checkout(self._checkout_version)
|
||||
self._copy_reopened_table(table)
|
||||
|
||||
def __getstate__(self) -> dict[str, Any]:
|
||||
return {
|
||||
"connection_state": self._conn.serialize(),
|
||||
"table_state": self._reopen_state(),
|
||||
"checkout_version": self._checkout_version,
|
||||
}
|
||||
|
||||
def __setstate__(self, state: dict[str, Any]) -> None:
|
||||
from . import deserialize_conn
|
||||
|
||||
self._conn = deserialize_conn(state["connection_state"], for_worker=True)
|
||||
self._namespace_path = list(state["table_state"]["namespace_path"] or [])
|
||||
self._location = None
|
||||
self._namespace_client = None
|
||||
self._pushdown_operations = set()
|
||||
self._checkout_version = state["checkout_version"]
|
||||
self._table_state = state["table_state"]
|
||||
self._table = None
|
||||
self._pid = None
|
||||
self._ensure_open()
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._table.name
|
||||
@@ -2180,6 +2247,7 @@ class LanceTable(Table):
|
||||
0 [1.1, 0.9] vector
|
||||
"""
|
||||
LOOP.run(self._table.checkout(version))
|
||||
self._checkout_version = self.version
|
||||
|
||||
def checkout_latest(self):
|
||||
"""Checkout the latest version of the table. This is an in-place operation.
|
||||
@@ -2188,6 +2256,7 @@ class LanceTable(Table):
|
||||
version of the table.
|
||||
"""
|
||||
LOOP.run(self._table.checkout_latest())
|
||||
self._checkout_version = None
|
||||
|
||||
def restore(self, version: Optional[Union[int, str]] = None):
|
||||
"""Restore a version of the table. This is an in-place operation.
|
||||
@@ -2236,6 +2305,7 @@ class LanceTable(Table):
|
||||
if version is not None:
|
||||
LOOP.run(self._table.checkout(version))
|
||||
LOOP.run(self._table.restore())
|
||||
self._checkout_version = None
|
||||
|
||||
def count_rows(self, filter: Optional[str] = None) -> int:
|
||||
return LOOP.run(self._table.count_rows(filter))
|
||||
@@ -3294,6 +3364,7 @@ class LanceTable(Table):
|
||||
self._location = location
|
||||
self._namespace_client = namespace_client
|
||||
self._pushdown_operations = pushdown_operations or set()
|
||||
self._init_reopen_tracking()
|
||||
|
||||
if data_storage_version is not None:
|
||||
warnings.warn(
|
||||
@@ -4551,6 +4622,10 @@ class AsyncTable:
|
||||
"""
|
||||
return await self._inner.latest_storage_options()
|
||||
|
||||
async def _table_reopen_state(self) -> dict[str, Any]:
|
||||
"""Get the Rust-side table state needed to reopen this table."""
|
||||
return await self._inner._table_reopen_state()
|
||||
|
||||
async def add(
|
||||
self,
|
||||
data: DATA,
|
||||
|
||||
@@ -215,10 +215,51 @@ def test_remote_table_is_picklable():
|
||||
|
||||
with mock_lancedb_connection(handler) as db:
|
||||
table = db.open_table("test")
|
||||
state = table.__getstate__()
|
||||
assert state["table_state"] == {
|
||||
"name": "test",
|
||||
"namespace_path": [],
|
||||
"storage_options": None,
|
||||
}
|
||||
restored = pickle.loads(pickle.dumps(table))
|
||||
assert restored.count_rows() == 3
|
||||
|
||||
|
||||
def test_remote_table_reopens_when_pid_changes_without_cached_state():
|
||||
def handler(request):
|
||||
request.close_connection = True
|
||||
if request.path == "/v1/table/test/describe/":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
payload = json.dumps(
|
||||
{
|
||||
"version": 1,
|
||||
"schema": {
|
||||
"fields": [
|
||||
{"name": "id", "type": {"type": "int64"}, "nullable": False}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
request.wfile.write(payload.encode())
|
||||
elif request.path == "/v1/table/test/count_rows/":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(b"3")
|
||||
else:
|
||||
request.send_response(404)
|
||||
request.end_headers()
|
||||
|
||||
with mock_lancedb_connection(handler) as db:
|
||||
table = db.open_table("test")
|
||||
table._pid = -1
|
||||
table._table_state = None
|
||||
|
||||
assert table.count_rows() == 3
|
||||
|
||||
|
||||
def test_remote_table_open_does_not_require_picklable_client_config():
|
||||
from lancedb.remote import HeaderProvider
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
import warnings
|
||||
from datetime import date, datetime, timedelta
|
||||
@@ -48,6 +49,36 @@ def test_basic(mem_db: DBConnection):
|
||||
assert table.to_arrow() == expected_data
|
||||
|
||||
|
||||
def test_lance_table_is_picklable(tmp_db: DBConnection):
|
||||
table = tmp_db.create_table("pickle_table", pa.table({"id": [1, 2, 3]}))
|
||||
|
||||
restored = pickle.loads(pickle.dumps(table))
|
||||
|
||||
assert restored.name == "pickle_table"
|
||||
assert restored.count_rows() == 3
|
||||
assert restored.to_arrow().column("id").to_pylist() == [1, 2, 3]
|
||||
|
||||
|
||||
def test_lance_table_pickle_preserves_checkout(tmp_db: DBConnection):
|
||||
table = tmp_db.create_table("pickle_checkout", pa.table({"id": [1]}))
|
||||
table.add(pa.table({"id": [2]}))
|
||||
table.checkout(1)
|
||||
|
||||
restored = pickle.loads(pickle.dumps(table))
|
||||
|
||||
assert restored.count_rows() == 1
|
||||
assert restored.to_arrow().column("id").to_pylist() == [1]
|
||||
restored.checkout_latest()
|
||||
assert restored.count_rows() == 2
|
||||
|
||||
|
||||
def test_memory_lance_table_pickle_is_unsupported(mem_db: DBConnection):
|
||||
table = mem_db.create_table("memory_pickle", pa.table({"id": [1]}))
|
||||
|
||||
with pytest.raises(ValueError, match="in-memory LanceTable"):
|
||||
pickle.dumps(table)
|
||||
|
||||
|
||||
def test_table_to_pandas_default_matches_arrow(tmp_db: DBConnection):
|
||||
pd = pytest.importorskip("pandas")
|
||||
data = pa.table({"id": [1, 2], "text": ["one", "two"]})
|
||||
|
||||
@@ -755,6 +755,23 @@ impl Table {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn _table_reopen_state(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner_ref()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let name = inner.name().to_string();
|
||||
let namespace_path = inner.namespace().to_vec();
|
||||
let storage_options = inner.initial_storage_options().await;
|
||||
|
||||
Python::attach(|py| {
|
||||
let dict = PyDict::new(py);
|
||||
dict.set_item("name", name)?;
|
||||
dict.set_item("namespace_path", namespace_path)?;
|
||||
dict.set_item("storage_options", storage_options)?;
|
||||
Ok(dict.unbind())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn __repr__(&self) -> String {
|
||||
match &self.inner {
|
||||
None => format!("ClosedTable({})", self.name),
|
||||
|
||||
Reference in New Issue
Block a user