From 9d8699f99e9ba2e7517bee0ae532780eb81ec057 Mon Sep 17 00:00:00 2001 From: Zelys Date: Fri, 3 Apr 2026 12:40:49 -0500 Subject: [PATCH 1/6] feat(python): support Enum types in Pydantic to Arrow schema conversion (#3232) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Fixes #1846. Python `Enum` fields raised `TypeError: Converting Pydantic type to Arrow Type: unsupported type ` when converting a Pydantic model to an Arrow schema. The fix adds Enum detection in `_pydantic_type_to_arrow_type`. When an Enum subclass is encountered, the value type of its members is inspected and mapped to the appropriate Arrow type: - `str`-valued enums (e.g. `class Status(str, Enum)`) → `pa.utf8()` - `int`-valued enums (e.g. `class Priority(int, Enum)`) → `pa.int64()` - Other homogeneous value types → the Arrow type for that Python type - Mixed-value or empty enums → `pa.utf8()` (safe fallback) This covers the common `(str, Enum)` and `(int, Enum)` mixin patterns used in practice. ## Changes - `python/python/lancedb/pydantic.py`: add Enum branch in `_pydantic_type_to_arrow_type` - `python/python/tests/test_pydantic.py`: add `test_enum_types` covering `str`, `int`, and `Optional` Enum fields ## Note on #2395 PR #2395 handles `StrEnum` (Python 3.11+) specifically, using a dictionary-encoded type. This PR handles the broader `(str, Enum)` / `(int, Enum)` mixin pattern that works across all Python versions and stores values as their natural Arrow type. AI assistance was used in developing this fix. --- python/python/lancedb/pydantic.py | 14 ++++++++++++++ python/python/tests/test_pydantic.py | 27 +++++++++++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/python/python/lancedb/pydantic.py b/python/python/lancedb/pydantic.py index 653ea3333..8b1f991f9 100644 --- a/python/python/lancedb/pydantic.py +++ b/python/python/lancedb/pydantic.py @@ -10,6 +10,7 @@ import sys import types from abc import ABC, abstractmethod from datetime import date, datetime +from enum import Enum from typing import ( TYPE_CHECKING, Any, @@ -314,6 +315,19 @@ def _pydantic_type_to_arrow_type(tp: Any, field: FieldInfo) -> pa.DataType: return pa.list_(pa.list_(tp.value_arrow_type(), tp.dim())) # For regular Vector return pa.list_(tp.value_arrow_type(), tp.dim()) + if _safe_issubclass(tp, Enum): + # Map Enum to the Arrow type of its value. + # For string-valued enums, use dictionary encoding for efficiency. + # For integer enums, use the native type. + # Fall back to utf8 for mixed-type or empty enums. + value_types = {type(m.value) for m in tp} + if len(value_types) == 1: + value_type = value_types.pop() + if value_type is str: + # Use dictionary encoding for string enums + return pa.dictionary(pa.int32(), pa.utf8()) + return _py_type_to_arrow_type(value_type, field) + return pa.utf8() return _py_type_to_arrow_type(tp, field) diff --git a/python/python/tests/test_pydantic.py b/python/python/tests/test_pydantic.py index fd0bb2c64..701bbef5a 100644 --- a/python/python/tests/test_pydantic.py +++ b/python/python/tests/test_pydantic.py @@ -3,6 +3,7 @@ import json from datetime import date, datetime +from enum import Enum from typing import List, Optional, Tuple import pyarrow as pa @@ -673,3 +674,29 @@ async def test_aliases_in_lance_model_async(mem_db_async): assert hasattr(model, "name") assert hasattr(model, "distance") assert model.distance < 0.01 + + +def test_enum_types(): + """Enum fields should map to the Arrow type of their value (issue #1846).""" + + class StrStatus(str, Enum): + PENDING = "pending" + RUNNING = "running" + DONE = "done" + + class IntPriority(int, Enum): + LOW = 1 + MEDIUM = 2 + HIGH = 3 + + class TestModel(pydantic.BaseModel): + status: StrStatus + priority: IntPriority + opt_status: Optional[StrStatus] = None + + schema = pydantic_to_schema(TestModel) + + assert schema.field("status").type == pa.dictionary(pa.int32(), pa.utf8()) + assert schema.field("priority").type == pa.int64() + assert schema.field("opt_status").type == pa.dictionary(pa.int32(), pa.utf8()) + assert schema.field("opt_status").nullable From d082c2d2ac538170a560703f35bc8b51759d7169 Mon Sep 17 00:00:00 2001 From: LanceDB Robot Date: Sun, 5 Apr 2026 10:49:51 +0800 Subject: [PATCH 2/6] chore: update lance dependency to v5.0.0-beta.5 (#3237) ## Summary - update Rust Lance workspace dependencies to `v5.0.0-beta.5` using `ci/set_lance_version.py` - update Java `lance-core` dependency property to `5.0.0-beta.5` - refresh Cargo lockfile to the new Lance tag ## Verification - `cargo clippy --workspace --tests --all-features -- -D warnings` - `cargo fmt --all` ## Upstream Tag - https://github.com/lance-format/lance/releases/tag/v5.0.0-beta.5 --------- Co-authored-by: Jack Ye --- Cargo.lock | 64 +++++++++++++------------- Cargo.toml | 28 +++++------ java/pom.xml | 2 +- python/pyproject.toml | 4 +- python/python/lancedb/_lancedb.pyi | 3 ++ python/python/lancedb/db.py | 55 ++++++++++++++++++++++ python/python/lancedb/namespace.py | 27 +++++++++++ python/python/lancedb/remote/db.py | 14 ++++++ python/python/tests/test_db.py | 57 +++++++++++++++++++++++ python/src/connection.rs | 19 ++++++++ rust/lancedb/src/connection.rs | 10 ++++ rust/lancedb/src/database.rs | 9 ++++ rust/lancedb/src/database/listing.rs | 9 ++++ rust/lancedb/src/database/namespace.rs | 10 ++++ rust/lancedb/src/remote/db.rs | 26 +++++++++++ 15 files changed, 288 insertions(+), 49 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 54e25c4f8..e2cf80fa0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3072,8 +3072,8 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" [[package]] name = "fsst" -version = "5.0.0-beta.4" -source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.4#d9068e76a301df9e21d7282419f24f61a11375ac" +version = "5.0.0-beta.5" +source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.5#d630106da5a238b3adfb8c5dea3b3921f3519945" dependencies = [ "arrow-array", "rand 0.9.2", @@ -4134,8 +4134,8 @@ dependencies = [ [[package]] name = "lance" -version = "5.0.0-beta.4" -source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.4#d9068e76a301df9e21d7282419f24f61a11375ac" +version = "5.0.0-beta.5" +source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.5#d630106da5a238b3adfb8c5dea3b3921f3519945" dependencies = [ "arrow", "arrow-arith", @@ -4201,8 +4201,8 @@ dependencies = [ [[package]] name = "lance-arrow" -version = "5.0.0-beta.4" -source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.4#d9068e76a301df9e21d7282419f24f61a11375ac" +version = "5.0.0-beta.5" +source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.5#d630106da5a238b3adfb8c5dea3b3921f3519945" dependencies = [ "arrow-array", "arrow-buffer", @@ -4222,8 +4222,8 @@ dependencies = [ [[package]] name = "lance-bitpacking" -version = "5.0.0-beta.4" -source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.4#d9068e76a301df9e21d7282419f24f61a11375ac" +version = "5.0.0-beta.5" +source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.5#d630106da5a238b3adfb8c5dea3b3921f3519945" dependencies = [ "arrayref", "paste", @@ -4232,8 +4232,8 @@ dependencies = [ [[package]] name = "lance-core" -version = "5.0.0-beta.4" -source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.4#d9068e76a301df9e21d7282419f24f61a11375ac" +version = "5.0.0-beta.5" +source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.5#d630106da5a238b3adfb8c5dea3b3921f3519945" dependencies = [ "arrow-array", "arrow-buffer", @@ -4270,8 +4270,8 @@ dependencies = [ [[package]] name = "lance-datafusion" -version = "5.0.0-beta.4" -source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.4#d9068e76a301df9e21d7282419f24f61a11375ac" +version = "5.0.0-beta.5" +source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.5#d630106da5a238b3adfb8c5dea3b3921f3519945" dependencies = [ "arrow", "arrow-array", @@ -4301,8 +4301,8 @@ dependencies = [ [[package]] name = "lance-datagen" -version = "5.0.0-beta.4" -source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.4#d9068e76a301df9e21d7282419f24f61a11375ac" +version = "5.0.0-beta.5" +source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.5#d630106da5a238b3adfb8c5dea3b3921f3519945" dependencies = [ "arrow", "arrow-array", @@ -4320,8 +4320,8 @@ dependencies = [ [[package]] name = "lance-encoding" -version = "5.0.0-beta.4" -source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.4#d9068e76a301df9e21d7282419f24f61a11375ac" +version = "5.0.0-beta.5" +source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.5#d630106da5a238b3adfb8c5dea3b3921f3519945" dependencies = [ "arrow-arith", "arrow-array", @@ -4358,8 +4358,8 @@ dependencies = [ [[package]] name = "lance-file" -version = "5.0.0-beta.4" -source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.4#d9068e76a301df9e21d7282419f24f61a11375ac" +version = "5.0.0-beta.5" +source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.5#d630106da5a238b3adfb8c5dea3b3921f3519945" dependencies = [ "arrow-arith", "arrow-array", @@ -4391,8 +4391,8 @@ dependencies = [ [[package]] name = "lance-index" -version = "5.0.0-beta.4" -source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.4#d9068e76a301df9e21d7282419f24f61a11375ac" +version = "5.0.0-beta.5" +source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.5#d630106da5a238b3adfb8c5dea3b3921f3519945" dependencies = [ "arrow", "arrow-arith", @@ -4456,8 +4456,8 @@ dependencies = [ [[package]] name = "lance-io" -version = "5.0.0-beta.4" -source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.4#d9068e76a301df9e21d7282419f24f61a11375ac" +version = "5.0.0-beta.5" +source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.5#d630106da5a238b3adfb8c5dea3b3921f3519945" dependencies = [ "arrow", "arrow-arith", @@ -4501,8 +4501,8 @@ dependencies = [ [[package]] name = "lance-linalg" -version = "5.0.0-beta.4" -source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.4#d9068e76a301df9e21d7282419f24f61a11375ac" +version = "5.0.0-beta.5" +source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.5#d630106da5a238b3adfb8c5dea3b3921f3519945" dependencies = [ "arrow-array", "arrow-buffer", @@ -4518,8 +4518,8 @@ dependencies = [ [[package]] name = "lance-namespace" -version = "5.0.0-beta.4" -source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.4#d9068e76a301df9e21d7282419f24f61a11375ac" +version = "5.0.0-beta.5" +source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.5#d630106da5a238b3adfb8c5dea3b3921f3519945" dependencies = [ "arrow", "async-trait", @@ -4532,8 +4532,8 @@ dependencies = [ [[package]] name = "lance-namespace-impls" -version = "5.0.0-beta.4" -source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.4#d9068e76a301df9e21d7282419f24f61a11375ac" +version = "5.0.0-beta.5" +source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.5#d630106da5a238b3adfb8c5dea3b3921f3519945" dependencies = [ "arrow", "arrow-ipc", @@ -4578,8 +4578,8 @@ dependencies = [ [[package]] name = "lance-table" -version = "5.0.0-beta.4" -source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.4#d9068e76a301df9e21d7282419f24f61a11375ac" +version = "5.0.0-beta.5" +source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.5#d630106da5a238b3adfb8c5dea3b3921f3519945" dependencies = [ "arrow", "arrow-array", @@ -4618,8 +4618,8 @@ dependencies = [ [[package]] name = "lance-testing" -version = "5.0.0-beta.4" -source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.4#d9068e76a301df9e21d7282419f24f61a11375ac" +version = "5.0.0-beta.5" +source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.5#d630106da5a238b3adfb8c5dea3b3921f3519945" dependencies = [ "arrow-array", "arrow-schema", diff --git a/Cargo.toml b/Cargo.toml index feef1066d..9bb9ab8a4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,20 +15,20 @@ categories = ["database-implementations"] rust-version = "1.91.0" [workspace.dependencies] -lance = { "version" = "=5.0.0-beta.4", default-features = false, "tag" = "v5.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" } -lance-core = { "version" = "=5.0.0-beta.4", "tag" = "v5.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" } -lance-datagen = { "version" = "=5.0.0-beta.4", "tag" = "v5.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" } -lance-file = { "version" = "=5.0.0-beta.4", "tag" = "v5.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" } -lance-io = { "version" = "=5.0.0-beta.4", default-features = false, "tag" = "v5.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" } -lance-index = { "version" = "=5.0.0-beta.4", "tag" = "v5.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" } -lance-linalg = { "version" = "=5.0.0-beta.4", "tag" = "v5.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" } -lance-namespace = { "version" = "=5.0.0-beta.4", "tag" = "v5.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" } -lance-namespace-impls = { "version" = "=5.0.0-beta.4", default-features = false, "tag" = "v5.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" } -lance-table = { "version" = "=5.0.0-beta.4", "tag" = "v5.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" } -lance-testing = { "version" = "=5.0.0-beta.4", "tag" = "v5.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" } -lance-datafusion = { "version" = "=5.0.0-beta.4", "tag" = "v5.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" } -lance-encoding = { "version" = "=5.0.0-beta.4", "tag" = "v5.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" } -lance-arrow = { "version" = "=5.0.0-beta.4", "tag" = "v5.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" } +lance = { "version" = "=5.0.0-beta.5", default-features = false, "tag" = "v5.0.0-beta.5", "git" = "https://github.com/lance-format/lance.git" } +lance-core = { "version" = "=5.0.0-beta.5", "tag" = "v5.0.0-beta.5", "git" = "https://github.com/lance-format/lance.git" } +lance-datagen = { "version" = "=5.0.0-beta.5", "tag" = "v5.0.0-beta.5", "git" = "https://github.com/lance-format/lance.git" } +lance-file = { "version" = "=5.0.0-beta.5", "tag" = "v5.0.0-beta.5", "git" = "https://github.com/lance-format/lance.git" } +lance-io = { "version" = "=5.0.0-beta.5", default-features = false, "tag" = "v5.0.0-beta.5", "git" = "https://github.com/lance-format/lance.git" } +lance-index = { "version" = "=5.0.0-beta.5", "tag" = "v5.0.0-beta.5", "git" = "https://github.com/lance-format/lance.git" } +lance-linalg = { "version" = "=5.0.0-beta.5", "tag" = "v5.0.0-beta.5", "git" = "https://github.com/lance-format/lance.git" } +lance-namespace = { "version" = "=5.0.0-beta.5", "tag" = "v5.0.0-beta.5", "git" = "https://github.com/lance-format/lance.git" } +lance-namespace-impls = { "version" = "=5.0.0-beta.5", default-features = false, "tag" = "v5.0.0-beta.5", "git" = "https://github.com/lance-format/lance.git" } +lance-table = { "version" = "=5.0.0-beta.5", "tag" = "v5.0.0-beta.5", "git" = "https://github.com/lance-format/lance.git" } +lance-testing = { "version" = "=5.0.0-beta.5", "tag" = "v5.0.0-beta.5", "git" = "https://github.com/lance-format/lance.git" } +lance-datafusion = { "version" = "=5.0.0-beta.5", "tag" = "v5.0.0-beta.5", "git" = "https://github.com/lance-format/lance.git" } +lance-encoding = { "version" = "=5.0.0-beta.5", "tag" = "v5.0.0-beta.5", "git" = "https://github.com/lance-format/lance.git" } +lance-arrow = { "version" = "=5.0.0-beta.5", "tag" = "v5.0.0-beta.5", "git" = "https://github.com/lance-format/lance.git" } ahash = "0.8" # Note that this one does not include pyarrow arrow = { version = "57.2", optional = false } diff --git a/java/pom.xml b/java/pom.xml index 69d728acc..9aef9d792 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -28,7 +28,7 @@ UTF-8 15.0.0 - 5.0.0-beta.4 + 5.0.0-beta.5 false 2.30.0 1.7 diff --git a/python/pyproject.toml b/python/pyproject.toml index 1f3de8b5d..98dfad32c 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -45,7 +45,7 @@ repository = "https://github.com/lancedb/lancedb" [project.optional-dependencies] pylance = [ - "pylance>=5.0.0b3", + "pylance>=5.0.0b5", ] tests = [ "aiohttp>=3.9.0", @@ -59,7 +59,7 @@ tests = [ "polars>=0.19, <=1.3.0", "tantivy>=0.20.0", "pyarrow-stubs>=16.0", - "pylance>=5.0.0b3", + "pylance>=5.0.0b5", "requests>=2.31.0", "datafusion>=52,<53", ] diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index 6a8b51d16..76c08041b 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -151,6 +151,9 @@ class Connection(object): async def drop_all_tables( self, namespace_path: Optional[List[str]] = None ) -> None: ... + async def namespace_client_config( + self, + ) -> Dict[str, Any]: ... class Table: def name(self) -> str: ... diff --git a/python/python/lancedb/db.py b/python/python/lancedb/db.py index bfe7f8d70..869f1481f 100644 --- a/python/python/lancedb/db.py +++ b/python/python/lancedb/db.py @@ -23,11 +23,13 @@ from lancedb.embeddings.registry import EmbeddingFunctionRegistry from lancedb.common import data_to_reader, sanitize_uri, validate_schema from lancedb.background_loop import LOOP from lance_namespace import ( + LanceNamespace, ListNamespacesResponse, CreateNamespaceResponse, DropNamespaceResponse, DescribeNamespaceResponse, ListTablesResponse, + connect as namespace_connect, ) from . import __version__ @@ -507,6 +509,26 @@ class DBConnection(EnforceOverrides): def uri(self) -> str: return self._uri + def namespace_client(self) -> LanceNamespace: + """Get the equivalent namespace client for this connection. + + For native storage connections, this returns a DirectoryNamespace + pointing to the same root with the same storage options. + + For namespace connections, this returns the backing namespace client. + + For enterprise (remote) connections, this returns a RestNamespace + with the same URI and authentication headers. + + Returns + ------- + LanceNamespace + The namespace client for this connection. + """ + raise NotImplementedError( + "namespace_client is not supported for this connection type" + ) + class LanceDBConnection(DBConnection): """ @@ -1044,6 +1066,20 @@ class LanceDBConnection(DBConnection): ) ) + @override + def namespace_client(self) -> LanceNamespace: + """Get the equivalent namespace client for this connection. + + Returns a DirectoryNamespace pointing to the same root with the + same storage options. + + Returns + ------- + LanceNamespace + The namespace client for this connection. + """ + return LOOP.run(self._conn.namespace_client()) + @deprecation.deprecated( deprecated_in="0.15.1", removed_in="0.17", @@ -1716,6 +1752,25 @@ class AsyncConnection(object): namespace_path = [] await self._inner.drop_all_tables(namespace_path=namespace_path) + async def namespace_client(self) -> LanceNamespace: + """Get the equivalent namespace client for this connection. + + For native storage connections, this returns a DirectoryNamespace + pointing to the same root with the same storage options. + + For namespace connections, this returns the backing namespace client. + + For enterprise (remote) connections, this returns a RestNamespace + with the same URI and authentication headers. + + Returns + ------- + LanceNamespace + The namespace client for this connection. + """ + config = await self._inner.namespace_client_config() + return namespace_connect(config["impl"], config["properties"]) + @deprecation.deprecated( deprecated_in="0.15.1", removed_in="0.17", diff --git a/python/python/lancedb/namespace.py b/python/python/lancedb/namespace.py index a400b2817..55df0c82b 100644 --- a/python/python/lancedb/namespace.py +++ b/python/python/lancedb/namespace.py @@ -890,6 +890,20 @@ class LanceNamespaceDBConnection(DBConnection): pushdown_operations=self._pushdown_operations, ) + @override + def namespace_client(self) -> LanceNamespace: + """Get the namespace client for this connection. + + For namespace connections, this returns the backing namespace client + that was provided during construction. + + Returns + ------- + LanceNamespace + The namespace client for this connection. + """ + return self._namespace_client + class AsyncLanceNamespaceDBConnection: """ @@ -1387,6 +1401,19 @@ class AsyncLanceNamespaceDBConnection: page_token=response.page_token, ) + async def namespace_client(self) -> LanceNamespace: + """Get the namespace client for this connection. + + For namespace connections, this returns the backing namespace client + that was provided during construction. + + Returns + ------- + LanceNamespace + The namespace client for this connection. + """ + return self._namespace_client + def connect_namespace( namespace_client_impl: str, diff --git a/python/python/lancedb/remote/db.py b/python/python/lancedb/remote/db.py index 0afe469b9..e110cdac1 100644 --- a/python/python/lancedb/remote/db.py +++ b/python/python/lancedb/remote/db.py @@ -24,6 +24,7 @@ from ..common import DATA from ..db import DBConnection, LOOP from ..embeddings import EmbeddingFunctionConfig from lance_namespace import ( + LanceNamespace, CreateNamespaceResponse, DescribeNamespaceResponse, DropNamespaceResponse, @@ -570,6 +571,19 @@ class RemoteDBConnection(DBConnection): ) ) + @override + def namespace_client(self) -> LanceNamespace: + """Get the equivalent namespace client for this connection. + + Returns a RestNamespace with the same URI and authentication headers. + + Returns + ------- + LanceNamespace + The namespace client for this connection. + """ + return LOOP.run(self._conn.namespace_client()) + async def close(self): """Close the connection to the database.""" self._conn.close() diff --git a/python/python/tests/test_db.py b/python/python/tests/test_db.py index 5ad72f8ed..ebb42b61e 100644 --- a/python/python/tests/test_db.py +++ b/python/python/tests/test_db.py @@ -3,6 +3,7 @@ import re +import sys from datetime import timedelta import os @@ -1048,3 +1049,59 @@ def test_clone_table_deep_clone_fails(tmp_path): source_uri = os.path.join(tmp_path, "source.lance") with pytest.raises(Exception, match="Deep clone is not yet implemented"): db.clone_table("cloned", source_uri, is_shallow=False) + + +@pytest.mark.skipif(sys.platform == "win32", reason="Namespace client issues") +def test_namespace_client_native_storage(tmp_path): + """Test namespace_client() returns DirectoryNamespace for native storage.""" + from lance.namespace import DirectoryNamespace + + db = lancedb.connect(tmp_path) + ns_client = db.namespace_client() + + assert isinstance(ns_client, DirectoryNamespace) + assert str(tmp_path) in ns_client.namespace_id() + + +@pytest.mark.skipif(sys.platform == "win32", reason="Namespace client issues") +def test_namespace_client_with_storage_options(tmp_path): + """Test namespace_client() preserves storage options.""" + from lance.namespace import DirectoryNamespace + + storage_options = {"timeout": "10s"} + db = lancedb.connect(tmp_path, storage_options=storage_options) + ns_client = db.namespace_client() + + assert isinstance(ns_client, DirectoryNamespace) + + +@pytest.mark.skipif(sys.platform == "win32", reason="Namespace client issues") +def test_namespace_client_operations(tmp_path): + """Test that namespace_client() returns a functional namespace client.""" + db = lancedb.connect(tmp_path) + ns_client = db.namespace_client() + + # Create a table through the main db connection + data = [{"id": 1, "text": "hello", "vector": [1.0, 2.0]}] + db.create_table("test_table", data=data) + + # Verify the namespace client can see the table + from lance_namespace import ListTablesRequest + + # id=[] means root namespace + response = ns_client.list_tables(ListTablesRequest(id=[])) + # Tables can be strings or objects with name attribute + table_names = [t.name if hasattr(t, "name") else t for t in response.tables] + assert "test_table" in table_names + + +@pytest.mark.skipif(sys.platform == "win32", reason="Namespace client issues") +def test_namespace_client_namespace_connection(tmp_path): + """Test namespace_client() returns the backing client for namespace connections.""" + from lance.namespace import DirectoryNamespace + + db = lancedb.connect_namespace("dir", {"root": str(tmp_path)}) + ns_client = db.namespace_client() + + assert isinstance(ns_client, DirectoryNamespace) + assert str(tmp_path) in ns_client.namespace_id() diff --git a/python/src/connection.rs b/python/src/connection.rs index 1e0bdee21..1db4e7344 100644 --- a/python/src/connection.rs +++ b/python/src/connection.rs @@ -474,6 +474,25 @@ impl Connection { }) }) } + + /// Get the configuration for constructing an equivalent namespace client. + /// Returns a dict with: + /// - "impl": "dir" for DirectoryNamespace, "rest" for RestNamespace + /// - "properties": configuration properties for the namespace + #[pyo3(signature = ())] + pub fn namespace_client_config(self_: PyRef<'_, Self>) -> PyResult> { + let inner = self_.get_inner()?.clone(); + let py = self_.py(); + future_into_py(py, async move { + let (impl_type, properties) = inner.namespace_client_config().await.infer_error()?; + Python::attach(|py| -> PyResult> { + let dict = PyDict::new(py); + dict.set_item("impl", impl_type)?; + dict.set_item("properties", properties)?; + Ok(dict.unbind()) + }) + }) + } } #[pyfunction] diff --git a/rust/lancedb/src/connection.rs b/rust/lancedb/src/connection.rs index 169246baf..e89dea37e 100644 --- a/rust/lancedb/src/connection.rs +++ b/rust/lancedb/src/connection.rs @@ -541,6 +541,16 @@ impl Connection { self.internal.namespace_client().await } + /// Get the configuration for constructing an equivalent namespace client. + /// Returns (impl_type, properties) where: + /// - impl_type: "dir" for DirectoryNamespace, "rest" for RestNamespace + /// - properties: configuration properties for the namespace + pub async fn namespace_client_config( + &self, + ) -> Result<(String, std::collections::HashMap)> { + self.internal.namespace_client_config().await + } + /// List tables with pagination support pub async fn list_tables(&self, request: ListTablesRequest) -> Result { self.internal.list_tables(request).await diff --git a/rust/lancedb/src/database.rs b/rust/lancedb/src/database.rs index 1fee2f295..787221909 100644 --- a/rust/lancedb/src/database.rs +++ b/rust/lancedb/src/database.rs @@ -265,4 +265,13 @@ pub trait Database: /// For ListingDatabase, it is the equivalent DirectoryNamespace. /// For RemoteDatabase, it is the equivalent RestNamespace. async fn namespace_client(&self) -> Result>; + + /// Get the configuration for constructing an equivalent namespace client. + /// Returns (impl_type, properties) where: + /// - impl_type: "dir" for DirectoryNamespace, "rest" for RestNamespace + /// - properties: configuration properties for the namespace + /// + /// This is useful for Python bindings where we want to return a Python + /// namespace object rather than a Rust trait object. + async fn namespace_client_config(&self) -> Result<(String, HashMap)>; } diff --git a/rust/lancedb/src/database/listing.rs b/rust/lancedb/src/database/listing.rs index 9c19fbee4..09b902f4d 100644 --- a/rust/lancedb/src/database/listing.rs +++ b/rust/lancedb/src/database/listing.rs @@ -1099,6 +1099,15 @@ impl Database for ListingDatabase { })?; Ok(Arc::new(namespace) as Arc) } + + async fn namespace_client_config(&self) -> Result<(String, HashMap)> { + let mut properties = HashMap::new(); + properties.insert("root".to_string(), self.uri.clone()); + for (key, value) in &self.storage_options { + properties.insert(format!("storage.{}", key), value.clone()); + } + Ok(("dir".to_string(), properties)) + } } #[cfg(test)] diff --git a/rust/lancedb/src/database/namespace.rs b/rust/lancedb/src/database/namespace.rs index 2e381aaea..03bc434e6 100644 --- a/rust/lancedb/src/database/namespace.rs +++ b/rust/lancedb/src/database/namespace.rs @@ -45,6 +45,10 @@ pub struct LanceNamespaceDatabase { uri: String, // Operations to push down to the namespace server pushdown_operations: HashSet, + // Namespace implementation type (e.g., "dir", "rest") + ns_impl: String, + // Namespace properties used to construct the namespace client + ns_properties: HashMap, } impl LanceNamespaceDatabase { @@ -74,6 +78,8 @@ impl LanceNamespaceDatabase { session, uri: format!("namespace://{}", ns_impl), pushdown_operations, + ns_impl: ns_impl.to_string(), + ns_properties, }) } } @@ -345,6 +351,10 @@ impl Database for LanceNamespaceDatabase { async fn namespace_client(&self) -> Result> { Ok(self.namespace.clone()) } + + async fn namespace_client_config(&self) -> Result<(String, HashMap)> { + Ok((self.ns_impl.clone(), self.ns_properties.clone())) + } } #[cfg(test)] diff --git a/rust/lancedb/src/remote/db.rs b/rust/lancedb/src/remote/db.rs index 553ce5599..be50b6859 100644 --- a/rust/lancedb/src/remote/db.rs +++ b/rust/lancedb/src/remote/db.rs @@ -777,6 +777,32 @@ impl Database for RemoteDatabase { let namespace = builder.build(); Ok(Arc::new(namespace) as Arc) } + + async fn namespace_client_config(&self) -> Result<(String, HashMap)> { + let mut properties = HashMap::new(); + properties.insert("uri".to_string(), self.client.host().to_string()); + properties.insert("delimiter".to_string(), self.client.id_delimiter.clone()); + for (key, value) in &self.namespace_headers { + properties.insert(format!("header.{}", key), value.clone()); + } + // Add TLS configuration if present + if let Some(tls_config) = &self.tls_config { + if let Some(cert_file) = &tls_config.cert_file { + properties.insert("tls.cert_file".to_string(), cert_file.clone()); + } + if let Some(key_file) = &tls_config.key_file { + properties.insert("tls.key_file".to_string(), key_file.clone()); + } + if let Some(ssl_ca_cert) = &tls_config.ssl_ca_cert { + properties.insert("tls.ssl_ca_cert".to_string(), ssl_ca_cert.clone()); + } + properties.insert( + "tls.assert_hostname".to_string(), + tls_config.assert_hostname.to_string(), + ); + } + Ok(("rest".to_string(), properties)) + } } /// RemoteOptions contains a subset of StorageOptions that are compatible with Remote LanceDB connections From 0ac59de5f19503c2063e0e40100a38e4a6ced199 Mon Sep 17 00:00:00 2001 From: Lance Release Date: Sun, 5 Apr 2026 02:50:52 +0000 Subject: [PATCH 3/6] =?UTF-8?q?Bump=20version:=200.31.0-beta.0=20=E2=86=92?= =?UTF-8?q?=200.31.0-beta.1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/.bumpversion.toml | 2 +- python/Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/.bumpversion.toml b/python/.bumpversion.toml index bee6c3e66..7c7821994 100644 --- a/python/.bumpversion.toml +++ b/python/.bumpversion.toml @@ -1,5 +1,5 @@ [tool.bumpversion] -current_version = "0.31.0-beta.0" +current_version = "0.31.0-beta.1" parse = """(?x) (?P0|[1-9]\\d*)\\. (?P0|[1-9]\\d*)\\. diff --git a/python/Cargo.toml b/python/Cargo.toml index 047853633..cbbdedf20 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "lancedb-python" -version = "0.31.0-beta.0" +version = "0.31.0-beta.1" edition.workspace = true description = "Python bindings for LanceDB" license.workspace = true From de3f8097e78d38ca9c7c6879f1366c768dcf3343 Mon Sep 17 00:00:00 2001 From: Lance Release Date: Sun, 5 Apr 2026 02:51:09 +0000 Subject: [PATCH 4/6] =?UTF-8?q?Bump=20version:=200.28.0-beta.0=20=E2=86=92?= =?UTF-8?q?=200.28.0-beta.1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .bumpversion.toml | 2 +- Cargo.lock | 6 +++--- docs/src/java/java.md | 2 +- java/lancedb-core/pom.xml | 2 +- java/pom.xml | 2 +- nodejs/Cargo.toml | 2 +- nodejs/npm/darwin-arm64/package.json | 2 +- nodejs/npm/linux-arm64-gnu/package.json | 2 +- nodejs/npm/linux-arm64-musl/package.json | 2 +- nodejs/npm/linux-x64-gnu/package.json | 2 +- nodejs/npm/linux-x64-musl/package.json | 2 +- nodejs/npm/win32-arm64-msvc/package.json | 2 +- nodejs/npm/win32-x64-msvc/package.json | 2 +- nodejs/package-lock.json | 4 ++-- nodejs/package.json | 2 +- rust/lancedb/Cargo.toml | 2 +- 16 files changed, 19 insertions(+), 19 deletions(-) diff --git a/.bumpversion.toml b/.bumpversion.toml index f649bcfc3..ee1693420 100644 --- a/.bumpversion.toml +++ b/.bumpversion.toml @@ -1,5 +1,5 @@ [tool.bumpversion] -current_version = "0.28.0-beta.0" +current_version = "0.28.0-beta.1" parse = """(?x) (?P0|[1-9]\\d*)\\. (?P0|[1-9]\\d*)\\. diff --git a/Cargo.lock b/Cargo.lock index e2cf80fa0..8ed07c153 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4630,7 +4630,7 @@ dependencies = [ [[package]] name = "lancedb" -version = "0.28.0-beta.0" +version = "0.28.0-beta.1" dependencies = [ "ahash", "anyhow", @@ -4712,7 +4712,7 @@ dependencies = [ [[package]] name = "lancedb-nodejs" -version = "0.28.0-beta.0" +version = "0.28.0-beta.1" dependencies = [ "arrow-array", "arrow-buffer", @@ -4734,7 +4734,7 @@ dependencies = [ [[package]] name = "lancedb-python" -version = "0.31.0-beta.0" +version = "0.31.0-beta.1" dependencies = [ "arrow", "async-trait", diff --git a/docs/src/java/java.md b/docs/src/java/java.md index ac0a15806..fb63a4b33 100644 --- a/docs/src/java/java.md +++ b/docs/src/java/java.md @@ -14,7 +14,7 @@ Add the following dependency to your `pom.xml`: com.lancedb lancedb-core - 0.28.0-beta.0 + 0.28.0-beta.1 ``` diff --git a/java/lancedb-core/pom.xml b/java/lancedb-core/pom.xml index 276036ca1..77cd36609 100644 --- a/java/lancedb-core/pom.xml +++ b/java/lancedb-core/pom.xml @@ -8,7 +8,7 @@ com.lancedb lancedb-parent - 0.28.0-beta.0 + 0.28.0-beta.1 ../pom.xml diff --git a/java/pom.xml b/java/pom.xml index 9aef9d792..e2da10a82 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -6,7 +6,7 @@ com.lancedb lancedb-parent - 0.28.0-beta.0 + 0.28.0-beta.1 pom ${project.artifactId} LanceDB Java SDK Parent POM diff --git a/nodejs/Cargo.toml b/nodejs/Cargo.toml index 5b3be8cbd..fb1971516 100644 --- a/nodejs/Cargo.toml +++ b/nodejs/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "lancedb-nodejs" edition.workspace = true -version = "0.28.0-beta.0" +version = "0.28.0-beta.1" license.workspace = true description.workspace = true repository.workspace = true diff --git a/nodejs/npm/darwin-arm64/package.json b/nodejs/npm/darwin-arm64/package.json index d67248af8..dac329961 100644 --- a/nodejs/npm/darwin-arm64/package.json +++ b/nodejs/npm/darwin-arm64/package.json @@ -1,6 +1,6 @@ { "name": "@lancedb/lancedb-darwin-arm64", - "version": "0.28.0-beta.0", + "version": "0.28.0-beta.1", "os": ["darwin"], "cpu": ["arm64"], "main": "lancedb.darwin-arm64.node", diff --git a/nodejs/npm/linux-arm64-gnu/package.json b/nodejs/npm/linux-arm64-gnu/package.json index b55edd04b..3812b0032 100644 --- a/nodejs/npm/linux-arm64-gnu/package.json +++ b/nodejs/npm/linux-arm64-gnu/package.json @@ -1,6 +1,6 @@ { "name": "@lancedb/lancedb-linux-arm64-gnu", - "version": "0.28.0-beta.0", + "version": "0.28.0-beta.1", "os": ["linux"], "cpu": ["arm64"], "main": "lancedb.linux-arm64-gnu.node", diff --git a/nodejs/npm/linux-arm64-musl/package.json b/nodejs/npm/linux-arm64-musl/package.json index 23bd06975..07888c548 100644 --- a/nodejs/npm/linux-arm64-musl/package.json +++ b/nodejs/npm/linux-arm64-musl/package.json @@ -1,6 +1,6 @@ { "name": "@lancedb/lancedb-linux-arm64-musl", - "version": "0.28.0-beta.0", + "version": "0.28.0-beta.1", "os": ["linux"], "cpu": ["arm64"], "main": "lancedb.linux-arm64-musl.node", diff --git a/nodejs/npm/linux-x64-gnu/package.json b/nodejs/npm/linux-x64-gnu/package.json index 51bb6ab0e..52bb5b876 100644 --- a/nodejs/npm/linux-x64-gnu/package.json +++ b/nodejs/npm/linux-x64-gnu/package.json @@ -1,6 +1,6 @@ { "name": "@lancedb/lancedb-linux-x64-gnu", - "version": "0.28.0-beta.0", + "version": "0.28.0-beta.1", "os": ["linux"], "cpu": ["x64"], "main": "lancedb.linux-x64-gnu.node", diff --git a/nodejs/npm/linux-x64-musl/package.json b/nodejs/npm/linux-x64-musl/package.json index c493c9bca..e7e691f11 100644 --- a/nodejs/npm/linux-x64-musl/package.json +++ b/nodejs/npm/linux-x64-musl/package.json @@ -1,6 +1,6 @@ { "name": "@lancedb/lancedb-linux-x64-musl", - "version": "0.28.0-beta.0", + "version": "0.28.0-beta.1", "os": ["linux"], "cpu": ["x64"], "main": "lancedb.linux-x64-musl.node", diff --git a/nodejs/npm/win32-arm64-msvc/package.json b/nodejs/npm/win32-arm64-msvc/package.json index 35625d568..015c4127e 100644 --- a/nodejs/npm/win32-arm64-msvc/package.json +++ b/nodejs/npm/win32-arm64-msvc/package.json @@ -1,6 +1,6 @@ { "name": "@lancedb/lancedb-win32-arm64-msvc", - "version": "0.28.0-beta.0", + "version": "0.28.0-beta.1", "os": [ "win32" ], diff --git a/nodejs/npm/win32-x64-msvc/package.json b/nodejs/npm/win32-x64-msvc/package.json index 8bf356780..8a08b0182 100644 --- a/nodejs/npm/win32-x64-msvc/package.json +++ b/nodejs/npm/win32-x64-msvc/package.json @@ -1,6 +1,6 @@ { "name": "@lancedb/lancedb-win32-x64-msvc", - "version": "0.28.0-beta.0", + "version": "0.28.0-beta.1", "os": ["win32"], "cpu": ["x64"], "main": "lancedb.win32-x64-msvc.node", diff --git a/nodejs/package-lock.json b/nodejs/package-lock.json index a2247faa1..211c66806 100644 --- a/nodejs/package-lock.json +++ b/nodejs/package-lock.json @@ -1,12 +1,12 @@ { "name": "@lancedb/lancedb", - "version": "0.28.0-beta.0", + "version": "0.28.0-beta.1", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@lancedb/lancedb", - "version": "0.28.0-beta.0", + "version": "0.28.0-beta.1", "cpu": [ "x64", "arm64" diff --git a/nodejs/package.json b/nodejs/package.json index 8e52e7aa1..7974453a2 100644 --- a/nodejs/package.json +++ b/nodejs/package.json @@ -11,7 +11,7 @@ "ann" ], "private": false, - "version": "0.28.0-beta.0", + "version": "0.28.0-beta.1", "main": "dist/index.js", "exports": { ".": "./dist/index.js", diff --git a/rust/lancedb/Cargo.toml b/rust/lancedb/Cargo.toml index a0c11fd77..6b6f41104 100644 --- a/rust/lancedb/Cargo.toml +++ b/rust/lancedb/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "lancedb" -version = "0.28.0-beta.0" +version = "0.28.0-beta.1" edition.workspace = true description = "LanceDB: A serverless, low-latency vector database for AI applications" license.workspace = true From a898dc81c22e55c247f67895e857ea482305b7fb Mon Sep 17 00:00:00 2001 From: Jack Ye Date: Mon, 6 Apr 2026 11:20:10 -0700 Subject: [PATCH 5/6] feat: add user_id field to ClientConfig for user identification (#3240) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary - Add a `user_id` field to `ClientConfig` that allows users to identify themselves to LanceDB Cloud/Enterprise - The user_id is sent as the `x-lancedb-user-id` HTTP header in all requests - Supports three configuration methods: - Direct assignment via `ClientConfig.user_id` - Environment variable `LANCEDB_USER_ID` - Indirect env var lookup via `LANCEDB_USER_ID_ENV_KEY` Closes #3230 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude Opus 4.5 --- docs/src/js/interfaces/ClientConfig.md | 15 +++ nodejs/src/remote.rs | 8 ++ python/python/lancedb/remote/__init__.py | 28 +++++ python/src/connection.rs | 2 + rust/lancedb/src/remote/client.rs | 135 +++++++++++++++++++++++ 5 files changed, 188 insertions(+) diff --git a/docs/src/js/interfaces/ClientConfig.md b/docs/src/js/interfaces/ClientConfig.md index c09764cba..94cee5b5e 100644 --- a/docs/src/js/interfaces/ClientConfig.md +++ b/docs/src/js/interfaces/ClientConfig.md @@ -53,3 +53,18 @@ optional tlsConfig: TlsConfig; ```ts optional userAgent: string; ``` + +*** + +### userId? + +```ts +optional userId: string; +``` + +User identifier for tracking purposes. + +This is sent as the `x-lancedb-user-id` header in requests to LanceDB Cloud/Enterprise. +It can be set directly, or via the `LANCEDB_USER_ID` environment variable. +Alternatively, set `LANCEDB_USER_ID_ENV_KEY` to specify another environment +variable that contains the user ID value. diff --git a/nodejs/src/remote.rs b/nodejs/src/remote.rs index 04602c49f..8cfcbc984 100644 --- a/nodejs/src/remote.rs +++ b/nodejs/src/remote.rs @@ -92,6 +92,13 @@ pub struct ClientConfig { pub extra_headers: Option>, pub id_delimiter: Option, pub tls_config: Option, + /// User identifier for tracking purposes. + /// + /// This is sent as the `x-lancedb-user-id` header in requests to LanceDB Cloud/Enterprise. + /// It can be set directly, or via the `LANCEDB_USER_ID` environment variable. + /// Alternatively, set `LANCEDB_USER_ID_ENV_KEY` to specify another environment + /// variable that contains the user ID value. + pub user_id: Option, } impl From for lancedb::remote::TimeoutConfig { @@ -145,6 +152,7 @@ impl From for lancedb::remote::ClientConfig { id_delimiter: config.id_delimiter, tls_config: config.tls_config.map(Into::into), header_provider: None, // the header provider is set separately later + user_id: config.user_id, } } } diff --git a/python/python/lancedb/remote/__init__.py b/python/python/lancedb/remote/__init__.py index 585c25a94..289e28942 100644 --- a/python/python/lancedb/remote/__init__.py +++ b/python/python/lancedb/remote/__init__.py @@ -145,6 +145,33 @@ class TlsConfig: @dataclass class ClientConfig: + """Configuration for the LanceDB Cloud HTTP client. + + Attributes + ---------- + user_agent: str + User agent string sent with requests. + retry_config: RetryConfig + Configuration for retrying failed requests. + timeout_config: Optional[TimeoutConfig] + Configuration for request timeouts. + extra_headers: Optional[dict] + Additional headers to include in requests. + id_delimiter: Optional[str] + The delimiter to use when constructing object identifiers. + tls_config: Optional[TlsConfig] + TLS/mTLS configuration for secure connections. + header_provider: Optional[HeaderProvider] + Provider for dynamic headers to be added to each request. + user_id: Optional[str] + User identifier for tracking purposes. This is sent as the + `x-lancedb-user-id` header in requests to LanceDB Cloud/Enterprise. + + This can also be set via the `LANCEDB_USER_ID` environment variable. + Alternatively, set `LANCEDB_USER_ID_ENV_KEY` to specify another + environment variable that contains the user ID value. + """ + user_agent: str = f"LanceDB-Python-Client/{__version__}" retry_config: RetryConfig = field(default_factory=RetryConfig) timeout_config: Optional[TimeoutConfig] = field(default_factory=TimeoutConfig) @@ -152,6 +179,7 @@ class ClientConfig: id_delimiter: Optional[str] = None tls_config: Optional[TlsConfig] = None header_provider: Optional["HeaderProvider"] = None + user_id: Optional[str] = None def __post_init__(self): if isinstance(self.retry_config, dict): diff --git a/python/src/connection.rs b/python/src/connection.rs index 1db4e7344..f19bfba97 100644 --- a/python/src/connection.rs +++ b/python/src/connection.rs @@ -547,6 +547,7 @@ pub struct PyClientConfig { id_delimiter: Option, tls_config: Option, header_provider: Option>, + user_id: Option, } #[derive(FromPyObject)] @@ -631,6 +632,7 @@ impl From for lancedb::remote::ClientConfig { id_delimiter: value.id_delimiter, tls_config: value.tls_config.map(Into::into), header_provider, + user_id: value.user_id, } } } diff --git a/rust/lancedb/src/remote/client.rs b/rust/lancedb/src/remote/client.rs index 9549c5b44..b50ca2206 100644 --- a/rust/lancedb/src/remote/client.rs +++ b/rust/lancedb/src/remote/client.rs @@ -52,6 +52,13 @@ pub struct ClientConfig { pub tls_config: Option, /// Provider for custom headers to be added to each request pub header_provider: Option>, + /// User identifier for tracking purposes. + /// + /// This is sent as the `x-lancedb-user-id` header in requests to LanceDB Cloud/Enterprise. + /// It can be set directly, or via the `LANCEDB_USER_ID` environment variable. + /// Alternatively, set `LANCEDB_USER_ID_ENV_KEY` to specify another environment + /// variable that contains the user ID value. + pub user_id: Option, } impl std::fmt::Debug for ClientConfig { @@ -67,6 +74,7 @@ impl std::fmt::Debug for ClientConfig { "header_provider", &self.header_provider.as_ref().map(|_| "Some(...)"), ) + .field("user_id", &self.user_id) .finish() } } @@ -81,10 +89,41 @@ impl Default for ClientConfig { id_delimiter: None, tls_config: None, header_provider: None, + user_id: None, } } } +impl ClientConfig { + /// Resolve the user ID from the config or environment variables. + /// + /// Resolution order: + /// 1. If `user_id` is set in the config, use that value + /// 2. If `LANCEDB_USER_ID` environment variable is set, use that value + /// 3. If `LANCEDB_USER_ID_ENV_KEY` is set, read the env var it points to + /// 4. Otherwise, return None + pub fn resolve_user_id(&self) -> Option { + if self.user_id.is_some() { + return self.user_id.clone(); + } + + if let Ok(user_id) = std::env::var("LANCEDB_USER_ID") + && !user_id.is_empty() + { + return Some(user_id); + } + + if let Ok(env_key) = std::env::var("LANCEDB_USER_ID_ENV_KEY") + && let Ok(user_id) = std::env::var(&env_key) + && !user_id.is_empty() + { + return Some(user_id); + } + + None + } +} + /// How to handle timeouts for HTTP requests. #[derive(Clone, Default, Debug)] pub struct TimeoutConfig { @@ -464,6 +503,15 @@ impl RestfulLanceDbClient { ); } + if let Some(user_id) = config.resolve_user_id() { + headers.insert( + HeaderName::from_static("x-lancedb-user-id"), + HeaderValue::from_str(&user_id).map_err(|_| Error::InvalidInput { + message: format!("non-ascii user_id '{}' provided", user_id), + })?, + ); + } + Ok(headers) } @@ -1072,4 +1120,91 @@ mod tests { _ => panic!("Expected Runtime error"), } } + + #[test] + fn test_resolve_user_id_direct_value() { + let config = ClientConfig { + user_id: Some("direct-user-id".to_string()), + ..Default::default() + }; + assert_eq!(config.resolve_user_id(), Some("direct-user-id".to_string())); + } + + #[test] + fn test_resolve_user_id_none() { + let config = ClientConfig::default(); + // Clear env vars that might be set from other tests + // SAFETY: This is only called in tests + unsafe { + std::env::remove_var("LANCEDB_USER_ID"); + std::env::remove_var("LANCEDB_USER_ID_ENV_KEY"); + } + assert_eq!(config.resolve_user_id(), None); + } + + #[test] + fn test_resolve_user_id_from_env() { + // SAFETY: This is only called in tests + unsafe { + std::env::set_var("LANCEDB_USER_ID", "env-user-id"); + } + let config = ClientConfig::default(); + assert_eq!(config.resolve_user_id(), Some("env-user-id".to_string())); + // SAFETY: This is only called in tests + unsafe { + std::env::remove_var("LANCEDB_USER_ID"); + } + } + + #[test] + fn test_resolve_user_id_from_env_key() { + // SAFETY: This is only called in tests + unsafe { + std::env::remove_var("LANCEDB_USER_ID"); + std::env::set_var("LANCEDB_USER_ID_ENV_KEY", "MY_CUSTOM_USER_ID"); + std::env::set_var("MY_CUSTOM_USER_ID", "custom-env-user-id"); + } + let config = ClientConfig::default(); + assert_eq!( + config.resolve_user_id(), + Some("custom-env-user-id".to_string()) + ); + // SAFETY: This is only called in tests + unsafe { + std::env::remove_var("LANCEDB_USER_ID_ENV_KEY"); + std::env::remove_var("MY_CUSTOM_USER_ID"); + } + } + + #[test] + fn test_resolve_user_id_direct_takes_precedence() { + // SAFETY: This is only called in tests + unsafe { + std::env::set_var("LANCEDB_USER_ID", "env-user-id"); + } + let config = ClientConfig { + user_id: Some("direct-user-id".to_string()), + ..Default::default() + }; + assert_eq!(config.resolve_user_id(), Some("direct-user-id".to_string())); + // SAFETY: This is only called in tests + unsafe { + std::env::remove_var("LANCEDB_USER_ID"); + } + } + + #[test] + fn test_resolve_user_id_empty_env_ignored() { + // SAFETY: This is only called in tests + unsafe { + std::env::set_var("LANCEDB_USER_ID", ""); + std::env::remove_var("LANCEDB_USER_ID_ENV_KEY"); + } + let config = ClientConfig::default(); + assert_eq!(config.resolve_user_id(), None); + // SAFETY: This is only called in tests + unsafe { + std::env::remove_var("LANCEDB_USER_ID"); + } + } } From a813ce2f71bd7f4d8b2783dc75fcf48aa21db033 Mon Sep 17 00:00:00 2001 From: yaommen Date: Thu, 9 Apr 2026 00:09:41 +0800 Subject: [PATCH 6/6] fix(python): sanitize bad vectors before Arrow cast (#3158) ## Problem `on_bad_vectors="drop"` is supposed to remove invalid vector rows before write, but for some schema-defined vector columns it can still fail later during Arrow cast instead of dropping the bad row. Repro: ```python class MySchema(LanceModel): text: str embedding: Vector(16) table = db.create_table("test", schema=MySchema) table.add( [ {"text": "hello", "embedding": []}, {"text": "bar", "embedding": [0.1] * 16}, ], on_bad_vectors="drop", ) ``` Before: ``` RuntimeError Arrow error: C Data interface error: Invalid: ListType can only be casted to FixedSizeListType if the lists are all the expected size. ``` After: ``` rows 1 texts ['bar'] ``` ## Solution Make bad-vector sanitization use schema dimensions before cast, while keeping the handling scoped to vector columns identified by schema metadata or existing vector-name heuristics. This also preserves existing integer vector inputs and avoids applying on_bad_vectors to unrelated fixed-size float columns. Fixes #1670 Signed-off-by: yaommen --- python/python/lancedb/table.py | 321 ++++++++++++++++++++++++------ python/python/tests/test_table.py | 225 +++++++++++++++++++++ python/python/tests/test_util.py | 115 ++++++++++- 3 files changed, 597 insertions(+), 64 deletions(-) diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index ccd40b4a7..45f76bdc9 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -270,15 +270,17 @@ def _sanitize_data( reader, on_bad_vectors=on_bad_vectors, fill_value=fill_value, + target_schema=target_schema, + metadata=metadata, ) if target_schema is None: target_schema, reader = _infer_target_schema(reader) if metadata: - new_metadata = target_schema.metadata or {} - new_metadata.update(metadata) - target_schema = target_schema.with_metadata(new_metadata) + target_schema = target_schema.with_metadata( + _merge_metadata(target_schema.metadata, metadata) + ) _validate_schema(target_schema) reader = _cast_to_target_schema(reader, target_schema, allow_subschema) @@ -294,7 +296,7 @@ def _cast_to_target_schema( # pa.Table.cast expects field order not to be changed. # Lance doesn't care about field order, so we don't need to rearrange fields # to match the target schema. We just need to correctly cast the fields. - if reader.schema == target_schema: + if reader.schema.equals(target_schema, check_metadata=True): # Fast path when the schemas are already the same return reader @@ -314,7 +316,13 @@ def _cast_to_target_schema( def gen(): for batch in reader: # Table but not RecordBatch has cast. - yield pa.Table.from_batches([batch]).cast(reordered_schema).to_batches()[0] + cast_batches = ( + pa.Table.from_batches([batch]).cast(reordered_schema).to_batches() + ) + if cast_batches: + yield pa.RecordBatch.from_arrays( + cast_batches[0].columns, schema=reordered_schema + ) return pa.RecordBatchReader.from_batches(reordered_schema, gen()) @@ -332,37 +340,51 @@ def _align_field_types( if target_field is None: raise ValueError(f"Field '{field.name}' not found in target schema") if pa.types.is_struct(target_field.type): - new_type = pa.struct( - _align_field_types( - field.type.fields, - target_field.type.fields, + if pa.types.is_struct(field.type): + new_type = pa.struct( + _align_field_types( + field.type.fields, + target_field.type.fields, + ) ) - ) + else: + new_type = target_field.type elif pa.types.is_list(target_field.type): - new_type = pa.list_( - _align_field_types( - [field.type.value_field], - [target_field.type.value_field], - )[0] - ) + if _is_list_like(field.type): + new_type = pa.list_( + _align_field_types( + [field.type.value_field], + [target_field.type.value_field], + )[0] + ) + else: + new_type = target_field.type elif pa.types.is_large_list(target_field.type): - new_type = pa.large_list( - _align_field_types( - [field.type.value_field], - [target_field.type.value_field], - )[0] - ) + if _is_list_like(field.type): + new_type = pa.large_list( + _align_field_types( + [field.type.value_field], + [target_field.type.value_field], + )[0] + ) + else: + new_type = target_field.type elif pa.types.is_fixed_size_list(target_field.type): - new_type = pa.list_( - _align_field_types( - [field.type.value_field], - [target_field.type.value_field], - )[0], - target_field.type.list_size, - ) + if _is_list_like(field.type): + new_type = pa.list_( + _align_field_types( + [field.type.value_field], + [target_field.type.value_field], + )[0], + target_field.type.list_size, + ) + else: + new_type = target_field.type else: new_type = target_field.type - new_fields.append(pa.field(field.name, new_type, field.nullable)) + new_fields.append( + pa.field(field.name, new_type, field.nullable, target_field.metadata) + ) return new_fields @@ -440,6 +462,7 @@ def sanitize_create_table( schema = data.schema if metadata: + metadata = _merge_metadata(schema.metadata, metadata) schema = schema.with_metadata(metadata) # Need to apply metadata to the data as well if isinstance(data, pa.Table): @@ -492,9 +515,9 @@ def _append_vector_columns( vector columns to the table. """ if schema is None: - metadata = metadata or {} + metadata = _merge_metadata(metadata) else: - metadata = schema.metadata or metadata or {} + metadata = _merge_metadata(schema.metadata, metadata) functions = EmbeddingFunctionRegistry.get_instance().parse_functions(metadata) if not functions: @@ -3211,43 +3234,157 @@ def _handle_bad_vectors( reader: pa.RecordBatchReader, on_bad_vectors: Literal["error", "drop", "fill", "null"] = "error", fill_value: float = 0.0, + target_schema: Optional[pa.Schema] = None, + metadata: Optional[dict] = None, ) -> pa.RecordBatchReader: - vector_columns = [] + vector_columns = _find_vector_columns(reader.schema, target_schema, metadata) + if not vector_columns: + return reader - for field in reader.schema: - # They can provide a 'vector' column that isn't yet a FSL - named_vector_col = ( - ( - pa.types.is_list(field.type) - or pa.types.is_large_list(field.type) - or pa.types.is_fixed_size_list(field.type) - ) - and pa.types.is_floating(field.type.value_type) - and field.name == VECTOR_COLUMN_NAME - ) - # TODO: we're making an assumption that fixed size list of 10 or more - # is a vector column. This is definitely a bit hacky. - likely_vector_col = ( - pa.types.is_fixed_size_list(field.type) - and pa.types.is_floating(field.type.value_type) - and (field.type.list_size >= 10) - ) - - if named_vector_col or likely_vector_col: - vector_columns.append(field.name) + output_schema = _vector_output_schema(reader.schema, vector_columns) def gen(): for batch in reader: - for name in vector_columns: + pending_dims = [] + for vector_column in vector_columns: + dim = vector_column["expected_dim"] + if target_schema is not None and dim is None: + dim = _infer_vector_dim(batch[vector_column["name"]]) + pending_dims.append(vector_column) batch = _handle_bad_vector_column( batch, - vector_column_name=name, + vector_column_name=vector_column["name"], on_bad_vectors=on_bad_vectors, fill_value=fill_value, + expected_dim=dim, + expected_value_type=vector_column["expected_value_type"], ) - yield batch + for vector_column in pending_dims: + if vector_column["expected_dim"] is None: + vector_column["expected_dim"] = _infer_vector_dim( + batch[vector_column["name"]] + ) + if batch.schema.equals(output_schema, check_metadata=True): + yield batch + continue - return pa.RecordBatchReader.from_batches(reader.schema, gen()) + cast_batches = ( + pa.Table.from_batches([batch]).cast(output_schema).to_batches() + ) + if cast_batches: + yield pa.RecordBatch.from_arrays( + cast_batches[0].columns, + schema=output_schema, + ) + + return pa.RecordBatchReader.from_batches(output_schema, gen()) + + +def _find_vector_columns( + reader_schema: pa.Schema, + target_schema: Optional[pa.Schema], + metadata: Optional[dict], +) -> List[dict]: + if target_schema is None: + vector_columns = [] + for field in reader_schema: + named_vector_col = ( + _is_list_like(field.type) + and pa.types.is_floating(field.type.value_type) + and field.name == VECTOR_COLUMN_NAME + ) + likely_vector_col = ( + pa.types.is_fixed_size_list(field.type) + and pa.types.is_floating(field.type.value_type) + and (field.type.list_size >= 10) + ) + if named_vector_col or likely_vector_col: + vector_columns.append( + { + "name": field.name, + "expected_dim": None, + "expected_value_type": None, + } + ) + return vector_columns + + reader_column_names = set(reader_schema.names) + active_metadata = _merge_metadata(target_schema.metadata, metadata) + embedding_function_columns = set( + EmbeddingFunctionRegistry.get_instance().parse_functions(active_metadata).keys() + ) + vector_columns = [] + for field in target_schema: + if field.name not in reader_column_names: + continue + if not _is_list_like(field.type) or not pa.types.is_floating( + field.type.value_type + ): + continue + + reader_field = reader_schema.field(field.name) + named_vector_col = ( + field.name in embedding_function_columns + or field.name == VECTOR_COLUMN_NAME + or (field.name == "embedding" and pa.types.is_fixed_size_list(field.type)) + ) + typed_fixed_vector_col = ( + pa.types.is_fixed_size_list(reader_field.type) + and pa.types.is_floating(reader_field.type.value_type) + and reader_field.type.list_size >= 10 + ) + + if named_vector_col or typed_fixed_vector_col: + vector_columns.append( + { + "name": field.name, + "expected_dim": ( + field.type.list_size + if pa.types.is_fixed_size_list(field.type) + else None + ), + "expected_value_type": field.type.value_type, + } + ) + + return vector_columns + + +def _vector_output_schema( + reader_schema: pa.Schema, + vector_columns: List[dict], +) -> pa.Schema: + columns_by_name = {column["name"]: column for column in vector_columns} + fields = [] + for field in reader_schema: + column = columns_by_name.get(field.name) + if column is None: + output_type = field.type + else: + output_type = _vector_output_type(field, column) + fields.append(pa.field(field.name, output_type, field.nullable, field.metadata)) + return pa.schema(fields, metadata=reader_schema.metadata) + + +def _vector_output_type(field: pa.Field, vector_column: dict) -> pa.DataType: + if not _is_list_like(field.type): + return field.type + + if vector_column["expected_value_type"] is not None and ( + pa.types.is_null(field.type.value_type) + or pa.types.is_integer(field.type.value_type) + or pa.types.is_unsigned_integer(field.type.value_type) + ): + return pa.list_(vector_column["expected_value_type"]) + + if ( + vector_column["expected_dim"] is not None + and pa.types.is_fixed_size_list(field.type) + and field.type.list_size != vector_column["expected_dim"] + ): + return pa.list_(field.type.value_type) + + return field.type def _handle_bad_vector_column( @@ -3255,6 +3392,8 @@ def _handle_bad_vector_column( vector_column_name: str, on_bad_vectors: str = "error", fill_value: float = 0.0, + expected_dim: Optional[int] = None, + expected_value_type: Optional[pa.DataType] = None, ) -> pa.RecordBatch: """ Ensure that the vector column exists and has type fixed_size_list(float) @@ -3271,14 +3410,39 @@ def _handle_bad_vector_column( fill_value: float, default 0.0 The value to use when filling vectors. Only used if on_bad_vectors="fill". """ + position = data.column_names.index(vector_column_name) vec_arr = data[vector_column_name] + if not _is_list_like(vec_arr.type): + return data - has_nan = has_nan_values(vec_arr) + if ( + expected_dim is not None + and pa.types.is_fixed_size_list(vec_arr.type) + and vec_arr.type.list_size != expected_dim + ): + vec_arr = pa.array(vec_arr.to_pylist(), type=pa.list_(vec_arr.type.value_type)) + data = data.set_column(position, vector_column_name, vec_arr) - if pa.types.is_fixed_size_list(vec_arr.type): + if expected_value_type is not None and ( + pa.types.is_integer(vec_arr.type.value_type) + or pa.types.is_unsigned_integer(vec_arr.type.value_type) + ): + vec_arr = pa.array(vec_arr.to_pylist(), type=pa.list_(expected_value_type)) + data = data.set_column(position, vector_column_name, vec_arr) + + if pa.types.is_floating(vec_arr.type.value_type): + has_nan = has_nan_values(vec_arr) + else: + has_nan = pa.array([False] * len(vec_arr)) + + if expected_dim is not None: + dim = expected_dim + elif pa.types.is_fixed_size_list(vec_arr.type): dim = vec_arr.type.list_size else: - dim = _modal_list_size(vec_arr) + dim = _infer_vector_dim(vec_arr) + if dim is None: + return data has_wrong_dim = pc.not_equal(pc.list_value_length(vec_arr), dim) has_bad_vectors = pc.any(has_nan).as_py() or pc.any(has_wrong_dim).as_py() @@ -3316,13 +3480,12 @@ def _handle_bad_vector_column( ) vec_arr = pc.if_else( is_bad, - pa.scalar([fill_value] * dim), + pa.scalar([fill_value] * dim, type=vec_arr.type), vec_arr, ) else: raise ValueError(f"Invalid value for on_bad_vectors: {on_bad_vectors}") - position = data.column_names.index(vector_column_name) return data.set_column(position, vector_column_name, vec_arr) @@ -3343,6 +3506,28 @@ def has_nan_values(arr: Union[pa.ListArray, pa.ChunkedArray]) -> pa.BooleanArray return pc.is_in(indices, has_nan_indices) +def _is_list_like(data_type: pa.DataType) -> bool: + return ( + pa.types.is_list(data_type) + or pa.types.is_large_list(data_type) + or pa.types.is_fixed_size_list(data_type) + ) + + +def _merge_metadata(*metadata_dicts: Optional[dict]) -> dict: + merged = {} + for metadata in metadata_dicts: + if metadata is None: + continue + for key, value in metadata.items(): + if isinstance(key, str): + key = key.encode("utf-8") + if isinstance(value, str): + value = value.encode("utf-8") + merged[key] = value + return merged + + def _name_suggests_vector_column(field_name: str) -> bool: """Check if a field name indicates a vector column.""" name_lower = field_name.lower() @@ -3410,6 +3595,16 @@ def _modal_list_size(arr: Union[pa.ListArray, pa.ChunkedArray]) -> int: return pc.mode(pc.list_value_length(arr))[0].as_py()["mode"] +def _infer_vector_dim(arr: Union[pa.Array, pa.ChunkedArray]) -> Optional[int]: + if not _is_list_like(arr.type): + return None + lengths = pc.list_value_length(arr) + lengths = pc.filter(lengths, pc.greater(lengths, 0)) + if len(lengths) == 0: + return None + return pc.mode(lengths)[0].as_py()["mode"] + + def _validate_schema(schema: pa.Schema): """ Make sure the metadata is valid utf8 diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index f1e71e4cb..639afe903 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -1049,6 +1049,231 @@ def test_add_with_nans(mem_db: DBConnection): assert np.allclose(v, np.array([0.0, 0.0])) +def test_add_with_empty_fixed_size_list_drops_bad_rows(mem_db: DBConnection): + class Schema(LanceModel): + text: str + embedding: Vector(16) + + table = mem_db.create_table("test_empty_embeddings", schema=Schema) + table.add( + [ + {"text": "hello", "embedding": []}, + {"text": "bar", "embedding": [0.1] * 16}, + ], + on_bad_vectors="drop", + ) + + data = table.to_arrow() + assert data["text"].to_pylist() == ["bar"] + assert np.allclose(data["embedding"].to_pylist()[0], np.array([0.1] * 16)) + + +def test_add_with_integer_embeddings_preserves_casting(mem_db: DBConnection): + class Schema(LanceModel): + text: str + embedding: Vector(4) + + table = mem_db.create_table("test_integer_embeddings", schema=Schema) + table.add( + [{"text": "foo", "embedding": [1, 2, 3, 4]}], + on_bad_vectors="drop", + ) + + assert table.to_arrow()["embedding"].to_pylist() == [[1.0, 2.0, 3.0, 4.0]] + + +def test_on_bad_vectors_does_not_handle_non_vector_fixed_size_lists( + mem_db: DBConnection, +): + schema = pa.schema( + [ + pa.field("vector", pa.list_(pa.float32(), 4)), + pa.field("bbox", pa.list_(pa.float32(), 4)), + ] + ) + table = mem_db.create_table("test_bbox_schema", schema=schema) + + with pytest.raises(RuntimeError, match="FixedSizeListType"): + table.add( + [{"vector": [1.0, 2.0, 3.0, 4.0], "bbox": [0.0, 1.0]}], + on_bad_vectors="drop", + ) + + +def test_on_bad_vectors_does_not_handle_custom_named_fixed_size_lists( + mem_db: DBConnection, +): + schema = pa.schema([pa.field("features", pa.list_(pa.float32(), 16))]) + table = mem_db.create_table("test_custom_named_fixed_size_vector", schema=schema) + + with pytest.raises(RuntimeError, match="FixedSizeListType"): + table.add( + [ + {"features": []}, + {"features": [0.1] * 16}, + ], + on_bad_vectors="drop", + ) + + +def test_on_bad_vectors_with_schema_list_vector_still_sanitizes(mem_db: DBConnection): + schema = pa.schema([pa.field("vector", pa.list_(pa.float32()))]) + table = mem_db.create_table("test_schema_list_vector", schema=schema) + table.add( + [ + {"vector": [1.0, 2.0]}, + {"vector": [3.0]}, + {"vector": [4.0, 5.0]}, + ], + on_bad_vectors="drop", + ) + + assert table.to_arrow()["vector"].to_pylist() == [[1.0, 2.0], [4.0, 5.0]] + + +def test_on_bad_vectors_handles_typed_custom_fixed_vectors_for_list_schema( + mem_db: DBConnection, +): + schema = pa.schema([pa.field("vec", pa.list_(pa.float32()))]) + table = mem_db.create_table("test_typed_custom_fixed_vector", schema=schema) + data = pa.table( + { + "vec": pa.array( + [[float("nan")] * 16, [1.0] * 16], + type=pa.list_(pa.float32(), 16), + ) + } + ) + + table.add(data, on_bad_vectors="drop") + + assert table.to_arrow()["vec"].to_pylist() == [[1.0] * 16] + + +def test_on_bad_vectors_fill_preserves_arrow_nested_vector_type(mem_db: DBConnection): + schema = pa.schema([pa.field("vector", pa.list_(pa.float32()))]) + table = mem_db.create_table("test_fill_arrow_nested_type", schema=schema) + data = pa.table( + { + "vector": pa.array( + [[1.0, 2.0], [float("nan"), 3.0]], + type=pa.list_(pa.float32(), 2), + ) + } + ) + + table.add( + data, + on_bad_vectors="fill", + fill_value=0.0, + ) + + assert table.to_arrow()["vector"].to_pylist() == [[1.0, 2.0], [0.0, 0.0]] + + +@pytest.mark.parametrize( + ("table_name", "batch1", "expected"), + [ + ( + "test_schema_list_vector_empty_prefix", + pa.record_batch({"vector": [[], []]}), + [[], [], [1.0, 2.0], [3.0, 4.0]], + ), + ( + "test_schema_list_vector_all_bad_prefix", + pa.record_batch({"vector": [[float("nan")] * 3, [float("nan")] * 3]}), + [[1.0, 2.0], [3.0, 4.0]], + ), + ], +) +def test_on_bad_vectors_with_schema_list_vector_ignores_invalid_prefix_batches( + mem_db: DBConnection, + table_name: str, + batch1: pa.RecordBatch, + expected: list, +): + schema = pa.schema([pa.field("vector", pa.list_(pa.float32()))]) + table = mem_db.create_table(table_name, schema=schema) + batch2 = pa.record_batch({"vector": [[1.0, 2.0], [3.0, 4.0]]}) + reader = pa.RecordBatchReader.from_batches(batch1.schema, [batch1, batch2]) + + table.add(reader, on_bad_vectors="drop") + + assert table.to_arrow()["vector"].to_pylist() == expected + + +def test_on_bad_vectors_with_multiple_vectors_locks_dim_after_final_drop( + mem_db: DBConnection, +): + registry = EmbeddingFunctionRegistry.get_instance() + func = MockTextEmbeddingFunction.create() + metadata = registry.get_table_metadata( + [ + EmbeddingFunctionConfig( + source_column="text1", vector_column="vec1", function=func + ), + EmbeddingFunctionConfig( + source_column="text2", vector_column="vec2", function=func + ), + ] + ) + schema = pa.schema( + [ + pa.field("vec1", pa.list_(pa.float32())), + pa.field("vec2", pa.list_(pa.float32())), + ], + metadata=metadata, + ) + table = mem_db.create_table("test_multi_vector_dim_lock", schema=schema) + batch1 = pa.record_batch( + { + "vec1": [[1.0, 2.0, 3.0], [10.0, 11.0]], + "vec2": [[float("nan"), 0.0], [5.0, 6.0]], + } + ) + batch2 = pa.record_batch( + { + "vec1": [[20.0, 21.0], [30.0, 31.0]], + "vec2": [[7.0, 8.0], [9.0, 10.0]], + } + ) + reader = pa.RecordBatchReader.from_batches(batch1.schema, [batch1, batch2]) + + table.add(reader, on_bad_vectors="drop") + + data = table.to_arrow() + assert data["vec1"].to_pylist() == [[10.0, 11.0], [20.0, 21.0], [30.0, 31.0]] + assert data["vec2"].to_pylist() == [[5.0, 6.0], [7.0, 8.0], [9.0, 10.0]] + + +def test_on_bad_vectors_does_not_handle_non_vector_list_columns(mem_db: DBConnection): + schema = pa.schema([pa.field("embedding_history", pa.list_(pa.float32()))]) + table = mem_db.create_table("test_non_vector_list_schema", schema=schema) + table.add( + [ + {"embedding_history": [1.0, 2.0]}, + {"embedding_history": [3.0]}, + ], + on_bad_vectors="drop", + ) + + assert table.to_arrow()["embedding_history"].to_pylist() == [ + [1.0, 2.0], + [3.0], + ] + + +def test_on_bad_vectors_all_null_schema_vector_batches_do_not_crash( + mem_db: DBConnection, +): + schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2), nullable=True)]) + table = mem_db.create_table("test_all_null_vector_batch", schema=schema) + + table.add([{"vector": None}], on_bad_vectors="drop") + + assert table.to_arrow()["vector"].to_pylist() == [None] + + def test_restore(mem_db: DBConnection): table = mem_db.create_table( "my_table", diff --git a/python/python/tests/test_util.py b/python/python/tests/test_util.py index 74296a221..b5ab159b7 100644 --- a/python/python/tests/test_util.py +++ b/python/python/tests/test_util.py @@ -15,8 +15,10 @@ from lancedb.table import ( _cast_to_target_schema, _handle_bad_vectors, _into_pyarrow_reader, - _sanitize_data, _infer_target_schema, + _merge_metadata, + _sanitize_data, + sanitize_create_table, ) import pyarrow as pa import pandas as pd @@ -304,6 +306,117 @@ def test_handle_bad_vectors_noop(): assert output["vector"] == vector +def test_handle_bad_vectors_updates_reader_schema_for_target_schema(): + data = pa.table({"vector": [[1, 2, 3, 4]]}) + target_schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 4))]) + + output = _handle_bad_vectors( + data.to_reader(), + on_bad_vectors="drop", + target_schema=target_schema, + ) + + assert output.schema == pa.schema([pa.field("vector", pa.list_(pa.float32()))]) + assert output.read_all()["vector"].to_pylist() == [[1.0, 2.0, 3.0, 4.0]] + + +def test_sanitize_data_keeps_target_field_metadata(): + source_field = pa.field( + "vector", + pa.list_(pa.float32(), 2), + metadata={b"source": b"drop-me"}, + ) + target_field = pa.field( + "vector", + pa.list_(pa.float32(), 2), + metadata={b"target": b"keep-me"}, + ) + data = pa.table( + {"vector": pa.array([[1.0, 2.0]], type=pa.list_(pa.float32(), 2))}, + schema=pa.schema([source_field]), + ) + + output = _sanitize_data( + data, + target_schema=pa.schema([target_field]), + on_bad_vectors="drop", + ).read_all() + + assert output.schema.field("vector").metadata == {b"target": b"keep-me"} + + +def test_sanitize_data_uses_separate_embedding_metadata_for_bad_vectors(): + registry = EmbeddingFunctionRegistry.get_instance() + conf = EmbeddingFunctionConfig( + source_column="text", + vector_column="custom_vector", + function=MockTextEmbeddingFunction.create(), + ) + metadata = registry.get_table_metadata([conf]) + schema = pa.schema( + { + "text": pa.string(), + "custom_vector": pa.list_(pa.float32(), 10), + }, + metadata={b"note": b"keep-me"}, + ) + data = pa.table( + { + "text": ["bad", "good"], + "custom_vector": [[1.0] * 9, [2.0] * 10], + } + ) + + output = _sanitize_data( + data, + target_schema=schema, + metadata=metadata, + on_bad_vectors="drop", + ).read_all() + + assert output["text"].to_pylist() == ["good"] + assert output.schema.metadata[b"note"] == b"keep-me" + assert b"embedding_functions" in output.schema.metadata + + +def test_sanitize_create_table_merges_and_overrides_embedding_metadata(): + registry = EmbeddingFunctionRegistry.get_instance() + old_conf = EmbeddingFunctionConfig( + source_column="text", + vector_column="old_vector", + function=MockTextEmbeddingFunction.create(), + ) + new_conf = EmbeddingFunctionConfig( + source_column="text", + vector_column="custom_vector", + function=MockTextEmbeddingFunction.create(), + ) + metadata = registry.get_table_metadata([new_conf]) + schema = pa.schema( + { + "text": pa.string(), + "custom_vector": pa.list_(pa.float32(), 10), + }, + metadata=_merge_metadata( + {b"note": b"keep-me"}, + registry.get_table_metadata([old_conf]), + ), + ) + + data, schema = sanitize_create_table( + pa.table({"text": ["good"]}), + schema, + metadata=metadata, + on_bad_vectors="drop", + ) + + assert schema.metadata[b"note"] == b"keep-me" + assert b"embedding_functions" in schema.metadata + assert data.schema.metadata[b"note"] == b"keep-me" + funcs = EmbeddingFunctionRegistry.get_instance().parse_functions(schema.metadata) + assert set(funcs.keys()) == {"custom_vector"} + + class TestModel(lancedb.pydantic.LanceModel): a: Optional[int] b: Optional[int]