diff --git a/rust/lancedb/src/connection.rs b/rust/lancedb/src/connection.rs index cf0e6899..390a7604 100644 --- a/rust/lancedb/src/connection.rs +++ b/rust/lancedb/src/connection.rs @@ -142,12 +142,6 @@ impl CreateTableBuilder { } } - /// Apply the given write options when writing the initial data - pub fn write_options(mut self, write_options: WriteOptions) -> Self { - self.request.write_options = write_options; - self - } - /// Execute the create table operation pub async fn execute(self) -> Result { let embedding_registry = self.embedding_registry.clone(); @@ -229,6 +223,12 @@ impl CreateTableBuilder { self } + /// Apply the given write options when writing the initial data + pub fn write_options(mut self, write_options: WriteOptions) -> Self { + self.request.write_options = write_options; + self + } + /// Set an option for the storage layer. /// /// Options already set on the connection will be inherited by the table, diff --git a/rust/lancedb/src/io/object_store.rs b/rust/lancedb/src/io/object_store.rs index 9d4b0dca..66fa7053 100644 --- a/rust/lancedb/src/io/object_store.rs +++ b/rust/lancedb/src/io/object_store.rs @@ -14,6 +14,9 @@ use object_store::{ use async_trait::async_trait; +#[cfg(test)] +pub mod io_tracking; + #[derive(Debug)] struct MirroringObjectStore { primary: Arc, diff --git a/rust/lancedb/src/io/object_store/io_tracking.rs b/rust/lancedb/src/io/object_store/io_tracking.rs new file mode 100644 index 00000000..71c68068 --- /dev/null +++ b/rust/lancedb/src/io/object_store/io_tracking.rs @@ -0,0 +1,237 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +use std::{ + fmt::{Display, Formatter}, + sync::{Arc, Mutex}, +}; + +use bytes::Bytes; +use futures::stream::BoxStream; +use lance::io::WrappingObjectStore; +use object_store::{ + path::Path, GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, ObjectStore, + PutMultipartOpts, PutOptions, PutPayload, PutResult, Result as OSResult, UploadPart, +}; + +#[derive(Debug, Default)] +pub struct IoStats { + pub read_iops: u64, + pub read_bytes: u64, + pub write_iops: u64, + pub write_bytes: u64, +} + +impl Display for IoStats { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{:#?}", self) + } +} + +#[derive(Debug, Clone)] +pub struct IoTrackingStore { + target: Arc, + stats: Arc>, +} + +impl Display for IoTrackingStore { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{:#?}", self) + } +} + +#[derive(Debug, Default, Clone)] +pub struct IoStatsHolder(Arc>); + +impl IoStatsHolder { + pub fn incremental_stats(&self) -> IoStats { + std::mem::take(&mut self.0.lock().expect("failed to lock IoStats")) + } +} + +impl WrappingObjectStore for IoStatsHolder { + fn wrap(&self, target: Arc) -> Arc { + Arc::new(IoTrackingStore { + target, + stats: self.0.clone(), + }) + } +} + +impl IoTrackingStore { + pub fn new_wrapper() -> (Arc, Arc>) { + let stats = Arc::new(Mutex::new(IoStats::default())); + (Arc::new(IoStatsHolder(stats.clone())), stats) + } + + fn record_read(&self, num_bytes: u64) { + let mut stats = self.stats.lock().unwrap(); + stats.read_iops += 1; + stats.read_bytes += num_bytes; + } + + fn record_write(&self, num_bytes: u64) { + let mut stats = self.stats.lock().unwrap(); + stats.write_iops += 1; + stats.write_bytes += num_bytes; + } +} + +#[async_trait::async_trait] +#[deny(clippy::missing_trait_methods)] +impl ObjectStore for IoTrackingStore { + async fn put(&self, location: &Path, bytes: PutPayload) -> OSResult { + self.record_write(bytes.content_length() as u64); + self.target.put(location, bytes).await + } + + async fn put_opts( + &self, + location: &Path, + bytes: PutPayload, + opts: PutOptions, + ) -> OSResult { + self.record_write(bytes.content_length() as u64); + self.target.put_opts(location, bytes, opts).await + } + + async fn put_multipart(&self, location: &Path) -> OSResult> { + let target = self.target.put_multipart(location).await?; + Ok(Box::new(IoTrackingMultipartUpload { + target, + stats: self.stats.clone(), + })) + } + + async fn put_multipart_opts( + &self, + location: &Path, + opts: PutMultipartOpts, + ) -> OSResult> { + let target = self.target.put_multipart_opts(location, opts).await?; + Ok(Box::new(IoTrackingMultipartUpload { + target, + stats: self.stats.clone(), + })) + } + + async fn get(&self, location: &Path) -> OSResult { + let result = self.target.get(location).await; + if let Ok(result) = &result { + let num_bytes = result.range.end - result.range.start; + self.record_read(num_bytes as u64); + } + result + } + + async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult { + let result = self.target.get_opts(location, options).await; + if let Ok(result) = &result { + let num_bytes = result.range.end - result.range.start; + self.record_read(num_bytes as u64); + } + result + } + + async fn get_range(&self, location: &Path, range: std::ops::Range) -> OSResult { + let result = self.target.get_range(location, range).await; + if let Ok(result) = &result { + self.record_read(result.len() as u64); + } + result + } + + async fn get_ranges( + &self, + location: &Path, + ranges: &[std::ops::Range], + ) -> OSResult> { + let result = self.target.get_ranges(location, ranges).await; + if let Ok(result) = &result { + self.record_read(result.iter().map(|b| b.len() as u64).sum()); + } + result + } + + async fn head(&self, location: &Path) -> OSResult { + self.record_read(0); + self.target.head(location).await + } + + async fn delete(&self, location: &Path) -> OSResult<()> { + self.record_write(0); + self.target.delete(location).await + } + + fn delete_stream<'a>( + &'a self, + locations: BoxStream<'a, OSResult>, + ) -> BoxStream<'a, OSResult> { + self.target.delete_stream(locations) + } + + fn list(&self, prefix: Option<&Path>) -> BoxStream<'_, OSResult> { + self.record_read(0); + self.target.list(prefix) + } + + fn list_with_offset( + &self, + prefix: Option<&Path>, + offset: &Path, + ) -> BoxStream<'_, OSResult> { + self.record_read(0); + self.target.list_with_offset(prefix, offset) + } + + async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult { + self.record_read(0); + self.target.list_with_delimiter(prefix).await + } + + async fn copy(&self, from: &Path, to: &Path) -> OSResult<()> { + self.record_write(0); + self.target.copy(from, to).await + } + + async fn rename(&self, from: &Path, to: &Path) -> OSResult<()> { + self.record_write(0); + self.target.rename(from, to).await + } + + async fn rename_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> { + self.record_write(0); + self.target.rename_if_not_exists(from, to).await + } + + async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> { + self.record_write(0); + self.target.copy_if_not_exists(from, to).await + } +} + +#[derive(Debug)] +struct IoTrackingMultipartUpload { + target: Box, + stats: Arc>, +} + +#[async_trait::async_trait] +impl MultipartUpload for IoTrackingMultipartUpload { + async fn abort(&mut self) -> OSResult<()> { + self.target.abort().await + } + + async fn complete(&mut self) -> OSResult { + self.target.complete().await + } + + fn put_part(&mut self, payload: PutPayload) -> UploadPart { + { + let mut stats = self.stats.lock().unwrap(); + stats.write_iops += 1; + stats.write_bytes += payload.content_length() as u64; + } + self.target.put_part(payload) + } +} diff --git a/rust/lancedb/src/table/dataset.rs b/rust/lancedb/src/table/dataset.rs index fa797ef3..48a4e123 100644 --- a/rust/lancedb/src/table/dataset.rs +++ b/rust/lancedb/src/table/dataset.rs @@ -48,7 +48,6 @@ impl DatasetRef { refresh_task, .. } => { - dataset.checkout_latest().await?; // Replace the refresh task if let Some(refresh_task) = refresh_task { refresh_task.abort(); @@ -372,3 +371,48 @@ impl DerefMut for DatasetWriteGuard<'_> { } } } + +#[cfg(test)] +mod tests { + use arrow_schema::{DataType, Field, Schema}; + use lance::{dataset::WriteParams, io::ObjectStoreParams}; + + use super::*; + + use crate::{connect, io::object_store::io_tracking::IoStatsHolder, table::WriteOptions}; + + #[tokio::test] + async fn test_iops_open_strong_consistency() { + let db = connect("memory://") + .read_consistency_interval(Some(Duration::ZERO)) + .execute() + .await + .expect("Failed to connect to database"); + let io_stats = IoStatsHolder::default(); + + let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + + let table = db + .create_empty_table("test", schema) + .write_options(WriteOptions { + lance_write_params: Some(WriteParams { + store_params: Some(ObjectStoreParams { + object_store_wrapper: Some(Arc::new(io_stats.clone())), + ..Default::default() + }), + ..Default::default() + }), + }) + .execute() + .await + .unwrap(); + + io_stats.incremental_stats(); + + // We should only need 1 read IOP to check the schema: looking for the + // latest version. + table.schema().await.unwrap(); + let stats = io_stats.incremental_stats(); + assert_eq!(stats.read_iops, 1); + } +}