// SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The LanceDB Authors use std::{collections::HashMap, sync::Arc, time::Duration}; use arrow::{datatypes::Schema, ffi_stream::ArrowArrayStreamReader, pyarrow::FromPyArrow}; use lancedb::{ connection::Connection as LanceConnection, database::{CreateTableMode, Database, ReadConsistency}, }; use pyo3::{ exceptions::{PyRuntimeError, PyValueError}, pyclass, pyfunction, pymethods, types::{PyDict, PyDictMethods}, Bound, FromPyObject, Py, PyAny, PyObject, PyRef, PyResult, Python, }; use pyo3_async_runtimes::tokio::future_into_py; use crate::{ error::PythonErrorExt, storage_options::py_object_to_storage_options_provider, table::Table, }; #[pyclass] pub struct Connection { inner: Option, } impl Connection { pub(crate) fn new(inner: LanceConnection) -> Self { Self { inner: Some(inner) } } pub(crate) fn get_inner(&self) -> PyResult<&LanceConnection> { self.inner .as_ref() .ok_or_else(|| PyRuntimeError::new_err("Connection is closed")) } } impl Connection { fn parse_create_mode_str(mode: &str) -> PyResult { match mode { "create" => Ok(CreateTableMode::Create), "overwrite" => Ok(CreateTableMode::Overwrite), "exist_ok" => Ok(CreateTableMode::exist_ok(|builder| builder)), _ => Err(PyValueError::new_err(format!("Invalid mode {}", mode))), } } pub fn database(&self) -> PyResult> { Ok(self.get_inner()?.database().clone()) } } #[pymethods] impl Connection { fn __repr__(&self) -> String { match &self.inner { Some(inner) => inner.to_string(), None => "ClosedConnection".to_string(), } } fn is_open(&self) -> bool { self.inner.is_some() } fn close(&mut self) { self.inner.take(); } #[getter] pub fn uri(&self) -> PyResult { self.get_inner().map(|inner| inner.uri().to_string()) } #[pyo3(signature = ())] pub fn get_read_consistency_interval(self_: PyRef<'_, Self>) -> PyResult> { let inner = self_.get_inner()?.clone(); future_into_py(self_.py(), async move { Ok(match inner.read_consistency().await.infer_error()? { ReadConsistency::Manual => None, ReadConsistency::Eventual(duration) => Some(duration.as_secs_f64()), ReadConsistency::Strong => Some(0.0_f64), }) }) } #[pyo3(signature = (namespace=vec![], start_after=None, limit=None))] pub fn table_names( self_: PyRef<'_, Self>, namespace: Vec, start_after: Option, limit: Option, ) -> PyResult> { let inner = self_.get_inner()?.clone(); let mut op = inner.table_names(); op = op.namespace(namespace); if let Some(start_after) = start_after { op = op.start_after(start_after); } if let Some(limit) = limit { op = op.limit(limit); } future_into_py(self_.py(), async move { op.execute().await.infer_error() }) } #[allow(clippy::too_many_arguments)] #[pyo3(signature = (name, mode, data, namespace=vec![], storage_options=None, storage_options_provider=None, location=None))] pub fn create_table<'a>( self_: PyRef<'a, Self>, name: String, mode: &str, data: Bound<'_, PyAny>, namespace: Vec, storage_options: Option>, storage_options_provider: Option, location: Option, ) -> PyResult> { let inner = self_.get_inner()?.clone(); let mode = Self::parse_create_mode_str(mode)?; let batches = ArrowArrayStreamReader::from_pyarrow_bound(&data)?; let mut builder = inner.create_table(name, batches).mode(mode); builder = builder.namespace(namespace); if let Some(storage_options) = storage_options { builder = builder.storage_options(storage_options); } if let Some(provider_obj) = storage_options_provider { let provider = py_object_to_storage_options_provider(provider_obj)?; builder = builder.storage_options_provider(provider); } if let Some(location) = location { builder = builder.location(location); } future_into_py(self_.py(), async move { let table = builder.execute().await.infer_error()?; Ok(Table::new(table)) }) } #[allow(clippy::too_many_arguments)] #[pyo3(signature = (name, mode, schema, namespace=vec![], storage_options=None, storage_options_provider=None, location=None))] pub fn create_empty_table<'a>( self_: PyRef<'a, Self>, name: String, mode: &str, schema: Bound<'_, PyAny>, namespace: Vec, storage_options: Option>, storage_options_provider: Option, location: Option, ) -> PyResult> { let inner = self_.get_inner()?.clone(); let mode = Self::parse_create_mode_str(mode)?; let schema = Schema::from_pyarrow_bound(&schema)?; let mut builder = inner.create_empty_table(name, Arc::new(schema)).mode(mode); builder = builder.namespace(namespace); if let Some(storage_options) = storage_options { builder = builder.storage_options(storage_options); } if let Some(provider_obj) = storage_options_provider { let provider = py_object_to_storage_options_provider(provider_obj)?; builder = builder.storage_options_provider(provider); } if let Some(location) = location { builder = builder.location(location); } future_into_py(self_.py(), async move { let table = builder.execute().await.infer_error()?; Ok(Table::new(table)) }) } #[pyo3(signature = (name, namespace=vec![], storage_options = None, storage_options_provider=None, index_cache_size = None, location=None))] pub fn open_table( self_: PyRef<'_, Self>, name: String, namespace: Vec, storage_options: Option>, storage_options_provider: Option, index_cache_size: Option, location: Option, ) -> PyResult> { let inner = self_.get_inner()?.clone(); let mut builder = inner.open_table(name); builder = builder.namespace(namespace); if let Some(storage_options) = storage_options { builder = builder.storage_options(storage_options); } if let Some(provider_obj) = storage_options_provider { let provider = py_object_to_storage_options_provider(provider_obj)?; builder = builder.storage_options_provider(provider); } if let Some(index_cache_size) = index_cache_size { builder = builder.index_cache_size(index_cache_size); } if let Some(location) = location { builder = builder.location(location); } future_into_py(self_.py(), async move { let table = builder.execute().await.infer_error()?; Ok(Table::new(table)) }) } #[pyo3(signature = (target_table_name, source_uri, target_namespace=vec![], source_version=None, source_tag=None, is_shallow=true))] pub fn clone_table( self_: PyRef<'_, Self>, target_table_name: String, source_uri: String, target_namespace: Vec, source_version: Option, source_tag: Option, is_shallow: bool, ) -> PyResult> { let inner = self_.get_inner()?.clone(); let mut builder = inner.clone_table(target_table_name, source_uri); builder = builder.target_namespace(target_namespace); if let Some(version) = source_version { builder = builder.source_version(version); } if let Some(tag) = source_tag { builder = builder.source_tag(tag); } builder = builder.is_shallow(is_shallow); future_into_py(self_.py(), async move { let table = builder.execute().await.infer_error()?; Ok(Table::new(table)) }) } #[pyo3(signature = (cur_name, new_name, cur_namespace=vec![], new_namespace=vec![]))] pub fn rename_table( self_: PyRef<'_, Self>, cur_name: String, new_name: String, cur_namespace: Vec, new_namespace: Vec, ) -> PyResult> { let inner = self_.get_inner()?.clone(); future_into_py(self_.py(), async move { inner .rename_table(cur_name, new_name, &cur_namespace, &new_namespace) .await .infer_error() }) } #[pyo3(signature = (name, namespace=vec![]))] pub fn drop_table( self_: PyRef<'_, Self>, name: String, namespace: Vec, ) -> PyResult> { let inner = self_.get_inner()?.clone(); future_into_py(self_.py(), async move { inner.drop_table(name, &namespace).await.infer_error() }) } #[pyo3(signature = (namespace=vec![],))] pub fn drop_all_tables( self_: PyRef<'_, Self>, namespace: Vec, ) -> PyResult> { let inner = self_.get_inner()?.clone(); future_into_py(self_.py(), async move { inner.drop_all_tables(&namespace).await.infer_error() }) } // Namespace management methods #[pyo3(signature = (namespace=vec![], page_token=None, limit=None))] pub fn list_namespaces( self_: PyRef<'_, Self>, namespace: Vec, page_token: Option, limit: Option, ) -> PyResult> { let inner = self_.get_inner()?.clone(); let py = self_.py(); future_into_py(py, async move { use lance_namespace::models::ListNamespacesRequest; let request = ListNamespacesRequest { id: if namespace.is_empty() { None } else { Some(namespace) }, page_token, limit: limit.map(|l| l as i32), }; let response = inner.list_namespaces(request).await.infer_error()?; Python::with_gil(|py| -> PyResult> { let dict = PyDict::new(py); dict.set_item("namespaces", response.namespaces)?; dict.set_item("page_token", response.page_token)?; Ok(dict.unbind()) }) }) } #[pyo3(signature = (namespace, mode=None, properties=None))] pub fn create_namespace( self_: PyRef<'_, Self>, namespace: Vec, mode: Option, properties: Option>, ) -> PyResult> { let inner = self_.get_inner()?.clone(); let py = self_.py(); future_into_py(py, async move { use lance_namespace::models::{create_namespace_request, CreateNamespaceRequest}; let mode_enum = mode.and_then(|m| match m.to_lowercase().as_str() { "create" => Some(create_namespace_request::Mode::Create), "exist_ok" => Some(create_namespace_request::Mode::ExistOk), "overwrite" => Some(create_namespace_request::Mode::Overwrite), _ => None, }); let request = CreateNamespaceRequest { id: if namespace.is_empty() { None } else { Some(namespace) }, mode: mode_enum, properties, }; let response = inner.create_namespace(request).await.infer_error()?; Python::with_gil(|py| -> PyResult> { let dict = PyDict::new(py); dict.set_item("properties", response.properties)?; Ok(dict.unbind()) }) }) } #[pyo3(signature = (namespace, mode=None, behavior=None))] pub fn drop_namespace( self_: PyRef<'_, Self>, namespace: Vec, mode: Option, behavior: Option, ) -> PyResult> { let inner = self_.get_inner()?.clone(); let py = self_.py(); future_into_py(py, async move { use lance_namespace::models::{drop_namespace_request, DropNamespaceRequest}; let mode_enum = mode.and_then(|m| match m.to_uppercase().as_str() { "SKIP" => Some(drop_namespace_request::Mode::Skip), "FAIL" => Some(drop_namespace_request::Mode::Fail), _ => None, }); let behavior_enum = behavior.and_then(|b| match b.to_uppercase().as_str() { "RESTRICT" => Some(drop_namespace_request::Behavior::Restrict), "CASCADE" => Some(drop_namespace_request::Behavior::Cascade), _ => None, }); let request = DropNamespaceRequest { id: if namespace.is_empty() { None } else { Some(namespace) }, mode: mode_enum, behavior: behavior_enum, }; let response = inner.drop_namespace(request).await.infer_error()?; Python::with_gil(|py| -> PyResult> { let dict = PyDict::new(py); dict.set_item("properties", response.properties)?; dict.set_item("transaction_id", response.transaction_id)?; Ok(dict.unbind()) }) }) } #[pyo3(signature = (namespace,))] pub fn describe_namespace( self_: PyRef<'_, Self>, namespace: Vec, ) -> PyResult> { let inner = self_.get_inner()?.clone(); let py = self_.py(); future_into_py(py, async move { use lance_namespace::models::DescribeNamespaceRequest; let request = DescribeNamespaceRequest { id: if namespace.is_empty() { None } else { Some(namespace) }, }; let response = inner.describe_namespace(request).await.infer_error()?; Python::with_gil(|py| -> PyResult> { let dict = PyDict::new(py); dict.set_item("properties", response.properties)?; Ok(dict.unbind()) }) }) } #[pyo3(signature = (namespace=vec![], page_token=None, limit=None))] pub fn list_tables( self_: PyRef<'_, Self>, namespace: Vec, page_token: Option, limit: Option, ) -> PyResult> { let inner = self_.get_inner()?.clone(); let py = self_.py(); future_into_py(py, async move { use lance_namespace::models::ListTablesRequest; let request = ListTablesRequest { id: if namespace.is_empty() { None } else { Some(namespace) }, page_token, limit: limit.map(|l| l as i32), }; let response = inner.list_tables(request).await.infer_error()?; Python::with_gil(|py| -> PyResult> { let dict = PyDict::new(py); dict.set_item("tables", response.tables)?; dict.set_item("page_token", response.page_token)?; Ok(dict.unbind()) }) }) } } #[pyfunction] #[pyo3(signature = (uri, api_key=None, region=None, host_override=None, read_consistency_interval=None, client_config=None, storage_options=None, session=None))] #[allow(clippy::too_many_arguments)] pub fn connect( py: Python<'_>, uri: String, api_key: Option, region: Option, host_override: Option, read_consistency_interval: Option, client_config: Option, storage_options: Option>, session: Option, ) -> PyResult> { future_into_py(py, async move { let mut builder = lancedb::connect(&uri); if let Some(api_key) = api_key { builder = builder.api_key(&api_key); } if let Some(region) = region { builder = builder.region(®ion); } if let Some(host_override) = host_override { builder = builder.host_override(&host_override); } if let Some(read_consistency_interval) = read_consistency_interval { let read_consistency_interval = Duration::from_secs_f64(read_consistency_interval); builder = builder.read_consistency_interval(read_consistency_interval); } if let Some(storage_options) = storage_options { builder = builder.storage_options(storage_options); } #[cfg(feature = "remote")] if let Some(client_config) = client_config { builder = builder.client_config(client_config.into()); } if let Some(session) = session { builder = builder.session(session.inner.clone()); } Ok(Connection::new(builder.execute().await.infer_error()?)) }) } #[derive(FromPyObject)] pub struct PyClientConfig { user_agent: String, retry_config: Option, timeout_config: Option, extra_headers: Option>, id_delimiter: Option, tls_config: Option, header_provider: Option>, } #[derive(FromPyObject)] pub struct PyClientRetryConfig { retries: Option, connect_retries: Option, read_retries: Option, backoff_factor: Option, backoff_jitter: Option, statuses: Option>, } #[derive(FromPyObject)] pub struct PyClientTimeoutConfig { timeout: Option, connect_timeout: Option, read_timeout: Option, pool_idle_timeout: Option, } #[derive(FromPyObject)] pub struct PyClientTlsConfig { cert_file: Option, key_file: Option, ssl_ca_cert: Option, assert_hostname: bool, } #[cfg(feature = "remote")] impl From for lancedb::remote::RetryConfig { fn from(value: PyClientRetryConfig) -> Self { Self { retries: value.retries, connect_retries: value.connect_retries, read_retries: value.read_retries, backoff_factor: value.backoff_factor, backoff_jitter: value.backoff_jitter, statuses: value.statuses, } } } #[cfg(feature = "remote")] impl From for lancedb::remote::TimeoutConfig { fn from(value: PyClientTimeoutConfig) -> Self { Self { timeout: value.timeout, connect_timeout: value.connect_timeout, read_timeout: value.read_timeout, pool_idle_timeout: value.pool_idle_timeout, } } } #[cfg(feature = "remote")] impl From for lancedb::remote::TlsConfig { fn from(value: PyClientTlsConfig) -> Self { Self { cert_file: value.cert_file, key_file: value.key_file, ssl_ca_cert: value.ssl_ca_cert, assert_hostname: value.assert_hostname, } } } #[cfg(feature = "remote")] impl From for lancedb::remote::ClientConfig { fn from(value: PyClientConfig) -> Self { use crate::header::PyHeaderProvider; let header_provider = value.header_provider.map(|provider| { let py_provider = PyHeaderProvider::new(provider); Arc::new(py_provider) as Arc }); Self { user_agent: value.user_agent, retry_config: value.retry_config.map(Into::into).unwrap_or_default(), timeout_config: value.timeout_config.map(Into::into).unwrap_or_default(), extra_headers: value.extra_headers.unwrap_or_default(), id_delimiter: value.id_delimiter, tls_config: value.tls_config.map(Into::into), header_provider, } } }