mirror of
https://github.com/lancedb/lancedb.git
synced 2026-04-09 17:30:41 +00:00
Compare commits
5 Commits
python-v0.
...
xuanwo/per
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7c37ba216a | ||
|
|
768d84845c | ||
|
|
2d380d1669 | ||
|
|
a898dc81c2 | ||
|
|
de3f8097e7 |
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.28.0-beta.0"
|
||||
current_version = "0.28.0-beta.1"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
6
Cargo.lock
generated
6
Cargo.lock
generated
@@ -4630,7 +4630,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lancedb"
|
||||
version = "0.28.0-beta.0"
|
||||
version = "0.28.0-beta.1"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
"anyhow",
|
||||
@@ -4712,7 +4712,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lancedb-nodejs"
|
||||
version = "0.28.0-beta.0"
|
||||
version = "0.28.0-beta.1"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -4734,7 +4734,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lancedb-python"
|
||||
version = "0.31.0-beta.0"
|
||||
version = "0.31.0-beta.1"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"async-trait",
|
||||
|
||||
@@ -14,7 +14,7 @@ Add the following dependency to your `pom.xml`:
|
||||
<dependency>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-core</artifactId>
|
||||
<version>0.28.0-beta.0</version>
|
||||
<version>0.28.0-beta.1</version>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
|
||||
@@ -53,3 +53,18 @@ optional tlsConfig: TlsConfig;
|
||||
```ts
|
||||
optional userAgent: string;
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
### userId?
|
||||
|
||||
```ts
|
||||
optional userId: string;
|
||||
```
|
||||
|
||||
User identifier for tracking purposes.
|
||||
|
||||
This is sent as the `x-lancedb-user-id` header in requests to LanceDB Cloud/Enterprise.
|
||||
It can be set directly, or via the `LANCEDB_USER_ID` environment variable.
|
||||
Alternatively, set `LANCEDB_USER_ID_ENV_KEY` to specify another environment
|
||||
variable that contains the user ID value.
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
<parent>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.28.0-beta.0</version>
|
||||
<version>0.28.0-beta.1</version>
|
||||
<relativePath>../pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.28.0-beta.0</version>
|
||||
<version>0.28.0-beta.1</version>
|
||||
<packaging>pom</packaging>
|
||||
<name>${project.artifactId}</name>
|
||||
<description>LanceDB Java SDK Parent POM</description>
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "lancedb-nodejs"
|
||||
edition.workspace = true
|
||||
version = "0.28.0-beta.0"
|
||||
version = "0.28.0-beta.1"
|
||||
license.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-arm64",
|
||||
"version": "0.28.0-beta.0",
|
||||
"version": "0.28.0-beta.1",
|
||||
"os": ["darwin"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.darwin-arm64.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
||||
"version": "0.28.0-beta.0",
|
||||
"version": "0.28.0-beta.1",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-musl",
|
||||
"version": "0.28.0-beta.0",
|
||||
"version": "0.28.0-beta.1",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-gnu",
|
||||
"version": "0.28.0-beta.0",
|
||||
"version": "0.28.0-beta.1",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-musl",
|
||||
"version": "0.28.0-beta.0",
|
||||
"version": "0.28.0-beta.1",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-arm64-msvc",
|
||||
"version": "0.28.0-beta.0",
|
||||
"version": "0.28.0-beta.1",
|
||||
"os": [
|
||||
"win32"
|
||||
],
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-x64-msvc",
|
||||
"version": "0.28.0-beta.0",
|
||||
"version": "0.28.0-beta.1",
|
||||
"os": ["win32"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.win32-x64-msvc.node",
|
||||
|
||||
4
nodejs/package-lock.json
generated
4
nodejs/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.28.0-beta.0",
|
||||
"version": "0.28.0-beta.1",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.28.0-beta.0",
|
||||
"version": "0.28.0-beta.1",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
"ann"
|
||||
],
|
||||
"private": false,
|
||||
"version": "0.28.0-beta.0",
|
||||
"version": "0.28.0-beta.1",
|
||||
"main": "dist/index.js",
|
||||
"exports": {
|
||||
".": "./dist/index.js",
|
||||
|
||||
@@ -92,6 +92,13 @@ pub struct ClientConfig {
|
||||
pub extra_headers: Option<HashMap<String, String>>,
|
||||
pub id_delimiter: Option<String>,
|
||||
pub tls_config: Option<TlsConfig>,
|
||||
/// User identifier for tracking purposes.
|
||||
///
|
||||
/// This is sent as the `x-lancedb-user-id` header in requests to LanceDB Cloud/Enterprise.
|
||||
/// It can be set directly, or via the `LANCEDB_USER_ID` environment variable.
|
||||
/// Alternatively, set `LANCEDB_USER_ID_ENV_KEY` to specify another environment
|
||||
/// variable that contains the user ID value.
|
||||
pub user_id: Option<String>,
|
||||
}
|
||||
|
||||
impl From<TimeoutConfig> for lancedb::remote::TimeoutConfig {
|
||||
@@ -145,6 +152,7 @@ impl From<ClientConfig> for lancedb::remote::ClientConfig {
|
||||
id_delimiter: config.id_delimiter,
|
||||
tls_config: config.tls_config.map(Into::into),
|
||||
header_provider: None, // the header provider is set separately later
|
||||
user_id: config.user_id,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,21 +1,98 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import json
|
||||
import pickle
|
||||
from datetime import timedelta
|
||||
from typing import Any, Callable, Iterator, Literal, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import pyarrow as pa
|
||||
from deprecation import deprecated
|
||||
from lancedb import AsyncConnection, DBConnection
|
||||
import pyarrow as pa
|
||||
import json
|
||||
|
||||
from ._lancedb import async_permutation_builder, PermutationReader
|
||||
from .table import LanceTable
|
||||
from .background_loop import LOOP
|
||||
from .table import LanceTable
|
||||
from .util import batch_to_tensor, batch_to_tensor_rows
|
||||
from typing import Any, Callable, Iterator, Literal, Optional, TYPE_CHECKING, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lancedb.dependencies import pandas as pd, numpy as np, polars as pl
|
||||
|
||||
|
||||
def _builtin_transform(format: str) -> Callable[[pa.RecordBatch], Any]:
|
||||
if format == "python":
|
||||
return Transforms.arrow2python
|
||||
if format == "python_col":
|
||||
return Transforms.arrow2pythoncol
|
||||
if format == "numpy":
|
||||
return Transforms.arrow2numpy
|
||||
if format == "pandas":
|
||||
return Transforms.arrow2pandas
|
||||
if format == "arrow":
|
||||
return Transforms.arrow2arrow
|
||||
if format == "torch":
|
||||
return batch_to_tensor_rows
|
||||
if format == "torch_col":
|
||||
return batch_to_tensor
|
||||
if format == "polars":
|
||||
return Transforms.arrow2polars()
|
||||
raise ValueError(f"Invalid format: {format}")
|
||||
|
||||
|
||||
def _table_to_state(
|
||||
table: Union[LanceTable, dict[str, Any]],
|
||||
) -> dict[str, Any]:
|
||||
if isinstance(table, dict):
|
||||
return table
|
||||
if not isinstance(table, LanceTable):
|
||||
raise pickle.PicklingError(
|
||||
"Permutation pickling only supports LanceTable-backed permutations"
|
||||
)
|
||||
if table._namespace_client is not None:
|
||||
raise pickle.PicklingError(
|
||||
"Permutation pickling does not yet support namespace-backed tables"
|
||||
)
|
||||
if table._conn.uri.startswith("memory://"):
|
||||
raise pickle.PicklingError(
|
||||
"Permutation pickling does not support in-memory databases"
|
||||
)
|
||||
|
||||
try:
|
||||
read_consistency_interval = table._conn.read_consistency_interval
|
||||
except Exception:
|
||||
read_consistency_interval = None
|
||||
return {
|
||||
"uri": table._conn.uri,
|
||||
"name": table.name,
|
||||
"version": table.version,
|
||||
"storage_options": table.initial_storage_options(),
|
||||
"read_consistency_interval_secs": (
|
||||
read_consistency_interval.total_seconds()
|
||||
if read_consistency_interval is not None
|
||||
else None
|
||||
),
|
||||
"namespace_path": list(table.namespace),
|
||||
}
|
||||
|
||||
|
||||
def _table_from_state(state: dict[str, Any]) -> LanceTable:
|
||||
from . import connect
|
||||
|
||||
read_consistency_interval = (
|
||||
timedelta(seconds=state["read_consistency_interval_secs"])
|
||||
if state["read_consistency_interval_secs"] is not None
|
||||
else None
|
||||
)
|
||||
db = connect(
|
||||
state["uri"],
|
||||
read_consistency_interval=read_consistency_interval,
|
||||
storage_options=state["storage_options"],
|
||||
)
|
||||
table = db.open_table(state["name"], namespace_path=state["namespace_path"])
|
||||
table.checkout(state["version"])
|
||||
return table
|
||||
|
||||
|
||||
class PermutationBuilder:
|
||||
"""
|
||||
A utility for creating a "permutation table" which is a table that defines an
|
||||
@@ -385,6 +462,13 @@ class Permutation:
|
||||
selection: dict[str, str],
|
||||
batch_size: int,
|
||||
transform_fn: Callable[pa.RecordBatch, Any],
|
||||
*,
|
||||
base_table: Union[LanceTable, dict[str, Any]],
|
||||
permutation_table: Optional[Union[LanceTable, dict[str, Any]]],
|
||||
split: int,
|
||||
offset: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
transform_spec: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Internal constructor. Use [from_tables](#from_tables) instead.
|
||||
@@ -395,6 +479,93 @@ class Permutation:
|
||||
self.selection = selection
|
||||
self.transform_fn = transform_fn
|
||||
self.batch_size = batch_size
|
||||
self._transform_spec = transform_spec
|
||||
# These fields are used to reconstruct the permutation in a new process.
|
||||
self._base_table = base_table
|
||||
self._permutation_table = permutation_table
|
||||
self._split = split
|
||||
self._offset = offset
|
||||
self._limit = limit
|
||||
|
||||
def _reopen_metadata(self) -> dict[str, Any]:
|
||||
return {
|
||||
"base_table": self._base_table,
|
||||
"permutation_table": self._permutation_table,
|
||||
"split": self._split,
|
||||
"offset": self._offset,
|
||||
"limit": self._limit,
|
||||
"transform_spec": self._transform_spec,
|
||||
}
|
||||
|
||||
def __getstate__(self) -> dict[str, Any]:
|
||||
if self._transform_spec is not None:
|
||||
transform_state = {
|
||||
"kind": "builtin",
|
||||
"format": self._transform_spec,
|
||||
}
|
||||
else:
|
||||
transform_state = {
|
||||
"kind": "callable",
|
||||
"transform_fn": self.transform_fn,
|
||||
}
|
||||
|
||||
return {
|
||||
"selection": self.selection,
|
||||
"batch_size": self.batch_size,
|
||||
"transform": transform_state,
|
||||
"reopen": {
|
||||
**self._reopen_metadata(),
|
||||
# Store reopen state instead of live LanceTable handles.
|
||||
"base_table": _table_to_state(self._base_table),
|
||||
"permutation_table": (
|
||||
_table_to_state(self._permutation_table)
|
||||
if self._permutation_table is not None
|
||||
else None
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
def __setstate__(self, state: dict[str, Any]) -> None:
|
||||
reopen = state["reopen"]
|
||||
base_table = _table_from_state(reopen["base_table"])
|
||||
permutation_table_state = reopen["permutation_table"]
|
||||
permutation_table = (
|
||||
_table_from_state(permutation_table_state)
|
||||
if permutation_table_state is not None
|
||||
else None
|
||||
)
|
||||
split = reopen["split"]
|
||||
offset = reopen["offset"]
|
||||
limit = reopen["limit"]
|
||||
|
||||
async def do_reopen():
|
||||
reader = await PermutationReader.from_tables(
|
||||
base_table, permutation_table, split
|
||||
)
|
||||
if offset is not None:
|
||||
reader = await reader.with_offset(offset)
|
||||
if limit is not None:
|
||||
reader = await reader.with_limit(limit)
|
||||
return reader
|
||||
|
||||
transform = state["transform"]
|
||||
if transform["kind"] == "builtin":
|
||||
transform_spec = transform["format"]
|
||||
transform_fn = _builtin_transform(transform_spec)
|
||||
else:
|
||||
transform_spec = None
|
||||
transform_fn = transform["transform_fn"]
|
||||
|
||||
self.reader = LOOP.run(do_reopen())
|
||||
self.selection = state["selection"]
|
||||
self.batch_size = state["batch_size"]
|
||||
self.transform_fn = transform_fn
|
||||
self._transform_spec = transform_spec
|
||||
self._base_table = reopen["base_table"]
|
||||
self._permutation_table = permutation_table_state
|
||||
self._split = split
|
||||
self._offset = offset
|
||||
self._limit = limit
|
||||
|
||||
def _with_selection(self, selection: dict[str, str]) -> "Permutation":
|
||||
"""
|
||||
@@ -403,7 +574,13 @@ class Permutation:
|
||||
Does not validation of the selection and it replaces it entirely. This is not
|
||||
intended for public use.
|
||||
"""
|
||||
return Permutation(self.reader, selection, self.batch_size, self.transform_fn)
|
||||
return Permutation(
|
||||
self.reader,
|
||||
selection,
|
||||
self.batch_size,
|
||||
self.transform_fn,
|
||||
**self._reopen_metadata(),
|
||||
)
|
||||
|
||||
def _with_reader(self, reader: PermutationReader) -> "Permutation":
|
||||
"""
|
||||
@@ -411,13 +588,25 @@ class Permutation:
|
||||
|
||||
This is an internal method and should not be used directly.
|
||||
"""
|
||||
return Permutation(reader, self.selection, self.batch_size, self.transform_fn)
|
||||
return Permutation(
|
||||
reader,
|
||||
self.selection,
|
||||
self.batch_size,
|
||||
self.transform_fn,
|
||||
**self._reopen_metadata(),
|
||||
)
|
||||
|
||||
def with_batch_size(self, batch_size: int) -> "Permutation":
|
||||
"""
|
||||
Creates a new permutation with the given batch size
|
||||
"""
|
||||
return Permutation(self.reader, self.selection, batch_size, self.transform_fn)
|
||||
return Permutation(
|
||||
self.reader,
|
||||
self.selection,
|
||||
batch_size,
|
||||
self.transform_fn,
|
||||
**self._reopen_metadata(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def identity(cls, table: LanceTable) -> "Permutation":
|
||||
@@ -491,7 +680,14 @@ class Permutation:
|
||||
schema = await reader.output_schema(None)
|
||||
initial_selection = {name: name for name in schema.names}
|
||||
return cls(
|
||||
reader, initial_selection, DEFAULT_BATCH_SIZE, Transforms.arrow2python
|
||||
reader,
|
||||
initial_selection,
|
||||
DEFAULT_BATCH_SIZE,
|
||||
Transforms.arrow2python,
|
||||
base_table=base_table,
|
||||
permutation_table=permutation_table,
|
||||
split=split,
|
||||
transform_spec="python",
|
||||
)
|
||||
|
||||
return LOOP.run(do_from_tables())
|
||||
@@ -732,24 +928,16 @@ class Permutation:
|
||||
this method.
|
||||
"""
|
||||
assert format is not None, "format is required"
|
||||
if format == "python":
|
||||
return self.with_transform(Transforms.arrow2python)
|
||||
if format == "python_col":
|
||||
return self.with_transform(Transforms.arrow2pythoncol)
|
||||
elif format == "numpy":
|
||||
return self.with_transform(Transforms.arrow2numpy)
|
||||
elif format == "pandas":
|
||||
return self.with_transform(Transforms.arrow2pandas)
|
||||
elif format == "arrow":
|
||||
return self.with_transform(Transforms.arrow2arrow)
|
||||
elif format == "torch":
|
||||
return self.with_transform(batch_to_tensor_rows)
|
||||
elif format == "torch_col":
|
||||
return self.with_transform(batch_to_tensor)
|
||||
elif format == "polars":
|
||||
return self.with_transform(Transforms.arrow2polars())
|
||||
else:
|
||||
raise ValueError(f"Invalid format: {format}")
|
||||
return Permutation(
|
||||
self.reader,
|
||||
self.selection,
|
||||
self.batch_size,
|
||||
_builtin_transform(format),
|
||||
**{
|
||||
**self._reopen_metadata(),
|
||||
"transform_spec": format,
|
||||
},
|
||||
)
|
||||
|
||||
def with_transform(self, transform: Callable[pa.RecordBatch, Any]) -> "Permutation":
|
||||
"""
|
||||
@@ -762,7 +950,16 @@ class Permutation:
|
||||
for expensive operations such as image decoding.
|
||||
"""
|
||||
assert transform is not None, "transform is required"
|
||||
return Permutation(self.reader, self.selection, self.batch_size, transform)
|
||||
return Permutation(
|
||||
self.reader,
|
||||
self.selection,
|
||||
self.batch_size,
|
||||
transform,
|
||||
**{
|
||||
**self._reopen_metadata(),
|
||||
"transform_spec": None,
|
||||
},
|
||||
)
|
||||
|
||||
def __getitem__(self, index: int) -> Any:
|
||||
"""
|
||||
@@ -800,7 +997,16 @@ class Permutation:
|
||||
|
||||
async def do_with_skip():
|
||||
reader = await self.reader.with_offset(skip)
|
||||
return self._with_reader(reader)
|
||||
return Permutation(
|
||||
reader,
|
||||
self.selection,
|
||||
self.batch_size,
|
||||
self.transform_fn,
|
||||
**{
|
||||
**self._reopen_metadata(),
|
||||
"offset": skip,
|
||||
},
|
||||
)
|
||||
|
||||
return LOOP.run(do_with_skip())
|
||||
|
||||
@@ -823,7 +1029,16 @@ class Permutation:
|
||||
|
||||
async def do_with_take():
|
||||
reader = await self.reader.with_limit(limit)
|
||||
return self._with_reader(reader)
|
||||
return Permutation(
|
||||
reader,
|
||||
self.selection,
|
||||
self.batch_size,
|
||||
self.transform_fn,
|
||||
**{
|
||||
**self._reopen_metadata(),
|
||||
"limit": limit,
|
||||
},
|
||||
)
|
||||
|
||||
return LOOP.run(do_with_take())
|
||||
|
||||
|
||||
@@ -145,6 +145,33 @@ class TlsConfig:
|
||||
|
||||
@dataclass
|
||||
class ClientConfig:
|
||||
"""Configuration for the LanceDB Cloud HTTP client.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
user_agent: str
|
||||
User agent string sent with requests.
|
||||
retry_config: RetryConfig
|
||||
Configuration for retrying failed requests.
|
||||
timeout_config: Optional[TimeoutConfig]
|
||||
Configuration for request timeouts.
|
||||
extra_headers: Optional[dict]
|
||||
Additional headers to include in requests.
|
||||
id_delimiter: Optional[str]
|
||||
The delimiter to use when constructing object identifiers.
|
||||
tls_config: Optional[TlsConfig]
|
||||
TLS/mTLS configuration for secure connections.
|
||||
header_provider: Optional[HeaderProvider]
|
||||
Provider for dynamic headers to be added to each request.
|
||||
user_id: Optional[str]
|
||||
User identifier for tracking purposes. This is sent as the
|
||||
`x-lancedb-user-id` header in requests to LanceDB Cloud/Enterprise.
|
||||
|
||||
This can also be set via the `LANCEDB_USER_ID` environment variable.
|
||||
Alternatively, set `LANCEDB_USER_ID_ENV_KEY` to specify another
|
||||
environment variable that contains the user ID value.
|
||||
"""
|
||||
|
||||
user_agent: str = f"LanceDB-Python-Client/{__version__}"
|
||||
retry_config: RetryConfig = field(default_factory=RetryConfig)
|
||||
timeout_config: Optional[TimeoutConfig] = field(default_factory=TimeoutConfig)
|
||||
@@ -152,6 +179,7 @@ class ClientConfig:
|
||||
id_delimiter: Optional[str] = None
|
||||
tls_config: Optional[TlsConfig] = None
|
||||
header_provider: Optional["HeaderProvider"] = None
|
||||
user_id: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.retry_config, dict):
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
import pyarrow as pa
|
||||
import math
|
||||
import pickle
|
||||
import pytest
|
||||
|
||||
from lancedb import DBConnection, Table, connect
|
||||
@@ -599,6 +600,87 @@ def test_limit_offset(some_permutation: Permutation):
|
||||
some_permutation.with_skip(500).with_take(500).num_rows
|
||||
|
||||
|
||||
def test_permutation_pickle_rejects_in_memory_tables(mem_db: DBConnection):
|
||||
table = mem_db.create_table("identity_table", pa.table({"id": range(10)}))
|
||||
permutation = Permutation.identity(table)
|
||||
|
||||
with pytest.raises(
|
||||
pickle.PicklingError,
|
||||
match="in-memory databases",
|
||||
):
|
||||
pickle.dumps(permutation)
|
||||
|
||||
|
||||
def test_identity_permutation_pickle_roundtrip_preserves_table_version(tmp_path):
|
||||
db = connect(tmp_path)
|
||||
table = db.create_table(
|
||||
"identity_table",
|
||||
pa.table({"id": range(10), "value": range(10)}),
|
||||
)
|
||||
permutation = (
|
||||
Permutation.identity(table).with_skip(2).with_take(3).with_format("python_col")
|
||||
)
|
||||
|
||||
payload = pickle.dumps(permutation)
|
||||
table.add(pa.table({"id": [10], "value": [10]}))
|
||||
|
||||
restored = pickle.loads(payload)
|
||||
assert restored.num_rows == 3
|
||||
batches = list(restored.iter(10, skip_last_batch=False))
|
||||
assert batches == [{"id": [2, 3, 4], "value": [2, 3, 4]}]
|
||||
|
||||
|
||||
def test_permutation_pickle_roundtrip_with_persisted_permutation_table(tmp_path):
|
||||
db = connect(tmp_path)
|
||||
table = db.create_table(
|
||||
"base_table",
|
||||
pa.table({"id": range(1000), "value": range(1000)}),
|
||||
)
|
||||
permutation_table = (
|
||||
permutation_builder(table)
|
||||
.split_random(ratios=[0.95, 0.05], seed=42, split_names=["train", "test"])
|
||||
.shuffle(seed=42)
|
||||
.persist(db, "persisted_permutation")
|
||||
.execute()
|
||||
)
|
||||
permutation = (
|
||||
Permutation.from_tables(table, permutation_table, "test")
|
||||
.select_columns(["id"])
|
||||
.rename_column("id", "row_id")
|
||||
.with_batch_size(32)
|
||||
.with_skip(5)
|
||||
.with_take(10)
|
||||
.with_format("arrow")
|
||||
)
|
||||
|
||||
restored = pickle.loads(pickle.dumps(permutation))
|
||||
|
||||
assert restored.batch_size == 32
|
||||
assert restored.column_names == ["row_id"]
|
||||
assert restored.num_rows == 10
|
||||
assert (
|
||||
restored.__getitems__([0, 1, 2]).to_pylist()
|
||||
== permutation.__getitems__([0, 1, 2]).to_pylist()
|
||||
)
|
||||
|
||||
|
||||
def test_permutation_pickle_roundtrip_preserves_builtin_polars_format(tmp_path):
|
||||
pl = pytest.importorskip("polars")
|
||||
|
||||
db = connect(tmp_path)
|
||||
table = db.create_table(
|
||||
"polars_table",
|
||||
pa.table({"id": range(5), "value": range(5)}),
|
||||
)
|
||||
permutation = Permutation.identity(table).with_take(2).with_format("polars")
|
||||
|
||||
restored = pickle.loads(pickle.dumps(permutation))
|
||||
batch = restored.__getitems__([0, 1])
|
||||
|
||||
assert isinstance(batch, pl.DataFrame)
|
||||
assert batch.to_dict(as_series=False) == {"id": [0, 1], "value": [0, 1]}
|
||||
|
||||
|
||||
def test_remove_columns(some_permutation: Permutation):
|
||||
assert some_permutation.remove_columns(["value"]).schema == pa.schema(
|
||||
[("id", pa.int64())]
|
||||
|
||||
@@ -547,6 +547,7 @@ pub struct PyClientConfig {
|
||||
id_delimiter: Option<String>,
|
||||
tls_config: Option<PyClientTlsConfig>,
|
||||
header_provider: Option<Py<PyAny>>,
|
||||
user_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(FromPyObject)]
|
||||
@@ -631,6 +632,7 @@ impl From<PyClientConfig> for lancedb::remote::ClientConfig {
|
||||
id_delimiter: value.id_delimiter,
|
||||
tls_config: value.tls_config.map(Into::into),
|
||||
header_provider,
|
||||
user_id: value.user_id,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb"
|
||||
version = "0.28.0-beta.0"
|
||||
version = "0.28.0-beta.1"
|
||||
edition.workspace = true
|
||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||
license.workspace = true
|
||||
|
||||
@@ -52,6 +52,13 @@ pub struct ClientConfig {
|
||||
pub tls_config: Option<TlsConfig>,
|
||||
/// Provider for custom headers to be added to each request
|
||||
pub header_provider: Option<Arc<dyn HeaderProvider>>,
|
||||
/// User identifier for tracking purposes.
|
||||
///
|
||||
/// This is sent as the `x-lancedb-user-id` header in requests to LanceDB Cloud/Enterprise.
|
||||
/// It can be set directly, or via the `LANCEDB_USER_ID` environment variable.
|
||||
/// Alternatively, set `LANCEDB_USER_ID_ENV_KEY` to specify another environment
|
||||
/// variable that contains the user ID value.
|
||||
pub user_id: Option<String>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for ClientConfig {
|
||||
@@ -67,6 +74,7 @@ impl std::fmt::Debug for ClientConfig {
|
||||
"header_provider",
|
||||
&self.header_provider.as_ref().map(|_| "Some(...)"),
|
||||
)
|
||||
.field("user_id", &self.user_id)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
@@ -81,10 +89,41 @@ impl Default for ClientConfig {
|
||||
id_delimiter: None,
|
||||
tls_config: None,
|
||||
header_provider: None,
|
||||
user_id: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientConfig {
|
||||
/// Resolve the user ID from the config or environment variables.
|
||||
///
|
||||
/// Resolution order:
|
||||
/// 1. If `user_id` is set in the config, use that value
|
||||
/// 2. If `LANCEDB_USER_ID` environment variable is set, use that value
|
||||
/// 3. If `LANCEDB_USER_ID_ENV_KEY` is set, read the env var it points to
|
||||
/// 4. Otherwise, return None
|
||||
pub fn resolve_user_id(&self) -> Option<String> {
|
||||
if self.user_id.is_some() {
|
||||
return self.user_id.clone();
|
||||
}
|
||||
|
||||
if let Ok(user_id) = std::env::var("LANCEDB_USER_ID")
|
||||
&& !user_id.is_empty()
|
||||
{
|
||||
return Some(user_id);
|
||||
}
|
||||
|
||||
if let Ok(env_key) = std::env::var("LANCEDB_USER_ID_ENV_KEY")
|
||||
&& let Ok(user_id) = std::env::var(&env_key)
|
||||
&& !user_id.is_empty()
|
||||
{
|
||||
return Some(user_id);
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// How to handle timeouts for HTTP requests.
|
||||
#[derive(Clone, Default, Debug)]
|
||||
pub struct TimeoutConfig {
|
||||
@@ -464,6 +503,15 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(user_id) = config.resolve_user_id() {
|
||||
headers.insert(
|
||||
HeaderName::from_static("x-lancedb-user-id"),
|
||||
HeaderValue::from_str(&user_id).map_err(|_| Error::InvalidInput {
|
||||
message: format!("non-ascii user_id '{}' provided", user_id),
|
||||
})?,
|
||||
);
|
||||
}
|
||||
|
||||
Ok(headers)
|
||||
}
|
||||
|
||||
@@ -1072,4 +1120,91 @@ mod tests {
|
||||
_ => panic!("Expected Runtime error"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_user_id_direct_value() {
|
||||
let config = ClientConfig {
|
||||
user_id: Some("direct-user-id".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
assert_eq!(config.resolve_user_id(), Some("direct-user-id".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_user_id_none() {
|
||||
let config = ClientConfig::default();
|
||||
// Clear env vars that might be set from other tests
|
||||
// SAFETY: This is only called in tests
|
||||
unsafe {
|
||||
std::env::remove_var("LANCEDB_USER_ID");
|
||||
std::env::remove_var("LANCEDB_USER_ID_ENV_KEY");
|
||||
}
|
||||
assert_eq!(config.resolve_user_id(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_user_id_from_env() {
|
||||
// SAFETY: This is only called in tests
|
||||
unsafe {
|
||||
std::env::set_var("LANCEDB_USER_ID", "env-user-id");
|
||||
}
|
||||
let config = ClientConfig::default();
|
||||
assert_eq!(config.resolve_user_id(), Some("env-user-id".to_string()));
|
||||
// SAFETY: This is only called in tests
|
||||
unsafe {
|
||||
std::env::remove_var("LANCEDB_USER_ID");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_user_id_from_env_key() {
|
||||
// SAFETY: This is only called in tests
|
||||
unsafe {
|
||||
std::env::remove_var("LANCEDB_USER_ID");
|
||||
std::env::set_var("LANCEDB_USER_ID_ENV_KEY", "MY_CUSTOM_USER_ID");
|
||||
std::env::set_var("MY_CUSTOM_USER_ID", "custom-env-user-id");
|
||||
}
|
||||
let config = ClientConfig::default();
|
||||
assert_eq!(
|
||||
config.resolve_user_id(),
|
||||
Some("custom-env-user-id".to_string())
|
||||
);
|
||||
// SAFETY: This is only called in tests
|
||||
unsafe {
|
||||
std::env::remove_var("LANCEDB_USER_ID_ENV_KEY");
|
||||
std::env::remove_var("MY_CUSTOM_USER_ID");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_user_id_direct_takes_precedence() {
|
||||
// SAFETY: This is only called in tests
|
||||
unsafe {
|
||||
std::env::set_var("LANCEDB_USER_ID", "env-user-id");
|
||||
}
|
||||
let config = ClientConfig {
|
||||
user_id: Some("direct-user-id".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
assert_eq!(config.resolve_user_id(), Some("direct-user-id".to_string()));
|
||||
// SAFETY: This is only called in tests
|
||||
unsafe {
|
||||
std::env::remove_var("LANCEDB_USER_ID");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_user_id_empty_env_ignored() {
|
||||
// SAFETY: This is only called in tests
|
||||
unsafe {
|
||||
std::env::set_var("LANCEDB_USER_ID", "");
|
||||
std::env::remove_var("LANCEDB_USER_ID_ENV_KEY");
|
||||
}
|
||||
let config = ClientConfig::default();
|
||||
assert_eq!(config.resolve_user_id(), None);
|
||||
// SAFETY: This is only called in tests
|
||||
unsafe {
|
||||
std::env::remove_var("LANCEDB_USER_ID");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user