feat: add cursor statements (#5094)

* feat: add sql parsers for cursor operations

* feat: cursor operator

* feat: implement RecordBatchStreamCursor

* feat: implement cursor storage and execution

* test: add tests

* chore: update docstring

* feat: add a temporary sql rewrite for cast in limit

this issue is described in #5097

* test: add more sql for cursor integration test

* feat: reject non-select query for cursor statement

* refactor: address review issues

* test: add empty result case

* feat: address review comments
This commit is contained in:
Ning Sun
2024-12-06 17:32:22 +08:00
committed by GitHub
parent 8b944268da
commit 3133f3fb4e
21 changed files with 786 additions and 5 deletions

2
Cargo.lock generated
View File

@@ -10987,9 +10987,11 @@ dependencies = [
"common-catalog",
"common-error",
"common-macro",
"common-recordbatch",
"common-telemetry",
"common-time",
"derive_builder 0.12.0",
"derive_more",
"meter-core",
"snafu 0.8.5",
"sql",

View File

@@ -0,0 +1,173 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use futures::StreamExt;
use tokio::sync::Mutex;
use crate::error::Result;
use crate::recordbatch::merge_record_batches;
use crate::{RecordBatch, SendableRecordBatchStream};
struct Inner {
stream: SendableRecordBatchStream,
current_row_index: usize,
current_batch: Option<RecordBatch>,
total_rows_in_current_batch: usize,
}
/// A cursor on RecordBatchStream that fetches data batch by batch
pub struct RecordBatchStreamCursor {
inner: Mutex<Inner>,
}
impl RecordBatchStreamCursor {
pub fn new(stream: SendableRecordBatchStream) -> RecordBatchStreamCursor {
Self {
inner: Mutex::new(Inner {
stream,
current_row_index: 0,
current_batch: None,
total_rows_in_current_batch: 0,
}),
}
}
/// Take `size` of row from the `RecordBatchStream` and create a new
/// `RecordBatch` for these rows.
pub async fn take(&self, size: usize) -> Result<RecordBatch> {
let mut remaining_rows_to_take = size;
let mut accumulated_rows = Vec::new();
let mut inner = self.inner.lock().await;
while remaining_rows_to_take > 0 {
// Ensure we have a current batch or fetch the next one
if inner.current_batch.is_none()
|| inner.current_row_index >= inner.total_rows_in_current_batch
{
match inner.stream.next().await {
Some(Ok(batch)) => {
inner.total_rows_in_current_batch = batch.num_rows();
inner.current_batch = Some(batch);
inner.current_row_index = 0;
}
Some(Err(e)) => return Err(e),
None => {
// Stream is exhausted
break;
}
}
}
// If we still have no batch after attempting to fetch
let current_batch = match &inner.current_batch {
Some(batch) => batch,
None => break,
};
// Calculate how many rows we can take from this batch
let rows_to_take_from_batch = remaining_rows_to_take
.min(inner.total_rows_in_current_batch - inner.current_row_index);
// Slice the current batch to get the desired rows
let taken_batch =
current_batch.slice(inner.current_row_index, rows_to_take_from_batch)?;
// Add the taken batch to accumulated rows
accumulated_rows.push(taken_batch);
// Update cursor and remaining rows
inner.current_row_index += rows_to_take_from_batch;
remaining_rows_to_take -= rows_to_take_from_batch;
}
// If no rows were accumulated, return empty
if accumulated_rows.is_empty() {
return Ok(RecordBatch::new_empty(inner.stream.schema()));
}
// If only one batch was accumulated, return it directly
if accumulated_rows.len() == 1 {
return Ok(accumulated_rows.remove(0));
}
// Merge multiple batches
merge_record_batches(inner.stream.schema(), &accumulated_rows)
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use datatypes::prelude::ConcreteDataType;
use datatypes::schema::{ColumnSchema, Schema};
use datatypes::vectors::StringVector;
use super::*;
use crate::RecordBatches;
#[tokio::test]
async fn test_cursor() {
let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
"a",
ConcreteDataType::string_datatype(),
false,
)]));
let rbs = RecordBatches::try_from_columns(
schema.clone(),
vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _],
)
.unwrap();
let cursor = RecordBatchStreamCursor::new(rbs.as_stream());
let result_rb = cursor.take(1).await.expect("take from cursor failed");
assert_eq!(result_rb.num_rows(), 1);
let result_rb = cursor.take(1).await.expect("take from cursor failed");
assert_eq!(result_rb.num_rows(), 1);
let result_rb = cursor.take(1).await.expect("take from cursor failed");
assert_eq!(result_rb.num_rows(), 0);
let rb = RecordBatch::new(
schema.clone(),
vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _],
)
.unwrap();
let rbs2 =
RecordBatches::try_new(schema.clone(), vec![rb.clone(), rb.clone(), rb]).unwrap();
let cursor = RecordBatchStreamCursor::new(rbs2.as_stream());
let result_rb = cursor.take(3).await.expect("take from cursor failed");
assert_eq!(result_rb.num_rows(), 3);
let result_rb = cursor.take(2).await.expect("take from cursor failed");
assert_eq!(result_rb.num_rows(), 2);
let result_rb = cursor.take(2).await.expect("take from cursor failed");
assert_eq!(result_rb.num_rows(), 1);
let result_rb = cursor.take(2).await.expect("take from cursor failed");
assert_eq!(result_rb.num_rows(), 0);
let rb = RecordBatch::new(
schema.clone(),
vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _],
)
.unwrap();
let rbs3 =
RecordBatches::try_new(schema.clone(), vec![rb.clone(), rb.clone(), rb]).unwrap();
let cursor = RecordBatchStreamCursor::new(rbs3.as_stream());
let result_rb = cursor.take(10).await.expect("take from cursor failed");
assert_eq!(result_rb.num_rows(), 6);
}
}

View File

@@ -168,6 +168,13 @@ pub enum Error {
#[snafu(source)]
error: tokio::time::error::Elapsed,
},
#[snafu(display("RecordBatch slice index overflow: {visit_index} > {size}"))]
RecordBatchSliceIndexOverflow {
#[snafu(implicit)]
location: Location,
size: usize,
visit_index: usize,
},
}
impl ErrorExt for Error {
@@ -182,7 +189,8 @@ impl ErrorExt for Error {
| Error::Format { .. }
| Error::ToArrowScalar { .. }
| Error::ProjectArrowRecordBatch { .. }
| Error::PhysicalExpr { .. } => StatusCode::Internal,
| Error::PhysicalExpr { .. }
| Error::RecordBatchSliceIndexOverflow { .. } => StatusCode::Internal,
Error::PollStream { .. } => StatusCode::EngineExecuteQuery,

View File

@@ -15,6 +15,7 @@
#![feature(never_type)]
pub mod adapter;
pub mod cursor;
pub mod error;
pub mod filter;
mod recordbatch;

View File

@@ -23,7 +23,7 @@ use datatypes::value::Value;
use datatypes::vectors::{Helper, VectorRef};
use serde::ser::{Error, SerializeStruct};
use serde::{Serialize, Serializer};
use snafu::{OptionExt, ResultExt};
use snafu::{ensure, OptionExt, ResultExt};
use crate::error::{
self, CastVectorSnafu, ColumnNotExistsSnafu, DataTypesSnafu, ProjectArrowRecordBatchSnafu,
@@ -194,6 +194,19 @@ impl RecordBatch {
.map(|t| t.to_string())
.unwrap_or("failed to pretty display a record batch".to_string())
}
/// Return a slice record batch starts from offset, with len rows
pub fn slice(&self, offset: usize, len: usize) -> Result<RecordBatch> {
ensure!(
offset + len <= self.num_rows(),
error::RecordBatchSliceIndexOverflowSnafu {
size: self.num_rows(),
visit_index: offset + len
}
);
let columns = self.columns.iter().map(|vector| vector.slice(offset, len));
RecordBatch::new(self.schema.clone(), columns)
}
}
impl Serialize for RecordBatch {
@@ -256,6 +269,36 @@ impl Iterator for RecordBatchRowIterator<'_> {
}
}
/// merge multiple recordbatch into a single
pub fn merge_record_batches(schema: SchemaRef, batches: &[RecordBatch]) -> Result<RecordBatch> {
let batches_len = batches.len();
if batches_len == 0 {
return Ok(RecordBatch::new_empty(schema));
}
let n_rows = batches.iter().map(|b| b.num_rows()).sum();
let n_columns = schema.num_columns();
// Collect arrays from each batch
let mut merged_columns = Vec::with_capacity(n_columns);
for col_idx in 0..n_columns {
let mut acc = schema.column_schemas()[col_idx]
.data_type
.create_mutable_vector(n_rows);
for batch in batches {
let column = batch.column(col_idx);
acc.extend_slice_of(column.as_ref(), 0, column.len())
.context(error::DataTypesSnafu)?;
}
merged_columns.push(acc.to_vector());
}
// Create a new RecordBatch with merged columns
RecordBatch::new(schema, merged_columns)
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
@@ -375,4 +418,80 @@ mod tests {
assert!(record_batch_iter.next().is_none());
}
#[test]
fn test_record_batch_slice() {
let column_schemas = vec![
ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
];
let schema = Arc::new(Schema::new(column_schemas));
let columns: Vec<VectorRef> = vec![
Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
Arc::new(StringVector::from(vec![
None,
Some("hello"),
Some("greptime"),
None,
])),
];
let recordbatch = RecordBatch::new(schema, columns).unwrap();
let recordbatch = recordbatch.slice(1, 2).expect("recordbatch slice");
let mut record_batch_iter = recordbatch.rows();
assert_eq!(
vec![Value::UInt32(2), Value::String("hello".into())],
record_batch_iter
.next()
.unwrap()
.into_iter()
.collect::<Vec<Value>>()
);
assert_eq!(
vec![Value::UInt32(3), Value::String("greptime".into())],
record_batch_iter
.next()
.unwrap()
.into_iter()
.collect::<Vec<Value>>()
);
assert!(record_batch_iter.next().is_none());
assert!(recordbatch.slice(1, 5).is_err());
}
#[test]
fn test_merge_record_batch() {
let column_schemas = vec![
ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
];
let schema = Arc::new(Schema::new(column_schemas));
let columns: Vec<VectorRef> = vec![
Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
Arc::new(StringVector::from(vec![
None,
Some("hello"),
Some("greptime"),
None,
])),
];
let recordbatch = RecordBatch::new(schema.clone(), columns).unwrap();
let columns: Vec<VectorRef> = vec![
Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
Arc::new(StringVector::from(vec![
None,
Some("hello"),
Some("greptime"),
None,
])),
];
let recordbatch2 = RecordBatch::new(schema.clone(), columns).unwrap();
let merged = merge_record_batches(schema.clone(), &[recordbatch, recordbatch2])
.expect("merge recordbatch");
assert_eq!(merged.num_rows(), 8);
}
}

View File

@@ -487,7 +487,11 @@ pub fn check_permission(
// TODO(dennis): add a hook for admin commands.
Statement::Admin(_) => {}
// These are executed by query engine, and will be checked there.
Statement::Query(_) | Statement::Explain(_) | Statement::Tql(_) | Statement::Delete(_) => {}
Statement::Query(_)
| Statement::Explain(_)
| Statement::Tql(_)
| Statement::Delete(_)
| Statement::DeclareCursor(_) => {}
// database ops won't be checked
Statement::CreateDatabase(_)
| Statement::ShowDatabases(_)
@@ -580,6 +584,8 @@ pub fn check_permission(
Statement::TruncateTable(stmt) => {
validate_param(stmt.table_name(), query_ctx)?;
}
// cursor operations are always allowed once it's created
Statement::FetchCursor(_) | Statement::CloseCursor(_) => {}
}
Ok(())
}

View File

@@ -786,6 +786,12 @@ pub enum Error {
#[snafu(source)]
error: Elapsed,
},
#[snafu(display("Cursor {name} is not found"))]
CursorNotFound { name: String },
#[snafu(display("A cursor named {name} already exists"))]
CursorExists { name: String },
}
pub type Result<T> = std::result::Result<T, Error>;
@@ -825,7 +831,9 @@ impl ErrorExt for Error {
| Error::FunctionArityMismatch { .. }
| Error::InvalidPartition { .. }
| Error::PhysicalExpr { .. }
| Error::InvalidJsonFormat { .. } => StatusCode::InvalidArguments,
| Error::InvalidJsonFormat { .. }
| Error::CursorNotFound { .. }
| Error::CursorExists { .. } => StatusCode::InvalidArguments,
Error::TableAlreadyExists { .. } | Error::ViewAlreadyExists { .. } => {
StatusCode::TableAlreadyExists

View File

@@ -16,6 +16,7 @@ mod admin;
mod copy_database;
mod copy_table_from;
mod copy_table_to;
mod cursor;
mod ddl;
mod describe;
mod dml;
@@ -133,6 +134,16 @@ impl StatementExecutor {
self.plan_exec(QueryStatement::Sql(stmt), query_ctx).await
}
Statement::DeclareCursor(declare_cursor) => {
self.declare_cursor(declare_cursor, query_ctx).await
}
Statement::FetchCursor(fetch_cursor) => {
self.fetch_cursor(fetch_cursor, query_ctx).await
}
Statement::CloseCursor(close_cursor) => {
self.close_cursor(close_cursor, query_ctx).await
}
Statement::Insert(insert) => self.insert(insert, query_ctx).await,
Statement::Tql(tql) => self.execute_tql(tql, query_ctx).await,

View File

@@ -0,0 +1,98 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use common_query::{Output, OutputData};
use common_recordbatch::cursor::RecordBatchStreamCursor;
use common_recordbatch::RecordBatches;
use common_telemetry::tracing;
use query::parser::QueryStatement;
use session::context::QueryContextRef;
use snafu::ResultExt;
use sql::statements::cursor::{CloseCursor, DeclareCursor, FetchCursor};
use sql::statements::statement::Statement;
use crate::error::{self, Result};
use crate::statement::StatementExecutor;
impl StatementExecutor {
#[tracing::instrument(skip_all)]
pub(super) async fn declare_cursor(
&self,
declare_cursor: DeclareCursor,
query_ctx: QueryContextRef,
) -> Result<Output> {
let cursor_name = declare_cursor.cursor_name.to_string();
if query_ctx.get_cursor(&cursor_name).is_some() {
error::CursorExistsSnafu {
name: cursor_name.to_string(),
}
.fail()?;
}
let query_stmt = Statement::Query(declare_cursor.query);
let output = self
.plan_exec(QueryStatement::Sql(query_stmt), query_ctx.clone())
.await?;
match output.data {
OutputData::RecordBatches(rb) => {
let rbs = rb.as_stream();
query_ctx.insert_cursor(cursor_name, RecordBatchStreamCursor::new(rbs));
}
OutputData::Stream(rbs) => {
query_ctx.insert_cursor(cursor_name, RecordBatchStreamCursor::new(rbs));
}
// Should not happen because we have query type ensured from parser.
OutputData::AffectedRows(_) => error::NotSupportedSnafu {
feat: "Non-query statement on cursor",
}
.fail()?,
}
Ok(Output::new_with_affected_rows(0))
}
#[tracing::instrument(skip_all)]
pub(super) async fn fetch_cursor(
&self,
fetch_cursor: FetchCursor,
query_ctx: QueryContextRef,
) -> Result<Output> {
let cursor_name = fetch_cursor.cursor_name.to_string();
let fetch_size = fetch_cursor.fetch_size;
if let Some(rb) = query_ctx.get_cursor(&cursor_name) {
let record_batch = rb
.take(fetch_size as usize)
.await
.context(error::BuildRecordBatchSnafu)?;
let record_batches =
RecordBatches::try_new(record_batch.schema.clone(), vec![record_batch])
.context(error::BuildRecordBatchSnafu)?;
Ok(Output::new_with_record_batches(record_batches))
} else {
error::CursorNotFoundSnafu { name: cursor_name }.fail()
}
}
#[tracing::instrument(skip_all)]
pub(super) async fn close_cursor(
&self,
close_cursor: CloseCursor,
query_ctx: QueryContextRef,
) -> Result<Output> {
query_ctx.remove_cursor(&close_cursor.cursor_name.to_string());
Ok(Output::new_with_affected_rows(0))
}
}

View File

@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::borrow::Cow;
use std::collections::HashMap;
use std::sync::Arc;
@@ -112,6 +113,13 @@ pub(crate) fn process<'a>(query: &str, query_ctx: QueryContextRef) -> Option<Vec
}
}
static LIMIT_CAST_PATTERN: Lazy<Regex> =
Lazy::new(|| Regex::new("(?i)(LIMIT\\s+\\d+)::bigint").unwrap());
pub(crate) fn rewrite_sql(query: &str) -> Cow<'_, str> {
//TODO(sunng87): remove this when we upgraded datafusion to 43 or newer
LIMIT_CAST_PATTERN.replace_all(query, "$1")
}
#[cfg(test)]
mod test {
use session::context::{QueryContext, QueryContextRef};
@@ -195,4 +203,13 @@ mod test {
assert!(process("SHOW TABLES ", query_context.clone()).is_none());
assert!(process("SET TIME_ZONE=utc ", query_context.clone()).is_none());
}
#[test]
fn test_rewrite() {
let sql = "SELECT * FROM number LIMIT 1::bigint";
let sql2 = "SELECT * FROM number limit 1::BIGINT";
assert_eq!("SELECT * FROM number LIMIT 1", rewrite_sql(sql));
assert_eq!("SELECT * FROM number limit 1", rewrite_sql(sql2));
}
}

View File

@@ -70,6 +70,9 @@ impl SimpleQueryHandler for PostgresServerHandlerInner {
return Ok(vec![Response::EmptyQuery]);
}
let query = fixtures::rewrite_sql(query);
let query = query.as_ref();
if let Some(resps) = fixtures::process(query, query_ctx.clone()) {
send_warning_opt(client, query_ctx).await?;
Ok(resps)
@@ -229,6 +232,9 @@ impl QueryParser for DefaultQueryParser {
});
}
let sql = fixtures::rewrite_sql(sql);
let sql = sql.as_ref();
let mut stmts =
ParserContext::create_with_dialect(sql, &PostgreSqlDialect {}, ParseOptions::default())
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;

View File

@@ -17,9 +17,11 @@ auth.workspace = true
common-catalog.workspace = true
common-error.workspace = true
common-macro.workspace = true
common-recordbatch.workspace = true
common-telemetry.workspace = true
common-time.workspace = true
derive_builder.workspace = true
derive_more = { version = "1", default-features = false, features = ["debug"] }
meter-core.workspace = true
snafu.workspace = true
sql.workspace = true

View File

@@ -23,6 +23,8 @@ use arc_swap::ArcSwap;
use auth::UserInfoRef;
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_catalog::{build_db_string, parse_catalog_and_schema_from_db_string};
use common_recordbatch::cursor::RecordBatchStreamCursor;
use common_telemetry::warn;
use common_time::timezone::parse_timezone;
use common_time::Timezone;
use derive_builder::Builder;
@@ -34,6 +36,8 @@ use crate::MutableInner;
pub type QueryContextRef = Arc<QueryContext>;
pub type ConnInfoRef = Arc<ConnInfo>;
const CURSOR_COUNT_WARNING_LIMIT: usize = 10;
#[derive(Debug, Builder, Clone)]
#[builder(pattern = "owned")]
#[builder(build_fn(skip))]
@@ -299,6 +303,27 @@ impl QueryContext {
pub fn set_query_timeout(&self, timeout: Duration) {
self.mutable_session_data.write().unwrap().query_timeout = Some(timeout);
}
pub fn insert_cursor(&self, name: String, rb: RecordBatchStreamCursor) {
let mut guard = self.mutable_session_data.write().unwrap();
guard.cursors.insert(name, Arc::new(rb));
let cursor_count = guard.cursors.len();
if cursor_count > CURSOR_COUNT_WARNING_LIMIT {
warn!("Current connection has {} open cursors", cursor_count);
}
}
pub fn remove_cursor(&self, name: &str) {
let mut guard = self.mutable_session_data.write().unwrap();
guard.cursors.remove(name);
}
pub fn get_cursor(&self, name: &str) -> Option<Arc<RecordBatchStreamCursor>> {
let guard = self.mutable_session_data.read().unwrap();
let rb = guard.cursors.get(name);
rb.cloned()
}
}
impl QueryContextBuilder {

View File

@@ -16,6 +16,7 @@ pub mod context;
pub mod session_config;
pub mod table_name;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::{Arc, RwLock};
use std::time::Duration;
@@ -23,9 +24,11 @@ use std::time::Duration;
use auth::UserInfoRef;
use common_catalog::build_db_string;
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_recordbatch::cursor::RecordBatchStreamCursor;
use common_time::timezone::get_timezone;
use common_time::Timezone;
use context::{ConfigurationVariables, QueryContextBuilder};
use derive_more::Debug;
use crate::context::{Channel, ConnInfo, QueryContextRef};
@@ -47,6 +50,8 @@ pub(crate) struct MutableInner {
user_info: UserInfoRef,
timezone: Timezone,
query_timeout: Option<Duration>,
#[debug(skip)]
pub(crate) cursors: HashMap<String, Arc<RecordBatchStreamCursor>>,
}
impl Default for MutableInner {
@@ -56,6 +61,7 @@ impl Default for MutableInner {
user_info: auth::userinfo_by_name(None),
timezone: get_timezone(None).clone(),
query_timeout: None,
cursors: HashMap::with_capacity(0),
}
}
}

View File

@@ -167,6 +167,12 @@ impl ParserContext<'_> {
self.parse_tql()
}
Keyword::DECLARE => self.parse_declare_cursor(),
Keyword::FETCH => self.parse_fetch_cursor(),
Keyword::CLOSE => self.parse_close_cursor(),
Keyword::USE => {
let _ = self.parser.next_token();

View File

@@ -16,6 +16,7 @@ pub(crate) mod admin_parser;
mod alter_parser;
pub(crate) mod copy_parser;
pub(crate) mod create_parser;
pub(crate) mod cursor_parser;
pub(crate) mod deallocate_parser;
pub(crate) mod delete_parser;
pub(crate) mod describe_parser;

View File

@@ -0,0 +1,157 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use snafu::{ensure, ResultExt};
use sqlparser::keywords::Keyword;
use sqlparser::tokenizer::Token;
use crate::error::{self, Result};
use crate::parser::ParserContext;
use crate::statements::cursor::{CloseCursor, DeclareCursor, FetchCursor};
use crate::statements::statement::Statement;
impl ParserContext<'_> {
pub(crate) fn parse_declare_cursor(&mut self) -> Result<Statement> {
let _ = self.parser.expect_keyword(Keyword::DECLARE);
let cursor_name = self
.parser
.parse_object_name(false)
.context(error::SyntaxSnafu)?;
let _ = self
.parser
.expect_keywords(&[Keyword::CURSOR, Keyword::FOR]);
let mut is_select = false;
if let Token::Word(w) = self.parser.peek_token().token {
match w.keyword {
Keyword::SELECT | Keyword::WITH => {
is_select = true;
}
_ => {}
}
};
ensure!(
is_select,
error::InvalidSqlSnafu {
msg: "Expect select query in cursor statement".to_string(),
}
);
let query_stmt = self.parse_query()?;
match query_stmt {
Statement::Query(query) => Ok(Statement::DeclareCursor(DeclareCursor {
cursor_name: ParserContext::canonicalize_object_name(cursor_name),
query,
})),
_ => error::InvalidSqlSnafu {
msg: format!("Expect query, found {}", query_stmt),
}
.fail(),
}
}
pub(crate) fn parse_fetch_cursor(&mut self) -> Result<Statement> {
let _ = self.parser.expect_keyword(Keyword::FETCH);
let fetch_size = self
.parser
.parse_literal_uint()
.context(error::SyntaxSnafu)?;
let _ = self.parser.parse_keyword(Keyword::FROM);
let cursor_name = self
.parser
.parse_object_name(false)
.context(error::SyntaxSnafu)?;
Ok(Statement::FetchCursor(FetchCursor {
cursor_name: ParserContext::canonicalize_object_name(cursor_name),
fetch_size,
}))
}
pub(crate) fn parse_close_cursor(&mut self) -> Result<Statement> {
let _ = self.parser.expect_keyword(Keyword::CLOSE);
let cursor_name = self
.parser
.parse_object_name(false)
.context(error::SyntaxSnafu)?;
Ok(Statement::CloseCursor(CloseCursor {
cursor_name: ParserContext::canonicalize_object_name(cursor_name),
}))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dialect::GreptimeDbDialect;
use crate::parser::ParseOptions;
#[test]
fn test_parse_declare_cursor() {
let sql = "DECLARE c1 CURSOR FOR\nSELECT * FROM numbers";
let result =
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
.unwrap();
if let Statement::DeclareCursor(dc) = &result[0] {
assert_eq!("c1", dc.cursor_name.to_string());
assert_eq!(
"DECLARE c1 CURSOR FOR SELECT * FROM numbers",
dc.to_string()
);
} else {
panic!("Unexpected statement");
}
let sql = "DECLARE c1 CURSOR FOR\nINSERT INTO numbers VALUES (1);";
let result =
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
assert!(result.is_err());
}
#[test]
fn test_parese_fetch_cursor() {
let sql = "FETCH 1000 FROM c1";
let result =
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
.unwrap();
if let Statement::FetchCursor(fc) = &result[0] {
assert_eq!("c1", fc.cursor_name.to_string());
assert_eq!("1000", fc.fetch_size.to_string());
assert_eq!(sql, fc.to_string());
} else {
panic!("Unexpected statement")
}
}
#[test]
fn test_close_fetch_cursor() {
let sql = "CLOSE c1";
let result =
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
.unwrap();
if let Statement::CloseCursor(cc) = &result[0] {
assert_eq!("c1", cc.cursor_name.to_string());
assert_eq!(sql, cc.to_string());
} else {
panic!("Unexpected statement")
}
}
}

View File

@@ -16,6 +16,7 @@ pub mod admin;
pub mod alter;
pub mod copy;
pub mod create;
pub mod cursor;
pub mod delete;
pub mod describe;
pub mod drop;
@@ -224,7 +225,7 @@ pub fn sql_number_to_value(data_type: &ConcreteDataType, n: &str) -> Result<Valu
// TODO(hl): also Date/DateTime
}
fn parse_sql_number<R: FromStr + std::fmt::Debug>(n: &str) -> Result<R>
pub(crate) fn parse_sql_number<R: FromStr + std::fmt::Debug>(n: &str) -> Result<R>
where
<R as FromStr>::Err: std::fmt::Debug,
{

View File

@@ -0,0 +1,60 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::fmt::Display;
use sqlparser::ast::ObjectName;
use sqlparser_derive::{Visit, VisitMut};
use super::query::Query;
/// Represents a DECLARE CURSOR statement
///
/// This statement will carry a SQL query
#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut)]
pub struct DeclareCursor {
pub cursor_name: ObjectName,
pub query: Box<Query>,
}
impl Display for DeclareCursor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "DECLARE {} CURSOR FOR {}", self.cursor_name, self.query)
}
}
/// Represents a FETCH FROM cursor statement
#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut)]
pub struct FetchCursor {
pub cursor_name: ObjectName,
pub fetch_size: u64,
}
impl Display for FetchCursor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "FETCH {} FROM {}", self.fetch_size, self.cursor_name)
}
}
/// Represents a CLOSE cursor statement
#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut)]
pub struct CloseCursor {
pub cursor_name: ObjectName,
}
impl Display for CloseCursor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "CLOSE {}", self.cursor_name)
}
}

View File

@@ -24,6 +24,7 @@ use crate::statements::alter::{AlterDatabase, AlterTable};
use crate::statements::create::{
CreateDatabase, CreateExternalTable, CreateFlow, CreateTable, CreateTableLike, CreateView,
};
use crate::statements::cursor::{CloseCursor, DeclareCursor, FetchCursor};
use crate::statements::delete::Delete;
use crate::statements::describe::DescribeTable;
use crate::statements::drop::{DropDatabase, DropFlow, DropTable, DropView};
@@ -118,6 +119,12 @@ pub enum Statement {
Use(String),
// Admin statement(extension)
Admin(Admin),
// DECLARE ... CURSOR FOR ...
DeclareCursor(DeclareCursor),
// FETCH ... FROM ...
FetchCursor(FetchCursor),
// CLOSE
CloseCursor(CloseCursor),
}
impl Display for Statement {
@@ -165,6 +172,9 @@ impl Display for Statement {
Statement::CreateView(s) => s.fmt(f),
Statement::Use(s) => s.fmt(f),
Statement::Admin(admin) => admin.fmt(f),
Statement::DeclareCursor(s) => s.fmt(f),
Statement::FetchCursor(s) => s.fmt(f),
Statement::CloseCursor(s) => s.fmt(f),
}
}
}

View File

@@ -72,6 +72,7 @@ macro_rules! sql_tests {
test_postgres_parameter_inference,
test_postgres_array_types,
test_mysql_prepare_stmt_insert_timestamp,
test_declare_fetch_close_cursor,
);
)*
};
@@ -1198,3 +1199,66 @@ pub async fn test_postgres_array_types(store_type: StorageType) {
let _ = fe_pg_server.shutdown().await;
guard.remove_all().await;
}
pub async fn test_declare_fetch_close_cursor(store_type: StorageType) {
let (addr, mut guard, fe_pg_server) = setup_pg_server(store_type, "sql_inference").await;
let (client, connection) = tokio_postgres::connect(&format!("postgres://{addr}/public"), NoTls)
.await
.unwrap();
let (tx, rx) = tokio::sync::oneshot::channel();
tokio::spawn(async move {
connection.await.unwrap();
tx.send(()).unwrap();
});
client
.execute(
"DECLARE c1 CURSOR FOR SELECT * FROM numbers WHERE number > 2 LIMIT 50::bigint",
&[],
)
.await
.expect("declare cursor");
// duplicated cursor
assert!(client
.execute("DECLARE c1 CURSOR FOR SELECT 1", &[],)
.await
.is_err());
let rows = client.query("FETCH 5 FROM c1", &[]).await.unwrap();
assert_eq!(5, rows.len());
let rows = client.query("FETCH 100 FROM c1", &[]).await.unwrap();
assert_eq!(45, rows.len());
let rows = client.query("FETCH 100 FROM c1", &[]).await.unwrap();
assert_eq!(0, rows.len());
client.execute("CLOSE c1", &[]).await.expect("close cursor");
// cursor not found
let result = client.query("FETCH 100 FROM c1", &[]).await;
assert!(result.is_err());
client
.execute(
"DECLARE c2 CURSOR FOR SELECT * FROM numbers WHERE number < 0",
&[],
)
.await
.expect("declare cursor");
let rows = client.query("FETCH 5 FROM c2", &[]).await.unwrap();
assert_eq!(0, rows.len());
client.execute("CLOSE c2", &[]).await.expect("close cursor");
// Shutdown the client.
drop(client);
rx.await.unwrap();
let _ = fe_pg_server.shutdown().await;
guard.remove_all().await;
}