mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-23 06:50:40 +00:00
Compare commits
2 Commits
v0.28.0-be
...
ticket/324
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ff08a996fc | ||
|
|
049a689a1c |
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.28.0-beta.11"
|
||||
current_version = "0.28.0-beta.10"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
406
Cargo.lock
generated
406
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
62
Cargo.toml
62
Cargo.toml
@@ -13,40 +13,40 @@ categories = ["database-implementations"]
|
||||
rust-version = "1.91.0"
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=6.0.0-beta.7", default-features = false, "tag" = "v6.0.0-beta.7", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-core = { "version" = "=6.0.0-beta.7", "tag" = "v6.0.0-beta.7", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-datagen = { "version" = "=6.0.0-beta.7", "tag" = "v6.0.0-beta.7", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-file = { "version" = "=6.0.0-beta.7", "tag" = "v6.0.0-beta.7", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-io = { "version" = "=6.0.0-beta.7", default-features = false, "tag" = "v6.0.0-beta.7", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-index = { "version" = "=6.0.0-beta.7", "tag" = "v6.0.0-beta.7", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-linalg = { "version" = "=6.0.0-beta.7", "tag" = "v6.0.0-beta.7", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-namespace = { "version" = "=6.0.0-beta.7", "tag" = "v6.0.0-beta.7", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-namespace-impls = { "version" = "=6.0.0-beta.7", default-features = false, "tag" = "v6.0.0-beta.7", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-table = { "version" = "=6.0.0-beta.7", "tag" = "v6.0.0-beta.7", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-testing = { "version" = "=6.0.0-beta.7", "tag" = "v6.0.0-beta.7", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-datafusion = { "version" = "=6.0.0-beta.7", "tag" = "v6.0.0-beta.7", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-encoding = { "version" = "=6.0.0-beta.7", "tag" = "v6.0.0-beta.7", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-arrow = { "version" = "=6.0.0-beta.7", "tag" = "v6.0.0-beta.7", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance = { "version" = "=6.0.0-beta.4", default-features = false, "tag" = "v6.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-core = { "version" = "=6.0.0-beta.4", "tag" = "v6.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-datagen = { "version" = "=6.0.0-beta.4", "tag" = "v6.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-file = { "version" = "=6.0.0-beta.4", "tag" = "v6.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-io = { "version" = "=6.0.0-beta.4", default-features = false, "tag" = "v6.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-index = { "version" = "=6.0.0-beta.4", "tag" = "v6.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-linalg = { "version" = "=6.0.0-beta.4", "tag" = "v6.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-namespace = { "version" = "=6.0.0-beta.4", "tag" = "v6.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-namespace-impls = { "version" = "=6.0.0-beta.4", default-features = false, "tag" = "v6.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-table = { "version" = "=6.0.0-beta.4", "tag" = "v6.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-testing = { "version" = "=6.0.0-beta.4", "tag" = "v6.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-datafusion = { "version" = "=6.0.0-beta.4", "tag" = "v6.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-encoding = { "version" = "=6.0.0-beta.4", "tag" = "v6.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-arrow = { "version" = "=6.0.0-beta.4", "tag" = "v6.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
|
||||
ahash = "0.8"
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "58.0.0", optional = false }
|
||||
arrow-array = "58.0.0"
|
||||
arrow-data = "58.0.0"
|
||||
arrow-ipc = "58.0.0"
|
||||
arrow-ord = "58.0.0"
|
||||
arrow-schema = "58.0.0"
|
||||
arrow-select = "58.0.0"
|
||||
arrow-cast = "58.0.0"
|
||||
arrow = { version = "57.2", optional = false }
|
||||
arrow-array = "57.2"
|
||||
arrow-data = "57.2"
|
||||
arrow-ipc = "57.2"
|
||||
arrow-ord = "57.2"
|
||||
arrow-schema = "57.2"
|
||||
arrow-select = "57.2"
|
||||
arrow-cast = "57.2"
|
||||
async-trait = "0"
|
||||
datafusion = { version = "53.0.0", default-features = false }
|
||||
datafusion-catalog = "53.0.0"
|
||||
datafusion-common = { version = "53.0.0", default-features = false }
|
||||
datafusion-execution = "53.0.0"
|
||||
datafusion-expr = "53.0.0"
|
||||
datafusion-functions = "53.0.0"
|
||||
datafusion-physical-plan = "53.0.0"
|
||||
datafusion-physical-expr = "53.0.0"
|
||||
datafusion-sql = "53.0.0"
|
||||
datafusion = { version = "52.1", default-features = false }
|
||||
datafusion-catalog = "52.1"
|
||||
datafusion-common = { version = "52.1", default-features = false }
|
||||
datafusion-execution = "52.1"
|
||||
datafusion-expr = "52.1"
|
||||
datafusion-functions = "52.1"
|
||||
datafusion-physical-plan = "52.1"
|
||||
datafusion-physical-expr = "52.1"
|
||||
datafusion-sql = "52.1"
|
||||
env_logger = "0.11"
|
||||
half = { "version" = "2.7.1", default-features = false, features = [
|
||||
"num-traits",
|
||||
|
||||
@@ -14,7 +14,7 @@ Add the following dependency to your `pom.xml`:
|
||||
<dependency>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-core</artifactId>
|
||||
<version>0.28.0-beta.11</version>
|
||||
<version>0.28.0-beta.10</version>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
<parent>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.28.0-beta.11</version>
|
||||
<version>0.28.0-beta.10</version>
|
||||
<relativePath>../pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.28.0-beta.11</version>
|
||||
<version>0.28.0-beta.10</version>
|
||||
<packaging>pom</packaging>
|
||||
<name>${project.artifactId}</name>
|
||||
<description>LanceDB Java SDK Parent POM</description>
|
||||
@@ -28,7 +28,7 @@
|
||||
<properties>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
<arrow.version>15.0.0</arrow.version>
|
||||
<lance-core.version>6.0.0-beta.7</lance-core.version>
|
||||
<lance-core.version>6.0.0-beta.4</lance-core.version>
|
||||
<spotless.skip>false</spotless.skip>
|
||||
<spotless.version>2.30.0</spotless.version>
|
||||
<spotless.java.googlejavaformat.version>1.7</spotless.java.googlejavaformat.version>
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "lancedb-nodejs"
|
||||
edition.workspace = true
|
||||
version = "0.28.0-beta.11"
|
||||
version = "0.28.0-beta.10"
|
||||
publish = false
|
||||
license.workspace = true
|
||||
description.workspace = true
|
||||
@@ -16,7 +16,7 @@ crate-type = ["cdylib"]
|
||||
async-trait.workspace = true
|
||||
arrow-ipc.workspace = true
|
||||
arrow-array.workspace = true
|
||||
arrow-buffer = "58.0.0"
|
||||
arrow-buffer = "57.2"
|
||||
half.workspace = true
|
||||
arrow-schema.workspace = true
|
||||
env_logger.workspace = true
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-arm64",
|
||||
"version": "0.28.0-beta.11",
|
||||
"version": "0.28.0-beta.10",
|
||||
"os": ["darwin"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.darwin-arm64.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
||||
"version": "0.28.0-beta.11",
|
||||
"version": "0.28.0-beta.10",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-musl",
|
||||
"version": "0.28.0-beta.11",
|
||||
"version": "0.28.0-beta.10",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-gnu",
|
||||
"version": "0.28.0-beta.11",
|
||||
"version": "0.28.0-beta.10",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-musl",
|
||||
"version": "0.28.0-beta.11",
|
||||
"version": "0.28.0-beta.10",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-arm64-msvc",
|
||||
"version": "0.28.0-beta.11",
|
||||
"version": "0.28.0-beta.10",
|
||||
"os": [
|
||||
"win32"
|
||||
],
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-x64-msvc",
|
||||
"version": "0.28.0-beta.11",
|
||||
"version": "0.28.0-beta.10",
|
||||
"os": ["win32"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.win32-x64-msvc.node",
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
"ann"
|
||||
],
|
||||
"private": false,
|
||||
"version": "0.28.0-beta.11",
|
||||
"version": "0.28.0-beta.10",
|
||||
"main": "dist/index.js",
|
||||
"exports": {
|
||||
".": "./dist/index.js",
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.31.0-beta.11"
|
||||
current_version = "0.31.0-beta.10"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb-python"
|
||||
version = "0.31.0-beta.11"
|
||||
version = "0.31.0-beta.10"
|
||||
publish = false
|
||||
edition.workspace = true
|
||||
description = "Python bindings for LanceDB"
|
||||
@@ -15,7 +15,7 @@ name = "_lancedb"
|
||||
crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
arrow = { version = "58.0.0", features = ["pyarrow"] }
|
||||
arrow = { version = "57.2", features = ["pyarrow"] }
|
||||
async-trait = "0.1"
|
||||
bytes = "1"
|
||||
lancedb = { path = "../rust/lancedb", default-features = false }
|
||||
@@ -25,8 +25,8 @@ lance-namespace-impls.workspace = true
|
||||
lance-io.workspace = true
|
||||
env_logger.workspace = true
|
||||
log.workspace = true
|
||||
pyo3 = { version = "0.28", features = ["extension-module", "abi3-py39"] }
|
||||
pyo3-async-runtimes = { version = "0.28", features = [
|
||||
pyo3 = { version = "0.26", features = ["extension-module", "abi3-py39"] }
|
||||
pyo3-async-runtimes = { version = "0.26", features = [
|
||||
"attributes",
|
||||
"tokio-runtime",
|
||||
] }
|
||||
@@ -38,7 +38,7 @@ snafu.workspace = true
|
||||
tokio = { version = "1.40", features = ["sync"] }
|
||||
|
||||
[build-dependencies]
|
||||
pyo3-build-config = { version = "0.28", features = [
|
||||
pyo3-build-config = { version = "0.26", features = [
|
||||
"extension-module",
|
||||
"abi3-py39",
|
||||
] }
|
||||
|
||||
230
python/python/lancedb/integrations/torch.py
Normal file
230
python/python/lancedb/integrations/torch.py
Normal file
@@ -0,0 +1,230 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
"""
|
||||
PyTorch integration for LanceDB.
|
||||
|
||||
Exposes ``LanceTorchDataset`` (map-style) and ``LanceIterableTorchDataset``
|
||||
(iterable-style) wrappers that adapt a LanceDB table or permutation to the
|
||||
PyTorch ``torch.utils.data`` API, while transparently handling the bits
|
||||
that make a hand-rolled subclass tricky:
|
||||
|
||||
* The underlying Lance reader holds Rust state that is not picklable, but
|
||||
``DataLoader(num_workers > 0)`` needs to fork the dataset to its workers.
|
||||
These classes strip the reader on pickle and re-open it in the worker on
|
||||
first read.
|
||||
* Constructing a permutation from a table involves several steps
|
||||
(``permutation_builder``/``Permutation.from_tables``/``select_columns``
|
||||
/``with_format``/...). The wrapper takes those as constructor arguments
|
||||
and applies them once the dataset is opened in the worker.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> import lancedb, torch # doctest: +SKIP
|
||||
>>> from lancedb.integrations.torch import LanceTorchDataset
|
||||
>>> db = lancedb.connect(uri) # doctest: +SKIP
|
||||
>>> tbl = db.open_table("images_224") # doctest: +SKIP
|
||||
>>> ds = LanceTorchDataset( # doctest: +SKIP
|
||||
... tbl, columns=["image_bytes", "label"], format="torch"
|
||||
... )
|
||||
>>> loader = torch.utils.data.DataLoader( # doctest: +SKIP
|
||||
... ds, batch_size=64, num_workers=4, shuffle=True,
|
||||
... )
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch.utils.data as _torch_data
|
||||
|
||||
from ..permutation import Permutation
|
||||
from ..table import LanceTable
|
||||
|
||||
|
||||
def _capture_table_state(table: LanceTable) -> Dict[str, Any]:
|
||||
"""Pull just enough state out of a LanceTable so we can re-open the same
|
||||
table in a forked worker process where the Rust handle isn't valid."""
|
||||
conn = table._conn
|
||||
connect_kwargs: Dict[str, Any] = {}
|
||||
storage_options = getattr(conn, "storage_options", None)
|
||||
if storage_options is not None:
|
||||
connect_kwargs["storage_options"] = storage_options
|
||||
return {
|
||||
"uri": conn.uri,
|
||||
"table_name": table.name,
|
||||
"connect_kwargs": connect_kwargs,
|
||||
}
|
||||
|
||||
|
||||
def _open_permutation(state: Dict[str, Any]) -> Permutation:
|
||||
"""Reconstruct a Permutation from a captured state dict."""
|
||||
import lancedb
|
||||
|
||||
db = lancedb.connect(state["uri"], **state["connect_kwargs"])
|
||||
base = db.open_table(state["table_name"])
|
||||
|
||||
perm_table_name = state.get("perm_table_name")
|
||||
if perm_table_name is not None:
|
||||
perm_tbl = db.open_table(perm_table_name)
|
||||
perm = Permutation.from_tables(base, perm_tbl, state.get("split"))
|
||||
else:
|
||||
perm = Permutation.identity(base)
|
||||
|
||||
columns = state.get("columns")
|
||||
fmt = state.get("format")
|
||||
transform = state.get("transform")
|
||||
batch_size = state.get("batch_size")
|
||||
|
||||
if columns is not None:
|
||||
perm = perm.select_columns(columns)
|
||||
if fmt is not None:
|
||||
perm = perm.with_format(fmt)
|
||||
if transform is not None:
|
||||
perm = perm.with_transform(transform)
|
||||
if batch_size is not None:
|
||||
perm = perm.with_batch_size(batch_size)
|
||||
return perm
|
||||
|
||||
|
||||
class LanceTorchDataset(_torch_data.Dataset):
|
||||
"""
|
||||
A PyTorch map-style ``Dataset`` backed by a LanceDB table or permutation.
|
||||
|
||||
Pass the same ``LanceTable`` you already opened (and, optionally, a
|
||||
permutation table / split / column selection / output format) and use
|
||||
the result anywhere a ``torch.utils.data.Dataset`` is expected.
|
||||
|
||||
The wrapper:
|
||||
|
||||
* Stores the URI / table name / storage options needed to re-open the
|
||||
table, not the Rust reader handle. Pickling keeps only the rebuild
|
||||
recipe, so ``DataLoader(num_workers > 0)`` works out of the box.
|
||||
* Implements both ``__getitem__`` and PyTorch's ``__getitems__`` dunder
|
||||
so the underlying batched ``Permutation.fetch`` is used when the
|
||||
DataLoader fetches a batch of indices.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
table : LanceTable, optional
|
||||
The base table to read from. Either ``table`` or both ``uri`` and
|
||||
``table_name`` must be provided.
|
||||
uri : str, optional
|
||||
Database URI to reconnect to. Required if ``table`` is not given.
|
||||
table_name : str, optional
|
||||
Name of the base table within ``uri``.
|
||||
connect_kwargs : dict, optional
|
||||
Extra keyword arguments forwarded to ``lancedb.connect`` when
|
||||
re-opening the database in a worker.
|
||||
permutation_table : LanceTable, optional
|
||||
A pre-built permutation table (see ``permutation_builder``) used to
|
||||
define the row ordering. If omitted, the identity permutation is
|
||||
used (rows in physical order).
|
||||
split : str or int, optional
|
||||
Split selector when ``permutation_table`` defines splits.
|
||||
columns : list[str], optional
|
||||
Subset of columns to read.
|
||||
format : str, optional
|
||||
Output format, forwarded to ``Permutation.with_format`` (e.g.
|
||||
``"torch"`` for HuggingFace-style ``dict[str, Tensor]`` batches).
|
||||
transform : Callable, optional
|
||||
Custom batch transform, forwarded to ``Permutation.with_transform``.
|
||||
Must be picklable to work with ``num_workers > 0``.
|
||||
batch_size : int, optional
|
||||
Forwarded to ``Permutation.with_batch_size`` for direct iteration.
|
||||
DataLoader controls its own batching, so this only matters if the
|
||||
dataset is iterated directly.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
table: Optional[LanceTable] = None,
|
||||
*,
|
||||
uri: Optional[str] = None,
|
||||
table_name: Optional[str] = None,
|
||||
connect_kwargs: Optional[Dict[str, Any]] = None,
|
||||
permutation_table: Optional[LanceTable] = None,
|
||||
split: Optional[Union[str, int]] = None,
|
||||
columns: Optional[List[str]] = None,
|
||||
format: Optional[str] = None,
|
||||
transform: Optional[Callable] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
):
|
||||
if table is None and (uri is None or table_name is None):
|
||||
raise ValueError(
|
||||
"Provide either `table` or both `uri` and `table_name`."
|
||||
)
|
||||
|
||||
if table is not None:
|
||||
state = _capture_table_state(table)
|
||||
if connect_kwargs is not None:
|
||||
state["connect_kwargs"] = connect_kwargs
|
||||
else:
|
||||
state = {
|
||||
"uri": uri,
|
||||
"table_name": table_name,
|
||||
"connect_kwargs": connect_kwargs or {},
|
||||
}
|
||||
|
||||
state["perm_table_name"] = (
|
||||
permutation_table.name if permutation_table is not None else None
|
||||
)
|
||||
state["split"] = split
|
||||
state["columns"] = columns
|
||||
state["format"] = format
|
||||
state["transform"] = transform
|
||||
state["batch_size"] = batch_size
|
||||
|
||||
self._state: Dict[str, Any] = state
|
||||
self._perm: Optional[Permutation] = None
|
||||
|
||||
def __getstate__(self) -> Dict[str, Any]:
|
||||
# Strip the Rust-backed reader so the dataset is picklable. Workers
|
||||
# rebuild it on first read via _ensure_open().
|
||||
d = self.__dict__.copy()
|
||||
d["_perm"] = None
|
||||
return d
|
||||
|
||||
def __setstate__(self, d: Dict[str, Any]) -> None:
|
||||
self.__dict__.update(d)
|
||||
|
||||
def _ensure_open(self) -> None:
|
||||
if self._perm is None:
|
||||
self._perm = _open_permutation(self._state)
|
||||
|
||||
def __len__(self) -> int:
|
||||
self._ensure_open()
|
||||
return len(self._perm)
|
||||
|
||||
def __getitem__(self, index: int) -> Any:
|
||||
self._ensure_open()
|
||||
return self._perm[index]
|
||||
|
||||
def __getitems__(self, indices: List[int]) -> Any:
|
||||
self._ensure_open()
|
||||
return self._perm.fetch(indices)
|
||||
|
||||
|
||||
class LanceIterableTorchDataset(_torch_data.IterableDataset):
|
||||
"""
|
||||
PyTorch iterable-style ``IterableDataset`` over a LanceDB permutation.
|
||||
|
||||
Yields batches in the order defined by the underlying ``Permutation``.
|
||||
With ``num_workers > 1`` each worker iterates the permutation
|
||||
independently — for sharded iteration use the map-style
|
||||
``LanceTorchDataset`` together with a sampler.
|
||||
|
||||
Constructor arguments mirror ``LanceTorchDataset``.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._inner = LanceTorchDataset(*args, **kwargs)
|
||||
|
||||
def __getstate__(self) -> Dict[str, Any]:
|
||||
return {"_inner": self._inner.__getstate__()}
|
||||
|
||||
def __setstate__(self, d: Dict[str, Any]) -> None:
|
||||
self._inner = LanceTorchDataset.__new__(LanceTorchDataset)
|
||||
self._inner.__setstate__(d["_inner"])
|
||||
|
||||
def __iter__(self):
|
||||
self._inner._ensure_open()
|
||||
return iter(self._inner._perm)
|
||||
@@ -779,6 +779,25 @@ class Permutation:
|
||||
batch = LOOP.run(do_getitems())
|
||||
return self.transform_fn(batch)
|
||||
|
||||
def fetch(self, indices: list[int]) -> Any:
|
||||
"""
|
||||
Fetch rows from the permutation by offset.
|
||||
|
||||
This is the public batch-access API. It returns the rows for the given
|
||||
offsets in the same shape as configured by
|
||||
[with_format](#with_format) / [with_transform](#with_transform).
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> db = lancedb.connect("memory:///")
|
||||
>>> tbl = db.create_table("tbl", data=[{"x": x} for x in range(10)])
|
||||
>>> perm = Permutation.identity(tbl)
|
||||
>>> perm.fetch([0, 5, 9])
|
||||
[{'x': 0}, {'x': 5}, {'x': 9}]
|
||||
"""
|
||||
return self.__getitems__(indices)
|
||||
|
||||
@deprecated(details="Use with_skip instead")
|
||||
def skip(self, skip: int) -> "Permutation":
|
||||
"""
|
||||
|
||||
@@ -1095,3 +1095,23 @@ def test_getitems_invalid_offset(some_permutation: Permutation):
|
||||
"""Test __getitems__ with an out-of-range offset raises an error."""
|
||||
with pytest.raises(Exception):
|
||||
some_permutation.__getitems__([999999])
|
||||
|
||||
|
||||
def test_fetch_matches_getitems(some_permutation: Permutation):
|
||||
"""Public fetch() should be equivalent to __getitems__."""
|
||||
indices = [0, 1, 2, 10, 100]
|
||||
assert some_permutation.fetch(indices) == some_permutation.__getitems__(indices)
|
||||
|
||||
|
||||
def test_fetch_respects_format(some_permutation: Permutation):
|
||||
"""fetch() applies the configured format/transform."""
|
||||
arrow_perm = some_permutation.with_format("arrow")
|
||||
result = arrow_perm.fetch([0, 1, 2])
|
||||
assert isinstance(result, pa.RecordBatch)
|
||||
assert result.num_rows == 3
|
||||
|
||||
|
||||
def test_fetch_invalid_offset(some_permutation: Permutation):
|
||||
"""fetch() with an out-of-range offset raises an error."""
|
||||
with pytest.raises(Exception):
|
||||
some_permutation.fetch([999999])
|
||||
|
||||
140
python/python/tests/test_torch_dataset.py
Normal file
140
python/python/tests/test_torch_dataset.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import pickle
|
||||
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
|
||||
from lancedb import connect
|
||||
from lancedb.permutation import permutation_builder
|
||||
|
||||
torch = pytest.importorskip("torch")
|
||||
from lancedb.integrations.torch import ( # noqa: E402
|
||||
LanceIterableTorchDataset,
|
||||
LanceTorchDataset,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_path(tmp_path):
|
||||
"""LanceTorchDataset needs a real, on-disk DB so workers can re-open it."""
|
||||
return tmp_path
|
||||
|
||||
|
||||
def _make_table(db_path, name="imgs", n=20):
|
||||
db = connect(db_path)
|
||||
return db.create_table(
|
||||
name,
|
||||
pa.table({"x": [float(i) for i in range(n)], "y": list(range(n))}),
|
||||
)
|
||||
|
||||
|
||||
def test_basic_len_and_getitem(db_path):
|
||||
tbl = _make_table(db_path)
|
||||
ds = LanceTorchDataset(tbl)
|
||||
assert len(ds) == 20
|
||||
row = ds[0]
|
||||
# Default ("python") format = list of dicts; __getitem__ wraps a single index.
|
||||
assert isinstance(row, list)
|
||||
assert row[0] == {"x": 0.0, "y": 0}
|
||||
|
||||
|
||||
def test_getitems_uses_fetch(db_path):
|
||||
tbl = _make_table(db_path)
|
||||
ds = LanceTorchDataset(tbl)
|
||||
rows = ds.__getitems__([0, 2, 4])
|
||||
assert rows == [
|
||||
{"x": 0.0, "y": 0},
|
||||
{"x": 2.0, "y": 2},
|
||||
{"x": 4.0, "y": 4},
|
||||
]
|
||||
|
||||
|
||||
def test_dataloader_default_collate(db_path):
|
||||
tbl = _make_table(db_path, n=40)
|
||||
ds = LanceTorchDataset(tbl)
|
||||
loader = torch.utils.data.DataLoader(ds, batch_size=8, shuffle=False)
|
||||
batch = next(iter(loader))
|
||||
# default collate stacks list-of-dicts into dict-of-tensors
|
||||
assert isinstance(batch, dict)
|
||||
assert batch["x"].size() == (8,)
|
||||
assert batch["y"].size() == (8,)
|
||||
|
||||
|
||||
def test_picklable(db_path):
|
||||
tbl = _make_table(db_path)
|
||||
ds = LanceTorchDataset(tbl, columns=["x"])
|
||||
|
||||
# Force open then ensure pickle drops the Rust handle.
|
||||
_ = len(ds)
|
||||
blob = pickle.dumps(ds)
|
||||
restored: LanceTorchDataset = pickle.loads(blob)
|
||||
# Rust state should not survive pickling.
|
||||
assert restored._perm is None
|
||||
# …but the dataset must work after re-opening transparently.
|
||||
assert len(restored) == 20
|
||||
assert restored[0] == [{"x": 0.0}]
|
||||
|
||||
|
||||
def test_dataloader_with_workers(db_path):
|
||||
tbl = _make_table(db_path, n=32)
|
||||
ds = LanceTorchDataset(tbl)
|
||||
loader = torch.utils.data.DataLoader(
|
||||
ds, batch_size=4, num_workers=2, shuffle=False
|
||||
)
|
||||
batches = list(loader)
|
||||
seen = []
|
||||
for b in batches:
|
||||
seen.extend(b["x"].tolist())
|
||||
assert sorted(seen) == [float(i) for i in range(32)]
|
||||
|
||||
|
||||
def test_with_permutation_table(db_path):
|
||||
tbl = _make_table(db_path, n=30)
|
||||
db = connect(db_path)
|
||||
perm_tbl = (
|
||||
permutation_builder(tbl)
|
||||
.split_random(ratios=[0.5, 0.5], seed=1, split_names=["train", "test"])
|
||||
.persist(db, "imgs_perm")
|
||||
.execute()
|
||||
)
|
||||
ds = LanceTorchDataset(tbl, permutation_table=perm_tbl, split="train")
|
||||
# Should pickle/restore the permutation table reference too.
|
||||
blob = pickle.dumps(ds)
|
||||
restored = pickle.loads(blob)
|
||||
assert len(restored) == 15
|
||||
|
||||
|
||||
def test_format_passthrough_dataloader(db_path):
|
||||
"""Custom `format` is forwarded to the underlying Permutation."""
|
||||
tbl = _make_table(db_path, n=20)
|
||||
ds = LanceTorchDataset(tbl, format="arrow")
|
||||
# Arrow batches don't go through default_collate, so use a no-op collate.
|
||||
loader = torch.utils.data.DataLoader(
|
||||
ds, batch_size=5, shuffle=False, collate_fn=lambda x: x
|
||||
)
|
||||
batch = next(iter(loader))
|
||||
assert isinstance(batch, pa.RecordBatch)
|
||||
assert batch.num_rows == 5
|
||||
|
||||
|
||||
def test_iterable_dataset(db_path):
|
||||
tbl = _make_table(db_path, n=20)
|
||||
ds = LanceIterableTorchDataset(tbl, batch_size=5)
|
||||
batches = list(ds)
|
||||
# default batch size + skip_last_batch=True yields full-size batches only
|
||||
assert len(batches) == 4
|
||||
assert all(len(b) == 5 for b in batches)
|
||||
|
||||
|
||||
def test_uri_table_name_constructor(db_path):
|
||||
_make_table(db_path)
|
||||
ds = LanceTorchDataset(uri=str(db_path), table_name="imgs")
|
||||
assert len(ds) == 20
|
||||
assert ds[0] == [{"x": 0.0, "y": 0}]
|
||||
|
||||
|
||||
def test_constructor_validates_args():
|
||||
with pytest.raises(ValueError, match="table"):
|
||||
LanceTorchDataset()
|
||||
@@ -17,7 +17,7 @@ use pyo3::{Bound, PyAny, PyResult, exceptions::PyValueError, prelude::*, pyfunct
|
||||
/// [`expr_lit`] and combined with the methods on this struct. On the Python
|
||||
/// side a thin wrapper class (`lancedb.expr.Expr`) delegates to these methods
|
||||
/// and adds Python operator overloads.
|
||||
#[pyclass(name = "PyExpr", from_py_object)]
|
||||
#[pyclass(name = "PyExpr")]
|
||||
#[derive(Clone)]
|
||||
pub struct PyExpr(pub DfExpr);
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ impl PyHeaderProvider {
|
||||
Ok(headers_py) => {
|
||||
// Convert Python dict to Rust HashMap
|
||||
let bound_headers = headers_py.bind(py);
|
||||
let dict: &Bound<PyDict> = bound_headers.cast().map_err(|e| {
|
||||
let dict: &Bound<PyDict> = bound_headers.downcast().map_err(|e| {
|
||||
format!("HeaderProvider.get_headers must return a dict: {}", e)
|
||||
})?;
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ use pyo3::{
|
||||
Bound, FromPyObject, PyAny, PyResult, Python,
|
||||
exceptions::{PyKeyError, PyValueError},
|
||||
intern, pyclass, pymethods,
|
||||
types::{PyAnyMethods, PyString},
|
||||
types::PyAnyMethods,
|
||||
};
|
||||
|
||||
use crate::util::parse_distance_type;
|
||||
@@ -22,7 +22,7 @@ pub fn class_name(ob: &'_ Bound<'_, PyAny>) -> PyResult<String> {
|
||||
let full_name = ob
|
||||
.getattr(intern!(ob.py(), "__class__"))?
|
||||
.getattr(intern!(ob.py(), "__name__"))?;
|
||||
let full_name = full_name.cast::<PyString>()?.to_string_lossy();
|
||||
let full_name = full_name.downcast()?.to_string_lossy();
|
||||
|
||||
match full_name.rsplit_once('.') {
|
||||
Some((_, name)) => Ok(name.to_string()),
|
||||
|
||||
@@ -183,7 +183,7 @@ async fn call_py_method_primitive<Req, Resp>(
|
||||
) -> lance_core::Result<Resp>
|
||||
where
|
||||
Req: serde::Serialize + Send + 'static,
|
||||
Resp: for<'a, 'py> pyo3::FromPyObject<'a, 'py> + Send + 'static,
|
||||
Resp: for<'py> pyo3::FromPyObject<'py> + Send + 'static,
|
||||
{
|
||||
let request_json = serde_json::to_string(&request).map_err(|e| {
|
||||
lance_core::Error::io(format!(
|
||||
@@ -203,7 +203,7 @@ where
|
||||
|
||||
// Call the Python method
|
||||
let result = py_namespace.call_method1(py, method_name, (request_arg,))?;
|
||||
let value: Resp = result.extract(py).map_err(Into::into)?;
|
||||
let value: Resp = result.extract(py)?;
|
||||
Ok::<_, PyErr>(value)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -25,12 +25,12 @@ use pyo3_async_runtimes::tokio::future_into_py;
|
||||
|
||||
fn table_from_py<'a>(table: Bound<'a, PyAny>) -> PyResult<Bound<'a, Table>> {
|
||||
if table.hasattr("_inner")? {
|
||||
Ok(table.getattr("_inner")?.cast_into::<Table>()?)
|
||||
Ok(table.getattr("_inner")?.downcast_into::<Table>()?)
|
||||
} else if table.hasattr("_table")? {
|
||||
Ok(table
|
||||
.getattr("_table")?
|
||||
.getattr("_inner")?
|
||||
.cast_into::<Table>()?)
|
||||
.downcast_into::<Table>()?)
|
||||
} else {
|
||||
Err(PyRuntimeError::new_err(
|
||||
"Provided table does not appear to be a Table or RemoteTable instance",
|
||||
@@ -90,9 +90,9 @@ impl PyAsyncPermutationBuilder {
|
||||
database
|
||||
.getattr("_conn")?
|
||||
.getattr("_inner")?
|
||||
.cast_into::<Connection>()?
|
||||
.downcast_into::<Connection>()?
|
||||
} else {
|
||||
database.getattr("_inner")?.cast_into::<Connection>()?
|
||||
database.getattr("_inner")?.downcast_into::<Connection>()?
|
||||
};
|
||||
let database = conn.borrow().database()?;
|
||||
slf.modify(|builder| builder.persist(database, table_name))
|
||||
@@ -243,7 +243,7 @@ impl PyPermutationReader {
|
||||
let Some(selection) = selection else {
|
||||
return Ok(Select::All);
|
||||
};
|
||||
let selection = selection.cast_into::<PyDict>()?;
|
||||
let selection = selection.downcast_into::<PyDict>()?;
|
||||
let selection = selection
|
||||
.iter()
|
||||
.map(|(key, value)| {
|
||||
|
||||
@@ -33,7 +33,7 @@ use pyo3::pyfunction;
|
||||
use pyo3::pymethods;
|
||||
use pyo3::types::PyList;
|
||||
use pyo3::types::{PyDict, PyString};
|
||||
use pyo3::{Borrowed, FromPyObject, exceptions::PyRuntimeError};
|
||||
use pyo3::{FromPyObject, exceptions::PyRuntimeError};
|
||||
use pyo3::{PyErr, pyclass};
|
||||
use pyo3::{exceptions::PyValueError, intern};
|
||||
use pyo3_async_runtimes::tokio::future_into_py;
|
||||
@@ -43,12 +43,9 @@ use crate::util::parse_distance_type;
|
||||
use crate::{arrow::RecordBatchStream, util::PyLanceDB};
|
||||
use crate::{error::PythonErrorExt, index::class_name};
|
||||
|
||||
impl<'a, 'py> FromPyObject<'a, 'py> for PyLanceDB<FtsQuery> {
|
||||
type Error = PyErr;
|
||||
|
||||
fn extract(ob: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
|
||||
let ob = ob.to_owned();
|
||||
match class_name(&ob)?.as_str() {
|
||||
impl FromPyObject<'_> for PyLanceDB<FtsQuery> {
|
||||
fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
|
||||
match class_name(ob)?.as_str() {
|
||||
"MatchQuery" => {
|
||||
let query = ob.getattr("query")?.extract()?;
|
||||
let column = ob.getattr("column")?.extract()?;
|
||||
@@ -427,7 +424,7 @@ impl Query {
|
||||
"Query text is required for nearest_to_text",
|
||||
))?;
|
||||
|
||||
let query = if let Ok(query_text) = fts_query.cast::<PyString>() {
|
||||
let query = if let Ok(query_text) = fts_query.downcast::<PyString>() {
|
||||
let mut query_text = query_text.to_string();
|
||||
let columns = query
|
||||
.get_item("columns")?
|
||||
@@ -609,7 +606,7 @@ impl TakeQuery {
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(from_py_object)]
|
||||
#[pyclass]
|
||||
#[derive(Clone)]
|
||||
pub struct FTSQuery {
|
||||
inner: LanceDbQuery,
|
||||
@@ -738,7 +735,7 @@ impl FTSQuery {
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(from_py_object)]
|
||||
#[pyclass]
|
||||
#[derive(Clone)]
|
||||
pub struct VectorQuery {
|
||||
inner: LanceDbVectorQuery,
|
||||
|
||||
@@ -11,7 +11,7 @@ use pyo3::{PyResult, pyclass, pymethods};
|
||||
/// Sessions allow you to configure cache sizes for index and metadata caches,
|
||||
/// which can significantly impact memory use and performance. They can
|
||||
/// also be re-used across multiple connections to share the same cache state.
|
||||
#[pyclass(from_py_object)]
|
||||
#[pyclass]
|
||||
#[derive(Clone)]
|
||||
pub struct Session {
|
||||
pub(crate) inner: Arc<LanceSession>,
|
||||
|
||||
@@ -29,7 +29,7 @@ use pyo3_async_runtimes::tokio::future_into_py;
|
||||
mod scannable;
|
||||
|
||||
/// Statistics about a compaction operation.
|
||||
#[pyclass(get_all, from_py_object)]
|
||||
#[pyclass(get_all)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct CompactionStats {
|
||||
/// The number of fragments removed
|
||||
@@ -43,7 +43,7 @@ pub struct CompactionStats {
|
||||
}
|
||||
|
||||
/// Statistics about a cleanup operation
|
||||
#[pyclass(get_all, from_py_object)]
|
||||
#[pyclass(get_all)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct RemovalStats {
|
||||
/// The number of bytes removed
|
||||
@@ -53,7 +53,7 @@ pub struct RemovalStats {
|
||||
}
|
||||
|
||||
/// Statistics about an optimize operation
|
||||
#[pyclass(get_all, from_py_object)]
|
||||
#[pyclass(get_all)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct OptimizeStats {
|
||||
/// Statistics about the compaction operation
|
||||
@@ -62,7 +62,7 @@ pub struct OptimizeStats {
|
||||
pub prune: RemovalStats,
|
||||
}
|
||||
|
||||
#[pyclass(get_all, from_py_object)]
|
||||
#[pyclass(get_all)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct UpdateResult {
|
||||
pub rows_updated: u64,
|
||||
@@ -88,7 +88,7 @@ impl From<lancedb::table::UpdateResult> for UpdateResult {
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(get_all, from_py_object)]
|
||||
#[pyclass(get_all)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AddResult {
|
||||
pub version: u64,
|
||||
@@ -109,7 +109,7 @@ impl From<lancedb::table::AddResult> for AddResult {
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(get_all, from_py_object)]
|
||||
#[pyclass(get_all)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct DeleteResult {
|
||||
pub num_deleted_rows: u64,
|
||||
@@ -135,7 +135,7 @@ impl From<lancedb::table::DeleteResult> for DeleteResult {
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(get_all, from_py_object)]
|
||||
#[pyclass(get_all)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct MergeResult {
|
||||
pub version: u64,
|
||||
@@ -171,7 +171,7 @@ impl From<lancedb::table::MergeResult> for MergeResult {
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(get_all, from_py_object)]
|
||||
#[pyclass(get_all)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AddColumnsResult {
|
||||
pub version: u64,
|
||||
@@ -192,7 +192,7 @@ impl From<lancedb::table::AddColumnsResult> for AddColumnsResult {
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(get_all, from_py_object)]
|
||||
#[pyclass(get_all)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AlterColumnsResult {
|
||||
pub version: u64,
|
||||
@@ -213,7 +213,7 @@ impl From<lancedb::table::AlterColumnsResult> for AlterColumnsResult {
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(get_all, from_py_object)]
|
||||
#[pyclass(get_all)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct DropColumnsResult {
|
||||
pub version: u64,
|
||||
|
||||
@@ -126,11 +126,8 @@ impl Scannable for PyScannable {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'py> FromPyObject<'a, 'py> for PyScannable {
|
||||
type Error = pyo3::PyErr;
|
||||
|
||||
fn extract(ob: pyo3::Borrowed<'a, 'py, PyAny>) -> pyo3::PyResult<Self> {
|
||||
let ob = ob.to_owned();
|
||||
impl<'py> FromPyObject<'py> for PyScannable {
|
||||
fn extract_bound(ob: &pyo3::Bound<'py, PyAny>) -> pyo3::PyResult<Self> {
|
||||
// Convert from Scannable dataclass.
|
||||
let schema: PyArrowType<Schema> = ob.getattr("schema")?.extract()?;
|
||||
let schema = Arc::new(schema.0);
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb"
|
||||
version = "0.28.0-beta.11"
|
||||
version = "0.28.0-beta.10"
|
||||
edition.workspace = true
|
||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||
license.workspace = true
|
||||
|
||||
@@ -43,7 +43,7 @@ pub struct RemoteInsertExec<S: HttpSend = Sender> {
|
||||
client: RestfulLanceDbClient<S>,
|
||||
input: Arc<dyn ExecutionPlan>,
|
||||
overwrite: bool,
|
||||
properties: Arc<PlanProperties>,
|
||||
properties: PlanProperties,
|
||||
add_result: Arc<Mutex<Option<AddResult>>>,
|
||||
metrics: ExecutionPlanMetricsSet,
|
||||
upload_id: Option<String>,
|
||||
@@ -118,7 +118,7 @@ impl<S: HttpSend + 'static> RemoteInsertExec<S> {
|
||||
client,
|
||||
input,
|
||||
overwrite,
|
||||
properties: Arc::new(properties),
|
||||
properties,
|
||||
add_result: Arc::new(Mutex::new(None)),
|
||||
metrics: ExecutionPlanMetricsSet::new(),
|
||||
upload_id,
|
||||
@@ -232,7 +232,7 @@ impl<S: HttpSend + 'static> ExecutionPlan for RemoteInsertExec<S> {
|
||||
self
|
||||
}
|
||||
|
||||
fn properties(&self) -> &Arc<PlanProperties> {
|
||||
fn properties(&self) -> &PlanProperties {
|
||||
&self.properties
|
||||
}
|
||||
|
||||
|
||||
@@ -39,26 +39,21 @@ use lance_index::scalar::FullTextSearchQuery;
|
||||
struct MetadataEraserExec {
|
||||
input: Arc<dyn ExecutionPlan>,
|
||||
schema: Arc<ArrowSchema>,
|
||||
properties: Arc<PlanProperties>,
|
||||
properties: PlanProperties,
|
||||
}
|
||||
|
||||
impl MetadataEraserExec {
|
||||
fn compute_properties_from_input(
|
||||
input: &Arc<dyn ExecutionPlan>,
|
||||
schema: &Arc<ArrowSchema>,
|
||||
) -> Arc<PlanProperties> {
|
||||
) -> PlanProperties {
|
||||
let input_properties = input.properties();
|
||||
let eq_properties = input_properties
|
||||
.eq_properties
|
||||
.clone()
|
||||
.with_new_schema(schema.clone())
|
||||
.unwrap();
|
||||
Arc::new(
|
||||
input_properties
|
||||
.as_ref()
|
||||
.clone()
|
||||
.with_eq_properties(eq_properties),
|
||||
)
|
||||
input_properties.clone().with_eq_properties(eq_properties)
|
||||
}
|
||||
|
||||
fn new(input: Arc<dyn ExecutionPlan>) -> Self {
|
||||
@@ -92,7 +87,7 @@ impl ExecutionPlan for MetadataEraserExec {
|
||||
self
|
||||
}
|
||||
|
||||
fn properties(&self) -> &Arc<PlanProperties> {
|
||||
fn properties(&self) -> &PlanProperties {
|
||||
&self.properties
|
||||
}
|
||||
|
||||
|
||||
@@ -81,7 +81,7 @@ pub struct InsertExec {
|
||||
dataset: Arc<Dataset>,
|
||||
input: Arc<dyn ExecutionPlan>,
|
||||
write_params: WriteParams,
|
||||
properties: Arc<PlanProperties>,
|
||||
properties: PlanProperties,
|
||||
partial_transactions: Arc<Mutex<Vec<Transaction>>>,
|
||||
metrics: ExecutionPlanMetricsSet,
|
||||
}
|
||||
@@ -107,7 +107,7 @@ impl InsertExec {
|
||||
dataset,
|
||||
input,
|
||||
write_params,
|
||||
properties: Arc::new(properties),
|
||||
properties,
|
||||
partial_transactions: Arc::new(Mutex::new(Vec::with_capacity(num_partitions))),
|
||||
metrics: ExecutionPlanMetricsSet::new(),
|
||||
}
|
||||
@@ -136,7 +136,7 @@ impl ExecutionPlan for InsertExec {
|
||||
self
|
||||
}
|
||||
|
||||
fn properties(&self) -> &Arc<PlanProperties> {
|
||||
fn properties(&self) -> &PlanProperties {
|
||||
&self.properties
|
||||
}
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ pub(crate) struct ScannableExec {
|
||||
// We don't require Scannable to be Sync, so we wrap it in a Mutex to allow safe concurrent access.
|
||||
source: Mutex<Box<dyn Scannable>>,
|
||||
num_rows: Option<usize>,
|
||||
properties: Arc<PlanProperties>,
|
||||
properties: PlanProperties,
|
||||
tracker: Option<Arc<WriteProgressTracker>>,
|
||||
}
|
||||
|
||||
@@ -49,7 +49,7 @@ impl ScannableExec {
|
||||
Self {
|
||||
source,
|
||||
num_rows,
|
||||
properties: Arc::new(properties),
|
||||
properties,
|
||||
tracker,
|
||||
}
|
||||
}
|
||||
@@ -70,7 +70,7 @@ impl ExecutionPlan for ScannableExec {
|
||||
self
|
||||
}
|
||||
|
||||
fn properties(&self) -> &Arc<PlanProperties> {
|
||||
fn properties(&self) -> &PlanProperties {
|
||||
&self.properties
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user