// SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The LanceDB Authors //! Namespace utilities for Python bindings use std::collections::HashMap; use std::sync::Arc; use async_trait::async_trait; use bytes::Bytes; use lance_namespace::LanceNamespace as LanceNamespaceTrait; use lance_namespace::models::*; use pyo3::prelude::*; use pyo3::types::PyDict; /// Wrapper that allows any Python object implementing LanceNamespace protocol /// to be used as a Rust LanceNamespace. /// /// This is similar to PyLanceNamespace in lance's Python bindings - it wraps a Python /// object and calls back into Python when namespace methods are invoked. pub struct PyLanceNamespace { py_namespace: Arc>, namespace_id: String, } impl PyLanceNamespace { /// Create a new PyLanceNamespace wrapper around a Python namespace object. pub fn new(_py: Python<'_>, py_namespace: &Bound<'_, PyAny>) -> PyResult { let namespace_id = py_namespace .call_method0("namespace_id")? .extract::()?; Ok(Self { py_namespace: Arc::new(py_namespace.clone().unbind()), namespace_id, }) } /// Create an Arc from a Python namespace object. pub fn create_arc( py: Python<'_>, py_namespace: &Bound<'_, PyAny>, ) -> PyResult> { let wrapper = Self::new(py, py_namespace)?; Ok(Arc::new(wrapper)) } } impl std::fmt::Debug for PyLanceNamespace { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "PyLanceNamespace {{ id: {} }}", self.namespace_id) } } /// Get or create the DictWithModelDump class in Python. /// This class acts like a dict but also has model_dump() method. /// This allows it to work with both: /// - depythonize (which expects a dict/Mapping) /// - Python code that calls .model_dump() (like DirectoryNamespace wrapper) fn get_dict_with_model_dump_class(py: Python<'_>) -> PyResult> { // Use a module-level cache via __builtins__ let builtins = py.import("builtins")?; if builtins.hasattr("_DictWithModelDump")? { return builtins.getattr("_DictWithModelDump"); } // Create the class using exec let locals = PyDict::new(py); py.run( c"class DictWithModelDump(dict): def model_dump(self): return dict(self)", None, Some(&locals), )?; let class = locals.get_item("DictWithModelDump")?.ok_or_else(|| { pyo3::exceptions::PyRuntimeError::new_err("Failed to create DictWithModelDump class") })?; // Cache it builtins.setattr("_DictWithModelDump", &class)?; Ok(class) } /// Helper to call a Python namespace method with JSON serialization. /// For methods that take a request and return a response. /// Uses DictWithModelDump to pass a dict that also has model_dump() method, /// making it compatible with both depythonize and Python wrappers. async fn call_py_method( py_namespace: Arc>, method_name: &'static str, request: Req, ) -> lance_core::Result where Req: serde::Serialize + Send + 'static, Resp: serde::de::DeserializeOwned + Send + 'static, { let request_json = serde_json::to_string(&request).map_err(|e| { lance_core::Error::io(format!( "Failed to serialize request for {}: {}", method_name, e )) })?; let response_json = tokio::task::spawn_blocking(move || { Python::attach(|py| { let json_module = py.import("json")?; let request_dict = json_module.call_method1("loads", (&request_json,))?; // Wrap dict in DictWithModelDump so it works with both depythonize and .model_dump() let dict_class = get_dict_with_model_dump_class(py)?; let request_arg = dict_class.call1((request_dict,))?; // Call the Python method let result = py_namespace.call_method1(py, method_name, (request_arg,))?; // Convert response to dict, then to JSON // Pydantic models have model_dump() method let result_dict = if result.bind(py).hasattr("model_dump")? { result.call_method0(py, "model_dump")? } else { result }; let response_json: String = json_module .call_method1("dumps", (result_dict,))? .extract()?; Ok::<_, PyErr>(response_json) }) }) .await .map_err(|e| lance_core::Error::io(format!("Task join error for {}: {}", method_name, e)))? .map_err(|e: PyErr| lance_core::Error::io(format!("Python error in {}: {}", method_name, e)))?; serde_json::from_str(&response_json).map_err(|e| { lance_core::Error::io(format!( "Failed to deserialize response from {}: {}", method_name, e )) }) } /// Helper for methods that return () on success async fn call_py_method_unit( py_namespace: Arc>, method_name: &'static str, request: Req, ) -> lance_core::Result<()> where Req: serde::Serialize + Send + 'static, { let request_json = serde_json::to_string(&request).map_err(|e| { lance_core::Error::io(format!( "Failed to serialize request for {}: {}", method_name, e )) })?; tokio::task::spawn_blocking(move || { Python::attach(|py| { let json_module = py.import("json")?; let request_dict = json_module.call_method1("loads", (&request_json,))?; // Wrap dict in DictWithModelDump let dict_class = get_dict_with_model_dump_class(py)?; let request_arg = dict_class.call1((request_dict,))?; // Call the Python method py_namespace.call_method1(py, method_name, (request_arg,))?; Ok::<_, PyErr>(()) }) }) .await .map_err(|e| lance_core::Error::io(format!("Task join error for {}: {}", method_name, e)))? .map_err(|e: PyErr| lance_core::Error::io(format!("Python error in {}: {}", method_name, e))) } /// Helper for methods that return a primitive type async fn call_py_method_primitive( py_namespace: Arc>, method_name: &'static str, request: Req, ) -> lance_core::Result where Req: serde::Serialize + Send + 'static, Resp: for<'py> pyo3::FromPyObject<'py> + Send + 'static, { let request_json = serde_json::to_string(&request).map_err(|e| { lance_core::Error::io(format!( "Failed to serialize request for {}: {}", method_name, e )) })?; tokio::task::spawn_blocking(move || { Python::attach(|py| { let json_module = py.import("json")?; let request_dict = json_module.call_method1("loads", (&request_json,))?; // Wrap dict in DictWithModelDump let dict_class = get_dict_with_model_dump_class(py)?; let request_arg = dict_class.call1((request_dict,))?; // Call the Python method let result = py_namespace.call_method1(py, method_name, (request_arg,))?; let value: Resp = result.extract(py)?; Ok::<_, PyErr>(value) }) }) .await .map_err(|e| lance_core::Error::io(format!("Task join error for {}: {}", method_name, e)))? .map_err(|e: PyErr| lance_core::Error::io(format!("Python error in {}: {}", method_name, e))) } /// Helper for methods that return Bytes async fn call_py_method_bytes( py_namespace: Arc>, method_name: &'static str, request: Req, ) -> lance_core::Result where Req: serde::Serialize + Send + 'static, { let request_json = serde_json::to_string(&request).map_err(|e| { lance_core::Error::io(format!( "Failed to serialize request for {}: {}", method_name, e )) })?; tokio::task::spawn_blocking(move || { Python::attach(|py| { let json_module = py.import("json")?; let request_dict = json_module.call_method1("loads", (&request_json,))?; // Wrap dict in DictWithModelDump let dict_class = get_dict_with_model_dump_class(py)?; let request_arg = dict_class.call1((request_dict,))?; // Call the Python method let result = py_namespace.call_method1(py, method_name, (request_arg,))?; let bytes_data: Vec = result.extract(py)?; Ok::<_, PyErr>(Bytes::from(bytes_data)) }) }) .await .map_err(|e| lance_core::Error::io(format!("Task join error for {}: {}", method_name, e)))? .map_err(|e: PyErr| lance_core::Error::io(format!("Python error in {}: {}", method_name, e))) } /// Helper for methods that take request + data and return a response async fn call_py_method_with_data( py_namespace: Arc>, method_name: &'static str, request: Req, data: Bytes, ) -> lance_core::Result where Req: serde::Serialize + Send + 'static, Resp: serde::de::DeserializeOwned + Send + 'static, { let request_json = serde_json::to_string(&request).map_err(|e| { lance_core::Error::io(format!( "Failed to serialize request for {}: {}", method_name, e )) })?; let response_json = tokio::task::spawn_blocking(move || { Python::attach(|py| { let json_module = py.import("json")?; let request_dict = json_module.call_method1("loads", (&request_json,))?; // Wrap dict in DictWithModelDump let dict_class = get_dict_with_model_dump_class(py)?; let request_arg = dict_class.call1((request_dict,))?; // Pass request and bytes to Python method let py_bytes = pyo3::types::PyBytes::new(py, &data); let result = py_namespace.call_method1(py, method_name, (request_arg, py_bytes))?; // Convert response dict to JSON let response_json: String = json_module.call_method1("dumps", (result,))?.extract()?; Ok::<_, PyErr>(response_json) }) }) .await .map_err(|e| lance_core::Error::io(format!("Task join error for {}: {}", method_name, e)))? .map_err(|e: PyErr| lance_core::Error::io(format!("Python error in {}: {}", method_name, e)))?; serde_json::from_str(&response_json).map_err(|e| { lance_core::Error::io(format!( "Failed to deserialize response from {}: {}", method_name, e )) }) } #[async_trait] impl LanceNamespaceTrait for PyLanceNamespace { fn namespace_id(&self) -> String { self.namespace_id.clone() } async fn list_namespaces( &self, request: ListNamespacesRequest, ) -> lance_core::Result { call_py_method(self.py_namespace.clone(), "list_namespaces", request).await } async fn describe_namespace( &self, request: DescribeNamespaceRequest, ) -> lance_core::Result { call_py_method(self.py_namespace.clone(), "describe_namespace", request).await } async fn create_namespace( &self, request: CreateNamespaceRequest, ) -> lance_core::Result { call_py_method(self.py_namespace.clone(), "create_namespace", request).await } async fn drop_namespace( &self, request: DropNamespaceRequest, ) -> lance_core::Result { call_py_method(self.py_namespace.clone(), "drop_namespace", request).await } async fn namespace_exists(&self, request: NamespaceExistsRequest) -> lance_core::Result<()> { call_py_method_unit(self.py_namespace.clone(), "namespace_exists", request).await } async fn list_tables( &self, request: ListTablesRequest, ) -> lance_core::Result { call_py_method(self.py_namespace.clone(), "list_tables", request).await } async fn describe_table( &self, request: DescribeTableRequest, ) -> lance_core::Result { call_py_method(self.py_namespace.clone(), "describe_table", request).await } async fn register_table( &self, request: RegisterTableRequest, ) -> lance_core::Result { call_py_method(self.py_namespace.clone(), "register_table", request).await } async fn table_exists(&self, request: TableExistsRequest) -> lance_core::Result<()> { call_py_method_unit(self.py_namespace.clone(), "table_exists", request).await } async fn drop_table(&self, request: DropTableRequest) -> lance_core::Result { call_py_method(self.py_namespace.clone(), "drop_table", request).await } async fn deregister_table( &self, request: DeregisterTableRequest, ) -> lance_core::Result { call_py_method(self.py_namespace.clone(), "deregister_table", request).await } async fn count_table_rows(&self, request: CountTableRowsRequest) -> lance_core::Result { call_py_method_primitive(self.py_namespace.clone(), "count_table_rows", request).await } async fn create_table( &self, request: CreateTableRequest, request_data: Bytes, ) -> lance_core::Result { call_py_method_with_data( self.py_namespace.clone(), "create_table", request, request_data, ) .await } async fn declare_table( &self, request: DeclareTableRequest, ) -> lance_core::Result { call_py_method(self.py_namespace.clone(), "declare_table", request).await } async fn insert_into_table( &self, request: InsertIntoTableRequest, request_data: Bytes, ) -> lance_core::Result { call_py_method_with_data( self.py_namespace.clone(), "insert_into_table", request, request_data, ) .await } async fn merge_insert_into_table( &self, request: MergeInsertIntoTableRequest, request_data: Bytes, ) -> lance_core::Result { call_py_method_with_data( self.py_namespace.clone(), "merge_insert_into_table", request, request_data, ) .await } async fn update_table( &self, request: UpdateTableRequest, ) -> lance_core::Result { call_py_method(self.py_namespace.clone(), "update_table", request).await } async fn delete_from_table( &self, request: DeleteFromTableRequest, ) -> lance_core::Result { call_py_method(self.py_namespace.clone(), "delete_from_table", request).await } async fn query_table(&self, request: QueryTableRequest) -> lance_core::Result { call_py_method_bytes(self.py_namespace.clone(), "query_table", request).await } async fn create_table_index( &self, request: CreateTableIndexRequest, ) -> lance_core::Result { call_py_method(self.py_namespace.clone(), "create_table_index", request).await } async fn list_table_indices( &self, request: ListTableIndicesRequest, ) -> lance_core::Result { call_py_method(self.py_namespace.clone(), "list_table_indices", request).await } async fn describe_table_index_stats( &self, request: DescribeTableIndexStatsRequest, ) -> lance_core::Result { call_py_method( self.py_namespace.clone(), "describe_table_index_stats", request, ) .await } async fn describe_transaction( &self, request: DescribeTransactionRequest, ) -> lance_core::Result { call_py_method(self.py_namespace.clone(), "describe_transaction", request).await } async fn alter_transaction( &self, request: AlterTransactionRequest, ) -> lance_core::Result { call_py_method(self.py_namespace.clone(), "alter_transaction", request).await } async fn create_table_scalar_index( &self, request: CreateTableIndexRequest, ) -> lance_core::Result { call_py_method( self.py_namespace.clone(), "create_table_scalar_index", request, ) .await } async fn drop_table_index( &self, request: DropTableIndexRequest, ) -> lance_core::Result { call_py_method(self.py_namespace.clone(), "drop_table_index", request).await } async fn list_all_tables( &self, request: ListTablesRequest, ) -> lance_core::Result { call_py_method(self.py_namespace.clone(), "list_all_tables", request).await } async fn restore_table( &self, request: RestoreTableRequest, ) -> lance_core::Result { call_py_method(self.py_namespace.clone(), "restore_table", request).await } async fn rename_table( &self, request: RenameTableRequest, ) -> lance_core::Result { call_py_method(self.py_namespace.clone(), "rename_table", request).await } async fn list_table_versions( &self, request: ListTableVersionsRequest, ) -> lance_core::Result { call_py_method(self.py_namespace.clone(), "list_table_versions", request).await } async fn create_table_version( &self, request: CreateTableVersionRequest, ) -> lance_core::Result { call_py_method(self.py_namespace.clone(), "create_table_version", request).await } async fn describe_table_version( &self, request: DescribeTableVersionRequest, ) -> lance_core::Result { call_py_method(self.py_namespace.clone(), "describe_table_version", request).await } async fn batch_delete_table_versions( &self, request: BatchDeleteTableVersionsRequest, ) -> lance_core::Result { call_py_method( self.py_namespace.clone(), "batch_delete_table_versions", request, ) .await } async fn update_table_schema_metadata( &self, request: UpdateTableSchemaMetadataRequest, ) -> lance_core::Result { call_py_method( self.py_namespace.clone(), "update_table_schema_metadata", request, ) .await } async fn get_table_stats( &self, request: GetTableStatsRequest, ) -> lance_core::Result { call_py_method(self.py_namespace.clone(), "get_table_stats", request).await } async fn explain_table_query_plan( &self, request: ExplainTableQueryPlanRequest, ) -> lance_core::Result { call_py_method_primitive( self.py_namespace.clone(), "explain_table_query_plan", request, ) .await } async fn analyze_table_query_plan( &self, request: AnalyzeTableQueryPlanRequest, ) -> lance_core::Result { call_py_method_primitive( self.py_namespace.clone(), "analyze_table_query_plan", request, ) .await } async fn alter_table_add_columns( &self, request: AlterTableAddColumnsRequest, ) -> lance_core::Result { call_py_method( self.py_namespace.clone(), "alter_table_add_columns", request, ) .await } async fn alter_table_alter_columns( &self, request: AlterTableAlterColumnsRequest, ) -> lance_core::Result { call_py_method( self.py_namespace.clone(), "alter_table_alter_columns", request, ) .await } async fn alter_table_drop_columns( &self, request: AlterTableDropColumnsRequest, ) -> lance_core::Result { call_py_method( self.py_namespace.clone(), "alter_table_drop_columns", request, ) .await } async fn list_table_tags( &self, request: ListTableTagsRequest, ) -> lance_core::Result { call_py_method(self.py_namespace.clone(), "list_table_tags", request).await } async fn create_table_tag( &self, request: CreateTableTagRequest, ) -> lance_core::Result { call_py_method(self.py_namespace.clone(), "create_table_tag", request).await } async fn delete_table_tag( &self, request: DeleteTableTagRequest, ) -> lance_core::Result { call_py_method(self.py_namespace.clone(), "delete_table_tag", request).await } async fn update_table_tag( &self, request: UpdateTableTagRequest, ) -> lance_core::Result { call_py_method(self.py_namespace.clone(), "update_table_tag", request).await } async fn get_table_tag_version( &self, request: GetTableTagVersionRequest, ) -> lance_core::Result { call_py_method(self.py_namespace.clone(), "get_table_tag_version", request).await } } /// Convert Python dict to HashMap #[allow(dead_code)] fn dict_to_hashmap(dict: &Bound<'_, PyDict>) -> PyResult> { let mut map = HashMap::new(); for (key, value) in dict.iter() { let key_str: String = key.extract()?; let value_str: String = value.extract()?; map.insert(key_str, value_str); } Ok(map) } /// Extract an Arc from a Python namespace object. /// /// This function wraps any Python namespace object with PyLanceNamespace. /// The PyLanceNamespace wrapper uses DictWithModelDump to pass requests, /// which works with both: /// - Native namespaces (DirectoryNamespace, RestNamespace) that use depythonize (expects dict) /// - Custom Python implementations that call .model_dump() on the request pub fn extract_namespace_arc( py: Python<'_>, ns: Py, ) -> PyResult> { let ns_ref = ns.bind(py); PyLanceNamespace::create_arc(py, ns_ref) }