mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-01-10 07:12:54 +00:00
feat: add cursor statements (#5094)
* feat: add sql parsers for cursor operations * feat: cursor operator * feat: implement RecordBatchStreamCursor * feat: implement cursor storage and execution * test: add tests * chore: update docstring * feat: add a temporary sql rewrite for cast in limit this issue is described in #5097 * test: add more sql for cursor integration test * feat: reject non-select query for cursor statement * refactor: address review issues * test: add empty result case * feat: address review comments
This commit is contained in:
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -10987,9 +10987,11 @@ dependencies = [
|
||||
"common-catalog",
|
||||
"common-error",
|
||||
"common-macro",
|
||||
"common-recordbatch",
|
||||
"common-telemetry",
|
||||
"common-time",
|
||||
"derive_builder 0.12.0",
|
||||
"derive_more",
|
||||
"meter-core",
|
||||
"snafu 0.8.5",
|
||||
"sql",
|
||||
|
||||
173
src/common/recordbatch/src/cursor.rs
Normal file
173
src/common/recordbatch/src/cursor.rs
Normal file
@@ -0,0 +1,173 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use futures::StreamExt;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::recordbatch::merge_record_batches;
|
||||
use crate::{RecordBatch, SendableRecordBatchStream};
|
||||
|
||||
struct Inner {
|
||||
stream: SendableRecordBatchStream,
|
||||
current_row_index: usize,
|
||||
current_batch: Option<RecordBatch>,
|
||||
total_rows_in_current_batch: usize,
|
||||
}
|
||||
|
||||
/// A cursor on RecordBatchStream that fetches data batch by batch
|
||||
pub struct RecordBatchStreamCursor {
|
||||
inner: Mutex<Inner>,
|
||||
}
|
||||
|
||||
impl RecordBatchStreamCursor {
|
||||
pub fn new(stream: SendableRecordBatchStream) -> RecordBatchStreamCursor {
|
||||
Self {
|
||||
inner: Mutex::new(Inner {
|
||||
stream,
|
||||
current_row_index: 0,
|
||||
current_batch: None,
|
||||
total_rows_in_current_batch: 0,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Take `size` of row from the `RecordBatchStream` and create a new
|
||||
/// `RecordBatch` for these rows.
|
||||
pub async fn take(&self, size: usize) -> Result<RecordBatch> {
|
||||
let mut remaining_rows_to_take = size;
|
||||
let mut accumulated_rows = Vec::new();
|
||||
|
||||
let mut inner = self.inner.lock().await;
|
||||
|
||||
while remaining_rows_to_take > 0 {
|
||||
// Ensure we have a current batch or fetch the next one
|
||||
if inner.current_batch.is_none()
|
||||
|| inner.current_row_index >= inner.total_rows_in_current_batch
|
||||
{
|
||||
match inner.stream.next().await {
|
||||
Some(Ok(batch)) => {
|
||||
inner.total_rows_in_current_batch = batch.num_rows();
|
||||
inner.current_batch = Some(batch);
|
||||
inner.current_row_index = 0;
|
||||
}
|
||||
Some(Err(e)) => return Err(e),
|
||||
None => {
|
||||
// Stream is exhausted
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we still have no batch after attempting to fetch
|
||||
let current_batch = match &inner.current_batch {
|
||||
Some(batch) => batch,
|
||||
None => break,
|
||||
};
|
||||
|
||||
// Calculate how many rows we can take from this batch
|
||||
let rows_to_take_from_batch = remaining_rows_to_take
|
||||
.min(inner.total_rows_in_current_batch - inner.current_row_index);
|
||||
|
||||
// Slice the current batch to get the desired rows
|
||||
let taken_batch =
|
||||
current_batch.slice(inner.current_row_index, rows_to_take_from_batch)?;
|
||||
|
||||
// Add the taken batch to accumulated rows
|
||||
accumulated_rows.push(taken_batch);
|
||||
|
||||
// Update cursor and remaining rows
|
||||
inner.current_row_index += rows_to_take_from_batch;
|
||||
remaining_rows_to_take -= rows_to_take_from_batch;
|
||||
}
|
||||
|
||||
// If no rows were accumulated, return empty
|
||||
if accumulated_rows.is_empty() {
|
||||
return Ok(RecordBatch::new_empty(inner.stream.schema()));
|
||||
}
|
||||
|
||||
// If only one batch was accumulated, return it directly
|
||||
if accumulated_rows.len() == 1 {
|
||||
return Ok(accumulated_rows.remove(0));
|
||||
}
|
||||
|
||||
// Merge multiple batches
|
||||
merge_record_batches(inner.stream.schema(), &accumulated_rows)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use datatypes::prelude::ConcreteDataType;
|
||||
use datatypes::schema::{ColumnSchema, Schema};
|
||||
use datatypes::vectors::StringVector;
|
||||
|
||||
use super::*;
|
||||
use crate::RecordBatches;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cursor() {
|
||||
let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
|
||||
"a",
|
||||
ConcreteDataType::string_datatype(),
|
||||
false,
|
||||
)]));
|
||||
|
||||
let rbs = RecordBatches::try_from_columns(
|
||||
schema.clone(),
|
||||
vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let cursor = RecordBatchStreamCursor::new(rbs.as_stream());
|
||||
let result_rb = cursor.take(1).await.expect("take from cursor failed");
|
||||
assert_eq!(result_rb.num_rows(), 1);
|
||||
|
||||
let result_rb = cursor.take(1).await.expect("take from cursor failed");
|
||||
assert_eq!(result_rb.num_rows(), 1);
|
||||
|
||||
let result_rb = cursor.take(1).await.expect("take from cursor failed");
|
||||
assert_eq!(result_rb.num_rows(), 0);
|
||||
|
||||
let rb = RecordBatch::new(
|
||||
schema.clone(),
|
||||
vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _],
|
||||
)
|
||||
.unwrap();
|
||||
let rbs2 =
|
||||
RecordBatches::try_new(schema.clone(), vec![rb.clone(), rb.clone(), rb]).unwrap();
|
||||
let cursor = RecordBatchStreamCursor::new(rbs2.as_stream());
|
||||
let result_rb = cursor.take(3).await.expect("take from cursor failed");
|
||||
assert_eq!(result_rb.num_rows(), 3);
|
||||
let result_rb = cursor.take(2).await.expect("take from cursor failed");
|
||||
assert_eq!(result_rb.num_rows(), 2);
|
||||
let result_rb = cursor.take(2).await.expect("take from cursor failed");
|
||||
assert_eq!(result_rb.num_rows(), 1);
|
||||
let result_rb = cursor.take(2).await.expect("take from cursor failed");
|
||||
assert_eq!(result_rb.num_rows(), 0);
|
||||
|
||||
let rb = RecordBatch::new(
|
||||
schema.clone(),
|
||||
vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _],
|
||||
)
|
||||
.unwrap();
|
||||
let rbs3 =
|
||||
RecordBatches::try_new(schema.clone(), vec![rb.clone(), rb.clone(), rb]).unwrap();
|
||||
let cursor = RecordBatchStreamCursor::new(rbs3.as_stream());
|
||||
let result_rb = cursor.take(10).await.expect("take from cursor failed");
|
||||
assert_eq!(result_rb.num_rows(), 6);
|
||||
}
|
||||
}
|
||||
@@ -168,6 +168,13 @@ pub enum Error {
|
||||
#[snafu(source)]
|
||||
error: tokio::time::error::Elapsed,
|
||||
},
|
||||
#[snafu(display("RecordBatch slice index overflow: {visit_index} > {size}"))]
|
||||
RecordBatchSliceIndexOverflow {
|
||||
#[snafu(implicit)]
|
||||
location: Location,
|
||||
size: usize,
|
||||
visit_index: usize,
|
||||
},
|
||||
}
|
||||
|
||||
impl ErrorExt for Error {
|
||||
@@ -182,7 +189,8 @@ impl ErrorExt for Error {
|
||||
| Error::Format { .. }
|
||||
| Error::ToArrowScalar { .. }
|
||||
| Error::ProjectArrowRecordBatch { .. }
|
||||
| Error::PhysicalExpr { .. } => StatusCode::Internal,
|
||||
| Error::PhysicalExpr { .. }
|
||||
| Error::RecordBatchSliceIndexOverflow { .. } => StatusCode::Internal,
|
||||
|
||||
Error::PollStream { .. } => StatusCode::EngineExecuteQuery,
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#![feature(never_type)]
|
||||
|
||||
pub mod adapter;
|
||||
pub mod cursor;
|
||||
pub mod error;
|
||||
pub mod filter;
|
||||
mod recordbatch;
|
||||
|
||||
@@ -23,7 +23,7 @@ use datatypes::value::Value;
|
||||
use datatypes::vectors::{Helper, VectorRef};
|
||||
use serde::ser::{Error, SerializeStruct};
|
||||
use serde::{Serialize, Serializer};
|
||||
use snafu::{OptionExt, ResultExt};
|
||||
use snafu::{ensure, OptionExt, ResultExt};
|
||||
|
||||
use crate::error::{
|
||||
self, CastVectorSnafu, ColumnNotExistsSnafu, DataTypesSnafu, ProjectArrowRecordBatchSnafu,
|
||||
@@ -194,6 +194,19 @@ impl RecordBatch {
|
||||
.map(|t| t.to_string())
|
||||
.unwrap_or("failed to pretty display a record batch".to_string())
|
||||
}
|
||||
|
||||
/// Return a slice record batch starts from offset, with len rows
|
||||
pub fn slice(&self, offset: usize, len: usize) -> Result<RecordBatch> {
|
||||
ensure!(
|
||||
offset + len <= self.num_rows(),
|
||||
error::RecordBatchSliceIndexOverflowSnafu {
|
||||
size: self.num_rows(),
|
||||
visit_index: offset + len
|
||||
}
|
||||
);
|
||||
let columns = self.columns.iter().map(|vector| vector.slice(offset, len));
|
||||
RecordBatch::new(self.schema.clone(), columns)
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for RecordBatch {
|
||||
@@ -256,6 +269,36 @@ impl Iterator for RecordBatchRowIterator<'_> {
|
||||
}
|
||||
}
|
||||
|
||||
/// merge multiple recordbatch into a single
|
||||
pub fn merge_record_batches(schema: SchemaRef, batches: &[RecordBatch]) -> Result<RecordBatch> {
|
||||
let batches_len = batches.len();
|
||||
if batches_len == 0 {
|
||||
return Ok(RecordBatch::new_empty(schema));
|
||||
}
|
||||
|
||||
let n_rows = batches.iter().map(|b| b.num_rows()).sum();
|
||||
let n_columns = schema.num_columns();
|
||||
// Collect arrays from each batch
|
||||
let mut merged_columns = Vec::with_capacity(n_columns);
|
||||
|
||||
for col_idx in 0..n_columns {
|
||||
let mut acc = schema.column_schemas()[col_idx]
|
||||
.data_type
|
||||
.create_mutable_vector(n_rows);
|
||||
|
||||
for batch in batches {
|
||||
let column = batch.column(col_idx);
|
||||
acc.extend_slice_of(column.as_ref(), 0, column.len())
|
||||
.context(error::DataTypesSnafu)?;
|
||||
}
|
||||
|
||||
merged_columns.push(acc.to_vector());
|
||||
}
|
||||
|
||||
// Create a new RecordBatch with merged columns
|
||||
RecordBatch::new(schema, merged_columns)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
@@ -375,4 +418,80 @@ mod tests {
|
||||
|
||||
assert!(record_batch_iter.next().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_record_batch_slice() {
|
||||
let column_schemas = vec![
|
||||
ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
|
||||
ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
|
||||
];
|
||||
let schema = Arc::new(Schema::new(column_schemas));
|
||||
let columns: Vec<VectorRef> = vec![
|
||||
Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
|
||||
Arc::new(StringVector::from(vec![
|
||||
None,
|
||||
Some("hello"),
|
||||
Some("greptime"),
|
||||
None,
|
||||
])),
|
||||
];
|
||||
let recordbatch = RecordBatch::new(schema, columns).unwrap();
|
||||
let recordbatch = recordbatch.slice(1, 2).expect("recordbatch slice");
|
||||
let mut record_batch_iter = recordbatch.rows();
|
||||
assert_eq!(
|
||||
vec![Value::UInt32(2), Value::String("hello".into())],
|
||||
record_batch_iter
|
||||
.next()
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.collect::<Vec<Value>>()
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
vec![Value::UInt32(3), Value::String("greptime".into())],
|
||||
record_batch_iter
|
||||
.next()
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.collect::<Vec<Value>>()
|
||||
);
|
||||
|
||||
assert!(record_batch_iter.next().is_none());
|
||||
|
||||
assert!(recordbatch.slice(1, 5).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge_record_batch() {
|
||||
let column_schemas = vec![
|
||||
ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
|
||||
ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
|
||||
];
|
||||
let schema = Arc::new(Schema::new(column_schemas));
|
||||
let columns: Vec<VectorRef> = vec![
|
||||
Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
|
||||
Arc::new(StringVector::from(vec![
|
||||
None,
|
||||
Some("hello"),
|
||||
Some("greptime"),
|
||||
None,
|
||||
])),
|
||||
];
|
||||
let recordbatch = RecordBatch::new(schema.clone(), columns).unwrap();
|
||||
|
||||
let columns: Vec<VectorRef> = vec![
|
||||
Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
|
||||
Arc::new(StringVector::from(vec![
|
||||
None,
|
||||
Some("hello"),
|
||||
Some("greptime"),
|
||||
None,
|
||||
])),
|
||||
];
|
||||
let recordbatch2 = RecordBatch::new(schema.clone(), columns).unwrap();
|
||||
|
||||
let merged = merge_record_batches(schema.clone(), &[recordbatch, recordbatch2])
|
||||
.expect("merge recordbatch");
|
||||
assert_eq!(merged.num_rows(), 8);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -487,7 +487,11 @@ pub fn check_permission(
|
||||
// TODO(dennis): add a hook for admin commands.
|
||||
Statement::Admin(_) => {}
|
||||
// These are executed by query engine, and will be checked there.
|
||||
Statement::Query(_) | Statement::Explain(_) | Statement::Tql(_) | Statement::Delete(_) => {}
|
||||
Statement::Query(_)
|
||||
| Statement::Explain(_)
|
||||
| Statement::Tql(_)
|
||||
| Statement::Delete(_)
|
||||
| Statement::DeclareCursor(_) => {}
|
||||
// database ops won't be checked
|
||||
Statement::CreateDatabase(_)
|
||||
| Statement::ShowDatabases(_)
|
||||
@@ -580,6 +584,8 @@ pub fn check_permission(
|
||||
Statement::TruncateTable(stmt) => {
|
||||
validate_param(stmt.table_name(), query_ctx)?;
|
||||
}
|
||||
// cursor operations are always allowed once it's created
|
||||
Statement::FetchCursor(_) | Statement::CloseCursor(_) => {}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -786,6 +786,12 @@ pub enum Error {
|
||||
#[snafu(source)]
|
||||
error: Elapsed,
|
||||
},
|
||||
|
||||
#[snafu(display("Cursor {name} is not found"))]
|
||||
CursorNotFound { name: String },
|
||||
|
||||
#[snafu(display("A cursor named {name} already exists"))]
|
||||
CursorExists { name: String },
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
@@ -825,7 +831,9 @@ impl ErrorExt for Error {
|
||||
| Error::FunctionArityMismatch { .. }
|
||||
| Error::InvalidPartition { .. }
|
||||
| Error::PhysicalExpr { .. }
|
||||
| Error::InvalidJsonFormat { .. } => StatusCode::InvalidArguments,
|
||||
| Error::InvalidJsonFormat { .. }
|
||||
| Error::CursorNotFound { .. }
|
||||
| Error::CursorExists { .. } => StatusCode::InvalidArguments,
|
||||
|
||||
Error::TableAlreadyExists { .. } | Error::ViewAlreadyExists { .. } => {
|
||||
StatusCode::TableAlreadyExists
|
||||
|
||||
@@ -16,6 +16,7 @@ mod admin;
|
||||
mod copy_database;
|
||||
mod copy_table_from;
|
||||
mod copy_table_to;
|
||||
mod cursor;
|
||||
mod ddl;
|
||||
mod describe;
|
||||
mod dml;
|
||||
@@ -133,6 +134,16 @@ impl StatementExecutor {
|
||||
self.plan_exec(QueryStatement::Sql(stmt), query_ctx).await
|
||||
}
|
||||
|
||||
Statement::DeclareCursor(declare_cursor) => {
|
||||
self.declare_cursor(declare_cursor, query_ctx).await
|
||||
}
|
||||
Statement::FetchCursor(fetch_cursor) => {
|
||||
self.fetch_cursor(fetch_cursor, query_ctx).await
|
||||
}
|
||||
Statement::CloseCursor(close_cursor) => {
|
||||
self.close_cursor(close_cursor, query_ctx).await
|
||||
}
|
||||
|
||||
Statement::Insert(insert) => self.insert(insert, query_ctx).await,
|
||||
|
||||
Statement::Tql(tql) => self.execute_tql(tql, query_ctx).await,
|
||||
|
||||
98
src/operator/src/statement/cursor.rs
Normal file
98
src/operator/src/statement/cursor.rs
Normal file
@@ -0,0 +1,98 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use common_query::{Output, OutputData};
|
||||
use common_recordbatch::cursor::RecordBatchStreamCursor;
|
||||
use common_recordbatch::RecordBatches;
|
||||
use common_telemetry::tracing;
|
||||
use query::parser::QueryStatement;
|
||||
use session::context::QueryContextRef;
|
||||
use snafu::ResultExt;
|
||||
use sql::statements::cursor::{CloseCursor, DeclareCursor, FetchCursor};
|
||||
use sql::statements::statement::Statement;
|
||||
|
||||
use crate::error::{self, Result};
|
||||
use crate::statement::StatementExecutor;
|
||||
|
||||
impl StatementExecutor {
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub(super) async fn declare_cursor(
|
||||
&self,
|
||||
declare_cursor: DeclareCursor,
|
||||
query_ctx: QueryContextRef,
|
||||
) -> Result<Output> {
|
||||
let cursor_name = declare_cursor.cursor_name.to_string();
|
||||
|
||||
if query_ctx.get_cursor(&cursor_name).is_some() {
|
||||
error::CursorExistsSnafu {
|
||||
name: cursor_name.to_string(),
|
||||
}
|
||||
.fail()?;
|
||||
}
|
||||
|
||||
let query_stmt = Statement::Query(declare_cursor.query);
|
||||
|
||||
let output = self
|
||||
.plan_exec(QueryStatement::Sql(query_stmt), query_ctx.clone())
|
||||
.await?;
|
||||
match output.data {
|
||||
OutputData::RecordBatches(rb) => {
|
||||
let rbs = rb.as_stream();
|
||||
query_ctx.insert_cursor(cursor_name, RecordBatchStreamCursor::new(rbs));
|
||||
}
|
||||
OutputData::Stream(rbs) => {
|
||||
query_ctx.insert_cursor(cursor_name, RecordBatchStreamCursor::new(rbs));
|
||||
}
|
||||
// Should not happen because we have query type ensured from parser.
|
||||
OutputData::AffectedRows(_) => error::NotSupportedSnafu {
|
||||
feat: "Non-query statement on cursor",
|
||||
}
|
||||
.fail()?,
|
||||
}
|
||||
|
||||
Ok(Output::new_with_affected_rows(0))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub(super) async fn fetch_cursor(
|
||||
&self,
|
||||
fetch_cursor: FetchCursor,
|
||||
query_ctx: QueryContextRef,
|
||||
) -> Result<Output> {
|
||||
let cursor_name = fetch_cursor.cursor_name.to_string();
|
||||
let fetch_size = fetch_cursor.fetch_size;
|
||||
if let Some(rb) = query_ctx.get_cursor(&cursor_name) {
|
||||
let record_batch = rb
|
||||
.take(fetch_size as usize)
|
||||
.await
|
||||
.context(error::BuildRecordBatchSnafu)?;
|
||||
let record_batches =
|
||||
RecordBatches::try_new(record_batch.schema.clone(), vec![record_batch])
|
||||
.context(error::BuildRecordBatchSnafu)?;
|
||||
Ok(Output::new_with_record_batches(record_batches))
|
||||
} else {
|
||||
error::CursorNotFoundSnafu { name: cursor_name }.fail()
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub(super) async fn close_cursor(
|
||||
&self,
|
||||
close_cursor: CloseCursor,
|
||||
query_ctx: QueryContextRef,
|
||||
) -> Result<Output> {
|
||||
query_ctx.remove_cursor(&close_cursor.cursor_name.to_string());
|
||||
Ok(Output::new_with_affected_rows(0))
|
||||
}
|
||||
}
|
||||
@@ -12,6 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::borrow::Cow;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -112,6 +113,13 @@ pub(crate) fn process<'a>(query: &str, query_ctx: QueryContextRef) -> Option<Vec
|
||||
}
|
||||
}
|
||||
|
||||
static LIMIT_CAST_PATTERN: Lazy<Regex> =
|
||||
Lazy::new(|| Regex::new("(?i)(LIMIT\\s+\\d+)::bigint").unwrap());
|
||||
pub(crate) fn rewrite_sql(query: &str) -> Cow<'_, str> {
|
||||
//TODO(sunng87): remove this when we upgraded datafusion to 43 or newer
|
||||
LIMIT_CAST_PATTERN.replace_all(query, "$1")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use session::context::{QueryContext, QueryContextRef};
|
||||
@@ -195,4 +203,13 @@ mod test {
|
||||
assert!(process("SHOW TABLES ", query_context.clone()).is_none());
|
||||
assert!(process("SET TIME_ZONE=utc ", query_context.clone()).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rewrite() {
|
||||
let sql = "SELECT * FROM number LIMIT 1::bigint";
|
||||
let sql2 = "SELECT * FROM number limit 1::BIGINT";
|
||||
|
||||
assert_eq!("SELECT * FROM number LIMIT 1", rewrite_sql(sql));
|
||||
assert_eq!("SELECT * FROM number limit 1", rewrite_sql(sql2));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -70,6 +70,9 @@ impl SimpleQueryHandler for PostgresServerHandlerInner {
|
||||
return Ok(vec![Response::EmptyQuery]);
|
||||
}
|
||||
|
||||
let query = fixtures::rewrite_sql(query);
|
||||
let query = query.as_ref();
|
||||
|
||||
if let Some(resps) = fixtures::process(query, query_ctx.clone()) {
|
||||
send_warning_opt(client, query_ctx).await?;
|
||||
Ok(resps)
|
||||
@@ -229,6 +232,9 @@ impl QueryParser for DefaultQueryParser {
|
||||
});
|
||||
}
|
||||
|
||||
let sql = fixtures::rewrite_sql(sql);
|
||||
let sql = sql.as_ref();
|
||||
|
||||
let mut stmts =
|
||||
ParserContext::create_with_dialect(sql, &PostgreSqlDialect {}, ParseOptions::default())
|
||||
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
|
||||
|
||||
@@ -17,9 +17,11 @@ auth.workspace = true
|
||||
common-catalog.workspace = true
|
||||
common-error.workspace = true
|
||||
common-macro.workspace = true
|
||||
common-recordbatch.workspace = true
|
||||
common-telemetry.workspace = true
|
||||
common-time.workspace = true
|
||||
derive_builder.workspace = true
|
||||
derive_more = { version = "1", default-features = false, features = ["debug"] }
|
||||
meter-core.workspace = true
|
||||
snafu.workspace = true
|
||||
sql.workspace = true
|
||||
|
||||
@@ -23,6 +23,8 @@ use arc_swap::ArcSwap;
|
||||
use auth::UserInfoRef;
|
||||
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
|
||||
use common_catalog::{build_db_string, parse_catalog_and_schema_from_db_string};
|
||||
use common_recordbatch::cursor::RecordBatchStreamCursor;
|
||||
use common_telemetry::warn;
|
||||
use common_time::timezone::parse_timezone;
|
||||
use common_time::Timezone;
|
||||
use derive_builder::Builder;
|
||||
@@ -34,6 +36,8 @@ use crate::MutableInner;
|
||||
pub type QueryContextRef = Arc<QueryContext>;
|
||||
pub type ConnInfoRef = Arc<ConnInfo>;
|
||||
|
||||
const CURSOR_COUNT_WARNING_LIMIT: usize = 10;
|
||||
|
||||
#[derive(Debug, Builder, Clone)]
|
||||
#[builder(pattern = "owned")]
|
||||
#[builder(build_fn(skip))]
|
||||
@@ -299,6 +303,27 @@ impl QueryContext {
|
||||
pub fn set_query_timeout(&self, timeout: Duration) {
|
||||
self.mutable_session_data.write().unwrap().query_timeout = Some(timeout);
|
||||
}
|
||||
|
||||
pub fn insert_cursor(&self, name: String, rb: RecordBatchStreamCursor) {
|
||||
let mut guard = self.mutable_session_data.write().unwrap();
|
||||
guard.cursors.insert(name, Arc::new(rb));
|
||||
|
||||
let cursor_count = guard.cursors.len();
|
||||
if cursor_count > CURSOR_COUNT_WARNING_LIMIT {
|
||||
warn!("Current connection has {} open cursors", cursor_count);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn remove_cursor(&self, name: &str) {
|
||||
let mut guard = self.mutable_session_data.write().unwrap();
|
||||
guard.cursors.remove(name);
|
||||
}
|
||||
|
||||
pub fn get_cursor(&self, name: &str) -> Option<Arc<RecordBatchStreamCursor>> {
|
||||
let guard = self.mutable_session_data.read().unwrap();
|
||||
let rb = guard.cursors.get(name);
|
||||
rb.cloned()
|
||||
}
|
||||
}
|
||||
|
||||
impl QueryContextBuilder {
|
||||
|
||||
@@ -16,6 +16,7 @@ pub mod context;
|
||||
pub mod session_config;
|
||||
pub mod table_name;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::time::Duration;
|
||||
@@ -23,9 +24,11 @@ use std::time::Duration;
|
||||
use auth::UserInfoRef;
|
||||
use common_catalog::build_db_string;
|
||||
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
|
||||
use common_recordbatch::cursor::RecordBatchStreamCursor;
|
||||
use common_time::timezone::get_timezone;
|
||||
use common_time::Timezone;
|
||||
use context::{ConfigurationVariables, QueryContextBuilder};
|
||||
use derive_more::Debug;
|
||||
|
||||
use crate::context::{Channel, ConnInfo, QueryContextRef};
|
||||
|
||||
@@ -47,6 +50,8 @@ pub(crate) struct MutableInner {
|
||||
user_info: UserInfoRef,
|
||||
timezone: Timezone,
|
||||
query_timeout: Option<Duration>,
|
||||
#[debug(skip)]
|
||||
pub(crate) cursors: HashMap<String, Arc<RecordBatchStreamCursor>>,
|
||||
}
|
||||
|
||||
impl Default for MutableInner {
|
||||
@@ -56,6 +61,7 @@ impl Default for MutableInner {
|
||||
user_info: auth::userinfo_by_name(None),
|
||||
timezone: get_timezone(None).clone(),
|
||||
query_timeout: None,
|
||||
cursors: HashMap::with_capacity(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -167,6 +167,12 @@ impl ParserContext<'_> {
|
||||
self.parse_tql()
|
||||
}
|
||||
|
||||
Keyword::DECLARE => self.parse_declare_cursor(),
|
||||
|
||||
Keyword::FETCH => self.parse_fetch_cursor(),
|
||||
|
||||
Keyword::CLOSE => self.parse_close_cursor(),
|
||||
|
||||
Keyword::USE => {
|
||||
let _ = self.parser.next_token();
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ pub(crate) mod admin_parser;
|
||||
mod alter_parser;
|
||||
pub(crate) mod copy_parser;
|
||||
pub(crate) mod create_parser;
|
||||
pub(crate) mod cursor_parser;
|
||||
pub(crate) mod deallocate_parser;
|
||||
pub(crate) mod delete_parser;
|
||||
pub(crate) mod describe_parser;
|
||||
|
||||
157
src/sql/src/parsers/cursor_parser.rs
Normal file
157
src/sql/src/parsers/cursor_parser.rs
Normal file
@@ -0,0 +1,157 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use snafu::{ensure, ResultExt};
|
||||
use sqlparser::keywords::Keyword;
|
||||
use sqlparser::tokenizer::Token;
|
||||
|
||||
use crate::error::{self, Result};
|
||||
use crate::parser::ParserContext;
|
||||
use crate::statements::cursor::{CloseCursor, DeclareCursor, FetchCursor};
|
||||
use crate::statements::statement::Statement;
|
||||
|
||||
impl ParserContext<'_> {
|
||||
pub(crate) fn parse_declare_cursor(&mut self) -> Result<Statement> {
|
||||
let _ = self.parser.expect_keyword(Keyword::DECLARE);
|
||||
let cursor_name = self
|
||||
.parser
|
||||
.parse_object_name(false)
|
||||
.context(error::SyntaxSnafu)?;
|
||||
let _ = self
|
||||
.parser
|
||||
.expect_keywords(&[Keyword::CURSOR, Keyword::FOR]);
|
||||
|
||||
let mut is_select = false;
|
||||
if let Token::Word(w) = self.parser.peek_token().token {
|
||||
match w.keyword {
|
||||
Keyword::SELECT | Keyword::WITH => {
|
||||
is_select = true;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
};
|
||||
ensure!(
|
||||
is_select,
|
||||
error::InvalidSqlSnafu {
|
||||
msg: "Expect select query in cursor statement".to_string(),
|
||||
}
|
||||
);
|
||||
|
||||
let query_stmt = self.parse_query()?;
|
||||
match query_stmt {
|
||||
Statement::Query(query) => Ok(Statement::DeclareCursor(DeclareCursor {
|
||||
cursor_name: ParserContext::canonicalize_object_name(cursor_name),
|
||||
query,
|
||||
})),
|
||||
_ => error::InvalidSqlSnafu {
|
||||
msg: format!("Expect query, found {}", query_stmt),
|
||||
}
|
||||
.fail(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn parse_fetch_cursor(&mut self) -> Result<Statement> {
|
||||
let _ = self.parser.expect_keyword(Keyword::FETCH);
|
||||
|
||||
let fetch_size = self
|
||||
.parser
|
||||
.parse_literal_uint()
|
||||
.context(error::SyntaxSnafu)?;
|
||||
let _ = self.parser.parse_keyword(Keyword::FROM);
|
||||
|
||||
let cursor_name = self
|
||||
.parser
|
||||
.parse_object_name(false)
|
||||
.context(error::SyntaxSnafu)?;
|
||||
|
||||
Ok(Statement::FetchCursor(FetchCursor {
|
||||
cursor_name: ParserContext::canonicalize_object_name(cursor_name),
|
||||
fetch_size,
|
||||
}))
|
||||
}
|
||||
|
||||
pub(crate) fn parse_close_cursor(&mut self) -> Result<Statement> {
|
||||
let _ = self.parser.expect_keyword(Keyword::CLOSE);
|
||||
let cursor_name = self
|
||||
.parser
|
||||
.parse_object_name(false)
|
||||
.context(error::SyntaxSnafu)?;
|
||||
|
||||
Ok(Statement::CloseCursor(CloseCursor {
|
||||
cursor_name: ParserContext::canonicalize_object_name(cursor_name),
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::*;
|
||||
use crate::dialect::GreptimeDbDialect;
|
||||
use crate::parser::ParseOptions;
|
||||
|
||||
#[test]
|
||||
fn test_parse_declare_cursor() {
|
||||
let sql = "DECLARE c1 CURSOR FOR\nSELECT * FROM numbers";
|
||||
let result =
|
||||
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
|
||||
.unwrap();
|
||||
|
||||
if let Statement::DeclareCursor(dc) = &result[0] {
|
||||
assert_eq!("c1", dc.cursor_name.to_string());
|
||||
assert_eq!(
|
||||
"DECLARE c1 CURSOR FOR SELECT * FROM numbers",
|
||||
dc.to_string()
|
||||
);
|
||||
} else {
|
||||
panic!("Unexpected statement");
|
||||
}
|
||||
|
||||
let sql = "DECLARE c1 CURSOR FOR\nINSERT INTO numbers VALUES (1);";
|
||||
let result =
|
||||
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parese_fetch_cursor() {
|
||||
let sql = "FETCH 1000 FROM c1";
|
||||
let result =
|
||||
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
|
||||
.unwrap();
|
||||
|
||||
if let Statement::FetchCursor(fc) = &result[0] {
|
||||
assert_eq!("c1", fc.cursor_name.to_string());
|
||||
assert_eq!("1000", fc.fetch_size.to_string());
|
||||
assert_eq!(sql, fc.to_string());
|
||||
} else {
|
||||
panic!("Unexpected statement")
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_close_fetch_cursor() {
|
||||
let sql = "CLOSE c1";
|
||||
let result =
|
||||
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
|
||||
.unwrap();
|
||||
|
||||
if let Statement::CloseCursor(cc) = &result[0] {
|
||||
assert_eq!("c1", cc.cursor_name.to_string());
|
||||
assert_eq!(sql, cc.to_string());
|
||||
} else {
|
||||
panic!("Unexpected statement")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -16,6 +16,7 @@ pub mod admin;
|
||||
pub mod alter;
|
||||
pub mod copy;
|
||||
pub mod create;
|
||||
pub mod cursor;
|
||||
pub mod delete;
|
||||
pub mod describe;
|
||||
pub mod drop;
|
||||
@@ -224,7 +225,7 @@ pub fn sql_number_to_value(data_type: &ConcreteDataType, n: &str) -> Result<Valu
|
||||
// TODO(hl): also Date/DateTime
|
||||
}
|
||||
|
||||
fn parse_sql_number<R: FromStr + std::fmt::Debug>(n: &str) -> Result<R>
|
||||
pub(crate) fn parse_sql_number<R: FromStr + std::fmt::Debug>(n: &str) -> Result<R>
|
||||
where
|
||||
<R as FromStr>::Err: std::fmt::Debug,
|
||||
{
|
||||
|
||||
60
src/sql/src/statements/cursor.rs
Normal file
60
src/sql/src/statements/cursor.rs
Normal file
@@ -0,0 +1,60 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::fmt::Display;
|
||||
|
||||
use sqlparser::ast::ObjectName;
|
||||
use sqlparser_derive::{Visit, VisitMut};
|
||||
|
||||
use super::query::Query;
|
||||
|
||||
/// Represents a DECLARE CURSOR statement
|
||||
///
|
||||
/// This statement will carry a SQL query
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut)]
|
||||
pub struct DeclareCursor {
|
||||
pub cursor_name: ObjectName,
|
||||
pub query: Box<Query>,
|
||||
}
|
||||
|
||||
impl Display for DeclareCursor {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "DECLARE {} CURSOR FOR {}", self.cursor_name, self.query)
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a FETCH FROM cursor statement
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut)]
|
||||
pub struct FetchCursor {
|
||||
pub cursor_name: ObjectName,
|
||||
pub fetch_size: u64,
|
||||
}
|
||||
|
||||
impl Display for FetchCursor {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "FETCH {} FROM {}", self.fetch_size, self.cursor_name)
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a CLOSE cursor statement
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut)]
|
||||
pub struct CloseCursor {
|
||||
pub cursor_name: ObjectName,
|
||||
}
|
||||
|
||||
impl Display for CloseCursor {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "CLOSE {}", self.cursor_name)
|
||||
}
|
||||
}
|
||||
@@ -24,6 +24,7 @@ use crate::statements::alter::{AlterDatabase, AlterTable};
|
||||
use crate::statements::create::{
|
||||
CreateDatabase, CreateExternalTable, CreateFlow, CreateTable, CreateTableLike, CreateView,
|
||||
};
|
||||
use crate::statements::cursor::{CloseCursor, DeclareCursor, FetchCursor};
|
||||
use crate::statements::delete::Delete;
|
||||
use crate::statements::describe::DescribeTable;
|
||||
use crate::statements::drop::{DropDatabase, DropFlow, DropTable, DropView};
|
||||
@@ -118,6 +119,12 @@ pub enum Statement {
|
||||
Use(String),
|
||||
// Admin statement(extension)
|
||||
Admin(Admin),
|
||||
// DECLARE ... CURSOR FOR ...
|
||||
DeclareCursor(DeclareCursor),
|
||||
// FETCH ... FROM ...
|
||||
FetchCursor(FetchCursor),
|
||||
// CLOSE
|
||||
CloseCursor(CloseCursor),
|
||||
}
|
||||
|
||||
impl Display for Statement {
|
||||
@@ -165,6 +172,9 @@ impl Display for Statement {
|
||||
Statement::CreateView(s) => s.fmt(f),
|
||||
Statement::Use(s) => s.fmt(f),
|
||||
Statement::Admin(admin) => admin.fmt(f),
|
||||
Statement::DeclareCursor(s) => s.fmt(f),
|
||||
Statement::FetchCursor(s) => s.fmt(f),
|
||||
Statement::CloseCursor(s) => s.fmt(f),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,6 +72,7 @@ macro_rules! sql_tests {
|
||||
test_postgres_parameter_inference,
|
||||
test_postgres_array_types,
|
||||
test_mysql_prepare_stmt_insert_timestamp,
|
||||
test_declare_fetch_close_cursor,
|
||||
);
|
||||
)*
|
||||
};
|
||||
@@ -1198,3 +1199,66 @@ pub async fn test_postgres_array_types(store_type: StorageType) {
|
||||
let _ = fe_pg_server.shutdown().await;
|
||||
guard.remove_all().await;
|
||||
}
|
||||
|
||||
pub async fn test_declare_fetch_close_cursor(store_type: StorageType) {
|
||||
let (addr, mut guard, fe_pg_server) = setup_pg_server(store_type, "sql_inference").await;
|
||||
|
||||
let (client, connection) = tokio_postgres::connect(&format!("postgres://{addr}/public"), NoTls)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let (tx, rx) = tokio::sync::oneshot::channel();
|
||||
tokio::spawn(async move {
|
||||
connection.await.unwrap();
|
||||
tx.send(()).unwrap();
|
||||
});
|
||||
|
||||
client
|
||||
.execute(
|
||||
"DECLARE c1 CURSOR FOR SELECT * FROM numbers WHERE number > 2 LIMIT 50::bigint",
|
||||
&[],
|
||||
)
|
||||
.await
|
||||
.expect("declare cursor");
|
||||
|
||||
// duplicated cursor
|
||||
assert!(client
|
||||
.execute("DECLARE c1 CURSOR FOR SELECT 1", &[],)
|
||||
.await
|
||||
.is_err());
|
||||
|
||||
let rows = client.query("FETCH 5 FROM c1", &[]).await.unwrap();
|
||||
assert_eq!(5, rows.len());
|
||||
|
||||
let rows = client.query("FETCH 100 FROM c1", &[]).await.unwrap();
|
||||
assert_eq!(45, rows.len());
|
||||
|
||||
let rows = client.query("FETCH 100 FROM c1", &[]).await.unwrap();
|
||||
assert_eq!(0, rows.len());
|
||||
|
||||
client.execute("CLOSE c1", &[]).await.expect("close cursor");
|
||||
|
||||
// cursor not found
|
||||
let result = client.query("FETCH 100 FROM c1", &[]).await;
|
||||
assert!(result.is_err());
|
||||
|
||||
client
|
||||
.execute(
|
||||
"DECLARE c2 CURSOR FOR SELECT * FROM numbers WHERE number < 0",
|
||||
&[],
|
||||
)
|
||||
.await
|
||||
.expect("declare cursor");
|
||||
|
||||
let rows = client.query("FETCH 5 FROM c2", &[]).await.unwrap();
|
||||
assert_eq!(0, rows.len());
|
||||
|
||||
client.execute("CLOSE c2", &[]).await.expect("close cursor");
|
||||
|
||||
// Shutdown the client.
|
||||
drop(client);
|
||||
rx.await.unwrap();
|
||||
|
||||
let _ = fe_pg_server.shutdown().await;
|
||||
guard.remove_all().await;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user