Compare commits

..

1 Commits

Author SHA1 Message Date
Ayush Chaurasia
2fa4d04de7 feat(python): align with_format("torch") with HuggingFace semantics
- New "torch" format yields list-of-per-row-dicts whose values are torch
  tensors. PyTorch DataLoader's default collate stacks these into a
  dict[str, Tensor] per batch — matching HuggingFace
  dataset.set_format("torch") and removing the need for a custom collate_fn.
- The previous "torch" behavior (list of 1-D row tensors without column
  names) moves to a new "torch_row" format.
- "torch_col" is unchanged (2-D tensor, requires custom collate).

Closes lancedb/lancedb#3245
2026-04-29 22:15:40 +05:30
35 changed files with 361 additions and 323 deletions

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.28.0-beta.11"
current_version = "0.28.0-beta.10"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

406
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -13,40 +13,40 @@ categories = ["database-implementations"]
rust-version = "1.91.0"
[workspace.dependencies]
lance = { "version" = "=6.0.0-beta.7", default-features = false, "tag" = "v6.0.0-beta.7", "git" = "https://github.com/lance-format/lance.git" }
lance-core = { "version" = "=6.0.0-beta.7", "tag" = "v6.0.0-beta.7", "git" = "https://github.com/lance-format/lance.git" }
lance-datagen = { "version" = "=6.0.0-beta.7", "tag" = "v6.0.0-beta.7", "git" = "https://github.com/lance-format/lance.git" }
lance-file = { "version" = "=6.0.0-beta.7", "tag" = "v6.0.0-beta.7", "git" = "https://github.com/lance-format/lance.git" }
lance-io = { "version" = "=6.0.0-beta.7", default-features = false, "tag" = "v6.0.0-beta.7", "git" = "https://github.com/lance-format/lance.git" }
lance-index = { "version" = "=6.0.0-beta.7", "tag" = "v6.0.0-beta.7", "git" = "https://github.com/lance-format/lance.git" }
lance-linalg = { "version" = "=6.0.0-beta.7", "tag" = "v6.0.0-beta.7", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace = { "version" = "=6.0.0-beta.7", "tag" = "v6.0.0-beta.7", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace-impls = { "version" = "=6.0.0-beta.7", default-features = false, "tag" = "v6.0.0-beta.7", "git" = "https://github.com/lance-format/lance.git" }
lance-table = { "version" = "=6.0.0-beta.7", "tag" = "v6.0.0-beta.7", "git" = "https://github.com/lance-format/lance.git" }
lance-testing = { "version" = "=6.0.0-beta.7", "tag" = "v6.0.0-beta.7", "git" = "https://github.com/lance-format/lance.git" }
lance-datafusion = { "version" = "=6.0.0-beta.7", "tag" = "v6.0.0-beta.7", "git" = "https://github.com/lance-format/lance.git" }
lance-encoding = { "version" = "=6.0.0-beta.7", "tag" = "v6.0.0-beta.7", "git" = "https://github.com/lance-format/lance.git" }
lance-arrow = { "version" = "=6.0.0-beta.7", "tag" = "v6.0.0-beta.7", "git" = "https://github.com/lance-format/lance.git" }
lance = { "version" = "=6.0.0-beta.4", default-features = false, "tag" = "v6.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
lance-core = { "version" = "=6.0.0-beta.4", "tag" = "v6.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
lance-datagen = { "version" = "=6.0.0-beta.4", "tag" = "v6.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
lance-file = { "version" = "=6.0.0-beta.4", "tag" = "v6.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
lance-io = { "version" = "=6.0.0-beta.4", default-features = false, "tag" = "v6.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
lance-index = { "version" = "=6.0.0-beta.4", "tag" = "v6.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
lance-linalg = { "version" = "=6.0.0-beta.4", "tag" = "v6.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace = { "version" = "=6.0.0-beta.4", "tag" = "v6.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace-impls = { "version" = "=6.0.0-beta.4", default-features = false, "tag" = "v6.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
lance-table = { "version" = "=6.0.0-beta.4", "tag" = "v6.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
lance-testing = { "version" = "=6.0.0-beta.4", "tag" = "v6.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
lance-datafusion = { "version" = "=6.0.0-beta.4", "tag" = "v6.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
lance-encoding = { "version" = "=6.0.0-beta.4", "tag" = "v6.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
lance-arrow = { "version" = "=6.0.0-beta.4", "tag" = "v6.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
ahash = "0.8"
# Note that this one does not include pyarrow
arrow = { version = "58.0.0", optional = false }
arrow-array = "58.0.0"
arrow-data = "58.0.0"
arrow-ipc = "58.0.0"
arrow-ord = "58.0.0"
arrow-schema = "58.0.0"
arrow-select = "58.0.0"
arrow-cast = "58.0.0"
arrow = { version = "57.2", optional = false }
arrow-array = "57.2"
arrow-data = "57.2"
arrow-ipc = "57.2"
arrow-ord = "57.2"
arrow-schema = "57.2"
arrow-select = "57.2"
arrow-cast = "57.2"
async-trait = "0"
datafusion = { version = "53.0.0", default-features = false }
datafusion-catalog = "53.0.0"
datafusion-common = { version = "53.0.0", default-features = false }
datafusion-execution = "53.0.0"
datafusion-expr = "53.0.0"
datafusion-functions = "53.0.0"
datafusion-physical-plan = "53.0.0"
datafusion-physical-expr = "53.0.0"
datafusion-sql = "53.0.0"
datafusion = { version = "52.1", default-features = false }
datafusion-catalog = "52.1"
datafusion-common = { version = "52.1", default-features = false }
datafusion-execution = "52.1"
datafusion-expr = "52.1"
datafusion-functions = "52.1"
datafusion-physical-plan = "52.1"
datafusion-physical-expr = "52.1"
datafusion-sql = "52.1"
env_logger = "0.11"
half = { "version" = "2.7.1", default-features = false, features = [
"num-traits",

View File

@@ -14,7 +14,7 @@ Add the following dependency to your `pom.xml`:
<dependency>
<groupId>com.lancedb</groupId>
<artifactId>lancedb-core</artifactId>
<version>0.28.0-beta.11</version>
<version>0.28.0-beta.10</version>
</dependency>
```

View File

@@ -8,7 +8,7 @@
<parent>
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.28.0-beta.11</version>
<version>0.28.0-beta.10</version>
<relativePath>../pom.xml</relativePath>
</parent>

View File

@@ -6,7 +6,7 @@
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.28.0-beta.11</version>
<version>0.28.0-beta.10</version>
<packaging>pom</packaging>
<name>${project.artifactId}</name>
<description>LanceDB Java SDK Parent POM</description>
@@ -28,7 +28,7 @@
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<arrow.version>15.0.0</arrow.version>
<lance-core.version>6.0.0-beta.7</lance-core.version>
<lance-core.version>6.0.0-beta.4</lance-core.version>
<spotless.skip>false</spotless.skip>
<spotless.version>2.30.0</spotless.version>
<spotless.java.googlejavaformat.version>1.7</spotless.java.googlejavaformat.version>

View File

@@ -1,7 +1,7 @@
[package]
name = "lancedb-nodejs"
edition.workspace = true
version = "0.28.0-beta.11"
version = "0.28.0-beta.10"
publish = false
license.workspace = true
description.workspace = true
@@ -16,7 +16,7 @@ crate-type = ["cdylib"]
async-trait.workspace = true
arrow-ipc.workspace = true
arrow-array.workspace = true
arrow-buffer = "58.0.0"
arrow-buffer = "57.2"
half.workspace = true
arrow-schema.workspace = true
env_logger.workspace = true

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-darwin-arm64",
"version": "0.28.0-beta.11",
"version": "0.28.0-beta.10",
"os": ["darwin"],
"cpu": ["arm64"],
"main": "lancedb.darwin-arm64.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-arm64-gnu",
"version": "0.28.0-beta.11",
"version": "0.28.0-beta.10",
"os": ["linux"],
"cpu": ["arm64"],
"main": "lancedb.linux-arm64-gnu.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-arm64-musl",
"version": "0.28.0-beta.11",
"version": "0.28.0-beta.10",
"os": ["linux"],
"cpu": ["arm64"],
"main": "lancedb.linux-arm64-musl.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-x64-gnu",
"version": "0.28.0-beta.11",
"version": "0.28.0-beta.10",
"os": ["linux"],
"cpu": ["x64"],
"main": "lancedb.linux-x64-gnu.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-x64-musl",
"version": "0.28.0-beta.11",
"version": "0.28.0-beta.10",
"os": ["linux"],
"cpu": ["x64"],
"main": "lancedb.linux-x64-musl.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-win32-arm64-msvc",
"version": "0.28.0-beta.11",
"version": "0.28.0-beta.10",
"os": [
"win32"
],

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-win32-x64-msvc",
"version": "0.28.0-beta.11",
"version": "0.28.0-beta.10",
"os": ["win32"],
"cpu": ["x64"],
"main": "lancedb.win32-x64-msvc.node",

View File

@@ -11,7 +11,7 @@
"ann"
],
"private": false,
"version": "0.28.0-beta.11",
"version": "0.28.0-beta.10",
"main": "dist/index.js",
"exports": {
".": "./dist/index.js",

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.31.0-beta.11"
current_version = "0.31.0-beta.10"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb-python"
version = "0.31.0-beta.11"
version = "0.31.0-beta.10"
publish = false
edition.workspace = true
description = "Python bindings for LanceDB"
@@ -15,7 +15,7 @@ name = "_lancedb"
crate-type = ["cdylib"]
[dependencies]
arrow = { version = "58.0.0", features = ["pyarrow"] }
arrow = { version = "57.2", features = ["pyarrow"] }
async-trait = "0.1"
bytes = "1"
lancedb = { path = "../rust/lancedb", default-features = false }
@@ -25,8 +25,8 @@ lance-namespace-impls.workspace = true
lance-io.workspace = true
env_logger.workspace = true
log.workspace = true
pyo3 = { version = "0.28", features = ["extension-module", "abi3-py39"] }
pyo3-async-runtimes = { version = "0.28", features = [
pyo3 = { version = "0.26", features = ["extension-module", "abi3-py39"] }
pyo3-async-runtimes = { version = "0.26", features = [
"attributes",
"tokio-runtime",
] }
@@ -38,7 +38,7 @@ snafu.workspace = true
tokio = { version = "1.40", features = ["sync"] }
[build-dependencies]
pyo3-build-config = { version = "0.28", features = [
pyo3-build-config = { version = "0.26", features = [
"extension-module",
"abi3-py39",
] }

View File

@@ -9,7 +9,7 @@ import json
from ._lancedb import async_permutation_builder, PermutationReader
from .table import LanceTable
from .background_loop import LOOP
from .util import batch_to_tensor, batch_to_tensor_rows
from .util import batch_to_tensor, batch_to_tensor_dict, batch_to_tensor_rows
from typing import Any, Callable, Iterator, Literal, Optional, TYPE_CHECKING, Union
if TYPE_CHECKING:
@@ -697,6 +697,7 @@ class Permutation:
"pandas",
"arrow",
"torch",
"torch_row",
"torch_col",
"polars",
],
@@ -712,8 +713,17 @@ class Permutation:
- "python_col" - the batch will be a dict of lists (one entry per column)
- "pandas" - the batch will be a pandas DataFrame
- "arrow" - the batch will be a pyarrow RecordBatch
- "torch" - the batch will be a list of tensors, one per row
- "torch_col" - the batch will be a 2D torch tensor (first dim indexes columns)
- "torch" - a list of per-row dicts whose values are torch tensors. When
used with ``torch.utils.data.DataLoader`` (default collate), each
batch yielded by the loader is ``dict[str, Tensor]`` — one tensor per
column, with column names preserved. This matches HuggingFace
``dataset.set_format("torch")`` semantics.
- "torch_row" - a list of 1-D torch tensors, one per row. Each tensor
stacks all column values into a single row vector and column names
are not preserved. (This was the previous "torch" behavior.)
- "torch_col" - a 2-D torch tensor of shape ``(n_cols, n_rows)``. Column
names are not preserved. Requires ``collate_fn=lambda x: x`` if used
with ``DataLoader``.
- "polars" - the batch will be a polars DataFrame
Conversion may or may not involve a data copy. Lance uses Arrow internally
@@ -741,6 +751,8 @@ class Permutation:
elif format == "arrow":
return self.with_transform(Transforms.arrow2arrow)
elif format == "torch":
return self.with_transform(batch_to_tensor_dict)
elif format == "torch_row":
return self.with_transform(batch_to_tensor_rows)
elif format == "torch_col":
return self.with_transform(batch_to_tensor)

View File

@@ -448,3 +448,29 @@ def batch_to_tensor_rows(batch: pa.RecordBatch):
stacked = torch.tensor(numpy.column_stack(columns))
rows = list(stacked.unbind(dim=0))
return rows
def batch_to_tensor_dict(batch: pa.RecordBatch):
"""
Convert a PyArrow RecordBatch into a list of per-row dicts whose values
are PyTorch tensors.
Each column is converted to a tensor in one shot (zero-copy via DLPack
when supported), then sliced per row. The result is shaped to work with
PyTorch's default DataLoader collate, which stacks the per-row dicts
into a single ``dict[str, Tensor]`` per batch — matching the
HuggingFace ``dataset.set_format("torch")`` convention.
Fails if torch is not installed.
Fails if a column's data type is not supported by PyTorch.
"""
torch = attempt_import_or_raise("torch", "torch")
columns: dict[str, "torch.Tensor"] = {}
for i, name in enumerate(batch.schema.names):
col = batch.column(i)
try:
columns[name] = torch.from_dlpack(col)
except Exception:
columns[name] = torch.tensor(col.to_numpy(zero_copy_only=False))
n = batch.num_rows
return [{name: t[i] for name, t in columns.items()} for i in range(n)]

View File

@@ -950,14 +950,27 @@ def test_transform_fn(mem_db):
try:
import torch
# "torch" format: list of per-row dicts of tensors (HF-compatible).
torch_result = list(
permutation.with_format("torch").iter(10, skip_last_batch=False)
)[0]
assert isinstance(torch_result, list)
assert len(torch_result) == 10
assert isinstance(torch_result[0], torch.Tensor)
assert torch_result[0].shape == (2,)
assert torch_result[0].dtype == torch.int64
assert isinstance(torch_result[0], dict)
assert set(torch_result[0].keys()) == {"id", "value"}
assert isinstance(torch_result[0]["id"], torch.Tensor)
assert torch_result[0]["id"].shape == ()
assert torch_result[0]["id"].dtype == torch.int64
# "torch_row" format: list of 1-D row tensors (previous "torch" behavior).
torch_row_result = list(
permutation.with_format("torch_row").iter(10, skip_last_batch=False)
)[0]
assert isinstance(torch_row_result, list)
assert len(torch_row_result) == 10
assert isinstance(torch_row_result[0], torch.Tensor)
assert torch_row_result[0].shape == (2,)
assert torch_row_result[0].dtype == torch.int64
except ImportError:
# Skip check if torch is not installed
pass

View File

@@ -27,8 +27,18 @@ def test_permutation_dataloader(mem_db):
for batch in dataloader:
assert batch["a"].size(0) == 10
# New "torch" format: per-row dicts of tensors, default collate yields
# dict[str, Tensor] (HuggingFace style).
permutation = permutation.with_format("torch")
dataloader = torch.utils.data.DataLoader(permutation, batch_size=10, shuffle=True)
for batch in dataloader:
assert isinstance(batch, dict)
assert "a" in batch
assert batch["a"].size() == (10,)
# Previous "torch" semantics is preserved under the "torch_row" name.
permutation = permutation.with_format("torch_row")
dataloader = torch.utils.data.DataLoader(permutation, batch_size=10, shuffle=True)
for batch in dataloader:
assert batch.size(0) == 10
assert batch.size(1) == 1

View File

@@ -17,7 +17,7 @@ use pyo3::{Bound, PyAny, PyResult, exceptions::PyValueError, prelude::*, pyfunct
/// [`expr_lit`] and combined with the methods on this struct. On the Python
/// side a thin wrapper class (`lancedb.expr.Expr`) delegates to these methods
/// and adds Python operator overloads.
#[pyclass(name = "PyExpr", from_py_object)]
#[pyclass(name = "PyExpr")]
#[derive(Clone)]
pub struct PyExpr(pub DfExpr);

View File

@@ -33,7 +33,7 @@ impl PyHeaderProvider {
Ok(headers_py) => {
// Convert Python dict to Rust HashMap
let bound_headers = headers_py.bind(py);
let dict: &Bound<PyDict> = bound_headers.cast().map_err(|e| {
let dict: &Bound<PyDict> = bound_headers.downcast().map_err(|e| {
format!("HeaderProvider.get_headers must return a dict: {}", e)
})?;

View File

@@ -13,7 +13,7 @@ use pyo3::{
Bound, FromPyObject, PyAny, PyResult, Python,
exceptions::{PyKeyError, PyValueError},
intern, pyclass, pymethods,
types::{PyAnyMethods, PyString},
types::PyAnyMethods,
};
use crate::util::parse_distance_type;
@@ -22,7 +22,7 @@ pub fn class_name(ob: &'_ Bound<'_, PyAny>) -> PyResult<String> {
let full_name = ob
.getattr(intern!(ob.py(), "__class__"))?
.getattr(intern!(ob.py(), "__name__"))?;
let full_name = full_name.cast::<PyString>()?.to_string_lossy();
let full_name = full_name.downcast()?.to_string_lossy();
match full_name.rsplit_once('.') {
Some((_, name)) => Ok(name.to_string()),

View File

@@ -183,7 +183,7 @@ async fn call_py_method_primitive<Req, Resp>(
) -> lance_core::Result<Resp>
where
Req: serde::Serialize + Send + 'static,
Resp: for<'a, 'py> pyo3::FromPyObject<'a, 'py> + Send + 'static,
Resp: for<'py> pyo3::FromPyObject<'py> + Send + 'static,
{
let request_json = serde_json::to_string(&request).map_err(|e| {
lance_core::Error::io(format!(
@@ -203,7 +203,7 @@ where
// Call the Python method
let result = py_namespace.call_method1(py, method_name, (request_arg,))?;
let value: Resp = result.extract(py).map_err(Into::into)?;
let value: Resp = result.extract(py)?;
Ok::<_, PyErr>(value)
})
})

View File

@@ -25,12 +25,12 @@ use pyo3_async_runtimes::tokio::future_into_py;
fn table_from_py<'a>(table: Bound<'a, PyAny>) -> PyResult<Bound<'a, Table>> {
if table.hasattr("_inner")? {
Ok(table.getattr("_inner")?.cast_into::<Table>()?)
Ok(table.getattr("_inner")?.downcast_into::<Table>()?)
} else if table.hasattr("_table")? {
Ok(table
.getattr("_table")?
.getattr("_inner")?
.cast_into::<Table>()?)
.downcast_into::<Table>()?)
} else {
Err(PyRuntimeError::new_err(
"Provided table does not appear to be a Table or RemoteTable instance",
@@ -90,9 +90,9 @@ impl PyAsyncPermutationBuilder {
database
.getattr("_conn")?
.getattr("_inner")?
.cast_into::<Connection>()?
.downcast_into::<Connection>()?
} else {
database.getattr("_inner")?.cast_into::<Connection>()?
database.getattr("_inner")?.downcast_into::<Connection>()?
};
let database = conn.borrow().database()?;
slf.modify(|builder| builder.persist(database, table_name))
@@ -243,7 +243,7 @@ impl PyPermutationReader {
let Some(selection) = selection else {
return Ok(Select::All);
};
let selection = selection.cast_into::<PyDict>()?;
let selection = selection.downcast_into::<PyDict>()?;
let selection = selection
.iter()
.map(|(key, value)| {

View File

@@ -33,7 +33,7 @@ use pyo3::pyfunction;
use pyo3::pymethods;
use pyo3::types::PyList;
use pyo3::types::{PyDict, PyString};
use pyo3::{Borrowed, FromPyObject, exceptions::PyRuntimeError};
use pyo3::{FromPyObject, exceptions::PyRuntimeError};
use pyo3::{PyErr, pyclass};
use pyo3::{exceptions::PyValueError, intern};
use pyo3_async_runtimes::tokio::future_into_py;
@@ -43,12 +43,9 @@ use crate::util::parse_distance_type;
use crate::{arrow::RecordBatchStream, util::PyLanceDB};
use crate::{error::PythonErrorExt, index::class_name};
impl<'a, 'py> FromPyObject<'a, 'py> for PyLanceDB<FtsQuery> {
type Error = PyErr;
fn extract(ob: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
let ob = ob.to_owned();
match class_name(&ob)?.as_str() {
impl FromPyObject<'_> for PyLanceDB<FtsQuery> {
fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
match class_name(ob)?.as_str() {
"MatchQuery" => {
let query = ob.getattr("query")?.extract()?;
let column = ob.getattr("column")?.extract()?;
@@ -427,7 +424,7 @@ impl Query {
"Query text is required for nearest_to_text",
))?;
let query = if let Ok(query_text) = fts_query.cast::<PyString>() {
let query = if let Ok(query_text) = fts_query.downcast::<PyString>() {
let mut query_text = query_text.to_string();
let columns = query
.get_item("columns")?
@@ -609,7 +606,7 @@ impl TakeQuery {
}
}
#[pyclass(from_py_object)]
#[pyclass]
#[derive(Clone)]
pub struct FTSQuery {
inner: LanceDbQuery,
@@ -738,7 +735,7 @@ impl FTSQuery {
}
}
#[pyclass(from_py_object)]
#[pyclass]
#[derive(Clone)]
pub struct VectorQuery {
inner: LanceDbVectorQuery,

View File

@@ -11,7 +11,7 @@ use pyo3::{PyResult, pyclass, pymethods};
/// Sessions allow you to configure cache sizes for index and metadata caches,
/// which can significantly impact memory use and performance. They can
/// also be re-used across multiple connections to share the same cache state.
#[pyclass(from_py_object)]
#[pyclass]
#[derive(Clone)]
pub struct Session {
pub(crate) inner: Arc<LanceSession>,

View File

@@ -29,7 +29,7 @@ use pyo3_async_runtimes::tokio::future_into_py;
mod scannable;
/// Statistics about a compaction operation.
#[pyclass(get_all, from_py_object)]
#[pyclass(get_all)]
#[derive(Clone, Debug)]
pub struct CompactionStats {
/// The number of fragments removed
@@ -43,7 +43,7 @@ pub struct CompactionStats {
}
/// Statistics about a cleanup operation
#[pyclass(get_all, from_py_object)]
#[pyclass(get_all)]
#[derive(Clone, Debug)]
pub struct RemovalStats {
/// The number of bytes removed
@@ -53,7 +53,7 @@ pub struct RemovalStats {
}
/// Statistics about an optimize operation
#[pyclass(get_all, from_py_object)]
#[pyclass(get_all)]
#[derive(Clone, Debug)]
pub struct OptimizeStats {
/// Statistics about the compaction operation
@@ -62,7 +62,7 @@ pub struct OptimizeStats {
pub prune: RemovalStats,
}
#[pyclass(get_all, from_py_object)]
#[pyclass(get_all)]
#[derive(Clone, Debug)]
pub struct UpdateResult {
pub rows_updated: u64,
@@ -88,7 +88,7 @@ impl From<lancedb::table::UpdateResult> for UpdateResult {
}
}
#[pyclass(get_all, from_py_object)]
#[pyclass(get_all)]
#[derive(Clone, Debug)]
pub struct AddResult {
pub version: u64,
@@ -109,7 +109,7 @@ impl From<lancedb::table::AddResult> for AddResult {
}
}
#[pyclass(get_all, from_py_object)]
#[pyclass(get_all)]
#[derive(Clone, Debug)]
pub struct DeleteResult {
pub num_deleted_rows: u64,
@@ -135,7 +135,7 @@ impl From<lancedb::table::DeleteResult> for DeleteResult {
}
}
#[pyclass(get_all, from_py_object)]
#[pyclass(get_all)]
#[derive(Clone, Debug)]
pub struct MergeResult {
pub version: u64,
@@ -171,7 +171,7 @@ impl From<lancedb::table::MergeResult> for MergeResult {
}
}
#[pyclass(get_all, from_py_object)]
#[pyclass(get_all)]
#[derive(Clone, Debug)]
pub struct AddColumnsResult {
pub version: u64,
@@ -192,7 +192,7 @@ impl From<lancedb::table::AddColumnsResult> for AddColumnsResult {
}
}
#[pyclass(get_all, from_py_object)]
#[pyclass(get_all)]
#[derive(Clone, Debug)]
pub struct AlterColumnsResult {
pub version: u64,
@@ -213,7 +213,7 @@ impl From<lancedb::table::AlterColumnsResult> for AlterColumnsResult {
}
}
#[pyclass(get_all, from_py_object)]
#[pyclass(get_all)]
#[derive(Clone, Debug)]
pub struct DropColumnsResult {
pub version: u64,

View File

@@ -126,11 +126,8 @@ impl Scannable for PyScannable {
}
}
impl<'a, 'py> FromPyObject<'a, 'py> for PyScannable {
type Error = pyo3::PyErr;
fn extract(ob: pyo3::Borrowed<'a, 'py, PyAny>) -> pyo3::PyResult<Self> {
let ob = ob.to_owned();
impl<'py> FromPyObject<'py> for PyScannable {
fn extract_bound(ob: &pyo3::Bound<'py, PyAny>) -> pyo3::PyResult<Self> {
// Convert from Scannable dataclass.
let schema: PyArrowType<Schema> = ob.getattr("schema")?.extract()?;
let schema = Arc::new(schema.0);

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb"
version = "0.28.0-beta.11"
version = "0.28.0-beta.10"
edition.workspace = true
description = "LanceDB: A serverless, low-latency vector database for AI applications"
license.workspace = true

View File

@@ -43,7 +43,7 @@ pub struct RemoteInsertExec<S: HttpSend = Sender> {
client: RestfulLanceDbClient<S>,
input: Arc<dyn ExecutionPlan>,
overwrite: bool,
properties: Arc<PlanProperties>,
properties: PlanProperties,
add_result: Arc<Mutex<Option<AddResult>>>,
metrics: ExecutionPlanMetricsSet,
upload_id: Option<String>,
@@ -118,7 +118,7 @@ impl<S: HttpSend + 'static> RemoteInsertExec<S> {
client,
input,
overwrite,
properties: Arc::new(properties),
properties,
add_result: Arc::new(Mutex::new(None)),
metrics: ExecutionPlanMetricsSet::new(),
upload_id,
@@ -232,7 +232,7 @@ impl<S: HttpSend + 'static> ExecutionPlan for RemoteInsertExec<S> {
self
}
fn properties(&self) -> &Arc<PlanProperties> {
fn properties(&self) -> &PlanProperties {
&self.properties
}

View File

@@ -39,26 +39,21 @@ use lance_index::scalar::FullTextSearchQuery;
struct MetadataEraserExec {
input: Arc<dyn ExecutionPlan>,
schema: Arc<ArrowSchema>,
properties: Arc<PlanProperties>,
properties: PlanProperties,
}
impl MetadataEraserExec {
fn compute_properties_from_input(
input: &Arc<dyn ExecutionPlan>,
schema: &Arc<ArrowSchema>,
) -> Arc<PlanProperties> {
) -> PlanProperties {
let input_properties = input.properties();
let eq_properties = input_properties
.eq_properties
.clone()
.with_new_schema(schema.clone())
.unwrap();
Arc::new(
input_properties
.as_ref()
.clone()
.with_eq_properties(eq_properties),
)
input_properties.clone().with_eq_properties(eq_properties)
}
fn new(input: Arc<dyn ExecutionPlan>) -> Self {
@@ -92,7 +87,7 @@ impl ExecutionPlan for MetadataEraserExec {
self
}
fn properties(&self) -> &Arc<PlanProperties> {
fn properties(&self) -> &PlanProperties {
&self.properties
}

View File

@@ -81,7 +81,7 @@ pub struct InsertExec {
dataset: Arc<Dataset>,
input: Arc<dyn ExecutionPlan>,
write_params: WriteParams,
properties: Arc<PlanProperties>,
properties: PlanProperties,
partial_transactions: Arc<Mutex<Vec<Transaction>>>,
metrics: ExecutionPlanMetricsSet,
}
@@ -107,7 +107,7 @@ impl InsertExec {
dataset,
input,
write_params,
properties: Arc::new(properties),
properties,
partial_transactions: Arc::new(Mutex::new(Vec::with_capacity(num_partitions))),
metrics: ExecutionPlanMetricsSet::new(),
}
@@ -136,7 +136,7 @@ impl ExecutionPlan for InsertExec {
self
}
fn properties(&self) -> &Arc<PlanProperties> {
fn properties(&self) -> &PlanProperties {
&self.properties
}

View File

@@ -20,7 +20,7 @@ pub(crate) struct ScannableExec {
// We don't require Scannable to be Sync, so we wrap it in a Mutex to allow safe concurrent access.
source: Mutex<Box<dyn Scannable>>,
num_rows: Option<usize>,
properties: Arc<PlanProperties>,
properties: PlanProperties,
tracker: Option<Arc<WriteProgressTracker>>,
}
@@ -49,7 +49,7 @@ impl ScannableExec {
Self {
source,
num_rows,
properties: Arc::new(properties),
properties,
tracker,
}
}
@@ -70,7 +70,7 @@ impl ExecutionPlan for ScannableExec {
self
}
fn properties(&self) -> &Arc<PlanProperties> {
fn properties(&self) -> &PlanProperties {
&self.properties
}