feat: implement CONNECTION clause of Copy To (#1163)

* feat: implement CONNECTION clause of Copy To

* test: add tests for s3 backend

* Apply suggestions from code review

Co-authored-by: Yingwen <realevenyag@gmail.com>

---------

Co-authored-by: Yingwen <realevenyag@gmail.com>
This commit is contained in:
Weny Xu
2023-03-16 11:36:38 +08:00
committed by GitHub
parent 17eb99bc52
commit facdda4d9f
11 changed files with 328 additions and 94 deletions

View File

@@ -64,6 +64,7 @@ tonic.workspace = true
tower = { version = "0.4", features = ["full"] }
tower-http = { version = "0.3", features = ["full"] }
url = "2.3.1"
uuid.workspace = true
[dev-dependencies]
axum-test-helper = { git = "https://github.com/sunng87/axum-test-helper.git", branch = "patch-1" }

View File

@@ -163,14 +163,14 @@ impl Instance {
QueryStatement::Sql(Statement::Copy(copy_table)) => match copy_table {
CopyTable::To(copy_table) => {
let (catalog_name, schema_name, table_name) =
table_idents_to_full_name(copy_table.table_name(), query_ctx.clone())?;
let file_name = copy_table.file_name().to_string();
table_idents_to_full_name(&copy_table.table_name, query_ctx.clone())?;
let file_name = copy_table.file_name;
let req = CopyTableRequest {
catalog_name,
schema_name,
table_name,
file_name,
connection: copy_table.connection,
};
self.sql_handler

View File

@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use std::pin::Pin;
use common_query::physical_plan::SessionContext;
@@ -22,16 +23,54 @@ use datafusion::parquet::basic::{Compression, Encoding};
use datafusion::parquet::file::properties::WriterProperties;
use datafusion::physical_plan::RecordBatchStream;
use futures::TryStreamExt;
use object_store::services::Fs as Builder;
use object_store::{ObjectStore, ObjectStoreBuilder};
use object_store::ObjectStore;
use snafu::ResultExt;
use table::engine::TableReference;
use table::requests::CopyTableRequest;
use url::{ParseError, Url};
use super::copy_table_from::{build_fs_backend, build_s3_backend, S3_SCHEMA};
use crate::error::{self, Result};
use crate::sql::SqlHandler;
impl SqlHandler {
fn build_backend(
&self,
url: &str,
connection: HashMap<String, String>,
) -> Result<(ObjectStore, String)> {
let result = Url::parse(url);
match result {
Ok(url) => {
let host = url.host_str();
let schema = url.scheme();
let path = url.path();
match schema.to_uppercase().as_str() {
S3_SCHEMA => {
let object_store = build_s3_backend(host, "/", connection)?;
Ok((object_store, path.to_string()))
}
_ => error::UnsupportedBackendProtocolSnafu {
protocol: schema.to_string(),
}
.fail(),
}
}
Err(ParseError::RelativeUrlWithoutBase) => {
let object_store = build_fs_backend("/")?;
Ok((object_store, url.to_string()))
}
Err(err) => Err(error::Error::InvalidUrl {
url: url.to_string(),
source: err,
}),
}
}
pub(crate) async fn copy_table(&self, req: CopyTableRequest) -> Result<Output> {
let table_ref = TableReference {
catalog: &req.catalog_name,
@@ -52,13 +91,9 @@ impl SqlHandler {
.context(error::TableScanExecSnafu)?;
let stream = Box::pin(DfRecordBatchStreamAdapter::new(stream));
let accessor = Builder::default()
.root("/")
.build()
.context(error::BuildBackendSnafu)?;
let object_store = ObjectStore::new(accessor).finish();
let (object_store, file_name) = self.build_backend(&req.file_name, req.connection)?;
let mut parquet_writer = ParquetWriter::new(req.file_name, stream, object_store);
let mut parquet_writer = ParquetWriter::new(file_name, stream, object_store);
// TODO(jiachun):
// For now, COPY is implemented synchronously.
// When copying large table, it will be blocked for a long time.

View File

@@ -34,7 +34,7 @@ use url::{ParseError, Url};
use crate::error::{self, Result};
use crate::sql::SqlHandler;
const S3_SCHEMA: &str = "S3";
pub const S3_SCHEMA: &str = "S3";
const ENDPOINT_URL: &str = "ENDPOINT_URL";
const ACCESS_KEY_ID: &str = "ACCESS_KEY_ID";
const SECRET_ACCESS_KEY: &str = "SECRET_ACCESS_KEY";
@@ -165,13 +165,10 @@ impl DataSource {
Source::Dir
};
let accessor = Fs::default()
.root(&path)
.build()
.context(error::BuildBackendSnafu)?;
let object_store = build_fs_backend(&path)?;
Ok(DataSource {
object_store: ObjectStore::new(accessor).finish(),
object_store,
source,
path,
regex,
@@ -184,59 +181,6 @@ impl DataSource {
}
}
fn build_s3_backend(
host: Option<&str>,
path: &str,
connection: HashMap<String, String>,
) -> Result<ObjectStore> {
let mut builder = S3::default();
builder.root(path);
if let Some(bucket) = host {
builder.bucket(bucket);
}
if let Some(endpoint) = connection.get(ENDPOINT_URL) {
builder.endpoint(endpoint);
}
if let Some(region) = connection.get(REGION) {
builder.region(region);
}
if let Some(key_id) = connection.get(ACCESS_KEY_ID) {
builder.access_key_id(key_id);
}
if let Some(key) = connection.get(SECRET_ACCESS_KEY) {
builder.secret_access_key(key);
}
if let Some(session_token) = connection.get(SESSION_TOKEN) {
builder.security_token(session_token);
}
if let Some(enable_str) = connection.get(ENABLE_VIRTUAL_HOST_STYLE) {
let enable = enable_str.as_str().parse::<bool>().map_err(|e| {
error::InvalidConnectionSnafu {
msg: format!(
"failed to parse the option {}={}, {}",
ENABLE_VIRTUAL_HOST_STYLE, enable_str, e
),
}
.build()
})?;
if enable {
builder.enable_virtual_host_style();
}
}
let accessor = builder.build().context(error::BuildBackendSnafu)?;
Ok(ObjectStore::new(accessor).finish())
}
fn from_url(
url: Url,
regex: Option<Regex>,
@@ -257,7 +201,7 @@ impl DataSource {
};
let object_store = match schema.to_uppercase().as_str() {
S3_SCHEMA => DataSource::build_s3_backend(host, &dir, connection)?,
S3_SCHEMA => build_s3_backend(host, &dir, connection)?,
_ => {
return error::UnsupportedBackendProtocolSnafu {
protocol: schema.to_string(),
@@ -348,6 +292,68 @@ impl DataSource {
}
}
pub fn build_s3_backend(
host: Option<&str>,
path: &str,
connection: HashMap<String, String>,
) -> Result<ObjectStore> {
let mut builder = S3::default();
builder.root(path);
if let Some(bucket) = host {
builder.bucket(bucket);
}
if let Some(endpoint) = connection.get(ENDPOINT_URL) {
builder.endpoint(endpoint);
}
if let Some(region) = connection.get(REGION) {
builder.region(region);
}
if let Some(key_id) = connection.get(ACCESS_KEY_ID) {
builder.access_key_id(key_id);
}
if let Some(key) = connection.get(SECRET_ACCESS_KEY) {
builder.secret_access_key(key);
}
if let Some(session_token) = connection.get(SESSION_TOKEN) {
builder.security_token(session_token);
}
if let Some(enable_str) = connection.get(ENABLE_VIRTUAL_HOST_STYLE) {
let enable = enable_str.as_str().parse::<bool>().map_err(|e| {
error::InvalidConnectionSnafu {
msg: format!(
"failed to parse the option {}={}, {}",
ENABLE_VIRTUAL_HOST_STYLE, enable_str, e
),
}
.build()
})?;
if enable {
builder.enable_virtual_host_style();
}
}
let accessor = builder.build().context(error::BuildBackendSnafu)?;
Ok(ObjectStore::new(accessor).finish())
}
pub fn build_fs_backend(root: &str) -> Result<ObjectStore> {
let accessor = Fs::default()
.root(root)
.build()
.context(error::BuildBackendSnafu)?;
Ok(ObjectStore::new(accessor).finish())
}
#[cfg(test)]
mod tests {

View File

@@ -12,11 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::env;
use std::sync::Arc;
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_query::Output;
use common_recordbatch::util;
use common_telemetry::logging;
use datatypes::data_type::ConcreteDataType;
use datatypes::vectors::{Int64Vector, StringVector, UInt64Vector, VectorRef};
use query::parser::{QueryLanguageParser, QueryStatement};
@@ -797,6 +799,45 @@ async fn test_execute_copy_to() {
assert!(matches!(output, Output::AffectedRows(2)));
}
#[tokio::test(flavor = "multi_thread")]
async fn test_execute_copy_to_s3() {
logging::init_default_ut_logging();
if let Ok(bucket) = env::var("GT_S3_BUCKET") {
if !bucket.is_empty() {
let instance = setup_test_instance("test_execute_copy_to_s3").await;
// setups
execute_sql(
&instance,
"create table demo(host string, cpu double, memory double, ts timestamp time index);",
)
.await;
let output = execute_sql(
&instance,
r#"insert into demo(host, cpu, memory, ts) values
('host1', 66.6, 1024, 1655276557000),
('host2', 88.8, 333.3, 1655276558000)
"#,
)
.await;
assert!(matches!(output, Output::AffectedRows(2)));
let key_id = env::var("GT_S3_ACCESS_KEY_ID").unwrap();
let key = env::var("GT_S3_ACCESS_KEY").unwrap();
let url =
env::var("GT_S3_ENDPOINT_URL").unwrap_or("https://s3.amazonaws.com".to_string());
let root = uuid::Uuid::new_v4().to_string();
// exports
let copy_to_stmt = format!("Copy demo TO 's3://{}/{}/export/demo.parquet' CONNECTION (ACCESS_KEY_ID='{}',SECRET_ACCESS_KEY='{}',ENDPOINT_URL='{}')", bucket, root, key_id, key, url);
let output = execute_sql(&instance, &copy_to_stmt).await;
assert!(matches!(output, Output::AffectedRows(2)));
}
}
}
#[tokio::test(flavor = "multi_thread")]
async fn test_execute_copy_from() {
let instance = setup_test_instance("test_execute_copy_from").await;
@@ -882,6 +923,106 @@ async fn test_execute_copy_from() {
}
}
#[tokio::test(flavor = "multi_thread")]
async fn test_execute_copy_from_s3() {
logging::init_default_ut_logging();
if let Ok(bucket) = env::var("GT_S3_BUCKET") {
if !bucket.is_empty() {
let instance = setup_test_instance("test_execute_copy_from_s3").await;
// setups
execute_sql(
&instance,
"create table demo(host string, cpu double, memory double, ts timestamp time index);",
)
.await;
let output = execute_sql(
&instance,
r#"insert into demo(host, cpu, memory, ts) values
('host1', 66.6, 1024, 1655276557000),
('host2', 88.8, 333.3, 1655276558000)
"#,
)
.await;
assert!(matches!(output, Output::AffectedRows(2)));
// export
let root = uuid::Uuid::new_v4().to_string();
let key_id = env::var("GT_S3_ACCESS_KEY_ID").unwrap();
let key = env::var("GT_S3_ACCESS_KEY").unwrap();
let url =
env::var("GT_S3_ENDPOINT_URL").unwrap_or("https://s3.amazonaws.com".to_string());
let copy_to_stmt = format!("Copy demo TO 's3://{}/{}/export/demo.parquet' CONNECTION (ACCESS_KEY_ID='{}',SECRET_ACCESS_KEY='{}',ENDPOINT_URL='{}')", bucket, root, key_id, key, url);
logging::info!("Copy table to s3: {}", copy_to_stmt);
let output = execute_sql(&instance, &copy_to_stmt).await;
assert!(matches!(output, Output::AffectedRows(2)));
struct Test<'a> {
sql: &'a str,
table_name: &'a str,
}
let tests = [
Test {
sql: &format!(
"Copy with_filename FROM 's3://{}/{}/export/demo.parquet_1_2'",
bucket, root
),
table_name: "with_filename",
},
Test {
sql: &format!("Copy with_path FROM 's3://{}/{}/export/'", bucket, root),
table_name: "with_path",
},
Test {
sql: &format!(
"Copy with_pattern FROM 's3://{}/{}/export/' WITH (PATTERN = 'demo.*')",
bucket, root
),
table_name: "with_pattern",
},
];
for test in tests {
// import
execute_sql(
&instance,
&format!(
"create table {}(host string, cpu double, memory double, ts timestamp time index);",
test.table_name
),
)
.await;
let sql = format!(
"{} CONNECTION (ACCESS_KEY_ID='{}',SECRET_ACCESS_KEY='{}',ENDPOINT_URL='{}')",
test.sql, key_id, key, url
);
logging::info!("Running sql: {}", sql);
let output = execute_sql(&instance, &sql).await;
assert!(matches!(output, Output::AffectedRows(2)));
let output = execute_sql(
&instance,
&format!("select * from {} order by ts", test.table_name),
)
.await;
let expected = "\
+-------+------+--------+---------------------+
| host | cpu | memory | ts |
+-------+------+--------+---------------------+
| host1 | 66.6 | 1024.0 | 2022-06-15T07:02:37 |
| host2 | 88.8 | 333.3 | 2022-06-15T07:02:38 |
+-------+------+--------+---------------------+"
.to_string();
check_output_stream(output, expected).await;
}
}
}
}
#[tokio::test(flavor = "multi_thread")]
async fn test_create_by_procedure() {
common_telemetry::init_default_ut_logging();