mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-29 18:00:40 +00:00
feat: expose storage options in LanceDB (#1204)
Exposes `storage_options` in LanceDB. This is provided for Python async, Node `lancedb`, and Node `vectordb` (and Rust of course). Python synchronous is omitted because it's not compatible with the PyArrow filesystems we use there currently. In the future, we will move the sync API to wrap the async one, and then it will get support for `storage_options`. 1. Fixes #1168 2. Closes #1165 3. Closes #1082 4. Closes #439 5. Closes #897 6. Closes #642 7. Closes #281 8. Closes #114 9. Closes #990 10. Deprecating `awsCredentials` and `awsRegion`. Users are encouraged to use `storageOptions` instead.
This commit is contained in:
@@ -12,19 +12,12 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use lance::io::ObjectStoreParams;
|
||||
use neon::prelude::*;
|
||||
use object_store::aws::{AwsCredential, AwsCredentialProvider};
|
||||
use object_store::CredentialProvider;
|
||||
use once_cell::sync::OnceCell;
|
||||
use tokio::runtime::Runtime;
|
||||
|
||||
use lancedb::connect;
|
||||
use lancedb::connection::Connection;
|
||||
use lancedb::table::ReadParams;
|
||||
|
||||
use crate::error::ResultExt;
|
||||
use crate::query::JsQuery;
|
||||
@@ -44,33 +37,6 @@ struct JsDatabase {
|
||||
|
||||
impl Finalize for JsDatabase {}
|
||||
|
||||
// TODO: object_store didn't export this type so I copied it.
|
||||
// Make a request to object_store to export this type
|
||||
#[derive(Debug)]
|
||||
pub struct StaticCredentialProvider<T> {
|
||||
credential: Arc<T>,
|
||||
}
|
||||
|
||||
impl<T> StaticCredentialProvider<T> {
|
||||
pub fn new(credential: T) -> Self {
|
||||
Self {
|
||||
credential: Arc::new(credential),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T> CredentialProvider for StaticCredentialProvider<T>
|
||||
where
|
||||
T: std::fmt::Debug + Send + Sync,
|
||||
{
|
||||
type Credential = T;
|
||||
|
||||
async fn get_credential(&self) -> object_store::Result<Arc<T>> {
|
||||
Ok(Arc::clone(&self.credential))
|
||||
}
|
||||
}
|
||||
|
||||
fn runtime<'a, C: Context<'a>>(cx: &mut C) -> NeonResult<&'static Runtime> {
|
||||
static RUNTIME: OnceCell<Runtime> = OnceCell::new();
|
||||
static LOG: OnceCell<()> = OnceCell::new();
|
||||
@@ -82,29 +48,28 @@ fn runtime<'a, C: Context<'a>>(cx: &mut C) -> NeonResult<&'static Runtime> {
|
||||
|
||||
fn database_new(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let path = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||
let aws_creds = get_aws_creds(&mut cx, 1)?;
|
||||
let region = get_aws_region(&mut cx, 4)?;
|
||||
let read_consistency_interval = cx
|
||||
.argument_opt(5)
|
||||
.and_then(|arg| arg.downcast::<JsNumber, _>(&mut cx).ok())
|
||||
.map(|v| v.value(&mut cx))
|
||||
.map(std::time::Duration::from_secs_f64);
|
||||
|
||||
let storage_options_js = cx.argument::<JsArray>(1)?.to_vec(&mut cx)?;
|
||||
let mut storage_options: Vec<(String, String)> = Vec::with_capacity(storage_options_js.len());
|
||||
for handle in storage_options_js {
|
||||
let obj = handle.downcast::<JsArray, _>(&mut cx).unwrap();
|
||||
let key = obj.get::<JsString, _, _>(&mut cx, 0)?.value(&mut cx);
|
||||
let value = obj.get::<JsString, _, _>(&mut cx, 0)?.value(&mut cx);
|
||||
|
||||
storage_options.push((key, value));
|
||||
}
|
||||
|
||||
let rt = runtime(&mut cx)?;
|
||||
let channel = cx.channel();
|
||||
let (deferred, promise) = cx.promise();
|
||||
|
||||
let mut conn_builder = connect(&path);
|
||||
if let Some(region) = region {
|
||||
conn_builder = conn_builder.region(®ion);
|
||||
}
|
||||
if let Some(aws_creds) = aws_creds {
|
||||
conn_builder = conn_builder.aws_creds(AwsCredential {
|
||||
key_id: aws_creds.key_id,
|
||||
secret_key: aws_creds.secret_key,
|
||||
token: aws_creds.token,
|
||||
});
|
||||
}
|
||||
let mut conn_builder = connect(&path).storage_options(storage_options);
|
||||
|
||||
if let Some(interval) = read_consistency_interval {
|
||||
conn_builder = conn_builder.read_consistency_interval(interval);
|
||||
}
|
||||
@@ -143,93 +108,19 @@ fn database_table_names(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
Ok(promise)
|
||||
}
|
||||
|
||||
/// Get AWS creds arguments from the context
|
||||
/// Consumes 3 arguments
|
||||
fn get_aws_creds(
|
||||
cx: &mut FunctionContext,
|
||||
arg_starting_location: i32,
|
||||
) -> NeonResult<Option<AwsCredential>> {
|
||||
let secret_key_id = cx
|
||||
.argument_opt(arg_starting_location)
|
||||
.filter(|arg| arg.is_a::<JsString, _>(cx))
|
||||
.and_then(|arg| arg.downcast_or_throw::<JsString, FunctionContext>(cx).ok())
|
||||
.map(|v| v.value(cx));
|
||||
|
||||
let secret_key = cx
|
||||
.argument_opt(arg_starting_location + 1)
|
||||
.filter(|arg| arg.is_a::<JsString, _>(cx))
|
||||
.and_then(|arg| arg.downcast_or_throw::<JsString, FunctionContext>(cx).ok())
|
||||
.map(|v| v.value(cx));
|
||||
|
||||
let temp_token = cx
|
||||
.argument_opt(arg_starting_location + 2)
|
||||
.filter(|arg| arg.is_a::<JsString, _>(cx))
|
||||
.and_then(|arg| arg.downcast_or_throw::<JsString, FunctionContext>(cx).ok())
|
||||
.map(|v| v.value(cx));
|
||||
|
||||
match (secret_key_id, secret_key, temp_token) {
|
||||
(Some(key_id), Some(key), optional_token) => Ok(Some(AwsCredential {
|
||||
key_id,
|
||||
secret_key: key,
|
||||
token: optional_token,
|
||||
})),
|
||||
(None, None, None) => Ok(None),
|
||||
_ => cx.throw_error("Invalid credentials configuration"),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_aws_credential_provider(
|
||||
cx: &mut FunctionContext,
|
||||
arg_starting_location: i32,
|
||||
) -> NeonResult<Option<AwsCredentialProvider>> {
|
||||
Ok(get_aws_creds(cx, arg_starting_location)?.map(|aws_cred| {
|
||||
Arc::new(StaticCredentialProvider::new(aws_cred))
|
||||
as Arc<dyn CredentialProvider<Credential = AwsCredential>>
|
||||
}))
|
||||
}
|
||||
|
||||
/// Get AWS region arguments from the context
|
||||
fn get_aws_region(cx: &mut FunctionContext, arg_location: i32) -> NeonResult<Option<String>> {
|
||||
let region = cx
|
||||
.argument_opt(arg_location)
|
||||
.filter(|arg| arg.is_a::<JsString, _>(cx))
|
||||
.map(|arg| arg.downcast_or_throw::<JsString, FunctionContext>(cx));
|
||||
|
||||
match region {
|
||||
Some(Ok(region)) => Ok(Some(region.value(cx))),
|
||||
None => Ok(None),
|
||||
Some(Err(e)) => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
fn database_open_table(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let db = cx
|
||||
.this()
|
||||
.downcast_or_throw::<JsBox<JsDatabase>, _>(&mut cx)?;
|
||||
let table_name = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||
|
||||
let aws_creds = get_aws_credential_provider(&mut cx, 1)?;
|
||||
|
||||
let aws_region = get_aws_region(&mut cx, 4)?;
|
||||
|
||||
let params = ReadParams {
|
||||
store_options: Some(ObjectStoreParams::with_aws_credentials(
|
||||
aws_creds, aws_region,
|
||||
)),
|
||||
..ReadParams::default()
|
||||
};
|
||||
|
||||
let rt = runtime(&mut cx)?;
|
||||
let channel = cx.channel();
|
||||
let database = db.database.clone();
|
||||
|
||||
let (deferred, promise) = cx.promise();
|
||||
rt.spawn(async move {
|
||||
let table_rst = database
|
||||
.open_table(&table_name)
|
||||
.lance_read_params(params)
|
||||
.execute()
|
||||
.await;
|
||||
let table_rst = database.open_table(&table_name).execute().await;
|
||||
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
let js_table = JsTable::from(table_rst.or_throw(&mut cx)?);
|
||||
|
||||
@@ -17,7 +17,6 @@ use std::ops::Deref;
|
||||
use arrow_array::{RecordBatch, RecordBatchIterator};
|
||||
use lance::dataset::optimize::CompactionOptions;
|
||||
use lance::dataset::{ColumnAlteration, NewColumnTransform, WriteMode, WriteParams};
|
||||
use lance::io::ObjectStoreParams;
|
||||
use lancedb::table::{OptimizeAction, WriteOptions};
|
||||
|
||||
use crate::arrow::{arrow_buffer_to_record_batch, record_batch_to_buffer};
|
||||
@@ -26,7 +25,7 @@ use neon::prelude::*;
|
||||
use neon::types::buffer::TypedArray;
|
||||
|
||||
use crate::error::ResultExt;
|
||||
use crate::{convert, get_aws_credential_provider, get_aws_region, runtime, JsDatabase};
|
||||
use crate::{convert, runtime, JsDatabase};
|
||||
|
||||
pub struct JsTable {
|
||||
pub table: LanceDbTable,
|
||||
@@ -59,6 +58,10 @@ impl JsTable {
|
||||
return cx.throw_error("Table::create only supports 'overwrite' and 'create' modes")
|
||||
}
|
||||
};
|
||||
let params = WriteParams {
|
||||
mode,
|
||||
..WriteParams::default()
|
||||
};
|
||||
|
||||
let rt = runtime(&mut cx)?;
|
||||
let channel = cx.channel();
|
||||
@@ -66,17 +69,6 @@ impl JsTable {
|
||||
let (deferred, promise) = cx.promise();
|
||||
let database = db.database.clone();
|
||||
|
||||
let aws_creds = get_aws_credential_provider(&mut cx, 3)?;
|
||||
let aws_region = get_aws_region(&mut cx, 6)?;
|
||||
|
||||
let params = WriteParams {
|
||||
store_params: Some(ObjectStoreParams::with_aws_credentials(
|
||||
aws_creds, aws_region,
|
||||
)),
|
||||
mode,
|
||||
..WriteParams::default()
|
||||
};
|
||||
|
||||
rt.spawn(async move {
|
||||
let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
|
||||
let table_rst = database
|
||||
@@ -112,13 +104,8 @@ impl JsTable {
|
||||
"overwrite" => WriteMode::Overwrite,
|
||||
s => return cx.throw_error(format!("invalid write mode {}", s)),
|
||||
};
|
||||
let aws_creds = get_aws_credential_provider(&mut cx, 2)?;
|
||||
let aws_region = get_aws_region(&mut cx, 5)?;
|
||||
|
||||
let params = WriteParams {
|
||||
store_params: Some(ObjectStoreParams::with_aws_credentials(
|
||||
aws_creds, aws_region,
|
||||
)),
|
||||
mode: write_mode,
|
||||
..WriteParams::default()
|
||||
};
|
||||
|
||||
@@ -46,8 +46,13 @@ tempfile = "3.5.0"
|
||||
rand = { version = "0.8.3", features = ["small_rng"] }
|
||||
uuid = { version = "1.7.0", features = ["v4"] }
|
||||
walkdir = "2"
|
||||
# For s3 integration tests (dev deps aren't allowed to be optional atm)
|
||||
aws-sdk-s3 = { version = "1.0" }
|
||||
aws-sdk-kms = { version = "1.0" }
|
||||
aws-config = { version = "1.0" }
|
||||
|
||||
[features]
|
||||
default = ["remote"]
|
||||
remote = ["dep:reqwest"]
|
||||
fp16kernels = ["lance-linalg/fp16kernels"]
|
||||
fp16kernels = ["lance-linalg/fp16kernels"]
|
||||
s3-test = []
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
//! LanceDB Database
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::fs::create_dir_all;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
@@ -22,9 +23,7 @@ use arrow_array::{RecordBatchIterator, RecordBatchReader};
|
||||
use arrow_schema::SchemaRef;
|
||||
use lance::dataset::{ReadParams, WriteMode};
|
||||
use lance::io::{ObjectStore, ObjectStoreParams, WrappingObjectStore};
|
||||
use object_store::{
|
||||
aws::AwsCredential, local::LocalFileSystem, CredentialProvider, StaticCredentialProvider,
|
||||
};
|
||||
use object_store::{aws::AwsCredential, local::LocalFileSystem};
|
||||
use snafu::prelude::*;
|
||||
|
||||
use crate::arrow::IntoArrow;
|
||||
@@ -208,6 +207,50 @@ impl<const HAS_DATA: bool, T: IntoArrow> CreateTableBuilder<HAS_DATA, T> {
|
||||
self.mode = mode;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set an option for the storage layer.
|
||||
///
|
||||
/// Options already set on the connection will be inherited by the table,
|
||||
/// but can be overridden here.
|
||||
///
|
||||
/// See available options at <https://lancedb.github.io/lancedb/guides/storage/>
|
||||
pub fn storage_option(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
|
||||
let store_options = self
|
||||
.write_options
|
||||
.lance_write_params
|
||||
.get_or_insert(Default::default())
|
||||
.store_params
|
||||
.get_or_insert(Default::default())
|
||||
.storage_options
|
||||
.get_or_insert(Default::default());
|
||||
store_options.insert(key.into(), value.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Set multiple options for the storage layer.
|
||||
///
|
||||
/// Options already set on the connection will be inherited by the table,
|
||||
/// but can be overridden here.
|
||||
///
|
||||
/// See available options at <https://lancedb.github.io/lancedb/guides/storage/>
|
||||
pub fn storage_options(
|
||||
mut self,
|
||||
pairs: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>,
|
||||
) -> Self {
|
||||
let store_options = self
|
||||
.write_options
|
||||
.lance_write_params
|
||||
.get_or_insert(Default::default())
|
||||
.store_params
|
||||
.get_or_insert(Default::default())
|
||||
.storage_options
|
||||
.get_or_insert(Default::default());
|
||||
|
||||
for (key, value) in pairs {
|
||||
store_options.insert(key.into(), value.into());
|
||||
}
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
@@ -252,6 +295,48 @@ impl OpenTableBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
/// Set an option for the storage layer.
|
||||
///
|
||||
/// Options already set on the connection will be inherited by the table,
|
||||
/// but can be overridden here.
|
||||
///
|
||||
/// See available options at <https://lancedb.github.io/lancedb/guides/storage/>
|
||||
pub fn storage_option(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
|
||||
let storage_options = self
|
||||
.lance_read_params
|
||||
.get_or_insert(Default::default())
|
||||
.store_options
|
||||
.get_or_insert(Default::default())
|
||||
.storage_options
|
||||
.get_or_insert(Default::default());
|
||||
storage_options.insert(key.into(), value.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Set multiple options for the storage layer.
|
||||
///
|
||||
/// Options already set on the connection will be inherited by the table,
|
||||
/// but can be overridden here.
|
||||
///
|
||||
/// See available options at <https://lancedb.github.io/lancedb/guides/storage/>
|
||||
pub fn storage_options(
|
||||
mut self,
|
||||
pairs: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>,
|
||||
) -> Self {
|
||||
let storage_options = self
|
||||
.lance_read_params
|
||||
.get_or_insert(Default::default())
|
||||
.store_options
|
||||
.get_or_insert(Default::default())
|
||||
.storage_options
|
||||
.get_or_insert(Default::default());
|
||||
|
||||
for (key, value) in pairs {
|
||||
storage_options.insert(key.into(), value.into());
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
/// Open the table
|
||||
pub async fn execute(self) -> Result<Table> {
|
||||
self.parent.clone().do_open_table(self).await
|
||||
@@ -385,8 +470,7 @@ pub struct ConnectBuilder {
|
||||
/// LanceDB Cloud host override, only required if using an on-premises Lance Cloud instance
|
||||
host_override: Option<String>,
|
||||
|
||||
/// User provided AWS credentials
|
||||
aws_creds: Option<AwsCredential>,
|
||||
storage_options: HashMap<String, String>,
|
||||
|
||||
/// The interval at which to check for updates from other processes.
|
||||
///
|
||||
@@ -409,8 +493,8 @@ impl ConnectBuilder {
|
||||
api_key: None,
|
||||
region: None,
|
||||
host_override: None,
|
||||
aws_creds: None,
|
||||
read_consistency_interval: None,
|
||||
storage_options: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -430,8 +514,37 @@ impl ConnectBuilder {
|
||||
}
|
||||
|
||||
/// [`AwsCredential`] to use when connecting to S3.
|
||||
#[deprecated(note = "Pass through storage_options instead")]
|
||||
pub fn aws_creds(mut self, aws_creds: AwsCredential) -> Self {
|
||||
self.aws_creds = Some(aws_creds);
|
||||
self.storage_options
|
||||
.insert("aws_access_key_id".into(), aws_creds.key_id.clone());
|
||||
self.storage_options
|
||||
.insert("aws_secret_access_key".into(), aws_creds.secret_key.clone());
|
||||
if let Some(token) = &aws_creds.token {
|
||||
self.storage_options
|
||||
.insert("aws_session_token".into(), token.clone());
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
/// Set an option for the storage layer.
|
||||
///
|
||||
/// See available options at <https://lancedb.github.io/lancedb/guides/storage/>
|
||||
pub fn storage_option(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
|
||||
self.storage_options.insert(key.into(), value.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Set multiple options for the storage layer.
|
||||
///
|
||||
/// See available options at <https://lancedb.github.io/lancedb/guides/storage/>
|
||||
pub fn storage_options(
|
||||
mut self,
|
||||
pairs: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>,
|
||||
) -> Self {
|
||||
for (key, value) in pairs {
|
||||
self.storage_options.insert(key.into(), value.into());
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
@@ -522,6 +635,9 @@ struct Database {
|
||||
pub(crate) store_wrapper: Option<Arc<dyn WrappingObjectStore>>,
|
||||
|
||||
read_consistency_interval: Option<std::time::Duration>,
|
||||
|
||||
// Storage options to be inherited by tables created from this connection
|
||||
storage_options: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Database {
|
||||
@@ -604,20 +720,11 @@ impl Database {
|
||||
};
|
||||
|
||||
let plain_uri = url.to_string();
|
||||
let os_params: ObjectStoreParams = if let Some(aws_creds) = &options.aws_creds {
|
||||
let credential_provider: Arc<
|
||||
dyn CredentialProvider<Credential = AwsCredential>,
|
||||
> = Arc::new(StaticCredentialProvider::new(AwsCredential {
|
||||
key_id: aws_creds.key_id.clone(),
|
||||
secret_key: aws_creds.secret_key.clone(),
|
||||
token: aws_creds.token.clone(),
|
||||
}));
|
||||
ObjectStoreParams::with_aws_credentials(
|
||||
Some(credential_provider),
|
||||
options.region.clone(),
|
||||
)
|
||||
} else {
|
||||
ObjectStoreParams::default()
|
||||
|
||||
let storage_options = options.storage_options.clone();
|
||||
let os_params = ObjectStoreParams {
|
||||
storage_options: Some(storage_options.clone()),
|
||||
..Default::default()
|
||||
};
|
||||
let (object_store, base_path) =
|
||||
ObjectStore::from_uri_and_params(&plain_uri, &os_params).await?;
|
||||
@@ -641,6 +748,7 @@ impl Database {
|
||||
object_store,
|
||||
store_wrapper: write_store_wrapper,
|
||||
read_consistency_interval: options.read_consistency_interval,
|
||||
storage_options,
|
||||
})
|
||||
}
|
||||
Err(_) => Self::open_path(uri, options.read_consistency_interval).await,
|
||||
@@ -662,6 +770,7 @@ impl Database {
|
||||
object_store,
|
||||
store_wrapper: None,
|
||||
read_consistency_interval,
|
||||
storage_options: HashMap::new(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -734,11 +843,26 @@ impl ConnectionInternal for Database {
|
||||
|
||||
async fn do_create_table(
|
||||
&self,
|
||||
options: CreateTableBuilder<false, NoData>,
|
||||
mut options: CreateTableBuilder<false, NoData>,
|
||||
data: Box<dyn RecordBatchReader + Send>,
|
||||
) -> Result<Table> {
|
||||
let table_uri = self.table_uri(&options.name)?;
|
||||
|
||||
// Inherit storage options from the connection
|
||||
let storage_options = options
|
||||
.write_options
|
||||
.lance_write_params
|
||||
.get_or_insert_with(Default::default)
|
||||
.store_params
|
||||
.get_or_insert_with(Default::default)
|
||||
.storage_options
|
||||
.get_or_insert_with(Default::default);
|
||||
for (key, value) in self.storage_options.iter() {
|
||||
if !storage_options.contains_key(key) {
|
||||
storage_options.insert(key.clone(), value.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let mut write_params = options.write_options.lance_write_params.unwrap_or_default();
|
||||
if matches!(&options.mode, CreateTableMode::Overwrite) {
|
||||
write_params.mode = WriteMode::Overwrite;
|
||||
@@ -768,8 +892,23 @@ impl ConnectionInternal for Database {
|
||||
}
|
||||
}
|
||||
|
||||
async fn do_open_table(&self, options: OpenTableBuilder) -> Result<Table> {
|
||||
async fn do_open_table(&self, mut options: OpenTableBuilder) -> Result<Table> {
|
||||
let table_uri = self.table_uri(&options.name)?;
|
||||
|
||||
// Inherit storage options from the connection
|
||||
let storage_options = options
|
||||
.lance_read_params
|
||||
.get_or_insert_with(Default::default)
|
||||
.store_options
|
||||
.get_or_insert_with(Default::default)
|
||||
.storage_options
|
||||
.get_or_insert_with(Default::default);
|
||||
for (key, value) in self.storage_options.iter() {
|
||||
if !storage_options.contains_key(key) {
|
||||
storage_options.insert(key.clone(), value.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let native_table = Arc::new(
|
||||
NativeTable::open_with_params(
|
||||
&table_uri,
|
||||
@@ -801,7 +940,10 @@ impl ConnectionInternal for Database {
|
||||
}
|
||||
|
||||
async fn drop_db(&self) -> Result<()> {
|
||||
todo!()
|
||||
self.object_store
|
||||
.remove_dir_all(self.base_path.clone())
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
//! LanceDB Table APIs
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -757,6 +758,8 @@ pub struct NativeTable {
|
||||
// the object store wrapper to use on write path
|
||||
store_wrapper: Option<Arc<dyn WrappingObjectStore>>,
|
||||
|
||||
storage_options: HashMap<String, String>,
|
||||
|
||||
// This comes from the connection options. We store here so we can pass down
|
||||
// to the dataset when we recreate it (for example, in checkout_latest).
|
||||
read_consistency_interval: Option<std::time::Duration>,
|
||||
@@ -822,6 +825,13 @@ impl NativeTable {
|
||||
None => params,
|
||||
};
|
||||
|
||||
let storage_options = params
|
||||
.store_options
|
||||
.clone()
|
||||
.unwrap_or_default()
|
||||
.storage_options
|
||||
.unwrap_or_default();
|
||||
|
||||
let dataset = DatasetBuilder::from_uri(uri)
|
||||
.with_read_params(params)
|
||||
.load()
|
||||
@@ -840,6 +850,7 @@ impl NativeTable {
|
||||
uri: uri.to_string(),
|
||||
dataset,
|
||||
store_wrapper: write_store_wrapper,
|
||||
storage_options,
|
||||
read_consistency_interval,
|
||||
})
|
||||
}
|
||||
@@ -908,6 +919,13 @@ impl NativeTable {
|
||||
None => params,
|
||||
};
|
||||
|
||||
let storage_options = params
|
||||
.store_params
|
||||
.clone()
|
||||
.unwrap_or_default()
|
||||
.storage_options
|
||||
.unwrap_or_default();
|
||||
|
||||
let dataset = Dataset::write(batches, uri, Some(params))
|
||||
.await
|
||||
.map_err(|e| match e {
|
||||
@@ -921,6 +939,7 @@ impl NativeTable {
|
||||
uri: uri.to_string(),
|
||||
dataset: DatasetConsistencyWrapper::new_latest(dataset, read_consistency_interval),
|
||||
store_wrapper: write_store_wrapper,
|
||||
storage_options,
|
||||
read_consistency_interval,
|
||||
})
|
||||
}
|
||||
@@ -1312,7 +1331,7 @@ impl TableInternal for NativeTable {
|
||||
add: AddDataBuilder<NoData>,
|
||||
data: Box<dyn RecordBatchReader + Send>,
|
||||
) -> Result<()> {
|
||||
let lance_params = add.write_options.lance_write_params.unwrap_or(WriteParams {
|
||||
let mut lance_params = add.write_options.lance_write_params.unwrap_or(WriteParams {
|
||||
mode: match add.mode {
|
||||
AddDataMode::Append => WriteMode::Append,
|
||||
AddDataMode::Overwrite => WriteMode::Overwrite,
|
||||
@@ -1320,6 +1339,18 @@ impl TableInternal for NativeTable {
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
// Bring storage options from table
|
||||
let storage_options = lance_params
|
||||
.store_params
|
||||
.get_or_insert(Default::default())
|
||||
.storage_options
|
||||
.get_or_insert(Default::default());
|
||||
for (key, value) in self.storage_options.iter() {
|
||||
if !storage_options.contains_key(key) {
|
||||
storage_options.insert(key.clone(), value.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// patch the params if we have a write store wrapper
|
||||
let lance_params = match self.store_wrapper.clone() {
|
||||
Some(wrapper) => lance_params.patch_with_store_wrapper(wrapper)?,
|
||||
|
||||
290
rust/lancedb/tests/object_store_test.rs
Normal file
290
rust/lancedb/tests/object_store_test.rs
Normal file
@@ -0,0 +1,290 @@
|
||||
// Copyright 2023 LanceDB Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#![cfg(feature = "s3-test")]
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator, StringArray};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
|
||||
use aws_config::{BehaviorVersion, ConfigLoader, Region, SdkConfig};
|
||||
use aws_sdk_s3::{config::Credentials, types::ServerSideEncryption, Client as S3Client};
|
||||
use lancedb::Result;
|
||||
|
||||
const CONFIG: &[(&str, &str)] = &[
|
||||
("access_key_id", "ACCESS_KEY"),
|
||||
("secret_access_key", "SECRET_KEY"),
|
||||
("endpoint", "http://127.0.0.1:4566"),
|
||||
("allow_http", "true"),
|
||||
];
|
||||
|
||||
async fn aws_config() -> SdkConfig {
|
||||
let credentials = Credentials::new(CONFIG[0].1, CONFIG[1].1, None, None, "static");
|
||||
ConfigLoader::default()
|
||||
.credentials_provider(credentials)
|
||||
.endpoint_url(CONFIG[2].1)
|
||||
.behavior_version(BehaviorVersion::latest())
|
||||
.region(Region::new("us-east-1"))
|
||||
.load()
|
||||
.await
|
||||
}
|
||||
|
||||
struct S3Bucket(String);
|
||||
|
||||
impl S3Bucket {
|
||||
async fn new(bucket: &str) -> Self {
|
||||
let config = aws_config().await;
|
||||
let client = S3Client::new(&config);
|
||||
|
||||
// In case it wasn't deleted earlier
|
||||
Self::delete_bucket(client.clone(), bucket).await;
|
||||
|
||||
client.create_bucket().bucket(bucket).send().await.unwrap();
|
||||
|
||||
Self(bucket.to_string())
|
||||
}
|
||||
|
||||
async fn delete_bucket(client: S3Client, bucket: &str) {
|
||||
// Before we delete the bucket, we need to delete all objects in it
|
||||
let res = client
|
||||
.list_objects_v2()
|
||||
.bucket(bucket)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|err| err.into_service_error());
|
||||
match res {
|
||||
Err(e) if e.is_no_such_bucket() => return,
|
||||
Err(e) => panic!("Failed to list objects in bucket: {}", e),
|
||||
_ => {}
|
||||
}
|
||||
let objects = res.unwrap().contents.unwrap_or_default();
|
||||
for object in objects {
|
||||
client
|
||||
.delete_object()
|
||||
.bucket(bucket)
|
||||
.key(object.key.unwrap())
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
client.delete_bucket().bucket(bucket).send().await.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for S3Bucket {
|
||||
fn drop(&mut self) {
|
||||
let bucket_name = self.0.clone();
|
||||
tokio::task::spawn(async move {
|
||||
let config = aws_config().await;
|
||||
let client = S3Client::new(&config);
|
||||
Self::delete_bucket(client, &bucket_name).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn test_data() -> RecordBatch {
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
Field::new("a", DataType::Int32, false),
|
||||
Field::new("b", DataType::Utf8, false),
|
||||
]));
|
||||
RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![
|
||||
Arc::new(Int32Array::from(vec![1, 2, 3])),
|
||||
Arc::new(StringArray::from(vec!["a", "b", "c"])),
|
||||
],
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_minio_lifecycle() -> Result<()> {
|
||||
// test create, update, drop, list on localstack minio
|
||||
let bucket = S3Bucket::new("test-bucket").await;
|
||||
let uri = format!("s3://{}", bucket.0);
|
||||
|
||||
let db = lancedb::connect(&uri)
|
||||
.storage_options(CONFIG.iter().cloned())
|
||||
.execute()
|
||||
.await?;
|
||||
|
||||
let data = test_data();
|
||||
let data = RecordBatchIterator::new(vec![Ok(data.clone())], data.schema());
|
||||
|
||||
let table = db.create_table("test_table", data).execute().await?;
|
||||
|
||||
let row_count = table.count_rows(None).await?;
|
||||
assert_eq!(row_count, 3);
|
||||
|
||||
let table_names = db.table_names().execute().await?;
|
||||
assert_eq!(table_names, vec!["test_table"]);
|
||||
|
||||
// Re-open the table
|
||||
let table = db.open_table("test_table").execute().await?;
|
||||
let row_count = table.count_rows(None).await?;
|
||||
assert_eq!(row_count, 3);
|
||||
|
||||
let data = test_data();
|
||||
let data = RecordBatchIterator::new(vec![Ok(data.clone())], data.schema());
|
||||
table.add(data).execute().await?;
|
||||
|
||||
db.drop_table("test_table").await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct KMSKey(String);
|
||||
|
||||
impl KMSKey {
|
||||
async fn new() -> Self {
|
||||
let config = aws_config().await;
|
||||
let client = aws_sdk_kms::Client::new(&config);
|
||||
let key = client
|
||||
.create_key()
|
||||
.description("test key")
|
||||
.send()
|
||||
.await
|
||||
.unwrap()
|
||||
.key_metadata
|
||||
.unwrap()
|
||||
.key_id;
|
||||
Self(key)
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for KMSKey {
|
||||
fn drop(&mut self) {
|
||||
let key_id = self.0.clone();
|
||||
tokio::task::spawn(async move {
|
||||
let config = aws_config().await;
|
||||
let client = aws_sdk_kms::Client::new(&config);
|
||||
client
|
||||
.schedule_key_deletion()
|
||||
.key_id(&key_id)
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
async fn validate_objects_encrypted(bucket: &str, path: &str, kms_key_id: &str) {
|
||||
// Get S3 client
|
||||
let config = aws_config().await;
|
||||
let client = S3Client::new(&config);
|
||||
|
||||
// list the objects are the path
|
||||
let objects = client
|
||||
.list_objects_v2()
|
||||
.bucket(bucket)
|
||||
.prefix(path)
|
||||
.send()
|
||||
.await
|
||||
.unwrap()
|
||||
.contents
|
||||
.unwrap();
|
||||
|
||||
let mut errors = vec![];
|
||||
let mut correctly_encrypted = vec![];
|
||||
|
||||
// For each object, call head
|
||||
for object in &objects {
|
||||
let head = client
|
||||
.head_object()
|
||||
.bucket(bucket)
|
||||
.key(object.key().unwrap())
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify the object is encrypted
|
||||
if head.server_side_encryption() != Some(&ServerSideEncryption::AwsKms) {
|
||||
errors.push(format!("Object {} is not encrypted", object.key().unwrap()));
|
||||
continue;
|
||||
}
|
||||
if !(head
|
||||
.ssekms_key_id()
|
||||
.map(|arn| arn.ends_with(kms_key_id))
|
||||
.unwrap_or(false))
|
||||
{
|
||||
errors.push(format!(
|
||||
"Object {} has wrong key id: {:?}, vs expected: {}",
|
||||
object.key().unwrap(),
|
||||
head.ssekms_key_id(),
|
||||
kms_key_id
|
||||
));
|
||||
continue;
|
||||
}
|
||||
correctly_encrypted.push(object.key().unwrap().to_string());
|
||||
}
|
||||
|
||||
if !errors.is_empty() {
|
||||
panic!(
|
||||
"{} of {} correctly encrypted: {:?}\n{} of {} not correct: {:?}",
|
||||
correctly_encrypted.len(),
|
||||
objects.len(),
|
||||
correctly_encrypted,
|
||||
errors.len(),
|
||||
objects.len(),
|
||||
errors
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_encryption() -> Result<()> {
|
||||
// test encryption on localstack minio
|
||||
let bucket = S3Bucket::new("test-encryption").await;
|
||||
let key = KMSKey::new().await;
|
||||
|
||||
let uri = format!("s3://{}", bucket.0);
|
||||
let db = lancedb::connect(&uri)
|
||||
.storage_options(CONFIG.iter().cloned())
|
||||
.execute()
|
||||
.await?;
|
||||
|
||||
// Create a table with encryption
|
||||
let data = test_data();
|
||||
let data = RecordBatchIterator::new(vec![Ok(data.clone())], data.schema());
|
||||
|
||||
let mut builder = db.create_table("test_table", data);
|
||||
for (key, value) in CONFIG {
|
||||
builder = builder.storage_option(*key, *value);
|
||||
}
|
||||
let table = builder
|
||||
.storage_option("aws_server_side_encryption", "aws:kms")
|
||||
.storage_option("aws_sse_kms_key_id", &key.0)
|
||||
.execute()
|
||||
.await?;
|
||||
validate_objects_encrypted(&bucket.0, "test_table", &key.0).await;
|
||||
|
||||
table.delete("a = 1").await?;
|
||||
validate_objects_encrypted(&bucket.0, "test_table", &key.0).await;
|
||||
|
||||
// Test we can set encryption at the connection level.
|
||||
let db = lancedb::connect(&uri)
|
||||
.storage_options(CONFIG.iter().cloned())
|
||||
.storage_option("aws_server_side_encryption", "aws:kms")
|
||||
.storage_option("aws_sse_kms_key_id", &key.0)
|
||||
.execute()
|
||||
.await?;
|
||||
|
||||
let table = db.open_table("test_table").execute().await?;
|
||||
|
||||
let data = test_data();
|
||||
let data = RecordBatchIterator::new(vec![Ok(data.clone())], data.schema());
|
||||
table.add(data).execute().await?;
|
||||
validate_objects_encrypted(&bucket.0, "test_table", &key.0).await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Reference in New Issue
Block a user