diff --git a/rust/vectordb/src/database.rs b/rust/vectordb/src/database.rs index 4b3710b7..92c30d72 100644 --- a/rust/vectordb/src/database.rs +++ b/rust/vectordb/src/database.rs @@ -20,7 +20,7 @@ use lance::dataset::WriteParams; use lance::io::object_store::ObjectStore; use snafu::prelude::*; -use crate::error::{CreateDirSnafu, Error, InvalidTableNameSnafu, Result}; +use crate::error::{CreateDirSnafu, InvalidTableNameSnafu, Result}; use crate::table::{ReadParams, Table}; pub const LANCE_FILE_EXTENSION: &str = "lance"; @@ -36,17 +36,6 @@ pub struct Database { const LANCE_EXTENSION: &str = "lance"; const ENGINE: &str = "engine"; -/// Parse a url, if it's not a valid url, assume it's a local file -/// and try to parse with file:// appended -fn parse_url(url: &str) -> Result { - match url::Url::parse(url) { - Ok(url) => Ok(url), - Err(_) => url::Url::parse(format!("file://{}", url).as_str()).map_err(|e| Error::Lance { - message: format!("Failed to parse uri: {}", e), - }), - } -} - /// A connection to LanceDB impl Database { /// Connects to LanceDB @@ -59,71 +48,73 @@ impl Database { /// /// * A [Database] object. pub async fn connect(uri: &str) -> Result { - // For a native (using lance directly) connection - // The DB doesn't use any uri parameters, but lance does - // So we need to parse the uri, extract the query string, and progate it to lance - let mut url = parse_url(uri)?; + let parse_res = url::Url::parse(uri); - // special handling for windows - if url.scheme().len() == 1 && cfg!(windows) { - let (object_store, base_path) = ObjectStore::from_uri(uri).await?; - if object_store.is_local() { - Self::try_create_dir(uri).context(CreateDirSnafu { path: uri })?; + match parse_res { + Ok(url) if url.scheme().len() == 1 && cfg!(windows) => Self::open_path(uri).await, + Ok(mut url) => { + // iter thru the query params and extract the commit store param + let mut engine = None; + let mut filtered_querys = vec![]; + + // WARNING: specifying engine is NOT a publicly supported feature in lancedb yet + // THE API WILL CHANGE + for (key, value) in url.query_pairs() { + if key == ENGINE { + engine = Some(value.to_string()); + } else { + // to owned so we can modify the url + filtered_querys.push((key.to_string(), value.to_string())); + } + } + + // Filter out the commit store query param -- it's a lancedb param + url.query_pairs_mut().clear(); + url.query_pairs_mut().extend_pairs(filtered_querys); + // Take a copy of the query string so we can propagate it to lance + let query_string = url.query().map(|s| s.to_string()); + // clear the query string so we can use the url as the base uri + // use .set_query(None) instead of .set_query("") because the latter + // will add a trailing '?' to the url + url.set_query(None); + + let table_base_uri = if let Some(store) = engine { + static WARN_ONCE: std::sync::Once = std::sync::Once::new(); + WARN_ONCE.call_once(|| { + log::warn!("Specifing engine is not a publicly supported feature in lancedb yet. THE API WILL CHANGE"); + }); + let old_scheme = url.scheme().to_string(); + let new_scheme = format!("{}+{}", old_scheme, store); + url.to_string().replacen(&old_scheme, &new_scheme, 1) + } else { + url.to_string() + }; + + let plain_uri = url.to_string(); + let (object_store, base_path) = ObjectStore::from_uri(&plain_uri).await?; + if object_store.is_local() { + Self::try_create_dir(&plain_uri).context(CreateDirSnafu { path: plain_uri })?; + } + + Ok(Database { + uri: table_base_uri, + query_string, + base_path, + object_store, + }) } - return Ok(Database { - uri: uri.to_string(), - query_string: None, - base_path, - object_store, - }); + Err(_) => Self::open_path(uri).await, } + } - // iter thru the query params and extract the commit store param - let mut engine = None; - let mut filtered_querys = vec![]; - - // WARNING: specifying engine is NOT a publicly supported feature in lancedb yet - // THE API WILL CHANGE - for (key, value) in url.query_pairs() { - if key == ENGINE { - engine = Some(value.to_string()); - } else { - // to owned so we can modify the url - filtered_querys.push((key.to_string(), value.to_string())); - } - } - - // Filter out the commit store query param -- it's a lancedb param - url.query_pairs_mut().clear(); - url.query_pairs_mut().extend_pairs(filtered_querys); - // Take a copy of the query string so we can propagate it to lance - let query_string = url.query().map(|s| s.to_string()); - // clear the query string so we can use the url as the base uri - // use .set_query(None) instead of .set_query("") because the latter - // will add a trailing '?' to the url - url.set_query(None); - - let table_base_uri = if let Some(store) = engine { - static WARN_ONCE: std::sync::Once = std::sync::Once::new(); - WARN_ONCE.call_once(|| { - log::warn!("Specifing engine is not a publicly supported feature in lancedb yet. THE API WILL CHANGE"); - }); - let old_scheme = url.scheme().to_string(); - let new_scheme = format!("{}+{}", old_scheme, store); - url.to_string().replacen(&old_scheme, &new_scheme, 1) - } else { - url.to_string() - }; - - let plain_uri = url.to_string(); - let (object_store, base_path) = ObjectStore::from_uri(&plain_uri).await?; + async fn open_path(path: &str) -> Result { + let (object_store, base_path) = ObjectStore::from_uri(path).await?; if object_store.is_local() { - Self::try_create_dir(&plain_uri).context(CreateDirSnafu { path: plain_uri })?; + Self::try_create_dir(path).context(CreateDirSnafu { path: path })?; } - - Ok(Database { - uri: table_base_uri, - query_string, + Ok(Self { + uri: path.to_string(), + query_string: None, base_path, object_store, }) @@ -250,15 +241,29 @@ mod tests { let uri = tmp_dir.path().to_str().unwrap(); let db = Database::connect(uri).await.unwrap(); - // file:// scheme should be automatically appended if not specified - // windows path come with drive letter, so file:// won't be appended - let expected = if cfg!(windows) { - uri.to_string() - } else { - format!("file://{}", uri) - }; + assert_eq!(db.uri, uri); + } - assert_eq!(db.uri, expected); + #[cfg(not(windows))] + #[tokio::test] + async fn test_connect_relative() { + let tmp_dir = tempdir().unwrap(); + let uri = std::fs::canonicalize(tmp_dir.path().to_str().unwrap()).unwrap(); + + let mut relative_anacestors = vec![]; + let current_dir = std::env::current_dir().unwrap(); + let mut ancestors = current_dir.ancestors(); + while let Some(_) = ancestors.next() { + relative_anacestors.push(".."); + } + let relative_root = std::path::PathBuf::from(relative_anacestors.join("/")); + let relative_uri = relative_root.join(&uri); + + let db = Database::connect(relative_uri.to_str().unwrap()) + .await + .unwrap(); + + assert_eq!(db.uri, relative_uri.to_str().unwrap().to_string()); } #[tokio::test]