// SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The LanceDB Authors use std::sync::Mutex; use lancedb::DistanceType; use pyo3::{ exceptions::{PyRuntimeError, PyValueError}, pyfunction, PyResult, }; /// A wrapper around a rust builder /// /// Rust builders are often implemented so that the builder methods /// consume the builder and return a new one. This is not compatible /// with the pyo3, which, being garbage collected, cannot easily obtain /// ownership of an object. /// /// This wrapper converts the compile-time safety of rust into runtime /// errors if any attempt to use the builder happens after it is consumed. pub struct BuilderWrapper { name: String, inner: Mutex>, } impl BuilderWrapper { pub fn new(name: impl AsRef, inner: T) -> Self { Self { name: name.as_ref().to_string(), inner: Mutex::new(Some(inner)), } } pub fn consume(&self, mod_fn: impl FnOnce(T) -> O) -> PyResult { let mut inner = self.inner.lock().unwrap(); let inner_builder = inner.take().ok_or_else(|| { PyRuntimeError::new_err(format!("{} has already been consumed", self.name)) })?; let result = mod_fn(inner_builder); Ok(result) } } pub fn parse_distance_type(distance_type: impl AsRef) -> PyResult { match distance_type.as_ref().to_lowercase().as_str() { "l2" => Ok(DistanceType::L2), "cosine" => Ok(DistanceType::Cosine), "dot" => Ok(DistanceType::Dot), "hamming" => Ok(DistanceType::Hamming), _ => Err(PyValueError::new_err(format!( "Invalid distance type '{}'. Must be one of l2, cosine, dot, or hamming", distance_type.as_ref() ))), } } #[pyfunction] pub fn validate_table_name(table_name: &str) -> PyResult<()> { lancedb::utils::validate_table_name(table_name) .map_err(|e| PyValueError::new_err(e.to_string())) } /// A wrapper around a LanceDB type to allow it to be used in Python #[derive(Debug, Clone)] pub struct PyLanceDB(pub T);