mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-14 10:30:40 +00:00
Merge remote-tracking branch 'origin/main' into xuanwo/native-fts-cjk-tokenizers
# Conflicts: # Cargo.lock
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.31.0-beta.9"
|
||||
current_version = "0.31.0-beta.11"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
[package]
|
||||
name = "lancedb-python"
|
||||
version = "0.31.0-beta.9"
|
||||
version = "0.31.0-beta.11"
|
||||
publish = false
|
||||
edition.workspace = true
|
||||
description = "Python bindings for LanceDB"
|
||||
license.workspace = true
|
||||
@@ -14,7 +15,7 @@ name = "_lancedb"
|
||||
crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
arrow = { version = "57.2", features = ["pyarrow"] }
|
||||
arrow = { version = "58.0.0", features = ["pyarrow"] }
|
||||
async-trait = "0.1"
|
||||
bytes = "1"
|
||||
lancedb = { path = "../rust/lancedb", default-features = false }
|
||||
@@ -24,8 +25,8 @@ lance-namespace-impls.workspace = true
|
||||
lance-io.workspace = true
|
||||
env_logger.workspace = true
|
||||
log.workspace = true
|
||||
pyo3 = { version = "0.26", features = ["extension-module", "abi3-py39"] }
|
||||
pyo3-async-runtimes = { version = "0.26", features = [
|
||||
pyo3 = { version = "0.28", features = ["extension-module", "abi3-py39"] }
|
||||
pyo3-async-runtimes = { version = "0.28", features = [
|
||||
"attributes",
|
||||
"tokio-runtime",
|
||||
] }
|
||||
@@ -34,10 +35,11 @@ futures.workspace = true
|
||||
serde = "1"
|
||||
serde_json = "1"
|
||||
snafu.workspace = true
|
||||
tokio = { version = "1.40", features = ["sync"] }
|
||||
tokio = { version = "1.40", features = ["sync", "rt-multi-thread"] }
|
||||
libc = "0.2"
|
||||
|
||||
[build-dependencies]
|
||||
pyo3-build-config = { version = "0.26", features = [
|
||||
pyo3-build-config = { version = "0.28", features = [
|
||||
"extension-module",
|
||||
"abi3-py39",
|
||||
] }
|
||||
|
||||
@@ -7,7 +7,6 @@ import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import timedelta
|
||||
from typing import Dict, Optional, Union, Any, List
|
||||
import warnings
|
||||
|
||||
__version__ = importlib.metadata.version("lancedb")
|
||||
|
||||
@@ -73,6 +72,7 @@ def connect(
|
||||
client_config: Union[ClientConfig, Dict[str, Any], None] = None,
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
session: Optional[Session] = None,
|
||||
manifest_enabled: bool = False,
|
||||
namespace_client_impl: Optional[str] = None,
|
||||
namespace_client_properties: Optional[Dict[str, str]] = None,
|
||||
namespace_client_pushdown_operations: Optional[List[str]] = None,
|
||||
@@ -111,6 +111,10 @@ def connect(
|
||||
storage_options: dict, optional
|
||||
Additional options for the storage backend. See available options at
|
||||
<https://docs.lancedb.com/storage/>
|
||||
manifest_enabled : bool, default False
|
||||
When true for local/native connections, use directory namespace
|
||||
manifests as the source of truth for table metadata. Existing
|
||||
directory-listed root tables are migrated into the manifest on access.
|
||||
session: Session, optional
|
||||
(For LanceDB OSS only)
|
||||
A session to use for this connection. Sessions allow you to configure
|
||||
@@ -158,11 +162,11 @@ def connect(
|
||||
conn : DBConnection
|
||||
A connection to a LanceDB database.
|
||||
"""
|
||||
if namespace_client_impl is not None or namespace_client_properties is not None:
|
||||
if namespace_client_impl is None or namespace_client_properties is None:
|
||||
if namespace_client_impl is not None:
|
||||
if namespace_client_properties is None:
|
||||
raise ValueError(
|
||||
"Both namespace_client_impl and "
|
||||
"namespace_client_properties must be provided"
|
||||
"namespace_client_properties must be provided when "
|
||||
"namespace_client_impl is set"
|
||||
)
|
||||
if kwargs:
|
||||
raise ValueError(f"Unknown keyword arguments: {kwargs}")
|
||||
@@ -175,6 +179,12 @@ def connect(
|
||||
namespace_client_pushdown_operations=namespace_client_pushdown_operations,
|
||||
)
|
||||
|
||||
if namespace_client_properties is not None and not manifest_enabled:
|
||||
raise ValueError(
|
||||
"namespace_client_impl must be provided when using "
|
||||
"namespace_client_properties unless manifest_enabled=True"
|
||||
)
|
||||
|
||||
if namespace_client_pushdown_operations is not None:
|
||||
raise ValueError(
|
||||
"namespace_client_pushdown_operations is only valid when "
|
||||
@@ -212,6 +222,8 @@ def connect(
|
||||
read_consistency_interval=read_consistency_interval,
|
||||
storage_options=storage_options,
|
||||
session=session,
|
||||
manifest_enabled=manifest_enabled,
|
||||
namespace_client_properties=namespace_client_properties,
|
||||
)
|
||||
|
||||
|
||||
@@ -289,6 +301,8 @@ def deserialize_conn(
|
||||
parsed["uri"],
|
||||
read_consistency_interval=rci,
|
||||
storage_options=storage_options,
|
||||
manifest_enabled=parsed.get("manifest_enabled", False),
|
||||
namespace_client_properties=parsed.get("namespace_client_properties"),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown connection_type: {connection_type}")
|
||||
@@ -304,6 +318,8 @@ async def connect_async(
|
||||
client_config: Optional[Union[ClientConfig, Dict[str, Any]]] = None,
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
session: Optional[Session] = None,
|
||||
manifest_enabled: bool = False,
|
||||
namespace_client_properties: Optional[Dict[str, str]] = None,
|
||||
) -> AsyncConnection:
|
||||
"""Connect to a LanceDB database.
|
||||
|
||||
@@ -343,6 +359,13 @@ async def connect_async(
|
||||
cache sizes for index and metadata caches, which can significantly
|
||||
impact memory use and performance. They can also be re-used across
|
||||
multiple connections to share the same cache state.
|
||||
manifest_enabled : bool, default False
|
||||
When true for local/native connections, use directory namespace
|
||||
manifests as the source of truth for table metadata. Existing
|
||||
directory-listed root tables are migrated into the manifest on access.
|
||||
namespace_client_properties : dict, optional
|
||||
Additional directory namespace client properties to use with
|
||||
``manifest_enabled=True``.
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -385,6 +408,8 @@ async def connect_async(
|
||||
client_config,
|
||||
storage_options,
|
||||
session,
|
||||
manifest_enabled,
|
||||
namespace_client_properties,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -412,13 +437,3 @@ __all__ = [
|
||||
"Table",
|
||||
"__version__",
|
||||
]
|
||||
|
||||
|
||||
def __warn_on_fork():
|
||||
warnings.warn(
|
||||
"lance is not fork-safe. If you are using multiprocessing, use spawn instead.",
|
||||
)
|
||||
|
||||
|
||||
if hasattr(os, "register_at_fork"):
|
||||
os.register_at_fork(before=__warn_on_fork) # type: ignore[attr-defined]
|
||||
|
||||
@@ -242,6 +242,8 @@ async def connect(
|
||||
client_config: Optional[Union[ClientConfig, Dict[str, Any]]],
|
||||
storage_options: Optional[Dict[str, str]],
|
||||
session: Optional[Session],
|
||||
manifest_enabled: bool = False,
|
||||
namespace_client_properties: Optional[Dict[str, str]] = None,
|
||||
) -> Connection: ...
|
||||
|
||||
class RecordBatchStream:
|
||||
@@ -440,7 +442,7 @@ class AsyncPermutationBuilder:
|
||||
async def execute(self) -> Table: ...
|
||||
|
||||
def async_permutation_builder(
|
||||
table: Table, dest_table_name: str
|
||||
table: Table,
|
||||
) -> AsyncPermutationBuilder: ...
|
||||
def fts_query_to_json(query: Any) -> str: ...
|
||||
|
||||
|
||||
@@ -2,7 +2,9 @@
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import threading
|
||||
import warnings
|
||||
|
||||
|
||||
class BackgroundEventLoop:
|
||||
@@ -13,6 +15,9 @@ class BackgroundEventLoop:
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._start()
|
||||
|
||||
def _start(self):
|
||||
self.loop = asyncio.new_event_loop()
|
||||
self.thread = threading.Thread(
|
||||
target=self.loop.run_forever,
|
||||
@@ -31,3 +36,30 @@ class BackgroundEventLoop:
|
||||
|
||||
|
||||
LOOP = BackgroundEventLoop()
|
||||
|
||||
_FORK_WARNED = False
|
||||
|
||||
|
||||
def _reset_after_fork():
|
||||
# Threads do not survive fork(), so the asyncio loop in LOOP.thread is
|
||||
# dead in the child. Re-initialize the singleton in place so existing
|
||||
# `from .background_loop import LOOP` references in other modules see
|
||||
# the new state. The Rust-side tokio runtime is reset analogously by a
|
||||
# pthread_atfork hook installed in the _lancedb extension.
|
||||
LOOP._start()
|
||||
global _FORK_WARNED
|
||||
if not _FORK_WARNED:
|
||||
_FORK_WARNED = True
|
||||
warnings.warn(
|
||||
"lancedb fork support is experimental: the internal async "
|
||||
"runtime has been reset in the forked child, but a small chance "
|
||||
"of deadlock remains if other state was mid-operation at fork "
|
||||
"time. The 'forkserver' or 'spawn' multiprocessing start method "
|
||||
"is likely a safer alternative.",
|
||||
RuntimeWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
if hasattr(os, "register_at_fork"):
|
||||
os.register_at_fork(after_in_child=_reset_after_fork)
|
||||
|
||||
@@ -590,8 +590,13 @@ class LanceDBConnection(DBConnection):
|
||||
read_consistency_interval: Optional[timedelta] = None,
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
session: Optional[Session] = None,
|
||||
manifest_enabled: bool = False,
|
||||
namespace_client_properties: Optional[Dict[str, str]] = None,
|
||||
_inner: Optional[LanceDbConnection] = None,
|
||||
):
|
||||
self.storage_options = storage_options
|
||||
self._manifest_enabled = manifest_enabled
|
||||
self._namespace_client_properties = namespace_client_properties
|
||||
if _inner is not None:
|
||||
self._conn = _inner
|
||||
self._cached_namespace_client = None
|
||||
@@ -633,6 +638,8 @@ class LanceDBConnection(DBConnection):
|
||||
None,
|
||||
storage_options,
|
||||
session,
|
||||
manifest_enabled,
|
||||
namespace_client_properties,
|
||||
)
|
||||
|
||||
# TODO: It would be nice if we didn't store self.storage_options but it is
|
||||
@@ -640,7 +647,6 @@ class LanceDBConnection(DBConnection):
|
||||
# work because some paths like LanceDBConnection.from_inner will lose the
|
||||
# storage_options. Also, this class really shouldn't be holding any state
|
||||
# beyond _conn.
|
||||
self.storage_options = storage_options
|
||||
self._conn = AsyncConnection(LOOP.run(do_connect()))
|
||||
self._cached_namespace_client: Optional[LanceNamespace] = None
|
||||
|
||||
@@ -677,6 +683,8 @@ class LanceDBConnection(DBConnection):
|
||||
"connection_type": "local",
|
||||
"uri": self.uri,
|
||||
"storage_options": self.storage_options,
|
||||
"manifest_enabled": self._manifest_enabled,
|
||||
"namespace_client_properties": self._namespace_client_properties,
|
||||
"read_consistency_interval_seconds": (
|
||||
rci.total_seconds() if rci else None
|
||||
),
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
from deprecation import deprecated
|
||||
from lancedb import AsyncConnection, DBConnection
|
||||
import pyarrow as pa
|
||||
import copy
|
||||
import json
|
||||
|
||||
from deprecation import deprecated
|
||||
import pyarrow as pa
|
||||
|
||||
from ._lancedb import async_permutation_builder, PermutationReader
|
||||
from .table import LanceTable
|
||||
from .background_loop import LOOP
|
||||
@@ -36,10 +37,7 @@ class PermutationBuilder:
|
||||
be referenced by name in the future. If names are not provided then they can only
|
||||
be referenced by their ordinal index. There is no requirement to name every split.
|
||||
|
||||
By default, the permutation will be stored in memory and will be lost when the
|
||||
program exits. To persist the permutation (for very large datasets or to share
|
||||
the permutation across multiple workers) use the [persist](#persist) method to
|
||||
create a permanent table.
|
||||
The permutation is stored in memory and will be lost when the program exits.
|
||||
"""
|
||||
|
||||
def __init__(self, table: LanceTable):
|
||||
@@ -51,15 +49,6 @@ class PermutationBuilder:
|
||||
"""
|
||||
self._async = async_permutation_builder(table)
|
||||
|
||||
def persist(
|
||||
self, database: Union[DBConnection, AsyncConnection], table_name: str
|
||||
) -> "PermutationBuilder":
|
||||
"""
|
||||
Persist the permutation to the given database.
|
||||
"""
|
||||
self._async.persist(database, table_name)
|
||||
return self
|
||||
|
||||
def split_random(
|
||||
self,
|
||||
*,
|
||||
@@ -380,20 +369,44 @@ class Permutation:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reader: PermutationReader,
|
||||
base_table: LanceTable,
|
||||
permutation_table: Optional[LanceTable],
|
||||
split: int,
|
||||
selection: dict[str, str],
|
||||
batch_size: int,
|
||||
transform_fn: Callable[pa.RecordBatch, Any],
|
||||
offset: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
connection_factory: Optional[Callable[[str], LanceTable]] = None,
|
||||
_reader: Optional[PermutationReader] = None,
|
||||
):
|
||||
"""
|
||||
Internal constructor. Use [from_tables](#from_tables) instead.
|
||||
"""
|
||||
assert reader is not None, "reader is required"
|
||||
assert base_table is not None, "base_table is required"
|
||||
assert selection is not None, "selection is required"
|
||||
self.reader = reader
|
||||
self.base_table = base_table
|
||||
self.permutation_table = permutation_table
|
||||
self.split = split
|
||||
self.selection = selection
|
||||
self.transform_fn = transform_fn
|
||||
self.batch_size = batch_size
|
||||
self.offset = offset
|
||||
self.limit = limit
|
||||
self.connection_factory = connection_factory
|
||||
if _reader is None:
|
||||
_reader = LOOP.run(self._build_reader())
|
||||
self.reader: PermutationReader = _reader
|
||||
|
||||
async def _build_reader(self) -> PermutationReader:
|
||||
reader = await PermutationReader.from_tables(
|
||||
self.base_table, self.permutation_table, self.split
|
||||
)
|
||||
if self.offset is not None:
|
||||
reader = await reader.with_offset(self.offset)
|
||||
if self.limit is not None:
|
||||
reader = await reader.with_limit(self.limit)
|
||||
return reader
|
||||
|
||||
def _with_selection(self, selection: dict[str, str]) -> "Permutation":
|
||||
"""
|
||||
@@ -402,21 +415,97 @@ class Permutation:
|
||||
Does not validation of the selection and it replaces it entirely. This is not
|
||||
intended for public use.
|
||||
"""
|
||||
return Permutation(self.reader, selection, self.batch_size, self.transform_fn)
|
||||
|
||||
def _with_reader(self, reader: PermutationReader) -> "Permutation":
|
||||
"""
|
||||
Creates a new permutation with the given reader
|
||||
|
||||
This is an internal method and should not be used directly.
|
||||
"""
|
||||
return Permutation(reader, self.selection, self.batch_size, self.transform_fn)
|
||||
new = copy.copy(self)
|
||||
new.selection = selection
|
||||
return new
|
||||
|
||||
def with_batch_size(self, batch_size: int) -> "Permutation":
|
||||
"""
|
||||
Creates a new permutation with the given batch size
|
||||
"""
|
||||
return Permutation(self.reader, self.selection, batch_size, self.transform_fn)
|
||||
new = copy.copy(self)
|
||||
new.batch_size = batch_size
|
||||
return new
|
||||
|
||||
def with_connection_factory(
|
||||
self, connection_factory: Callable[[str], LanceTable]
|
||||
) -> "Permutation":
|
||||
"""
|
||||
Creates a new permutation that will use ``connection_factory`` to reopen
|
||||
the base table when this permutation is unpickled in a worker process.
|
||||
|
||||
The factory is a callable that takes a single argument — the base table
|
||||
name — and returns a [LanceTable]. It must be picklable; the worker
|
||||
will pickle it via standard ``pickle`` and call it to recover the base
|
||||
table. Picklable callables in practice means top-level (module-level)
|
||||
functions, ``functools.partial`` of such functions, or instances of
|
||||
picklable classes implementing ``__call__``. Lambdas and closures over
|
||||
local variables don't pickle with the default protocol.
|
||||
|
||||
Setting a factory is necessary when the URI alone is not enough to
|
||||
re-open the connection — most importantly for LanceDB Cloud (``db://``)
|
||||
connections, where ``api_key`` and ``region`` aren't recoverable from
|
||||
the connection object after construction.
|
||||
|
||||
For local file or cloud-storage paths the factory is optional: if not
|
||||
set, ``__getstate__`` falls back to capturing
|
||||
``(uri, storage_options, namespace_path)`` and re-opening via
|
||||
``lancedb.connect(uri, storage_options=...)``.
|
||||
|
||||
Examples
|
||||
--------
|
||||
Basic native (file-system path), parameterized via ``functools.partial``::
|
||||
|
||||
import functools, lancedb
|
||||
from lancedb.permutation import Permutation
|
||||
|
||||
def open_native_table(uri: str, table_name: str):
|
||||
return lancedb.connect(uri).open_table(table_name)
|
||||
|
||||
factory = functools.partial(open_native_table, "/data/lance_db")
|
||||
permutation = Permutation.identity(
|
||||
factory("training")
|
||||
).with_connection_factory(factory)
|
||||
|
||||
Native via :func:`lancedb.connect_namespace` (e.g. a directory- or
|
||||
REST-backed namespace client). The factory takes the
|
||||
implementation name and properties dict as partial-bound args so
|
||||
the worker can rebuild the same namespace connection::
|
||||
|
||||
def open_via_namespace(
|
||||
impl: str, properties: dict[str, str], table_name: str,
|
||||
):
|
||||
return lancedb.connect_namespace(impl, properties).open_table(
|
||||
table_name,
|
||||
)
|
||||
|
||||
factory = functools.partial(
|
||||
open_via_namespace,
|
||||
"dir",
|
||||
{"root": "/data/lance_db"},
|
||||
)
|
||||
|
||||
LanceDB Cloud, reading credentials from env vars at worker startup
|
||||
so secrets aren't pickled into the dataset::
|
||||
|
||||
import os, lancedb
|
||||
|
||||
def open_remote_table(table_name: str):
|
||||
db = lancedb.connect(
|
||||
"db://my-database",
|
||||
api_key=os.environ["LANCEDB_API_KEY"],
|
||||
region=os.environ.get("LANCEDB_REGION", "us-east-1"),
|
||||
)
|
||||
return db.open_table(table_name)
|
||||
|
||||
permutation = Permutation.identity(
|
||||
open_remote_table("training")
|
||||
).with_connection_factory(open_remote_table)
|
||||
"""
|
||||
assert connection_factory is not None, "connection_factory is required"
|
||||
new = copy.copy(self)
|
||||
new.connection_factory = connection_factory
|
||||
return new
|
||||
|
||||
@classmethod
|
||||
def identity(cls, table: LanceTable) -> "Permutation":
|
||||
@@ -489,11 +578,126 @@ class Permutation:
|
||||
schema = await reader.output_schema(None)
|
||||
initial_selection = {name: name for name in schema.names}
|
||||
return cls(
|
||||
reader, initial_selection, DEFAULT_BATCH_SIZE, Transforms.arrow2python
|
||||
base_table,
|
||||
permutation_table,
|
||||
split,
|
||||
initial_selection,
|
||||
DEFAULT_BATCH_SIZE,
|
||||
Transforms.arrow2python,
|
||||
_reader=reader,
|
||||
)
|
||||
|
||||
return LOOP.run(do_from_tables())
|
||||
|
||||
def __getstate__(self) -> dict[str, Any]:
|
||||
"""Build a picklable state dict for this permutation.
|
||||
|
||||
The base table is captured either via a user-supplied
|
||||
``connection_factory`` (see [with_connection_factory]) or, as a
|
||||
fallback, by introspecting ``(uri, storage_options, namespace_path)``
|
||||
on the connection. The permutation table — always an in-memory
|
||||
LanceDB table — is captured as a pyarrow Table (which pickles via
|
||||
Arrow IPC natively). The reader is dropped from the wire format;
|
||||
``__setstate__`` rebuilds it from the restored tables.
|
||||
"""
|
||||
permutation_data: Optional[pa.Table] = None
|
||||
if self.permutation_table is not None:
|
||||
permutation_data = self.permutation_table.to_arrow()
|
||||
|
||||
common = {
|
||||
"base_table_name": self.base_table.name,
|
||||
"permutation_data": permutation_data,
|
||||
"split": self.split,
|
||||
"selection": self.selection,
|
||||
"batch_size": self.batch_size,
|
||||
"transform_fn": self.transform_fn,
|
||||
"offset": self.offset,
|
||||
"limit": self.limit,
|
||||
"connection_factory": self.connection_factory,
|
||||
}
|
||||
|
||||
if self.connection_factory is not None:
|
||||
# The factory carries enough state to recover the base table on
|
||||
# its own; we don't need to capture the URI / storage options /
|
||||
# namespace from the existing connection.
|
||||
return common
|
||||
|
||||
# URI-introspection fallback: only viable for native (OSS) connections
|
||||
# where (uri, storage_options) is enough to reopen. Remote / cloud
|
||||
# connections don't expose recoverable api_key / region — those users
|
||||
# must call with_connection_factory().
|
||||
try:
|
||||
base_uri = self.base_table._conn.uri
|
||||
storage_options = self.base_table._conn.storage_options
|
||||
except AttributeError as e:
|
||||
raise ValueError(
|
||||
"Cannot pickle this Permutation: the base table's connection "
|
||||
"does not expose a uri/storage_options, which usually means it "
|
||||
"is a remote (LanceDB Cloud) connection. Call "
|
||||
"Permutation.with_connection_factory(...) first to provide a "
|
||||
"picklable callable that re-opens the base table from a worker "
|
||||
"process."
|
||||
) from e
|
||||
|
||||
if base_uri.startswith("memory://"):
|
||||
# In-memory base tables don't exist in any worker process by
|
||||
# default, so dump the entire base table into the pickle. This
|
||||
# can be expensive for large datasets — users with large
|
||||
# in-memory base tables should either persist them or set a
|
||||
# connection_factory.
|
||||
return {
|
||||
**common,
|
||||
"base_table_data": self.base_table.to_arrow(),
|
||||
}
|
||||
|
||||
return {
|
||||
**common,
|
||||
"base_table_uri": base_uri,
|
||||
"base_table_namespace": self.base_table._namespace_path,
|
||||
"base_table_storage_options": storage_options,
|
||||
}
|
||||
|
||||
def __setstate__(self, state: dict[str, Any]) -> None:
|
||||
from . import connect
|
||||
|
||||
connection_factory = state["connection_factory"]
|
||||
if connection_factory is not None:
|
||||
base_table = connection_factory(state["base_table_name"])
|
||||
elif "base_table_data" in state:
|
||||
# In-memory base table inlined into the pickle; rebuild the same
|
||||
# way we rebuild the in-memory permutation table.
|
||||
mem_db = connect("memory://")
|
||||
base_table = mem_db.create_table(
|
||||
state["base_table_name"], state["base_table_data"]
|
||||
)
|
||||
else:
|
||||
base_db = connect(
|
||||
state["base_table_uri"],
|
||||
storage_options=state["base_table_storage_options"],
|
||||
)
|
||||
base_table = base_db.open_table(
|
||||
state["base_table_name"],
|
||||
namespace_path=state["base_table_namespace"] or None,
|
||||
)
|
||||
|
||||
permutation_table: Optional[LanceTable] = None
|
||||
if state["permutation_data"] is not None:
|
||||
mem_db = connect("memory://")
|
||||
permutation_table = mem_db.create_table(
|
||||
"permutation", state["permutation_data"]
|
||||
)
|
||||
|
||||
self.base_table = base_table
|
||||
self.permutation_table = permutation_table
|
||||
self.split = state["split"]
|
||||
self.selection = state["selection"]
|
||||
self.batch_size = state["batch_size"]
|
||||
self.transform_fn = state["transform_fn"]
|
||||
self.offset = state["offset"]
|
||||
self.limit = state["limit"]
|
||||
self.connection_factory = connection_factory
|
||||
self.reader = LOOP.run(self._build_reader())
|
||||
|
||||
@property
|
||||
def schema(self) -> pa.Schema:
|
||||
async def do_output_schema():
|
||||
@@ -760,7 +964,9 @@ class Permutation:
|
||||
for expensive operations such as image decoding.
|
||||
"""
|
||||
assert transform is not None, "transform is required"
|
||||
return Permutation(self.reader, self.selection, self.batch_size, transform)
|
||||
new = copy.copy(self)
|
||||
new.transform_fn = transform
|
||||
return new
|
||||
|
||||
def __getitem__(self, index: int) -> Any:
|
||||
"""
|
||||
@@ -795,12 +1001,10 @@ class Permutation:
|
||||
"""
|
||||
Skip the first `skip` rows of the permutation
|
||||
"""
|
||||
|
||||
async def do_with_skip():
|
||||
reader = await self.reader.with_offset(skip)
|
||||
return self._with_reader(reader)
|
||||
|
||||
return LOOP.run(do_with_skip())
|
||||
new = copy.copy(self)
|
||||
new.offset = skip
|
||||
new.reader = LOOP.run(new._build_reader())
|
||||
return new
|
||||
|
||||
@deprecated(details="Use with_take instead")
|
||||
def take(self, limit: int) -> "Permutation":
|
||||
@@ -818,12 +1022,10 @@ class Permutation:
|
||||
"""
|
||||
Limit the permutation to `limit` rows (following any `skip`)
|
||||
"""
|
||||
|
||||
async def do_with_take():
|
||||
reader = await self.reader.with_limit(limit)
|
||||
return self._with_reader(reader)
|
||||
|
||||
return LOOP.run(do_with_take())
|
||||
new = copy.copy(self)
|
||||
new.limit = limit
|
||||
new.reader = LOOP.run(new._build_reader())
|
||||
return new
|
||||
|
||||
@deprecated(details="Use with_repeat instead")
|
||||
def repeat(self, times: int) -> "Permutation":
|
||||
|
||||
@@ -1643,7 +1643,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
def _validate_query(self, query, vector=None, text=None):
|
||||
if query is not None and (vector is not None or text is not None):
|
||||
raise ValueError(
|
||||
"You can either provide a string query in search() method"
|
||||
"You can either provide a string query in search() method "
|
||||
"or set `vector()` and `text()` explicitly for hybrid search."
|
||||
"But not both."
|
||||
)
|
||||
|
||||
@@ -9,21 +9,6 @@ from lancedb import DBConnection, Table, connect
|
||||
from lancedb.permutation import Permutation, Permutations, permutation_builder
|
||||
|
||||
|
||||
def test_permutation_persistence(tmp_path):
|
||||
db = connect(tmp_path)
|
||||
tbl = db.create_table("test_table", pa.table({"x": range(100), "y": range(100)}))
|
||||
|
||||
permutation_tbl = (
|
||||
permutation_builder(tbl).shuffle().persist(db, "test_permutation").execute()
|
||||
)
|
||||
assert permutation_tbl.count_rows() == 100
|
||||
|
||||
re_open = db.open_table("test_permutation")
|
||||
assert re_open.count_rows() == 100
|
||||
|
||||
assert permutation_tbl.to_arrow() == re_open.to_arrow()
|
||||
|
||||
|
||||
def test_split_random_ratios(mem_db):
|
||||
"""Test random splitting with ratios."""
|
||||
tbl = mem_db.create_table(
|
||||
|
||||
@@ -6,6 +6,8 @@ import contextlib
|
||||
from datetime import timedelta
|
||||
import http.server
|
||||
import json
|
||||
import multiprocessing as mp
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
@@ -1230,3 +1232,82 @@ def test_background_loop_cancellation(exception):
|
||||
with pytest.raises(exception):
|
||||
loop.run(None)
|
||||
mock_future.cancel.assert_called_once()
|
||||
|
||||
|
||||
def _remote_fork_child(port: int, queue) -> None:
|
||||
# Build a fresh Connection in the child so we exercise the at-fork-child
|
||||
# tokio runtime reset rather than relying on an inherited reqwest client.
|
||||
db = lancedb.connect(
|
||||
"db://dev",
|
||||
api_key="fake",
|
||||
host_override=f"http://localhost:{port}",
|
||||
client_config={
|
||||
"retry_config": {"retries": 0},
|
||||
"timeout_config": {"connect_timeout": 2, "read_timeout": 2},
|
||||
},
|
||||
)
|
||||
queue.put(db.table_names())
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform != "linux",
|
||||
reason=(
|
||||
"fork() is unavailable on Windows and unsafe on macOS "
|
||||
"(Apple frameworks/TLS are not fork-safe)"
|
||||
),
|
||||
)
|
||||
def test_remote_connection_after_fork():
|
||||
"""A freshly-built remote Connection in a forked child should not hang.
|
||||
|
||||
The pyo3-async-runtimes tokio runtime would otherwise be inherited from
|
||||
the parent with dead worker threads; the at-fork-child handler in our
|
||||
runtime module rebuilds it on first use in the child.
|
||||
"""
|
||||
|
||||
def handler(request):
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(b'{"tables": []}')
|
||||
|
||||
server = http.server.HTTPServer(("localhost", 0), make_mock_http_handler(handler))
|
||||
port = server.server_address[1]
|
||||
server_thread = threading.Thread(target=server.serve_forever)
|
||||
server_thread.start()
|
||||
try:
|
||||
# Hit the server in the parent first so the runtime + LOOP are warm
|
||||
# before fork; a fresh child must still succeed.
|
||||
parent_db = lancedb.connect(
|
||||
"db://dev",
|
||||
api_key="fake",
|
||||
host_override=f"http://localhost:{port}",
|
||||
client_config={
|
||||
"retry_config": {"retries": 0},
|
||||
"timeout_config": {"connect_timeout": 2, "read_timeout": 2},
|
||||
},
|
||||
)
|
||||
assert parent_db.table_names() == []
|
||||
|
||||
ctx = mp.get_context("fork")
|
||||
queue = ctx.Queue()
|
||||
proc = ctx.Process(target=_remote_fork_child, args=(port, queue))
|
||||
proc.start()
|
||||
proc.join(timeout=15)
|
||||
|
||||
if proc.is_alive():
|
||||
proc.terminate()
|
||||
proc.join(timeout=5)
|
||||
if proc.is_alive():
|
||||
proc.kill()
|
||||
proc.join()
|
||||
pytest.fail("Remote connection hung after fork")
|
||||
|
||||
assert proc.exitcode == 0, f"child exited with code {proc.exitcode}"
|
||||
assert not queue.empty(), "child produced no result"
|
||||
assert queue.get() == []
|
||||
|
||||
# Parent connection must still be usable after the child returned.
|
||||
assert parent_db.table_names() == []
|
||||
finally:
|
||||
server.shutdown()
|
||||
server_thread.join()
|
||||
|
||||
@@ -1,14 +1,29 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import functools
|
||||
import multiprocessing as mp
|
||||
import pickle
|
||||
import sys
|
||||
|
||||
import lancedb
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
from lancedb.permutation import Permutation, Permutations, permutation_builder
|
||||
from lancedb.util import tbl_to_tensor
|
||||
from lancedb.permutation import Permutation
|
||||
|
||||
torch = pytest.importorskip("torch")
|
||||
|
||||
|
||||
def _open_native_table(uri: str, table_name: str):
|
||||
"""Top-level connection factory used by the explicit-factory pickle test.
|
||||
|
||||
Defined at module scope so that pickle can resolve it by name in the
|
||||
worker / unpickling process.
|
||||
"""
|
||||
return lancedb.connect(uri).open_table(table_name)
|
||||
|
||||
|
||||
def test_table_dataloader(mem_db):
|
||||
table = mem_db.create_table("test_table", pa.table({"a": range(1000)}))
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
@@ -40,3 +55,156 @@ def test_permutation_dataloader(mem_db):
|
||||
for batch in dataloader:
|
||||
assert batch.size(0) == 1
|
||||
assert batch.size(1) == 10
|
||||
|
||||
|
||||
def test_permutation_is_picklable(tmp_db):
|
||||
"""A Permutation must be picklable so it can be used with PyTorch's
|
||||
DataLoader when num_workers > 0 (which uses multiprocessing and pickles
|
||||
the dataset to pass it to worker processes)."""
|
||||
table = tmp_db.create_table("test_table", pa.table({"a": range(1000)}))
|
||||
permutation = Permutation.identity(table)
|
||||
|
||||
pickled = pickle.dumps(permutation)
|
||||
restored = pickle.loads(pickled)
|
||||
|
||||
assert len(restored) == 1000
|
||||
rows = restored.__getitems__([0, 1, 2])
|
||||
assert rows == [{"a": 0}, {"a": 1}, {"a": 2}]
|
||||
|
||||
|
||||
def test_permutation_with_memory_base_is_picklable(mem_db):
|
||||
"""An in-memory base table is inlined into the pickle as Arrow IPC bytes
|
||||
and rebuilt on the other side as an in-memory LanceTable, so the
|
||||
Permutation round-trips even though the original database can't be
|
||||
reopened across processes."""
|
||||
table = mem_db.create_table("test_table", pa.table({"a": range(50)}))
|
||||
permutation = Permutation.identity(table)
|
||||
|
||||
restored = pickle.loads(pickle.dumps(permutation))
|
||||
|
||||
assert len(restored) == 50
|
||||
assert restored.__getitems__([0, 10, 49]) == [{"a": 0}, {"a": 10}, {"a": 49}]
|
||||
|
||||
|
||||
def test_permutation_dataloader_multiprocessing(tmp_db):
|
||||
"""Using a Permutation with a PyTorch DataLoader that has num_workers > 0
|
||||
must work end-to-end. Each worker process gets a pickled copy of the
|
||||
dataset and reads batches from it."""
|
||||
table = tmp_db.create_table("test_table", pa.table({"a": range(1000)}))
|
||||
permutation = Permutation.identity(table)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
permutation,
|
||||
batch_size=10,
|
||||
shuffle=True,
|
||||
num_workers=2,
|
||||
multiprocessing_context="spawn",
|
||||
)
|
||||
seen = 0
|
||||
for batch in dataloader:
|
||||
assert batch["a"].size(0) == 10
|
||||
seen += batch["a"].size(0)
|
||||
assert seen == 1000
|
||||
|
||||
|
||||
def test_permutation_pickle_with_connection_factory(tmp_path):
|
||||
"""When the user provides a connection_factory, pickling should round-trip
|
||||
through that factory rather than introspecting the connection URI. Useful
|
||||
for remote / cloud connections where the URI alone isn't reopenable."""
|
||||
db = lancedb.connect(tmp_path)
|
||||
db.create_table("test_table", pa.table({"a": range(50)}))
|
||||
|
||||
factory = functools.partial(_open_native_table, str(tmp_path))
|
||||
permutation = Permutation.identity(factory("test_table")).with_connection_factory(
|
||||
factory
|
||||
)
|
||||
|
||||
restored = pickle.loads(pickle.dumps(permutation))
|
||||
|
||||
assert len(restored) == 50
|
||||
# The factory survives pickling and is what powered base-table reopen.
|
||||
assert restored.connection_factory is not None
|
||||
assert restored.connection_factory.func is _open_native_table
|
||||
assert restored.__getitems__([0, 1, 2]) == [{"a": 0}, {"a": 1}, {"a": 2}]
|
||||
|
||||
|
||||
def test_permutation_with_builder_is_picklable(tmp_db):
|
||||
"""A Permutation built from a non-identity permutation table must round-trip
|
||||
through pickle while preserving the row order defined by the permutation."""
|
||||
table = tmp_db.create_table("test_table", pa.table({"a": range(100)}))
|
||||
perm_tbl = (
|
||||
permutation_builder(table)
|
||||
.split_random(ratios=[0.8, 0.2], seed=42, split_names=["train", "test"])
|
||||
.shuffle(seed=42)
|
||||
.execute()
|
||||
)
|
||||
permutations = Permutations(table, perm_tbl)
|
||||
permutation = permutations["train"]
|
||||
|
||||
indices = list(range(len(permutation)))
|
||||
expected = permutation.__getitems__(indices)
|
||||
|
||||
restored = pickle.loads(pickle.dumps(permutation))
|
||||
|
||||
assert len(restored) == len(permutation)
|
||||
assert restored.__getitems__(indices) == expected
|
||||
|
||||
|
||||
def _multiworker_dataloader_target(db_uri: str, result_queue):
|
||||
import lancedb
|
||||
from lancedb.permutation import Permutation
|
||||
|
||||
db = lancedb.connect(db_uri)
|
||||
table = db.open_table("test_table")
|
||||
permutation = Permutation.identity(table)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
permutation,
|
||||
batch_size=10,
|
||||
num_workers=2,
|
||||
multiprocessing_context="fork",
|
||||
)
|
||||
count = 0
|
||||
for batch in dataloader:
|
||||
assert batch["a"].size(0) == 10
|
||||
count += 1
|
||||
result_queue.put(count)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform != "linux",
|
||||
reason=(
|
||||
"fork() is unavailable on Windows and unsafe on macOS "
|
||||
"(Apple frameworks/TLS are not fork-safe)"
|
||||
),
|
||||
)
|
||||
def test_permutation_dataloader_fork_workers(tmp_path):
|
||||
"""A Permutation used by a fork-based DataLoader should not hang.
|
||||
|
||||
PyTorch's DataLoader uses fork-based multiprocessing by default on Linux.
|
||||
LanceDB drives async work through a background asyncio thread that does
|
||||
not survive a fork, so any LOOP.run() in a worker blocks forever.
|
||||
"""
|
||||
import lancedb
|
||||
|
||||
db_uri = str(tmp_path / "db")
|
||||
db = lancedb.connect(db_uri)
|
||||
db.create_table("test_table", pa.table({"a": list(range(1000))}))
|
||||
|
||||
ctx = mp.get_context("spawn")
|
||||
queue = ctx.Queue()
|
||||
proc = ctx.Process(target=_multiworker_dataloader_target, args=(db_uri, queue))
|
||||
proc.start()
|
||||
proc.join(timeout=30)
|
||||
|
||||
if proc.is_alive():
|
||||
proc.terminate()
|
||||
proc.join(timeout=5)
|
||||
if proc.is_alive():
|
||||
proc.kill()
|
||||
proc.join()
|
||||
pytest.fail("Permutation hung when iterated in a fork-based DataLoader worker")
|
||||
|
||||
assert proc.exitcode == 0, f"child exited with code {proc.exitcode}"
|
||||
assert not queue.empty(), "child produced no batches"
|
||||
assert queue.get() == 100
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::error::PythonErrorExt;
|
||||
use crate::runtime::future_into_py;
|
||||
use arrow::{
|
||||
datatypes::SchemaRef,
|
||||
pyarrow::{IntoPyArrow, ToPyArrow},
|
||||
@@ -12,9 +14,6 @@ use lancedb::arrow::SendableRecordBatchStream;
|
||||
use pyo3::{
|
||||
Bound, Py, PyAny, PyRef, PyResult, Python, exceptions::PyStopAsyncIteration, pyclass, pymethods,
|
||||
};
|
||||
use pyo3_async_runtimes::tokio::future_into_py;
|
||||
|
||||
use crate::error::PythonErrorExt;
|
||||
|
||||
#[pyclass]
|
||||
pub struct RecordBatchStream {
|
||||
|
||||
@@ -7,6 +7,12 @@ use std::{
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
error::PythonErrorExt,
|
||||
namespace::{create_namespace_storage_options_provider, extract_namespace_arc},
|
||||
runtime::future_into_py,
|
||||
table::Table,
|
||||
};
|
||||
use arrow::{datatypes::Schema, ffi_stream::ArrowArrayStreamReader, pyarrow::FromPyArrow};
|
||||
use lancedb::{
|
||||
connection::Connection as LanceConnection,
|
||||
@@ -20,13 +26,6 @@ use pyo3::{
|
||||
pyclass, pyfunction, pymethods,
|
||||
types::{PyDict, PyDictMethods},
|
||||
};
|
||||
use pyo3_async_runtimes::tokio::future_into_py;
|
||||
|
||||
use crate::{
|
||||
error::PythonErrorExt,
|
||||
namespace::{create_namespace_storage_options_provider, extract_namespace_arc},
|
||||
table::Table,
|
||||
};
|
||||
|
||||
#[pyclass]
|
||||
pub struct Connection {
|
||||
@@ -525,7 +524,7 @@ impl Connection {
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (uri, api_key=None, region=None, host_override=None, read_consistency_interval=None, client_config=None, storage_options=None, session=None))]
|
||||
#[pyo3(signature = (uri, api_key=None, region=None, host_override=None, read_consistency_interval=None, client_config=None, storage_options=None, session=None, manifest_enabled=false, namespace_client_properties=None))]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn connect(
|
||||
py: Python<'_>,
|
||||
@@ -537,6 +536,8 @@ pub fn connect(
|
||||
client_config: Option<PyClientConfig>,
|
||||
storage_options: Option<HashMap<String, String>>,
|
||||
session: Option<crate::session::Session>,
|
||||
manifest_enabled: bool,
|
||||
namespace_client_properties: Option<HashMap<String, String>>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
future_into_py(py, async move {
|
||||
let mut builder = lancedb::connect(&uri);
|
||||
@@ -556,6 +557,12 @@ pub fn connect(
|
||||
if let Some(storage_options) = storage_options {
|
||||
builder = builder.storage_options(storage_options);
|
||||
}
|
||||
if manifest_enabled {
|
||||
builder = builder.manifest_enabled(true);
|
||||
}
|
||||
if let Some(namespace_client_properties) = namespace_client_properties {
|
||||
builder = builder.namespace_client_properties(namespace_client_properties);
|
||||
}
|
||||
#[cfg(feature = "remote")]
|
||||
if let Some(client_config) = client_config {
|
||||
builder = builder.client_config(client_config.into());
|
||||
|
||||
@@ -17,7 +17,7 @@ use pyo3::{Bound, PyAny, PyResult, exceptions::PyValueError, prelude::*, pyfunct
|
||||
/// [`expr_lit`] and combined with the methods on this struct. On the Python
|
||||
/// side a thin wrapper class (`lancedb.expr.Expr`) delegates to these methods
|
||||
/// and adds Python operator overloads.
|
||||
#[pyclass(name = "PyExpr")]
|
||||
#[pyclass(name = "PyExpr", from_py_object)]
|
||||
#[derive(Clone)]
|
||||
pub struct PyExpr(pub DfExpr);
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ impl PyHeaderProvider {
|
||||
Ok(headers_py) => {
|
||||
// Convert Python dict to Rust HashMap
|
||||
let bound_headers = headers_py.bind(py);
|
||||
let dict: &Bound<PyDict> = bound_headers.downcast().map_err(|e| {
|
||||
let dict: &Bound<PyDict> = bound_headers.cast().map_err(|e| {
|
||||
format!("HeaderProvider.get_headers must return a dict: {}", e)
|
||||
})?;
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ use pyo3::{
|
||||
Bound, FromPyObject, PyAny, PyResult, Python,
|
||||
exceptions::{PyKeyError, PyValueError},
|
||||
intern, pyclass, pymethods,
|
||||
types::PyAnyMethods,
|
||||
types::{PyAnyMethods, PyString},
|
||||
};
|
||||
|
||||
use crate::util::parse_distance_type;
|
||||
@@ -22,7 +22,7 @@ pub fn class_name(ob: &'_ Bound<'_, PyAny>) -> PyResult<String> {
|
||||
let full_name = ob
|
||||
.getattr(intern!(ob.py(), "__class__"))?
|
||||
.getattr(intern!(ob.py(), "__name__"))?;
|
||||
let full_name = full_name.downcast()?.to_string_lossy();
|
||||
let full_name = full_name.cast::<PyString>()?.to_string_lossy();
|
||||
|
||||
match full_name.rsplit_once('.') {
|
||||
Some((_, name)) => Ok(name.to_string()),
|
||||
|
||||
@@ -28,6 +28,7 @@ pub mod index;
|
||||
pub mod namespace;
|
||||
pub mod permutation;
|
||||
pub mod query;
|
||||
pub mod runtime;
|
||||
pub mod session;
|
||||
pub mod table;
|
||||
pub mod util;
|
||||
|
||||
@@ -183,7 +183,7 @@ async fn call_py_method_primitive<Req, Resp>(
|
||||
) -> lance_core::Result<Resp>
|
||||
where
|
||||
Req: serde::Serialize + Send + 'static,
|
||||
Resp: for<'py> pyo3::FromPyObject<'py> + Send + 'static,
|
||||
Resp: for<'a, 'py> pyo3::FromPyObject<'a, 'py> + Send + 'static,
|
||||
{
|
||||
let request_json = serde_json::to_string(&request).map_err(|e| {
|
||||
lance_core::Error::io(format!(
|
||||
@@ -203,7 +203,7 @@ where
|
||||
|
||||
// Call the Python method
|
||||
let result = py_namespace.call_method1(py, method_name, (request_arg,))?;
|
||||
let value: Resp = result.extract(py)?;
|
||||
let value: Resp = result.extract(py).map_err(Into::into)?;
|
||||
Ok::<_, PyErr>(value)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use crate::{
|
||||
arrow::RecordBatchStream, connection::Connection, error::PythonErrorExt, table::Table,
|
||||
arrow::RecordBatchStream, error::PythonErrorExt, runtime::future_into_py, table::Table,
|
||||
};
|
||||
use arrow::pyarrow::{PyArrowType, ToPyArrow};
|
||||
use lancedb::{
|
||||
@@ -21,16 +21,15 @@ use pyo3::{
|
||||
pyclass, pymethods,
|
||||
types::{PyAnyMethods, PyDict, PyDictMethods, PyType},
|
||||
};
|
||||
use pyo3_async_runtimes::tokio::future_into_py;
|
||||
|
||||
fn table_from_py<'a>(table: Bound<'a, PyAny>) -> PyResult<Bound<'a, Table>> {
|
||||
if table.hasattr("_inner")? {
|
||||
Ok(table.getattr("_inner")?.downcast_into::<Table>()?)
|
||||
Ok(table.getattr("_inner")?.cast_into::<Table>()?)
|
||||
} else if table.hasattr("_table")? {
|
||||
Ok(table
|
||||
.getattr("_table")?
|
||||
.getattr("_inner")?
|
||||
.downcast_into::<Table>()?)
|
||||
.cast_into::<Table>()?)
|
||||
} else {
|
||||
Err(PyRuntimeError::new_err(
|
||||
"Provided table does not appear to be a Table or RemoteTable instance",
|
||||
@@ -80,24 +79,6 @@ impl PyAsyncPermutationBuilder {
|
||||
|
||||
#[pymethods]
|
||||
impl PyAsyncPermutationBuilder {
|
||||
#[pyo3(signature = (database, table_name))]
|
||||
pub fn persist(
|
||||
slf: PyRefMut<'_, Self>,
|
||||
database: Bound<'_, PyAny>,
|
||||
table_name: String,
|
||||
) -> PyResult<Self> {
|
||||
let conn = if database.hasattr("_conn")? {
|
||||
database
|
||||
.getattr("_conn")?
|
||||
.getattr("_inner")?
|
||||
.downcast_into::<Connection>()?
|
||||
} else {
|
||||
database.getattr("_inner")?.downcast_into::<Connection>()?
|
||||
};
|
||||
let database = conn.borrow().database()?;
|
||||
slf.modify(|builder| builder.persist(database, table_name))
|
||||
}
|
||||
|
||||
#[pyo3(signature = (*, ratios=None, counts=None, fixed=None, seed=None, split_names=None))]
|
||||
pub fn split_random(
|
||||
slf: PyRefMut<'_, Self>,
|
||||
@@ -243,7 +224,7 @@ impl PyPermutationReader {
|
||||
let Some(selection) = selection else {
|
||||
return Ok(Select::All);
|
||||
};
|
||||
let selection = selection.downcast_into::<PyDict>()?;
|
||||
let selection = selection.cast_into::<PyDict>()?;
|
||||
let selection = selection
|
||||
.iter()
|
||||
.map(|(key, value)| {
|
||||
|
||||
@@ -4,6 +4,11 @@
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::expr::PyExpr;
|
||||
use crate::runtime::future_into_py;
|
||||
use crate::util::parse_distance_type;
|
||||
use crate::{arrow::RecordBatchStream, util::PyLanceDB};
|
||||
use crate::{error::PythonErrorExt, index::class_name};
|
||||
use arrow::array::Array;
|
||||
use arrow::array::ArrayData;
|
||||
use arrow::array::make_array;
|
||||
@@ -33,19 +38,16 @@ use pyo3::pyfunction;
|
||||
use pyo3::pymethods;
|
||||
use pyo3::types::PyList;
|
||||
use pyo3::types::{PyDict, PyString};
|
||||
use pyo3::{FromPyObject, exceptions::PyRuntimeError};
|
||||
use pyo3::{Borrowed, FromPyObject, exceptions::PyRuntimeError};
|
||||
use pyo3::{PyErr, pyclass};
|
||||
use pyo3::{exceptions::PyValueError, intern};
|
||||
use pyo3_async_runtimes::tokio::future_into_py;
|
||||
|
||||
use crate::expr::PyExpr;
|
||||
use crate::util::parse_distance_type;
|
||||
use crate::{arrow::RecordBatchStream, util::PyLanceDB};
|
||||
use crate::{error::PythonErrorExt, index::class_name};
|
||||
impl<'a, 'py> FromPyObject<'a, 'py> for PyLanceDB<FtsQuery> {
|
||||
type Error = PyErr;
|
||||
|
||||
impl FromPyObject<'_> for PyLanceDB<FtsQuery> {
|
||||
fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
|
||||
match class_name(ob)?.as_str() {
|
||||
fn extract(ob: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
|
||||
let ob = ob.to_owned();
|
||||
match class_name(&ob)?.as_str() {
|
||||
"MatchQuery" => {
|
||||
let query = ob.getattr("query")?.extract()?;
|
||||
let column = ob.getattr("column")?.extract()?;
|
||||
@@ -424,7 +426,7 @@ impl Query {
|
||||
"Query text is required for nearest_to_text",
|
||||
))?;
|
||||
|
||||
let query = if let Ok(query_text) = fts_query.downcast::<PyString>() {
|
||||
let query = if let Ok(query_text) = fts_query.cast::<PyString>() {
|
||||
let mut query_text = query_text.to_string();
|
||||
let columns = query
|
||||
.get_item("columns")?
|
||||
@@ -606,7 +608,7 @@ impl TakeQuery {
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
#[pyclass(from_py_object)]
|
||||
#[derive(Clone)]
|
||||
pub struct FTSQuery {
|
||||
inner: LanceDbQuery,
|
||||
@@ -735,7 +737,7 @@ impl FTSQuery {
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
#[pyclass(from_py_object)]
|
||||
#[derive(Clone)]
|
||||
pub struct VectorQuery {
|
||||
inner: LanceDbVectorQuery,
|
||||
|
||||
142
python/src/runtime.rs
Normal file
142
python/src/runtime.rs
Normal file
@@ -0,0 +1,142 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
//! Fork-safe wrapper around tokio + pyo3-async-runtimes.
|
||||
//!
|
||||
//! `pyo3_async_runtimes::tokio` keeps its multi-threaded runtime in a
|
||||
//! `OnceLock` that can never be replaced. Tokio's worker threads do not
|
||||
//! survive `fork()`, so once a child inherits a "frozen" runtime, every
|
||||
//! `future_into_py` call hangs forever.
|
||||
//!
|
||||
//! We sidestep the global by routing every future through our own
|
||||
//! [`LanceRuntime`] (a [`pyo3_async_runtimes::generic::Runtime`] impl) backed
|
||||
//! by an [`AtomicPtr`] to a tokio runtime that we own. A `pthread_atfork`
|
||||
//! child handler nulls the pointer; the next `spawn` rebuilds the runtime in
|
||||
//! the child. This mirrors the pattern used in the Lance Python bindings.
|
||||
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::atomic::{AtomicBool, AtomicPtr, Ordering};
|
||||
|
||||
use pyo3::{Bound, PyAny, PyResult, Python, conversion::IntoPyObject};
|
||||
use pyo3_async_runtimes::{
|
||||
TaskLocals,
|
||||
generic::{ContextExt, JoinError, Runtime},
|
||||
};
|
||||
use tokio::{runtime, task};
|
||||
|
||||
static RUNTIME: AtomicPtr<runtime::Runtime> = AtomicPtr::new(std::ptr::null_mut());
|
||||
static RUNTIME_INSTALLING: AtomicBool = AtomicBool::new(false);
|
||||
static ATFORK_INSTALLED: AtomicBool = AtomicBool::new(false);
|
||||
|
||||
fn create_runtime() -> runtime::Runtime {
|
||||
runtime::Builder::new_multi_thread()
|
||||
.enable_all()
|
||||
.thread_name("lancedb-tokio-worker")
|
||||
.build()
|
||||
.expect("Failed to build tokio runtime")
|
||||
}
|
||||
|
||||
fn get_runtime() -> &'static runtime::Runtime {
|
||||
loop {
|
||||
let ptr = RUNTIME.load(Ordering::SeqCst);
|
||||
if !ptr.is_null() {
|
||||
return unsafe { &*ptr };
|
||||
}
|
||||
if !RUNTIME_INSTALLING.fetch_or(true, Ordering::SeqCst) {
|
||||
break;
|
||||
}
|
||||
std::thread::yield_now();
|
||||
}
|
||||
if !ATFORK_INSTALLED.fetch_or(true, Ordering::SeqCst) {
|
||||
install_atfork();
|
||||
}
|
||||
let new_ptr = Box::into_raw(Box::new(create_runtime()));
|
||||
RUNTIME.store(new_ptr, Ordering::SeqCst);
|
||||
unsafe { &*new_ptr }
|
||||
}
|
||||
|
||||
/// Runs in async-signal context after `fork()` in the child. We can only
|
||||
/// touch atomics here; we deliberately leak the previous runtime because
|
||||
/// dropping a tokio `Runtime` would try to join its (now-dead) worker
|
||||
/// threads and hang.
|
||||
extern "C" fn atfork_child() {
|
||||
RUNTIME.store(std::ptr::null_mut(), Ordering::SeqCst);
|
||||
RUNTIME_INSTALLING.store(false, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
#[cfg(not(windows))]
|
||||
fn install_atfork() {
|
||||
unsafe { libc::pthread_atfork(None, None, Some(atfork_child)) };
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
fn install_atfork() {}
|
||||
|
||||
/// Marker type implementing [`Runtime`] over our fork-safe runtime slot.
|
||||
pub struct LanceRuntime;
|
||||
|
||||
/// Newtype wrapper around `tokio::task::JoinError` so we can implement the
|
||||
/// foreign [`JoinError`] trait without violating orphan rules.
|
||||
pub struct LanceJoinError(task::JoinError);
|
||||
|
||||
impl JoinError for LanceJoinError {
|
||||
fn is_panic(&self) -> bool {
|
||||
self.0.is_panic()
|
||||
}
|
||||
fn into_panic(self) -> Box<dyn std::any::Any + Send + 'static> {
|
||||
self.0.into_panic()
|
||||
}
|
||||
}
|
||||
|
||||
impl Runtime for LanceRuntime {
|
||||
type JoinError = LanceJoinError;
|
||||
type JoinHandle = Pin<Box<dyn Future<Output = Result<(), Self::JoinError>> + Send>>;
|
||||
|
||||
fn spawn<F>(fut: F) -> Self::JoinHandle
|
||||
where
|
||||
F: Future<Output = ()> + Send + 'static,
|
||||
{
|
||||
let handle = get_runtime().spawn(fut);
|
||||
Box::pin(async move { handle.await.map_err(LanceJoinError) })
|
||||
}
|
||||
|
||||
fn spawn_blocking<F>(f: F) -> Self::JoinHandle
|
||||
where
|
||||
F: FnOnce() + Send + 'static,
|
||||
{
|
||||
let handle = get_runtime().spawn_blocking(f);
|
||||
Box::pin(async move { handle.await.map_err(LanceJoinError) })
|
||||
}
|
||||
}
|
||||
|
||||
tokio::task_local! {
|
||||
static TASK_LOCALS: std::cell::OnceCell<TaskLocals>;
|
||||
}
|
||||
|
||||
impl ContextExt for LanceRuntime {
|
||||
fn scope<F, R>(locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R> + Send>>
|
||||
where
|
||||
F: Future<Output = R> + Send + 'static,
|
||||
{
|
||||
let cell = std::cell::OnceCell::new();
|
||||
cell.set(locals).unwrap();
|
||||
Box::pin(TASK_LOCALS.scope(cell, fut))
|
||||
}
|
||||
|
||||
fn get_task_locals() -> Option<TaskLocals> {
|
||||
TASK_LOCALS
|
||||
.try_with(|c| c.get().cloned())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Drop-in replacement for `pyo3_async_runtimes::tokio::future_into_py` that
|
||||
/// uses our fork-safe runtime.
|
||||
pub fn future_into_py<F, T>(py: Python<'_>, fut: F) -> PyResult<Bound<'_, PyAny>>
|
||||
where
|
||||
F: Future<Output = PyResult<T>> + Send + 'static,
|
||||
T: for<'py> IntoPyObject<'py> + Send + 'static,
|
||||
{
|
||||
pyo3_async_runtimes::generic::future_into_py::<LanceRuntime, _, T>(py, fut)
|
||||
}
|
||||
@@ -11,7 +11,7 @@ use pyo3::{PyResult, pyclass, pymethods};
|
||||
/// Sessions allow you to configure cache sizes for index and metadata caches,
|
||||
/// which can significantly impact memory use and performance. They can
|
||||
/// also be re-used across multiple connections to share the same cache state.
|
||||
#[pyclass]
|
||||
#[pyclass(from_py_object)]
|
||||
#[derive(Clone)]
|
||||
pub struct Session {
|
||||
pub(crate) inner: Arc<LanceSession>,
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use crate::runtime::future_into_py;
|
||||
use crate::{
|
||||
connection::Connection,
|
||||
error::PythonErrorExt,
|
||||
@@ -24,12 +25,11 @@ use pyo3::{
|
||||
pyclass, pymethods,
|
||||
types::{IntoPyDict, PyAnyMethods, PyDict, PyDictMethods},
|
||||
};
|
||||
use pyo3_async_runtimes::tokio::future_into_py;
|
||||
|
||||
mod scannable;
|
||||
|
||||
/// Statistics about a compaction operation.
|
||||
#[pyclass(get_all)]
|
||||
#[pyclass(get_all, from_py_object)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct CompactionStats {
|
||||
/// The number of fragments removed
|
||||
@@ -43,7 +43,7 @@ pub struct CompactionStats {
|
||||
}
|
||||
|
||||
/// Statistics about a cleanup operation
|
||||
#[pyclass(get_all)]
|
||||
#[pyclass(get_all, from_py_object)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct RemovalStats {
|
||||
/// The number of bytes removed
|
||||
@@ -53,7 +53,7 @@ pub struct RemovalStats {
|
||||
}
|
||||
|
||||
/// Statistics about an optimize operation
|
||||
#[pyclass(get_all)]
|
||||
#[pyclass(get_all, from_py_object)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct OptimizeStats {
|
||||
/// Statistics about the compaction operation
|
||||
@@ -62,7 +62,7 @@ pub struct OptimizeStats {
|
||||
pub prune: RemovalStats,
|
||||
}
|
||||
|
||||
#[pyclass(get_all)]
|
||||
#[pyclass(get_all, from_py_object)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct UpdateResult {
|
||||
pub rows_updated: u64,
|
||||
@@ -88,7 +88,7 @@ impl From<lancedb::table::UpdateResult> for UpdateResult {
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(get_all)]
|
||||
#[pyclass(get_all, from_py_object)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AddResult {
|
||||
pub version: u64,
|
||||
@@ -109,7 +109,7 @@ impl From<lancedb::table::AddResult> for AddResult {
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(get_all)]
|
||||
#[pyclass(get_all, from_py_object)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct DeleteResult {
|
||||
pub num_deleted_rows: u64,
|
||||
@@ -135,7 +135,7 @@ impl From<lancedb::table::DeleteResult> for DeleteResult {
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(get_all)]
|
||||
#[pyclass(get_all, from_py_object)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct MergeResult {
|
||||
pub version: u64,
|
||||
@@ -171,7 +171,7 @@ impl From<lancedb::table::MergeResult> for MergeResult {
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(get_all)]
|
||||
#[pyclass(get_all, from_py_object)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AddColumnsResult {
|
||||
pub version: u64,
|
||||
@@ -192,7 +192,7 @@ impl From<lancedb::table::AddColumnsResult> for AddColumnsResult {
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(get_all)]
|
||||
#[pyclass(get_all, from_py_object)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AlterColumnsResult {
|
||||
pub version: u64,
|
||||
@@ -213,7 +213,7 @@ impl From<lancedb::table::AlterColumnsResult> for AlterColumnsResult {
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(get_all)]
|
||||
#[pyclass(get_all, from_py_object)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct DropColumnsResult {
|
||||
pub version: u64,
|
||||
|
||||
@@ -126,8 +126,11 @@ impl Scannable for PyScannable {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'py> FromPyObject<'py> for PyScannable {
|
||||
fn extract_bound(ob: &pyo3::Bound<'py, PyAny>) -> pyo3::PyResult<Self> {
|
||||
impl<'a, 'py> FromPyObject<'a, 'py> for PyScannable {
|
||||
type Error = pyo3::PyErr;
|
||||
|
||||
fn extract(ob: pyo3::Borrowed<'a, 'py, PyAny>) -> pyo3::PyResult<Self> {
|
||||
let ob = ob.to_owned();
|
||||
// Convert from Scannable dataclass.
|
||||
let schema: PyArrowType<Schema> = ob.getattr("schema")?.extract()?;
|
||||
let schema = Arc::new(schema.0);
|
||||
|
||||
Reference in New Issue
Block a user