mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-21 07:20:41 +00:00
feat: add cursor statements (#5094)
* feat: add sql parsers for cursor operations * feat: cursor operator * feat: implement RecordBatchStreamCursor * feat: implement cursor storage and execution * test: add tests * chore: update docstring * feat: add a temporary sql rewrite for cast in limit this issue is described in #5097 * test: add more sql for cursor integration test * feat: reject non-select query for cursor statement * refactor: address review issues * test: add empty result case * feat: address review comments
This commit is contained in:
173
src/common/recordbatch/src/cursor.rs
Normal file
173
src/common/recordbatch/src/cursor.rs
Normal file
@@ -0,0 +1,173 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use futures::StreamExt;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::recordbatch::merge_record_batches;
|
||||
use crate::{RecordBatch, SendableRecordBatchStream};
|
||||
|
||||
struct Inner {
|
||||
stream: SendableRecordBatchStream,
|
||||
current_row_index: usize,
|
||||
current_batch: Option<RecordBatch>,
|
||||
total_rows_in_current_batch: usize,
|
||||
}
|
||||
|
||||
/// A cursor on RecordBatchStream that fetches data batch by batch
|
||||
pub struct RecordBatchStreamCursor {
|
||||
inner: Mutex<Inner>,
|
||||
}
|
||||
|
||||
impl RecordBatchStreamCursor {
|
||||
pub fn new(stream: SendableRecordBatchStream) -> RecordBatchStreamCursor {
|
||||
Self {
|
||||
inner: Mutex::new(Inner {
|
||||
stream,
|
||||
current_row_index: 0,
|
||||
current_batch: None,
|
||||
total_rows_in_current_batch: 0,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Take `size` of row from the `RecordBatchStream` and create a new
|
||||
/// `RecordBatch` for these rows.
|
||||
pub async fn take(&self, size: usize) -> Result<RecordBatch> {
|
||||
let mut remaining_rows_to_take = size;
|
||||
let mut accumulated_rows = Vec::new();
|
||||
|
||||
let mut inner = self.inner.lock().await;
|
||||
|
||||
while remaining_rows_to_take > 0 {
|
||||
// Ensure we have a current batch or fetch the next one
|
||||
if inner.current_batch.is_none()
|
||||
|| inner.current_row_index >= inner.total_rows_in_current_batch
|
||||
{
|
||||
match inner.stream.next().await {
|
||||
Some(Ok(batch)) => {
|
||||
inner.total_rows_in_current_batch = batch.num_rows();
|
||||
inner.current_batch = Some(batch);
|
||||
inner.current_row_index = 0;
|
||||
}
|
||||
Some(Err(e)) => return Err(e),
|
||||
None => {
|
||||
// Stream is exhausted
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we still have no batch after attempting to fetch
|
||||
let current_batch = match &inner.current_batch {
|
||||
Some(batch) => batch,
|
||||
None => break,
|
||||
};
|
||||
|
||||
// Calculate how many rows we can take from this batch
|
||||
let rows_to_take_from_batch = remaining_rows_to_take
|
||||
.min(inner.total_rows_in_current_batch - inner.current_row_index);
|
||||
|
||||
// Slice the current batch to get the desired rows
|
||||
let taken_batch =
|
||||
current_batch.slice(inner.current_row_index, rows_to_take_from_batch)?;
|
||||
|
||||
// Add the taken batch to accumulated rows
|
||||
accumulated_rows.push(taken_batch);
|
||||
|
||||
// Update cursor and remaining rows
|
||||
inner.current_row_index += rows_to_take_from_batch;
|
||||
remaining_rows_to_take -= rows_to_take_from_batch;
|
||||
}
|
||||
|
||||
// If no rows were accumulated, return empty
|
||||
if accumulated_rows.is_empty() {
|
||||
return Ok(RecordBatch::new_empty(inner.stream.schema()));
|
||||
}
|
||||
|
||||
// If only one batch was accumulated, return it directly
|
||||
if accumulated_rows.len() == 1 {
|
||||
return Ok(accumulated_rows.remove(0));
|
||||
}
|
||||
|
||||
// Merge multiple batches
|
||||
merge_record_batches(inner.stream.schema(), &accumulated_rows)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use datatypes::prelude::ConcreteDataType;
|
||||
use datatypes::schema::{ColumnSchema, Schema};
|
||||
use datatypes::vectors::StringVector;
|
||||
|
||||
use super::*;
|
||||
use crate::RecordBatches;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cursor() {
|
||||
let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
|
||||
"a",
|
||||
ConcreteDataType::string_datatype(),
|
||||
false,
|
||||
)]));
|
||||
|
||||
let rbs = RecordBatches::try_from_columns(
|
||||
schema.clone(),
|
||||
vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let cursor = RecordBatchStreamCursor::new(rbs.as_stream());
|
||||
let result_rb = cursor.take(1).await.expect("take from cursor failed");
|
||||
assert_eq!(result_rb.num_rows(), 1);
|
||||
|
||||
let result_rb = cursor.take(1).await.expect("take from cursor failed");
|
||||
assert_eq!(result_rb.num_rows(), 1);
|
||||
|
||||
let result_rb = cursor.take(1).await.expect("take from cursor failed");
|
||||
assert_eq!(result_rb.num_rows(), 0);
|
||||
|
||||
let rb = RecordBatch::new(
|
||||
schema.clone(),
|
||||
vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _],
|
||||
)
|
||||
.unwrap();
|
||||
let rbs2 =
|
||||
RecordBatches::try_new(schema.clone(), vec![rb.clone(), rb.clone(), rb]).unwrap();
|
||||
let cursor = RecordBatchStreamCursor::new(rbs2.as_stream());
|
||||
let result_rb = cursor.take(3).await.expect("take from cursor failed");
|
||||
assert_eq!(result_rb.num_rows(), 3);
|
||||
let result_rb = cursor.take(2).await.expect("take from cursor failed");
|
||||
assert_eq!(result_rb.num_rows(), 2);
|
||||
let result_rb = cursor.take(2).await.expect("take from cursor failed");
|
||||
assert_eq!(result_rb.num_rows(), 1);
|
||||
let result_rb = cursor.take(2).await.expect("take from cursor failed");
|
||||
assert_eq!(result_rb.num_rows(), 0);
|
||||
|
||||
let rb = RecordBatch::new(
|
||||
schema.clone(),
|
||||
vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _],
|
||||
)
|
||||
.unwrap();
|
||||
let rbs3 =
|
||||
RecordBatches::try_new(schema.clone(), vec![rb.clone(), rb.clone(), rb]).unwrap();
|
||||
let cursor = RecordBatchStreamCursor::new(rbs3.as_stream());
|
||||
let result_rb = cursor.take(10).await.expect("take from cursor failed");
|
||||
assert_eq!(result_rb.num_rows(), 6);
|
||||
}
|
||||
}
|
||||
@@ -168,6 +168,13 @@ pub enum Error {
|
||||
#[snafu(source)]
|
||||
error: tokio::time::error::Elapsed,
|
||||
},
|
||||
#[snafu(display("RecordBatch slice index overflow: {visit_index} > {size}"))]
|
||||
RecordBatchSliceIndexOverflow {
|
||||
#[snafu(implicit)]
|
||||
location: Location,
|
||||
size: usize,
|
||||
visit_index: usize,
|
||||
},
|
||||
}
|
||||
|
||||
impl ErrorExt for Error {
|
||||
@@ -182,7 +189,8 @@ impl ErrorExt for Error {
|
||||
| Error::Format { .. }
|
||||
| Error::ToArrowScalar { .. }
|
||||
| Error::ProjectArrowRecordBatch { .. }
|
||||
| Error::PhysicalExpr { .. } => StatusCode::Internal,
|
||||
| Error::PhysicalExpr { .. }
|
||||
| Error::RecordBatchSliceIndexOverflow { .. } => StatusCode::Internal,
|
||||
|
||||
Error::PollStream { .. } => StatusCode::EngineExecuteQuery,
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#![feature(never_type)]
|
||||
|
||||
pub mod adapter;
|
||||
pub mod cursor;
|
||||
pub mod error;
|
||||
pub mod filter;
|
||||
mod recordbatch;
|
||||
|
||||
@@ -23,7 +23,7 @@ use datatypes::value::Value;
|
||||
use datatypes::vectors::{Helper, VectorRef};
|
||||
use serde::ser::{Error, SerializeStruct};
|
||||
use serde::{Serialize, Serializer};
|
||||
use snafu::{OptionExt, ResultExt};
|
||||
use snafu::{ensure, OptionExt, ResultExt};
|
||||
|
||||
use crate::error::{
|
||||
self, CastVectorSnafu, ColumnNotExistsSnafu, DataTypesSnafu, ProjectArrowRecordBatchSnafu,
|
||||
@@ -194,6 +194,19 @@ impl RecordBatch {
|
||||
.map(|t| t.to_string())
|
||||
.unwrap_or("failed to pretty display a record batch".to_string())
|
||||
}
|
||||
|
||||
/// Return a slice record batch starts from offset, with len rows
|
||||
pub fn slice(&self, offset: usize, len: usize) -> Result<RecordBatch> {
|
||||
ensure!(
|
||||
offset + len <= self.num_rows(),
|
||||
error::RecordBatchSliceIndexOverflowSnafu {
|
||||
size: self.num_rows(),
|
||||
visit_index: offset + len
|
||||
}
|
||||
);
|
||||
let columns = self.columns.iter().map(|vector| vector.slice(offset, len));
|
||||
RecordBatch::new(self.schema.clone(), columns)
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for RecordBatch {
|
||||
@@ -256,6 +269,36 @@ impl Iterator for RecordBatchRowIterator<'_> {
|
||||
}
|
||||
}
|
||||
|
||||
/// merge multiple recordbatch into a single
|
||||
pub fn merge_record_batches(schema: SchemaRef, batches: &[RecordBatch]) -> Result<RecordBatch> {
|
||||
let batches_len = batches.len();
|
||||
if batches_len == 0 {
|
||||
return Ok(RecordBatch::new_empty(schema));
|
||||
}
|
||||
|
||||
let n_rows = batches.iter().map(|b| b.num_rows()).sum();
|
||||
let n_columns = schema.num_columns();
|
||||
// Collect arrays from each batch
|
||||
let mut merged_columns = Vec::with_capacity(n_columns);
|
||||
|
||||
for col_idx in 0..n_columns {
|
||||
let mut acc = schema.column_schemas()[col_idx]
|
||||
.data_type
|
||||
.create_mutable_vector(n_rows);
|
||||
|
||||
for batch in batches {
|
||||
let column = batch.column(col_idx);
|
||||
acc.extend_slice_of(column.as_ref(), 0, column.len())
|
||||
.context(error::DataTypesSnafu)?;
|
||||
}
|
||||
|
||||
merged_columns.push(acc.to_vector());
|
||||
}
|
||||
|
||||
// Create a new RecordBatch with merged columns
|
||||
RecordBatch::new(schema, merged_columns)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
@@ -375,4 +418,80 @@ mod tests {
|
||||
|
||||
assert!(record_batch_iter.next().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_record_batch_slice() {
|
||||
let column_schemas = vec![
|
||||
ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
|
||||
ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
|
||||
];
|
||||
let schema = Arc::new(Schema::new(column_schemas));
|
||||
let columns: Vec<VectorRef> = vec![
|
||||
Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
|
||||
Arc::new(StringVector::from(vec![
|
||||
None,
|
||||
Some("hello"),
|
||||
Some("greptime"),
|
||||
None,
|
||||
])),
|
||||
];
|
||||
let recordbatch = RecordBatch::new(schema, columns).unwrap();
|
||||
let recordbatch = recordbatch.slice(1, 2).expect("recordbatch slice");
|
||||
let mut record_batch_iter = recordbatch.rows();
|
||||
assert_eq!(
|
||||
vec![Value::UInt32(2), Value::String("hello".into())],
|
||||
record_batch_iter
|
||||
.next()
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.collect::<Vec<Value>>()
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
vec![Value::UInt32(3), Value::String("greptime".into())],
|
||||
record_batch_iter
|
||||
.next()
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.collect::<Vec<Value>>()
|
||||
);
|
||||
|
||||
assert!(record_batch_iter.next().is_none());
|
||||
|
||||
assert!(recordbatch.slice(1, 5).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge_record_batch() {
|
||||
let column_schemas = vec![
|
||||
ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
|
||||
ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
|
||||
];
|
||||
let schema = Arc::new(Schema::new(column_schemas));
|
||||
let columns: Vec<VectorRef> = vec![
|
||||
Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
|
||||
Arc::new(StringVector::from(vec![
|
||||
None,
|
||||
Some("hello"),
|
||||
Some("greptime"),
|
||||
None,
|
||||
])),
|
||||
];
|
||||
let recordbatch = RecordBatch::new(schema.clone(), columns).unwrap();
|
||||
|
||||
let columns: Vec<VectorRef> = vec![
|
||||
Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
|
||||
Arc::new(StringVector::from(vec![
|
||||
None,
|
||||
Some("hello"),
|
||||
Some("greptime"),
|
||||
None,
|
||||
])),
|
||||
];
|
||||
let recordbatch2 = RecordBatch::new(schema.clone(), columns).unwrap();
|
||||
|
||||
let merged = merge_record_batches(schema.clone(), &[recordbatch, recordbatch2])
|
||||
.expect("merge recordbatch");
|
||||
assert_eq!(merged.num_rows(), 8);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user