diff --git a/src/cmd/src/cli/export.rs b/src/cmd/src/cli/export.rs index 5634d3f1b6..ee5f5329cd 100644 --- a/src/cmd/src/cli/export.rs +++ b/src/cmd/src/cli/export.rs @@ -13,7 +13,7 @@ // limitations under the License. use std::collections::HashSet; -use std::path::Path; +use std::path::PathBuf; use std::sync::Arc; use async_trait::async_trait; @@ -29,7 +29,7 @@ use tracing_appender::non_blocking::WorkerGuard; use crate::cli::database::DatabaseClient; use crate::cli::{database, Instance, Tool}; -use crate::error::{EmptyResultSnafu, Error, FileIoSnafu, Result}; +use crate::error::{EmptyResultSnafu, Error, FileIoSnafu, Result, SchemaNotFoundSnafu}; type TableReference = (String, String, String); @@ -120,18 +120,34 @@ pub struct Export { } impl Export { + fn catalog_path(&self) -> PathBuf { + PathBuf::from(&self.output_dir).join(&self.catalog) + } + async fn get_db_names(&self) -> Result> { - if let Some(schema) = &self.schema { - Ok(vec![schema.clone()]) - } else { - self.all_db_names().await - } + let db_names = self.all_db_names().await?; + let Some(schema) = &self.schema else { + return Ok(db_names); + }; + + // Check if the schema exists + db_names + .into_iter() + .find(|db_name| db_name.to_lowercase() == schema.to_lowercase()) + .map(|name| vec![name]) + .context(SchemaNotFoundSnafu { + catalog: &self.catalog, + schema, + }) } /// Iterate over all db names. async fn all_db_names(&self) -> Result> { - let result = self.database_client.sql_in_public("SHOW DATABASES").await?; - let records = result.context(EmptyResultSnafu)?; + let records = self + .database_client + .sql_in_public("SHOW DATABASES") + .await? + .context(EmptyResultSnafu)?; let mut result = Vec::with_capacity(records.len()); for value in records { let Value::String(schema) = &value[0] else { @@ -167,8 +183,11 @@ impl Export { and table_catalog = \'{catalog}\' \ and table_schema = \'{schema}\'" ); - let result = self.database_client.sql_in_public(&sql).await?; - let records = result.context(EmptyResultSnafu)?; + let records = self + .database_client + .sql_in_public(&sql) + .await? + .context(EmptyResultSnafu)?; let mut metric_physical_tables = HashSet::with_capacity(records.len()); for value in records { let mut t = Vec::with_capacity(3); @@ -181,7 +200,6 @@ impl Export { metric_physical_tables.insert((t[0].clone(), t[1].clone(), t[2].clone())); } - // TODO: SQL injection hurts let sql = format!( "SELECT table_catalog, table_schema, table_name, table_type \ FROM information_schema.tables \ @@ -189,8 +207,11 @@ impl Export { and table_catalog = \'{catalog}\' \ and table_schema = \'{schema}\'", ); - let result = self.database_client.sql_in_public(&sql).await?; - let records = result.context(EmptyResultSnafu)?; + let records = self + .database_client + .sql_in_public(&sql) + .await? + .context(EmptyResultSnafu)?; debug!("Fetched table/view list: {:?}", records); @@ -232,19 +253,50 @@ impl Export { show_type: &str, catalog: &str, schema: &str, - table: &str, + table: Option<&str>, ) -> Result { - let sql = format!( - r#"SHOW CREATE {} "{}"."{}"."{}""#, - show_type, catalog, schema, table - ); - let result = self.database_client.sql_in_public(&sql).await?; - let records = result.context(EmptyResultSnafu)?; - let Value::String(create_table) = &records[0][1] else { + let sql = match table { + Some(table) => format!( + r#"SHOW CREATE {} "{}"."{}"."{}""#, + show_type, catalog, schema, table + ), + None => format!(r#"SHOW CREATE {} "{}"."{}""#, show_type, catalog, schema), + }; + let records = self + .database_client + .sql_in_public(&sql) + .await? + .context(EmptyResultSnafu)?; + let Value::String(create) = &records[0][1] else { unreachable!() }; - Ok(format!("{};\n", create_table)) + Ok(format!("{};\n", create)) + } + + async fn export_create_database(&self) -> Result<()> { + let timer = Instant::now(); + let db_names = self.get_db_names().await?; + let db_count = db_names.len(); + for schema in db_names { + let db_dir = self.catalog_path().join(format!("{schema}/")); + tokio::fs::create_dir_all(&db_dir) + .await + .context(FileIoSnafu)?; + let file = db_dir.join("create_database.sql"); + let mut file = File::create(file).await.context(FileIoSnafu)?; + let create_database = self + .show_create("DATABASE", &self.catalog, &schema, None) + .await?; + file.write_all(create_database.as_bytes()) + .await + .context(FileIoSnafu)?; + } + + let elapsed = timer.elapsed(); + info!("Success {db_count} jobs, cost: {elapsed:?}"); + + Ok(()) } async fn export_create_table(&self) -> Result<()> { @@ -261,43 +313,29 @@ impl Export { self.get_table_list(&self.catalog, &schema).await?; let table_count = metric_physical_tables.len() + remaining_tables.len() + views.len(); - let output_dir = Path::new(&self.output_dir) - .join(&self.catalog) - .join(format!("{schema}/")); - tokio::fs::create_dir_all(&output_dir) + let db_dir = self.catalog_path().join(format!("{schema}/")); + tokio::fs::create_dir_all(&db_dir) .await .context(FileIoSnafu)?; - let output_file = Path::new(&output_dir).join("create_tables.sql"); - let mut file = File::create(output_file).await.context(FileIoSnafu)?; + let file = db_dir.join("create_tables.sql"); + let mut file = File::create(file).await.context(FileIoSnafu)?; for (c, s, t) in metric_physical_tables.into_iter().chain(remaining_tables) { - match self.show_create("TABLE", &c, &s, &t).await { - Err(e) => { - error!(e; r#"Failed to export table "{}"."{}"."{}""#, c, s, t) - } - Ok(create_table) => { - file.write_all(create_table.as_bytes()) - .await - .context(FileIoSnafu)?; - } - } + let create_table = self.show_create("TABLE", &c, &s, Some(&t)).await?; + file.write_all(create_table.as_bytes()) + .await + .context(FileIoSnafu)?; } for (c, s, v) in views { - match self.show_create("VIEW", &c, &s, &v).await { - Err(e) => { - error!(e; r#"Failed to export view "{}"."{}"."{}""#, c, s, v) - } - Ok(create_view) => { - file.write_all(create_view.as_bytes()) - .await - .context(FileIoSnafu)?; - } - } + let create_view = self.show_create("VIEW", &c, &s, Some(&v)).await?; + file.write_all(create_view.as_bytes()) + .await + .context(FileIoSnafu)?; } info!( "Finished exporting {}.{schema} with {table_count} table schemas to path: {}", self.catalog, - output_dir.to_string_lossy() + db_dir.to_string_lossy() ); Ok::<(), Error>(()) @@ -317,7 +355,7 @@ impl Export { .count(); let elapsed = timer.elapsed(); - info!("Success {success}/{db_count} jobs, cost: {:?}", elapsed); + info!("Success {success}/{db_count} jobs, cost: {elapsed:?}"); Ok(()) } @@ -332,10 +370,8 @@ impl Export { let semaphore_moved = semaphore.clone(); tasks.push(async move { let _permit = semaphore_moved.acquire().await.unwrap(); - let output_dir = Path::new(&self.output_dir) - .join(&self.catalog) - .join(format!("{schema}/")); - tokio::fs::create_dir_all(&output_dir) + let db_dir = self.catalog_path().join(format!("{schema}/")); + tokio::fs::create_dir_all(&db_dir) .await .context(FileIoSnafu)?; @@ -359,7 +395,7 @@ impl Export { r#"COPY DATABASE "{}"."{}" TO '{}' {};"#, self.catalog, schema, - output_dir.to_str().unwrap(), + db_dir.to_str().unwrap(), with_options ); @@ -370,18 +406,18 @@ impl Export { info!( "Finished exporting {}.{schema} data into path: {}", self.catalog, - output_dir.to_string_lossy() + db_dir.to_string_lossy() ); // The export copy from sql - let copy_from_file = output_dir.join("copy_from.sql"); + let copy_from_file = db_dir.join("copy_from.sql"); let mut writer = BufWriter::new(File::create(copy_from_file).await.context(FileIoSnafu)?); let copy_database_from_sql = format!( r#"COPY DATABASE "{}"."{}" FROM '{}' WITH (FORMAT='parquet');"#, self.catalog, schema, - output_dir.to_str().unwrap() + db_dir.to_str().unwrap() ); writer .write(copy_database_from_sql.as_bytes()) @@ -418,9 +454,13 @@ impl Export { impl Tool for Export { async fn do_work(&self) -> Result<()> { match self.target { - ExportTarget::Schema => self.export_create_table().await, + ExportTarget::Schema => { + self.export_create_database().await?; + self.export_create_table().await + } ExportTarget::Data => self.export_database_data().await, ExportTarget::All => { + self.export_create_database().await?; self.export_create_table().await?; self.export_database_data().await } diff --git a/src/cmd/src/cli/import.rs b/src/cmd/src/cli/import.rs index 920e225d7a..b1d27fb0e0 100644 --- a/src/cmd/src/cli/import.rs +++ b/src/cmd/src/cli/import.rs @@ -17,15 +17,16 @@ use std::sync::Arc; use async_trait::async_trait; use clap::{Parser, ValueEnum}; +use common_catalog::consts::DEFAULT_SCHEMA_NAME; use common_telemetry::{error, info, warn}; -use snafu::ResultExt; +use snafu::{OptionExt, ResultExt}; use tokio::sync::Semaphore; use tokio::time::Instant; use tracing_appender::non_blocking::WorkerGuard; use crate::cli::database::DatabaseClient; use crate::cli::{database, Instance, Tool}; -use crate::error::{Error, FileIoSnafu, Result}; +use crate::error::{Error, FileIoSnafu, Result, SchemaNotFoundSnafu}; #[derive(Debug, Default, Clone, ValueEnum)] enum ImportTarget { @@ -100,14 +101,17 @@ pub struct Import { impl Import { async fn import_create_table(&self) -> Result<()> { - self.do_sql_job("create_tables.sql").await + // Use default db to creates other dbs + self.do_sql_job("create_database.sql", Some(DEFAULT_SCHEMA_NAME)) + .await?; + self.do_sql_job("create_tables.sql", None).await } async fn import_database_data(&self) -> Result<()> { - self.do_sql_job("copy_from.sql").await + self.do_sql_job("copy_from.sql", None).await } - async fn do_sql_job(&self, filename: &str) -> Result<()> { + async fn do_sql_job(&self, filename: &str, exec_db: Option<&str>) -> Result<()> { let timer = Instant::now(); let semaphore = Arc::new(Semaphore::new(self.parallelism)); let db_names = self.get_db_names().await?; @@ -125,7 +129,8 @@ impl Import { if sql.is_empty() { info!("Empty `{filename}` {database_input_dir:?}"); } else { - self.database_client.sql(&sql, &schema).await?; + let db = exec_db.unwrap_or(&schema); + self.database_client.sql(&sql, db).await?; info!("Imported `{filename}` for database {schema}"); } @@ -155,11 +160,20 @@ impl Import { } async fn get_db_names(&self) -> Result> { - if let Some(schema) = &self.schema { - Ok(vec![schema.clone()]) - } else { - self.all_db_names().await - } + let db_names = self.all_db_names().await?; + let Some(schema) = &self.schema else { + return Ok(db_names); + }; + + // Check if the schema exists + db_names + .into_iter() + .find(|db_name| db_name.to_lowercase() == schema.to_lowercase()) + .map(|name| vec![name]) + .context(SchemaNotFoundSnafu { + catalog: &self.catalog, + schema, + }) } // Get all database names in the input directory. diff --git a/src/cmd/src/error.rs b/src/cmd/src/error.rs index 66cc57c625..0c2a4fbad9 100644 --- a/src/cmd/src/error.rs +++ b/src/cmd/src/error.rs @@ -354,6 +354,14 @@ pub enum Error { error: tonic::transport::Error, msg: Option, }, + + #[snafu(display("Cannot find schema {schema} in catalog {catalog}"))] + SchemaNotFound { + catalog: String, + schema: String, + #[snafu(implicit)] + location: Location, + }, } pub type Result = std::result::Result; @@ -415,6 +423,7 @@ impl ErrorExt for Error { } Error::MetaClientInit { source, .. } => source.status_code(), Error::TonicTransport { .. } => StatusCode::Internal, + Error::SchemaNotFound { .. } => StatusCode::DatabaseNotFound, } } diff --git a/src/servers/src/http/handler.rs b/src/servers/src/http/handler.rs index d1690e79a9..4d5ca58461 100644 --- a/src/servers/src/http/handler.rs +++ b/src/servers/src/http/handler.rs @@ -20,6 +20,7 @@ use aide::transform::TransformOperation; use axum::extract::{Json, Query, State}; use axum::response::{IntoResponse, Response}; use axum::{Extension, Form}; +use common_catalog::parse_catalog_and_schema_from_db_string; use common_error::ext::ErrorExt; use common_error::status_code::StatusCode; use common_plugins::GREPTIME_EXEC_WRITE_COST; @@ -76,6 +77,11 @@ pub async fn sql( ) -> HttpResponse { let start = Instant::now(); let sql_handler = &state.sql_handler; + if let Some(db) = &query_params.db.or(form_params.db) { + let (catalog, schema) = parse_catalog_and_schema_from_db_string(db); + query_ctx.set_current_catalog(&catalog); + query_ctx.set_current_schema(&schema); + } let db = query_ctx.get_db_string(); query_ctx.set_channel(Channel::Http);