From fa53cfcfd2efa7a5d34aec169758ea5926901222 Mon Sep 17 00:00:00 2001 From: Bert Date: Tue, 4 Mar 2025 16:58:46 -0500 Subject: [PATCH] feat: support modifying field metadata in lancedb python (#2178) --- python/python/lancedb/table.py | 28 +++++++++++++++++++++++++ python/python/tests/test_table.py | 9 ++++++++ python/src/table.rs | 34 ++++++++++++++++++++++++++++++- 3 files changed, 70 insertions(+), 1 deletion(-) diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index cf077ba5..1aa3725d 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -2405,6 +2405,19 @@ class LanceTable(Table): """ LOOP.run(self._table.migrate_v2_manifest_paths()) + def replace_field_metadata(self, field_name: str, new_metadata: Dict[str, str]): + """ + Replace the metadata of a field in the schema + + Parameters + ---------- + field_name: str + The name of the field to replace the metadata for + new_metadata: dict + The new metadata to set + """ + LOOP.run(self._table.replace_field_metadata(field_name, new_metadata)) + def _handle_bad_vectors( reader: pa.RecordBatchReader, @@ -3635,6 +3648,21 @@ class AsyncTable: """ await self._inner.migrate_manifest_paths_v2() + async def replace_field_metadata( + self, field_name: str, new_metadata: dict[str, str] + ): + """ + Replace the metadata of a field in the schema + + Parameters + ---------- + field_name: str + The name of the field to replace the metadata for + new_metadata: dict + The new metadata to set + """ + await self._inner.replace_field_metadata(field_name, new_metadata) + @dataclass class IndexStatistics: diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index 82e11cf5..0d0d0bed 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -1481,3 +1481,12 @@ async def test_optimize_delete_unverified(tmp_db_async: AsyncConnection, tmp_pat cleanup_older_than=timedelta(seconds=0), delete_unverified=True ) assert stats.prune.old_versions_removed == 2 + + +def test_replace_field_metadata(tmp_path): + db = lancedb.connect(tmp_path) + table = db.create_table("my_table", data=[{"x": 0}]) + table.replace_field_metadata("x", {"foo": "bar"}) + schema = table.schema + field = schema[0].metadata + assert field == {b"foo": b"bar"} diff --git a/python/src/table.rs b/python/src/table.rs index 211487fa..87f4b2ef 100644 --- a/python/src/table.rs +++ b/python/src/table.rs @@ -10,12 +10,13 @@ use lancedb::table::{ Table as LanceDbTable, }; use pyo3::{ - exceptions::{PyRuntimeError, PyValueError}, + exceptions::{PyKeyError, PyRuntimeError, PyValueError}, pyclass, pymethods, types::{IntoPyDict, PyAnyMethods, PyDict, PyDictMethods}, Bound, FromPyObject, PyAny, PyRef, PyResult, Python, ToPyObject, }; use pyo3_async_runtimes::tokio::future_into_py; +use std::collections::HashMap; use crate::{ error::PythonErrorExt, @@ -486,6 +487,37 @@ impl Table { Ok(()) }) } + + pub fn replace_field_metadata<'a>( + self_: PyRef<'a, Self>, + field_name: String, + metadata: &Bound<'_, PyDict>, + ) -> PyResult> { + let mut new_metadata = HashMap::::new(); + for (column_name, value) in metadata.into_iter() { + let key: String = column_name.extract()?; + let value: String = value.extract()?; + new_metadata.insert(key, value); + } + + let inner = self_.inner_ref()?.clone(); + future_into_py(self_.py(), async move { + let native_tbl = inner + .as_native() + .ok_or_else(|| PyValueError::new_err("This cannot be run on a remote table"))?; + let schema = native_tbl.manifest().await.infer_error()?.schema; + let field = schema + .field(&field_name) + .ok_or_else(|| PyKeyError::new_err(format!("Field {} not found", field_name)))?; + + native_tbl + .replace_field_metadata(vec![(field.id as u32, new_metadata)]) + .await + .infer_error()?; + + Ok(()) + }) + } } #[derive(FromPyObject)]