From facdda4d9fbcc0722049f7dfe8c87594903aaaa0 Mon Sep 17 00:00:00 2001 From: Weny Xu Date: Thu, 16 Mar 2023 11:36:38 +0800 Subject: [PATCH] feat: implement CONNECTION clause of Copy To (#1163) * feat: implement CONNECTION clause of Copy To * test: add tests for s3 backend * Apply suggestions from code review Co-authored-by: Yingwen --------- Co-authored-by: Yingwen --- .env.example | 2 +- Cargo.lock | 1 + src/datanode/Cargo.toml | 1 + src/datanode/src/instance/sql.rs | 6 +- src/datanode/src/sql/copy_table.rs | 51 +++++++-- src/datanode/src/sql/copy_table_from.rs | 126 +++++++++++---------- src/datanode/src/tests/instance_test.rs | 141 ++++++++++++++++++++++++ src/frontend/src/instance.rs | 2 +- src/sql/src/parsers/copy_parser.rs | 64 ++++++++++- src/sql/src/statements/copy.rs | 27 ++--- src/table/src/requests.rs | 1 + 11 files changed, 328 insertions(+), 94 deletions(-) diff --git a/.env.example b/.env.example index da1bbcc213..2f842b2f76 100644 --- a/.env.example +++ b/.env.example @@ -2,7 +2,7 @@ GT_S3_BUCKET=S3 bucket GT_S3_ACCESS_KEY_ID=S3 access key id GT_S3_ACCESS_KEY=S3 secret access key - +GT_S3_ENDPOINT_URL=S3 endpoint url # Settings for oss test GT_OSS_BUCKET=OSS bucket GT_OSS_ACCESS_KEY_ID=OSS access key id diff --git a/Cargo.lock b/Cargo.lock index 8445d67bd1..8b0a6e2f96 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2327,6 +2327,7 @@ dependencies = [ "tower", "tower-http", "url", + "uuid", ] [[package]] diff --git a/src/datanode/Cargo.toml b/src/datanode/Cargo.toml index eae2a9cd69..7745b8bcfd 100644 --- a/src/datanode/Cargo.toml +++ b/src/datanode/Cargo.toml @@ -64,6 +64,7 @@ tonic.workspace = true tower = { version = "0.4", features = ["full"] } tower-http = { version = "0.3", features = ["full"] } url = "2.3.1" +uuid.workspace = true [dev-dependencies] axum-test-helper = { git = "https://github.com/sunng87/axum-test-helper.git", branch = "patch-1" } diff --git a/src/datanode/src/instance/sql.rs b/src/datanode/src/instance/sql.rs index becbdbdea3..b363d2c8cd 100644 --- a/src/datanode/src/instance/sql.rs +++ b/src/datanode/src/instance/sql.rs @@ -163,14 +163,14 @@ impl Instance { QueryStatement::Sql(Statement::Copy(copy_table)) => match copy_table { CopyTable::To(copy_table) => { let (catalog_name, schema_name, table_name) = - table_idents_to_full_name(copy_table.table_name(), query_ctx.clone())?; - let file_name = copy_table.file_name().to_string(); - + table_idents_to_full_name(©_table.table_name, query_ctx.clone())?; + let file_name = copy_table.file_name; let req = CopyTableRequest { catalog_name, schema_name, table_name, file_name, + connection: copy_table.connection, }; self.sql_handler diff --git a/src/datanode/src/sql/copy_table.rs b/src/datanode/src/sql/copy_table.rs index 5b401649d4..8acd9447ae 100644 --- a/src/datanode/src/sql/copy_table.rs +++ b/src/datanode/src/sql/copy_table.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; use std::pin::Pin; use common_query::physical_plan::SessionContext; @@ -22,16 +23,54 @@ use datafusion::parquet::basic::{Compression, Encoding}; use datafusion::parquet::file::properties::WriterProperties; use datafusion::physical_plan::RecordBatchStream; use futures::TryStreamExt; -use object_store::services::Fs as Builder; -use object_store::{ObjectStore, ObjectStoreBuilder}; +use object_store::ObjectStore; use snafu::ResultExt; use table::engine::TableReference; use table::requests::CopyTableRequest; +use url::{ParseError, Url}; +use super::copy_table_from::{build_fs_backend, build_s3_backend, S3_SCHEMA}; use crate::error::{self, Result}; use crate::sql::SqlHandler; impl SqlHandler { + fn build_backend( + &self, + url: &str, + connection: HashMap, + ) -> Result<(ObjectStore, String)> { + let result = Url::parse(url); + + match result { + Ok(url) => { + let host = url.host_str(); + + let schema = url.scheme(); + + let path = url.path(); + + match schema.to_uppercase().as_str() { + S3_SCHEMA => { + let object_store = build_s3_backend(host, "/", connection)?; + Ok((object_store, path.to_string())) + } + + _ => error::UnsupportedBackendProtocolSnafu { + protocol: schema.to_string(), + } + .fail(), + } + } + Err(ParseError::RelativeUrlWithoutBase) => { + let object_store = build_fs_backend("/")?; + Ok((object_store, url.to_string())) + } + Err(err) => Err(error::Error::InvalidUrl { + url: url.to_string(), + source: err, + }), + } + } pub(crate) async fn copy_table(&self, req: CopyTableRequest) -> Result { let table_ref = TableReference { catalog: &req.catalog_name, @@ -52,13 +91,9 @@ impl SqlHandler { .context(error::TableScanExecSnafu)?; let stream = Box::pin(DfRecordBatchStreamAdapter::new(stream)); - let accessor = Builder::default() - .root("/") - .build() - .context(error::BuildBackendSnafu)?; - let object_store = ObjectStore::new(accessor).finish(); + let (object_store, file_name) = self.build_backend(&req.file_name, req.connection)?; - let mut parquet_writer = ParquetWriter::new(req.file_name, stream, object_store); + let mut parquet_writer = ParquetWriter::new(file_name, stream, object_store); // TODO(jiachun): // For now, COPY is implemented synchronously. // When copying large table, it will be blocked for a long time. diff --git a/src/datanode/src/sql/copy_table_from.rs b/src/datanode/src/sql/copy_table_from.rs index bedb36cf8e..3c21e15769 100644 --- a/src/datanode/src/sql/copy_table_from.rs +++ b/src/datanode/src/sql/copy_table_from.rs @@ -34,7 +34,7 @@ use url::{ParseError, Url}; use crate::error::{self, Result}; use crate::sql::SqlHandler; -const S3_SCHEMA: &str = "S3"; +pub const S3_SCHEMA: &str = "S3"; const ENDPOINT_URL: &str = "ENDPOINT_URL"; const ACCESS_KEY_ID: &str = "ACCESS_KEY_ID"; const SECRET_ACCESS_KEY: &str = "SECRET_ACCESS_KEY"; @@ -165,13 +165,10 @@ impl DataSource { Source::Dir }; - let accessor = Fs::default() - .root(&path) - .build() - .context(error::BuildBackendSnafu)?; + let object_store = build_fs_backend(&path)?; Ok(DataSource { - object_store: ObjectStore::new(accessor).finish(), + object_store, source, path, regex, @@ -184,59 +181,6 @@ impl DataSource { } } - fn build_s3_backend( - host: Option<&str>, - path: &str, - connection: HashMap, - ) -> Result { - let mut builder = S3::default(); - - builder.root(path); - - if let Some(bucket) = host { - builder.bucket(bucket); - } - - if let Some(endpoint) = connection.get(ENDPOINT_URL) { - builder.endpoint(endpoint); - } - - if let Some(region) = connection.get(REGION) { - builder.region(region); - } - - if let Some(key_id) = connection.get(ACCESS_KEY_ID) { - builder.access_key_id(key_id); - } - - if let Some(key) = connection.get(SECRET_ACCESS_KEY) { - builder.secret_access_key(key); - } - - if let Some(session_token) = connection.get(SESSION_TOKEN) { - builder.security_token(session_token); - } - - if let Some(enable_str) = connection.get(ENABLE_VIRTUAL_HOST_STYLE) { - let enable = enable_str.as_str().parse::().map_err(|e| { - error::InvalidConnectionSnafu { - msg: format!( - "failed to parse the option {}={}, {}", - ENABLE_VIRTUAL_HOST_STYLE, enable_str, e - ), - } - .build() - })?; - if enable { - builder.enable_virtual_host_style(); - } - } - - let accessor = builder.build().context(error::BuildBackendSnafu)?; - - Ok(ObjectStore::new(accessor).finish()) - } - fn from_url( url: Url, regex: Option, @@ -257,7 +201,7 @@ impl DataSource { }; let object_store = match schema.to_uppercase().as_str() { - S3_SCHEMA => DataSource::build_s3_backend(host, &dir, connection)?, + S3_SCHEMA => build_s3_backend(host, &dir, connection)?, _ => { return error::UnsupportedBackendProtocolSnafu { protocol: schema.to_string(), @@ -348,6 +292,68 @@ impl DataSource { } } +pub fn build_s3_backend( + host: Option<&str>, + path: &str, + connection: HashMap, +) -> Result { + let mut builder = S3::default(); + + builder.root(path); + + if let Some(bucket) = host { + builder.bucket(bucket); + } + + if let Some(endpoint) = connection.get(ENDPOINT_URL) { + builder.endpoint(endpoint); + } + + if let Some(region) = connection.get(REGION) { + builder.region(region); + } + + if let Some(key_id) = connection.get(ACCESS_KEY_ID) { + builder.access_key_id(key_id); + } + + if let Some(key) = connection.get(SECRET_ACCESS_KEY) { + builder.secret_access_key(key); + } + + if let Some(session_token) = connection.get(SESSION_TOKEN) { + builder.security_token(session_token); + } + + if let Some(enable_str) = connection.get(ENABLE_VIRTUAL_HOST_STYLE) { + let enable = enable_str.as_str().parse::().map_err(|e| { + error::InvalidConnectionSnafu { + msg: format!( + "failed to parse the option {}={}, {}", + ENABLE_VIRTUAL_HOST_STYLE, enable_str, e + ), + } + .build() + })?; + if enable { + builder.enable_virtual_host_style(); + } + } + + let accessor = builder.build().context(error::BuildBackendSnafu)?; + + Ok(ObjectStore::new(accessor).finish()) +} + +pub fn build_fs_backend(root: &str) -> Result { + let accessor = Fs::default() + .root(root) + .build() + .context(error::BuildBackendSnafu)?; + + Ok(ObjectStore::new(accessor).finish()) +} + #[cfg(test)] mod tests { diff --git a/src/datanode/src/tests/instance_test.rs b/src/datanode/src/tests/instance_test.rs index 9ec4682b89..1fca1d9775 100644 --- a/src/datanode/src/tests/instance_test.rs +++ b/src/datanode/src/tests/instance_test.rs @@ -12,11 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::env; use std::sync::Arc; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_query::Output; use common_recordbatch::util; +use common_telemetry::logging; use datatypes::data_type::ConcreteDataType; use datatypes::vectors::{Int64Vector, StringVector, UInt64Vector, VectorRef}; use query::parser::{QueryLanguageParser, QueryStatement}; @@ -797,6 +799,45 @@ async fn test_execute_copy_to() { assert!(matches!(output, Output::AffectedRows(2))); } +#[tokio::test(flavor = "multi_thread")] +async fn test_execute_copy_to_s3() { + logging::init_default_ut_logging(); + if let Ok(bucket) = env::var("GT_S3_BUCKET") { + if !bucket.is_empty() { + let instance = setup_test_instance("test_execute_copy_to_s3").await; + + // setups + execute_sql( + &instance, + "create table demo(host string, cpu double, memory double, ts timestamp time index);", + ) + .await; + + let output = execute_sql( + &instance, + r#"insert into demo(host, cpu, memory, ts) values + ('host1', 66.6, 1024, 1655276557000), + ('host2', 88.8, 333.3, 1655276558000) + "#, + ) + .await; + assert!(matches!(output, Output::AffectedRows(2))); + let key_id = env::var("GT_S3_ACCESS_KEY_ID").unwrap(); + let key = env::var("GT_S3_ACCESS_KEY").unwrap(); + let url = + env::var("GT_S3_ENDPOINT_URL").unwrap_or("https://s3.amazonaws.com".to_string()); + + let root = uuid::Uuid::new_v4().to_string(); + + // exports + let copy_to_stmt = format!("Copy demo TO 's3://{}/{}/export/demo.parquet' CONNECTION (ACCESS_KEY_ID='{}',SECRET_ACCESS_KEY='{}',ENDPOINT_URL='{}')", bucket, root, key_id, key, url); + + let output = execute_sql(&instance, ©_to_stmt).await; + assert!(matches!(output, Output::AffectedRows(2))); + } + } +} + #[tokio::test(flavor = "multi_thread")] async fn test_execute_copy_from() { let instance = setup_test_instance("test_execute_copy_from").await; @@ -882,6 +923,106 @@ async fn test_execute_copy_from() { } } +#[tokio::test(flavor = "multi_thread")] +async fn test_execute_copy_from_s3() { + logging::init_default_ut_logging(); + if let Ok(bucket) = env::var("GT_S3_BUCKET") { + if !bucket.is_empty() { + let instance = setup_test_instance("test_execute_copy_from_s3").await; + + // setups + execute_sql( + &instance, + "create table demo(host string, cpu double, memory double, ts timestamp time index);", + ) + .await; + + let output = execute_sql( + &instance, + r#"insert into demo(host, cpu, memory, ts) values + ('host1', 66.6, 1024, 1655276557000), + ('host2', 88.8, 333.3, 1655276558000) + "#, + ) + .await; + assert!(matches!(output, Output::AffectedRows(2))); + + // export + let root = uuid::Uuid::new_v4().to_string(); + let key_id = env::var("GT_S3_ACCESS_KEY_ID").unwrap(); + let key = env::var("GT_S3_ACCESS_KEY").unwrap(); + let url = + env::var("GT_S3_ENDPOINT_URL").unwrap_or("https://s3.amazonaws.com".to_string()); + + let copy_to_stmt = format!("Copy demo TO 's3://{}/{}/export/demo.parquet' CONNECTION (ACCESS_KEY_ID='{}',SECRET_ACCESS_KEY='{}',ENDPOINT_URL='{}')", bucket, root, key_id, key, url); + logging::info!("Copy table to s3: {}", copy_to_stmt); + + let output = execute_sql(&instance, ©_to_stmt).await; + assert!(matches!(output, Output::AffectedRows(2))); + + struct Test<'a> { + sql: &'a str, + table_name: &'a str, + } + let tests = [ + Test { + sql: &format!( + "Copy with_filename FROM 's3://{}/{}/export/demo.parquet_1_2'", + bucket, root + ), + table_name: "with_filename", + }, + Test { + sql: &format!("Copy with_path FROM 's3://{}/{}/export/'", bucket, root), + table_name: "with_path", + }, + Test { + sql: &format!( + "Copy with_pattern FROM 's3://{}/{}/export/' WITH (PATTERN = 'demo.*')", + bucket, root + ), + table_name: "with_pattern", + }, + ]; + + for test in tests { + // import + execute_sql( + &instance, + &format!( + "create table {}(host string, cpu double, memory double, ts timestamp time index);", + test.table_name + ), + ) + .await; + let sql = format!( + "{} CONNECTION (ACCESS_KEY_ID='{}',SECRET_ACCESS_KEY='{}',ENDPOINT_URL='{}')", + test.sql, key_id, key, url + ); + logging::info!("Running sql: {}", sql); + + let output = execute_sql(&instance, &sql).await; + assert!(matches!(output, Output::AffectedRows(2))); + + let output = execute_sql( + &instance, + &format!("select * from {} order by ts", test.table_name), + ) + .await; + let expected = "\ ++-------+------+--------+---------------------+ +| host | cpu | memory | ts | ++-------+------+--------+---------------------+ +| host1 | 66.6 | 1024.0 | 2022-06-15T07:02:37 | +| host2 | 88.8 | 333.3 | 2022-06-15T07:02:38 | ++-------+------+--------+---------------------+" + .to_string(); + check_output_stream(output, expected).await; + } + } + } +} + #[tokio::test(flavor = "multi_thread")] async fn test_create_by_procedure() { common_telemetry::init_default_ut_logging(); diff --git a/src/frontend/src/instance.rs b/src/frontend/src/instance.rs index 2ccb4fa48a..2d1b2264a1 100644 --- a/src/frontend/src/instance.rs +++ b/src/frontend/src/instance.rs @@ -677,7 +677,7 @@ pub fn check_permission( validate_param(delete.table_name(), query_ctx)?; } Statement::Copy(stmd) => match stmd { - CopyTable::To(copy_table_to) => validate_param(copy_table_to.table_name(), query_ctx)?, + CopyTable::To(copy_table_to) => validate_param(©_table_to.table_name, query_ctx)?, CopyTable::From(copy_table_from) => { validate_param(©_table_from.table_name, query_ctx)? } diff --git a/src/sql/src/parsers/copy_parser.rs b/src/sql/src/parsers/copy_parser.rs index 8bdac09243..518056dea8 100644 --- a/src/sql/src/parsers/copy_parser.rs +++ b/src/sql/src/parsers/copy_parser.rs @@ -130,8 +130,24 @@ impl<'a> ParserContext<'a> { } } + let connection_options = self + .parser + .parse_options(Keyword::CONNECTION) + .context(error::SyntaxSnafu { sql: self.sql })?; + + let connection = connection_options + .into_iter() + .filter_map(|option| { + if let Some(v) = ParserContext::parse_option_string(option.value) { + Some((option.name.value.to_uppercase(), v)) + } else { + None + } + }) + .collect(); + Ok(CopyTable::To(CopyTableTo::new( - table_name, file_name, format, + table_name, file_name, format, connection, ))) } @@ -167,7 +183,7 @@ mod tests { match statement { Statement::Copy(CopyTable::To(copy_table)) => { let (catalog, schema, table) = - if let [catalog, schema, table] = ©_table.table_name().0[..] { + if let [catalog, schema, table] = ©_table.table_name.0[..] { ( catalog.value.clone(), schema.value.clone(), @@ -181,11 +197,11 @@ mod tests { assert_eq!("schema0", schema); assert_eq!("tbl", table); - let file_name = copy_table.file_name(); + let file_name = copy_table.file_name; assert_eq!("tbl_file.parquet", file_name); - let format = copy_table.format(); - assert_eq!(Format::Parquet, *format); + let format = copy_table.format; + assert_eq!(Format::Parquet, format); } _ => unreachable!(), } @@ -275,6 +291,44 @@ mod tests { } } + #[test] + fn test_parse_copy_table_to() { + struct Test<'a> { + sql: &'a str, + expected_connection: HashMap, + } + + let tests = [ + Test { + sql: "COPY catalog0.schema0.tbl TO 'tbl_file.parquet' ", + expected_connection: HashMap::new(), + }, + Test { + sql: "COPY catalog0.schema0.tbl TO 'tbl_file.parquet' CONNECTION (FOO='Bar', ONE='two')", + expected_connection: [("FOO","Bar"),("ONE","two")].into_iter().map(|(k,v)|{(k.to_string(),v.to_string())}).collect() + }, + Test { + sql:"COPY catalog0.schema0.tbl TO 'tbl_file.parquet' WITH (FORMAT = 'parquet') CONNECTION (FOO='Bar', ONE='two')", + expected_connection: [("FOO","Bar"),("ONE","two")].into_iter().map(|(k,v)|{(k.to_string(),v.to_string())}).collect() + }, + ]; + + for test in tests { + let mut result = + ParserContext::create_with_dialect(test.sql, &GenericDialect {}).unwrap(); + assert_eq!(1, result.len()); + + let statement = result.remove(0); + assert_matches!(statement, Statement::Copy { .. }); + match statement { + Statement::Copy(CopyTable::To(copy_table)) => { + assert_eq!(copy_table.connection.clone(), test.expected_connection); + } + _ => unreachable!(), + } + } + } + #[test] fn test_parse_copy_table_with_unsupopoted_format() { let results = [ diff --git a/src/sql/src/statements/copy.rs b/src/sql/src/statements/copy.rs index 140e0babde..e2c3862a1a 100644 --- a/src/sql/src/statements/copy.rs +++ b/src/sql/src/statements/copy.rs @@ -26,31 +26,26 @@ pub enum CopyTable { #[derive(Debug, Clone, PartialEq, Eq)] pub struct CopyTableTo { - table_name: ObjectName, - file_name: String, - format: Format, + pub table_name: ObjectName, + pub file_name: String, + pub format: Format, + pub connection: HashMap, } impl CopyTableTo { - pub(crate) fn new(table_name: ObjectName, file_name: String, format: Format) -> Self { + pub(crate) fn new( + table_name: ObjectName, + file_name: String, + format: Format, + connection: HashMap, + ) -> Self { Self { table_name, file_name, format, + connection, } } - - pub fn table_name(&self) -> &ObjectName { - &self.table_name - } - - pub fn file_name(&self) -> &str { - &self.file_name - } - - pub fn format(&self) -> &Format { - &self.format - } } // TODO: To combine struct CopyTableFrom and CopyTableTo diff --git a/src/table/src/requests.rs b/src/table/src/requests.rs index 76fff8a986..b09e47446f 100644 --- a/src/table/src/requests.rs +++ b/src/table/src/requests.rs @@ -197,6 +197,7 @@ pub struct CopyTableRequest { pub schema_name: String, pub table_name: String, pub file_name: String, + pub connection: HashMap, } #[derive(Debug)]