diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 6f896741c9..e1616012fc 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -35,6 +35,7 @@ jobs: build-macos: name: Build macOS binary strategy: + fail-fast: false matrix: # The file format is greptime-- include: @@ -129,6 +130,7 @@ jobs: build-linux: name: Build linux binary strategy: + fail-fast: false matrix: # The file format is greptime-- include: diff --git a/Cargo.lock b/Cargo.lock index 9bf924c252..f31e4451f3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4098,7 +4098,7 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" [[package]] name = "greptime-proto" version = "0.1.0" -source = "git+https://github.com/GreptimeTeam/greptime-proto.git?rev=5d5eb65bb985ff47b3a417fb2505e315e2f5c319#5d5eb65bb985ff47b3a417fb2505e315e2f5c319" +source = "git+https://github.com/GreptimeTeam/greptime-proto.git?rev=7aeaeaba1e0ca6a5c736b6ab2eb63144ae3d284b#7aeaeaba1e0ca6a5c736b6ab2eb63144ae3d284b" dependencies = [ "prost", "serde", @@ -8538,6 +8538,8 @@ dependencies = [ "common-test-util", "common-time", "datafusion", + "datafusion-common", + "datafusion-expr", "datatypes", "derive_builder 0.12.0", "digest", @@ -8974,6 +8976,7 @@ dependencies = [ "bitflags 1.3.2", "byteorder", "bytes", + "chrono", "crc", "crossbeam-queue", "digest", @@ -9554,6 +9557,7 @@ dependencies = [ "axum", "axum-test-helper", "catalog", + "chrono", "client", "common-base", "common-catalog", diff --git a/Cargo.toml b/Cargo.toml index 6ffe5f6151..0b09ee555e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -72,7 +72,7 @@ datafusion-sql = { git = "https://github.com/waynexia/arrow-datafusion.git", rev datafusion-substrait = { git = "https://github.com/waynexia/arrow-datafusion.git", rev = "63e52dde9e44cac4b1f6c6e6b6bf6368ba3bd323" } futures = "0.3" futures-util = "0.3" -greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "5d5eb65bb985ff47b3a417fb2505e315e2f5c319" } +greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "7aeaeaba1e0ca6a5c736b6ab2eb63144ae3d284b" } itertools = "0.10" parquet = "40.0" paste = "1.0" diff --git a/Makefile b/Makefile index 1e3dd4eb0a..c35715ffa7 100644 --- a/Makefile +++ b/Makefile @@ -34,7 +34,7 @@ docker-image: ## Build docker image. ##@ Test test: nextest ## Run unit and integration tests. - cargo nextest run + cargo nextest run --retries 3 .PHONY: nextest ## Install nextest tools. nextest: diff --git a/src/catalog/src/remote/manager.rs b/src/catalog/src/remote/manager.rs index 6cc2c78799..dc8b521793 100644 --- a/src/catalog/src/remote/manager.rs +++ b/src/catalog/src/remote/manager.rs @@ -85,6 +85,7 @@ impl RemoteCatalogManager { catalog_name: catalog_name.to_string(), backend: self.backend.clone(), engine_manager: self.engine_manager.clone(), + region_alive_keepers: self.region_alive_keepers.clone(), }) as _ } @@ -132,10 +133,17 @@ impl RemoteCatalogManager { increment_gauge!(crate::metrics::METRIC_CATALOG_MANAGER_CATALOG_COUNT, 1.0); + let region_alive_keepers = self.region_alive_keepers.clone(); joins.push(common_runtime::spawn_bg(async move { - let max_table_id = - initiate_schemas(node_id, backend, engine_manager, &catalog_name, catalog) - .await?; + let max_table_id = initiate_schemas( + node_id, + backend, + engine_manager, + &catalog_name, + catalog, + region_alive_keepers, + ) + .await?; info!( "Catalog name: {}, max table id allocated: {}", &catalog_name, max_table_id @@ -164,6 +172,7 @@ impl RemoteCatalogManager { self.engine_manager.clone(), catalog_name, schema_name, + self.region_alive_keepers.clone(), ); let catalog_provider = self.new_catalog_provider(catalog_name); @@ -209,6 +218,7 @@ fn new_schema_provider( engine_manager: TableEngineManagerRef, catalog_name: &str, schema_name: &str, + region_alive_keepers: Arc, ) -> SchemaProviderRef { Arc::new(RemoteSchemaProvider { catalog_name: catalog_name.to_string(), @@ -216,6 +226,7 @@ fn new_schema_provider( node_id, backend, engine_manager, + region_alive_keepers, }) as _ } @@ -249,6 +260,7 @@ async fn initiate_schemas( engine_manager: TableEngineManagerRef, catalog_name: &str, catalog: CatalogProviderRef, + region_alive_keepers: Arc, ) -> Result { let mut schemas = iter_remote_schemas(&backend, catalog_name).await; let mut joins = Vec::new(); @@ -268,6 +280,7 @@ async fn initiate_schemas( engine_manager.clone(), &catalog_name, &schema_name, + region_alive_keepers.clone(), ); catalog .register_schema(schema_name.clone(), schema.clone()) @@ -611,18 +624,7 @@ impl CatalogManager for RemoteCatalogManager { &[crate::metrics::db_label(catalog, schema)], ); schema_provider - .register_table(table_name.to_string(), request.table.clone()) - .await?; - - let table_ident = TableIdent { - catalog: request.catalog, - schema: request.schema, - table: request.table_name, - table_id: request.table_id, - engine: request.table.table_info().meta.engine.clone(), - }; - self.region_alive_keepers - .register_table(table_ident, request.table) + .register_table(table_name.to_string(), request.table) .await?; Ok(true) @@ -678,6 +680,7 @@ impl CatalogManager for RemoteCatalogManager { self.engine_manager.clone(), &catalog_name, &schema_name, + self.region_alive_keepers.clone(), ); catalog_provider .register_schema(schema_name, schema_provider) @@ -813,6 +816,7 @@ pub struct RemoteCatalogProvider { catalog_name: String, backend: KvBackendRef, engine_manager: TableEngineManagerRef, + region_alive_keepers: Arc, } impl RemoteCatalogProvider { @@ -821,12 +825,14 @@ impl RemoteCatalogProvider { backend: KvBackendRef, engine_manager: TableEngineManagerRef, node_id: u64, + region_alive_keepers: Arc, ) -> Self { Self { node_id, catalog_name, backend, engine_manager, + region_alive_keepers, } } @@ -844,6 +850,7 @@ impl RemoteCatalogProvider { node_id: self.node_id, backend: self.backend.clone(), engine_manager: self.engine_manager.clone(), + region_alive_keepers: self.region_alive_keepers.clone(), }; Arc::new(provider) as Arc<_> } @@ -906,6 +913,7 @@ pub struct RemoteSchemaProvider { node_id: u64, backend: KvBackendRef, engine_manager: TableEngineManagerRef, + region_alive_keepers: Arc, } impl RemoteSchemaProvider { @@ -915,6 +923,7 @@ impl RemoteSchemaProvider { node_id: u64, engine_manager: TableEngineManagerRef, backend: KvBackendRef, + region_alive_keepers: Arc, ) -> Self { Self { catalog_name, @@ -922,6 +931,7 @@ impl RemoteSchemaProvider { node_id, backend, engine_manager, + region_alive_keepers, } } @@ -1004,6 +1014,18 @@ impl SchemaProvider for RemoteSchemaProvider { &table_value.as_bytes().context(InvalidCatalogValueSnafu)?, ) .await?; + + let table_ident = TableIdent { + catalog: table_info.catalog_name.clone(), + schema: table_info.schema_name.clone(), + table: table_info.name.clone(), + table_id: table_info.ident.table_id, + engine: table_info.meta.engine.clone(), + }; + self.region_alive_keepers + .register_table(table_ident, table) + .await?; + debug!( "Successfully set catalog table entry, key: {}, table value: {:?}", table_key, table_value diff --git a/src/catalog/src/remote/region_alive_keeper.rs b/src/catalog/src/remote/region_alive_keeper.rs index 327e846b3b..61daee4cf1 100644 --- a/src/catalog/src/remote/region_alive_keeper.rs +++ b/src/catalog/src/remote/region_alive_keeper.rs @@ -14,6 +14,7 @@ use std::collections::HashMap; use std::future::Future; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use async_trait::async_trait; @@ -30,7 +31,7 @@ use table::engine::manager::TableEngineManagerRef; use table::engine::{CloseTableResult, EngineContext, TableEngineRef}; use table::requests::CloseTableRequest; use table::TableRef; -use tokio::sync::{mpsc, Mutex}; +use tokio::sync::{mpsc, oneshot, Mutex}; use tokio::task::JoinHandle; use tokio::time::{Duration, Instant}; @@ -40,6 +41,8 @@ use crate::error::{Result, TableEngineNotFoundSnafu}; pub struct RegionAliveKeepers { table_engine_manager: TableEngineManagerRef, keepers: Arc>>>, + heartbeat_interval_millis: u64, + started: AtomicBool, /// The epoch when [RegionAliveKeepers] is created. It's used to get a monotonically non-decreasing /// elapsed time when submitting heartbeats to Metasrv (because [Instant] is monotonically @@ -49,23 +52,24 @@ pub struct RegionAliveKeepers { } impl RegionAliveKeepers { - pub fn new(table_engine_manager: TableEngineManagerRef) -> Self { + pub fn new( + table_engine_manager: TableEngineManagerRef, + heartbeat_interval_millis: u64, + ) -> Self { Self { table_engine_manager, keepers: Arc::new(Mutex::new(HashMap::new())), + heartbeat_interval_millis, + started: AtomicBool::new(false), epoch: Instant::now(), } } - async fn find_keeper(&self, table_ident: &TableIdent) -> Option> { + pub async fn find_keeper(&self, table_ident: &TableIdent) -> Option> { self.keepers.lock().await.get(table_ident).cloned() } - pub(crate) async fn register_table( - &self, - table_ident: TableIdent, - table: TableRef, - ) -> Result<()> { + pub async fn register_table(&self, table_ident: TableIdent, table: TableRef) -> Result<()> { let keeper = self.find_keeper(&table_ident).await; if keeper.is_some() { return Ok(()); @@ -78,17 +82,29 @@ impl RegionAliveKeepers { engine_name: &table_ident.engine, })?; - let keeper = Arc::new(RegionAliveKeeper::new(table_engine, table_ident.clone())); + let keeper = Arc::new(RegionAliveKeeper::new( + table_engine, + table_ident.clone(), + self.heartbeat_interval_millis, + )); for r in table.table_info().meta.region_numbers.iter() { keeper.register_region(*r).await; } - info!("Register RegionAliveKeeper for table {table_ident}"); - self.keepers.lock().await.insert(table_ident, keeper); + let mut keepers = self.keepers.lock().await; + keepers.insert(table_ident.clone(), keeper.clone()); + + if self.started.load(Ordering::Relaxed) { + keeper.start().await; + + info!("RegionAliveKeeper for table {table_ident} is started!"); + } else { + info!("RegionAliveKeeper for table {table_ident} is registered but not started yet!"); + } Ok(()) } - pub(crate) async fn deregister_table(&self, table_ident: &TableIdent) { + pub async fn deregister_table(&self, table_ident: &TableIdent) { if self.keepers.lock().await.remove(table_ident).is_some() { info!("Deregister RegionAliveKeeper for table {table_ident}"); } @@ -114,10 +130,17 @@ impl RegionAliveKeepers { keeper.deregister_region(region_ident.region_number).await } - pub async fn start(&self, heartbeat_interval_millis: u64) { - for keeper in self.keepers.lock().await.values() { - keeper.start(heartbeat_interval_millis).await; + pub async fn start(&self) { + let keepers = self.keepers.lock().await; + for keeper in keepers.values() { + keeper.start().await; } + self.started.store(true, Ordering::Relaxed); + + info!( + "RegionAliveKeepers for tables {:?} are started!", + keepers.keys().map(|x| x.to_string()).collect::>(), + ); } pub fn epoch(&self) -> Instant { @@ -171,18 +194,26 @@ impl HeartbeatResponseHandler for RegionAliveKeepers { /// opened regions to Metasrv, in heartbeats. If Metasrv decides some region could be resided in this /// Datanode, it will "extend" the region's "lease", with a deadline for [RegionAliveKeeper] to /// countdown. -struct RegionAliveKeeper { +pub struct RegionAliveKeeper { table_engine: TableEngineRef, table_ident: TableIdent, countdown_task_handles: Arc>>>, + heartbeat_interval_millis: u64, + started: AtomicBool, } impl RegionAliveKeeper { - fn new(table_engine: TableEngineRef, table_ident: TableIdent) -> Self { + fn new( + table_engine: TableEngineRef, + table_ident: TableIdent, + heartbeat_interval_millis: u64, + ) -> Self { Self { table_engine, table_ident, countdown_task_handles: Arc::new(Mutex::new(HashMap::new())), + heartbeat_interval_millis, + started: AtomicBool::new(false), } } @@ -210,14 +241,22 @@ impl RegionAliveKeeper { || on_task_finished, )); - self.countdown_task_handles - .lock() - .await - .insert(region, handle); - info!( - "Register alive countdown for new region {region} in table {}", - self.table_ident - ) + let mut handles = self.countdown_task_handles.lock().await; + handles.insert(region, handle.clone()); + + if self.started.load(Ordering::Relaxed) { + handle.start(self.heartbeat_interval_millis).await; + + info!( + "Region alive countdown for region {region} in table {} is started!", + self.table_ident + ); + } else { + info!( + "Region alive countdown for region {region} in table {} is registered but not started yet!", + self.table_ident + ); + } } async fn deregister_region(&self, region: RegionNumber) { @@ -235,14 +274,18 @@ impl RegionAliveKeeper { } } - async fn start(&self, heartbeat_interval_millis: u64) { - for handle in self.countdown_task_handles.lock().await.values() { - handle.start(heartbeat_interval_millis).await; + async fn start(&self) { + let handles = self.countdown_task_handles.lock().await; + for handle in handles.values() { + handle.start(self.heartbeat_interval_millis).await; } + + self.started.store(true, Ordering::Relaxed); info!( - "RegionAliveKeeper for table {} is started!", + "Region alive countdowns for regions {:?} in table {} are started!", + handles.keys().copied().collect::>(), self.table_ident - ) + ); } async fn keep_lived(&self, designated_regions: Vec, deadline: Instant) { @@ -253,15 +296,24 @@ impl RegionAliveKeeper { // Else the region alive keeper might be triggered by lagging messages, we can safely ignore it. } } + + pub async fn deadline(&self, region: RegionNumber) -> Option { + let mut deadline = None; + if let Some(handle) = self.find_handle(®ion).await { + let (s, r) = oneshot::channel(); + if handle.tx.send(CountdownCommand::Deadline(s)).await.is_ok() { + deadline = r.await.ok() + } + } + deadline + } } #[derive(Debug)] enum CountdownCommand { Start(u64), Reset(Instant), - - #[cfg(test)] - Deadline(tokio::sync::oneshot::Sender), + Deadline(oneshot::Sender), } struct CountdownTaskHandle { @@ -362,7 +414,10 @@ impl CountdownTask { }, Some(CountdownCommand::Reset(deadline)) => { if countdown.deadline() < deadline { - debug!("Reset deadline to region {region} of table {table_ident} to {deadline:?}"); + debug!( + "Reset deadline of region {region} of table {table_ident} to approximately {} seconds later", + (deadline - Instant::now()).as_secs_f32(), + ); countdown.set(tokio::time::sleep_until(deadline)); } // Else the countdown could be either: @@ -378,10 +433,8 @@ impl CountdownTask { ); break; }, - - #[cfg(test)] Some(CountdownCommand::Deadline(tx)) => { - tx.send(countdown.deadline()).unwrap() + let _ = tx.send(countdown.deadline()); } } } @@ -433,7 +486,6 @@ mod test { use table::engine::{TableEngine, TableReference}; use table::requests::{CreateTableRequest, TableOptions}; use table::test_util::EmptyTable; - use tokio::sync::oneshot; use super::*; use crate::remote::mock::MockTableEngine; @@ -441,7 +493,7 @@ mod test { async fn prepare_keepers() -> (TableIdent, RegionAliveKeepers) { let table_engine = Arc::new(MockTableEngine::default()); let table_engine_manager = Arc::new(MemoryTableEngineManager::new(table_engine)); - let keepers = RegionAliveKeepers::new(table_engine_manager); + let keepers = RegionAliveKeepers::new(table_engine_manager, 5000); let catalog = "my_catalog"; let schema = "my_schema"; @@ -483,7 +535,7 @@ mod test { async fn test_handle_heartbeat_response() { let (table_ident, keepers) = prepare_keepers().await; - keepers.start(5000).await; + keepers.start().await; let startup_protection_until = Instant::now() + Duration::from_secs(21); let duration_since_epoch = (Instant::now() - keepers.epoch).as_millis() as _; @@ -517,8 +569,7 @@ mod test { keep_alive_until: Instant, is_kept_live: bool, ) { - let handles = keeper.countdown_task_handles.lock().await; - let deadline = deadline(&handles.get(®ion_number).unwrap().tx).await; + let deadline = keeper.deadline(region_number).await.unwrap(); if is_kept_live { assert!(deadline > startup_protection_until && deadline == keep_alive_until); } else { @@ -555,11 +606,16 @@ mod test { }) .await; - keepers.start(5000).await; + keepers.start().await; for keeper in keepers.keepers.lock().await.values() { - for handle in keeper.countdown_task_handles.lock().await.values() { + let regions = { + let handles = keeper.countdown_task_handles.lock().await; + handles.keys().copied().collect::>() + }; + for region in regions { // assert countdown tasks are started - assert!(deadline(&handle.tx).await <= Instant::now() + Duration::from_secs(20)); + let deadline = keeper.deadline(region).await.unwrap(); + assert!(deadline <= Instant::now() + Duration::from_secs(20)); } } @@ -598,22 +654,13 @@ mod test { table_id: 1024, engine: "mito".to_string(), }; - let keeper = RegionAliveKeeper::new(table_engine, table_ident); + let keeper = RegionAliveKeeper::new(table_engine, table_ident, 1000); let region = 1; assert!(keeper.find_handle(®ion).await.is_none()); keeper.register_region(region).await; assert!(keeper.find_handle(®ion).await.is_some()); - let sender = &keeper - .countdown_task_handles - .lock() - .await - .get(®ion) - .unwrap() - .tx - .clone(); - let ten_seconds_later = || Instant::now() + Duration::from_secs(10); keeper.keep_lived(vec![1, 2, 3], ten_seconds_later()).await; @@ -622,12 +669,12 @@ mod test { let far_future = Instant::now() + Duration::from_secs(86400 * 365 * 29); // assert if keeper is not started, keep_lived is of no use - assert!(deadline(sender).await > far_future); + assert!(keeper.deadline(region).await.unwrap() > far_future); - keeper.start(1000).await; + keeper.start().await; keeper.keep_lived(vec![1, 2, 3], ten_seconds_later()).await; // assert keep_lived works if keeper is started - assert!(deadline(sender).await <= ten_seconds_later()); + assert!(keeper.deadline(region).await.unwrap() <= ten_seconds_later()); keeper.deregister_region(region).await; assert!(keeper.find_handle(®ion).await.is_none()); @@ -726,6 +773,12 @@ mod test { task.run().await; }); + async fn deadline(tx: &mpsc::Sender) -> Instant { + let (s, r) = oneshot::channel(); + tx.send(CountdownCommand::Deadline(s)).await.unwrap(); + r.await.unwrap() + } + // if countdown task is not started, its deadline is set to far future assert!(deadline(&tx).await > Instant::now() + Duration::from_secs(86400 * 365 * 29)); @@ -747,10 +800,4 @@ mod test { tokio::time::sleep(Duration::from_millis(2000)).await; assert!(!table_engine.table_exists(ctx, &table_ref)); } - - async fn deadline(tx: &mpsc::Sender) -> Instant { - let (s, r) = oneshot::channel(); - tx.send(CountdownCommand::Deadline(s)).await.unwrap(); - r.await.unwrap() - } } diff --git a/src/catalog/tests/remote_catalog_tests.rs b/src/catalog/tests/remote_catalog_tests.rs index 776c6be6c9..9d3f539f83 100644 --- a/src/catalog/tests/remote_catalog_tests.rs +++ b/src/catalog/tests/remote_catalog_tests.rs @@ -19,6 +19,7 @@ mod tests { use std::assert_matches::assert_matches; use std::collections::HashSet; use std::sync::Arc; + use std::time::Duration; use catalog::helper::{CatalogKey, CatalogValue, SchemaKey, SchemaValue}; use catalog::remote::mock::{MockKvBackend, MockTableEngine}; @@ -29,11 +30,27 @@ mod tests { }; use catalog::{CatalogManager, RegisterTableRequest}; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, MITO_ENGINE}; + use common_meta::ident::TableIdent; use datatypes::schema::RawSchema; use futures_util::StreamExt; use table::engine::manager::{MemoryTableEngineManager, TableEngineManagerRef}; use table::engine::{EngineContext, TableEngineRef}; use table::requests::CreateTableRequest; + use table::test_util::EmptyTable; + use tokio::time::Instant; + + struct TestingComponents { + kv_backend: KvBackendRef, + catalog_manager: Arc, + table_engine_manager: TableEngineManagerRef, + region_alive_keepers: Arc, + } + + impl TestingComponents { + fn table_engine(&self) -> TableEngineRef { + self.table_engine_manager.engine(MITO_ENGINE).unwrap() + } + } #[tokio::test] async fn test_backend() { @@ -121,14 +138,7 @@ mod tests { assert!(ret.is_none()); } - async fn prepare_components( - node_id: u64, - ) -> ( - KvBackendRef, - TableEngineRef, - Arc, - TableEngineManagerRef, - ) { + async fn prepare_components(node_id: u64) -> TestingComponents { let cached_backend = Arc::new(CachedMetaKvBackend::wrap( Arc::new(MockKvBackend::default()), )); @@ -136,30 +146,34 @@ mod tests { let table_engine = Arc::new(MockTableEngine::default()); let engine_manager = Arc::new(MemoryTableEngineManager::alias( MITO_ENGINE.to_string(), - table_engine.clone(), + table_engine, )); + let region_alive_keepers = Arc::new(RegionAliveKeepers::new(engine_manager.clone(), 5000)); + let catalog_manager = RemoteCatalogManager::new( engine_manager.clone(), node_id, cached_backend.clone(), - Arc::new(RegionAliveKeepers::new(engine_manager.clone())), + region_alive_keepers.clone(), ); catalog_manager.start().await.unwrap(); - ( - cached_backend, - table_engine, - Arc::new(catalog_manager), - engine_manager as Arc<_>, - ) + TestingComponents { + kv_backend: cached_backend, + catalog_manager: Arc::new(catalog_manager), + table_engine_manager: engine_manager, + region_alive_keepers, + } } #[tokio::test] async fn test_remote_catalog_default() { common_telemetry::init_default_ut_logging(); let node_id = 42; - let (_, _, catalog_manager, _) = prepare_components(node_id).await; + let TestingComponents { + catalog_manager, .. + } = prepare_components(node_id).await; assert_eq!( vec![DEFAULT_CATALOG_NAME.to_string()], catalog_manager.catalog_names().await.unwrap() @@ -180,14 +194,16 @@ mod tests { async fn test_remote_catalog_register_nonexistent() { common_telemetry::init_default_ut_logging(); let node_id = 42; - let (_, table_engine, catalog_manager, _) = prepare_components(node_id).await; + let components = prepare_components(node_id).await; + // register a new table with an nonexistent catalog let catalog_name = "nonexistent_catalog".to_string(); let schema_name = "nonexistent_schema".to_string(); let table_name = "fail_table".to_string(); // this schema has no effect let table_schema = RawSchema::new(vec![]); - let table = table_engine + let table = components + .table_engine() .create_table( &EngineContext {}, CreateTableRequest { @@ -213,7 +229,7 @@ mod tests { table_id: 1, table, }; - let res = catalog_manager.register_table(reg_req).await; + let res = components.catalog_manager.register_table(reg_req).await; // because nonexistent_catalog does not exist yet. assert_matches!( @@ -225,7 +241,8 @@ mod tests { #[tokio::test] async fn test_register_table() { let node_id = 42; - let (_, table_engine, catalog_manager, _) = prepare_components(node_id).await; + let components = prepare_components(node_id).await; + let catalog_manager = &components.catalog_manager; let default_catalog = catalog_manager .catalog(DEFAULT_CATALOG_NAME) .await @@ -249,7 +266,8 @@ mod tests { let table_id = 1; // this schema has no effect let table_schema = RawSchema::new(vec![]); - let table = table_engine + let table = components + .table_engine() .create_table( &EngineContext {}, CreateTableRequest { @@ -285,8 +303,10 @@ mod tests { #[tokio::test] async fn test_register_catalog_schema_table() { let node_id = 42; - let (backend, table_engine, catalog_manager, engine_manager) = - prepare_components(node_id).await; + let components = prepare_components(node_id).await; + let backend = &components.kv_backend; + let catalog_manager = components.catalog_manager.clone(); + let engine_manager = components.table_engine_manager.clone(); let catalog_name = "test_catalog".to_string(); let schema_name = "nonexistent_schema".to_string(); @@ -295,6 +315,7 @@ mod tests { backend.clone(), engine_manager.clone(), node_id, + components.region_alive_keepers.clone(), )); // register catalog to catalog manager @@ -308,7 +329,8 @@ mod tests { HashSet::from_iter(catalog_manager.catalog_names().await.unwrap().into_iter()) ); - let table_to_register = table_engine + let table_to_register = components + .table_engine() .create_table( &EngineContext {}, CreateTableRequest { @@ -355,6 +377,7 @@ mod tests { node_id, engine_manager, backend.clone(), + components.region_alive_keepers.clone(), )); let prev = new_catalog @@ -374,4 +397,94 @@ mod tests { .collect() ) } + + #[tokio::test] + async fn test_register_table_before_and_after_region_alive_keeper_started() { + let components = prepare_components(42).await; + let catalog_manager = &components.catalog_manager; + let region_alive_keepers = &components.region_alive_keepers; + + let table_before = TableIdent { + catalog: DEFAULT_CATALOG_NAME.to_string(), + schema: DEFAULT_SCHEMA_NAME.to_string(), + table: "table_before".to_string(), + table_id: 1, + engine: MITO_ENGINE.to_string(), + }; + let request = RegisterTableRequest { + catalog: table_before.catalog.clone(), + schema: table_before.schema.clone(), + table_name: table_before.table.clone(), + table_id: table_before.table_id, + table: Arc::new(EmptyTable::new(CreateTableRequest { + id: table_before.table_id, + catalog_name: table_before.catalog.clone(), + schema_name: table_before.schema.clone(), + table_name: table_before.table.clone(), + desc: None, + schema: RawSchema::new(vec![]), + region_numbers: vec![0], + primary_key_indices: vec![], + create_if_not_exists: false, + table_options: Default::default(), + engine: MITO_ENGINE.to_string(), + })), + }; + assert!(catalog_manager.register_table(request).await.unwrap()); + + let keeper = region_alive_keepers + .find_keeper(&table_before) + .await + .unwrap(); + let deadline = keeper.deadline(0).await.unwrap(); + let far_future = Instant::now() + Duration::from_secs(86400 * 365 * 29); + // assert region alive countdown is not started + assert!(deadline > far_future); + + region_alive_keepers.start().await; + + let table_after = TableIdent { + catalog: DEFAULT_CATALOG_NAME.to_string(), + schema: DEFAULT_SCHEMA_NAME.to_string(), + table: "table_after".to_string(), + table_id: 2, + engine: MITO_ENGINE.to_string(), + }; + let request = RegisterTableRequest { + catalog: table_after.catalog.clone(), + schema: table_after.schema.clone(), + table_name: table_after.table.clone(), + table_id: table_after.table_id, + table: Arc::new(EmptyTable::new(CreateTableRequest { + id: table_after.table_id, + catalog_name: table_after.catalog.clone(), + schema_name: table_after.schema.clone(), + table_name: table_after.table.clone(), + desc: None, + schema: RawSchema::new(vec![]), + region_numbers: vec![0], + primary_key_indices: vec![], + create_if_not_exists: false, + table_options: Default::default(), + engine: MITO_ENGINE.to_string(), + })), + }; + assert!(catalog_manager.register_table(request).await.unwrap()); + + let keeper = region_alive_keepers + .find_keeper(&table_after) + .await + .unwrap(); + let deadline = keeper.deadline(0).await.unwrap(); + // assert countdown is started for the table registered after [RegionAliveKeepers] started + assert!(deadline <= Instant::now() + Duration::from_secs(20)); + + let keeper = region_alive_keepers + .find_keeper(&table_before) + .await + .unwrap(); + let deadline = keeper.deadline(0).await.unwrap(); + // assert countdown is started for the table registered before [RegionAliveKeepers] started, too + assert!(deadline <= Instant::now() + Duration::from_secs(20)); + } } diff --git a/src/common/time/src/date.rs b/src/common/time/src/date.rs index fff9f412db..4540490111 100644 --- a/src/common/time/src/date.rs +++ b/src/common/time/src/date.rs @@ -52,6 +52,12 @@ impl From for Date { } } +impl From for Date { + fn from(date: NaiveDate) -> Self { + Self(date.num_days_from_ce() - UNIX_EPOCH_FROM_CE) + } +} + impl Display for Date { /// [Date] is formatted according to ISO-8601 standard. fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { diff --git a/src/datanode/src/heartbeat.rs b/src/datanode/src/heartbeat.rs index 87275b6613..6c9e3e0365 100644 --- a/src/datanode/src/heartbeat.rs +++ b/src/datanode/src/heartbeat.rs @@ -30,12 +30,14 @@ use snafu::ResultExt; use tokio::sync::mpsc; use tokio::time::Instant; +use crate::datanode::DatanodeOptions; use crate::error::{self, MetaClientInitSnafu, Result}; pub(crate) mod handler; pub struct HeartbeatTask { node_id: u64, + node_epoch: u64, server_addr: String, server_hostname: Option, running: Arc, @@ -56,21 +58,23 @@ impl HeartbeatTask { /// Create a new heartbeat task instance. pub fn new( node_id: u64, - server_addr: String, - server_hostname: Option, + opts: &DatanodeOptions, meta_client: Arc, catalog_manager: CatalogManagerRef, resp_handler_executor: HeartbeatResponseHandlerExecutorRef, + heartbeat_interval_millis: u64, region_alive_keepers: Arc, ) -> Self { Self { node_id, - server_addr, - server_hostname, + // We use datanode's start time millis as the node's epoch. + node_epoch: common_time::util::current_time_millis() as u64, + server_addr: opts.rpc_addr.clone(), + server_hostname: opts.rpc_hostname.clone(), running: Arc::new(AtomicBool::new(false)), meta_client, catalog_manager, - interval: 5_000, // default interval is set to 5 secs + interval: heartbeat_interval_millis, resp_handler_executor, region_alive_keepers, } @@ -133,10 +137,11 @@ impl HeartbeatTask { } let interval = self.interval; let node_id = self.node_id; + let node_epoch = self.node_epoch; let addr = resolve_addr(&self.server_addr, &self.server_hostname); info!("Starting heartbeat to Metasrv with interval {interval}. My node id is {node_id}, address is {addr}."); - self.region_alive_keepers.start(interval).await; + self.region_alive_keepers.start().await; let meta_client = self.meta_client.clone(); let catalog_manager_clone = self.catalog_manager.clone(); @@ -201,6 +206,7 @@ impl HeartbeatTask { }), region_stats, duration_since_epoch: (Instant::now() - epoch).as_millis() as u64, + node_epoch, ..Default::default() }; sleep.as_mut().reset(Instant::now() + Duration::from_millis(interval)); diff --git a/src/datanode/src/heartbeat/handler/close_region.rs b/src/datanode/src/heartbeat/handler/close_region.rs index 1dc0157fe7..abc492d40f 100644 --- a/src/datanode/src/heartbeat/handler/close_region.rs +++ b/src/datanode/src/heartbeat/handler/close_region.rs @@ -23,6 +23,7 @@ use common_meta::heartbeat::handler::{ HandleControl, HeartbeatResponseHandler, HeartbeatResponseHandlerContext, }; use common_meta::instruction::{Instruction, InstructionReply, SimpleReply}; +use common_meta::RegionIdent; use common_telemetry::{error, info, warn}; use snafu::ResultExt; use store_api::storage::RegionNumber; @@ -55,25 +56,8 @@ impl HeartbeatResponseHandler for CloseRegionHandler { let mailbox = ctx.mailbox.clone(); let self_ref = Arc::new(self.clone()); - let region_alive_keepers = self.region_alive_keepers.clone(); common_runtime::spawn_bg(async move { - let table_ident = ®ion_ident.table_ident; - let table_ref = TableReference::full( - &table_ident.catalog, - &table_ident.schema, - &table_ident.table, - ); - let result = self_ref - .close_region_inner( - table_ident.engine.clone(), - &table_ref, - vec![region_ident.region_number], - ) - .await; - - if matches!(result, Ok(true)) { - region_alive_keepers.deregister_region(®ion_ident).await; - } + let result = self_ref.close_region_inner(region_ident).await; if let Err(e) = mailbox .send((meta, CloseRegionHandler::map_result(result))) @@ -152,20 +136,21 @@ impl CloseRegionHandler { Ok(true) } - async fn close_region_inner( - &self, - engine: String, - table_ref: &TableReference<'_>, - region_numbers: Vec, - ) -> Result { - let engine = - self.table_engine_manager - .engine(&engine) - .context(error::TableEngineNotFoundSnafu { - engine_name: &engine, - })?; + async fn close_region_inner(&self, region_ident: RegionIdent) -> Result { + let table_ident = ®ion_ident.table_ident; + let engine_name = &table_ident.engine; + let engine = self + .table_engine_manager + .engine(engine_name) + .context(error::TableEngineNotFoundSnafu { engine_name })?; let ctx = EngineContext::default(); + let table_ref = &TableReference::full( + &table_ident.catalog, + &table_ident.schema, + &table_ident.table, + ); + let region_numbers = vec![region_ident.region_number]; if self .regions_closed( table_ref.catalog, @@ -203,7 +188,15 @@ impl CloseRegionHandler { })? { CloseTableResult::NotFound | CloseTableResult::Released(_) => { // Deregister table if The table released. - self.deregister_table(table_ref).await + let deregistered = self.deregister_table(table_ref).await?; + + if deregistered { + self.region_alive_keepers + .deregister_table(table_ident) + .await; + } + + Ok(deregistered) } CloseTableResult::PartialClosed(regions) => { // Requires caller to update the region_numbers @@ -211,6 +204,11 @@ impl CloseRegionHandler { "Close partial regions: {:?} in table: {}", regions, table_ref ); + + self.region_alive_keepers + .deregister_region(®ion_ident) + .await; + Ok(true) } }; diff --git a/src/datanode/src/instance.rs b/src/datanode/src/instance.rs index eb4f68ee5b..b80a22c0e2 100644 --- a/src/datanode/src/instance.rs +++ b/src/datanode/src/instance.rs @@ -197,8 +197,12 @@ impl Instance { let kv_backend = Arc::new(CachedMetaKvBackend::new(meta_client.clone())); - let region_alive_keepers = - Arc::new(RegionAliveKeepers::new(engine_manager.clone())); + let heartbeat_interval_millis = 5000; + + let region_alive_keepers = Arc::new(RegionAliveKeepers::new( + engine_manager.clone(), + heartbeat_interval_millis, + )); let catalog_manager = Arc::new(RemoteCatalogManager::new( engine_manager.clone(), @@ -224,11 +228,11 @@ impl Instance { let heartbeat_task = Some(HeartbeatTask::new( opts.node_id.context(MissingNodeIdSnafu)?, - opts.rpc_addr.clone(), - opts.rpc_hostname.clone(), + opts, meta_client, catalog_manager.clone(), Arc::new(handlers_executor), + heartbeat_interval_millis, region_alive_keepers, )); diff --git a/src/datanode/src/tests.rs b/src/datanode/src/tests.rs index 1796a6c875..5b4ba4de3d 100644 --- a/src/datanode/src/tests.rs +++ b/src/datanode/src/tests.rs @@ -14,6 +14,7 @@ use std::assert_matches::assert_matches; use std::sync::Arc; +use std::time::Duration; use api::v1::greptime_request::Request as GrpcRequest; use api::v1::meta::HeartbeatResponse; @@ -32,8 +33,10 @@ use datatypes::prelude::ConcreteDataType; use servers::query_handler::grpc::GrpcQueryHandler; use session::context::QueryContext; use table::engine::manager::TableEngineManagerRef; +use table::TableRef; use test_util::MockInstance; use tokio::sync::mpsc::{self, Receiver}; +use tokio::time::Instant; use crate::heartbeat::handler::close_region::CloseRegionHandler; use crate::heartbeat::handler::open_region::OpenRegionHandler; @@ -64,7 +67,7 @@ async fn test_close_region_handler() { CloseRegionHandler::new( catalog_manager_ref.clone(), engine_manager_ref.clone(), - Arc::new(RegionAliveKeepers::new(engine_manager_ref.clone())), + Arc::new(RegionAliveKeepers::new(engine_manager_ref.clone(), 5000)), ), )])); @@ -134,43 +137,57 @@ async fn test_open_region_handler() { .. } = prepare_handler_test("test_open_region_handler").await; - let region_alive_keeper = Arc::new(RegionAliveKeepers::new(engine_manager_ref.clone())); + let region_alive_keepers = Arc::new(RegionAliveKeepers::new(engine_manager_ref.clone(), 5000)); + region_alive_keepers.start().await; let executor = Arc::new(HandlerGroupExecutor::new(vec![ Arc::new(OpenRegionHandler::new( catalog_manager_ref.clone(), engine_manager_ref.clone(), - region_alive_keeper.clone(), + region_alive_keepers.clone(), )), Arc::new(CloseRegionHandler::new( catalog_manager_ref.clone(), engine_manager_ref.clone(), - region_alive_keeper, + region_alive_keepers.clone(), )), ])); - prepare_table(instance.inner()).await; + let instruction = open_region_instruction(); + let Instruction::OpenRegion(region_ident) = instruction.clone() else { unreachable!() }; + let table_ident = ®ion_ident.table_ident; + + let table = prepare_table(instance.inner()).await; + region_alive_keepers + .register_table(table_ident.clone(), table) + .await + .unwrap(); // Opens a opened table - handle_instruction(executor.clone(), mailbox.clone(), open_region_instruction()).await; + handle_instruction(executor.clone(), mailbox.clone(), instruction.clone()).await; let (_, reply) = rx.recv().await.unwrap(); assert_matches!( reply, InstructionReply::OpenRegion(SimpleReply { result: true, .. }) ); + let keeper = region_alive_keepers.find_keeper(table_ident).await.unwrap(); + let deadline = keeper.deadline(0).await.unwrap(); + assert!(deadline <= Instant::now() + Duration::from_secs(20)); + // Opens a non-exist table + let non_exist_table_ident = TableIdent { + catalog: "greptime".to_string(), + schema: "public".to_string(), + table: "non-exist".to_string(), + table_id: 2024, + engine: "mito".to_string(), + }; handle_instruction( executor.clone(), mailbox.clone(), Instruction::OpenRegion(RegionIdent { - table_ident: TableIdent { - catalog: "greptime".to_string(), - schema: "public".to_string(), - table: "non-exist".to_string(), - table_id: 2024, - engine: "mito".to_string(), - }, + table_ident: non_exist_table_ident.clone(), region_number: 0, cluster_id: 1, datanode_id: 2, @@ -183,6 +200,11 @@ async fn test_open_region_handler() { InstructionReply::OpenRegion(SimpleReply { result: false, .. }) ); + assert!(region_alive_keepers + .find_keeper(&non_exist_table_ident) + .await + .is_none()); + // Closes demo table handle_instruction( executor.clone(), @@ -197,8 +219,13 @@ async fn test_open_region_handler() { ); assert_test_table_not_found(instance.inner()).await; + assert!(region_alive_keepers + .find_keeper(table_ident) + .await + .is_none()); + // Opens demo table - handle_instruction(executor.clone(), mailbox.clone(), open_region_instruction()).await; + handle_instruction(executor.clone(), mailbox.clone(), instruction).await; let (_, reply) = rx.recv().await.unwrap(); assert_matches!( reply, @@ -275,10 +302,10 @@ fn open_region_instruction() -> Instruction { }) } -async fn prepare_table(instance: &Instance) { +async fn prepare_table(instance: &Instance) -> TableRef { test_util::create_test_table(instance, ConcreteDataType::timestamp_millisecond_datatype()) .await - .unwrap(); + .unwrap() } async fn assert_test_table_not_found(instance: &Instance) { diff --git a/src/datanode/src/tests/test_util.rs b/src/datanode/src/tests/test_util.rs index 91e344a00d..d59f9f5670 100644 --- a/src/datanode/src/tests/test_util.rs +++ b/src/datanode/src/tests/test_util.rs @@ -22,6 +22,7 @@ use servers::Mode; use snafu::ResultExt; use table::engine::{EngineContext, TableEngineRef}; use table::requests::{CreateTableRequest, TableOptions}; +use table::TableRef; use crate::datanode::{ DatanodeOptions, FileConfig, ObjectStoreConfig, ProcedureConfig, StorageConfig, WalConfig, @@ -84,7 +85,7 @@ fn create_tmp_dir_and_datanode_opts(name: &str) -> (DatanodeOptions, TestGuard) pub(crate) async fn create_test_table( instance: &Instance, ts_type: ConcreteDataType, -) -> Result<()> { +) -> Result { let column_schemas = vec![ ColumnSchema::new("host", ConcreteDataType::string_datatype(), true), ColumnSchema::new("cpu", ConcreteDataType::float64_datatype(), true), @@ -125,8 +126,8 @@ pub(crate) async fn create_test_table( .unwrap() .unwrap(); schema_provider - .register_table(table_name.to_string(), table) + .register_table(table_name.to_string(), table.clone()) .await .unwrap(); - Ok(()) + Ok(table) } diff --git a/src/datatypes/src/data_type.rs b/src/datatypes/src/data_type.rs index f9e40880c8..f0d225fdb9 100644 --- a/src/datatypes/src/data_type.rs +++ b/src/datatypes/src/data_type.rs @@ -183,6 +183,12 @@ impl ConcreteDataType { } } +impl From<&ConcreteDataType> for ConcreteDataType { + fn from(t: &ConcreteDataType) -> Self { + t.clone() + } +} + impl TryFrom<&ArrowDataType> for ConcreteDataType { type Error = Error; diff --git a/src/datatypes/src/value.rs b/src/datatypes/src/value.rs index 733b853808..319da1066d 100644 --- a/src/datatypes/src/value.rs +++ b/src/datatypes/src/value.rs @@ -248,7 +248,7 @@ impl Value { Value::Binary(v) => ScalarValue::LargeBinary(Some(v.to_vec())), Value::Date(v) => ScalarValue::Date32(Some(v.val())), Value::DateTime(v) => ScalarValue::Date64(Some(v.val())), - Value::Null => to_null_value(output_type), + Value::Null => to_null_scalar_value(output_type), Value::List(list) => { // Safety: The logical type of the value and output_type are the same. let list_type = output_type.as_list().unwrap(); @@ -261,7 +261,7 @@ impl Value { } } -fn to_null_value(output_type: &ConcreteDataType) -> ScalarValue { +pub fn to_null_scalar_value(output_type: &ConcreteDataType) -> ScalarValue { match output_type { ConcreteDataType::Null(_) => ScalarValue::Null, ConcreteDataType::Boolean(_) => ScalarValue::Boolean(None), @@ -285,7 +285,7 @@ fn to_null_value(output_type: &ConcreteDataType) -> ScalarValue { } ConcreteDataType::Dictionary(dict) => ScalarValue::Dictionary( Box::new(dict.key_type().as_arrow_type()), - Box::new(to_null_value(dict.value_type())), + Box::new(to_null_scalar_value(dict.value_type())), ), } } diff --git a/src/frontend/src/instance.rs b/src/frontend/src/instance.rs index f76a44c00a..3507b7e36e 100644 --- a/src/frontend/src/instance.rs +++ b/src/frontend/src/instance.rs @@ -53,7 +53,9 @@ use meta_client::MetaClientOptions; use partition::manager::PartitionRuleManager; use partition::route::TableRoutes; use query::parser::{PromQuery, QueryLanguageParser, QueryStatement}; +use query::plan::LogicalPlan; use query::query_engine::options::{validate_catalog_and_schema, QueryOptions}; +use query::query_engine::DescribeResult; use query::{QueryEngineFactory, QueryEngineRef}; use servers::error as server_error; use servers::error::{ExecuteQuerySnafu, ParsePromQLSnafu}; @@ -73,8 +75,9 @@ use sql::statements::statement::Statement; use crate::catalog::FrontendCatalogManager; use crate::error::{ - self, Error, ExecutePromqlSnafu, ExternalSnafu, InvalidInsertRequestSnafu, - MissingMetasrvOptsSnafu, ParseSqlSnafu, PlanStatementSnafu, Result, SqlExecInterceptedSnafu, + self, Error, ExecLogicalPlanSnafu, ExecutePromqlSnafu, ExternalSnafu, + InvalidInsertRequestSnafu, MissingMetasrvOptsSnafu, ParseSqlSnafu, PlanStatementSnafu, Result, + SqlExecInterceptedSnafu, }; use crate::expr_factory::{CreateExprFactoryRef, DefaultCreateExprFactory}; use crate::frontend::FrontendOptions; @@ -506,6 +509,14 @@ impl SqlQueryHandler for Instance { } } + async fn do_exec_plan(&self, plan: LogicalPlan, query_ctx: QueryContextRef) -> Result { + let _timer = timer!(metrics::METRIC_EXEC_PLAN_ELAPSED); + self.query_engine + .execute(plan, query_ctx) + .await + .context(ExecLogicalPlanSnafu) + } + async fn do_promql_query( &self, query: &PromQuery, @@ -523,8 +534,11 @@ impl SqlQueryHandler for Instance { &self, stmt: Statement, query_ctx: QueryContextRef, - ) -> Result> { - if let Statement::Query(_) = stmt { + ) -> Result> { + if matches!( + stmt, + Statement::Insert(_) | Statement::Query(_) | Statement::Delete(_) + ) { let plan = self .query_engine .planner() diff --git a/src/frontend/src/metrics.rs b/src/frontend/src/metrics.rs index 61ffab4089..cb7745d8c0 100644 --- a/src/frontend/src/metrics.rs +++ b/src/frontend/src/metrics.rs @@ -13,6 +13,7 @@ // limitations under the License. pub(crate) const METRIC_HANDLE_SQL_ELAPSED: &str = "frontend.handle_sql_elapsed"; +pub(crate) const METRIC_EXEC_PLAN_ELAPSED: &str = "frontend.exec_plan_elapsed"; pub(crate) const METRIC_HANDLE_SCRIPTS_ELAPSED: &str = "frontend.handle_scripts_elapsed"; pub(crate) const METRIC_RUN_SCRIPT_ELAPSED: &str = "frontend.run_script_elapsed"; diff --git a/src/meta-client/src/client.rs b/src/meta-client/src/client.rs index a893be1311..78a612f4a2 100644 --- a/src/meta-client/src/client.rs +++ b/src/meta-client/src/client.rs @@ -755,16 +755,21 @@ mod tests { async fn test_batch_put() { let tc = new_client("test_batch_put").await; - let req = BatchPutRequest::new() - .add_kv(tc.key("key"), b"value".to_vec()) - .add_kv(tc.key("key2"), b"value2".to_vec()); + let mut req = BatchPutRequest::new(); + for i in 0..256 { + req = req.add_kv( + tc.key(&format!("key-{}", i)), + format!("value-{}", i).into_bytes(), + ); + } + let res = tc.client.batch_put(req).await; assert_eq!(0, res.unwrap().take_prev_kvs().len()); - let req = RangeRequest::new().with_range(tc.key("key"), tc.key("key3")); + let req = RangeRequest::new().with_prefix(tc.key("key-")); let res = tc.client.range(req).await; let kvs = res.unwrap().take_kvs(); - assert_eq!(2, kvs.len()); + assert_eq!(256, kvs.len()); } #[tokio::test] @@ -772,16 +777,17 @@ mod tests { let tc = new_client("test_batch_get").await; tc.gen_data().await; - let req = BatchGetRequest::default() - .add_key(tc.key("key-1")) - .add_key(tc.key("key-2")); + let mut req = BatchGetRequest::default(); + for i in 0..256 { + req = req.add_key(tc.key(&format!("key-{}", i))); + } let mut res = tc.client.batch_get(req).await.unwrap(); - assert_eq!(2, res.take_kvs().len()); + assert_eq!(10, res.take_kvs().len()); let req = BatchGetRequest::default() .add_key(tc.key("key-1")) - .add_key(tc.key("key-222")); + .add_key(tc.key("key-999")); let mut res = tc.client.batch_get(req).await.unwrap(); assert_eq!(1, res.take_kvs().len()); diff --git a/src/meta-srv/src/handler.rs b/src/meta-srv/src/handler.rs index 84acb376c4..b2727a8ef2 100644 --- a/src/meta-srv/src/handler.rs +++ b/src/meta-srv/src/handler.rs @@ -25,12 +25,12 @@ use api::v1::meta::{ pub use check_leader_handler::CheckLeaderHandler; pub use collect_stats_handler::CollectStatsHandler; use common_meta::instruction::{Instruction, InstructionReply}; -use common_telemetry::{debug, info, warn}; +use common_telemetry::{debug, info, timer, warn}; use dashmap::DashMap; pub use failure_handler::RegionFailureHandler; pub use keep_lease_handler::KeepLeaseHandler; use metrics::{decrement_gauge, increment_gauge}; -pub use on_leader_start::OnLeaderStartHandler; +pub use on_leader_start_handler::OnLeaderStartHandler; pub use persist_stats_handler::PersistStatsHandler; pub use response_header_handler::ResponseHeaderHandler; use snafu::{OptionExt, ResultExt}; @@ -40,7 +40,7 @@ use tokio::sync::{oneshot, Notify, RwLock}; use self::node_stat::Stat; use crate::error::{self, DeserializeFromJsonSnafu, Result, UnexpectedInstructionReplySnafu}; use crate::metasrv::Context; -use crate::metrics::METRIC_META_HEARTBEAT_CONNECTION_NUM; +use crate::metrics::{METRIC_META_HANDLER_EXECUTE, METRIC_META_HEARTBEAT_CONNECTION_NUM}; use crate::sequence::Sequence; use crate::service::mailbox::{ BroadcastChannel, Channel, Mailbox, MailboxReceiver, MailboxRef, MessageId, @@ -52,7 +52,7 @@ pub(crate) mod failure_handler; mod keep_lease_handler; pub mod mailbox_handler; pub mod node_stat; -mod on_leader_start; +mod on_leader_start_handler; mod persist_stats_handler; pub(crate) mod region_lease_handler; mod response_header_handler; @@ -61,6 +61,12 @@ mod response_header_handler; pub trait HeartbeatHandler: Send + Sync { fn is_acceptable(&self, role: Role) -> bool; + fn name(&self) -> &'static str { + let type_name = std::any::type_name::(); + // short name + type_name.split("::").last().unwrap_or(type_name) + } + async fn handle( &self, req: &HeartbeatRequest, @@ -171,9 +177,22 @@ impl Pushers { } } +struct NameCachedHandler { + name: &'static str, + handler: Box, +} + +impl NameCachedHandler { + fn new(handler: impl HeartbeatHandler + 'static) -> Self { + let name = handler.name(); + let handler = Box::new(handler); + Self { name, handler } + } +} + #[derive(Clone, Default)] pub struct HeartbeatHandlerGroup { - handlers: Arc>>>, + handlers: Arc>>, pushers: Pushers, } @@ -187,7 +206,7 @@ impl HeartbeatHandlerGroup { pub async fn add_handler(&self, handler: impl HeartbeatHandler + 'static) { let mut handlers = self.handlers.write().await; - handlers.push(Box::new(handler)); + handlers.push(NameCachedHandler::new(handler)); } pub async fn register(&self, key: impl AsRef, pusher: Pusher) { @@ -223,13 +242,14 @@ impl HeartbeatHandlerGroup { err_msg: format!("invalid role: {:?}", req.header), })?; - for h in handlers.iter() { + for NameCachedHandler { name, handler } in handlers.iter() { if ctx.is_skip_all() { break; } - if h.is_acceptable(role) { - h.handle(&req, &mut ctx, &mut acc).await?; + if handler.is_acceptable(role) { + let _timer = timer!(METRIC_META_HANDLER_EXECUTE, &[("name", *name)]); + handler.handle(&req, &mut ctx, &mut acc).await?; } } let header = std::mem::take(&mut acc.header); @@ -383,7 +403,11 @@ mod tests { use api::v1::meta::{MailboxMessage, RequestHeader, Role, PROTOCOL_VERSION}; use tokio::sync::mpsc; - use crate::handler::{HeartbeatHandlerGroup, HeartbeatMailbox, Pusher}; + use crate::handler::mailbox_handler::MailboxHandler; + use crate::handler::{ + CheckLeaderHandler, CollectStatsHandler, HeartbeatHandlerGroup, HeartbeatMailbox, + OnLeaderStartHandler, PersistStatsHandler, Pusher, ResponseHeaderHandler, + }; use crate::sequence::Sequence; use crate::service::mailbox::{Channel, MailboxReceiver, MailboxRef}; use crate::service::store::memory::MemStore; @@ -452,4 +476,25 @@ mod tests { (mailbox, receiver) } + + #[tokio::test] + async fn test_handler_name() { + let group = HeartbeatHandlerGroup::default(); + group.add_handler(ResponseHeaderHandler::default()).await; + group.add_handler(CheckLeaderHandler::default()).await; + group.add_handler(OnLeaderStartHandler::default()).await; + group.add_handler(CollectStatsHandler::default()).await; + group.add_handler(MailboxHandler::default()).await; + group.add_handler(PersistStatsHandler::default()).await; + + let handlers = group.handlers.read().await; + + assert_eq!(6, handlers.len()); + assert_eq!("ResponseHeaderHandler", handlers[0].handler.name()); + assert_eq!("CheckLeaderHandler", handlers[1].handler.name()); + assert_eq!("OnLeaderStartHandler", handlers[2].handler.name()); + assert_eq!("CollectStatsHandler", handlers[3].handler.name()); + assert_eq!("MailboxHandler", handlers[4].handler.name()); + assert_eq!("PersistStatsHandler", handlers[5].handler.name()); + } } diff --git a/src/meta-srv/src/handler/node_stat.rs b/src/meta-srv/src/handler/node_stat.rs index ef9289470a..f0f17eefa9 100644 --- a/src/meta-srv/src/handler/node_stat.rs +++ b/src/meta-srv/src/handler/node_stat.rs @@ -42,6 +42,8 @@ pub struct Stat { pub write_io_rate: f64, /// Region stats on this node pub region_stats: Vec, + // The node epoch is used to check whether the node has restarted or redeployed. + pub node_epoch: u64, } #[derive(Debug, Default, Serialize, Deserialize)] @@ -79,6 +81,7 @@ impl TryFrom for Stat { is_leader, node_stat, region_stats, + node_epoch, .. } = value; @@ -104,6 +107,7 @@ impl TryFrom for Stat { read_io_rate: node_stat.read_io_rate, write_io_rate: node_stat.write_io_rate, region_stats: region_stats.into_iter().map(RegionStat::from).collect(), + node_epoch, }) } _ => Err(()), diff --git a/src/meta-srv/src/handler/on_leader_start.rs b/src/meta-srv/src/handler/on_leader_start_handler.rs similarity index 100% rename from src/meta-srv/src/handler/on_leader_start.rs rename to src/meta-srv/src/handler/on_leader_start_handler.rs diff --git a/src/meta-srv/src/handler/persist_stats_handler.rs b/src/meta-srv/src/handler/persist_stats_handler.rs index 09751c32ee..2b1ad61e11 100644 --- a/src/meta-srv/src/handler/persist_stats_handler.rs +++ b/src/meta-srv/src/handler/persist_stats_handler.rs @@ -23,9 +23,47 @@ use crate::metasrv::Context; const MAX_CACHED_STATS_PER_KEY: usize = 10; +#[derive(Default)] +struct EpochStats { + stats: Vec, + epoch: Option, +} + +impl EpochStats { + #[inline] + fn drain_all(&mut self) -> Vec { + self.stats.drain(..).collect() + } + + #[inline] + fn clear(&mut self) { + self.stats.clear(); + } + + #[inline] + fn push(&mut self, stat: Stat) { + self.stats.push(stat); + } + + #[inline] + fn len(&self) -> usize { + self.stats.len() + } + + #[inline] + fn epoch(&self) -> Option { + self.epoch + } + + #[inline] + fn set_epoch(&mut self, epoch: u64) { + self.epoch = Some(epoch); + } +} + #[derive(Default)] pub struct PersistStatsHandler { - stats_cache: DashMap>, + stats_cache: DashMap, } #[async_trait::async_trait] @@ -40,26 +78,47 @@ impl HeartbeatHandler for PersistStatsHandler { ctx: &mut Context, acc: &mut HeartbeatAccumulator, ) -> Result<()> { - let Some(stat) = acc.stat.take() else { return Ok(()) }; + let Some(current_stat) = acc.stat.take() else { return Ok(()) }; - let key = stat.stat_key(); + let key = current_stat.stat_key(); let mut entry = self .stats_cache .entry(key) - .or_insert_with(|| Vec::with_capacity(MAX_CACHED_STATS_PER_KEY)); - let stats = entry.value_mut(); - stats.push(stat); + .or_insert_with(EpochStats::default); - if stats.len() < MAX_CACHED_STATS_PER_KEY { + let key: Vec = key.into(); + let epoch_stats = entry.value_mut(); + + let refresh = if let Some(epoch) = epoch_stats.epoch() { + // This node may have been redeployed. + if current_stat.node_epoch > epoch { + epoch_stats.set_epoch(current_stat.node_epoch); + epoch_stats.clear(); + true + } else { + false + } + } else { + epoch_stats.set_epoch(current_stat.node_epoch); + // If the epoch is empty, it indicates that the current node sending the heartbeat + // for the first time to the current meta leader, so it is necessary to persist + // the data to the KV store as soon as possible. + true + }; + + epoch_stats.push(current_stat); + + if !refresh && epoch_stats.len() < MAX_CACHED_STATS_PER_KEY { return Ok(()); } - let stats = stats.drain(..).collect(); - let val = StatValue { stats }; - + let value: Vec = StatValue { + stats: epoch_stats.drain_all(), + } + .try_into()?; let put = PutRequest { - key: key.into(), - value: val.try_into()?, + key, + value, ..Default::default() }; @@ -74,12 +133,11 @@ mod tests { use std::sync::atomic::AtomicBool; use std::sync::Arc; - use api::v1::meta::RangeRequest; - use super::*; use crate::handler::{HeartbeatMailbox, Pushers}; use crate::keys::StatKey; use crate::sequence::Sequence; + use crate::service::store::ext::KvStoreExt; use crate::service::store::memory::MemStore; #[tokio::test] @@ -88,7 +146,7 @@ mod tests { let kv_store = Arc::new(MemStore::new()); let seq = Sequence::new("test_seq", 0, 10, kv_store.clone()); let mailbox = HeartbeatMailbox::create(Pushers::default(), seq); - let mut ctx = Context { + let ctx = Context { server_addr: "127.0.0.1:0000".to_string(), in_memory, kv_store, @@ -98,9 +156,40 @@ mod tests { is_infancy: false, }; - let req = HeartbeatRequest::default(); let handler = PersistStatsHandler::default(); - for i in 1..=MAX_CACHED_STATS_PER_KEY { + handle_request_many_times(ctx.clone(), &handler, 1).await; + + let key = StatKey { + cluster_id: 3, + node_id: 101, + }; + let res = ctx.in_memory.get(key.try_into().unwrap()).await.unwrap(); + assert!(res.is_some()); + let kv = res.unwrap(); + let key: StatKey = kv.key.clone().try_into().unwrap(); + assert_eq!(3, key.cluster_id); + assert_eq!(101, key.node_id); + let val: StatValue = kv.value.try_into().unwrap(); + // first new stat must be set in kv store immediately + assert_eq!(1, val.stats.len()); + assert_eq!(Some(1), val.stats[0].region_num); + + handle_request_many_times(ctx.clone(), &handler, 10).await; + let res = ctx.in_memory.get(key.try_into().unwrap()).await.unwrap(); + assert!(res.is_some()); + let kv = res.unwrap(); + let val: StatValue = kv.value.try_into().unwrap(); + // refresh every 10 stats + assert_eq!(10, val.stats.len()); + } + + async fn handle_request_many_times( + mut ctx: Context, + handler: &PersistStatsHandler, + loop_times: i32, + ) { + let req = HeartbeatRequest::default(); + for i in 1..=loop_times { let mut acc = HeartbeatAccumulator { stat: Some(Stat { cluster_id: 3, @@ -112,30 +201,5 @@ mod tests { }; handler.handle(&req, &mut ctx, &mut acc).await.unwrap(); } - - let key = StatKey { - cluster_id: 3, - node_id: 101, - }; - - let req = RangeRequest { - key: key.try_into().unwrap(), - ..Default::default() - }; - - let res = ctx.in_memory.range(req).await.unwrap(); - - assert_eq!(1, res.kvs.len()); - - let kv = &res.kvs[0]; - - let key: StatKey = kv.key.clone().try_into().unwrap(); - assert_eq!(3, key.cluster_id); - assert_eq!(101, key.node_id); - - let val: StatValue = kv.value.clone().try_into().unwrap(); - - assert_eq!(10, val.stats.len()); - assert_eq!(Some(1), val.stats[0].region_num); } } diff --git a/src/meta-srv/src/metrics.rs b/src/meta-srv/src/metrics.rs index f468c4fef5..cac6598991 100644 --- a/src/meta-srv/src/metrics.rs +++ b/src/meta-srv/src/metrics.rs @@ -17,3 +17,4 @@ pub(crate) const METRIC_META_CREATE_SCHEMA: &str = "meta.create_schema"; pub(crate) const METRIC_META_KV_REQUEST: &str = "meta.kv_request"; pub(crate) const METRIC_META_ROUTE_REQUEST: &str = "meta.route_request"; pub(crate) const METRIC_META_HEARTBEAT_CONNECTION_NUM: &str = "meta.heartbeat_connection_num"; +pub(crate) const METRIC_META_HANDLER_EXECUTE: &str = "meta.handler_execute"; diff --git a/src/meta-srv/src/service/store/etcd.rs b/src/meta-srv/src/service/store/etcd.rs index 55cebce8bd..22834b355b 100644 --- a/src/meta-srv/src/service/store/etcd.rs +++ b/src/meta-srv/src/service/store/etcd.rs @@ -24,6 +24,7 @@ use common_error::prelude::*; use common_telemetry::{timer, warn}; use etcd_client::{ Client, Compare, CompareOp, DeleteOptions, GetOptions, PutOptions, Txn, TxnOp, TxnOpResponse, + TxnResponse, }; use crate::error; @@ -31,6 +32,8 @@ use crate::error::Result; use crate::metrics::METRIC_META_KV_REQUEST; use crate::service::store::kv::{KvStore, KvStoreRef}; +const MAX_TXN_SIZE: usize = 128; + pub struct EtcdStore { client: Client, } @@ -51,6 +54,51 @@ impl EtcdStore { pub fn with_etcd_client(client: Client) -> Result { Ok(Arc::new(Self { client })) } + + async fn do_multi_txn(&self, mut txn_ops: Vec) -> Result> { + if txn_ops.len() < MAX_TXN_SIZE { + // fast path + let txn = Txn::new().and_then(txn_ops); + let txn_res = self + .client + .kv_client() + .txn(txn) + .await + .context(error::EtcdFailedSnafu)?; + return Ok(vec![txn_res]); + } + + let mut txns = vec![]; + loop { + if txn_ops.is_empty() { + break; + } + + if txn_ops.len() < MAX_TXN_SIZE { + let txn = Txn::new().and_then(txn_ops); + txns.push(txn); + break; + } + + let part = txn_ops.drain(..MAX_TXN_SIZE).collect::>(); + let txn = Txn::new().and_then(part); + txns.push(txn); + } + + let mut txn_responses = Vec::with_capacity(txns.len()); + // Considering the pressure on etcd, it would be more appropriate to execute txn in + // a serial manner. + for txn in txns { + let txn_res = self + .client + .kv_client() + .txn(txn) + .await + .context(error::EtcdFailedSnafu)?; + txn_responses.push(txn_res); + } + Ok(txn_responses) + } } #[async_trait::async_trait] @@ -142,26 +190,19 @@ impl KvStore for EtcdStore { .into_iter() .map(|k| TxnOp::get(k, options.clone())) .collect(); - if get_ops.len() > 128 { - warn!("batch_get too large, size: {}", get_ops.len()); - } - let txn = Txn::new().and_then(get_ops); - let txn_res = self - .client - .kv_client() - .txn(txn) - .await - .context(error::EtcdFailedSnafu)?; + let txn_responses = self.do_multi_txn(get_ops).await?; let mut kvs = vec![]; - for op_res in txn_res.op_responses() { - let get_res = match op_res { - TxnOpResponse::Get(get_res) => get_res, - _ => unreachable!(), - }; + for txn_res in txn_responses { + for op_res in txn_res.op_responses() { + let get_res = match op_res { + TxnOpResponse::Get(get_res) => get_res, + _ => unreachable!(), + }; - kvs.extend(get_res.kvs().iter().map(KvPair::from_etcd_kv)); + kvs.extend(get_res.kvs().iter().map(KvPair::from_etcd_kv)); + } } let header = Some(ResponseHeader::success(cluster_id)); @@ -188,27 +229,20 @@ impl KvStore for EtcdStore { .into_iter() .map(|kv| (TxnOp::put(kv.key, kv.value, options.clone()))) .collect::>(); - if put_ops.len() > 128 { - warn!("batch_put too large, size: {}", put_ops.len()); - } - let txn = Txn::new().and_then(put_ops); - let txn_res = self - .client - .kv_client() - .txn(txn) - .await - .context(error::EtcdFailedSnafu)?; + let txn_responses = self.do_multi_txn(put_ops).await?; let mut prev_kvs = vec![]; - for op_res in txn_res.op_responses() { - match op_res { - TxnOpResponse::Put(put_res) => { - if let Some(prev_kv) = put_res.prev_key() { - prev_kvs.push(KvPair::from_etcd_kv(prev_kv)); + for txn_res in txn_responses { + for op_res in txn_res.op_responses() { + match op_res { + TxnOpResponse::Put(put_res) => { + if let Some(prev_kv) = put_res.prev_key() { + prev_kvs.push(KvPair::from_etcd_kv(prev_kv)); + } } + _ => unreachable!(), // never get here } - _ => unreachable!(), // never get here } } @@ -238,31 +272,23 @@ impl KvStore for EtcdStore { .into_iter() .map(|k| TxnOp::delete(k, options.clone())) .collect::>(); - if delete_ops.len() > 128 { - warn!("batch_delete too large, size: {}", delete_ops.len()); - } - let txn = Txn::new().and_then(delete_ops); - let txn_res = self - .client - .kv_client() - .txn(txn) - .await - .context(error::EtcdFailedSnafu)?; + let txn_responses = self.do_multi_txn(delete_ops).await?; - for op_res in txn_res.op_responses() { - match op_res { - TxnOpResponse::Delete(delete_res) => { - delete_res.prev_kvs().iter().for_each(|kv| { - prev_kvs.push(KvPair::from_etcd_kv(kv)); - }); + for txn_res in txn_responses { + for op_res in txn_res.op_responses() { + match op_res { + TxnOpResponse::Delete(delete_res) => { + delete_res.prev_kvs().iter().for_each(|kv| { + prev_kvs.push(KvPair::from_etcd_kv(kv)); + }); + } + _ => unreachable!(), // never get here } - _ => unreachable!(), // never get here } } let header = Some(ResponseHeader::success(cluster_id)); - Ok(BatchDeleteResponse { header, prev_kvs }) } diff --git a/src/meta-srv/src/service/store/ext.rs b/src/meta-srv/src/service/store/ext.rs index 9e629ef176..2cbe2c18ca 100644 --- a/src/meta-srv/src/service/store/ext.rs +++ b/src/meta-srv/src/service/store/ext.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use api::v1::meta::{KeyValue, RangeRequest}; +use api::v1::meta::{DeleteRangeRequest, KeyValue, RangeRequest}; use crate::error::Result; use crate::service::store::kv::KvStore; @@ -24,6 +24,10 @@ pub trait KvStoreExt { /// Check if a key exists, it does not return the value. async fn exists(&self, key: Vec) -> Result; + + /// Delete the value by the given key. If prev_kv is true, + /// the previous key-value pairs will be returned. + async fn delete(&self, key: Vec, prev_kv: bool) -> Result>; } #[async_trait::async_trait] @@ -53,6 +57,18 @@ where Ok(!kvs.is_empty()) } + + async fn delete(&self, key: Vec, prev_kv: bool) -> Result> { + let req = DeleteRangeRequest { + key, + prev_kv, + ..Default::default() + }; + + let mut prev_kvs = self.delete_range(req).await?.prev_kvs; + + Ok(prev_kvs.pop()) + } } #[cfg(test)] @@ -115,6 +131,31 @@ mod tests { assert!(!in_mem.exists("test_key".as_bytes().to_vec()).await.unwrap()); } + #[tokio::test] + async fn test_delete() { + let mut in_mem = Arc::new(MemStore::new()) as KvStoreRef; + + let mut prev_kv = in_mem + .delete("test_key1".as_bytes().to_vec(), true) + .await + .unwrap(); + assert!(prev_kv.is_none()); + + put_stats_to_store(&mut in_mem).await; + + assert!(in_mem + .exists("test_key1".as_bytes().to_vec()) + .await + .unwrap()); + + prev_kv = in_mem + .delete("test_key1".as_bytes().to_vec(), true) + .await + .unwrap(); + assert!(prev_kv.is_some()); + assert_eq!("test_key1".as_bytes(), prev_kv.unwrap().key); + } + async fn put_stats_to_store(store: &mut KvStoreRef) { store .put(PutRequest { diff --git a/src/query/src/datafusion.rs b/src/query/src/datafusion.rs index 4018fad5e4..cbe88a30ad 100644 --- a/src/query/src/datafusion.rs +++ b/src/query/src/datafusion.rs @@ -38,7 +38,6 @@ use datafusion::physical_plan::ExecutionPlan; use datafusion_common::ResolvedTableReference; use datafusion_expr::{DmlStatement, LogicalPlan as DfLogicalPlan, WriteOp}; use datatypes::prelude::VectorRef; -use datatypes::schema::Schema; use futures_util::StreamExt; use session::context::QueryContextRef; use snafu::{ensure, OptionExt, ResultExt}; @@ -59,7 +58,7 @@ use crate::physical_planner::PhysicalPlanner; use crate::physical_wrapper::PhysicalWrapperRef; use crate::plan::LogicalPlan; use crate::planner::{DfLogicalPlanner, LogicalPlanner}; -use crate::query_engine::{QueryEngineContext, QueryEngineState}; +use crate::query_engine::{DescribeResult, QueryEngineContext, QueryEngineState}; use crate::{metrics, QueryEngine}; pub struct DatafusionQueryEngine { @@ -234,11 +233,12 @@ impl QueryEngine for DatafusionQueryEngine { "datafusion" } - async fn describe(&self, plan: LogicalPlan) -> Result { - // TODO(sunng87): consider cache optmised logical plan between describe - // and execute + async fn describe(&self, plan: LogicalPlan) -> Result { let optimised_plan = self.optimize(&plan)?; - optimised_plan.schema() + Ok(DescribeResult { + schema: optimised_plan.schema()?, + logical_plan: optimised_plan, + }) } async fn execute(&self, plan: LogicalPlan, query_ctx: QueryContextRef) -> Result { @@ -553,7 +553,10 @@ mod tests { .await .unwrap(); - let schema = engine.describe(plan).await.unwrap(); + let DescribeResult { + schema, + logical_plan, + } = engine.describe(plan).await.unwrap(); assert_eq!( schema.column_schemas()[0], @@ -563,5 +566,6 @@ mod tests { true ) ); + assert_eq!("Limit: skip=0, fetch=20\n Aggregate: groupBy=[[]], aggr=[[SUM(numbers.number)]]\n TableScan: numbers projection=[number]", format!("{}", logical_plan.display_indent())); } } diff --git a/src/query/src/plan.rs b/src/query/src/plan.rs index b24ddc4504..14ff331122 100644 --- a/src/query/src/plan.rs +++ b/src/query/src/plan.rs @@ -12,13 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; use std::fmt::{Debug, Display}; +use common_query::prelude::ScalarValue; use datafusion_expr::LogicalPlan as DfLogicalPlan; +use datatypes::data_type::ConcreteDataType; use datatypes::schema::Schema; use snafu::ResultExt; -use crate::error::{ConvertDatafusionSchemaSnafu, Result}; +use crate::error::{ConvertDatafusionSchemaSnafu, DataFusionSnafu, Result}; /// A LogicalPlan represents the different types of relational /// operators (such as Projection, Filter, etc) and can be created by @@ -59,4 +62,28 @@ impl LogicalPlan { let LogicalPlan::DfPlan(plan) = self; plan.display_indent() } + + /// Walk the logical plan, find any `PlaceHolder` tokens, + /// and return a map of their IDs and ConcreteDataTypes + pub fn get_param_types(&self) -> Result>> { + let LogicalPlan::DfPlan(plan) = self; + let types = plan.get_parameter_types().context(DataFusionSnafu)?; + + Ok(types + .into_iter() + .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v)))) + .collect()) + } + + /// Return a logical plan with all placeholders/params (e.g $1 $2, + /// ...) replaced with corresponding values provided in the + /// params_values + pub fn replace_params_with_values(&self, values: &[ScalarValue]) -> Result { + let LogicalPlan::DfPlan(plan) = self; + + plan.clone() + .replace_params_with_values(values) + .context(DataFusionSnafu) + .map(LogicalPlan::DfPlan) + } } diff --git a/src/query/src/planner.rs b/src/query/src/planner.rs index 92131be37c..2ec425c24c 100644 --- a/src/query/src/planner.rs +++ b/src/query/src/planner.rs @@ -77,6 +77,7 @@ impl DfLogicalPlanner { }; PlanSqlSnafu { sql } })?; + Ok(LogicalPlan::DfPlan(result)) } diff --git a/src/query/src/query_engine.rs b/src/query/src/query_engine.rs index 153c90274c..e996ebfc7c 100644 --- a/src/query/src/query_engine.rs +++ b/src/query/src/query_engine.rs @@ -43,6 +43,15 @@ pub use crate::query_engine::state::QueryEngineState; pub type SqlStatementExecutorRef = Arc; +/// Describe statement result +#[derive(Debug)] +pub struct DescribeResult { + /// The schema of statement + pub schema: Schema, + /// The logical plan for statement + pub logical_plan: LogicalPlan, +} + #[async_trait] pub trait SqlStatementExecutor: Send + Sync { async fn execute_sql(&self, stmt: Statement, query_ctx: QueryContextRef) -> Result; @@ -58,7 +67,7 @@ pub trait QueryEngine: Send + Sync { fn name(&self) -> &str; - async fn describe(&self, plan: LogicalPlan) -> Result; + async fn describe(&self, plan: LogicalPlan) -> Result; async fn execute(&self, plan: LogicalPlan, query_ctx: QueryContextRef) -> Result; diff --git a/src/servers/Cargo.toml b/src/servers/Cargo.toml index 9212a1cd1d..8621aa0522 100644 --- a/src/servers/Cargo.toml +++ b/src/servers/Cargo.toml @@ -33,6 +33,9 @@ common-runtime = { path = "../common/runtime" } common-telemetry = { path = "../common/telemetry" } common-time = { path = "../common/time" } datafusion.workspace = true +datafusion-common.workspace = true +datafusion-expr.workspace = true + datatypes = { path = "../datatypes" } derive_builder = "0.12" digest = "0.10" diff --git a/src/servers/src/error.rs b/src/servers/src/error.rs index 2f224da225..317953d970 100644 --- a/src/servers/src/error.rs +++ b/src/servers/src/error.rs @@ -11,7 +11,6 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - use std::any::Any; use std::net::SocketAddr; use std::string::FromUtf8Error; @@ -23,6 +22,7 @@ use base64::DecodeError; use catalog; use common_error::prelude::*; use common_telemetry::logging; +use datatypes::prelude::ConcreteDataType; use query::parser::PromQuery; use serde_json::json; use snafu::Location; @@ -75,6 +75,12 @@ pub enum Error { source: BoxedError, }, + #[snafu(display("Failed to execute plan, source: {}", source))] + ExecutePlan { + location: Location, + source: BoxedError, + }, + #[snafu(display("{source}"))] ExecuteGrpcQuery { location: Location, @@ -250,6 +256,12 @@ pub enum Error { source: query::error::Error, }, + #[snafu(display("Failed to get param types, source: {source}, location: {location}"))] + GetPreparedStmtParams { + source: query::error::Error, + location: Location, + }, + #[snafu(display("{}", reason))] UnexpectedResult { reason: String, location: Location }, @@ -269,10 +281,7 @@ pub enum Error { #[cfg(feature = "pprof")] #[snafu(display("Failed to dump pprof data, source: {}", source))] - DumpPprof { - #[snafu(backtrace)] - source: common_pprof::Error, - }, + DumpPprof { source: common_pprof::Error }, #[snafu(display("Failed to update jemalloc metrics, source: {source}, location: {location}"))] UpdateJemallocMetrics { @@ -285,6 +294,31 @@ pub enum Error { source: datafusion::error::DataFusionError, location: Location, }, + + #[snafu(display( + "Failed to replace params with values in prepared statement, source: {source}, location: {location}" + ))] + ReplacePreparedStmtParams { + source: query::error::Error, + location: Location, + }, + + #[snafu(display("Failed to convert scalar value, source: {source}, location: {location}"))] + ConvertScalarValue { + source: datatypes::error::Error, + location: Location, + }, + + #[snafu(display( + "Expected type: {:?}, actual: {:?}, location: {location}", + expected, + actual + ))] + PreparedStmtTypeMismatch { + expected: ConcreteDataType, + actual: opensrv_mysql::ColumnType, + location: Location, + }, } pub type Result = std::result::Result; @@ -309,6 +343,7 @@ impl ErrorExt for Error { InsertScript { source, .. } | ExecuteScript { source, .. } | ExecuteQuery { source, .. } + | ExecutePlan { source, .. } | ExecuteGrpcQuery { source, .. } | CheckDatabaseValidity { source, .. } => source.status_code(), @@ -324,6 +359,7 @@ impl ErrorExt for Error { | InvalidFlightTicket { .. } | InvalidPrepareStatement { .. } | DataFrame { .. } + | PreparedStmtTypeMismatch { .. } | TimePrecision { .. } => StatusCode::InvalidArguments, InfluxdbLinesWrite { source, .. } | PromSeriesWrite { source, .. } => { @@ -347,7 +383,9 @@ impl ErrorExt for Error { DumpProfileData { source, .. } => source.status_code(), InvalidFlushArgument { .. } => StatusCode::InvalidArguments, - ParsePromQL { source, .. } => source.status_code(), + ReplacePreparedStmtParams { source, .. } + | GetPreparedStmtParams { source, .. } + | ParsePromQL { source, .. } => source.status_code(), Other { source, .. } => source.status_code(), UnexpectedResult { .. } => StatusCode::Unexpected, @@ -366,6 +404,8 @@ impl ErrorExt for Error { DumpPprof { source, .. } => source.status_code(), UpdateJemallocMetrics { .. } => StatusCode::Internal, + + ConvertScalarValue { source, .. } => source.status_code(), } } diff --git a/src/servers/src/http.rs b/src/servers/src/http.rs index cfe9f668a5..0b28ae0623 100644 --- a/src/servers/src/http.rs +++ b/src/servers/src/http.rs @@ -719,6 +719,8 @@ mod test { use datatypes::schema::{ColumnSchema, Schema}; use datatypes::vectors::{StringVector, UInt32Vector}; use query::parser::PromQuery; + use query::plan::LogicalPlan; + use query::query_engine::DescribeResult; use session::context::QueryContextRef; use tokio::sync::mpsc; @@ -760,11 +762,19 @@ mod test { unimplemented!() } + async fn do_exec_plan( + &self, + _plan: LogicalPlan, + _query_ctx: QueryContextRef, + ) -> std::result::Result { + unimplemented!() + } + async fn do_describe( &self, _stmt: sql::statements::statement::Statement, _query_ctx: QueryContextRef, - ) -> Result> { + ) -> Result> { unimplemented!() } diff --git a/src/servers/src/mysql.rs b/src/servers/src/mysql.rs index 04059124f7..0e73ae617a 100644 --- a/src/servers/src/mysql.rs +++ b/src/servers/src/mysql.rs @@ -14,5 +14,6 @@ mod federated; pub mod handler; +mod helper; pub mod server; pub mod writer; diff --git a/src/servers/src/mysql/handler.rs b/src/servers/src/mysql/handler.rs index d917db5597..6205829fcd 100644 --- a/src/servers/src/mysql/handler.rs +++ b/src/servers/src/mysql/handler.rs @@ -11,7 +11,6 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - use std::collections::HashMap; use std::net::SocketAddr; use std::sync::atomic::{AtomicU32, Ordering}; @@ -22,18 +21,20 @@ use async_trait::async_trait; use chrono::{NaiveDate, NaiveDateTime}; use common_error::prelude::ErrorExt; use common_query::Output; -use common_telemetry::tracing::log; -use common_telemetry::{error, timer, trace, warn}; +use common_telemetry::{error, logging, timer, trace, warn}; +use datatypes::prelude::ConcreteDataType; use metrics::increment_counter; use opensrv_mysql::{ - AsyncMysqlShim, Column, ColumnFlags, ColumnType, ErrorKind, InitWriter, ParamParser, - ParamValue, QueryResultWriter, StatementMetaWriter, ValueInner, + AsyncMysqlShim, Column, ErrorKind, InitWriter, ParamParser, ParamValue, QueryResultWriter, + StatementMetaWriter, ValueInner, }; use parking_lot::RwLock; +use query::plan::LogicalPlan; +use query::query_engine::DescribeResult; use rand::RngCore; use session::context::Channel; use session::{Session, SessionRef}; -use snafu::ensure; +use snafu::{ensure, ResultExt}; use sql::dialect::MySqlDialect; use sql::parser::ParserContext; use sql::statements::statement::Statement; @@ -41,17 +42,27 @@ use tokio::io::AsyncWrite; use crate::auth::{Identity, Password, UserProviderRef}; use crate::error::{self, InvalidPrepareStatementSnafu, Result}; +use crate::mysql::helper::{ + self, format_placeholder, replace_placeholders, transform_placeholders, +}; use crate::mysql::writer; +use crate::mysql::writer::create_mysql_column; use crate::query_handler::sql::ServerSqlQueryHandlerRef; +/// Cached SQL and logical plan +#[derive(Clone)] +struct SqlPlan { + query: String, + plan: Option, +} + // An intermediate shim for executing MySQL queries. pub struct MysqlInstanceShim { query_handler: ServerSqlQueryHandlerRef, salt: [u8; 20], session: SessionRef, user_provider: Option, - // TODO(SSebo): use something like moka to achieve TTL or LRU - prepared_stmts: Arc>>, + prepared_stmts: Arc>>, prepared_stmts_counter: AtomicU32, } @@ -105,14 +116,34 @@ impl MysqlInstanceShim { output } - fn set_query(&self, query: String) -> u32 { - let stmt_id = self.prepared_stmts_counter.fetch_add(1, Ordering::SeqCst); - let mut guard = self.prepared_stmts.write(); - guard.insert(stmt_id, query); + /// Execute the logical plan and return the output + async fn do_exec_plan(&self, query: &str, plan: LogicalPlan) -> Result { + if let Some(output) = crate::mysql::federated::check(query, self.session.context()) { + Ok(output) + } else { + self.query_handler + .do_exec_plan(plan, self.session.context()) + .await + } + } + + /// Describe the statement + async fn do_describe(&self, statement: Statement) -> Result> { + self.query_handler + .do_describe(statement, self.session.context()) + .await + } + + /// Save query and logical plan, return the unique id + fn save_plan(&self, plan: SqlPlan) -> u32 { + let stmt_id = self.prepared_stmts_counter.fetch_add(1, Ordering::Relaxed); + let mut prepared_stmts = self.prepared_stmts.write(); + prepared_stmts.insert(stmt_id, plan); stmt_id } - fn query(&self, stmt_id: u32) -> Option { + /// Retrieve the query and logical plan by id + fn plan(&self, stmt_id: u32) -> Option { let guard = self.prepared_stmts.read(); guard.get(&stmt_id).cloned() } @@ -175,15 +206,36 @@ impl AsyncMysqlShim for MysqlInstanceShi query: &'a str, w: StatementMetaWriter<'a, W>, ) -> Result<()> { - let (query, param_num) = replace_placeholder(query); - if let Err(e) = validate_query(&query).await { - w.error(ErrorKind::ER_UNKNOWN_ERROR, e.to_string().as_bytes()) - .await?; - return Ok(()); + let raw_query = query.clone(); + let (query, param_num) = replace_placeholders(query); + + let statement = validate_query(raw_query).await?; + + // We have to transform the placeholder, because DataFusion only parses placeholders + // in the form of "$i", it can't process "?" right now. + let statement = transform_placeholders(statement); + + let plan = self + .do_describe(statement.clone()) + .await? + .map(|DescribeResult { logical_plan, .. }| logical_plan); + + let params = if let Some(plan) = &plan { + prepared_params( + &plan + .get_param_types() + .context(error::GetPreparedStmtParamsSnafu)?, + )? + } else { + dummy_params(param_num)? }; - let stmt_id = self.set_query(query); - let params = dummy_params(param_num); + debug_assert_eq!(params.len(), param_num - 1); + + let stmt_id = self.save_plan(SqlPlan { + query: query.to_string(), + plan, + }); w.reply(stmt_id, ¶ms, &[]).await?; increment_counter!( @@ -216,7 +268,7 @@ impl AsyncMysqlShim for MysqlInstanceShi ] ); let params: Vec = p.into_iter().collect(); - let query = match self.query(stmt_id) { + let sql_plan = match self.plan(stmt_id) { None => { w.error( ErrorKind::ER_UNKNOWN_STMT_HANDLER, @@ -225,13 +277,36 @@ impl AsyncMysqlShim for MysqlInstanceShi .await?; return Ok(()); } - Some(query) => query, + Some(sql_plan) => sql_plan, }; - let query = replace_params(params, query); - log::debug!("execute replaced query: {}", query); + let (query, outputs) = match sql_plan.plan { + Some(plan) => { + let param_types = plan + .get_param_types() + .context(error::GetPreparedStmtParamsSnafu)?; + + if params.len() != param_types.len() { + return error::InternalSnafu { + err_msg: "prepare statement params number mismatch".to_string(), + } + .fail(); + } + let plan = replace_params_with_values(&plan, param_types, params)?; + logging::debug!("Mysql execute prepared plan: {}", plan.display_indent()); + let outputs = vec![self.do_exec_plan(&sql_plan.query, plan).await]; + + (sql_plan.query, outputs) + } + None => { + let query = replace_params(params, sql_plan.query); + logging::debug!("Mysql execute replaced query: {}", query); + let outputs = self.do_query(&query).await; + + (query, outputs) + } + }; - let outputs = self.do_query(&query).await; writer::write_output(w, &query, self.session.context(), outputs).await?; Ok(()) @@ -318,7 +393,7 @@ fn replace_params(params: Vec, query: String) -> String { ValueInner::Datetime(_) => NaiveDateTime::from(param.value).to_string(), ValueInner::Time(_) => format_duration(Duration::from(param.value)), }; - query = query.replace(&format!("${}", index), &s); + query = query.replace(&format_placeholder(index), &s); index += 1; } query @@ -331,6 +406,27 @@ fn format_duration(duration: Duration) -> String { format!("{}:{}:{}", hours, minutes, seconds) } +fn replace_params_with_values( + plan: &LogicalPlan, + param_types: HashMap>, + params: Vec, +) -> Result { + debug_assert_eq!(param_types.len(), params.len()); + + let mut values = Vec::with_capacity(params.len()); + + for (i, param) in params.iter().enumerate() { + if let Some(Some(t)) = param_types.get(&format_placeholder(i + 1)) { + let value = helper::convert_value(param, t)?; + + values.push(value); + } + } + + plan.replace_params_with_values(&values) + .context(error::ReplacePreparedStmtParamsSnafu) +} + async fn validate_query(query: &str) -> Result { let statement = ParserContext::create_with_dialect(query, &MySqlDialect {}); let mut statement = statement.map_err(|e| { @@ -352,29 +448,27 @@ async fn validate_query(query: &str) -> Result { Ok(statement) } -// dummy columns to satisfy opensrv_mysql, just the number of params is useful -// TODO(SSebo): use parameter type inference to return actual types -fn dummy_params(index: u32) -> Vec { - let mut params = vec![]; +fn dummy_params(index: usize) -> Result> { + let mut params = Vec::with_capacity(index - 1); for _ in 1..index { - params.push(opensrv_mysql::Column { - table: "".to_string(), - column: "".to_string(), - coltype: ColumnType::MYSQL_TYPE_LONG, - colflags: ColumnFlags::NOT_NULL_FLAG, - }); + params.push(create_mysql_column(&ConcreteDataType::null_datatype(), "")?); } - params + + Ok(params) } -fn replace_placeholder(query: &str) -> (String, u32) { - let mut query = query.to_string(); - let mut index = 1; - while let Some(position) = query.find('?') { - let place_holder = format!("${}", index); - query.replace_range(position..position + 1, &place_holder); - index += 1; +/// Parameters that the client must provide when executing the prepared statement. +fn prepared_params(param_types: &HashMap>) -> Result> { + let mut params = Vec::with_capacity(param_types.len()); + + // Placeholder index starts from 1 + for index in 1..=param_types.len() { + if let Some(Some(t)) = param_types.get(&format_placeholder(index)) { + let column = create_mysql_column(t, "")?; + params.push(column); + } } - (query, index) + + Ok(params) } diff --git a/src/servers/src/mysql/helper.rs b/src/servers/src/mysql/helper.rs new file mode 100644 index 0000000000..e734b821c2 --- /dev/null +++ b/src/servers/src/mysql/helper.rs @@ -0,0 +1,238 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +use std::ops::ControlFlow; +use std::time::Duration; + +use chrono::{NaiveDate, NaiveDateTime}; +use common_query::prelude::ScalarValue; +use datatypes::prelude::ConcreteDataType; +use datatypes::value::{self, Value}; +use itertools::Itertools; +use opensrv_mysql::{ParamValue, ValueInner}; +use snafu::ResultExt; +use sql::ast::{visit_expressions_mut, Expr, Value as ValueExpr, VisitMut}; +use sql::statements::statement::Statement; + +use crate::error::{self, Result}; + +/// Returns the placeholder string "$i". +pub fn format_placeholder(i: usize) -> String { + format!("${}", i) +} + +/// Replace all the "?" placeholder into "$i" in SQL, +/// returns the new SQL and the last placeholder index. +pub fn replace_placeholders(query: &str) -> (String, usize) { + let query_parts = query.split('?').collect::>(); + let parts_len = query_parts.len(); + let mut index = 0; + let query = query_parts + .into_iter() + .enumerate() + .map(|(i, part)| { + if i == parts_len - 1 { + return part.to_string(); + } + + index += 1; + format!("{part}{}", format_placeholder(index)) + }) + .join(""); + + (query, index + 1) +} + +/// Transform all the "?" placeholder into "$i". +/// Only works for Insert,Query and Delete statements. +pub fn transform_placeholders(stmt: Statement) -> Statement { + match stmt { + Statement::Query(mut query) => { + visit_placeholders(&mut query.inner); + Statement::Query(query) + } + Statement::Insert(mut insert) => { + visit_placeholders(&mut insert.inner); + Statement::Insert(insert) + } + Statement::Delete(mut delete) => { + visit_placeholders(&mut delete.inner); + Statement::Delete(delete) + } + stmt => stmt, + } +} + +fn visit_placeholders(v: &mut V) +where + V: VisitMut, +{ + let mut index = 1; + visit_expressions_mut(v, |expr| { + if let Expr::Value(ValueExpr::Placeholder(s)) = expr { + *s = format_placeholder(index); + index += 1; + } + ControlFlow::<()>::Continue(()) + }); +} + +/// Convert [`ParamValue`] into [`Value`] according to param type. +/// It will try it's best to do type conversions if possible +pub fn convert_value(param: &ParamValue, t: &ConcreteDataType) -> Result { + match param.value.into_inner() { + ValueInner::Int(i) => match t { + ConcreteDataType::Int8(_) => Ok(ScalarValue::Int8(Some(i as i8))), + ConcreteDataType::Int16(_) => Ok(ScalarValue::Int16(Some(i as i16))), + ConcreteDataType::Int32(_) => Ok(ScalarValue::Int32(Some(i as i32))), + ConcreteDataType::Int64(_) => Ok(ScalarValue::Int64(Some(i))), + ConcreteDataType::UInt8(_) => Ok(ScalarValue::UInt8(Some(i as u8))), + ConcreteDataType::UInt16(_) => Ok(ScalarValue::UInt16(Some(i as u16))), + ConcreteDataType::UInt32(_) => Ok(ScalarValue::UInt32(Some(i as u32))), + ConcreteDataType::UInt64(_) => Ok(ScalarValue::UInt64(Some(i as u64))), + ConcreteDataType::Float32(_) => Ok(ScalarValue::Float32(Some(i as f32))), + ConcreteDataType::Float64(_) => Ok(ScalarValue::Float64(Some(i as f64))), + ConcreteDataType::Timestamp(ts_type) => Value::Timestamp(ts_type.create_timestamp(i)) + .try_to_scalar_value(t) + .context(error::ConvertScalarValueSnafu), + + _ => error::PreparedStmtTypeMismatchSnafu { + expected: t, + actual: param.coltype, + } + .fail(), + }, + ValueInner::UInt(u) => match t { + ConcreteDataType::Int8(_) => Ok(ScalarValue::Int8(Some(u as i8))), + ConcreteDataType::Int16(_) => Ok(ScalarValue::Int16(Some(u as i16))), + ConcreteDataType::Int32(_) => Ok(ScalarValue::Int32(Some(u as i32))), + ConcreteDataType::Int64(_) => Ok(ScalarValue::Int64(Some(u as i64))), + ConcreteDataType::UInt8(_) => Ok(ScalarValue::UInt8(Some(u as u8))), + ConcreteDataType::UInt16(_) => Ok(ScalarValue::UInt16(Some(u as u16))), + ConcreteDataType::UInt32(_) => Ok(ScalarValue::UInt32(Some(u as u32))), + ConcreteDataType::UInt64(_) => Ok(ScalarValue::UInt64(Some(u))), + ConcreteDataType::Float32(_) => Ok(ScalarValue::Float32(Some(u as f32))), + ConcreteDataType::Float64(_) => Ok(ScalarValue::Float64(Some(u as f64))), + ConcreteDataType::Timestamp(ts_type) => { + Value::Timestamp(ts_type.create_timestamp(u as i64)) + .try_to_scalar_value(t) + .context(error::ConvertScalarValueSnafu) + } + + _ => error::PreparedStmtTypeMismatchSnafu { + expected: t, + actual: param.coltype, + } + .fail(), + }, + ValueInner::Double(f) => match t { + ConcreteDataType::Int8(_) => Ok(ScalarValue::Int8(Some(f as i8))), + ConcreteDataType::Int16(_) => Ok(ScalarValue::Int16(Some(f as i16))), + ConcreteDataType::Int32(_) => Ok(ScalarValue::Int32(Some(f as i32))), + ConcreteDataType::Int64(_) => Ok(ScalarValue::Int64(Some(f as i64))), + ConcreteDataType::UInt8(_) => Ok(ScalarValue::UInt8(Some(f as u8))), + ConcreteDataType::UInt16(_) => Ok(ScalarValue::UInt16(Some(f as u16))), + ConcreteDataType::UInt32(_) => Ok(ScalarValue::UInt32(Some(f as u32))), + ConcreteDataType::UInt64(_) => Ok(ScalarValue::UInt64(Some(f as u64))), + ConcreteDataType::Float32(_) => Ok(ScalarValue::Float32(Some(f as f32))), + ConcreteDataType::Float64(_) => Ok(ScalarValue::Float64(Some(f))), + + _ => error::PreparedStmtTypeMismatchSnafu { + expected: t, + actual: param.coltype, + } + .fail(), + }, + ValueInner::NULL => Ok(value::to_null_scalar_value(t)), + ValueInner::Bytes(b) => match t { + ConcreteDataType::String(_) => Ok(ScalarValue::Utf8(Some( + String::from_utf8_lossy(b).to_string(), + ))), + ConcreteDataType::Binary(_) => Ok(ScalarValue::LargeBinary(Some(b.to_vec()))), + + _ => error::PreparedStmtTypeMismatchSnafu { + expected: t, + actual: param.coltype, + } + .fail(), + }, + ValueInner::Date(_) => { + let date: common_time::Date = NaiveDate::from(param.value).into(); + Ok(ScalarValue::Date32(Some(date.val()))) + } + ValueInner::Datetime(_) => Ok(ScalarValue::Date64(Some( + NaiveDateTime::from(param.value).timestamp_millis(), + ))), + ValueInner::Time(_) => Ok(ScalarValue::Time64Nanosecond(Some( + Duration::from(param.value).as_millis() as i64, + ))), + } +} + +#[cfg(test)] +mod tests { + use sql::dialect::MySqlDialect; + use sql::parser::ParserContext; + + use super::*; + + #[test] + fn test_format_placeholder() { + assert_eq!("$1", format_placeholder(1)); + assert_eq!("$3", format_placeholder(3)); + } + + #[test] + fn test_replace_placeholders() { + let create = "create table demo(host string, ts timestamp time index)"; + let (sql, index) = replace_placeholders(create); + assert_eq!(create, sql); + assert_eq!(1, index); + + let insert = "insert into demo values(?,?,?)"; + let (sql, index) = replace_placeholders(insert); + assert_eq!("insert into demo values($1,$2,$3)", sql); + assert_eq!(4, index); + + let query = "select from demo where host=? and idc in (select idc from idcs where name=?) and cpu>?"; + let (sql, index) = replace_placeholders(query); + assert_eq!("select from demo where host=$1 and idc in (select idc from idcs where name=$2) and cpu>$3", sql); + assert_eq!(4, index); + } + + fn parse_sql(sql: &str) -> Statement { + let mut stmts = ParserContext::create_with_dialect(sql, &MySqlDialect {}).unwrap(); + stmts.remove(0) + } + + #[test] + fn test_transform_placeholders() { + let insert = parse_sql("insert into demo values(?,?,?)"); + let Statement::Insert(insert) = transform_placeholders(insert) else { unreachable!()}; + assert_eq!( + "INSERT INTO demo VALUES ($1, $2, $3)", + insert.inner.to_string() + ); + + let delete = parse_sql("delete from demo where host=? and idc=?"); + let Statement::Delete(delete) = transform_placeholders(delete) else { unreachable!()}; + assert_eq!( + "DELETE FROM demo WHERE host = $1 AND idc = $2", + delete.inner.to_string() + ); + + let select = parse_sql("select from demo where host=? and idc in (select idc from idcs where name=?) and cpu>?"); + let Statement::Query(select) = transform_placeholders(select) else { unreachable!()}; + assert_eq!("SELECT from AS demo WHERE host = $1 AND idc IN (SELECT idc FROM idcs WHERE name = $2) AND cpu > $3", select.inner.to_string()); + } +} diff --git a/src/servers/src/mysql/writer.rs b/src/servers/src/mysql/writer.rs index 6a060635b6..4249b82780 100644 --- a/src/servers/src/mysql/writer.rs +++ b/src/servers/src/mysql/writer.rs @@ -18,7 +18,7 @@ use common_query::Output; use common_recordbatch::{util, RecordBatch}; use common_telemetry::error; use datatypes::prelude::{ConcreteDataType, Value}; -use datatypes::schema::{ColumnSchema, SchemaRef}; +use datatypes::schema::SchemaRef; use opensrv_mysql::{ Column, ColumnFlags, ColumnType, ErrorKind, OkResponse, QueryResultWriter, RowWriter, }; @@ -176,8 +176,8 @@ impl<'a, W: AsyncWrite + Unpin> MysqlResultWriter<'a, W> { Value::Float64(v) => row_writer.write_col(v.0)?, Value::String(v) => row_writer.write_col(v.as_utf8())?, Value::Binary(v) => row_writer.write_col(v.deref())?, - Value::Date(v) => row_writer.write_col(v.val())?, - Value::DateTime(v) => row_writer.write_col(v.val())?, + Value::Date(v) => row_writer.write_col(v.to_chrono_date())?, + Value::DateTime(v) => row_writer.write_col(v.to_chrono_datetime())?, Value::Timestamp(v) => row_writer .write_col(v.to_timezone_aware_string(query_context.time_zone()))?, Value::List(_) => { @@ -208,8 +208,11 @@ impl<'a, W: AsyncWrite + Unpin> MysqlResultWriter<'a, W> { } } -fn create_mysql_column(column_schema: &ColumnSchema) -> Result { - let column_type = match column_schema.data_type { +pub(crate) fn create_mysql_column( + data_type: &ConcreteDataType, + column_name: &str, +) -> Result { + let column_type = match data_type { ConcreteDataType::Null(_) => Ok(ColumnType::MYSQL_TYPE_NULL), ConcreteDataType::Boolean(_) | ConcreteDataType::Int8(_) | ConcreteDataType::UInt8(_) => { Ok(ColumnType::MYSQL_TYPE_TINY) @@ -230,15 +233,12 @@ fn create_mysql_column(column_schema: &ColumnSchema) -> Result { ConcreteDataType::Date(_) => Ok(ColumnType::MYSQL_TYPE_DATE), ConcreteDataType::DateTime(_) => Ok(ColumnType::MYSQL_TYPE_DATETIME), _ => error::InternalSnafu { - err_msg: format!( - "not implemented for column datatype {:?}", - column_schema.data_type - ), + err_msg: format!("not implemented for column datatype {:?}", data_type), } .fail(), }; let mut colflags = ColumnFlags::empty(); - match column_schema.data_type { + match data_type { ConcreteDataType::UInt16(_) | ConcreteDataType::UInt8(_) | ConcreteDataType::UInt32(_) @@ -246,7 +246,7 @@ fn create_mysql_column(column_schema: &ColumnSchema) -> Result { _ => {} }; column_type.map(|column_type| Column { - column: column_schema.name.clone(), + column: column_name.to_string(), coltype: column_type, // TODO(LFC): Currently "table" and "colflags" are not relevant in MySQL server @@ -261,6 +261,6 @@ pub fn create_mysql_column_def(schema: &SchemaRef) -> Result> { schema .column_schemas() .iter() - .map(create_mysql_column) + .map(|column_schema| create_mysql_column(&column_schema.data_type, &column_schema.name)) .collect() } diff --git a/src/servers/src/postgres/handler.rs b/src/servers/src/postgres/handler.rs index 583e988474..5fc71be2f7 100644 --- a/src/servers/src/postgres/handler.rs +++ b/src/servers/src/postgres/handler.rs @@ -33,6 +33,7 @@ use pgwire::api::stmt::QueryParser; use pgwire::api::store::MemPortalStore; use pgwire::api::{ClientInfo, Type}; use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; +use query::query_engine::DescribeResult; use sql::dialect::PostgreSqlDialect; use sql::parser::ParserContext; use sql::statements::statement::Statement; @@ -405,7 +406,7 @@ impl ExtendedQueryHandler for PostgresServerHandler { // get Statement part of the tuple let (stmt, _) = stmt; - if let Some(schema) = self + if let Some(DescribeResult { schema, .. }) = self .query_handler .do_describe(stmt.clone(), self.session.context()) .await diff --git a/src/servers/src/query_handler/sql.rs b/src/servers/src/query_handler/sql.rs index af8bbed5c2..3ae964c9cc 100644 --- a/src/servers/src/query_handler/sql.rs +++ b/src/servers/src/query_handler/sql.rs @@ -17,8 +17,8 @@ use std::sync::Arc; use async_trait::async_trait; use common_error::prelude::*; use common_query::Output; -use datatypes::schema::Schema; use query::parser::PromQuery; +use query::plan::LogicalPlan; use session::context::QueryContextRef; use sql::statements::statement::Statement; @@ -26,6 +26,7 @@ use crate::error::{self, Result}; pub type SqlQueryHandlerRef = Arc + Send + Sync>; pub type ServerSqlQueryHandlerRef = SqlQueryHandlerRef; +use query::query_engine::DescribeResult; #[async_trait] pub trait SqlQueryHandler { @@ -37,6 +38,12 @@ pub trait SqlQueryHandler { query_ctx: QueryContextRef, ) -> Vec>; + async fn do_exec_plan( + &self, + plan: LogicalPlan, + query_ctx: QueryContextRef, + ) -> std::result::Result; + async fn do_promql_query( &self, query: &PromQuery, @@ -47,7 +54,7 @@ pub trait SqlQueryHandler { &self, stmt: Statement, query_ctx: QueryContextRef, - ) -> std::result::Result, Self::Error>; + ) -> std::result::Result, Self::Error>; async fn is_valid_schema( &self, @@ -83,6 +90,14 @@ where .collect() } + async fn do_exec_plan(&self, plan: LogicalPlan, query_ctx: QueryContextRef) -> Result { + self.0 + .do_exec_plan(plan, query_ctx) + .await + .map_err(BoxedError::new) + .context(error::ExecutePlanSnafu) + } + async fn do_promql_query( &self, query: &PromQuery, @@ -107,7 +122,7 @@ where &self, stmt: Statement, query_ctx: QueryContextRef, - ) -> Result> { + ) -> Result> { self.0 .do_describe(stmt, query_ctx) .await diff --git a/src/servers/tests/http/influxdb_test.rs b/src/servers/tests/http/influxdb_test.rs index 658dc929d1..d7a92543af 100644 --- a/src/servers/tests/http/influxdb_test.rs +++ b/src/servers/tests/http/influxdb_test.rs @@ -21,8 +21,9 @@ use axum::{http, Router}; use axum_test_helper::TestClient; use common_query::Output; use common_test_util::ports; -use datatypes::schema::Schema; use query::parser::PromQuery; +use query::plan::LogicalPlan; +use query::query_engine::DescribeResult; use servers::error::{Error, Result}; use servers::http::{HttpOptions, HttpServerBuilder}; use servers::influxdb::InfluxdbRequest; @@ -71,6 +72,14 @@ impl SqlQueryHandler for DummyInstance { unimplemented!() } + async fn do_exec_plan( + &self, + _plan: LogicalPlan, + _query_ctx: QueryContextRef, + ) -> std::result::Result { + unimplemented!() + } + async fn do_promql_query( &self, _: &PromQuery, @@ -83,7 +92,7 @@ impl SqlQueryHandler for DummyInstance { &self, _stmt: sql::statements::statement::Statement, _query_ctx: QueryContextRef, - ) -> Result> { + ) -> Result> { unimplemented!() } diff --git a/src/servers/tests/http/opentsdb_test.rs b/src/servers/tests/http/opentsdb_test.rs index 9cce749fc9..d01f2a0a10 100644 --- a/src/servers/tests/http/opentsdb_test.rs +++ b/src/servers/tests/http/opentsdb_test.rs @@ -20,8 +20,9 @@ use axum::Router; use axum_test_helper::TestClient; use common_query::Output; use common_test_util::ports; -use datatypes::schema::Schema; use query::parser::PromQuery; +use query::plan::LogicalPlan; +use query::query_engine::DescribeResult; use servers::error::{self, Result}; use servers::http::{HttpOptions, HttpServerBuilder}; use servers::opentsdb::codec::DataPoint; @@ -70,6 +71,14 @@ impl SqlQueryHandler for DummyInstance { unimplemented!() } + async fn do_exec_plan( + &self, + _plan: LogicalPlan, + _query_ctx: QueryContextRef, + ) -> std::result::Result { + unimplemented!() + } + async fn do_promql_query( &self, _: &PromQuery, @@ -82,7 +91,7 @@ impl SqlQueryHandler for DummyInstance { &self, _stmt: sql::statements::statement::Statement, _query_ctx: QueryContextRef, - ) -> Result> { + ) -> Result> { unimplemented!() } diff --git a/src/servers/tests/http/prometheus_test.rs b/src/servers/tests/http/prometheus_test.rs index ba70759a72..ddcdbb7fa1 100644 --- a/src/servers/tests/http/prometheus_test.rs +++ b/src/servers/tests/http/prometheus_test.rs @@ -23,9 +23,10 @@ use axum::Router; use axum_test_helper::TestClient; use common_query::Output; use common_test_util::ports; -use datatypes::schema::Schema; use prost::Message; use query::parser::PromQuery; +use query::plan::LogicalPlan; +use query::query_engine::DescribeResult; use servers::error::{Error, Result}; use servers::http::{HttpOptions, HttpServerBuilder}; use servers::prometheus; @@ -95,6 +96,14 @@ impl SqlQueryHandler for DummyInstance { unimplemented!() } + async fn do_exec_plan( + &self, + _plan: LogicalPlan, + _query_ctx: QueryContextRef, + ) -> std::result::Result { + unimplemented!() + } + async fn do_promql_query( &self, _: &PromQuery, @@ -107,7 +116,7 @@ impl SqlQueryHandler for DummyInstance { &self, _stmt: sql::statements::statement::Statement, _query_ctx: QueryContextRef, - ) -> Result> { + ) -> Result> { unimplemented!() } diff --git a/src/servers/tests/mod.rs b/src/servers/tests/mod.rs index e91eebb9d5..910b2ca132 100644 --- a/src/servers/tests/mod.rs +++ b/src/servers/tests/mod.rs @@ -21,8 +21,9 @@ use async_trait::async_trait; use catalog::local::{MemoryCatalogManager, MemoryCatalogProvider, MemorySchemaProvider}; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_query::Output; -use datatypes::schema::Schema; use query::parser::{PromQuery, QueryLanguageParser, QueryStatement}; +use query::plan::LogicalPlan; +use query::query_engine::DescribeResult; use query::{QueryEngineFactory, QueryEngineRef}; use script::engine::{CompileContext, EvalContext, Script, ScriptEngine}; use script::python::{PyEngine, PyScript}; @@ -78,6 +79,10 @@ impl SqlQueryHandler for DummyInstance { vec![Ok(output)] } + async fn do_exec_plan(&self, plan: LogicalPlan, query_ctx: QueryContextRef) -> Result { + Ok(self.query_engine.execute(plan, query_ctx).await.unwrap()) + } + async fn do_promql_query( &self, _: &PromQuery, @@ -90,7 +95,7 @@ impl SqlQueryHandler for DummyInstance { &self, stmt: Statement, query_ctx: QueryContextRef, - ) -> Result> { + ) -> Result> { if let Statement::Query(_) = stmt { let plan = self .query_engine diff --git a/src/sql/src/ast.rs b/src/sql/src/ast.rs index b35b71b51b..a72d7965b7 100644 --- a/src/sql/src/ast.rs +++ b/src/sql/src/ast.rs @@ -13,7 +13,7 @@ // limitations under the License. pub use sqlparser::ast::{ - BinaryOperator, ColumnDef, ColumnOption, ColumnOptionDef, DataType, Expr, Function, - FunctionArg, FunctionArgExpr, Ident, ObjectName, SqlOption, TableConstraint, TimezoneInfo, - Value, + visit_expressions_mut, BinaryOperator, ColumnDef, ColumnOption, ColumnOptionDef, DataType, + Expr, Function, FunctionArg, FunctionArgExpr, Ident, ObjectName, SqlOption, TableConstraint, + TimezoneInfo, Value, VisitMut, Visitor, }; diff --git a/src/table/src/test_util/empty_table.rs b/src/table/src/test_util/empty_table.rs index 679ace6887..0503515642 100644 --- a/src/table/src/test_util/empty_table.rs +++ b/src/table/src/test_util/empty_table.rs @@ -37,8 +37,10 @@ impl EmptyTable { .next_column_id(0) .options(req.table_options) .region_numbers(req.region_numbers) + .engine(req.engine) .build(); let table_info = TableInfoBuilder::default() + .table_id(req.id) .catalog_name(req.catalog_name) .schema_name(req.schema_name) .name(req.table_name) diff --git a/src/table/src/test_util/memtable.rs b/src/table/src/test_util/memtable.rs index f2e942ce8d..3d27650f20 100644 --- a/src/table/src/test_util/memtable.rs +++ b/src/table/src/test_util/memtable.rs @@ -56,7 +56,7 @@ impl MemTable { Self::new_with_catalog( table_name, recordbatch, - 0, + 1, "greptime".to_string(), "public".to_string(), regions, diff --git a/tests-integration/Cargo.toml b/tests-integration/Cargo.toml index 90bed17fa3..e4cc07dde3 100644 --- a/tests-integration/Cargo.toml +++ b/tests-integration/Cargo.toml @@ -13,6 +13,7 @@ axum = "0.6" axum-test-helper = { git = "https://github.com/sunng87/axum-test-helper.git", branch = "patch-1" } async-trait = "0.1" catalog = { path = "../src/catalog" } +chrono.workspace = true client = { path = "../src/client", features = ["testing"] } common-base = { path = "../src/common/base" } common-catalog = { path = "../src/common/catalog" } @@ -49,6 +50,7 @@ sqlx = { version = "0.6", features = [ "runtime-tokio-rustls", "mysql", "postgres", + "chrono", ] } table = { path = "../src/table" } tempfile.workspace = true diff --git a/tests-integration/tests/sql.rs b/tests-integration/tests/sql.rs index 915ab909bb..f5d770582f 100644 --- a/tests-integration/tests/sql.rs +++ b/tests-integration/tests/sql.rs @@ -11,6 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +use chrono::{DateTime, NaiveDate, NaiveDateTime, Utc}; use sqlx::mysql::MySqlPoolOptions; use sqlx::postgres::PgPoolOptions; use sqlx::Row; @@ -62,20 +63,24 @@ pub async fn test_mysql_crud(store_type: StorageType) { .await .unwrap(); - sqlx::query("create table demo(i bigint, ts timestamp time index)") + sqlx::query("create table demo(i bigint, ts timestamp time index, d date, dt datetime)") .execute(&pool) .await .unwrap(); for i in 0..10 { - sqlx::query("insert into demo values(?, ?)") + let dt = DateTime::::from_utc(NaiveDateTime::from_timestamp_opt(60, i).unwrap(), Utc); + let d = NaiveDate::from_yo_opt(2015, 100).unwrap(); + sqlx::query("insert into demo values(?, ?, ?, ?)") .bind(i) .bind(i) + .bind(d) + .bind(dt) .execute(&pool) .await .unwrap(); } - let rows = sqlx::query("select i from demo") + let rows = sqlx::query("select i, d, dt from demo") .fetch_all(&pool) .await .unwrap(); @@ -83,7 +88,34 @@ pub async fn test_mysql_crud(store_type: StorageType) { for (i, row) in rows.iter().enumerate() { let ret: i64 = row.get(0); + let d: NaiveDate = row.get(1); + let dt: DateTime = row.get(2); assert_eq!(ret, i as i64); + + let expected_d = NaiveDate::from_yo_opt(2015, 100).unwrap(); + assert_eq!(expected_d, d); + + let expected_dt = DateTime::::from_utc( + NaiveDateTime::from_timestamp_opt(60, i as u32).unwrap(), + Utc, + ); + + assert_eq!( + format!("{}", expected_dt.format("%Y-%m-%d %H:%M:%S")), + format!("{}", dt.format("%Y-%m-%d %H:%M:%S")) + ); + } + + let rows = sqlx::query("select i from demo where i=?") + .bind(6) + .fetch_all(&pool) + .await + .unwrap(); + assert_eq!(rows.len(), 1); + + for row in rows { + let ret: i64 = row.get(0); + assert_eq!(ret, 6); } sqlx::query("delete from demo") @@ -133,6 +165,18 @@ pub async fn test_postgres_crud(store_type: StorageType) { assert_eq!(ret, i as i64); } + let rows = sqlx::query("select i from demo where i=$1") + .bind(6) + .fetch_all(&pool) + .await + .unwrap(); + assert_eq!(rows.len(), 1); + + for row in rows { + let ret: i64 = row.get(0); + assert_eq!(ret, 6); + } + sqlx::query("delete from demo") .execute(&pool) .await