From 6a50d719207088c9d77f634cac7b966057e5dc01 Mon Sep 17 00:00:00 2001 From: "Lei, HUANG" <6406592+v0y4g3r@users.noreply.github.com> Date: Mon, 14 Apr 2025 21:15:56 +0800 Subject: [PATCH] fix: memtable panic (#5894) * fix: memtable panic * fix: ci --- src/mito2/src/memtable/time_series.rs | 196 ++++++++++++++++++++++---- 1 file changed, 171 insertions(+), 25 deletions(-) diff --git a/src/mito2/src/memtable/time_series.rs b/src/mito2/src/memtable/time_series.rs index 82758a542b..44bce1ec74 100644 --- a/src/mito2/src/memtable/time_series.rs +++ b/src/mito2/src/memtable/time_series.rs @@ -161,18 +161,15 @@ impl TimeSeriesMemtable { let primary_key_encoded = self.row_codec.encode(kv.primary_keys())?; - let (series, series_allocated) = self.series_set.get_or_add_series(primary_key_encoded); - stats.key_bytes += series_allocated; + let (key_allocated, value_allocated) = + self.series_set.push_to_series(primary_key_encoded, &kv); + stats.key_bytes += key_allocated; + stats.value_bytes += value_allocated; // safety: timestamp of kv must be both present and a valid timestamp value. let ts = kv.timestamp().as_timestamp().unwrap().unwrap().value(); stats.min_ts = stats.min_ts.min(ts); stats.max_ts = stats.max_ts.max(ts); - - let mut guard = series.write().unwrap(); - let size = guard.push(kv.timestamp(), kv.sequence(), kv.op_type(), kv.fields()); - stats.value_bytes += size; - Ok(()) } } @@ -368,25 +365,46 @@ impl SeriesSet { } impl SeriesSet { - /// Returns the series for given primary key, or create a new series if not already exist, - /// along with the allocated memory footprint for primary keys. - fn get_or_add_series(&self, primary_key: Vec) -> (Arc>, usize) { + /// Push [KeyValue] to SeriesSet with given primary key and return key/value allocated memory size. + fn push_to_series(&self, primary_key: Vec, kv: &KeyValue) -> (usize, usize) { if let Some(series) = self.series.read().unwrap().get(&primary_key) { - return (series.clone(), 0); + let value_allocated = series.write().unwrap().push( + kv.timestamp(), + kv.sequence(), + kv.op_type(), + kv.fields(), + ); + return (0, value_allocated); }; - let s = Arc::new(RwLock::new(Series::new(&self.region_metadata))); + let mut indices = self.series.write().unwrap(); match indices.entry(primary_key) { Entry::Vacant(v) => { let key_len = v.key().len(); - v.insert(s.clone()); - (s, key_len) + let mut series = Series::new(&self.region_metadata); + let value_allocated = + series.push(kv.timestamp(), kv.sequence(), kv.op_type(), kv.fields()); + v.insert(Arc::new(RwLock::new(series))); + (key_len, value_allocated) } // safety: series must exist at given index. - Entry::Occupied(v) => (v.get().clone(), 0), + Entry::Occupied(v) => { + let value_allocated = v.get().write().unwrap().push( + kv.timestamp(), + kv.sequence(), + kv.op_type(), + kv.fields(), + ); + (0, value_allocated) + } } } + #[cfg(test)] + fn get_series(&self, primary_key: &[u8]) -> Option>> { + self.series.read().unwrap().get(primary_key).cloned() + } + /// Iterates all series in [SeriesSet]. fn iter_series( &self, @@ -948,7 +966,7 @@ mod tests { use api::helper::ColumnDataTypeWrapper; use api::v1::value::ValueData; - use api::v1::{Row, Rows, SemanticType}; + use api::v1::{Mutation, Row, Rows, SemanticType}; use common_time::Timestamp; use datatypes::prelude::{ConcreteDataType, ScalarVector}; use datatypes::schema::ColumnSchema; @@ -959,6 +977,7 @@ mod tests { use super::*; use crate::row_converter::SortField; + use crate::test_util::column_metadata_to_column_schema; fn schema_for_test() -> RegionMetadataRef { let mut builder = RegionMetadataBuilder::new(RegionId::new(123, 456)); @@ -1242,18 +1261,54 @@ mod tests { let mut handles = Vec::with_capacity(concurrency); for i in 0..concurrency { let set = set.clone(); + let schema = schema.clone(); + let column_schemas = schema + .column_metadatas + .iter() + .map(column_metadata_to_column_schema) + .collect::>(); let handle = std::thread::spawn(move || { for j in i * 100..(i + 1) * 100 { let pk = j % pk_num; let primary_key = format!("pk-{}", pk).as_bytes().to_vec(); - let (series, _) = set.get_or_add_series(primary_key); - let mut guard = series.write().unwrap(); - guard.push( - ts_value_ref(j as i64), - j as u64, - OpType::Put, - field_value_ref(j as i64, j as f64), - ); + + let kvs = KeyValues::new( + &schema, + Mutation { + op_type: OpType::Put as i32, + sequence: j as u64, + rows: Some(Rows { + schema: column_schemas.clone(), + rows: vec![Row { + values: vec![ + api::v1::Value { + value_data: Some(ValueData::StringValue(format!( + "{}", + j + ))), + }, + api::v1::Value { + value_data: Some(ValueData::I64Value(j as i64)), + }, + api::v1::Value { + value_data: Some(ValueData::TimestampMillisecondValue( + j as i64, + )), + }, + api::v1::Value { + value_data: Some(ValueData::I64Value(j as i64)), + }, + api::v1::Value { + value_data: Some(ValueData::F64Value(j as f64)), + }, + ], + }], + }), + write_hint: None, + }, + ) + .unwrap(); + set.push_to_series(primary_key, &kvs.iter().next().unwrap()); } }); handles.push(handle); @@ -1269,7 +1324,7 @@ mod tests { for i in 0..pk_num { let pk = format!("pk-{}", i).as_bytes().to_vec(); - let (series, _) = set.get_or_add_series(pk); + let series = set.get_series(&pk).unwrap(); let mut guard = series.write().unwrap(); let values = guard.compact(&schema).unwrap(); timestamps.extend(values.sequence.iter_data().map(|v| v.unwrap() as i64)); @@ -1385,4 +1440,95 @@ mod tests { } assert_eq!((0..100i64).collect::>(), v0_all); } + + #[test] + fn test_memtable_concurrent_write_read() { + common_telemetry::init_default_ut_logging(); + let schema = schema_for_test(); + let memtable = Arc::new(TimeSeriesMemtable::new( + schema.clone(), + 42, + None, + true, + MergeMode::LastRow, + )); + + // Number of writer threads + let num_writers = 10; + // Number of reader threads + let num_readers = 5; + // Number of series per writer + let series_per_writer = 100; + // Number of rows per series + let rows_per_series = 10; + // Total number of series + let total_series = num_writers * series_per_writer; + + // Create a barrier to synchronize the start of all threads + let barrier = Arc::new(std::sync::Barrier::new(num_writers + num_readers + 1)); + + // Spawn writer threads + let mut writer_handles = Vec::with_capacity(num_writers); + for writer_id in 0..num_writers { + let memtable = memtable.clone(); + let schema = schema.clone(); + let barrier = barrier.clone(); + + let handle = std::thread::spawn(move || { + // Wait for all threads to be ready + barrier.wait(); + + // Create and write series + for series_id in 0..series_per_writer { + let series_key = format!("writer-{}-series-{}", writer_id, series_id); + let kvs = + build_key_values(&schema, series_key, series_id as i64, rows_per_series); + memtable.write(&kvs).unwrap(); + } + }); + + writer_handles.push(handle); + } + + // Spawn reader threads + let mut reader_handles = Vec::with_capacity(num_readers); + for _ in 0..num_readers { + let memtable = memtable.clone(); + let barrier = barrier.clone(); + + let handle = std::thread::spawn(move || { + barrier.wait(); + + for _ in 0..10 { + let iter = memtable.iter(None, None, None).unwrap(); + for batch_result in iter { + let _ = batch_result.unwrap(); + } + } + }); + + reader_handles.push(handle); + } + + barrier.wait(); + + for handle in writer_handles { + handle.join().unwrap(); + } + for handle in reader_handles { + handle.join().unwrap(); + } + + let iter = memtable.iter(None, None, None).unwrap(); + let mut series_count = 0; + let mut row_count = 0; + + for batch_result in iter { + let batch = batch_result.unwrap(); + series_count += 1; + row_count += batch.num_rows(); + } + assert_eq!(total_series, series_count); + assert_eq!(total_series * rows_per_series, row_count); + } }