diff --git a/src/cli/src/data/import_v2/command.rs b/src/cli/src/data/import_v2/command.rs index ed9727cd8a..43bef46be0 100644 --- a/src/cli/src/data/import_v2/command.rs +++ b/src/cli/src/data/import_v2/command.rs @@ -71,6 +71,10 @@ pub struct ImportV2Command { #[clap(long, value_enum, default_value_t = ProgressMode::Auto)] progress: ProgressMode, + /// Number of import data tasks to run concurrently on the client (1..=64). + #[clap(long, default_value = "1", value_parser = parse_task_parallelism)] + task_parallelism: usize, + /// Basic authentication (user:password). #[clap(long)] auth_basic: Option, @@ -132,6 +136,7 @@ impl ImportV2Command { schemas, dry_run: self.dry_run, progress: self.progress, + task_parallelism: self.task_parallelism, snapshot_uri: self.from.clone(), storage_config: self.storage.clone(), storage: Box::new(storage), @@ -140,12 +145,24 @@ impl ImportV2Command { } } +fn parse_task_parallelism(value: &str) -> std::result::Result { + let parallelism = value + .parse::() + .map_err(|_| "task parallelism must be an integer between 1 and 64".to_string())?; + if (1..=64).contains(¶llelism) { + Ok(parallelism) + } else { + Err("task parallelism must be between 1 and 64".to_string()) + } +} + /// Import tool implementation. pub struct Import { catalog: String, schemas: Option>, dry_run: bool, progress: ProgressMode, + task_parallelism: usize, snapshot_uri: String, storage_config: ObjectStoreConfig, storage: Box, @@ -241,6 +258,7 @@ impl Import { schemas: schemas_to_import.clone(), state_path, tasks: data_tasks, + task_parallelism: self.task_parallelism, }) .await?, ) @@ -735,6 +753,42 @@ mod tests { ); } + #[test] + fn test_task_parallelism_defaults_to_one() { + assert_eq!(parse_command(&[]).task_parallelism, 1); + } + + #[test] + fn test_task_parallelism_parses_valid_values() { + assert_eq!( + parse_command(&["--task-parallelism", "2"]).task_parallelism, + 2 + ); + assert_eq!( + parse_command(&["--task-parallelism", "64"]).task_parallelism, + 64 + ); + } + + #[test] + fn test_task_parallelism_rejects_invalid_values() { + for value in ["0", "65", "abc"] { + assert!( + ImportV2Command::try_parse_from([ + "import-v2", + "--addr", + "127.0.0.1:4000", + "--from", + "file:///tmp/snapshot", + "--task-parallelism", + value, + ]) + .is_err(), + "value {value} should be rejected" + ); + } + } + #[test] fn test_parse_ddl_statements() { let content = r#" diff --git a/src/cli/src/data/import_v2/coordinator.rs b/src/cli/src/data/import_v2/coordinator.rs index 58b1bdf4d7..ff1cd2c336 100644 --- a/src/cli/src/data/import_v2/coordinator.rs +++ b/src/cli/src/data/import_v2/coordinator.rs @@ -18,10 +18,12 @@ use std::time::Instant; use async_trait::async_trait; use common_telemetry::{info, warn}; +use futures::StreamExt; +use futures::stream::FuturesUnordered; use crate::data::export_v2::manifest::{ChunkMeta, ChunkStatus}; use crate::data::import_v2::error::{ - ImportStateDdlIncompleteSnafu, ImportStateMismatchSnafu, Result, + Error, ImportStateDdlIncompleteSnafu, ImportStateMismatchSnafu, Result, }; use crate::data::import_v2::state::{ ImportState, ImportStateLockGuard, ImportTaskKey, ImportTaskStatus, canonical_schema_selection, @@ -44,6 +46,8 @@ pub(crate) struct ImportResumeConfig { pub(crate) schemas: Vec, pub(crate) state_path: PathBuf, pub(crate) tasks: Vec, + /// Number of data tasks to run concurrently. `1` preserves serial behavior. + pub(crate) task_parallelism: usize, } pub(crate) struct ImportResumeSession { @@ -187,6 +191,39 @@ where ); let import_start = Instant::now(); + let result = if config.task_parallelism <= 1 { + import_tasks_serial(&config, &mut state, executor, progress).await + } else { + import_tasks_concurrent( + &config, + &mut state, + executor, + config.task_parallelism, + progress, + ) + .await + }; + progress_phase.finish(); + + // On failure, leave the state file in place so a later run can resume. + result?; + + delete_import_state(&config.state_path).await?; + info!("Data import finished in {:?}", import_start.elapsed()); + drop(lock); + Ok(()) +} + +/// Imports data tasks one at a time, preserving the original serial behavior. +async fn import_tasks_serial( + config: &ImportResumeConfig, + state: &mut ImportState, + executor: &E, + progress: &dyn ProgressReporter, +) -> Result<()> +where + E: ImportTaskExecutor + Sync, +{ for (idx, task) in config.tasks.iter().enumerate() { if state.task_status(task.chunk_id, &task.schema) == Some(ImportTaskStatus::Completed) { info!( @@ -213,7 +250,7 @@ where ImportTaskStatus::InProgress, None, )?; - save_import_state(&config.state_path, &state).await?; + save_import_state(&config.state_path, state).await?; let task_start = Instant::now(); let result = executor.import_task(task).await; @@ -225,14 +262,8 @@ where // duplicating data depending on engine semantics), but we must // not pretend the import as a whole failed - return the persist // error so the operator notices, after logging the success. - update_status_and_save( - &config, - &mut state, - task, - ImportTaskStatus::Completed, - None, - ) - .await?; + update_status_and_save(config, state, task, ImportTaskStatus::Completed, None) + .await?; info!( "[{}/{}] Chunk {} schema {}: done in {:?}", idx + 1, @@ -247,32 +278,166 @@ where // Persist Failed best-effort, but always surface the original // task error to the caller. State persistence problems are // logged so they are not silently lost. - if let Err(persist_error) = update_status_and_save( - &config, - &mut state, - task, - ImportTaskStatus::Failed, - Some(task_error.to_string()), - ) - .await - { - warn!( - "Failed to persist Failed status for chunk {} schema {} after task error ({}); state file may be out of date: {}", - task.chunk_id, task.schema, task_error, persist_error - ); - } + persist_failed_best_effort(config, state, task, &task_error).await; return Err(task_error); } } } - progress_phase.finish(); - delete_import_state(&config.state_path).await?; - info!("Data import finished in {:?}", import_start.elapsed()); - drop(lock); Ok(()) } +/// Imports up to `task_parallelism` data tasks concurrently on the client. +/// +/// The coordinator owns all state mutation/persistence: it marks tasks +/// `InProgress` and persists the state before polling their futures, then +/// applies each task result and persists again on completion. The task futures +/// only run the import; they never touch the state, so state writes stay +/// serialized in this task. +/// +/// On the first task failure we stop scheduling new tasks but let already +/// in-flight tasks finish and persist their final status, then return the first +/// error. +async fn import_tasks_concurrent( + config: &ImportResumeConfig, + state: &mut ImportState, + executor: &E, + task_parallelism: usize, + progress: &dyn ProgressReporter, +) -> Result<()> +where + E: ImportTaskExecutor + Sync, +{ + let mut pending = FuturesUnordered::new(); + let mut next_idx = 0; + let mut first_error: Option = None; + + loop { + let mut scheduled = false; + + // Schedule eligible tasks in order up to the parallelism limit. Once a + // failure is seen, stop scheduling but keep draining in-flight tasks. + while first_error.is_none() && pending.len() < task_parallelism { + let Some(idx) = next_pending_task(config, state, &mut next_idx, progress) else { + break; + }; + + let task = &config.tasks[idx]; + info!( + "[{}/{}] Chunk {} schema {}: importing...", + idx + 1, + config.tasks.len(), + task.chunk_id, + task.schema + ); + state.set_task_status( + task.chunk_id, + &task.schema, + ImportTaskStatus::InProgress, + None, + )?; + scheduled = true; + + pending.push(async move { + let result = executor.import_task(task).await; + (idx, result) + }); + } + + if scheduled { + save_import_state(&config.state_path, state).await?; + } + + let Some((idx, task_result)) = pending.next().await else { + break; + }; + + let task = &config.tasks[idx]; + match task_result { + Ok(()) => { + // The task itself succeeded. If we cannot persist the Completed + // marker, surface the persist error so the operator notices, + // after logging the success. + update_status_and_save(config, state, task, ImportTaskStatus::Completed, None) + .await?; + info!( + "[{}/{}] Chunk {} schema {}: done", + idx + 1, + config.tasks.len(), + task.chunk_id, + task.schema + ); + progress.inc(1); + } + Err(task_error) => { + // Persist Failed best-effort, but stop scheduling and remember + // the first error to return after draining in-flight tasks. + persist_failed_best_effort(config, state, task, &task_error).await; + if first_error.is_none() { + first_error = Some(task_error); + } + } + } + } + + match first_error { + Some(err) => Err(err), + None => Ok(()), + } +} + +/// Returns the index of the next task eligible for import, scanning forward from +/// `next_idx` and skipping tasks already marked `Completed` (counting each +/// skipped task once toward progress). Advances `next_idx` past the returned +/// task so each task is scheduled at most once. +fn next_pending_task( + config: &ImportResumeConfig, + state: &ImportState, + next_idx: &mut usize, + progress: &dyn ProgressReporter, +) -> Option { + while *next_idx < config.tasks.len() { + let idx = *next_idx; + *next_idx += 1; + let task = &config.tasks[idx]; + if state.task_status(task.chunk_id, &task.schema) == Some(ImportTaskStatus::Completed) { + info!( + "[{}/{}] Chunk {} schema {}: already completed, skipped", + idx + 1, + config.tasks.len(), + task.chunk_id, + task.schema + ); + progress.inc(1); + continue; + } + return Some(idx); + } + None +} + +async fn persist_failed_best_effort( + config: &ImportResumeConfig, + state: &mut ImportState, + task: &ImportTaskKey, + task_error: &Error, +) { + if let Err(persist_error) = update_status_and_save( + config, + state, + task, + ImportTaskStatus::Failed, + Some(task_error.to_string()), + ) + .await + { + warn!( + "Failed to persist Failed status for chunk {} schema {} after task error ({}); state file may be out of date: {}", + task.chunk_id, task.schema, task_error, persist_error + ); + } +} + async fn update_status_and_save( config: &ImportResumeConfig, state: &mut ImportState, @@ -445,6 +610,14 @@ mod tests { } fn config(path: PathBuf, tasks: Vec) -> ImportResumeConfig { + config_with_parallelism(path, tasks, 1) + } + + fn config_with_parallelism( + path: PathBuf, + tasks: Vec, + task_parallelism: usize, + ) -> ImportResumeConfig { ImportResumeConfig { snapshot_id: "snapshot-1".to_string(), target_addr: "127.0.0.1:4000".to_string(), @@ -452,6 +625,7 @@ mod tests { schemas: vec!["public".to_string(), "analytics".to_string()], state_path: path, tasks, + task_parallelism, } } @@ -807,4 +981,233 @@ mod tests { assert_eq!(progress.total_inc(), 0); assert!(progress.events().contains(&ProgressEvent::FinishPhase)); } + + /// Executor that records the maximum number of concurrently in-flight tasks + /// observed. Each task yields a few times so siblings get scheduled before + /// it completes, making the observed maximum a faithful proxy for the + /// coordinator's in-flight limit. + struct ConcurrencyTrackingExecutor { + imported: Arc>>, + in_flight: Arc, + max_in_flight: Arc, + } + + #[async_trait] + impl ImportTaskExecutor for ConcurrencyTrackingExecutor { + async fn import_task(&self, task: &ImportTaskKey) -> Result<()> { + let current = self.in_flight.fetch_add(1, Ordering::SeqCst) + 1; + self.max_in_flight.fetch_max(current, Ordering::SeqCst); + for _ in 0..8 { + tokio::task::yield_now().await; + } + self.in_flight.fetch_sub(1, Ordering::SeqCst); + self.imported.lock().unwrap().push(task.clone()); + Ok(()) + } + } + + fn ddl_completed_state(tasks: &[ImportTaskKey]) -> ImportState { + let mut state = ImportState::new( + "snapshot-1", + "127.0.0.1:4000", + "greptime", + &["public".to_string(), "analytics".to_string()], + tasks.to_vec(), + ); + state.mark_ddl_completed(); + state + } + + #[tokio::test] + async fn test_import_concurrent_caps_in_flight_tasks() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("import_state.json"); + let tasks: Vec = (0..8).map(|id| ImportTaskKey::new(id, "public")).collect(); + save_import_state(&path, &ddl_completed_state(&tasks)) + .await + .unwrap(); + + let imported = Arc::new(Mutex::new(Vec::new())); + let max_in_flight = Arc::new(AtomicUsize::new(0)); + let executor = ConcurrencyTrackingExecutor { + imported: imported.clone(), + in_flight: Arc::new(AtomicUsize::new(0)), + max_in_flight: max_in_flight.clone(), + }; + + let session = + prepare_import_resume(config_with_parallelism(path.clone(), tasks.clone(), 3)) + .await + .unwrap(); + import_with_resume_session(session, &executor) + .await + .unwrap(); + + // Never exceed the requested parallelism, but actually run concurrently. + let observed = max_in_flight.load(Ordering::SeqCst); + assert!( + observed <= 3, + "observed {observed} in-flight, expected <= 3" + ); + assert!( + observed >= 2, + "observed {observed} in-flight, expected concurrency" + ); + // All tasks imported and the state file is cleaned up on success. + assert_eq!(imported.lock().unwrap().len(), tasks.len()); + assert!(load_import_state(&path).await.unwrap().is_none()); + } + + #[tokio::test] + async fn test_import_concurrent_skips_completed_and_counts_progress_once() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("import_state.json"); + let tasks = vec![ + ImportTaskKey::new(1, "public"), + ImportTaskKey::new(2, "analytics"), + ImportTaskKey::new(3, "public"), + ]; + let mut state = ddl_completed_state(&tasks); + state + .set_task_status(1, "public", ImportTaskStatus::Completed, None) + .unwrap(); + save_import_state(&path, &state).await.unwrap(); + + let imported = Arc::new(Mutex::new(Vec::new())); + let executor = recording_executor(imported.clone()); + let progress = RecordingProgress::default(); + + let session = prepare_import_resume(config_with_parallelism(path.clone(), tasks, 4)) + .await + .unwrap(); + import_with_resume_session_with_progress(session, &executor, &progress) + .await + .unwrap(); + + // The already-completed task is not re-imported. + let mut imported = imported.lock().unwrap().clone(); + imported.sort_by_key(|task| task.chunk_id); + assert_eq!( + imported, + vec![ + ImportTaskKey::new(2, "analytics"), + ImportTaskKey::new(3, "public"), + ] + ); + // One unit for the skipped-completed task plus one per imported task. + assert_eq!(progress.total_inc(), 3); + assert!(load_import_state(&path).await.unwrap().is_none()); + } + + #[tokio::test] + async fn test_import_concurrent_persists_failed_and_returns_first_error() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("import_state.json"); + let failed_task = ImportTaskKey::new(1, "public"); + let tasks = vec![failed_task.clone(), ImportTaskKey::new(2, "analytics")]; + save_import_state(&path, &ddl_completed_state(&tasks)) + .await + .unwrap(); + + let executor = RecordingExecutor { + imported: Arc::new(Mutex::new(Vec::new())), + fail_task: Some(failed_task.clone()), + failure_mode: Some(FailureMode::Fatal), + attempts: Arc::new(AtomicUsize::new(0)), + }; + + let session = prepare_import_resume(config_with_parallelism(path.clone(), tasks, 4)) + .await + .unwrap(); + let error = import_with_resume_session(session, &executor) + .await + .unwrap_err(); + assert!(matches!( + error, + crate::data::import_v2::error::Error::TestTaskFailed { + retryable: false, + .. + } + )); + + // The state file is retained for resume, with the failed task recorded. + let state = load_import_state(&path).await.unwrap().unwrap(); + assert_eq!( + state.task_status(failed_task.chunk_id, &failed_task.schema), + Some(ImportTaskStatus::Failed) + ); + } + + struct FailFastExecutor { + imported: Arc>>, + fail_task: ImportTaskKey, + } + + #[async_trait] + impl ImportTaskExecutor for FailFastExecutor { + async fn import_task(&self, task: &ImportTaskKey) -> Result<()> { + if task == &self.fail_task { + return TestTaskFailedSnafu { + message: "fatal failure".to_string(), + retryable: false, + } + .fail(); + } + + for _ in 0..8 { + tokio::task::yield_now().await; + } + self.imported.lock().unwrap().push(task.clone()); + Ok(()) + } + } + + #[tokio::test] + async fn test_import_concurrent_stops_scheduling_new_tasks_after_failure() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("import_state.json"); + let failed_task = ImportTaskKey::new(1, "public"); + let in_flight_task = ImportTaskKey::new(2, "analytics"); + let unscheduled_task_1 = ImportTaskKey::new(3, "public"); + let unscheduled_task_2 = ImportTaskKey::new(4, "analytics"); + let tasks = vec![ + failed_task.clone(), + in_flight_task.clone(), + unscheduled_task_1.clone(), + unscheduled_task_2.clone(), + ]; + save_import_state(&path, &ddl_completed_state(&tasks)) + .await + .unwrap(); + + let imported = Arc::new(Mutex::new(Vec::new())); + let executor = FailFastExecutor { + imported: imported.clone(), + fail_task: failed_task.clone(), + }; + + let session = prepare_import_resume(config_with_parallelism(path.clone(), tasks, 2)) + .await + .unwrap(); + import_with_resume_session(session, &executor) + .await + .unwrap_err(); + + // The already in-flight sibling is drained, but tasks beyond the + // parallelism window are not scheduled after the first failure. + assert_eq!(imported.lock().unwrap().clone(), vec![in_flight_task]); + let state = load_import_state(&path).await.unwrap().unwrap(); + assert_eq!( + state.task_status(failed_task.chunk_id, &failed_task.schema), + Some(ImportTaskStatus::Failed) + ); + assert_eq!( + state.task_status(unscheduled_task_1.chunk_id, &unscheduled_task_1.schema), + Some(ImportTaskStatus::Pending) + ); + assert_eq!( + state.task_status(unscheduled_task_2.chunk_id, &unscheduled_task_2.schema), + Some(ImportTaskStatus::Pending) + ); + } }