mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-26 06:39:57 +00:00
62 lines
2.0 KiB
Rust
62 lines
2.0 KiB
Rust
// SPDX-License-Identifier: Apache-2.0
|
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
|
|
|
use std::sync::Mutex;
|
|
|
|
use lancedb::DistanceType;
|
|
use pyo3::{
|
|
exceptions::{PyRuntimeError, PyValueError},
|
|
pyfunction, PyResult,
|
|
};
|
|
|
|
/// A wrapper around a rust builder
|
|
///
|
|
/// Rust builders are often implemented so that the builder methods
|
|
/// consume the builder and return a new one. This is not compatible
|
|
/// with the pyo3, which, being garbage collected, cannot easily obtain
|
|
/// ownership of an object.
|
|
///
|
|
/// This wrapper converts the compile-time safety of rust into runtime
|
|
/// errors if any attempt to use the builder happens after it is consumed.
|
|
pub struct BuilderWrapper<T> {
|
|
name: String,
|
|
inner: Mutex<Option<T>>,
|
|
}
|
|
|
|
impl<T> BuilderWrapper<T> {
|
|
pub fn new(name: impl AsRef<str>, inner: T) -> Self {
|
|
Self {
|
|
name: name.as_ref().to_string(),
|
|
inner: Mutex::new(Some(inner)),
|
|
}
|
|
}
|
|
|
|
pub fn consume<O>(&self, mod_fn: impl FnOnce(T) -> O) -> PyResult<O> {
|
|
let mut inner = self.inner.lock().unwrap();
|
|
let inner_builder = inner.take().ok_or_else(|| {
|
|
PyRuntimeError::new_err(format!("{} has already been consumed", self.name))
|
|
})?;
|
|
let result = mod_fn(inner_builder);
|
|
Ok(result)
|
|
}
|
|
}
|
|
|
|
pub fn parse_distance_type(distance_type: impl AsRef<str>) -> PyResult<DistanceType> {
|
|
match distance_type.as_ref().to_lowercase().as_str() {
|
|
"l2" => Ok(DistanceType::L2),
|
|
"cosine" => Ok(DistanceType::Cosine),
|
|
"dot" => Ok(DistanceType::Dot),
|
|
"hamming" => Ok(DistanceType::Hamming),
|
|
_ => Err(PyValueError::new_err(format!(
|
|
"Invalid distance type '{}'. Must be one of l2, cosine, dot, or hamming",
|
|
distance_type.as_ref()
|
|
))),
|
|
}
|
|
}
|
|
|
|
#[pyfunction]
|
|
pub fn validate_table_name(table_name: &str) -> PyResult<()> {
|
|
lancedb::utils::validate_table_name(table_name)
|
|
.map_err(|e| PyValueError::new_err(e.to_string()))
|
|
}
|