mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2025-12-26 08:00:01 +00:00
Compare commits
35 Commits
dashboard/
...
v0.18.0-ni
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
03954e8b3b | ||
|
|
bd8f5d2b71 | ||
|
|
74721a06ba | ||
|
|
18e4839a17 | ||
|
|
cbe0cf4a74 | ||
|
|
e26b98f452 | ||
|
|
d8b967408e | ||
|
|
c35407fdce | ||
|
|
edf4b3f7f8 | ||
|
|
14550429e9 | ||
|
|
ff2da4903e | ||
|
|
c92ab4217f | ||
|
|
77981a7de5 | ||
|
|
9096c5ebbf | ||
|
|
0a959f9920 | ||
|
|
85c1a91bae | ||
|
|
7aba9a18fd | ||
|
|
4c18d140b4 | ||
|
|
b8e0c49cb4 | ||
|
|
db42ad42dc | ||
|
|
8ce963f63e | ||
|
|
b3aabb6706 | ||
|
|
028effe952 | ||
|
|
d86f489a74 | ||
|
|
6c066c1a4a | ||
|
|
9ab87e11a4 | ||
|
|
9fe7069146 | ||
|
|
733a1afcd1 | ||
|
|
5e65581f94 | ||
|
|
e75e5baa63 | ||
|
|
c4b89df523 | ||
|
|
6a15e62719 | ||
|
|
2bddbe8c47 | ||
|
|
ea8125aafb | ||
|
|
49722951c6 |
6
.github/workflows/semantic-pull-request.yml
vendored
6
.github/workflows/semantic-pull-request.yml
vendored
@@ -1,7 +1,7 @@
|
||||
name: "Semantic Pull Request"
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
pull_request_target:
|
||||
types:
|
||||
- opened
|
||||
- reopened
|
||||
@@ -12,9 +12,9 @@ concurrency:
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
issues: write
|
||||
contents: write
|
||||
contents: read
|
||||
pull-requests: write
|
||||
issues: write
|
||||
|
||||
jobs:
|
||||
check:
|
||||
|
||||
7
Cargo.lock
generated
7
Cargo.lock
generated
@@ -5302,7 +5302,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "greptime-proto"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/GreptimeTeam/greptime-proto.git?rev=f9836cf8aab30e672f640c6ef4c1cfd2cf9fbc36#f9836cf8aab30e672f640c6ef4c1cfd2cf9fbc36"
|
||||
source = "git+https://github.com/GreptimeTeam/greptime-proto.git?rev=3e821d0d405e6733690a4e4352812ba2ff780a3e#3e821d0d405e6733690a4e4352812ba2ff780a3e"
|
||||
dependencies = [
|
||||
"prost 0.13.5",
|
||||
"prost-types 0.13.5",
|
||||
@@ -7287,8 +7287,7 @@ checksum = "32a282da65faaf38286cf3be983213fcf1d2e2a58700e808f83f4ea9a4804bc0"
|
||||
[[package]]
|
||||
name = "memcomparable"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "376101dbd964fc502d5902216e180f92b3d003b5cc3d2e40e044eb5470fca677"
|
||||
source = "git+https://github.com/v0y4g3r/memcomparable.git?rev=a07122dc03556bbd88ad66234cbea7efd3b23efb#a07122dc03556bbd88ad66234cbea7efd3b23efb"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"serde",
|
||||
@@ -7607,7 +7606,6 @@ dependencies = [
|
||||
"itertools 0.14.0",
|
||||
"lazy_static",
|
||||
"log-store",
|
||||
"memcomparable",
|
||||
"mito-codec",
|
||||
"moka",
|
||||
"object-store",
|
||||
@@ -12360,6 +12358,7 @@ dependencies = [
|
||||
"sqlparser 0.55.0-greptime",
|
||||
"strum 0.27.1",
|
||||
"tokio",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -145,7 +145,7 @@ etcd-client = { git = "https://github.com/GreptimeTeam/etcd-client", rev = "f62d
|
||||
fst = "0.4.7"
|
||||
futures = "0.3"
|
||||
futures-util = "0.3"
|
||||
greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "f9836cf8aab30e672f640c6ef4c1cfd2cf9fbc36" }
|
||||
greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "3e821d0d405e6733690a4e4352812ba2ff780a3e" }
|
||||
hex = "0.4"
|
||||
http = "1"
|
||||
humantime = "2.1"
|
||||
|
||||
@@ -151,6 +151,7 @@
|
||||
| `region_engine.mito.max_concurrent_scan_files` | Integer | `384` | Maximum number of SST files to scan concurrently. |
|
||||
| `region_engine.mito.allow_stale_entries` | Bool | `false` | Whether to allow stale WAL entries read during replay. |
|
||||
| `region_engine.mito.min_compaction_interval` | String | `0m` | Minimum time interval between two compactions.<br/>To align with the old behavior, the default value is 0 (no restrictions). |
|
||||
| `region_engine.mito.enable_experimental_flat_format` | Bool | `false` | Whether to enable experimental flat format. |
|
||||
| `region_engine.mito.index` | -- | -- | The options for index in Mito engine. |
|
||||
| `region_engine.mito.index.aux_path` | String | `""` | Auxiliary directory path for the index in filesystem, used to store intermediate files for<br/>creating the index and staging files for searching the index, defaults to `{data_home}/index_intermediate`.<br/>The default name for this directory is `index_intermediate` for backward compatibility.<br/><br/>This path contains two subdirectories:<br/>- `__intm`: for storing intermediate files used during creating index.<br/>- `staging`: for storing staging files used during searching index. |
|
||||
| `region_engine.mito.index.staging_size` | String | `2GB` | The max capacity of the staging directory. |
|
||||
@@ -543,6 +544,7 @@
|
||||
| `region_engine.mito.max_concurrent_scan_files` | Integer | `384` | Maximum number of SST files to scan concurrently. |
|
||||
| `region_engine.mito.allow_stale_entries` | Bool | `false` | Whether to allow stale WAL entries read during replay. |
|
||||
| `region_engine.mito.min_compaction_interval` | String | `0m` | Minimum time interval between two compactions.<br/>To align with the old behavior, the default value is 0 (no restrictions). |
|
||||
| `region_engine.mito.enable_experimental_flat_format` | Bool | `false` | Whether to enable experimental flat format. |
|
||||
| `region_engine.mito.index` | -- | -- | The options for index in Mito engine. |
|
||||
| `region_engine.mito.index.aux_path` | String | `""` | Auxiliary directory path for the index in filesystem, used to store intermediate files for<br/>creating the index and staging files for searching the index, defaults to `{data_home}/index_intermediate`.<br/>The default name for this directory is `index_intermediate` for backward compatibility.<br/><br/>This path contains two subdirectories:<br/>- `__intm`: for storing intermediate files used during creating index.<br/>- `staging`: for storing staging files used during searching index. |
|
||||
| `region_engine.mito.index.staging_size` | String | `2GB` | The max capacity of the staging directory. |
|
||||
|
||||
@@ -497,6 +497,9 @@ allow_stale_entries = false
|
||||
## To align with the old behavior, the default value is 0 (no restrictions).
|
||||
min_compaction_interval = "0m"
|
||||
|
||||
## Whether to enable experimental flat format.
|
||||
enable_experimental_flat_format = false
|
||||
|
||||
## The options for index in Mito engine.
|
||||
[region_engine.mito.index]
|
||||
|
||||
|
||||
@@ -576,6 +576,9 @@ allow_stale_entries = false
|
||||
## To align with the old behavior, the default value is 0 (no restrictions).
|
||||
min_compaction_interval = "0m"
|
||||
|
||||
## Whether to enable experimental flat format.
|
||||
enable_experimental_flat_format = false
|
||||
|
||||
## The options for index in Mito engine.
|
||||
[region_engine.mito.index]
|
||||
|
||||
|
||||
@@ -30,22 +30,7 @@ curl https://raw.githubusercontent.com/brendangregg/FlameGraph/master/flamegraph
|
||||
|
||||
## Profiling
|
||||
|
||||
### Configuration
|
||||
|
||||
You can control heap profiling activation through configuration. Add the following to your configuration file:
|
||||
|
||||
```toml
|
||||
[memory]
|
||||
# Whether to enable heap profiling activation during startup.
|
||||
# When enabled, heap profiling will be activated if the `MALLOC_CONF` environment variable
|
||||
# is set to "prof:true,prof_active:false". The official image adds this env variable.
|
||||
# Default is true.
|
||||
enable_heap_profiling = true
|
||||
```
|
||||
|
||||
By default, if you set `MALLOC_CONF=prof:true,prof_active:false`, the database will enable profiling during startup. You can disable this behavior by setting `enable_heap_profiling = false` in the configuration.
|
||||
|
||||
### Starting with environment variables
|
||||
### Enable memory profiling for greptimedb binary
|
||||
|
||||
Start GreptimeDB instance with environment variables:
|
||||
|
||||
@@ -57,6 +42,22 @@ MALLOC_CONF=prof:true ./target/debug/greptime standalone start
|
||||
_RJEM_MALLOC_CONF=prof:true ./target/debug/greptime standalone start
|
||||
```
|
||||
|
||||
### Memory profiling for greptimedb docker image
|
||||
|
||||
We have memory profiling enabled and activated by default in our official docker
|
||||
image.
|
||||
|
||||
This behavior is controlled by configuration `enable_heap_profiling`:
|
||||
|
||||
```toml
|
||||
[memory]
|
||||
# Whether to enable heap profiling activation during startup.
|
||||
# Default is true.
|
||||
enable_heap_profiling = true
|
||||
```
|
||||
|
||||
To disable memory profiling, set `enable_heap_profiling` to `false`.
|
||||
|
||||
### Memory profiling control
|
||||
|
||||
You can control heap profiling activation using the new HTTP APIs:
|
||||
|
||||
41
scripts/generate_certs.sh
Executable file
41
scripts/generate_certs.sh
Executable file
@@ -0,0 +1,41 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
CERT_DIR="${1:-$(dirname "$0")/../tests-integration/fixtures/certs}"
|
||||
DAYS="${2:-365}"
|
||||
|
||||
mkdir -p "${CERT_DIR}"
|
||||
cd "${CERT_DIR}"
|
||||
|
||||
echo "Generating CA certificate..."
|
||||
openssl req -new -x509 -days "${DAYS}" -nodes -text \
|
||||
-out root.crt -keyout root.key \
|
||||
-subj "/CN=GreptimeDBRootCA"
|
||||
|
||||
|
||||
echo "Generating server certificate..."
|
||||
openssl req -new -nodes -text \
|
||||
-out server.csr -keyout server.key \
|
||||
-subj "/CN=greptime"
|
||||
|
||||
openssl x509 -req -in server.csr -text -days "${DAYS}" \
|
||||
-CA root.crt -CAkey root.key -CAcreateserial \
|
||||
-out server.crt \
|
||||
-extensions v3_req -extfile <(printf "[v3_req]\nsubjectAltName=DNS:localhost,IP:127.0.0.1")
|
||||
|
||||
echo "Generating client certificate..."
|
||||
# Make sure the client certificate is for the greptimedb user
|
||||
openssl req -new -nodes -text \
|
||||
-out client.csr -keyout client.key \
|
||||
-subj "/CN=greptimedb"
|
||||
|
||||
openssl x509 -req -in client.csr -CA root.crt -CAkey root.key -CAcreateserial \
|
||||
-out client.crt -days 365 -extensions v3_req -extfile <(printf "[v3_req]\nsubjectAltName=DNS:localhost")
|
||||
|
||||
rm -f *.csr
|
||||
|
||||
echo "TLS certificates generated successfully in ${CERT_DIR}"
|
||||
|
||||
chmod 644 root.key
|
||||
chmod 644 client.key
|
||||
chmod 644 server.key
|
||||
@@ -19,8 +19,8 @@ use common_error::ext::BoxedError;
|
||||
use common_meta::kv_backend::KvBackendRef;
|
||||
use common_meta::kv_backend::chroot::ChrootKvBackend;
|
||||
use common_meta::kv_backend::etcd::EtcdStore;
|
||||
use meta_srv::bootstrap::create_etcd_client_with_tls;
|
||||
use meta_srv::metasrv::BackendImpl;
|
||||
use meta_srv::utils::etcd::create_etcd_client_with_tls;
|
||||
use servers::tls::{TlsMode, TlsOption};
|
||||
|
||||
use crate::error::{EmptyStoreAddrsSnafu, UnsupportedMemoryBackendSnafu};
|
||||
@@ -83,6 +83,20 @@ pub(crate) struct StoreConfig {
|
||||
}
|
||||
|
||||
impl StoreConfig {
|
||||
pub fn tls_config(&self) -> Option<TlsOption> {
|
||||
if self.backend_tls_mode != TlsMode::Disable {
|
||||
Some(TlsOption {
|
||||
mode: self.backend_tls_mode.clone(),
|
||||
cert_path: self.backend_tls_cert_path.clone(),
|
||||
key_path: self.backend_tls_key_path.clone(),
|
||||
ca_cert_path: self.backend_tls_ca_cert_path.clone(),
|
||||
watch: self.backend_tls_watch,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Builds a [`KvBackendRef`] from the store configuration.
|
||||
pub async fn build(&self) -> Result<KvBackendRef, BoxedError> {
|
||||
let max_txn_ops = self.max_txn_ops;
|
||||
@@ -92,17 +106,7 @@ impl StoreConfig {
|
||||
} else {
|
||||
let kvbackend = match self.backend {
|
||||
BackendImpl::EtcdStore => {
|
||||
let tls_config = if self.backend_tls_mode != TlsMode::Disable {
|
||||
Some(TlsOption {
|
||||
mode: self.backend_tls_mode.clone(),
|
||||
cert_path: self.backend_tls_cert_path.clone(),
|
||||
key_path: self.backend_tls_key_path.clone(),
|
||||
ca_cert_path: self.backend_tls_ca_cert_path.clone(),
|
||||
watch: self.backend_tls_watch,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let tls_config = self.tls_config();
|
||||
let etcd_client = create_etcd_client_with_tls(store_addrs, tls_config.as_ref())
|
||||
.await
|
||||
.map_err(BoxedError::new)?;
|
||||
@@ -111,9 +115,14 @@ impl StoreConfig {
|
||||
#[cfg(feature = "pg_kvbackend")]
|
||||
BackendImpl::PostgresStore => {
|
||||
let table_name = &self.meta_table_name;
|
||||
let pool = meta_srv::bootstrap::create_postgres_pool(store_addrs, None)
|
||||
.await
|
||||
.map_err(BoxedError::new)?;
|
||||
let tls_config = self.tls_config();
|
||||
let pool = meta_srv::utils::postgres::create_postgres_pool(
|
||||
store_addrs,
|
||||
None,
|
||||
tls_config,
|
||||
)
|
||||
.await
|
||||
.map_err(BoxedError::new)?;
|
||||
let schema_name = self.meta_schema_name.as_deref();
|
||||
Ok(common_meta::kv_backend::rds::PgStore::with_pg_pool(
|
||||
pool,
|
||||
@@ -127,9 +136,11 @@ impl StoreConfig {
|
||||
#[cfg(feature = "mysql_kvbackend")]
|
||||
BackendImpl::MysqlStore => {
|
||||
let table_name = &self.meta_table_name;
|
||||
let pool = meta_srv::bootstrap::create_mysql_pool(store_addrs)
|
||||
.await
|
||||
.map_err(BoxedError::new)?;
|
||||
let tls_config = self.tls_config();
|
||||
let pool =
|
||||
meta_srv::utils::mysql::create_mysql_pool(store_addrs, tls_config.as_ref())
|
||||
.await
|
||||
.map_err(BoxedError::new)?;
|
||||
Ok(common_meta::kv_backend::rds::MySqlStore::with_mysql_pool(
|
||||
pool,
|
||||
table_name,
|
||||
|
||||
@@ -196,7 +196,10 @@ pub async fn stream_to_parquet(
|
||||
concurrency: usize,
|
||||
) -> Result<usize> {
|
||||
let write_props = column_wise_config(
|
||||
WriterProperties::builder().set_compression(Compression::ZSTD(ZstdLevel::default())),
|
||||
WriterProperties::builder()
|
||||
.set_compression(Compression::ZSTD(ZstdLevel::default()))
|
||||
.set_statistics_truncate_length(None)
|
||||
.set_column_index_truncate_length(None),
|
||||
schema,
|
||||
)
|
||||
.build();
|
||||
|
||||
@@ -12,23 +12,19 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
mod add_region_follower;
|
||||
mod flush_compact_region;
|
||||
mod flush_compact_table;
|
||||
mod migrate_region;
|
||||
mod reconcile_catalog;
|
||||
mod reconcile_database;
|
||||
mod reconcile_table;
|
||||
mod remove_region_follower;
|
||||
|
||||
use add_region_follower::AddRegionFollowerFunction;
|
||||
use flush_compact_region::{CompactRegionFunction, FlushRegionFunction};
|
||||
use flush_compact_table::{CompactTableFunction, FlushTableFunction};
|
||||
use migrate_region::MigrateRegionFunction;
|
||||
use reconcile_catalog::ReconcileCatalogFunction;
|
||||
use reconcile_database::ReconcileDatabaseFunction;
|
||||
use reconcile_table::ReconcileTableFunction;
|
||||
use remove_region_follower::RemoveRegionFollowerFunction;
|
||||
|
||||
use crate::flush_flow::FlushFlowFunction;
|
||||
use crate::function_registry::FunctionRegistry;
|
||||
@@ -40,8 +36,6 @@ impl AdminFunction {
|
||||
/// Register all admin functions to [`FunctionRegistry`].
|
||||
pub fn register(registry: &FunctionRegistry) {
|
||||
registry.register(MigrateRegionFunction::factory());
|
||||
registry.register(AddRegionFollowerFunction::factory());
|
||||
registry.register(RemoveRegionFollowerFunction::factory());
|
||||
registry.register(FlushRegionFunction::factory());
|
||||
registry.register(CompactRegionFunction::factory());
|
||||
registry.register(FlushTableFunction::factory());
|
||||
|
||||
@@ -1,155 +0,0 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use common_macro::admin_fn;
|
||||
use common_meta::rpc::procedure::AddRegionFollowerRequest;
|
||||
use common_query::error::{
|
||||
InvalidFuncArgsSnafu, MissingProcedureServiceHandlerSnafu, Result,
|
||||
UnsupportedInputDataTypeSnafu,
|
||||
};
|
||||
use datafusion_expr::{Signature, TypeSignature, Volatility};
|
||||
use datatypes::data_type::DataType;
|
||||
use datatypes::prelude::ConcreteDataType;
|
||||
use datatypes::value::{Value, ValueRef};
|
||||
use session::context::QueryContextRef;
|
||||
use snafu::ensure;
|
||||
|
||||
use crate::handlers::ProcedureServiceHandlerRef;
|
||||
use crate::helper::cast_u64;
|
||||
|
||||
/// A function to add a follower to a region.
|
||||
/// Only available in cluster mode.
|
||||
///
|
||||
/// - `add_region_follower(region_id, peer_id)`.
|
||||
///
|
||||
/// The parameters:
|
||||
/// - `region_id`: the region id
|
||||
/// - `peer_id`: the peer id
|
||||
#[admin_fn(
|
||||
name = AddRegionFollowerFunction,
|
||||
display_name = add_region_follower,
|
||||
sig_fn = signature,
|
||||
ret = uint64
|
||||
)]
|
||||
pub(crate) async fn add_region_follower(
|
||||
procedure_service_handler: &ProcedureServiceHandlerRef,
|
||||
_ctx: &QueryContextRef,
|
||||
params: &[ValueRef<'_>],
|
||||
) -> Result<Value> {
|
||||
ensure!(
|
||||
params.len() == 2,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect exactly 2, have: {}",
|
||||
params.len()
|
||||
),
|
||||
}
|
||||
);
|
||||
|
||||
let Some(region_id) = cast_u64(¶ms[0])? else {
|
||||
return UnsupportedInputDataTypeSnafu {
|
||||
function: "add_region_follower",
|
||||
datatypes: params.iter().map(|v| v.data_type()).collect::<Vec<_>>(),
|
||||
}
|
||||
.fail();
|
||||
};
|
||||
let Some(peer_id) = cast_u64(¶ms[1])? else {
|
||||
return UnsupportedInputDataTypeSnafu {
|
||||
function: "add_region_follower",
|
||||
datatypes: params.iter().map(|v| v.data_type()).collect::<Vec<_>>(),
|
||||
}
|
||||
.fail();
|
||||
};
|
||||
|
||||
procedure_service_handler
|
||||
.add_region_follower(AddRegionFollowerRequest { region_id, peer_id })
|
||||
.await?;
|
||||
|
||||
Ok(Value::from(0u64))
|
||||
}
|
||||
|
||||
fn signature() -> Signature {
|
||||
Signature::one_of(
|
||||
vec![
|
||||
// add_region_follower(region_id, peer)
|
||||
TypeSignature::Uniform(
|
||||
2,
|
||||
ConcreteDataType::numerics()
|
||||
.into_iter()
|
||||
.map(|dt| dt.as_arrow_type())
|
||||
.collect(),
|
||||
),
|
||||
],
|
||||
Volatility::Immutable,
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow::array::UInt64Array;
|
||||
use arrow::datatypes::{DataType, Field};
|
||||
use datafusion_expr::ColumnarValue;
|
||||
|
||||
use super::*;
|
||||
use crate::function::FunctionContext;
|
||||
use crate::function_factory::ScalarFunctionFactory;
|
||||
|
||||
#[test]
|
||||
fn test_add_region_follower_misc() {
|
||||
let factory: ScalarFunctionFactory = AddRegionFollowerFunction::factory().into();
|
||||
let f = factory.provide(FunctionContext::mock());
|
||||
assert_eq!("add_region_follower", f.name());
|
||||
assert_eq!(DataType::UInt64, f.return_type(&[]).unwrap());
|
||||
assert!(matches!(f.signature(),
|
||||
datafusion_expr::Signature {
|
||||
type_signature: datafusion_expr::TypeSignature::OneOf(sigs),
|
||||
volatility: datafusion_expr::Volatility::Immutable
|
||||
} if sigs.len() == 1));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_add_region_follower() {
|
||||
let factory: ScalarFunctionFactory = AddRegionFollowerFunction::factory().into();
|
||||
let provider = factory.provide(FunctionContext::mock());
|
||||
let f = provider.as_async().unwrap();
|
||||
|
||||
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
|
||||
args: vec![
|
||||
ColumnarValue::Array(Arc::new(UInt64Array::from(vec![1]))),
|
||||
ColumnarValue::Array(Arc::new(UInt64Array::from(vec![2]))),
|
||||
],
|
||||
arg_fields: vec![
|
||||
Arc::new(Field::new("arg_0", DataType::UInt64, false)),
|
||||
Arc::new(Field::new("arg_1", DataType::UInt64, false)),
|
||||
],
|
||||
return_field: Arc::new(Field::new("result", DataType::UInt64, true)),
|
||||
number_rows: 1,
|
||||
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
|
||||
};
|
||||
|
||||
let result = f.invoke_async_with_args(func_args).await.unwrap();
|
||||
|
||||
match result {
|
||||
ColumnarValue::Array(array) => {
|
||||
let result_array = array.as_any().downcast_ref::<UInt64Array>().unwrap();
|
||||
assert_eq!(result_array.value(0), 0u64);
|
||||
}
|
||||
ColumnarValue::Scalar(scalar) => {
|
||||
assert_eq!(scalar, datafusion_common::ScalarValue::UInt64(Some(0)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,155 +0,0 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use common_macro::admin_fn;
|
||||
use common_meta::rpc::procedure::RemoveRegionFollowerRequest;
|
||||
use common_query::error::{
|
||||
InvalidFuncArgsSnafu, MissingProcedureServiceHandlerSnafu, Result,
|
||||
UnsupportedInputDataTypeSnafu,
|
||||
};
|
||||
use datafusion_expr::{Signature, TypeSignature, Volatility};
|
||||
use datatypes::data_type::DataType;
|
||||
use datatypes::prelude::ConcreteDataType;
|
||||
use datatypes::value::{Value, ValueRef};
|
||||
use session::context::QueryContextRef;
|
||||
use snafu::ensure;
|
||||
|
||||
use crate::handlers::ProcedureServiceHandlerRef;
|
||||
use crate::helper::cast_u64;
|
||||
|
||||
/// A function to remove a follower from a region.
|
||||
//// Only available in cluster mode.
|
||||
///
|
||||
/// - `remove_region_follower(region_id, peer_id)`.
|
||||
///
|
||||
/// The parameters:
|
||||
/// - `region_id`: the region id
|
||||
/// - `peer_id`: the peer id
|
||||
#[admin_fn(
|
||||
name = RemoveRegionFollowerFunction,
|
||||
display_name = remove_region_follower,
|
||||
sig_fn = signature,
|
||||
ret = uint64
|
||||
)]
|
||||
pub(crate) async fn remove_region_follower(
|
||||
procedure_service_handler: &ProcedureServiceHandlerRef,
|
||||
_ctx: &QueryContextRef,
|
||||
params: &[ValueRef<'_>],
|
||||
) -> Result<Value> {
|
||||
ensure!(
|
||||
params.len() == 2,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect exactly 2, have: {}",
|
||||
params.len()
|
||||
),
|
||||
}
|
||||
);
|
||||
|
||||
let Some(region_id) = cast_u64(¶ms[0])? else {
|
||||
return UnsupportedInputDataTypeSnafu {
|
||||
function: "add_region_follower",
|
||||
datatypes: params.iter().map(|v| v.data_type()).collect::<Vec<_>>(),
|
||||
}
|
||||
.fail();
|
||||
};
|
||||
let Some(peer_id) = cast_u64(¶ms[1])? else {
|
||||
return UnsupportedInputDataTypeSnafu {
|
||||
function: "add_region_follower",
|
||||
datatypes: params.iter().map(|v| v.data_type()).collect::<Vec<_>>(),
|
||||
}
|
||||
.fail();
|
||||
};
|
||||
|
||||
procedure_service_handler
|
||||
.remove_region_follower(RemoveRegionFollowerRequest { region_id, peer_id })
|
||||
.await?;
|
||||
|
||||
Ok(Value::from(0u64))
|
||||
}
|
||||
|
||||
fn signature() -> Signature {
|
||||
Signature::one_of(
|
||||
vec![
|
||||
// remove_region_follower(region_id, peer_id)
|
||||
TypeSignature::Uniform(
|
||||
2,
|
||||
ConcreteDataType::numerics()
|
||||
.into_iter()
|
||||
.map(|dt| dt.as_arrow_type())
|
||||
.collect(),
|
||||
),
|
||||
],
|
||||
Volatility::Immutable,
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow::array::UInt64Array;
|
||||
use arrow::datatypes::{DataType, Field};
|
||||
use datafusion_expr::ColumnarValue;
|
||||
|
||||
use super::*;
|
||||
use crate::function::FunctionContext;
|
||||
use crate::function_factory::ScalarFunctionFactory;
|
||||
|
||||
#[test]
|
||||
fn test_remove_region_follower_misc() {
|
||||
let factory: ScalarFunctionFactory = RemoveRegionFollowerFunction::factory().into();
|
||||
let f = factory.provide(FunctionContext::mock());
|
||||
assert_eq!("remove_region_follower", f.name());
|
||||
assert_eq!(DataType::UInt64, f.return_type(&[]).unwrap());
|
||||
assert!(matches!(f.signature(),
|
||||
datafusion_expr::Signature {
|
||||
type_signature: datafusion_expr::TypeSignature::OneOf(sigs),
|
||||
volatility: datafusion_expr::Volatility::Immutable
|
||||
} if sigs.len() == 1));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_remove_region_follower() {
|
||||
let factory: ScalarFunctionFactory = RemoveRegionFollowerFunction::factory().into();
|
||||
let provider = factory.provide(FunctionContext::mock());
|
||||
let f = provider.as_async().unwrap();
|
||||
|
||||
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
|
||||
args: vec![
|
||||
ColumnarValue::Array(Arc::new(UInt64Array::from(vec![1]))),
|
||||
ColumnarValue::Array(Arc::new(UInt64Array::from(vec![1]))),
|
||||
],
|
||||
arg_fields: vec![
|
||||
Arc::new(Field::new("arg_0", DataType::UInt64, false)),
|
||||
Arc::new(Field::new("arg_1", DataType::UInt64, false)),
|
||||
],
|
||||
return_field: Arc::new(Field::new("result", DataType::UInt64, true)),
|
||||
number_rows: 1,
|
||||
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
|
||||
};
|
||||
|
||||
let result = f.invoke_async_with_args(func_args).await.unwrap();
|
||||
|
||||
match result {
|
||||
ColumnarValue::Array(array) => {
|
||||
let result_array = array.as_any().downcast_ref::<UInt64Array>().unwrap();
|
||||
assert_eq!(result_array.value(0), 0u64);
|
||||
}
|
||||
ColumnarValue::Scalar(scalar) => {
|
||||
assert_eq!(scalar, datafusion_common::ScalarValue::UInt64(Some(0)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -15,11 +15,15 @@
|
||||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error::Result;
|
||||
use common_error::ext::{BoxedError, PlainError};
|
||||
use common_error::status_code::StatusCode;
|
||||
use common_query::error::{ExecuteSnafu, Result};
|
||||
use datafusion::arrow::datatypes::DataType;
|
||||
use datafusion_expr::Signature;
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature};
|
||||
use datatypes::vectors::VectorRef;
|
||||
use session::context::{QueryContextBuilder, QueryContextRef};
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::state::FunctionState;
|
||||
|
||||
@@ -68,8 +72,26 @@ pub trait Function: fmt::Display + Sync + Send {
|
||||
/// The signature of function.
|
||||
fn signature(&self) -> Signature;
|
||||
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
// TODO(LFC): Remove default implementation once all UDFs have implemented this function.
|
||||
let _ = args;
|
||||
Err(datafusion_common::DataFusionError::NotImplemented(
|
||||
"invoke_with_args".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
/// Evaluate the function, e.g. run/execute the function.
|
||||
fn eval(&self, ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef>;
|
||||
/// TODO(LFC): Remove `eval` when all UDFs are rewritten to `invoke_with_args`
|
||||
fn eval(&self, _: &FunctionContext, _: &[VectorRef]) -> Result<VectorRef> {
|
||||
Err(BoxedError::new(PlainError::new(
|
||||
"unsupported".to_string(),
|
||||
StatusCode::Unsupported,
|
||||
)))
|
||||
.context(ExecuteSnafu)
|
||||
}
|
||||
|
||||
fn aliases(&self) -> &[String] {
|
||||
&[]
|
||||
|
||||
@@ -19,8 +19,7 @@ use async_trait::async_trait;
|
||||
use catalog::CatalogManagerRef;
|
||||
use common_base::AffectedRows;
|
||||
use common_meta::rpc::procedure::{
|
||||
AddRegionFollowerRequest, MigrateRegionRequest, ProcedureStateResponse,
|
||||
RemoveRegionFollowerRequest,
|
||||
ManageRegionFollowerRequest, MigrateRegionRequest, ProcedureStateResponse,
|
||||
};
|
||||
use common_query::Output;
|
||||
use common_query::error::Result;
|
||||
@@ -72,11 +71,8 @@ pub trait ProcedureServiceHandler: Send + Sync {
|
||||
/// Query the procedure' state by its id
|
||||
async fn query_procedure_state(&self, pid: &str) -> Result<ProcedureStateResponse>;
|
||||
|
||||
/// Add a region follower to a region.
|
||||
async fn add_region_follower(&self, request: AddRegionFollowerRequest) -> Result<()>;
|
||||
|
||||
/// Remove a region follower from a region.
|
||||
async fn remove_region_follower(&self, request: RemoveRegionFollowerRequest) -> Result<()>;
|
||||
/// Manage a region follower to a region.
|
||||
async fn manage_region_follower(&self, request: ManageRegionFollowerRequest) -> Result<()>;
|
||||
|
||||
/// Get the catalog manager
|
||||
fn catalog_manager(&self) -> &CatalogManagerRef;
|
||||
|
||||
@@ -14,14 +14,15 @@
|
||||
|
||||
use std::fmt;
|
||||
|
||||
use common_query::error::{ArrowComputeSnafu, IntoVectorSnafu, InvalidFuncArgsSnafu, Result};
|
||||
use datafusion_expr::Signature;
|
||||
use common_query::error::{ArrowComputeSnafu, Result};
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion_common::utils;
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature};
|
||||
use datatypes::arrow::compute::kernels::numeric;
|
||||
use datatypes::arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
|
||||
use datatypes::vectors::{Helper, VectorRef};
|
||||
use snafu::{ResultExt, ensure};
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function::Function;
|
||||
use crate::helper;
|
||||
|
||||
/// A function adds an interval value to Timestamp, Date, and return the result.
|
||||
@@ -58,25 +59,15 @@ impl Function for DateAddFunction {
|
||||
)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 2,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect 2, have: {}",
|
||||
columns.len()
|
||||
),
|
||||
}
|
||||
);
|
||||
|
||||
let left = columns[0].to_arrow_array();
|
||||
let right = columns[1].to_arrow_array();
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let args = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let [left, right] = utils::take_function_args(self.name(), args)?;
|
||||
|
||||
let result = numeric::add(&left, &right).context(ArrowComputeSnafu)?;
|
||||
let arrow_type = result.data_type().clone();
|
||||
Helper::try_into_vector(result).context(IntoVectorSnafu {
|
||||
data_type: arrow_type,
|
||||
})
|
||||
Ok(ColumnarValue::Array(result))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,12 +81,14 @@ impl fmt::Display for DateAddFunction {
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use datafusion_expr::{TypeSignature, Volatility};
|
||||
use datatypes::arrow::datatypes::IntervalDayTime;
|
||||
use datatypes::value::Value;
|
||||
use datatypes::vectors::{
|
||||
DateVector, IntervalDayTimeVector, IntervalYearMonthVector, TimestampSecondVector,
|
||||
use arrow_schema::Field;
|
||||
use datafusion::arrow::array::{
|
||||
Array, AsArray, Date32Array, IntervalDayTimeArray, IntervalYearMonthArray,
|
||||
TimestampSecondArray,
|
||||
};
|
||||
use datafusion::arrow::datatypes::{Date32Type, IntervalDayTime, TimestampSecondType};
|
||||
use datafusion_common::config::ConfigOptions;
|
||||
use datafusion_expr::{TypeSignature, Volatility};
|
||||
|
||||
use super::{DateAddFunction, *};
|
||||
|
||||
@@ -142,25 +135,37 @@ mod tests {
|
||||
];
|
||||
let results = [Some(124), None, Some(45), None];
|
||||
|
||||
let time_vector = TimestampSecondVector::from(times.clone());
|
||||
let interval_vector = IntervalDayTimeVector::from_vec(intervals);
|
||||
let args: Vec<VectorRef> = vec![Arc::new(time_vector), Arc::new(interval_vector)];
|
||||
let vector = f.eval(&FunctionContext::default(), &args).unwrap();
|
||||
let args = vec![
|
||||
ColumnarValue::Array(Arc::new(TimestampSecondArray::from(times.clone()))),
|
||||
ColumnarValue::Array(Arc::new(IntervalDayTimeArray::from(intervals))),
|
||||
];
|
||||
|
||||
let vector = f
|
||||
.invoke_with_args(ScalarFunctionArgs {
|
||||
args,
|
||||
arg_fields: vec![],
|
||||
number_rows: 4,
|
||||
return_field: Arc::new(Field::new(
|
||||
"x",
|
||||
DataType::Timestamp(TimeUnit::Second, None),
|
||||
true,
|
||||
)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
})
|
||||
.and_then(|v| ColumnarValue::values_to_arrays(&[v]))
|
||||
.map(|mut a| a.remove(0))
|
||||
.unwrap();
|
||||
let vector = vector.as_primitive::<TimestampSecondType>();
|
||||
|
||||
assert_eq!(4, vector.len());
|
||||
for (i, _t) in times.iter().enumerate() {
|
||||
let v = vector.get(i);
|
||||
let result = results.get(i).unwrap();
|
||||
|
||||
if result.is_none() {
|
||||
assert_eq!(Value::Null, v);
|
||||
continue;
|
||||
}
|
||||
match v {
|
||||
Value::Timestamp(ts) => {
|
||||
assert_eq!(ts.value(), result.unwrap());
|
||||
}
|
||||
_ => unreachable!(),
|
||||
if let Some(x) = result {
|
||||
assert!(vector.is_valid(i));
|
||||
assert_eq!(vector.value(i), *x);
|
||||
} else {
|
||||
assert!(vector.is_null(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -174,25 +179,37 @@ mod tests {
|
||||
let intervals = vec![1, 2, 3, 1];
|
||||
let results = [Some(154), None, Some(131), None];
|
||||
|
||||
let date_vector = DateVector::from(dates.clone());
|
||||
let interval_vector = IntervalYearMonthVector::from_vec(intervals);
|
||||
let args: Vec<VectorRef> = vec![Arc::new(date_vector), Arc::new(interval_vector)];
|
||||
let vector = f.eval(&FunctionContext::default(), &args).unwrap();
|
||||
let args = vec![
|
||||
ColumnarValue::Array(Arc::new(Date32Array::from(dates.clone()))),
|
||||
ColumnarValue::Array(Arc::new(IntervalYearMonthArray::from(intervals))),
|
||||
];
|
||||
|
||||
let vector = f
|
||||
.invoke_with_args(ScalarFunctionArgs {
|
||||
args,
|
||||
arg_fields: vec![],
|
||||
number_rows: 4,
|
||||
return_field: Arc::new(Field::new(
|
||||
"x",
|
||||
DataType::Timestamp(TimeUnit::Second, None),
|
||||
true,
|
||||
)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
})
|
||||
.and_then(|v| ColumnarValue::values_to_arrays(&[v]))
|
||||
.map(|mut a| a.remove(0))
|
||||
.unwrap();
|
||||
let vector = vector.as_primitive::<Date32Type>();
|
||||
|
||||
assert_eq!(4, vector.len());
|
||||
for (i, _t) in dates.iter().enumerate() {
|
||||
let v = vector.get(i);
|
||||
let result = results.get(i).unwrap();
|
||||
|
||||
if result.is_none() {
|
||||
assert_eq!(Value::Null, v);
|
||||
continue;
|
||||
}
|
||||
match v {
|
||||
Value::Date(date) => {
|
||||
assert_eq!(date.val(), result.unwrap());
|
||||
}
|
||||
_ => unreachable!(),
|
||||
if let Some(x) = result {
|
||||
assert!(vector.is_valid(i));
|
||||
assert_eq!(vector.value(i), *x);
|
||||
} else {
|
||||
assert!(vector.is_null(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,14 +14,15 @@
|
||||
|
||||
use std::fmt;
|
||||
|
||||
use common_query::error::{ArrowComputeSnafu, IntoVectorSnafu, InvalidFuncArgsSnafu, Result};
|
||||
use datafusion_expr::Signature;
|
||||
use common_query::error::{ArrowComputeSnafu, Result};
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion_common::utils;
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature};
|
||||
use datatypes::arrow::compute::kernels::numeric;
|
||||
use datatypes::arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
|
||||
use datatypes::vectors::{Helper, VectorRef};
|
||||
use snafu::{ResultExt, ensure};
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function::Function;
|
||||
use crate::helper;
|
||||
|
||||
/// A function subtracts an interval value to Timestamp, Date, and return the result.
|
||||
@@ -58,25 +59,15 @@ impl Function for DateSubFunction {
|
||||
)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 2,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect 2, have: {}",
|
||||
columns.len()
|
||||
),
|
||||
}
|
||||
);
|
||||
|
||||
let left = columns[0].to_arrow_array();
|
||||
let right = columns[1].to_arrow_array();
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let args = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let [left, right] = utils::take_function_args(self.name(), args)?;
|
||||
|
||||
let result = numeric::sub(&left, &right).context(ArrowComputeSnafu)?;
|
||||
let arrow_type = result.data_type().clone();
|
||||
Helper::try_into_vector(result).context(IntoVectorSnafu {
|
||||
data_type: arrow_type,
|
||||
})
|
||||
Ok(ColumnarValue::Array(result))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,12 +81,14 @@ impl fmt::Display for DateSubFunction {
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use datafusion_expr::{TypeSignature, Volatility};
|
||||
use datatypes::arrow::datatypes::IntervalDayTime;
|
||||
use datatypes::value::Value;
|
||||
use datatypes::vectors::{
|
||||
DateVector, IntervalDayTimeVector, IntervalYearMonthVector, TimestampSecondVector,
|
||||
use arrow_schema::Field;
|
||||
use datafusion::arrow::array::{
|
||||
Array, AsArray, Date32Array, IntervalDayTimeArray, IntervalYearMonthArray,
|
||||
TimestampSecondArray,
|
||||
};
|
||||
use datafusion::arrow::datatypes::{Date32Type, IntervalDayTime, TimestampSecondType};
|
||||
use datafusion_common::config::ConfigOptions;
|
||||
use datafusion_expr::{TypeSignature, Volatility};
|
||||
|
||||
use super::{DateSubFunction, *};
|
||||
|
||||
@@ -142,25 +135,37 @@ mod tests {
|
||||
];
|
||||
let results = [Some(122), None, Some(39), None];
|
||||
|
||||
let time_vector = TimestampSecondVector::from(times.clone());
|
||||
let interval_vector = IntervalDayTimeVector::from_vec(intervals);
|
||||
let args: Vec<VectorRef> = vec![Arc::new(time_vector), Arc::new(interval_vector)];
|
||||
let vector = f.eval(&FunctionContext::default(), &args).unwrap();
|
||||
let args = vec![
|
||||
ColumnarValue::Array(Arc::new(TimestampSecondArray::from(times.clone()))),
|
||||
ColumnarValue::Array(Arc::new(IntervalDayTimeArray::from(intervals))),
|
||||
];
|
||||
|
||||
let vector = f
|
||||
.invoke_with_args(ScalarFunctionArgs {
|
||||
args,
|
||||
arg_fields: vec![],
|
||||
number_rows: 4,
|
||||
return_field: Arc::new(Field::new(
|
||||
"x",
|
||||
DataType::Timestamp(TimeUnit::Second, None),
|
||||
true,
|
||||
)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
})
|
||||
.and_then(|v| ColumnarValue::values_to_arrays(&[v]))
|
||||
.map(|mut a| a.remove(0))
|
||||
.unwrap();
|
||||
let vector = vector.as_primitive::<TimestampSecondType>();
|
||||
|
||||
assert_eq!(4, vector.len());
|
||||
for (i, _t) in times.iter().enumerate() {
|
||||
let v = vector.get(i);
|
||||
let result = results.get(i).unwrap();
|
||||
|
||||
if result.is_none() {
|
||||
assert_eq!(Value::Null, v);
|
||||
continue;
|
||||
}
|
||||
match v {
|
||||
Value::Timestamp(ts) => {
|
||||
assert_eq!(ts.value(), result.unwrap());
|
||||
}
|
||||
_ => unreachable!(),
|
||||
if let Some(x) = result {
|
||||
assert!(vector.is_valid(i));
|
||||
assert_eq!(vector.value(i), *x);
|
||||
} else {
|
||||
assert!(vector.is_null(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -180,25 +185,37 @@ mod tests {
|
||||
let intervals = vec![1, 2, 3, 1];
|
||||
let results = [Some(3659), None, Some(1168), None];
|
||||
|
||||
let date_vector = DateVector::from(dates.clone());
|
||||
let interval_vector = IntervalYearMonthVector::from_vec(intervals);
|
||||
let args: Vec<VectorRef> = vec![Arc::new(date_vector), Arc::new(interval_vector)];
|
||||
let vector = f.eval(&FunctionContext::default(), &args).unwrap();
|
||||
let args = vec![
|
||||
ColumnarValue::Array(Arc::new(Date32Array::from(dates.clone()))),
|
||||
ColumnarValue::Array(Arc::new(IntervalYearMonthArray::from(intervals))),
|
||||
];
|
||||
|
||||
let vector = f
|
||||
.invoke_with_args(ScalarFunctionArgs {
|
||||
args,
|
||||
arg_fields: vec![],
|
||||
number_rows: 4,
|
||||
return_field: Arc::new(Field::new(
|
||||
"x",
|
||||
DataType::Timestamp(TimeUnit::Second, None),
|
||||
true,
|
||||
)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
})
|
||||
.and_then(|v| ColumnarValue::values_to_arrays(&[v]))
|
||||
.map(|mut a| a.remove(0))
|
||||
.unwrap();
|
||||
let vector = vector.as_primitive::<Date32Type>();
|
||||
|
||||
assert_eq!(4, vector.len());
|
||||
for (i, _t) in dates.iter().enumerate() {
|
||||
let v = vector.get(i);
|
||||
let result = results.get(i).unwrap();
|
||||
|
||||
if result.is_none() {
|
||||
assert_eq!(Value::Null, v);
|
||||
continue;
|
||||
}
|
||||
match v {
|
||||
Value::Date(date) => {
|
||||
assert_eq!(date.val(), result.unwrap());
|
||||
}
|
||||
_ => unreachable!(),
|
||||
if let Some(x) = result {
|
||||
assert!(vector.is_valid(i));
|
||||
assert_eq!(vector.value(i), *x);
|
||||
} else {
|
||||
assert!(vector.is_null(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,62 +17,26 @@ use std::sync::Arc;
|
||||
|
||||
use common_error::ext::{BoxedError, PlainError};
|
||||
use common_error::status_code::StatusCode;
|
||||
use common_query::error::{self, InvalidFuncArgsSnafu, Result};
|
||||
use datafusion::arrow::datatypes::Field;
|
||||
use common_query::error::{self, Result};
|
||||
use datafusion::arrow::array::{Array, AsArray, ListBuilder, StringViewBuilder};
|
||||
use datafusion::arrow::datatypes::{DataType, Field, Float64Type, UInt8Type};
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion_common::{DataFusionError, utils};
|
||||
use datafusion_expr::type_coercion::aggregates::INTEGERS;
|
||||
use datafusion_expr::{Signature, TypeSignature, Volatility};
|
||||
use datatypes::arrow::datatypes::DataType;
|
||||
use datatypes::prelude::ConcreteDataType;
|
||||
use datatypes::scalars::{Scalar, ScalarVectorBuilder};
|
||||
use datatypes::value::{ListValue, Value};
|
||||
use datatypes::vectors::{ListVectorBuilder, MutableVector, StringVectorBuilder, VectorRef};
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
|
||||
use geohash::Coord;
|
||||
use snafu::{ResultExt, ensure};
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function::Function;
|
||||
use crate::scalars::geo::helpers;
|
||||
|
||||
macro_rules! ensure_resolution_usize {
|
||||
($v: ident) => {
|
||||
if !($v > 0 && $v <= 12) {
|
||||
Err(BoxedError::new(PlainError::new(
|
||||
format!("Invalid geohash resolution {}, expect value: [1, 12]", $v),
|
||||
StatusCode::EngineExecuteQuery,
|
||||
)))
|
||||
.context(error::ExecuteSnafu)
|
||||
} else {
|
||||
Ok($v as usize)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
fn try_into_resolution(v: Value) -> Result<usize> {
|
||||
match v {
|
||||
Value::Int8(v) => {
|
||||
ensure_resolution_usize!(v)
|
||||
}
|
||||
Value::Int16(v) => {
|
||||
ensure_resolution_usize!(v)
|
||||
}
|
||||
Value::Int32(v) => {
|
||||
ensure_resolution_usize!(v)
|
||||
}
|
||||
Value::Int64(v) => {
|
||||
ensure_resolution_usize!(v)
|
||||
}
|
||||
Value::UInt8(v) => {
|
||||
ensure_resolution_usize!(v)
|
||||
}
|
||||
Value::UInt16(v) => {
|
||||
ensure_resolution_usize!(v)
|
||||
}
|
||||
Value::UInt32(v) => {
|
||||
ensure_resolution_usize!(v)
|
||||
}
|
||||
Value::UInt64(v) => {
|
||||
ensure_resolution_usize!(v)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
fn ensure_resolution_usize(v: u8) -> datafusion_common::Result<usize> {
|
||||
if v == 0 || v > 12 {
|
||||
return Err(DataFusionError::Execution(format!(
|
||||
"Invalid geohash resolution {v}, valid value range: [1, 12]"
|
||||
)));
|
||||
}
|
||||
Ok(v as usize)
|
||||
}
|
||||
|
||||
/// Function that return geohash string for a given geospatial coordinate.
|
||||
@@ -109,31 +73,33 @@ impl Function for GeohashFunction {
|
||||
Signature::one_of(signatures, Volatility::Stable)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 3,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect 3, provided : {}",
|
||||
columns.len()
|
||||
),
|
||||
}
|
||||
);
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let args = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let [lat_vec, lon_vec, resolutions] = utils::take_function_args(self.name(), args)?;
|
||||
|
||||
let lat_vec = &columns[0];
|
||||
let lon_vec = &columns[1];
|
||||
let resolution_vec = &columns[2];
|
||||
let lat_vec = helpers::cast::<Float64Type>(&lat_vec)?;
|
||||
let lat_vec = lat_vec.as_primitive::<Float64Type>();
|
||||
let lon_vec = helpers::cast::<Float64Type>(&lon_vec)?;
|
||||
let lon_vec = lon_vec.as_primitive::<Float64Type>();
|
||||
let resolutions = helpers::cast::<UInt8Type>(&resolutions)?;
|
||||
let resolutions = resolutions.as_primitive::<UInt8Type>();
|
||||
|
||||
let size = lat_vec.len();
|
||||
let mut results = StringVectorBuilder::with_capacity(size);
|
||||
let mut builder = StringViewBuilder::with_capacity(size);
|
||||
|
||||
for i in 0..size {
|
||||
let lat = lat_vec.get(i).as_f64_lossy();
|
||||
let lon = lon_vec.get(i).as_f64_lossy();
|
||||
let r = try_into_resolution(resolution_vec.get(i))?;
|
||||
let lat = lat_vec.is_valid(i).then(|| lat_vec.value(i));
|
||||
let lon = lon_vec.is_valid(i).then(|| lon_vec.value(i));
|
||||
let r = resolutions
|
||||
.is_valid(i)
|
||||
.then(|| ensure_resolution_usize(resolutions.value(i)))
|
||||
.transpose()?;
|
||||
|
||||
let result = match (lat, lon) {
|
||||
(Some(lat), Some(lon)) => {
|
||||
let result = match (lat, lon, r) {
|
||||
(Some(lat), Some(lon), Some(r)) => {
|
||||
let coord = Coord { x: lon, y: lat };
|
||||
let encoded = geohash::encode(coord, r)
|
||||
.map_err(|e| {
|
||||
@@ -148,10 +114,10 @@ impl Function for GeohashFunction {
|
||||
_ => None,
|
||||
};
|
||||
|
||||
results.push(result.as_deref());
|
||||
builder.append_option(result);
|
||||
}
|
||||
|
||||
Ok(results.to_vector())
|
||||
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -176,8 +142,8 @@ impl Function for GeohashNeighboursFunction {
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
|
||||
Ok(DataType::List(Arc::new(Field::new(
|
||||
"x",
|
||||
DataType::Utf8,
|
||||
"item",
|
||||
DataType::Utf8View,
|
||||
false,
|
||||
))))
|
||||
}
|
||||
@@ -199,32 +165,33 @@ impl Function for GeohashNeighboursFunction {
|
||||
Signature::one_of(signatures, Volatility::Stable)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 3,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect 3, provided : {}",
|
||||
columns.len()
|
||||
),
|
||||
}
|
||||
);
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let args = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let [lat_vec, lon_vec, resolutions] = utils::take_function_args(self.name(), args)?;
|
||||
|
||||
let lat_vec = &columns[0];
|
||||
let lon_vec = &columns[1];
|
||||
let resolution_vec = &columns[2];
|
||||
let lat_vec = helpers::cast::<Float64Type>(&lat_vec)?;
|
||||
let lat_vec = lat_vec.as_primitive::<Float64Type>();
|
||||
let lon_vec = helpers::cast::<Float64Type>(&lon_vec)?;
|
||||
let lon_vec = lon_vec.as_primitive::<Float64Type>();
|
||||
let resolutions = helpers::cast::<UInt8Type>(&resolutions)?;
|
||||
let resolutions = resolutions.as_primitive::<UInt8Type>();
|
||||
|
||||
let size = lat_vec.len();
|
||||
let mut results =
|
||||
ListVectorBuilder::with_type_capacity(ConcreteDataType::string_datatype(), size);
|
||||
let mut builder = ListBuilder::new(StringViewBuilder::new());
|
||||
|
||||
for i in 0..size {
|
||||
let lat = lat_vec.get(i).as_f64_lossy();
|
||||
let lon = lon_vec.get(i).as_f64_lossy();
|
||||
let r = try_into_resolution(resolution_vec.get(i))?;
|
||||
let lat = lat_vec.is_valid(i).then(|| lat_vec.value(i));
|
||||
let lon = lon_vec.is_valid(i).then(|| lon_vec.value(i));
|
||||
let r = resolutions
|
||||
.is_valid(i)
|
||||
.then(|| ensure_resolution_usize(resolutions.value(i)))
|
||||
.transpose()?;
|
||||
|
||||
let result = match (lat, lon) {
|
||||
(Some(lat), Some(lon)) => {
|
||||
match (lat, lon, r) {
|
||||
(Some(lat), Some(lon), Some(r)) => {
|
||||
let coord = Coord { x: lon, y: lat };
|
||||
let encoded = geohash::encode(coord, r)
|
||||
.map_err(|e| {
|
||||
@@ -242,8 +209,8 @@ impl Function for GeohashNeighboursFunction {
|
||||
))
|
||||
})
|
||||
.context(error::ExecuteSnafu)?;
|
||||
Some(ListValue::new(
|
||||
vec![
|
||||
builder.append_value(
|
||||
[
|
||||
neighbours.n,
|
||||
neighbours.nw,
|
||||
neighbours.w,
|
||||
@@ -254,22 +221,14 @@ impl Function for GeohashNeighboursFunction {
|
||||
neighbours.ne,
|
||||
]
|
||||
.into_iter()
|
||||
.map(Value::from)
|
||||
.collect(),
|
||||
ConcreteDataType::string_datatype(),
|
||||
))
|
||||
.map(Some),
|
||||
);
|
||||
}
|
||||
_ => None,
|
||||
_ => builder.append_null(),
|
||||
};
|
||||
|
||||
if let Some(list_value) = result {
|
||||
results.push(Some(list_value.as_scalar_ref()));
|
||||
} else {
|
||||
results.push(None);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results.to_vector())
|
||||
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -12,6 +12,9 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use datafusion::arrow::array::{ArrayRef, ArrowPrimitiveType};
|
||||
use datafusion::arrow::compute;
|
||||
|
||||
macro_rules! ensure_columns_len {
|
||||
($columns:ident) => {
|
||||
snafu::ensure!(
|
||||
@@ -73,3 +76,15 @@ macro_rules! ensure_and_coerce {
|
||||
}
|
||||
|
||||
pub(crate) use ensure_and_coerce;
|
||||
|
||||
pub(crate) fn cast<T: ArrowPrimitiveType>(array: &ArrayRef) -> datafusion_common::Result<ArrayRef> {
|
||||
let x = compute::cast_with_options(
|
||||
array.as_ref(),
|
||||
&T::DATA_TYPE,
|
||||
&compute::CastOptions {
|
||||
safe: false,
|
||||
..Default::default()
|
||||
},
|
||||
)?;
|
||||
Ok(x)
|
||||
}
|
||||
|
||||
@@ -16,23 +16,20 @@ use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error::{
|
||||
GeneralDataFusionSnafu, IntoVectorSnafu, InvalidFuncArgsSnafu, InvalidInputTypeSnafu, Result,
|
||||
};
|
||||
use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use datafusion::arrow::array::{Array, ArrayRef, AsArray, BooleanArray};
|
||||
use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion};
|
||||
use datafusion::common::{DFSchema, Result as DfResult};
|
||||
use datafusion::execution::SessionStateBuilder;
|
||||
use datafusion::logical_expr::{self, Expr, Volatility};
|
||||
use datafusion::logical_expr::{self, ColumnarValue, Expr, Volatility};
|
||||
use datafusion::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner};
|
||||
use datafusion_expr::Signature;
|
||||
use datafusion_common::{DataFusionError, utils};
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature};
|
||||
use datatypes::arrow::array::RecordBatch;
|
||||
use datatypes::arrow::datatypes::{DataType, Field};
|
||||
use datatypes::prelude::VectorRef;
|
||||
use datatypes::vectors::BooleanVector;
|
||||
use snafu::{OptionExt, ResultExt, ensure};
|
||||
use store_api::storage::ConcreteDataType;
|
||||
use snafu::{OptionExt, ensure};
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function::Function;
|
||||
use crate::function_registry::FunctionRegistry;
|
||||
|
||||
/// `matches` for full text search.
|
||||
@@ -67,38 +64,36 @@ impl Function for MatchesFunction {
|
||||
}
|
||||
|
||||
// TODO: read case-sensitive config
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 2,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect exactly 2, have: {}",
|
||||
columns.len()
|
||||
),
|
||||
}
|
||||
);
|
||||
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DfResult<ColumnarValue> {
|
||||
let args = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let [data_column, patterns] = utils::take_function_args(self.name(), args)?;
|
||||
|
||||
let data_column = &columns[0];
|
||||
if data_column.is_empty() {
|
||||
return Ok(Arc::new(BooleanVector::from(Vec::<bool>::with_capacity(0))));
|
||||
return Ok(ColumnarValue::Array(Arc::new(BooleanArray::from(
|
||||
Vec::<bool>::with_capacity(0),
|
||||
))));
|
||||
}
|
||||
|
||||
let pattern_vector = &columns[1]
|
||||
.cast(&ConcreteDataType::string_datatype())
|
||||
.context(InvalidInputTypeSnafu {
|
||||
err_msg: "cannot cast `pattern` to string",
|
||||
})?;
|
||||
// Safety: both length and type are checked before
|
||||
let pattern = pattern_vector.get(0).as_string().unwrap();
|
||||
let pattern = match patterns.data_type() {
|
||||
DataType::Utf8View => patterns.as_string_view().value(0),
|
||||
DataType::Utf8 => patterns.as_string::<i32>().value(0),
|
||||
DataType::LargeUtf8 => patterns.as_string::<i64>().value(0),
|
||||
t => {
|
||||
return Err(DataFusionError::Execution(format!(
|
||||
"unsupported datatype {t}"
|
||||
)));
|
||||
}
|
||||
};
|
||||
self.eval(data_column, pattern)
|
||||
}
|
||||
}
|
||||
|
||||
impl MatchesFunction {
|
||||
fn eval(&self, data: &VectorRef, pattern: String) -> Result<VectorRef> {
|
||||
fn eval(&self, data_array: ArrayRef, pattern: &str) -> DfResult<ColumnarValue> {
|
||||
let col_name = "data";
|
||||
let parser_context = ParserContext::default();
|
||||
let raw_ast = parser_context.parse_pattern(&pattern)?;
|
||||
let raw_ast = parser_context.parse_pattern(pattern)?;
|
||||
let ast = raw_ast.transform_ast()?;
|
||||
|
||||
let like_expr = ast.into_like_expr(col_name);
|
||||
@@ -106,27 +101,17 @@ impl MatchesFunction {
|
||||
let input_schema = Self::input_schema();
|
||||
let session_state = SessionStateBuilder::new().with_default_features().build();
|
||||
let planner = DefaultPhysicalPlanner::default();
|
||||
let physical_expr = planner
|
||||
.create_physical_expr(&like_expr, &input_schema, &session_state)
|
||||
.context(GeneralDataFusionSnafu)?;
|
||||
let physical_expr =
|
||||
planner.create_physical_expr(&like_expr, &input_schema, &session_state)?;
|
||||
|
||||
let data_array = data.to_arrow_array();
|
||||
let arrow_schema = Arc::new(input_schema.as_arrow().clone());
|
||||
let input_record_batch = RecordBatch::try_new(arrow_schema, vec![data_array]).unwrap();
|
||||
|
||||
let num_rows = input_record_batch.num_rows();
|
||||
let result = physical_expr
|
||||
.evaluate(&input_record_batch)
|
||||
.context(GeneralDataFusionSnafu)?;
|
||||
let result_array = result
|
||||
.into_array(num_rows)
|
||||
.context(GeneralDataFusionSnafu)?;
|
||||
let result_vector =
|
||||
BooleanVector::try_from_arrow_array(result_array).context(IntoVectorSnafu {
|
||||
data_type: DataType::Boolean,
|
||||
})?;
|
||||
let result = physical_expr.evaluate(&input_record_batch)?;
|
||||
let result_array = result.into_array(num_rows)?;
|
||||
|
||||
Ok(Arc::new(result_vector))
|
||||
Ok(ColumnarValue::Array(Arc::new(result_array)))
|
||||
}
|
||||
|
||||
fn input_schema() -> DFSchema {
|
||||
@@ -210,14 +195,12 @@ impl PatternAst {
|
||||
/// Transform this AST with preset rules to make it correct.
|
||||
fn transform_ast(self) -> Result<Self> {
|
||||
self.transform_up(Self::collapse_binary_branch_fn)
|
||||
.context(GeneralDataFusionSnafu)
|
||||
.map(|data| data.data)?
|
||||
.transform_up(Self::eliminate_optional_fn)
|
||||
.context(GeneralDataFusionSnafu)
|
||||
.map(|data| data.data)?
|
||||
.transform_down(Self::eliminate_single_child_fn)
|
||||
.context(GeneralDataFusionSnafu)
|
||||
.map(|data| data.data)
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Collapse binary branch with the same operator. I.e., this transformer
|
||||
@@ -842,7 +825,9 @@ impl Tokenizer {
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use datatypes::vectors::StringVector;
|
||||
use datafusion::arrow::array::StringArray;
|
||||
use datafusion_common::ScalarValue;
|
||||
use datafusion_common::config::ConfigOptions;
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -1309,7 +1294,7 @@ mod test {
|
||||
"The quick brown fox jumps over dog",
|
||||
"The quick brown fox jumps over the dog",
|
||||
];
|
||||
let input_vector: VectorRef = Arc::new(StringVector::from(input_data));
|
||||
let col: ArrayRef = Arc::new(StringArray::from(input_data));
|
||||
let cases = [
|
||||
// basic cases
|
||||
("quick", vec![true, false, true, true, true, true, true]),
|
||||
@@ -1400,9 +1385,22 @@ mod test {
|
||||
|
||||
let f = MatchesFunction;
|
||||
for (pattern, expected) in cases {
|
||||
let actual: VectorRef = f.eval(&input_vector, pattern.to_string()).unwrap();
|
||||
let expected: VectorRef = Arc::new(BooleanVector::from(expected)) as _;
|
||||
assert_eq!(expected, actual, "{pattern}");
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![
|
||||
ColumnarValue::Array(col.clone()),
|
||||
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(pattern.to_string()))),
|
||||
],
|
||||
arg_fields: vec![],
|
||||
number_rows: col.len(),
|
||||
return_field: Arc::new(Field::new("x", col.data_type().clone(), true)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
let actual = f
|
||||
.invoke_with_args(args)
|
||||
.and_then(|x| x.to_array(col.len()))
|
||||
.unwrap();
|
||||
let expected: ArrayRef = Arc::new(BooleanArray::from(expected));
|
||||
assert_eq!(expected.as_ref(), actual.as_ref(), "{pattern}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,15 +19,13 @@ mod rate;
|
||||
use std::fmt;
|
||||
|
||||
pub use clamp::{ClampFunction, ClampMaxFunction, ClampMinFunction};
|
||||
use common_query::error::{GeneralDataFusionSnafu, Result};
|
||||
use common_query::error::Result;
|
||||
use datafusion::arrow::datatypes::DataType;
|
||||
use datafusion::error::DataFusionError;
|
||||
use datafusion_expr::{Signature, Volatility};
|
||||
use datatypes::vectors::VectorRef;
|
||||
pub use rate::RateFunction;
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function::Function;
|
||||
use crate::function_registry::FunctionRegistry;
|
||||
use crate::scalars::math::modulo::ModuloFunction;
|
||||
|
||||
@@ -68,7 +66,7 @@ impl Function for RangeFunction {
|
||||
.ok_or(DataFusionError::Internal(
|
||||
"No expr found in range_fn".into(),
|
||||
))
|
||||
.context(GeneralDataFusionSnafu)
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// `range_fn` will never been used. As long as a legal signature is returned, the specific content of the signature does not matter.
|
||||
@@ -76,11 +74,4 @@ impl Function for RangeFunction {
|
||||
fn signature(&self) -> Signature {
|
||||
Signature::variadic_any(Volatility::Immutable)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
Err(DataFusionError::Internal(
|
||||
"range_fn just a empty function used in range select, It should not be eval!".into(),
|
||||
))
|
||||
.context(GeneralDataFusionSnafu)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,54 +15,21 @@
|
||||
use std::fmt::{self, Display};
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use datafusion::arrow::array::{ArrayIter, PrimitiveArray};
|
||||
use common_query::error::Result;
|
||||
use datafusion::arrow::array::{Array, ArrayRef, AsArray, PrimitiveArray};
|
||||
use datafusion::arrow::datatypes::DataType as ArrowDataType;
|
||||
use datafusion::logical_expr::Volatility;
|
||||
use datafusion_expr::Signature;
|
||||
use datafusion::logical_expr::{ColumnarValue, Volatility};
|
||||
use datafusion_common::{DataFusionError, ScalarValue, utils};
|
||||
use datafusion_expr::type_coercion::aggregates::NUMERICS;
|
||||
use datatypes::data_type::DataType;
|
||||
use datatypes::prelude::VectorRef;
|
||||
use datatypes::types::LogicalPrimitiveType;
|
||||
use datatypes::value::TryAsPrimitive;
|
||||
use datatypes::vectors::PrimitiveVector;
|
||||
use datatypes::with_match_primitive_type_id;
|
||||
use snafu::{OptionExt, ensure};
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature};
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function::Function;
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct ClampFunction;
|
||||
|
||||
const CLAMP_NAME: &str = "clamp";
|
||||
|
||||
/// Ensure the vector is constant and not empty (i.e., all values are identical)
|
||||
fn ensure_constant_vector(vector: &VectorRef) -> Result<()> {
|
||||
ensure!(
|
||||
!vector.is_empty(),
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: "Expect at least one value",
|
||||
}
|
||||
);
|
||||
|
||||
if vector.is_const() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let first = vector.get_ref(0);
|
||||
for i in 1..vector.len() {
|
||||
let v = vector.get_ref(i);
|
||||
if first != v {
|
||||
return InvalidFuncArgsSnafu {
|
||||
err_msg: "All values in min/max argument must be identical",
|
||||
}
|
||||
.fail();
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl Function for ClampFunction {
|
||||
fn name(&self) -> &str {
|
||||
CLAMP_NAME
|
||||
@@ -78,76 +45,12 @@ impl Function for ClampFunction {
|
||||
Signature::uniform(3, NUMERICS.to_vec(), Volatility::Immutable)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 3,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect exactly 3, have: {}",
|
||||
columns.len()
|
||||
),
|
||||
}
|
||||
);
|
||||
ensure!(
|
||||
columns[0].data_type().is_numeric(),
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The first arg's type is not numeric, have: {}",
|
||||
columns[0].data_type()
|
||||
),
|
||||
}
|
||||
);
|
||||
ensure!(
|
||||
columns[0].data_type() == columns[1].data_type()
|
||||
&& columns[1].data_type() == columns[2].data_type(),
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"Arguments don't have identical types: {}, {}, {}",
|
||||
columns[0].data_type(),
|
||||
columns[1].data_type(),
|
||||
columns[2].data_type()
|
||||
),
|
||||
}
|
||||
);
|
||||
|
||||
ensure_constant_vector(&columns[1])?;
|
||||
ensure_constant_vector(&columns[2])?;
|
||||
|
||||
with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| {
|
||||
let input_array = columns[0].to_arrow_array();
|
||||
let input = input_array
|
||||
.as_any()
|
||||
.downcast_ref::<PrimitiveArray<<$S as LogicalPrimitiveType>::ArrowPrimitive>>()
|
||||
.unwrap();
|
||||
|
||||
let min = TryAsPrimitive::<$S>::try_as_primitive(&columns[1].get(0))
|
||||
.with_context(|| {
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: "The second arg should not be none",
|
||||
}
|
||||
})?;
|
||||
let max = TryAsPrimitive::<$S>::try_as_primitive(&columns[2].get(0))
|
||||
.with_context(|| {
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: "The third arg should not be none",
|
||||
}
|
||||
})?;
|
||||
|
||||
// ensure min <= max
|
||||
ensure!(
|
||||
min <= max,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The second arg should be less than or equal to the third arg, have: {:?}, {:?}",
|
||||
columns[1], columns[2]
|
||||
),
|
||||
}
|
||||
);
|
||||
|
||||
clamp_impl::<$S, true, true>(input, min, max)
|
||||
},{
|
||||
unreachable!()
|
||||
})
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let [col, min, max] = utils::take_function_args(self.name(), args.args)?;
|
||||
clamp_impl(col, min, max)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -157,25 +60,155 @@ impl Display for ClampFunction {
|
||||
}
|
||||
}
|
||||
|
||||
fn clamp_impl<T: LogicalPrimitiveType, const CLAMP_MIN: bool, const CLAMP_MAX: bool>(
|
||||
input: &PrimitiveArray<T::ArrowPrimitive>,
|
||||
min: T::Native,
|
||||
max: T::Native,
|
||||
) -> Result<VectorRef> {
|
||||
let iter = ArrayIter::new(input);
|
||||
let result = iter.map(|x| {
|
||||
x.map(|x| {
|
||||
if CLAMP_MIN && x < min {
|
||||
min
|
||||
} else if CLAMP_MAX && x > max {
|
||||
max
|
||||
} else {
|
||||
x
|
||||
fn clamp_impl(
|
||||
col: ColumnarValue,
|
||||
min: ColumnarValue,
|
||||
max: ColumnarValue,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
if col.data_type() != min.data_type() || min.data_type() != max.data_type() {
|
||||
return Err(DataFusionError::Execution(format!(
|
||||
"argument data types mismatch: {}, {}, {}",
|
||||
col.data_type(),
|
||||
min.data_type(),
|
||||
max.data_type(),
|
||||
)));
|
||||
}
|
||||
|
||||
macro_rules! with_match_numerics_types {
|
||||
($data_type:expr, | $_:tt $T:ident | $body:tt) => {{
|
||||
macro_rules! __with_ty__ {
|
||||
( $_ $T:ident ) => {
|
||||
$body
|
||||
};
|
||||
}
|
||||
})
|
||||
});
|
||||
let result = PrimitiveArray::<T::ArrowPrimitive>::from_iter(result);
|
||||
Ok(Arc::new(PrimitiveVector::<T>::from(result)))
|
||||
|
||||
use datafusion::arrow::datatypes::{
|
||||
Float32Type, Float64Type, Int8Type, Int16Type, Int32Type, Int64Type, UInt8Type,
|
||||
UInt16Type, UInt32Type, UInt64Type,
|
||||
};
|
||||
|
||||
match $data_type {
|
||||
ArrowDataType::Int8 => Ok(__with_ty__! { Int8Type }),
|
||||
ArrowDataType::Int16 => Ok(__with_ty__! { Int16Type }),
|
||||
ArrowDataType::Int32 => Ok(__with_ty__! { Int32Type }),
|
||||
ArrowDataType::Int64 => Ok(__with_ty__! { Int64Type }),
|
||||
ArrowDataType::UInt8 => Ok(__with_ty__! { UInt8Type }),
|
||||
ArrowDataType::UInt16 => Ok(__with_ty__! { UInt16Type }),
|
||||
ArrowDataType::UInt32 => Ok(__with_ty__! { UInt32Type }),
|
||||
ArrowDataType::UInt64 => Ok(__with_ty__! { UInt64Type }),
|
||||
ArrowDataType::Float32 => Ok(__with_ty__! { Float32Type }),
|
||||
ArrowDataType::Float64 => Ok(__with_ty__! { Float64Type }),
|
||||
_ => Err(DataFusionError::Execution(format!(
|
||||
"unsupported numeric data type: '{}'",
|
||||
$data_type
|
||||
))),
|
||||
}
|
||||
}};
|
||||
}
|
||||
|
||||
macro_rules! clamp {
|
||||
($v: ident, $min: ident, $max: ident) => {
|
||||
if $v < $min {
|
||||
$min
|
||||
} else if $v > $max {
|
||||
$max
|
||||
} else {
|
||||
$v
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
match (col, min, max) {
|
||||
(ColumnarValue::Scalar(col), ColumnarValue::Scalar(min), ColumnarValue::Scalar(max)) => {
|
||||
if min > max {
|
||||
return Err(DataFusionError::Execution(format!(
|
||||
"min '{}' > max '{}'",
|
||||
min, max
|
||||
)));
|
||||
}
|
||||
Ok(ColumnarValue::Scalar(clamp!(col, min, max)))
|
||||
}
|
||||
|
||||
(ColumnarValue::Array(col), ColumnarValue::Array(min), ColumnarValue::Array(max)) => {
|
||||
if col.len() != min.len() || col.len() != max.len() {
|
||||
return Err(DataFusionError::Internal(
|
||||
"arguments not of same length".to_string(),
|
||||
));
|
||||
}
|
||||
let result = with_match_numerics_types!(
|
||||
col.data_type(),
|
||||
|$S| {
|
||||
let col = col.as_primitive::<$S>();
|
||||
let min = min.as_primitive::<$S>();
|
||||
let max = max.as_primitive::<$S>();
|
||||
Arc::new(PrimitiveArray::<$S>::from(
|
||||
(0..col.len())
|
||||
.map(|i| {
|
||||
let v = col.is_valid(i).then(|| col.value(i));
|
||||
// Index safety: checked above, all have same length.
|
||||
let min = min.is_valid(i).then(|| min.value(i));
|
||||
let max = max.is_valid(i).then(|| max.value(i));
|
||||
Ok(match (v, min, max) {
|
||||
(Some(v), Some(min), Some(max)) => {
|
||||
if min > max {
|
||||
return Err(DataFusionError::Execution(format!(
|
||||
"min '{}' > max '{}'",
|
||||
min, max
|
||||
)));
|
||||
}
|
||||
Some(clamp!(v, min, max))
|
||||
},
|
||||
_ => None,
|
||||
})
|
||||
})
|
||||
.collect::<datafusion_common::Result<Vec<_>>>()?,
|
||||
)
|
||||
) as ArrayRef
|
||||
}
|
||||
)?;
|
||||
Ok(ColumnarValue::Array(result))
|
||||
}
|
||||
|
||||
(ColumnarValue::Array(col), ColumnarValue::Scalar(min), ColumnarValue::Scalar(max)) => {
|
||||
if min.is_null() || max.is_null() {
|
||||
return Err(DataFusionError::Execution(
|
||||
"argument 'min' or 'max' is null".to_string(),
|
||||
));
|
||||
}
|
||||
let min = min.to_array()?;
|
||||
let max = max.to_array()?;
|
||||
let result = with_match_numerics_types!(
|
||||
col.data_type(),
|
||||
|$S| {
|
||||
let col = col.as_primitive::<$S>();
|
||||
// Index safety: checked above, both are not nulls.
|
||||
let min = min.as_primitive::<$S>().value(0);
|
||||
let max = max.as_primitive::<$S>().value(0);
|
||||
if min > max {
|
||||
return Err(DataFusionError::Execution(format!(
|
||||
"min '{}' > max '{}'",
|
||||
min, max
|
||||
)));
|
||||
}
|
||||
Arc::new(PrimitiveArray::<$S>::from(
|
||||
(0..col.len())
|
||||
.map(|x| {
|
||||
col.is_valid(x).then(|| {
|
||||
let v = col.value(x);
|
||||
clamp!(v, min, max)
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
) as ArrayRef
|
||||
}
|
||||
)?;
|
||||
Ok(ColumnarValue::Array(result))
|
||||
}
|
||||
_ => Err(DataFusionError::Internal(
|
||||
"argument column types mismatch".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
@@ -197,59 +230,19 @@ impl Function for ClampMinFunction {
|
||||
Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 2,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect exactly 2, have: {}",
|
||||
columns.len()
|
||||
),
|
||||
}
|
||||
);
|
||||
ensure!(
|
||||
columns[0].data_type().is_numeric(),
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The first arg's type is not numeric, have: {}",
|
||||
columns[0].data_type()
|
||||
),
|
||||
}
|
||||
);
|
||||
ensure!(
|
||||
columns[0].data_type() == columns[1].data_type(),
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"Arguments don't have identical types: {}, {}",
|
||||
columns[0].data_type(),
|
||||
columns[1].data_type()
|
||||
),
|
||||
}
|
||||
);
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let [col, min] = utils::take_function_args(self.name(), args.args)?;
|
||||
|
||||
ensure_constant_vector(&columns[1])?;
|
||||
|
||||
with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| {
|
||||
let input_array = columns[0].to_arrow_array();
|
||||
let input = input_array
|
||||
.as_any()
|
||||
.downcast_ref::<PrimitiveArray<<$S as LogicalPrimitiveType>::ArrowPrimitive>>()
|
||||
.unwrap();
|
||||
|
||||
let min = TryAsPrimitive::<$S>::try_as_primitive(&columns[1].get(0))
|
||||
.with_context(|| {
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: "The second arg (min) should not be none",
|
||||
}
|
||||
})?;
|
||||
// For clamp_min, max is effectively infinity, so we don't use it in the clamp_impl logic.
|
||||
// We pass a default/dummy value for max.
|
||||
let max_dummy = <$S as LogicalPrimitiveType>::Native::default();
|
||||
|
||||
clamp_impl::<$S, true, false>(input, min, max_dummy)
|
||||
},{
|
||||
unreachable!()
|
||||
})
|
||||
let Some(max) = ScalarValue::max(&min.data_type()) else {
|
||||
return Err(DataFusionError::Internal(format!(
|
||||
"cannot find a max value for numeric data type {}",
|
||||
min.data_type()
|
||||
)));
|
||||
};
|
||||
clamp_impl(col, min, ColumnarValue::Scalar(max))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -278,59 +271,19 @@ impl Function for ClampMaxFunction {
|
||||
Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 2,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect exactly 2, have: {}",
|
||||
columns.len()
|
||||
),
|
||||
}
|
||||
);
|
||||
ensure!(
|
||||
columns[0].data_type().is_numeric(),
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The first arg's type is not numeric, have: {}",
|
||||
columns[0].data_type()
|
||||
),
|
||||
}
|
||||
);
|
||||
ensure!(
|
||||
columns[0].data_type() == columns[1].data_type(),
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"Arguments don't have identical types: {}, {}",
|
||||
columns[0].data_type(),
|
||||
columns[1].data_type()
|
||||
),
|
||||
}
|
||||
);
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let [col, max] = utils::take_function_args(self.name(), args.args)?;
|
||||
|
||||
ensure_constant_vector(&columns[1])?;
|
||||
|
||||
with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| {
|
||||
let input_array = columns[0].to_arrow_array();
|
||||
let input = input_array
|
||||
.as_any()
|
||||
.downcast_ref::<PrimitiveArray<<$S as LogicalPrimitiveType>::ArrowPrimitive>>()
|
||||
.unwrap();
|
||||
|
||||
let max = TryAsPrimitive::<$S>::try_as_primitive(&columns[1].get(0))
|
||||
.with_context(|| {
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: "The second arg (max) should not be none",
|
||||
}
|
||||
})?;
|
||||
// For clamp_max, min is effectively -infinity, so we don't use it in the clamp_impl logic.
|
||||
// We pass a default/dummy value for min.
|
||||
let min_dummy = <$S as LogicalPrimitiveType>::Native::default();
|
||||
|
||||
clamp_impl::<$S, false, true>(input, min_dummy, max)
|
||||
},{
|
||||
unreachable!()
|
||||
})
|
||||
let Some(min) = ScalarValue::min(&max.data_type()) else {
|
||||
return Err(DataFusionError::Internal(format!(
|
||||
"cannot find a min value for numeric data type {}",
|
||||
max.data_type()
|
||||
)));
|
||||
};
|
||||
clamp_impl(col, ColumnarValue::Scalar(min), max)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -345,55 +298,80 @@ mod test {
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use datatypes::prelude::ScalarVector;
|
||||
use datatypes::vectors::{
|
||||
ConstantVector, Float64Vector, Int64Vector, StringVector, UInt64Vector,
|
||||
};
|
||||
use arrow_schema::Field;
|
||||
use datafusion_common::config::ConfigOptions;
|
||||
use datatypes::arrow::array::{ArrayRef, Float64Array, Int64Array, UInt64Array};
|
||||
use datatypes::arrow_array::StringArray;
|
||||
|
||||
use super::*;
|
||||
use crate::function::FunctionContext;
|
||||
|
||||
macro_rules! impl_test_eval {
|
||||
($func: ty) => {
|
||||
impl $func {
|
||||
fn test_eval(
|
||||
&self,
|
||||
args: Vec<ColumnarValue>,
|
||||
number_rows: usize,
|
||||
) -> datafusion_common::Result<ArrayRef> {
|
||||
let input_type = args[0].data_type();
|
||||
self.invoke_with_args(ScalarFunctionArgs {
|
||||
args,
|
||||
arg_fields: vec![],
|
||||
number_rows,
|
||||
return_field: Arc::new(Field::new("x", input_type, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
})
|
||||
.and_then(|v| ColumnarValue::values_to_arrays(&[v]).map_err(Into::into))
|
||||
.map(|mut a| a.remove(0))
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_test_eval!(ClampFunction);
|
||||
impl_test_eval!(ClampMinFunction);
|
||||
impl_test_eval!(ClampMaxFunction);
|
||||
|
||||
#[test]
|
||||
fn clamp_i64() {
|
||||
let inputs = [
|
||||
(
|
||||
vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
|
||||
-1,
|
||||
10,
|
||||
-1i64,
|
||||
10i64,
|
||||
vec![Some(-1), Some(-1), Some(-1), Some(0), Some(1), Some(2)],
|
||||
),
|
||||
(
|
||||
vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
|
||||
0,
|
||||
0,
|
||||
0i64,
|
||||
0i64,
|
||||
vec![Some(0), Some(0), Some(0), Some(0), Some(0), Some(0)],
|
||||
),
|
||||
(
|
||||
vec![Some(-3), None, Some(-1), None, None, Some(2)],
|
||||
-2,
|
||||
1,
|
||||
-2i64,
|
||||
1i64,
|
||||
vec![Some(-2), None, Some(-1), None, None, Some(1)],
|
||||
),
|
||||
(
|
||||
vec![None, None, None, None, None],
|
||||
0,
|
||||
1,
|
||||
0i64,
|
||||
1i64,
|
||||
vec![None, None, None, None, None],
|
||||
),
|
||||
];
|
||||
|
||||
let func = ClampFunction;
|
||||
for (in_data, min, max, expected) in inputs {
|
||||
let args = [
|
||||
Arc::new(Int64Vector::from(in_data)) as _,
|
||||
Arc::new(Int64Vector::from_vec(vec![min])) as _,
|
||||
Arc::new(Int64Vector::from_vec(vec![max])) as _,
|
||||
let number_rows = in_data.len();
|
||||
let args = vec![
|
||||
ColumnarValue::Array(Arc::new(Int64Array::from(in_data))),
|
||||
ColumnarValue::Scalar(min.into()),
|
||||
ColumnarValue::Scalar(max.into()),
|
||||
];
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), args.as_slice())
|
||||
.unwrap();
|
||||
let expected: VectorRef = Arc::new(Int64Vector::from(expected));
|
||||
assert_eq!(expected, result);
|
||||
let result = func.test_eval(args, number_rows).unwrap();
|
||||
let expected: ArrayRef = Arc::new(Int64Array::from(expected));
|
||||
assert_eq!(expected.as_ref(), result.as_ref());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -402,42 +380,41 @@ mod test {
|
||||
let inputs = [
|
||||
(
|
||||
vec![Some(0), Some(1), Some(2), Some(3), Some(4), Some(5)],
|
||||
1,
|
||||
3,
|
||||
1u64,
|
||||
3u64,
|
||||
vec![Some(1), Some(1), Some(2), Some(3), Some(3), Some(3)],
|
||||
),
|
||||
(
|
||||
vec![Some(0), Some(1), Some(2), Some(3), Some(4), Some(5)],
|
||||
0,
|
||||
0,
|
||||
0u64,
|
||||
0u64,
|
||||
vec![Some(0), Some(0), Some(0), Some(0), Some(0), Some(0)],
|
||||
),
|
||||
(
|
||||
vec![Some(0), None, Some(2), None, None, Some(5)],
|
||||
1,
|
||||
3,
|
||||
1u64,
|
||||
3u64,
|
||||
vec![Some(1), None, Some(2), None, None, Some(3)],
|
||||
),
|
||||
(
|
||||
vec![None, None, None, None, None],
|
||||
0,
|
||||
1,
|
||||
0u64,
|
||||
1u64,
|
||||
vec![None, None, None, None, None],
|
||||
),
|
||||
];
|
||||
|
||||
let func = ClampFunction;
|
||||
for (in_data, min, max, expected) in inputs {
|
||||
let args = [
|
||||
Arc::new(UInt64Vector::from(in_data)) as _,
|
||||
Arc::new(UInt64Vector::from_vec(vec![min])) as _,
|
||||
Arc::new(UInt64Vector::from_vec(vec![max])) as _,
|
||||
let number_rows = in_data.len();
|
||||
let args = vec![
|
||||
ColumnarValue::Array(Arc::new(UInt64Array::from(in_data))),
|
||||
ColumnarValue::Scalar(min.into()),
|
||||
ColumnarValue::Scalar(max.into()),
|
||||
];
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), args.as_slice())
|
||||
.unwrap();
|
||||
let expected: VectorRef = Arc::new(UInt64Vector::from(expected));
|
||||
assert_eq!(expected, result);
|
||||
let result = func.test_eval(args, number_rows).unwrap();
|
||||
let expected: ArrayRef = Arc::new(UInt64Array::from(expected));
|
||||
assert_eq!(expected.as_ref(), result.as_ref());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -472,38 +449,18 @@ mod test {
|
||||
|
||||
let func = ClampFunction;
|
||||
for (in_data, min, max, expected) in inputs {
|
||||
let args = [
|
||||
Arc::new(Float64Vector::from(in_data)) as _,
|
||||
Arc::new(Float64Vector::from_vec(vec![min])) as _,
|
||||
Arc::new(Float64Vector::from_vec(vec![max])) as _,
|
||||
let number_rows = in_data.len();
|
||||
let args = vec![
|
||||
ColumnarValue::Array(Arc::new(Float64Array::from(in_data))),
|
||||
ColumnarValue::Scalar(min.into()),
|
||||
ColumnarValue::Scalar(max.into()),
|
||||
];
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), args.as_slice())
|
||||
.unwrap();
|
||||
let expected: VectorRef = Arc::new(Float64Vector::from(expected));
|
||||
assert_eq!(expected, result);
|
||||
let result = func.test_eval(args, number_rows).unwrap();
|
||||
let expected: ArrayRef = Arc::new(Float64Array::from(expected));
|
||||
assert_eq!(expected.as_ref(), result.as_ref());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clamp_const_i32() {
|
||||
let input = vec![Some(5)];
|
||||
let min = 2;
|
||||
let max = 4;
|
||||
|
||||
let func = ClampFunction;
|
||||
let args = [
|
||||
Arc::new(ConstantVector::new(Arc::new(Int64Vector::from(input)), 1)) as _,
|
||||
Arc::new(Int64Vector::from_vec(vec![min])) as _,
|
||||
Arc::new(Int64Vector::from_vec(vec![max])) as _,
|
||||
];
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), args.as_slice())
|
||||
.unwrap();
|
||||
let expected: VectorRef = Arc::new(Int64Vector::from(vec![Some(4)]));
|
||||
assert_eq!(expected, result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clamp_invalid_min_max() {
|
||||
let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
|
||||
@@ -511,28 +468,30 @@ mod test {
|
||||
let max = -1.0;
|
||||
|
||||
let func = ClampFunction;
|
||||
let args = [
|
||||
Arc::new(Float64Vector::from(input)) as _,
|
||||
Arc::new(Float64Vector::from_vec(vec![min])) as _,
|
||||
Arc::new(Float64Vector::from_vec(vec![max])) as _,
|
||||
let number_rows = input.len();
|
||||
let args = vec![
|
||||
ColumnarValue::Array(Arc::new(Float64Array::from(input))),
|
||||
ColumnarValue::Scalar(min.into()),
|
||||
ColumnarValue::Scalar(max.into()),
|
||||
];
|
||||
let result = func.eval(&FunctionContext::default(), args.as_slice());
|
||||
let result = func.test_eval(args, number_rows);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clamp_type_not_match() {
|
||||
let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
|
||||
let min = -1;
|
||||
let max = 10;
|
||||
let min = -1i64;
|
||||
let max = 10u64;
|
||||
|
||||
let func = ClampFunction;
|
||||
let args = [
|
||||
Arc::new(Float64Vector::from(input)) as _,
|
||||
Arc::new(Int64Vector::from_vec(vec![min])) as _,
|
||||
Arc::new(UInt64Vector::from_vec(vec![max])) as _,
|
||||
let number_rows = input.len();
|
||||
let args = vec![
|
||||
ColumnarValue::Array(Arc::new(Float64Array::from(input))),
|
||||
ColumnarValue::Scalar(min.into()),
|
||||
ColumnarValue::Scalar(max.into()),
|
||||
];
|
||||
let result = func.eval(&FunctionContext::default(), args.as_slice());
|
||||
let result = func.test_eval(args, number_rows);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
@@ -543,12 +502,13 @@ mod test {
|
||||
let max = 1.0;
|
||||
|
||||
let func = ClampFunction;
|
||||
let args = [
|
||||
Arc::new(Float64Vector::from(input)) as _,
|
||||
Arc::new(Float64Vector::from_vec(vec![min, max])) as _,
|
||||
Arc::new(Float64Vector::from_vec(vec![max, min])) as _,
|
||||
let number_rows = input.len();
|
||||
let args = vec![
|
||||
ColumnarValue::Array(Arc::new(Float64Array::from(input))),
|
||||
ColumnarValue::Array(Arc::new(Float64Array::from(vec![min, max]))),
|
||||
ColumnarValue::Array(Arc::new(Float64Array::from(vec![max, min]))),
|
||||
];
|
||||
let result = func.eval(&FunctionContext::default(), args.as_slice());
|
||||
let result = func.test_eval(args, number_rows);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
@@ -558,11 +518,12 @@ mod test {
|
||||
let min = -10.0;
|
||||
|
||||
let func = ClampFunction;
|
||||
let args = [
|
||||
Arc::new(Float64Vector::from(input)) as _,
|
||||
Arc::new(Float64Vector::from_vec(vec![min])) as _,
|
||||
let number_rows = input.len();
|
||||
let args = vec![
|
||||
ColumnarValue::Array(Arc::new(Float64Array::from(input))),
|
||||
ColumnarValue::Scalar(min.into()),
|
||||
];
|
||||
let result = func.eval(&FunctionContext::default(), args.as_slice());
|
||||
let result = func.test_eval(args, number_rows);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
@@ -571,12 +532,13 @@ mod test {
|
||||
let input = vec![Some("foo"), Some("foo"), Some("foo"), Some("foo")];
|
||||
|
||||
let func = ClampFunction;
|
||||
let args = [
|
||||
Arc::new(StringVector::from(input)) as _,
|
||||
Arc::new(StringVector::from_vec(vec!["bar"])) as _,
|
||||
Arc::new(StringVector::from_vec(vec!["baz"])) as _,
|
||||
let number_rows = input.len();
|
||||
let args = vec![
|
||||
ColumnarValue::Array(Arc::new(StringArray::from(input))),
|
||||
ColumnarValue::Scalar("bar".into()),
|
||||
ColumnarValue::Scalar("baz".into()),
|
||||
];
|
||||
let result = func.eval(&FunctionContext::default(), args.as_slice());
|
||||
let result = func.test_eval(args, number_rows);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
@@ -585,27 +547,26 @@ mod test {
|
||||
let inputs = [
|
||||
(
|
||||
vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
|
||||
-1,
|
||||
-1i64,
|
||||
vec![Some(-1), Some(-1), Some(-1), Some(0), Some(1), Some(2)],
|
||||
),
|
||||
(
|
||||
vec![Some(-3), None, Some(-1), None, None, Some(2)],
|
||||
-2,
|
||||
-2i64,
|
||||
vec![Some(-2), None, Some(-1), None, None, Some(2)],
|
||||
),
|
||||
];
|
||||
|
||||
let func = ClampMinFunction;
|
||||
for (in_data, min, expected) in inputs {
|
||||
let args = [
|
||||
Arc::new(Int64Vector::from(in_data)) as _,
|
||||
Arc::new(Int64Vector::from_vec(vec![min])) as _,
|
||||
let number_rows = in_data.len();
|
||||
let args = vec![
|
||||
ColumnarValue::Array(Arc::new(Int64Array::from(in_data))),
|
||||
ColumnarValue::Scalar(min.into()),
|
||||
];
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), args.as_slice())
|
||||
.unwrap();
|
||||
let expected: VectorRef = Arc::new(Int64Vector::from(expected));
|
||||
assert_eq!(expected, result);
|
||||
let result = func.test_eval(args, number_rows).unwrap();
|
||||
let expected: ArrayRef = Arc::new(Int64Array::from(expected));
|
||||
assert_eq!(expected.as_ref(), result.as_ref());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -614,27 +575,26 @@ mod test {
|
||||
let inputs = [
|
||||
(
|
||||
vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
|
||||
1,
|
||||
1i64,
|
||||
vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(1)],
|
||||
),
|
||||
(
|
||||
vec![Some(-3), None, Some(-1), None, None, Some(2)],
|
||||
0,
|
||||
0i64,
|
||||
vec![Some(-3), None, Some(-1), None, None, Some(0)],
|
||||
),
|
||||
];
|
||||
|
||||
let func = ClampMaxFunction;
|
||||
for (in_data, max, expected) in inputs {
|
||||
let args = [
|
||||
Arc::new(Int64Vector::from(in_data)) as _,
|
||||
Arc::new(Int64Vector::from_vec(vec![max])) as _,
|
||||
let number_rows = in_data.len();
|
||||
let args = vec![
|
||||
ColumnarValue::Array(Arc::new(Int64Array::from(in_data))),
|
||||
ColumnarValue::Scalar(max.into()),
|
||||
];
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), args.as_slice())
|
||||
.unwrap();
|
||||
let expected: VectorRef = Arc::new(Int64Vector::from(expected));
|
||||
assert_eq!(expected, result);
|
||||
let result = func.test_eval(args, number_rows).unwrap();
|
||||
let expected: ArrayRef = Arc::new(Int64Array::from(expected));
|
||||
assert_eq!(expected.as_ref(), result.as_ref());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -648,15 +608,14 @@ mod test {
|
||||
|
||||
let func = ClampMinFunction;
|
||||
for (in_data, min, expected) in inputs {
|
||||
let args = [
|
||||
Arc::new(Float64Vector::from(in_data)) as _,
|
||||
Arc::new(Float64Vector::from_vec(vec![min])) as _,
|
||||
let number_rows = in_data.len();
|
||||
let args = vec![
|
||||
ColumnarValue::Array(Arc::new(Float64Array::from(in_data))),
|
||||
ColumnarValue::Scalar(min.into()),
|
||||
];
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), args.as_slice())
|
||||
.unwrap();
|
||||
let expected: VectorRef = Arc::new(Float64Vector::from(expected));
|
||||
assert_eq!(expected, result);
|
||||
let result = func.test_eval(args, number_rows).unwrap();
|
||||
let expected: ArrayRef = Arc::new(Float64Array::from(expected));
|
||||
assert_eq!(expected.as_ref(), result.as_ref());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -670,43 +629,44 @@ mod test {
|
||||
|
||||
let func = ClampMaxFunction;
|
||||
for (in_data, max, expected) in inputs {
|
||||
let args = [
|
||||
Arc::new(Float64Vector::from(in_data)) as _,
|
||||
Arc::new(Float64Vector::from_vec(vec![max])) as _,
|
||||
let number_rows = in_data.len();
|
||||
let args = vec![
|
||||
ColumnarValue::Array(Arc::new(Float64Array::from(in_data))),
|
||||
ColumnarValue::Scalar(max.into()),
|
||||
];
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), args.as_slice())
|
||||
.unwrap();
|
||||
let expected: VectorRef = Arc::new(Float64Vector::from(expected));
|
||||
assert_eq!(expected, result);
|
||||
let result = func.test_eval(args, number_rows).unwrap();
|
||||
let expected: ArrayRef = Arc::new(Float64Array::from(expected));
|
||||
assert_eq!(expected.as_ref(), result.as_ref());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clamp_min_type_not_match() {
|
||||
let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
|
||||
let min = -1;
|
||||
let min = -1i64;
|
||||
|
||||
let func = ClampMinFunction;
|
||||
let args = [
|
||||
Arc::new(Float64Vector::from(input)) as _,
|
||||
Arc::new(Int64Vector::from_vec(vec![min])) as _,
|
||||
let number_rows = input.len();
|
||||
let args = vec![
|
||||
ColumnarValue::Array(Arc::new(Float64Array::from(input))),
|
||||
ColumnarValue::Scalar(min.into()),
|
||||
];
|
||||
let result = func.eval(&FunctionContext::default(), args.as_slice());
|
||||
let result = func.test_eval(args, number_rows);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clamp_max_type_not_match() {
|
||||
let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
|
||||
let max = 1;
|
||||
let max = 1i64;
|
||||
|
||||
let func = ClampMaxFunction;
|
||||
let args = [
|
||||
Arc::new(Float64Vector::from(input)) as _,
|
||||
Arc::new(Int64Vector::from_vec(vec![max])) as _,
|
||||
let number_rows = input.len();
|
||||
let args = vec![
|
||||
ColumnarValue::Array(Arc::new(Float64Array::from(input))),
|
||||
ColumnarValue::Scalar(max.into()),
|
||||
];
|
||||
let result = func.eval(&FunctionContext::default(), args.as_slice());
|
||||
let result = func.test_eval(args, number_rows);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,6 +65,14 @@ impl ScalarUDFImpl for ScalarUdf {
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<datafusion_expr::ColumnarValue> {
|
||||
let result = self.function.invoke_with_args(args.clone());
|
||||
if !matches!(
|
||||
result,
|
||||
Err(datafusion_common::DataFusionError::NotImplemented(_))
|
||||
) {
|
||||
return result;
|
||||
}
|
||||
|
||||
let columns = args
|
||||
.args
|
||||
.iter()
|
||||
|
||||
@@ -28,7 +28,14 @@ mod vector_norm;
|
||||
mod vector_sub;
|
||||
mod vector_subvector;
|
||||
|
||||
use std::borrow::Cow;
|
||||
|
||||
use datafusion_common::{DataFusionError, Result, ScalarValue, utils};
|
||||
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
|
||||
|
||||
use crate::function_registry::FunctionRegistry;
|
||||
use crate::scalars::vector::impl_conv::as_veclit;
|
||||
|
||||
pub(crate) struct VectorFunction;
|
||||
|
||||
impl VectorFunction {
|
||||
@@ -59,3 +66,155 @@ impl VectorFunction {
|
||||
registry.register_scalar(elem_product::ElemProductFunction);
|
||||
}
|
||||
}
|
||||
|
||||
// Use macro instead of function to "return" the reference to `ScalarValue` in the
|
||||
// `ColumnarValue::Array` match arm.
|
||||
macro_rules! try_get_scalar_value {
|
||||
($col: ident, $i: ident) => {
|
||||
match $col {
|
||||
datafusion::logical_expr::ColumnarValue::Array(a) => {
|
||||
&datafusion_common::ScalarValue::try_from_array(a.as_ref(), $i)?
|
||||
}
|
||||
datafusion::logical_expr::ColumnarValue::Scalar(v) => v,
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub(crate) fn ensure_same_length(values: &[&ColumnarValue]) -> Result<usize> {
|
||||
if values.is_empty() {
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
let mut array_len = None;
|
||||
for v in values {
|
||||
array_len = match (v, array_len) {
|
||||
(ColumnarValue::Array(a), None) => Some(a.len()),
|
||||
(ColumnarValue::Array(a), Some(array_len)) => {
|
||||
if array_len == a.len() {
|
||||
Some(array_len)
|
||||
} else {
|
||||
return Err(DataFusionError::Internal(format!(
|
||||
"Arguments has mixed length. Expected length: {array_len}, found length: {}",
|
||||
a.len()
|
||||
)));
|
||||
}
|
||||
}
|
||||
(ColumnarValue::Scalar(_), array_len) => array_len,
|
||||
}
|
||||
}
|
||||
|
||||
// If array_len is none, it means there are only scalars, treat them each as 1 element array.
|
||||
let array_len = array_len.unwrap_or(1);
|
||||
Ok(array_len)
|
||||
}
|
||||
|
||||
struct VectorCalculator<'a, F> {
|
||||
name: &'a str,
|
||||
func: F,
|
||||
}
|
||||
|
||||
impl<F> VectorCalculator<'_, F>
|
||||
where
|
||||
F: Fn(&ScalarValue, &ScalarValue) -> Result<ScalarValue>,
|
||||
{
|
||||
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
|
||||
let [arg0, arg1] = utils::take_function_args(self.name, &args.args)?;
|
||||
|
||||
if let (ColumnarValue::Scalar(v0), ColumnarValue::Scalar(v1)) = (arg0, arg1) {
|
||||
let result = (self.func)(v0, v1)?;
|
||||
return Ok(ColumnarValue::Scalar(result));
|
||||
}
|
||||
|
||||
let len = ensure_same_length(&[arg0, arg1])?;
|
||||
let mut results = Vec::with_capacity(len);
|
||||
for i in 0..len {
|
||||
let v0 = try_get_scalar_value!(arg0, i);
|
||||
let v1 = try_get_scalar_value!(arg1, i);
|
||||
results.push((self.func)(v0, v1)?);
|
||||
}
|
||||
|
||||
let results = ScalarValue::iter_to_array(results.into_iter())?;
|
||||
Ok(ColumnarValue::Array(results))
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> VectorCalculator<'_, F>
|
||||
where
|
||||
F: Fn(&Option<Cow<[f32]>>, &Option<Cow<[f32]>>) -> Result<ScalarValue>,
|
||||
{
|
||||
fn invoke_with_vectors(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
|
||||
let [arg0, arg1] = utils::take_function_args(self.name, &args.args)?;
|
||||
|
||||
if let (ColumnarValue::Scalar(v0), ColumnarValue::Scalar(v1)) = (arg0, arg1) {
|
||||
let v0 = as_veclit(v0)?;
|
||||
let v1 = as_veclit(v1)?;
|
||||
let result = (self.func)(&v0, &v1)?;
|
||||
return Ok(ColumnarValue::Scalar(result));
|
||||
}
|
||||
|
||||
let len = ensure_same_length(&[arg0, arg1])?;
|
||||
let mut results = Vec::with_capacity(len);
|
||||
|
||||
match (arg0, arg1) {
|
||||
(ColumnarValue::Scalar(v0), ColumnarValue::Array(a1)) => {
|
||||
let v0 = as_veclit(v0)?;
|
||||
for i in 0..len {
|
||||
let v1 = ScalarValue::try_from_array(a1, i)?;
|
||||
let v1 = as_veclit(&v1)?;
|
||||
results.push((self.func)(&v0, &v1)?);
|
||||
}
|
||||
}
|
||||
(ColumnarValue::Array(a0), ColumnarValue::Scalar(v1)) => {
|
||||
let v1 = as_veclit(v1)?;
|
||||
for i in 0..len {
|
||||
let v0 = ScalarValue::try_from_array(a0, i)?;
|
||||
let v0 = as_veclit(&v0)?;
|
||||
results.push((self.func)(&v0, &v1)?);
|
||||
}
|
||||
}
|
||||
(ColumnarValue::Array(a0), ColumnarValue::Array(a1)) => {
|
||||
for i in 0..len {
|
||||
let v0 = ScalarValue::try_from_array(a0, i)?;
|
||||
let v0 = as_veclit(&v0)?;
|
||||
let v1 = ScalarValue::try_from_array(a1, i)?;
|
||||
let v1 = as_veclit(&v1)?;
|
||||
results.push((self.func)(&v0, &v1)?);
|
||||
}
|
||||
}
|
||||
(ColumnarValue::Scalar(_), ColumnarValue::Scalar(_)) => {
|
||||
// unreachable because this arm has been separately dealt with above
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
|
||||
let results = ScalarValue::iter_to_array(results.into_iter())?;
|
||||
Ok(ColumnarValue::Array(results))
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> VectorCalculator<'_, F>
|
||||
where
|
||||
F: Fn(&ScalarValue) -> Result<ScalarValue>,
|
||||
{
|
||||
fn invoke_with_single_argument(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
|
||||
let [arg0] = utils::take_function_args(self.name, &args.args)?;
|
||||
|
||||
let arg0 = match arg0 {
|
||||
ColumnarValue::Scalar(v) => {
|
||||
let result = (self.func)(v)?;
|
||||
return Ok(ColumnarValue::Scalar(result));
|
||||
}
|
||||
ColumnarValue::Array(a) => a,
|
||||
};
|
||||
|
||||
let len = arg0.len();
|
||||
let mut results = Vec::with_capacity(len);
|
||||
for i in 0..len {
|
||||
let v = ScalarValue::try_from_array(arg0, i)?;
|
||||
results.push((self.func)(&v)?);
|
||||
}
|
||||
|
||||
let results = ScalarValue::iter_to_array(results.into_iter())?;
|
||||
Ok(ColumnarValue::Array(results))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ use std::fmt::Display;
|
||||
use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use datafusion::arrow::datatypes::DataType;
|
||||
use datafusion_expr::type_coercion::aggregates::BINARYS;
|
||||
use datafusion_expr::{Signature, Volatility};
|
||||
use datafusion_expr::{Signature, TypeSignature, Volatility};
|
||||
use datatypes::scalars::ScalarVectorBuilder;
|
||||
use datatypes::types::vector_type_value_to_string;
|
||||
use datatypes::value::Value;
|
||||
@@ -41,7 +41,13 @@ impl Function for VectorToStringFunction {
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
Signature::uniform(1, BINARYS.to_vec(), Volatility::Immutable)
|
||||
Signature::one_of(
|
||||
vec![
|
||||
TypeSignature::Uniform(1, vec![DataType::BinaryView]),
|
||||
TypeSignature::Uniform(1, BINARYS.to_vec()),
|
||||
],
|
||||
Volatility::Immutable,
|
||||
)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
|
||||
@@ -19,20 +19,17 @@ mod l2sq;
|
||||
use std::borrow::Cow;
|
||||
use std::fmt::Display;
|
||||
|
||||
use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use datafusion_expr::Signature;
|
||||
use common_query::error::Result;
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion_common::ScalarValue;
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature};
|
||||
use datatypes::arrow::datatypes::DataType;
|
||||
use datatypes::scalars::ScalarVectorBuilder;
|
||||
use datatypes::vectors::{Float32VectorBuilder, MutableVector, VectorRef};
|
||||
use snafu::ensure;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function::Function;
|
||||
use crate::helper;
|
||||
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const};
|
||||
|
||||
macro_rules! define_distance_function {
|
||||
($StructName:ident, $display_name:expr, $similarity_method:path) => {
|
||||
|
||||
/// A function calculates the distance between two vectors.
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
@@ -54,59 +51,34 @@ macro_rules! define_distance_function {
|
||||
)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 2,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect exactly two, have: {}",
|
||||
columns.len()
|
||||
),
|
||||
}
|
||||
);
|
||||
let arg0 = &columns[0];
|
||||
let arg1 = &columns[1];
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let body = |v0: &Option<Cow<[f32]>>,
|
||||
v1: &Option<Cow<[f32]>>|
|
||||
-> datafusion_common::Result<ScalarValue> {
|
||||
let result = if let (Some(v0), Some(v1)) = (v0, v1) {
|
||||
if v0.len() != v1.len() {
|
||||
return Err(datafusion_common::DataFusionError::Execution(format!(
|
||||
"vectors length not match: {}",
|
||||
self.name()
|
||||
)));
|
||||
}
|
||||
|
||||
let size = arg0.len();
|
||||
let mut result = Float32VectorBuilder::with_capacity(size);
|
||||
if size == 0 {
|
||||
return Ok(result.to_vector());
|
||||
}
|
||||
|
||||
let arg0_const = as_veclit_if_const(arg0)?;
|
||||
let arg1_const = as_veclit_if_const(arg1)?;
|
||||
|
||||
for i in 0..size {
|
||||
let vec0 = match arg0_const.as_ref() {
|
||||
Some(a) => Some(Cow::Borrowed(a.as_ref())),
|
||||
None => as_veclit(arg0.get_ref(i))?,
|
||||
};
|
||||
let vec1 = match arg1_const.as_ref() {
|
||||
Some(b) => Some(Cow::Borrowed(b.as_ref())),
|
||||
None => as_veclit(arg1.get_ref(i))?,
|
||||
};
|
||||
|
||||
if let (Some(vec0), Some(vec1)) = (vec0, vec1) {
|
||||
ensure!(
|
||||
vec0.len() == vec1.len(),
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the vectors must match to calculate distance, have: {} vs {}",
|
||||
vec0.len(),
|
||||
vec1.len()
|
||||
),
|
||||
}
|
||||
);
|
||||
|
||||
// Checked if the length of the vectors match
|
||||
let d = $similarity_method(vec0.as_ref(), vec1.as_ref());
|
||||
result.push(Some(d));
|
||||
let d = $similarity_method(v0, v1);
|
||||
Some(d)
|
||||
} else {
|
||||
result.push_null();
|
||||
}
|
||||
}
|
||||
None
|
||||
};
|
||||
Ok(ScalarValue::Float32(result))
|
||||
};
|
||||
|
||||
return Ok(result.to_vector());
|
||||
let calculator = $crate::scalars::vector::VectorCalculator {
|
||||
name: self.name(),
|
||||
func: body,
|
||||
};
|
||||
calculator.invoke_with_vectors(args)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -115,7 +87,7 @@ macro_rules! define_distance_function {
|
||||
write!(f, "{}", $display_name.to_ascii_uppercase())
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
define_distance_function!(CosDistanceFunction, "vec_cos_distance", cos::cos);
|
||||
@@ -126,10 +98,29 @@ define_distance_function!(DotProductFunction, "vec_dot_product", dot::dot);
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use datatypes::vectors::{BinaryVector, ConstantVector, StringVector};
|
||||
use arrow_schema::Field;
|
||||
use datafusion::arrow::array::{Array, ArrayRef, AsArray, BinaryArray, StringViewArray};
|
||||
use datafusion::arrow::datatypes::Float32Type;
|
||||
use datafusion_common::config::ConfigOptions;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn test_invoke(func: &dyn Function, args: &[ArrayRef]) -> datafusion_common::Result<ArrayRef> {
|
||||
let number_rows = args[0].len();
|
||||
let args = ScalarFunctionArgs {
|
||||
args: args
|
||||
.iter()
|
||||
.map(|x| ColumnarValue::Array(x.clone()))
|
||||
.collect::<Vec<_>>(),
|
||||
arg_fields: vec![],
|
||||
number_rows,
|
||||
return_field: Arc::new(Field::new("x", DataType::Float32, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
func.invoke_with_args(args)
|
||||
.and_then(|x| x.to_array(number_rows))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_distance_string_string() {
|
||||
let funcs = [
|
||||
@@ -139,36 +130,34 @@ mod tests {
|
||||
];
|
||||
|
||||
for func in funcs {
|
||||
let vec1 = Arc::new(StringVector::from(vec![
|
||||
let vec1: ArrayRef = Arc::new(StringViewArray::from(vec![
|
||||
Some("[0.0, 1.0]"),
|
||||
Some("[1.0, 0.0]"),
|
||||
None,
|
||||
Some("[1.0, 0.0]"),
|
||||
])) as VectorRef;
|
||||
let vec2 = Arc::new(StringVector::from(vec![
|
||||
]));
|
||||
let vec2: ArrayRef = Arc::new(StringViewArray::from(vec![
|
||||
Some("[0.0, 1.0]"),
|
||||
Some("[0.0, 1.0]"),
|
||||
Some("[0.0, 1.0]"),
|
||||
None,
|
||||
])) as VectorRef;
|
||||
]));
|
||||
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), &[vec1.clone(), vec2.clone()])
|
||||
.unwrap();
|
||||
let result = test_invoke(func.as_ref(), &[vec1.clone(), vec2.clone()]).unwrap();
|
||||
let result = result.as_primitive::<Float32Type>();
|
||||
|
||||
assert!(!result.get(0).is_null());
|
||||
assert!(!result.get(1).is_null());
|
||||
assert!(result.get(2).is_null());
|
||||
assert!(result.get(3).is_null());
|
||||
assert!(!result.is_null(0));
|
||||
assert!(!result.is_null(1));
|
||||
assert!(result.is_null(2));
|
||||
assert!(result.is_null(3));
|
||||
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), &[vec2, vec1])
|
||||
.unwrap();
|
||||
let result = test_invoke(func.as_ref(), &[vec2, vec1]).unwrap();
|
||||
let result = result.as_primitive::<Float32Type>();
|
||||
|
||||
assert!(!result.get(0).is_null());
|
||||
assert!(!result.get(1).is_null());
|
||||
assert!(result.get(2).is_null());
|
||||
assert!(result.get(3).is_null());
|
||||
assert!(!result.is_null(0));
|
||||
assert!(!result.is_null(1));
|
||||
assert!(result.is_null(2));
|
||||
assert!(result.is_null(3));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -181,37 +170,35 @@ mod tests {
|
||||
];
|
||||
|
||||
for func in funcs {
|
||||
let vec1 = Arc::new(BinaryVector::from(vec![
|
||||
let vec1: ArrayRef = Arc::new(BinaryArray::from_iter(vec![
|
||||
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
|
||||
Some(vec![0, 0, 128, 63, 0, 0, 0, 0]),
|
||||
None,
|
||||
Some(vec![0, 0, 128, 63, 0, 0, 0, 0]),
|
||||
])) as VectorRef;
|
||||
let vec2 = Arc::new(BinaryVector::from(vec![
|
||||
]));
|
||||
let vec2: ArrayRef = Arc::new(BinaryArray::from_iter(vec![
|
||||
// [0.0, 1.0]
|
||||
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
|
||||
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
|
||||
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
|
||||
None,
|
||||
])) as VectorRef;
|
||||
]));
|
||||
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), &[vec1.clone(), vec2.clone()])
|
||||
.unwrap();
|
||||
let result = test_invoke(func.as_ref(), &[vec1.clone(), vec2.clone()]).unwrap();
|
||||
let result = result.as_primitive::<Float32Type>();
|
||||
|
||||
assert!(!result.get(0).is_null());
|
||||
assert!(!result.get(1).is_null());
|
||||
assert!(result.get(2).is_null());
|
||||
assert!(result.get(3).is_null());
|
||||
assert!(!result.is_null(0));
|
||||
assert!(!result.is_null(1));
|
||||
assert!(result.is_null(2));
|
||||
assert!(result.is_null(3));
|
||||
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), &[vec2, vec1])
|
||||
.unwrap();
|
||||
let result = test_invoke(func.as_ref(), &[vec2, vec1]).unwrap();
|
||||
let result = result.as_primitive::<Float32Type>();
|
||||
|
||||
assert!(!result.get(0).is_null());
|
||||
assert!(!result.get(1).is_null());
|
||||
assert!(result.get(2).is_null());
|
||||
assert!(result.get(3).is_null());
|
||||
assert!(!result.is_null(0));
|
||||
assert!(!result.is_null(1));
|
||||
assert!(result.is_null(2));
|
||||
assert!(result.is_null(3));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -224,115 +211,35 @@ mod tests {
|
||||
];
|
||||
|
||||
for func in funcs {
|
||||
let vec1 = Arc::new(StringVector::from(vec![
|
||||
let vec1: ArrayRef = Arc::new(StringViewArray::from(vec![
|
||||
Some("[0.0, 1.0]"),
|
||||
Some("[1.0, 0.0]"),
|
||||
None,
|
||||
Some("[1.0, 0.0]"),
|
||||
])) as VectorRef;
|
||||
let vec2 = Arc::new(BinaryVector::from(vec![
|
||||
]));
|
||||
let vec2: ArrayRef = Arc::new(BinaryArray::from_iter(vec![
|
||||
// [0.0, 1.0]
|
||||
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
|
||||
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
|
||||
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
|
||||
None,
|
||||
])) as VectorRef;
|
||||
]));
|
||||
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), &[vec1.clone(), vec2.clone()])
|
||||
.unwrap();
|
||||
let result = test_invoke(func.as_ref(), &[vec1.clone(), vec2.clone()]).unwrap();
|
||||
let result = result.as_primitive::<Float32Type>();
|
||||
|
||||
assert!(!result.get(0).is_null());
|
||||
assert!(!result.get(1).is_null());
|
||||
assert!(result.get(2).is_null());
|
||||
assert!(result.get(3).is_null());
|
||||
assert!(!result.is_null(0));
|
||||
assert!(!result.is_null(1));
|
||||
assert!(result.is_null(2));
|
||||
assert!(result.is_null(3));
|
||||
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), &[vec2, vec1])
|
||||
.unwrap();
|
||||
let result = test_invoke(func.as_ref(), &[vec2, vec1]).unwrap();
|
||||
let result = result.as_primitive::<Float32Type>();
|
||||
|
||||
assert!(!result.get(0).is_null());
|
||||
assert!(!result.get(1).is_null());
|
||||
assert!(result.get(2).is_null());
|
||||
assert!(result.get(3).is_null());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_distance_const_string() {
|
||||
let funcs = [
|
||||
Box::new(CosDistanceFunction {}) as Box<dyn Function>,
|
||||
Box::new(L2SqDistanceFunction {}) as Box<dyn Function>,
|
||||
Box::new(DotProductFunction {}) as Box<dyn Function>,
|
||||
];
|
||||
|
||||
for func in funcs {
|
||||
let const_str = Arc::new(ConstantVector::new(
|
||||
Arc::new(StringVector::from(vec!["[0.0, 1.0]"])),
|
||||
4,
|
||||
));
|
||||
|
||||
let vec1 = Arc::new(StringVector::from(vec![
|
||||
Some("[0.0, 1.0]"),
|
||||
Some("[1.0, 0.0]"),
|
||||
None,
|
||||
Some("[1.0, 0.0]"),
|
||||
])) as VectorRef;
|
||||
let vec2 = Arc::new(BinaryVector::from(vec![
|
||||
// [0.0, 1.0]
|
||||
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
|
||||
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
|
||||
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
|
||||
None,
|
||||
])) as VectorRef;
|
||||
|
||||
let result = func
|
||||
.eval(
|
||||
&FunctionContext::default(),
|
||||
&[const_str.clone(), vec1.clone()],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.get(0).is_null());
|
||||
assert!(!result.get(1).is_null());
|
||||
assert!(result.get(2).is_null());
|
||||
assert!(!result.get(3).is_null());
|
||||
|
||||
let result = func
|
||||
.eval(
|
||||
&FunctionContext::default(),
|
||||
&[vec1.clone(), const_str.clone()],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.get(0).is_null());
|
||||
assert!(!result.get(1).is_null());
|
||||
assert!(result.get(2).is_null());
|
||||
assert!(!result.get(3).is_null());
|
||||
|
||||
let result = func
|
||||
.eval(
|
||||
&FunctionContext::default(),
|
||||
&[const_str.clone(), vec2.clone()],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.get(0).is_null());
|
||||
assert!(!result.get(1).is_null());
|
||||
assert!(!result.get(2).is_null());
|
||||
assert!(result.get(3).is_null());
|
||||
|
||||
let result = func
|
||||
.eval(
|
||||
&FunctionContext::default(),
|
||||
&[vec2.clone(), const_str.clone()],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.get(0).is_null());
|
||||
assert!(!result.get(1).is_null());
|
||||
assert!(!result.get(2).is_null());
|
||||
assert!(result.get(3).is_null());
|
||||
assert!(!result.is_null(0));
|
||||
assert!(!result.is_null(1));
|
||||
assert!(result.is_null(2));
|
||||
assert!(result.is_null(3));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -345,15 +252,16 @@ mod tests {
|
||||
];
|
||||
|
||||
for func in funcs {
|
||||
let vec1 = Arc::new(StringVector::from(vec!["[1.0]"])) as VectorRef;
|
||||
let vec2 = Arc::new(StringVector::from(vec!["[1.0, 1.0]"])) as VectorRef;
|
||||
let result = func.eval(&FunctionContext::default(), &[vec1, vec2]);
|
||||
let vec1: ArrayRef = Arc::new(StringViewArray::from(vec!["[1.0]"]));
|
||||
let vec2: ArrayRef = Arc::new(StringViewArray::from(vec!["[1.0, 1.0]"]));
|
||||
let result = test_invoke(func.as_ref(), &[vec1, vec2]);
|
||||
assert!(result.is_err());
|
||||
|
||||
let vec1 = Arc::new(BinaryVector::from(vec![vec![0, 0, 128, 63]])) as VectorRef;
|
||||
let vec2 =
|
||||
Arc::new(BinaryVector::from(vec![vec![0, 0, 128, 63, 0, 0, 0, 64]])) as VectorRef;
|
||||
let result = func.eval(&FunctionContext::default(), &[vec1, vec2]);
|
||||
let vec1: ArrayRef = Arc::new(BinaryArray::from_iter_values(vec![vec![0, 0, 128, 63]]));
|
||||
let vec2: ArrayRef = Arc::new(BinaryArray::from_iter_values(vec![vec![
|
||||
0, 0, 128, 63, 0, 0, 0, 64,
|
||||
]]));
|
||||
let result = test_invoke(func.as_ref(), &[vec1, vec2]);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,20 +12,18 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::borrow::Cow;
|
||||
use std::fmt::Display;
|
||||
|
||||
use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use common_query::error::Result;
|
||||
use datafusion::arrow::datatypes::DataType;
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion::logical_expr_common::type_coercion::aggregates::{BINARYS, STRINGS};
|
||||
use datafusion_expr::{Signature, TypeSignature, Volatility};
|
||||
use datatypes::scalars::ScalarVectorBuilder;
|
||||
use datatypes::vectors::{Float32VectorBuilder, MutableVector, VectorRef};
|
||||
use datafusion_common::ScalarValue;
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
|
||||
use nalgebra::DVectorView;
|
||||
use snafu::ensure;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const};
|
||||
use crate::function::Function;
|
||||
use crate::scalars::vector::{VectorCalculator, impl_conv};
|
||||
|
||||
const NAME: &str = "vec_elem_product";
|
||||
|
||||
@@ -64,43 +62,21 @@ impl Function for ElemProductFunction {
|
||||
)
|
||||
}
|
||||
|
||||
fn eval(
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
_func_ctx: &FunctionContext,
|
||||
columns: &[VectorRef],
|
||||
) -> common_query::error::Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 1,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect exactly one, have: {}",
|
||||
columns.len()
|
||||
)
|
||||
}
|
||||
);
|
||||
let arg0 = &columns[0];
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let body = |v0: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
|
||||
let v0 = impl_conv::as_veclit(v0)?
|
||||
.map(|v0| DVectorView::from_slice(&v0, v0.len()).product());
|
||||
Ok(ScalarValue::Float32(v0))
|
||||
};
|
||||
|
||||
let len = arg0.len();
|
||||
let mut result = Float32VectorBuilder::with_capacity(len);
|
||||
if len == 0 {
|
||||
return Ok(result.to_vector());
|
||||
}
|
||||
|
||||
let arg0_const = as_veclit_if_const(arg0)?;
|
||||
|
||||
for i in 0..len {
|
||||
let arg0 = match arg0_const.as_ref() {
|
||||
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
|
||||
None => as_veclit(arg0.get_ref(i))?,
|
||||
};
|
||||
let Some(arg0) = arg0 else {
|
||||
result.push_null();
|
||||
continue;
|
||||
};
|
||||
result.push(Some(DVectorView::from_slice(&arg0, arg0.len()).product()));
|
||||
}
|
||||
|
||||
Ok(result.to_vector())
|
||||
let calculator = VectorCalculator {
|
||||
name: self.name(),
|
||||
func: body,
|
||||
};
|
||||
calculator.invoke_with_single_argument(args)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,27 +90,39 @@ impl Display for ElemProductFunction {
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use datatypes::vectors::StringVector;
|
||||
use arrow_schema::Field;
|
||||
use datafusion::arrow::array::{Array, AsArray, StringArray};
|
||||
use datafusion::arrow::datatypes::Float32Type;
|
||||
use datafusion_common::config::ConfigOptions;
|
||||
|
||||
use super::*;
|
||||
use crate::function::FunctionContext;
|
||||
|
||||
#[test]
|
||||
fn test_elem_product() {
|
||||
let func = ElemProductFunction;
|
||||
|
||||
let input0 = Arc::new(StringVector::from(vec![
|
||||
let input = Arc::new(StringArray::from(vec![
|
||||
Some("[1.0,2.0,3.0]".to_string()),
|
||||
Some("[4.0,5.0,6.0]".to_string()),
|
||||
None,
|
||||
]));
|
||||
|
||||
let result = func.eval(&FunctionContext::default(), &[input0]).unwrap();
|
||||
let result = func
|
||||
.invoke_with_args(ScalarFunctionArgs {
|
||||
args: vec![ColumnarValue::Array(input.clone())],
|
||||
arg_fields: vec![],
|
||||
number_rows: input.len(),
|
||||
return_field: Arc::new(Field::new("x", DataType::Float32, true)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
})
|
||||
.and_then(|v| ColumnarValue::values_to_arrays(&[v]))
|
||||
.map(|mut a| a.remove(0))
|
||||
.unwrap();
|
||||
let result = result.as_primitive::<Float32Type>();
|
||||
|
||||
let result = result.as_ref();
|
||||
assert_eq!(result.len(), 3);
|
||||
assert_eq!(result.get_ref(0).as_f32().unwrap(), Some(6.0));
|
||||
assert_eq!(result.get_ref(1).as_f32().unwrap(), Some(120.0));
|
||||
assert_eq!(result.get_ref(2).as_f32().unwrap(), None);
|
||||
assert_eq!(result.value(0), 6.0);
|
||||
assert_eq!(result.value(1), 120.0);
|
||||
assert!(result.is_null(2));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,20 +12,18 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::borrow::Cow;
|
||||
use std::fmt::Display;
|
||||
|
||||
use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use common_query::error::Result;
|
||||
use datafusion::arrow::datatypes::DataType;
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion_common::ScalarValue;
|
||||
use datafusion_expr::type_coercion::aggregates::{BINARYS, STRINGS};
|
||||
use datafusion_expr::{Signature, TypeSignature, Volatility};
|
||||
use datatypes::scalars::ScalarVectorBuilder;
|
||||
use datatypes::vectors::{Float32VectorBuilder, MutableVector, VectorRef};
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
|
||||
use nalgebra::DVectorView;
|
||||
use snafu::ensure;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const};
|
||||
use crate::function::Function;
|
||||
use crate::scalars::vector::{VectorCalculator, impl_conv};
|
||||
|
||||
const NAME: &str = "vec_elem_sum";
|
||||
|
||||
@@ -51,43 +49,21 @@ impl Function for ElemSumFunction {
|
||||
)
|
||||
}
|
||||
|
||||
fn eval(
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
_func_ctx: &FunctionContext,
|
||||
columns: &[VectorRef],
|
||||
) -> common_query::error::Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 1,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect exactly one, have: {}",
|
||||
columns.len()
|
||||
)
|
||||
}
|
||||
);
|
||||
let arg0 = &columns[0];
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let body = |v0: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
|
||||
let v0 =
|
||||
impl_conv::as_veclit(v0)?.map(|v0| DVectorView::from_slice(&v0, v0.len()).sum());
|
||||
Ok(ScalarValue::Float32(v0))
|
||||
};
|
||||
|
||||
let len = arg0.len();
|
||||
let mut result = Float32VectorBuilder::with_capacity(len);
|
||||
if len == 0 {
|
||||
return Ok(result.to_vector());
|
||||
}
|
||||
|
||||
let arg0_const = as_veclit_if_const(arg0)?;
|
||||
|
||||
for i in 0..len {
|
||||
let arg0 = match arg0_const.as_ref() {
|
||||
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
|
||||
None => as_veclit(arg0.get_ref(i))?,
|
||||
};
|
||||
let Some(arg0) = arg0 else {
|
||||
result.push_null();
|
||||
continue;
|
||||
};
|
||||
result.push(Some(DVectorView::from_slice(&arg0, arg0.len()).sum()));
|
||||
}
|
||||
|
||||
Ok(result.to_vector())
|
||||
let calculator = VectorCalculator {
|
||||
name: self.name(),
|
||||
func: body,
|
||||
};
|
||||
calculator.invoke_with_single_argument(args)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -101,27 +77,40 @@ impl Display for ElemSumFunction {
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use datatypes::vectors::StringVector;
|
||||
use arrow::array::StringViewArray;
|
||||
use arrow_schema::Field;
|
||||
use datafusion::arrow::array::{Array, AsArray};
|
||||
use datafusion::arrow::datatypes::Float32Type;
|
||||
use datafusion_common::config::ConfigOptions;
|
||||
|
||||
use super::*;
|
||||
use crate::function::FunctionContext;
|
||||
|
||||
#[test]
|
||||
fn test_elem_sum() {
|
||||
let func = ElemSumFunction;
|
||||
|
||||
let input0 = Arc::new(StringVector::from(vec![
|
||||
let input = Arc::new(StringViewArray::from(vec![
|
||||
Some("[1.0,2.0,3.0]".to_string()),
|
||||
Some("[4.0,5.0,6.0]".to_string()),
|
||||
None,
|
||||
]));
|
||||
|
||||
let result = func.eval(&FunctionContext::default(), &[input0]).unwrap();
|
||||
let result = func
|
||||
.invoke_with_args(ScalarFunctionArgs {
|
||||
args: vec![ColumnarValue::Array(input.clone())],
|
||||
arg_fields: vec![],
|
||||
number_rows: input.len(),
|
||||
return_field: Arc::new(Field::new("x", DataType::Float32, true)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
})
|
||||
.and_then(|v| ColumnarValue::values_to_arrays(&[v]))
|
||||
.map(|mut a| a.remove(0))
|
||||
.unwrap();
|
||||
let result = result.as_primitive::<Float32Type>();
|
||||
|
||||
let result = result.as_ref();
|
||||
assert_eq!(result.len(), 3);
|
||||
assert_eq!(result.get_ref(0).as_f32().unwrap(), Some(6.0));
|
||||
assert_eq!(result.get_ref(1).as_f32().unwrap(), Some(15.0));
|
||||
assert_eq!(result.get_ref(2).as_f32().unwrap(), None);
|
||||
assert_eq!(result.value(0), 6.0);
|
||||
assert_eq!(result.value(1), 15.0);
|
||||
assert!(result.is_null(2));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,40 +13,18 @@
|
||||
// limitations under the License.
|
||||
|
||||
use std::borrow::Cow;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use datatypes::prelude::ConcreteDataType;
|
||||
use datatypes::value::ValueRef;
|
||||
use datatypes::vectors::Vector;
|
||||
|
||||
/// Convert a constant string or binary literal to a vector literal.
|
||||
pub fn as_veclit_if_const(arg: &Arc<dyn Vector>) -> Result<Option<Cow<'_, [f32]>>> {
|
||||
if !arg.is_const() {
|
||||
return Ok(None);
|
||||
}
|
||||
if arg.data_type() != ConcreteDataType::string_datatype()
|
||||
&& arg.data_type() != ConcreteDataType::binary_datatype()
|
||||
{
|
||||
return Ok(None);
|
||||
}
|
||||
as_veclit(arg.get_ref(0))
|
||||
}
|
||||
use datafusion_common::ScalarValue;
|
||||
|
||||
/// Convert a string or binary literal to a vector literal.
|
||||
pub fn as_veclit(arg: ValueRef<'_>) -> Result<Option<Cow<'_, [f32]>>> {
|
||||
match arg.data_type() {
|
||||
ConcreteDataType::Binary(_) => arg
|
||||
.as_binary()
|
||||
.unwrap() // Safe: checked if it is a binary
|
||||
.map(binlit_as_veclit)
|
||||
pub fn as_veclit(arg: &ScalarValue) -> Result<Option<Cow<'_, [f32]>>> {
|
||||
match arg {
|
||||
ScalarValue::Binary(b) => b.as_ref().map(|x| binlit_as_veclit(x)).transpose(),
|
||||
ScalarValue::Utf8(s) | ScalarValue::Utf8View(s) => s
|
||||
.as_ref()
|
||||
.map(|x| parse_veclit_from_strlit(x).map(Cow::Owned))
|
||||
.transpose(),
|
||||
ConcreteDataType::String(_) => arg
|
||||
.as_string()
|
||||
.unwrap() // Safe: checked if it is a string
|
||||
.map(|s| Ok(Cow::Owned(parse_veclit_from_strlit(s)?)))
|
||||
.transpose(),
|
||||
ConcreteDataType::Null(_) => Ok(None),
|
||||
_ => InvalidFuncArgsSnafu {
|
||||
err_msg: format!("Unsupported data type: {:?}", arg.data_type()),
|
||||
}
|
||||
|
||||
@@ -12,20 +12,19 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::borrow::Cow;
|
||||
use std::fmt::Display;
|
||||
|
||||
use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use datafusion_expr::Signature;
|
||||
use datatypes::arrow::datatypes::DataType;
|
||||
use datatypes::scalars::ScalarVectorBuilder;
|
||||
use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
|
||||
use common_query::error::Result;
|
||||
use datafusion::arrow::datatypes::DataType;
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion_common::ScalarValue;
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature};
|
||||
use nalgebra::DVectorView;
|
||||
use snafu::ensure;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function::Function;
|
||||
use crate::helper;
|
||||
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
|
||||
use crate::scalars::vector::VectorCalculator;
|
||||
use crate::scalars::vector::impl_conv::{as_veclit, veclit_to_binlit};
|
||||
|
||||
const NAME: &str = "vec_scalar_add";
|
||||
|
||||
@@ -60,7 +59,7 @@ impl Function for ScalarAddFunction {
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
|
||||
Ok(DataType::Binary)
|
||||
Ok(DataType::BinaryView)
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
@@ -70,52 +69,26 @@ impl Function for ScalarAddFunction {
|
||||
)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 2,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect exactly two, have: {}",
|
||||
columns.len()
|
||||
),
|
||||
}
|
||||
);
|
||||
let arg0 = &columns[0];
|
||||
let arg1 = &columns[1];
|
||||
|
||||
let len = arg0.len();
|
||||
let mut result = BinaryVectorBuilder::with_capacity(len);
|
||||
if len == 0 {
|
||||
return Ok(result.to_vector());
|
||||
}
|
||||
|
||||
let arg1_const = as_veclit_if_const(arg1)?;
|
||||
|
||||
for i in 0..len {
|
||||
let arg0 = arg0.get(i).as_f64_lossy();
|
||||
let Some(arg0) = arg0 else {
|
||||
result.push_null();
|
||||
continue;
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let body = |v0: &ScalarValue, v1: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
|
||||
let ScalarValue::Float64(Some(v0)) = v0 else {
|
||||
return Ok(ScalarValue::BinaryView(None));
|
||||
};
|
||||
|
||||
let arg1 = match arg1_const.as_ref() {
|
||||
Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())),
|
||||
None => as_veclit(arg1.get_ref(i))?,
|
||||
};
|
||||
let Some(arg1) = arg1 else {
|
||||
result.push_null();
|
||||
continue;
|
||||
};
|
||||
let v1 = as_veclit(v1)?
|
||||
.map(|v1| DVectorView::from_slice(&v1, v1.len()).add_scalar(*v0 as f32));
|
||||
let result = v1.map(|v1| veclit_to_binlit(v1.as_slice()));
|
||||
Ok(ScalarValue::BinaryView(result))
|
||||
};
|
||||
|
||||
let vec = DVectorView::from_slice(&arg1, arg1.len());
|
||||
let vec_res = vec.add_scalar(arg0 as _);
|
||||
|
||||
let veclit = vec_res.as_slice();
|
||||
let binlit = veclit_to_binlit(veclit);
|
||||
result.push(Some(&binlit));
|
||||
}
|
||||
|
||||
Ok(result.to_vector())
|
||||
let calculator = VectorCalculator {
|
||||
name: self.name(),
|
||||
func: body,
|
||||
};
|
||||
calculator.invoke_with_args(args)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -129,7 +102,9 @@ impl Display for ScalarAddFunction {
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use datatypes::vectors::{Float32Vector, StringVector};
|
||||
use arrow_schema::Field;
|
||||
use datafusion::arrow::array::{Array, AsArray, Float64Array, StringViewArray};
|
||||
use datafusion_common::config::ConfigOptions;
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -137,34 +112,42 @@ mod tests {
|
||||
fn test_scalar_add() {
|
||||
let func = ScalarAddFunction;
|
||||
|
||||
let input0 = Arc::new(Float32Vector::from(vec![
|
||||
let input0 = Arc::new(Float64Array::from(vec![
|
||||
Some(1.0),
|
||||
Some(-1.0),
|
||||
None,
|
||||
Some(3.0),
|
||||
]));
|
||||
let input1 = Arc::new(StringVector::from(vec![
|
||||
let input1 = Arc::new(StringViewArray::from(vec![
|
||||
Some("[1.0,2.0,3.0]".to_string()),
|
||||
Some("[4.0,5.0,6.0]".to_string()),
|
||||
Some("[7.0,8.0,9.0]".to_string()),
|
||||
None,
|
||||
]));
|
||||
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
|
||||
arg_fields: vec![],
|
||||
number_rows: 4,
|
||||
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), &[input0, input1])
|
||||
.invoke_with_args(args)
|
||||
.and_then(|x| x.to_array(4))
|
||||
.unwrap();
|
||||
|
||||
let result = result.as_ref();
|
||||
let result = result.as_binary_view();
|
||||
assert_eq!(result.len(), 4);
|
||||
assert_eq!(
|
||||
result.get_ref(0).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[2.0, 3.0, 4.0]).as_slice())
|
||||
result.value(0),
|
||||
veclit_to_binlit(&[2.0, 3.0, 4.0]).as_slice()
|
||||
);
|
||||
assert_eq!(
|
||||
result.get_ref(1).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[3.0, 4.0, 5.0]).as_slice())
|
||||
result.value(1),
|
||||
veclit_to_binlit(&[3.0, 4.0, 5.0]).as_slice()
|
||||
);
|
||||
assert!(result.get_ref(2).is_null());
|
||||
assert!(result.get_ref(3).is_null());
|
||||
assert!(result.is_null(2));
|
||||
assert!(result.is_null(3));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,20 +12,19 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::borrow::Cow;
|
||||
use std::fmt::Display;
|
||||
|
||||
use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use datafusion_expr::Signature;
|
||||
use datatypes::arrow::datatypes::DataType;
|
||||
use datatypes::scalars::ScalarVectorBuilder;
|
||||
use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
|
||||
use common_query::error::Result;
|
||||
use datafusion::arrow::datatypes::DataType;
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion_common::ScalarValue;
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature};
|
||||
use nalgebra::DVectorView;
|
||||
use snafu::ensure;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function::Function;
|
||||
use crate::helper;
|
||||
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
|
||||
use crate::scalars::vector::VectorCalculator;
|
||||
use crate::scalars::vector::impl_conv::{as_veclit, veclit_to_binlit};
|
||||
|
||||
const NAME: &str = "vec_scalar_mul";
|
||||
|
||||
@@ -60,7 +59,7 @@ impl Function for ScalarMulFunction {
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
|
||||
Ok(DataType::Binary)
|
||||
Ok(DataType::BinaryView)
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
@@ -70,52 +69,26 @@ impl Function for ScalarMulFunction {
|
||||
)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 2,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect exactly two, have: {}",
|
||||
columns.len()
|
||||
),
|
||||
}
|
||||
);
|
||||
let arg0 = &columns[0];
|
||||
let arg1 = &columns[1];
|
||||
|
||||
let len = arg0.len();
|
||||
let mut result = BinaryVectorBuilder::with_capacity(len);
|
||||
if len == 0 {
|
||||
return Ok(result.to_vector());
|
||||
}
|
||||
|
||||
let arg1_const = as_veclit_if_const(arg1)?;
|
||||
|
||||
for i in 0..len {
|
||||
let arg0 = arg0.get(i).as_f64_lossy();
|
||||
let Some(arg0) = arg0 else {
|
||||
result.push_null();
|
||||
continue;
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let body = |v0: &ScalarValue, v1: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
|
||||
let ScalarValue::Float64(Some(v0)) = v0 else {
|
||||
return Ok(ScalarValue::BinaryView(None));
|
||||
};
|
||||
|
||||
let arg1 = match arg1_const.as_ref() {
|
||||
Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())),
|
||||
None => as_veclit(arg1.get_ref(i))?,
|
||||
};
|
||||
let Some(arg1) = arg1 else {
|
||||
result.push_null();
|
||||
continue;
|
||||
};
|
||||
let v1 =
|
||||
as_veclit(v1)?.map(|v1| DVectorView::from_slice(&v1, v1.len()).scale(*v0 as f32));
|
||||
let result = v1.map(|v1| veclit_to_binlit(v1.as_slice()));
|
||||
Ok(ScalarValue::BinaryView(result))
|
||||
};
|
||||
|
||||
let vec = DVectorView::from_slice(&arg1, arg1.len());
|
||||
let vec_res = vec.scale(arg0 as _);
|
||||
|
||||
let veclit = vec_res.as_slice();
|
||||
let binlit = veclit_to_binlit(veclit);
|
||||
result.push(Some(&binlit));
|
||||
}
|
||||
|
||||
Ok(result.to_vector())
|
||||
let calculator = VectorCalculator {
|
||||
name: self.name(),
|
||||
func: body,
|
||||
};
|
||||
calculator.invoke_with_args(args)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -129,7 +102,9 @@ impl Display for ScalarMulFunction {
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use datatypes::vectors::{Float32Vector, StringVector};
|
||||
use arrow_schema::Field;
|
||||
use datafusion::arrow::array::{Array, AsArray, Float64Array, StringViewArray};
|
||||
use datafusion_common::config::ConfigOptions;
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -137,34 +112,42 @@ mod tests {
|
||||
fn test_scalar_mul() {
|
||||
let func = ScalarMulFunction;
|
||||
|
||||
let input0 = Arc::new(Float32Vector::from(vec![
|
||||
let input0 = Arc::new(Float64Array::from(vec![
|
||||
Some(2.0),
|
||||
Some(-0.5),
|
||||
None,
|
||||
Some(3.0),
|
||||
]));
|
||||
let input1 = Arc::new(StringVector::from(vec![
|
||||
let input1 = Arc::new(StringViewArray::from(vec![
|
||||
Some("[1.0,2.0,3.0]".to_string()),
|
||||
Some("[8.0,10.0,12.0]".to_string()),
|
||||
Some("[7.0,8.0,9.0]".to_string()),
|
||||
None,
|
||||
]));
|
||||
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
|
||||
arg_fields: vec![],
|
||||
number_rows: 4,
|
||||
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), &[input0, input1])
|
||||
.invoke_with_args(args)
|
||||
.and_then(|x| x.to_array(4))
|
||||
.unwrap();
|
||||
|
||||
let result = result.as_ref();
|
||||
let result = result.as_binary_view();
|
||||
assert_eq!(result.len(), 4);
|
||||
assert_eq!(
|
||||
result.get_ref(0).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[2.0, 4.0, 6.0]).as_slice())
|
||||
result.value(0),
|
||||
veclit_to_binlit(&[2.0, 4.0, 6.0]).as_slice()
|
||||
);
|
||||
assert_eq!(
|
||||
result.get_ref(1).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[-4.0, -5.0, -6.0]).as_slice())
|
||||
result.value(1),
|
||||
veclit_to_binlit(&[-4.0, -5.0, -6.0]).as_slice()
|
||||
);
|
||||
assert!(result.get_ref(2).is_null());
|
||||
assert!(result.get_ref(3).is_null());
|
||||
assert!(result.is_null(2));
|
||||
assert!(result.is_null(3));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,17 +15,17 @@
|
||||
use std::borrow::Cow;
|
||||
use std::fmt::Display;
|
||||
|
||||
use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use datafusion_expr::Signature;
|
||||
use datatypes::arrow::datatypes::DataType;
|
||||
use datatypes::scalars::ScalarVectorBuilder;
|
||||
use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
|
||||
use common_query::error::Result;
|
||||
use datafusion::arrow::datatypes::DataType;
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion_common::{DataFusionError, ScalarValue};
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature};
|
||||
use nalgebra::DVectorView;
|
||||
use snafu::ensure;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function::Function;
|
||||
use crate::helper;
|
||||
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
|
||||
use crate::scalars::vector::VectorCalculator;
|
||||
use crate::scalars::vector::impl_conv::veclit_to_binlit;
|
||||
|
||||
const NAME: &str = "vec_add";
|
||||
|
||||
@@ -51,7 +51,7 @@ impl Function for VectorAddFunction {
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
|
||||
Ok(DataType::Binary)
|
||||
Ok(DataType::BinaryView)
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
@@ -61,66 +61,36 @@ impl Function for VectorAddFunction {
|
||||
)
|
||||
}
|
||||
|
||||
fn eval(
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
_func_ctx: &FunctionContext,
|
||||
columns: &[VectorRef],
|
||||
) -> common_query::error::Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 2,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect exactly two, have: {}",
|
||||
columns.len()
|
||||
)
|
||||
}
|
||||
);
|
||||
let arg0 = &columns[0];
|
||||
let arg1 = &columns[1];
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let body = |v0: &Option<Cow<[f32]>>,
|
||||
v1: &Option<Cow<[f32]>>|
|
||||
-> datafusion_common::Result<ScalarValue> {
|
||||
let result = if let (Some(v0), Some(v1)) = (v0, v1) {
|
||||
let v0 = DVectorView::from_slice(v0, v0.len());
|
||||
let v1 = DVectorView::from_slice(v1, v1.len());
|
||||
if v0.len() != v1.len() {
|
||||
return Err(DataFusionError::Execution(format!(
|
||||
"vectors length not match: {}",
|
||||
self.name()
|
||||
)));
|
||||
}
|
||||
|
||||
ensure!(
|
||||
arg0.len() == arg1.len(),
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The lengths of the vector are not aligned, args 0: {}, args 1: {}",
|
||||
arg0.len(),
|
||||
arg1.len(),
|
||||
)
|
||||
}
|
||||
);
|
||||
|
||||
let len = arg0.len();
|
||||
let mut result = BinaryVectorBuilder::with_capacity(len);
|
||||
if len == 0 {
|
||||
return Ok(result.to_vector());
|
||||
}
|
||||
|
||||
let arg0_const = as_veclit_if_const(arg0)?;
|
||||
let arg1_const = as_veclit_if_const(arg1)?;
|
||||
|
||||
for i in 0..len {
|
||||
let arg0 = match arg0_const.as_ref() {
|
||||
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
|
||||
None => as_veclit(arg0.get_ref(i))?,
|
||||
let result = veclit_to_binlit((v0 + v1).as_slice());
|
||||
Some(result)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let arg1 = match arg1_const.as_ref() {
|
||||
Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())),
|
||||
None => as_veclit(arg1.get_ref(i))?,
|
||||
};
|
||||
let (Some(arg0), Some(arg1)) = (arg0, arg1) else {
|
||||
result.push_null();
|
||||
continue;
|
||||
};
|
||||
let vec0 = DVectorView::from_slice(&arg0, arg0.len());
|
||||
let vec1 = DVectorView::from_slice(&arg1, arg1.len());
|
||||
Ok(ScalarValue::BinaryView(result))
|
||||
};
|
||||
|
||||
let vec_res = vec0 + vec1;
|
||||
let veclit = vec_res.as_slice();
|
||||
let binlit = veclit_to_binlit(veclit);
|
||||
result.push(Some(&binlit));
|
||||
}
|
||||
|
||||
Ok(result.to_vector())
|
||||
let calculator = VectorCalculator {
|
||||
name: self.name(),
|
||||
func: body,
|
||||
};
|
||||
calculator.invoke_with_vectors(args)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -134,8 +104,9 @@ impl Display for VectorAddFunction {
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error::Error;
|
||||
use datatypes::vectors::StringVector;
|
||||
use arrow_schema::Field;
|
||||
use datafusion::arrow::array::{Array, AsArray, StringViewArray};
|
||||
use datafusion_common::config::ConfigOptions;
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -143,63 +114,71 @@ mod tests {
|
||||
fn test_sub() {
|
||||
let func = VectorAddFunction;
|
||||
|
||||
let input0 = Arc::new(StringVector::from(vec![
|
||||
let input0 = Arc::new(StringViewArray::from(vec![
|
||||
Some("[1.0,2.0,3.0]".to_string()),
|
||||
Some("[4.0,5.0,6.0]".to_string()),
|
||||
None,
|
||||
Some("[2.0,3.0,3.0]".to_string()),
|
||||
]));
|
||||
let input1 = Arc::new(StringVector::from(vec![
|
||||
let input1 = Arc::new(StringViewArray::from(vec![
|
||||
Some("[1.0,1.0,1.0]".to_string()),
|
||||
Some("[6.0,5.0,4.0]".to_string()),
|
||||
Some("[3.0,2.0,2.0]".to_string()),
|
||||
None,
|
||||
]));
|
||||
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
|
||||
arg_fields: vec![],
|
||||
number_rows: 4,
|
||||
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), &[input0, input1])
|
||||
.invoke_with_args(args)
|
||||
.and_then(|x| x.to_array(4))
|
||||
.unwrap();
|
||||
|
||||
let result = result.as_ref();
|
||||
let result = result.as_binary_view();
|
||||
assert_eq!(result.len(), 4);
|
||||
assert_eq!(
|
||||
result.get_ref(0).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[2.0, 3.0, 4.0]).as_slice())
|
||||
result.value(0),
|
||||
veclit_to_binlit(&[2.0, 3.0, 4.0]).as_slice()
|
||||
);
|
||||
assert_eq!(
|
||||
result.get_ref(1).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[10.0, 10.0, 10.0]).as_slice())
|
||||
result.value(1),
|
||||
veclit_to_binlit(&[10.0, 10.0, 10.0]).as_slice()
|
||||
);
|
||||
assert!(result.get_ref(2).is_null());
|
||||
assert!(result.get_ref(3).is_null());
|
||||
assert!(result.is_null(2));
|
||||
assert!(result.is_null(3));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sub_error() {
|
||||
let func = VectorAddFunction;
|
||||
|
||||
let input0 = Arc::new(StringVector::from(vec![
|
||||
let input0 = Arc::new(StringViewArray::from(vec![
|
||||
Some("[1.0,2.0,3.0]".to_string()),
|
||||
Some("[4.0,5.0,6.0]".to_string()),
|
||||
None,
|
||||
Some("[2.0,3.0,3.0]".to_string()),
|
||||
]));
|
||||
let input1 = Arc::new(StringVector::from(vec![
|
||||
let input1 = Arc::new(StringViewArray::from(vec![
|
||||
Some("[1.0,1.0,1.0]".to_string()),
|
||||
Some("[6.0,5.0,4.0]".to_string()),
|
||||
Some("[3.0,2.0,2.0]".to_string()),
|
||||
]));
|
||||
|
||||
let result = func.eval(&FunctionContext::default(), &[input0, input1]);
|
||||
|
||||
match result {
|
||||
Err(Error::InvalidFuncArgs { err_msg, .. }) => {
|
||||
assert_eq!(
|
||||
err_msg,
|
||||
"The lengths of the vector are not aligned, args 0: 4, args 1: 3"
|
||||
)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
|
||||
arg_fields: vec![],
|
||||
number_rows: 4,
|
||||
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
let e = func.invoke_with_args(args).unwrap_err();
|
||||
assert!(e.to_string().starts_with(
|
||||
"Internal error: Arguments has mixed length. Expected length: 4, found length: 3."
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,19 +12,18 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::borrow::Cow;
|
||||
use std::fmt::Display;
|
||||
|
||||
use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use common_query::error::Result;
|
||||
use datafusion::arrow::datatypes::DataType;
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion::logical_expr_common::type_coercion::aggregates::{BINARYS, STRINGS};
|
||||
use datafusion_expr::{Signature, TypeSignature, Volatility};
|
||||
use datatypes::scalars::ScalarVectorBuilder;
|
||||
use datatypes::vectors::{MutableVector, UInt64VectorBuilder, VectorRef};
|
||||
use snafu::ensure;
|
||||
use datafusion_common::ScalarValue;
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const};
|
||||
use crate::function::Function;
|
||||
use crate::scalars::vector::VectorCalculator;
|
||||
use crate::scalars::vector::impl_conv::as_veclit;
|
||||
|
||||
const NAME: &str = "vec_dim";
|
||||
|
||||
@@ -63,43 +62,20 @@ impl Function for VectorDimFunction {
|
||||
)
|
||||
}
|
||||
|
||||
fn eval(
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
_func_ctx: &FunctionContext,
|
||||
columns: &[VectorRef],
|
||||
) -> common_query::error::Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 1,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect exactly one, have: {}",
|
||||
columns.len()
|
||||
)
|
||||
}
|
||||
);
|
||||
let arg0 = &columns[0];
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let body = |v0: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
|
||||
let v = as_veclit(v0)?.map(|v0| v0.len() as u64);
|
||||
Ok(ScalarValue::UInt64(v))
|
||||
};
|
||||
|
||||
let len = arg0.len();
|
||||
let mut result = UInt64VectorBuilder::with_capacity(len);
|
||||
if len == 0 {
|
||||
return Ok(result.to_vector());
|
||||
}
|
||||
|
||||
let arg0_const = as_veclit_if_const(arg0)?;
|
||||
|
||||
for i in 0..len {
|
||||
let arg0 = match arg0_const.as_ref() {
|
||||
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
|
||||
None => as_veclit(arg0.get_ref(i))?,
|
||||
};
|
||||
let Some(arg0) = arg0 else {
|
||||
result.push_null();
|
||||
continue;
|
||||
};
|
||||
result.push(Some(arg0.len() as u64));
|
||||
}
|
||||
|
||||
Ok(result.to_vector())
|
||||
let calculator = VectorCalculator {
|
||||
name: self.name(),
|
||||
func: body,
|
||||
};
|
||||
calculator.invoke_with_single_argument(args)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -113,8 +89,10 @@ impl Display for VectorDimFunction {
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error::Error;
|
||||
use datatypes::vectors::StringVector;
|
||||
use arrow_schema::Field;
|
||||
use datafusion::arrow::array::{Array, AsArray, StringViewArray};
|
||||
use datafusion::arrow::datatypes::UInt64Type;
|
||||
use datafusion_common::config::ConfigOptions;
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -122,49 +100,60 @@ mod tests {
|
||||
fn test_vec_dim() {
|
||||
let func = VectorDimFunction;
|
||||
|
||||
let input0 = Arc::new(StringVector::from(vec![
|
||||
let input0 = Arc::new(StringViewArray::from(vec![
|
||||
Some("[0.0,2.0,3.0]".to_string()),
|
||||
Some("[1.0,2.0,3.0,4.0]".to_string()),
|
||||
None,
|
||||
Some("[5.0]".to_string()),
|
||||
]));
|
||||
|
||||
let result = func.eval(&FunctionContext::default(), &[input0]).unwrap();
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![ColumnarValue::Array(input0)],
|
||||
arg_fields: vec![],
|
||||
number_rows: 4,
|
||||
return_field: Arc::new(Field::new("x", DataType::UInt64, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
let result = func
|
||||
.invoke_with_args(args)
|
||||
.and_then(|x| x.to_array(4))
|
||||
.unwrap();
|
||||
|
||||
let result = result.as_ref();
|
||||
let result = result.as_primitive::<UInt64Type>();
|
||||
assert_eq!(result.len(), 4);
|
||||
assert_eq!(result.get_ref(0).as_u64().unwrap(), Some(3));
|
||||
assert_eq!(result.get_ref(1).as_u64().unwrap(), Some(4));
|
||||
assert!(result.get_ref(2).is_null());
|
||||
assert_eq!(result.get_ref(3).as_u64().unwrap(), Some(1));
|
||||
assert_eq!(result.value(0), 3);
|
||||
assert_eq!(result.value(1), 4);
|
||||
assert!(result.is_null(2));
|
||||
assert_eq!(result.value(3), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dim_error() {
|
||||
let func = VectorDimFunction;
|
||||
|
||||
let input0 = Arc::new(StringVector::from(vec![
|
||||
let input0 = Arc::new(StringViewArray::from(vec![
|
||||
Some("[1.0,2.0,3.0]".to_string()),
|
||||
Some("[4.0,5.0,6.0]".to_string()),
|
||||
None,
|
||||
Some("[2.0,3.0,3.0]".to_string()),
|
||||
]));
|
||||
let input1 = Arc::new(StringVector::from(vec![
|
||||
let input1 = Arc::new(StringViewArray::from(vec![
|
||||
Some("[1.0,1.0,1.0]".to_string()),
|
||||
Some("[6.0,5.0,4.0]".to_string()),
|
||||
Some("[3.0,2.0,2.0]".to_string()),
|
||||
]));
|
||||
|
||||
let result = func.eval(&FunctionContext::default(), &[input0, input1]);
|
||||
|
||||
match result {
|
||||
Err(Error::InvalidFuncArgs { err_msg, .. }) => {
|
||||
assert_eq!(
|
||||
err_msg,
|
||||
"The length of the args is not correct, expect exactly one, have: 2"
|
||||
)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
|
||||
arg_fields: vec![],
|
||||
number_rows: 4,
|
||||
return_field: Arc::new(Field::new("x", DataType::UInt64, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
let e = func.invoke_with_args(args).unwrap_err();
|
||||
assert!(
|
||||
e.to_string()
|
||||
.starts_with("Execution error: vec_dim function requires 1 argument, got 2")
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,17 +15,17 @@
|
||||
use std::borrow::Cow;
|
||||
use std::fmt::Display;
|
||||
|
||||
use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use datafusion_expr::Signature;
|
||||
use datatypes::arrow::datatypes::DataType;
|
||||
use datatypes::scalars::ScalarVectorBuilder;
|
||||
use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
|
||||
use common_query::error::Result;
|
||||
use datafusion::arrow::datatypes::DataType;
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion_common::{DataFusionError, ScalarValue};
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature};
|
||||
use nalgebra::DVectorView;
|
||||
use snafu::ensure;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function::Function;
|
||||
use crate::helper;
|
||||
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
|
||||
use crate::scalars::vector::VectorCalculator;
|
||||
use crate::scalars::vector::impl_conv::veclit_to_binlit;
|
||||
|
||||
const NAME: &str = "vec_div";
|
||||
|
||||
@@ -52,7 +52,7 @@ impl Function for VectorDivFunction {
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
|
||||
Ok(DataType::Binary)
|
||||
Ok(DataType::BinaryView)
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
@@ -62,64 +62,36 @@ impl Function for VectorDivFunction {
|
||||
)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 2,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect exactly two, have: {}",
|
||||
columns.len()
|
||||
),
|
||||
}
|
||||
);
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let body = |v0: &Option<Cow<[f32]>>,
|
||||
v1: &Option<Cow<[f32]>>|
|
||||
-> datafusion_common::Result<ScalarValue> {
|
||||
let result = if let (Some(v0), Some(v1)) = (v0, v1) {
|
||||
let v0 = DVectorView::from_slice(v0, v0.len());
|
||||
let v1 = DVectorView::from_slice(v1, v1.len());
|
||||
if v0.len() != v1.len() {
|
||||
return Err(DataFusionError::Execution(format!(
|
||||
"vectors length not match: {}",
|
||||
self.name()
|
||||
)));
|
||||
}
|
||||
|
||||
let arg0 = &columns[0];
|
||||
let arg1 = &columns[1];
|
||||
|
||||
let len = arg0.len();
|
||||
let mut result = BinaryVectorBuilder::with_capacity(len);
|
||||
if len == 0 {
|
||||
return Ok(result.to_vector());
|
||||
}
|
||||
|
||||
let arg0_const = as_veclit_if_const(arg0)?;
|
||||
let arg1_const = as_veclit_if_const(arg1)?;
|
||||
|
||||
for i in 0..len {
|
||||
let arg0 = match arg0_const.as_ref() {
|
||||
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
|
||||
None => as_veclit(arg0.get_ref(i))?,
|
||||
};
|
||||
|
||||
let arg1 = match arg1_const.as_ref() {
|
||||
Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())),
|
||||
None => as_veclit(arg1.get_ref(i))?,
|
||||
};
|
||||
|
||||
if let (Some(arg0), Some(arg1)) = (arg0, arg1) {
|
||||
ensure!(
|
||||
arg0.len() == arg1.len(),
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the vectors must match for division, have: {} vs {}",
|
||||
arg0.len(),
|
||||
arg1.len()
|
||||
),
|
||||
}
|
||||
);
|
||||
let vec0 = DVectorView::from_slice(&arg0, arg0.len());
|
||||
let vec1 = DVectorView::from_slice(&arg1, arg1.len());
|
||||
let vec_res = vec0.component_div(&vec1);
|
||||
|
||||
let veclit = vec_res.as_slice();
|
||||
let binlit = veclit_to_binlit(veclit);
|
||||
result.push(Some(&binlit));
|
||||
let result = veclit_to_binlit((v0.component_div(&v1)).as_slice());
|
||||
Some(result)
|
||||
} else {
|
||||
result.push_null();
|
||||
}
|
||||
}
|
||||
None
|
||||
};
|
||||
Ok(ScalarValue::BinaryView(result))
|
||||
};
|
||||
|
||||
Ok(result.to_vector())
|
||||
let calculator = VectorCalculator {
|
||||
name: self.name(),
|
||||
func: body,
|
||||
};
|
||||
calculator.invoke_with_vectors(args)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -133,8 +105,9 @@ impl Display for VectorDivFunction {
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error;
|
||||
use datatypes::vectors::StringVector;
|
||||
use arrow_schema::Field;
|
||||
use datafusion::arrow::array::{Array, AsArray, StringViewArray};
|
||||
use datafusion_common::config::ConfigOptions;
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -144,69 +117,80 @@ mod tests {
|
||||
|
||||
let vec0 = vec![1.0, 2.0, 3.0];
|
||||
let vec1 = vec![1.0, 1.0];
|
||||
let (len0, len1) = (vec0.len(), vec1.len());
|
||||
let input0 = Arc::new(StringVector::from(vec![Some(format!("{vec0:?}"))]));
|
||||
let input1 = Arc::new(StringVector::from(vec![Some(format!("{vec1:?}"))]));
|
||||
let input0 = Arc::new(StringViewArray::from(vec![Some(format!("{vec0:?}"))]));
|
||||
let input1 = Arc::new(StringViewArray::from(vec![Some(format!("{vec1:?}"))]));
|
||||
|
||||
let err = func
|
||||
.eval(&FunctionContext::default(), &[input0, input1])
|
||||
.unwrap_err();
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
|
||||
arg_fields: vec![],
|
||||
number_rows: 3,
|
||||
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
let e = func.invoke_with_args(args).unwrap_err();
|
||||
assert_eq!(
|
||||
e.to_string(),
|
||||
"Execution error: vectors length not match: vec_div"
|
||||
);
|
||||
|
||||
match err {
|
||||
error::Error::InvalidFuncArgs { err_msg, .. } => {
|
||||
assert_eq!(
|
||||
err_msg,
|
||||
format!(
|
||||
"The length of the vectors must match for division, have: {} vs {}",
|
||||
len0, len1
|
||||
)
|
||||
)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
|
||||
let input0 = Arc::new(StringVector::from(vec![
|
||||
let input0 = Arc::new(StringViewArray::from(vec![
|
||||
Some("[1.0,2.0,3.0]".to_string()),
|
||||
Some("[8.0,10.0,12.0]".to_string()),
|
||||
Some("[7.0,8.0,9.0]".to_string()),
|
||||
None,
|
||||
]));
|
||||
|
||||
let input1 = Arc::new(StringVector::from(vec![
|
||||
let input1 = Arc::new(StringViewArray::from(vec![
|
||||
Some("[1.0,1.0,1.0]".to_string()),
|
||||
Some("[2.0,2.0,2.0]".to_string()),
|
||||
None,
|
||||
Some("[3.0,3.0,3.0]".to_string()),
|
||||
]));
|
||||
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
|
||||
arg_fields: vec![],
|
||||
number_rows: 4,
|
||||
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), &[input0, input1])
|
||||
.invoke_with_args(args)
|
||||
.and_then(|x| x.to_array(4))
|
||||
.unwrap();
|
||||
|
||||
let result = result.as_ref();
|
||||
let result = result.as_binary_view();
|
||||
assert_eq!(result.len(), 4);
|
||||
assert_eq!(
|
||||
result.get_ref(0).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[1.0, 2.0, 3.0]).as_slice())
|
||||
result.value(0),
|
||||
veclit_to_binlit(&[1.0, 2.0, 3.0]).as_slice()
|
||||
);
|
||||
assert_eq!(
|
||||
result.get_ref(1).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[4.0, 5.0, 6.0]).as_slice())
|
||||
result.value(1),
|
||||
veclit_to_binlit(&[4.0, 5.0, 6.0]).as_slice()
|
||||
);
|
||||
assert!(result.get_ref(2).is_null());
|
||||
assert!(result.get_ref(3).is_null());
|
||||
assert!(result.is_null(2));
|
||||
assert!(result.is_null(3));
|
||||
|
||||
let input0 = Arc::new(StringVector::from(vec![Some("[1.0,-2.0]".to_string())]));
|
||||
let input1 = Arc::new(StringVector::from(vec![Some("[0.0,0.0]".to_string())]));
|
||||
let input0 = Arc::new(StringViewArray::from(vec![Some("[1.0,-2.0]".to_string())]));
|
||||
let input1 = Arc::new(StringViewArray::from(vec![Some("[0.0,0.0]".to_string())]));
|
||||
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
|
||||
arg_fields: vec![],
|
||||
number_rows: 2,
|
||||
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), &[input0, input1])
|
||||
.invoke_with_args(args)
|
||||
.and_then(|x| x.to_array(2))
|
||||
.unwrap();
|
||||
|
||||
let result = result.as_ref();
|
||||
let result = result.as_binary_view();
|
||||
assert_eq!(
|
||||
result.get_ref(0).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[f64::INFINITY as f32, f64::NEG_INFINITY as f32]).as_slice())
|
||||
result.value(0),
|
||||
veclit_to_binlit(&[f64::INFINITY as f32, f64::NEG_INFINITY as f32]).as_slice()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,19 +12,18 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::borrow::Cow;
|
||||
use std::fmt::Display;
|
||||
|
||||
use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use datafusion_expr::Signature;
|
||||
use common_query::error::Result;
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion_common::{DataFusionError, ScalarValue};
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature};
|
||||
use datatypes::arrow::datatypes::DataType;
|
||||
use datatypes::scalars::ScalarVectorBuilder;
|
||||
use datatypes::vectors::{Float32VectorBuilder, MutableVector, VectorRef};
|
||||
use snafu::ensure;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function::Function;
|
||||
use crate::helper;
|
||||
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const};
|
||||
use crate::scalars::vector::VectorCalculator;
|
||||
use crate::scalars::vector::impl_conv::as_veclit;
|
||||
|
||||
const NAME: &str = "vec_kth_elem";
|
||||
|
||||
@@ -63,72 +62,44 @@ impl Function for VectorKthElemFunction {
|
||||
)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 2,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect exactly two, have: {}",
|
||||
columns.len()
|
||||
),
|
||||
}
|
||||
);
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let body = |v0: &ScalarValue, v1: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
|
||||
let v0 = as_veclit(v0)?;
|
||||
|
||||
let arg0 = &columns[0];
|
||||
let arg1 = &columns[1];
|
||||
let v1 = match v1 {
|
||||
ScalarValue::Int64(None) => return Ok(ScalarValue::Float32(None)),
|
||||
ScalarValue::Int64(Some(v1)) if *v1 >= 0 => *v1 as usize,
|
||||
_ => {
|
||||
return Err(DataFusionError::Execution(format!(
|
||||
"2nd argument not a valid index or expected datatype: {}",
|
||||
self.name()
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
let len = arg0.len();
|
||||
let mut result = Float32VectorBuilder::with_capacity(len);
|
||||
if len == 0 {
|
||||
return Ok(result.to_vector());
|
||||
let result = v0
|
||||
.map(|v0| {
|
||||
if v1 >= v0.len() {
|
||||
Err(DataFusionError::Execution(format!(
|
||||
"index out of bound: {}",
|
||||
self.name()
|
||||
)))
|
||||
} else {
|
||||
Ok(v0[v1])
|
||||
}
|
||||
})
|
||||
.transpose()?;
|
||||
Ok(ScalarValue::Float32(result))
|
||||
};
|
||||
|
||||
let arg0_const = as_veclit_if_const(arg0)?;
|
||||
|
||||
for i in 0..len {
|
||||
let arg0 = match arg0_const.as_ref() {
|
||||
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
|
||||
None => as_veclit(arg0.get_ref(i))?,
|
||||
};
|
||||
let Some(arg0) = arg0 else {
|
||||
result.push_null();
|
||||
continue;
|
||||
};
|
||||
|
||||
let arg1 = arg1.get(i).as_f64_lossy();
|
||||
let Some(arg1) = arg1 else {
|
||||
result.push_null();
|
||||
continue;
|
||||
};
|
||||
|
||||
ensure!(
|
||||
arg1 >= 0.0 && arg1.fract() == 0.0,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"Invalid argument: k must be a non-negative integer, but got k = {}.",
|
||||
arg1
|
||||
),
|
||||
}
|
||||
);
|
||||
|
||||
let k = arg1 as usize;
|
||||
|
||||
ensure!(
|
||||
k < arg0.len(),
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"Out of range: k must be in the range [0, {}], but got k = {}.",
|
||||
arg0.len() - 1,
|
||||
k
|
||||
),
|
||||
}
|
||||
);
|
||||
|
||||
let value = arg0[k];
|
||||
|
||||
result.push(Some(value));
|
||||
}
|
||||
Ok(result.to_vector())
|
||||
let calculator = VectorCalculator {
|
||||
name: self.name(),
|
||||
func: body,
|
||||
};
|
||||
calculator.invoke_with_args(args)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -142,8 +113,10 @@ impl Display for VectorKthElemFunction {
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error;
|
||||
use datatypes::vectors::{Int64Vector, StringVector};
|
||||
use arrow_schema::Field;
|
||||
use datafusion::arrow::array::{Array, ArrayRef, AsArray, Int64Array, StringViewArray};
|
||||
use datafusion::arrow::datatypes::Float32Type;
|
||||
use datafusion_common::config::ConfigOptions;
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -151,55 +124,66 @@ mod tests {
|
||||
fn test_vec_kth_elem() {
|
||||
let func = VectorKthElemFunction;
|
||||
|
||||
let input0 = Arc::new(StringVector::from(vec![
|
||||
let input0: ArrayRef = Arc::new(StringViewArray::from(vec![
|
||||
Some("[1.0,2.0,3.0]".to_string()),
|
||||
Some("[4.0,5.0,6.0]".to_string()),
|
||||
Some("[7.0,8.0,9.0]".to_string()),
|
||||
None,
|
||||
]));
|
||||
let input1 = Arc::new(Int64Vector::from(vec![Some(0), Some(2), None, Some(1)]));
|
||||
let input1: ArrayRef = Arc::new(Int64Array::from(vec![Some(0), Some(2), None, Some(1)]));
|
||||
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
|
||||
arg_fields: vec![],
|
||||
number_rows: 4,
|
||||
return_field: Arc::new(Field::new("x", DataType::Float32, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), &[input0, input1])
|
||||
.invoke_with_args(args)
|
||||
.and_then(|x| x.to_array(4))
|
||||
.unwrap();
|
||||
|
||||
let result = result.as_ref();
|
||||
let result = result.as_primitive::<Float32Type>();
|
||||
assert_eq!(result.len(), 4);
|
||||
assert_eq!(result.get_ref(0).as_f32().unwrap(), Some(1.0));
|
||||
assert_eq!(result.get_ref(1).as_f32().unwrap(), Some(6.0));
|
||||
assert!(result.get_ref(2).is_null());
|
||||
assert!(result.get_ref(3).is_null());
|
||||
assert_eq!(result.value(0), 1.0);
|
||||
assert_eq!(result.value(1), 6.0);
|
||||
assert!(result.is_null(2));
|
||||
assert!(result.is_null(3));
|
||||
|
||||
let input0 = Arc::new(StringVector::from(vec![Some("[1.0,2.0,3.0]".to_string())]));
|
||||
let input1 = Arc::new(Int64Vector::from(vec![Some(3)]));
|
||||
let input0: ArrayRef = Arc::new(StringViewArray::from(vec![Some(
|
||||
"[1.0,2.0,3.0]".to_string(),
|
||||
)]));
|
||||
let input1: ArrayRef = Arc::new(Int64Array::from(vec![Some(3)]));
|
||||
|
||||
let err = func
|
||||
.eval(&FunctionContext::default(), &[input0, input1])
|
||||
.unwrap_err();
|
||||
match err {
|
||||
error::Error::InvalidFuncArgs { err_msg, .. } => {
|
||||
assert_eq!(
|
||||
err_msg,
|
||||
format!("Out of range: k must be in the range [0, 2], but got k = 3.")
|
||||
)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
|
||||
arg_fields: vec![],
|
||||
number_rows: 3,
|
||||
return_field: Arc::new(Field::new("x", DataType::Float32, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
let e = func.invoke_with_args(args).unwrap_err();
|
||||
assert!(
|
||||
e.to_string()
|
||||
.starts_with("Execution error: index out of bound: vec_kth_elem")
|
||||
);
|
||||
|
||||
let input0 = Arc::new(StringVector::from(vec![Some("[1.0,2.0,3.0]".to_string())]));
|
||||
let input1 = Arc::new(Int64Vector::from(vec![Some(-1)]));
|
||||
let input0: ArrayRef = Arc::new(StringViewArray::from(vec![Some(
|
||||
"[1.0,2.0,3.0]".to_string(),
|
||||
)]));
|
||||
let input1: ArrayRef = Arc::new(Int64Array::from(vec![Some(-1)]));
|
||||
|
||||
let err = func
|
||||
.eval(&FunctionContext::default(), &[input0, input1])
|
||||
.unwrap_err();
|
||||
match err {
|
||||
error::Error::InvalidFuncArgs { err_msg, .. } => {
|
||||
assert_eq!(
|
||||
err_msg,
|
||||
format!("Invalid argument: k must be a non-negative integer, but got k = -1.")
|
||||
)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
|
||||
arg_fields: vec![],
|
||||
number_rows: 3,
|
||||
return_field: Arc::new(Field::new("x", DataType::Float32, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
let e = func.invoke_with_args(args).unwrap_err();
|
||||
assert!(e.to_string().starts_with(
|
||||
"Execution error: 2nd argument not a valid index or expected datatype: vec_kth_elem"
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,17 +15,17 @@
|
||||
use std::borrow::Cow;
|
||||
use std::fmt::Display;
|
||||
|
||||
use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use datafusion_expr::Signature;
|
||||
use datatypes::arrow::datatypes::DataType;
|
||||
use datatypes::scalars::ScalarVectorBuilder;
|
||||
use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
|
||||
use common_query::error::Result;
|
||||
use datafusion::arrow::datatypes::DataType;
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion_common::{DataFusionError, ScalarValue};
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature};
|
||||
use nalgebra::DVectorView;
|
||||
use snafu::ensure;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function::Function;
|
||||
use crate::helper;
|
||||
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
|
||||
use crate::scalars::vector::VectorCalculator;
|
||||
use crate::scalars::vector::impl_conv::veclit_to_binlit;
|
||||
|
||||
const NAME: &str = "vec_mul";
|
||||
|
||||
@@ -52,7 +52,7 @@ impl Function for VectorMulFunction {
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
|
||||
Ok(DataType::Binary)
|
||||
Ok(DataType::BinaryView)
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
@@ -62,64 +62,36 @@ impl Function for VectorMulFunction {
|
||||
)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 2,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect exactly two, have: {}",
|
||||
columns.len()
|
||||
),
|
||||
}
|
||||
);
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let body = |v0: &Option<Cow<[f32]>>,
|
||||
v1: &Option<Cow<[f32]>>|
|
||||
-> datafusion_common::Result<ScalarValue> {
|
||||
let result = if let (Some(v0), Some(v1)) = (v0, v1) {
|
||||
let v0 = DVectorView::from_slice(v0, v0.len());
|
||||
let v1 = DVectorView::from_slice(v1, v1.len());
|
||||
if v0.len() != v1.len() {
|
||||
return Err(DataFusionError::Execution(format!(
|
||||
"vectors length not match: {}",
|
||||
self.name()
|
||||
)));
|
||||
}
|
||||
|
||||
let arg0 = &columns[0];
|
||||
let arg1 = &columns[1];
|
||||
|
||||
let len = arg0.len();
|
||||
let mut result = BinaryVectorBuilder::with_capacity(len);
|
||||
if len == 0 {
|
||||
return Ok(result.to_vector());
|
||||
}
|
||||
|
||||
let arg0_const = as_veclit_if_const(arg0)?;
|
||||
let arg1_const = as_veclit_if_const(arg1)?;
|
||||
|
||||
for i in 0..len {
|
||||
let arg0 = match arg0_const.as_ref() {
|
||||
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
|
||||
None => as_veclit(arg0.get_ref(i))?,
|
||||
};
|
||||
|
||||
let arg1 = match arg1_const.as_ref() {
|
||||
Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())),
|
||||
None => as_veclit(arg1.get_ref(i))?,
|
||||
};
|
||||
|
||||
if let (Some(arg0), Some(arg1)) = (arg0, arg1) {
|
||||
ensure!(
|
||||
arg0.len() == arg1.len(),
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the vectors must match for multiplying, have: {} vs {}",
|
||||
arg0.len(),
|
||||
arg1.len()
|
||||
),
|
||||
}
|
||||
);
|
||||
let vec0 = DVectorView::from_slice(&arg0, arg0.len());
|
||||
let vec1 = DVectorView::from_slice(&arg1, arg1.len());
|
||||
let vec_res = vec1.component_mul(&vec0);
|
||||
|
||||
let veclit = vec_res.as_slice();
|
||||
let binlit = veclit_to_binlit(veclit);
|
||||
result.push(Some(&binlit));
|
||||
let result = veclit_to_binlit((v0.component_mul(&v1)).as_slice());
|
||||
Some(result)
|
||||
} else {
|
||||
result.push_null();
|
||||
}
|
||||
}
|
||||
None
|
||||
};
|
||||
Ok(ScalarValue::BinaryView(result))
|
||||
};
|
||||
|
||||
Ok(result.to_vector())
|
||||
let calculator = VectorCalculator {
|
||||
name: self.name(),
|
||||
func: body,
|
||||
};
|
||||
calculator.invoke_with_vectors(args)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -133,8 +105,9 @@ impl Display for VectorMulFunction {
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error;
|
||||
use datatypes::vectors::StringVector;
|
||||
use arrow_schema::Field;
|
||||
use datafusion::arrow::array::{Array, AsArray, StringViewArray};
|
||||
use datafusion_common::config::ConfigOptions;
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -144,56 +117,59 @@ mod tests {
|
||||
|
||||
let vec0 = vec![1.0, 2.0, 3.0];
|
||||
let vec1 = vec![1.0, 1.0];
|
||||
let (len0, len1) = (vec0.len(), vec1.len());
|
||||
let input0 = Arc::new(StringVector::from(vec![Some(format!("{vec0:?}"))]));
|
||||
let input1 = Arc::new(StringVector::from(vec![Some(format!("{vec1:?}"))]));
|
||||
let input0 = Arc::new(StringViewArray::from(vec![Some(format!("{vec0:?}"))]));
|
||||
let input1 = Arc::new(StringViewArray::from(vec![Some(format!("{vec1:?}"))]));
|
||||
|
||||
let err = func
|
||||
.eval(&FunctionContext::default(), &[input0, input1])
|
||||
.unwrap_err();
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
|
||||
arg_fields: vec![],
|
||||
number_rows: 4,
|
||||
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
let e = func.invoke_with_args(args).unwrap_err();
|
||||
assert!(
|
||||
e.to_string()
|
||||
.starts_with("Execution error: vectors length not match: vec_mul")
|
||||
);
|
||||
|
||||
match err {
|
||||
error::Error::InvalidFuncArgs { err_msg, .. } => {
|
||||
assert_eq!(
|
||||
err_msg,
|
||||
format!(
|
||||
"The length of the vectors must match for multiplying, have: {} vs {}",
|
||||
len0, len1
|
||||
)
|
||||
)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
|
||||
let input0 = Arc::new(StringVector::from(vec![
|
||||
let input0 = Arc::new(StringViewArray::from(vec![
|
||||
Some("[1.0,2.0,3.0]".to_string()),
|
||||
Some("[8.0,10.0,12.0]".to_string()),
|
||||
Some("[7.0,8.0,9.0]".to_string()),
|
||||
None,
|
||||
]));
|
||||
|
||||
let input1 = Arc::new(StringVector::from(vec![
|
||||
let input1 = Arc::new(StringViewArray::from(vec![
|
||||
Some("[1.0,1.0,1.0]".to_string()),
|
||||
Some("[2.0,2.0,2.0]".to_string()),
|
||||
None,
|
||||
Some("[3.0,3.0,3.0]".to_string()),
|
||||
]));
|
||||
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
|
||||
arg_fields: vec![],
|
||||
number_rows: 4,
|
||||
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), &[input0, input1])
|
||||
.invoke_with_args(args)
|
||||
.and_then(|x| x.to_array(4))
|
||||
.unwrap();
|
||||
|
||||
let result = result.as_ref();
|
||||
let result = result.as_binary_view();
|
||||
assert_eq!(result.len(), 4);
|
||||
assert_eq!(
|
||||
result.get_ref(0).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[1.0, 2.0, 3.0]).as_slice())
|
||||
result.value(0),
|
||||
veclit_to_binlit(&[1.0, 2.0, 3.0]).as_slice()
|
||||
);
|
||||
assert_eq!(
|
||||
result.get_ref(1).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[16.0, 20.0, 24.0]).as_slice())
|
||||
result.value(1),
|
||||
veclit_to_binlit(&[16.0, 20.0, 24.0]).as_slice()
|
||||
);
|
||||
assert!(result.get_ref(2).is_null());
|
||||
assert!(result.get_ref(3).is_null());
|
||||
assert!(result.is_null(2));
|
||||
assert!(result.is_null(3));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,20 +12,19 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::borrow::Cow;
|
||||
use std::fmt::Display;
|
||||
|
||||
use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use common_query::error::Result;
|
||||
use datafusion::arrow::datatypes::DataType;
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion::logical_expr_common::type_coercion::aggregates::{BINARYS, STRINGS};
|
||||
use datafusion_expr::{Signature, TypeSignature, Volatility};
|
||||
use datatypes::scalars::ScalarVectorBuilder;
|
||||
use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
|
||||
use datafusion_common::ScalarValue;
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
|
||||
use nalgebra::DVectorView;
|
||||
use snafu::ensure;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
|
||||
use crate::function::Function;
|
||||
use crate::scalars::vector::VectorCalculator;
|
||||
use crate::scalars::vector::impl_conv::{as_veclit, veclit_to_binlit};
|
||||
|
||||
const NAME: &str = "vec_norm";
|
||||
|
||||
@@ -53,7 +52,7 @@ impl Function for VectorNormFunction {
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
|
||||
Ok(DataType::Binary)
|
||||
Ok(DataType::BinaryView)
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
@@ -66,55 +65,27 @@ impl Function for VectorNormFunction {
|
||||
)
|
||||
}
|
||||
|
||||
fn eval(
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
_func_ctx: &FunctionContext,
|
||||
columns: &[VectorRef],
|
||||
) -> common_query::error::Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 1,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect exactly one, have: {}",
|
||||
columns.len()
|
||||
)
|
||||
}
|
||||
);
|
||||
let arg0 = &columns[0];
|
||||
|
||||
let len = arg0.len();
|
||||
let mut result = BinaryVectorBuilder::with_capacity(len);
|
||||
if len == 0 {
|
||||
return Ok(result.to_vector());
|
||||
}
|
||||
|
||||
let arg0_const = as_veclit_if_const(arg0)?;
|
||||
|
||||
for i in 0..len {
|
||||
let arg0 = match arg0_const.as_ref() {
|
||||
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
|
||||
None => as_veclit(arg0.get_ref(i))?,
|
||||
};
|
||||
let Some(arg0) = arg0 else {
|
||||
result.push_null();
|
||||
continue;
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let body = |v0: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
|
||||
let v0 = as_veclit(v0)?;
|
||||
let Some(v0) = v0 else {
|
||||
return Ok(ScalarValue::BinaryView(None));
|
||||
};
|
||||
|
||||
let vec0 = DVectorView::from_slice(&arg0, arg0.len());
|
||||
let vec1 = DVectorView::from_slice(&arg0, arg0.len());
|
||||
let vec2scalar = vec1.component_mul(&vec0);
|
||||
let scalar_var = vec2scalar.sum().sqrt();
|
||||
let v0 = DVectorView::from_slice(&v0, v0.len());
|
||||
let result =
|
||||
veclit_to_binlit(v0.unscale(v0.component_mul(&v0).sum().sqrt()).as_slice());
|
||||
Ok(ScalarValue::BinaryView(Some(result)))
|
||||
};
|
||||
|
||||
let vec = DVectorView::from_slice(&arg0, arg0.len());
|
||||
// Use unscale to avoid division by zero and keep more precision as possible
|
||||
let vec_res = vec.unscale(scalar_var);
|
||||
|
||||
let veclit = vec_res.as_slice();
|
||||
let binlit = veclit_to_binlit(veclit);
|
||||
result.push(Some(&binlit));
|
||||
}
|
||||
|
||||
Ok(result.to_vector())
|
||||
let calculator = VectorCalculator {
|
||||
name: self.name(),
|
||||
func: body,
|
||||
};
|
||||
calculator.invoke_with_single_argument(args)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -128,7 +99,9 @@ impl Display for VectorNormFunction {
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use datatypes::vectors::StringVector;
|
||||
use arrow_schema::Field;
|
||||
use datafusion::arrow::array::{Array, AsArray, StringViewArray};
|
||||
use datafusion_common::config::ConfigOptions;
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -136,7 +109,7 @@ mod tests {
|
||||
fn test_vec_norm() {
|
||||
let func = VectorNormFunction;
|
||||
|
||||
let input0 = Arc::new(StringVector::from(vec![
|
||||
let input0 = Arc::new(StringViewArray::from(vec![
|
||||
Some("[0.0,2.0,3.0]".to_string()),
|
||||
Some("[1.0,2.0,3.0]".to_string()),
|
||||
Some("[7.0,8.0,9.0]".to_string()),
|
||||
@@ -144,26 +117,36 @@ mod tests {
|
||||
None,
|
||||
]));
|
||||
|
||||
let result = func.eval(&FunctionContext::default(), &[input0]).unwrap();
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![ColumnarValue::Array(input0)],
|
||||
arg_fields: vec![],
|
||||
number_rows: 5,
|
||||
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
let result = func
|
||||
.invoke_with_args(args)
|
||||
.and_then(|x| x.to_array(5))
|
||||
.unwrap();
|
||||
|
||||
let result = result.as_ref();
|
||||
let result = result.as_binary_view();
|
||||
assert_eq!(result.len(), 5);
|
||||
assert_eq!(
|
||||
result.get_ref(0).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[0.0, 0.5547002, 0.8320503]).as_slice())
|
||||
result.value(0),
|
||||
veclit_to_binlit(&[0.0, 0.5547002, 0.8320503]).as_slice()
|
||||
);
|
||||
assert_eq!(
|
||||
result.get_ref(1).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[0.26726124, 0.5345225, 0.8017837]).as_slice())
|
||||
result.value(1),
|
||||
veclit_to_binlit(&[0.26726124, 0.5345225, 0.8017837]).as_slice()
|
||||
);
|
||||
assert_eq!(
|
||||
result.get_ref(2).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[0.5025707, 0.5743665, 0.64616233]).as_slice())
|
||||
result.value(2),
|
||||
veclit_to_binlit(&[0.5025707, 0.5743665, 0.64616233]).as_slice()
|
||||
);
|
||||
assert_eq!(
|
||||
result.get_ref(3).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[0.5025707, -0.5743665, 0.64616233]).as_slice())
|
||||
result.value(3),
|
||||
veclit_to_binlit(&[0.5025707, -0.5743665, 0.64616233]).as_slice()
|
||||
);
|
||||
assert!(result.get_ref(4).is_null());
|
||||
assert!(result.is_null(4));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,17 +15,17 @@
|
||||
use std::borrow::Cow;
|
||||
use std::fmt::Display;
|
||||
|
||||
use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use datafusion_expr::Signature;
|
||||
use datatypes::arrow::datatypes::DataType;
|
||||
use datatypes::scalars::ScalarVectorBuilder;
|
||||
use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
|
||||
use common_query::error::Result;
|
||||
use datafusion::arrow::datatypes::DataType;
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion_common::{DataFusionError, ScalarValue};
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature};
|
||||
use nalgebra::DVectorView;
|
||||
use snafu::ensure;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function::Function;
|
||||
use crate::helper;
|
||||
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
|
||||
use crate::scalars::vector::VectorCalculator;
|
||||
use crate::scalars::vector::impl_conv::veclit_to_binlit;
|
||||
|
||||
const NAME: &str = "vec_sub";
|
||||
|
||||
@@ -51,7 +51,7 @@ impl Function for VectorSubFunction {
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
|
||||
Ok(DataType::Binary)
|
||||
Ok(DataType::BinaryView)
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
@@ -61,66 +61,36 @@ impl Function for VectorSubFunction {
|
||||
)
|
||||
}
|
||||
|
||||
fn eval(
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
_func_ctx: &FunctionContext,
|
||||
columns: &[VectorRef],
|
||||
) -> common_query::error::Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 2,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect exactly two, have: {}",
|
||||
columns.len()
|
||||
)
|
||||
}
|
||||
);
|
||||
let arg0 = &columns[0];
|
||||
let arg1 = &columns[1];
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let body = |v0: &Option<Cow<[f32]>>,
|
||||
v1: &Option<Cow<[f32]>>|
|
||||
-> datafusion_common::Result<ScalarValue> {
|
||||
let result = if let (Some(v0), Some(v1)) = (v0, v1) {
|
||||
let v0 = DVectorView::from_slice(v0, v0.len());
|
||||
let v1 = DVectorView::from_slice(v1, v1.len());
|
||||
if v0.len() != v1.len() {
|
||||
return Err(DataFusionError::Execution(format!(
|
||||
"vectors length not match: {}",
|
||||
self.name()
|
||||
)));
|
||||
}
|
||||
|
||||
ensure!(
|
||||
arg0.len() == arg1.len(),
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The lengths of the vector are not aligned, args 0: {}, args 1: {}",
|
||||
arg0.len(),
|
||||
arg1.len(),
|
||||
)
|
||||
}
|
||||
);
|
||||
|
||||
let len = arg0.len();
|
||||
let mut result = BinaryVectorBuilder::with_capacity(len);
|
||||
if len == 0 {
|
||||
return Ok(result.to_vector());
|
||||
}
|
||||
|
||||
let arg0_const = as_veclit_if_const(arg0)?;
|
||||
let arg1_const = as_veclit_if_const(arg1)?;
|
||||
|
||||
for i in 0..len {
|
||||
let arg0 = match arg0_const.as_ref() {
|
||||
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
|
||||
None => as_veclit(arg0.get_ref(i))?,
|
||||
let result = veclit_to_binlit((v0 - v1).as_slice());
|
||||
Some(result)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let arg1 = match arg1_const.as_ref() {
|
||||
Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())),
|
||||
None => as_veclit(arg1.get_ref(i))?,
|
||||
};
|
||||
let (Some(arg0), Some(arg1)) = (arg0, arg1) else {
|
||||
result.push_null();
|
||||
continue;
|
||||
};
|
||||
let vec0 = DVectorView::from_slice(&arg0, arg0.len());
|
||||
let vec1 = DVectorView::from_slice(&arg1, arg1.len());
|
||||
Ok(ScalarValue::BinaryView(result))
|
||||
};
|
||||
|
||||
let vec_res = vec0 - vec1;
|
||||
let veclit = vec_res.as_slice();
|
||||
let binlit = veclit_to_binlit(veclit);
|
||||
result.push(Some(&binlit));
|
||||
}
|
||||
|
||||
Ok(result.to_vector())
|
||||
let calculator = VectorCalculator {
|
||||
name: self.name(),
|
||||
func: body,
|
||||
};
|
||||
calculator.invoke_with_vectors(args)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -134,8 +104,9 @@ impl Display for VectorSubFunction {
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error::Error;
|
||||
use datatypes::vectors::StringVector;
|
||||
use arrow_schema::Field;
|
||||
use datafusion::arrow::array::{Array, ArrayRef, AsArray, StringViewArray};
|
||||
use datafusion_common::config::ConfigOptions;
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -143,63 +114,71 @@ mod tests {
|
||||
fn test_sub() {
|
||||
let func = VectorSubFunction;
|
||||
|
||||
let input0 = Arc::new(StringVector::from(vec![
|
||||
let input0: ArrayRef = Arc::new(StringViewArray::from(vec![
|
||||
Some("[1.0,2.0,3.0]".to_string()),
|
||||
Some("[4.0,5.0,6.0]".to_string()),
|
||||
None,
|
||||
Some("[2.0,3.0,3.0]".to_string()),
|
||||
]));
|
||||
let input1 = Arc::new(StringVector::from(vec![
|
||||
let input1: ArrayRef = Arc::new(StringViewArray::from(vec![
|
||||
Some("[1.0,1.0,1.0]".to_string()),
|
||||
Some("[6.0,5.0,4.0]".to_string()),
|
||||
Some("[3.0,2.0,2.0]".to_string()),
|
||||
None,
|
||||
]));
|
||||
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
|
||||
arg_fields: vec![],
|
||||
number_rows: 4,
|
||||
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), &[input0, input1])
|
||||
.invoke_with_args(args)
|
||||
.and_then(|x| x.to_array(4))
|
||||
.unwrap();
|
||||
|
||||
let result = result.as_ref();
|
||||
let result = result.as_binary_view();
|
||||
assert_eq!(result.len(), 4);
|
||||
assert_eq!(
|
||||
result.get_ref(0).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[0.0, 1.0, 2.0]).as_slice())
|
||||
result.value(0),
|
||||
veclit_to_binlit(&[0.0, 1.0, 2.0]).as_slice()
|
||||
);
|
||||
assert_eq!(
|
||||
result.get_ref(1).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[-2.0, 0.0, 2.0]).as_slice())
|
||||
result.value(1),
|
||||
veclit_to_binlit(&[-2.0, 0.0, 2.0]).as_slice()
|
||||
);
|
||||
assert!(result.get_ref(2).is_null());
|
||||
assert!(result.get_ref(3).is_null());
|
||||
assert!(result.is_null(2));
|
||||
assert!(result.is_null(3));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sub_error() {
|
||||
let func = VectorSubFunction;
|
||||
|
||||
let input0 = Arc::new(StringVector::from(vec![
|
||||
let input0: ArrayRef = Arc::new(StringViewArray::from(vec![
|
||||
Some("[1.0,2.0,3.0]".to_string()),
|
||||
Some("[4.0,5.0,6.0]".to_string()),
|
||||
None,
|
||||
Some("[2.0,3.0,3.0]".to_string()),
|
||||
]));
|
||||
let input1 = Arc::new(StringVector::from(vec![
|
||||
let input1: ArrayRef = Arc::new(StringViewArray::from(vec![
|
||||
Some("[1.0,1.0,1.0]".to_string()),
|
||||
Some("[6.0,5.0,4.0]".to_string()),
|
||||
Some("[3.0,2.0,2.0]".to_string()),
|
||||
]));
|
||||
|
||||
let result = func.eval(&FunctionContext::default(), &[input0, input1]);
|
||||
|
||||
match result {
|
||||
Err(Error::InvalidFuncArgs { err_msg, .. }) => {
|
||||
assert_eq!(
|
||||
err_msg,
|
||||
"The lengths of the vector are not aligned, args 0: 4, args 1: 3"
|
||||
)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
|
||||
arg_fields: vec![],
|
||||
number_rows: 4,
|
||||
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
let e = func.invoke_with_args(args).unwrap_err();
|
||||
assert!(e.to_string().starts_with(
|
||||
"Internal error: Arguments has mixed length. Expected length: 4, found length: 3."
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,18 +12,20 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::borrow::Cow;
|
||||
use std::fmt::Display;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error::{InvalidFuncArgsSnafu, Result};
|
||||
use datafusion_expr::{Signature, TypeSignature, Volatility};
|
||||
use datafusion::arrow::array::{Array, AsArray, BinaryViewBuilder};
|
||||
use datafusion::arrow::datatypes::Int64Type;
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion_common::{ScalarValue, utils};
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
|
||||
use datatypes::arrow::datatypes::DataType;
|
||||
use datatypes::scalars::ScalarVectorBuilder;
|
||||
use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
|
||||
use snafu::ensure;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
|
||||
use crate::function::Function;
|
||||
use crate::scalars::vector::impl_conv::{as_veclit, veclit_to_binlit};
|
||||
|
||||
const NAME: &str = "vec_subvector";
|
||||
|
||||
@@ -52,7 +54,7 @@ impl Function for VectorSubvectorFunction {
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
|
||||
Ok(DataType::Binary)
|
||||
Ok(DataType::BinaryView)
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
@@ -65,50 +67,28 @@ impl Function for VectorSubvectorFunction {
|
||||
)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
ensure!(
|
||||
columns.len() == 3,
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The length of the args is not correct, expect exactly three, have: {}",
|
||||
columns.len()
|
||||
)
|
||||
}
|
||||
);
|
||||
|
||||
let arg0 = &columns[0];
|
||||
let arg1 = &columns[1];
|
||||
let arg2 = &columns[2];
|
||||
|
||||
ensure!(
|
||||
arg0.len() == arg1.len() && arg1.len() == arg2.len(),
|
||||
InvalidFuncArgsSnafu {
|
||||
err_msg: format!(
|
||||
"The lengths of the vector are not aligned, args 0: {}, args 1: {}, args 2: {}",
|
||||
arg0.len(),
|
||||
arg1.len(),
|
||||
arg2.len()
|
||||
)
|
||||
}
|
||||
);
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let args = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let [arg0, arg1, arg2] = utils::take_function_args(self.name(), args)?;
|
||||
let arg1 = arg1.as_primitive::<Int64Type>();
|
||||
let arg2 = arg2.as_primitive::<Int64Type>();
|
||||
|
||||
let len = arg0.len();
|
||||
let mut result = BinaryVectorBuilder::with_capacity(len);
|
||||
let mut builder = BinaryViewBuilder::with_capacity(len);
|
||||
if len == 0 {
|
||||
return Ok(result.to_vector());
|
||||
return Ok(ColumnarValue::Array(Arc::new(builder.finish())));
|
||||
}
|
||||
|
||||
let arg0_const = as_veclit_if_const(arg0)?;
|
||||
|
||||
for i in 0..len {
|
||||
let arg0 = match arg0_const.as_ref() {
|
||||
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
|
||||
None => as_veclit(arg0.get_ref(i))?,
|
||||
};
|
||||
let arg1 = arg1.get(i).as_i64();
|
||||
let arg2 = arg2.get(i).as_i64();
|
||||
let v = ScalarValue::try_from_array(&arg0, i)?;
|
||||
let arg0 = as_veclit(&v)?;
|
||||
let arg1 = arg1.is_valid(i).then(|| arg1.value(i));
|
||||
let arg2 = arg2.is_valid(i).then(|| arg2.value(i));
|
||||
let (Some(arg0), Some(arg1), Some(arg2)) = (arg0, arg1, arg2) else {
|
||||
result.push_null();
|
||||
builder.append_null();
|
||||
continue;
|
||||
};
|
||||
|
||||
@@ -126,10 +106,10 @@ impl Function for VectorSubvectorFunction {
|
||||
|
||||
let subvector = &arg0[arg1 as usize..arg2 as usize];
|
||||
let binlit = veclit_to_binlit(subvector);
|
||||
result.push(Some(&binlit));
|
||||
builder.append_value(&binlit);
|
||||
}
|
||||
|
||||
Ok(result.to_vector())
|
||||
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -143,89 +123,102 @@ impl Display for VectorSubvectorFunction {
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error::Error;
|
||||
use datatypes::vectors::{Int64Vector, StringVector};
|
||||
use arrow_schema::Field;
|
||||
use datafusion::arrow::array::{ArrayRef, Int64Array, StringViewArray};
|
||||
use datafusion_common::config::ConfigOptions;
|
||||
|
||||
use super::*;
|
||||
use crate::function::FunctionContext;
|
||||
|
||||
#[test]
|
||||
fn test_subvector() {
|
||||
let func = VectorSubvectorFunction;
|
||||
|
||||
let input0 = Arc::new(StringVector::from(vec![
|
||||
let input0: ArrayRef = Arc::new(StringViewArray::from(vec![
|
||||
Some("[1.0, 2.0, 3.0, 4.0, 5.0]".to_string()),
|
||||
Some("[6.0, 7.0, 8.0, 9.0, 10.0]".to_string()),
|
||||
None,
|
||||
Some("[11.0, 12.0, 13.0]".to_string()),
|
||||
]));
|
||||
let input1 = Arc::new(Int64Vector::from(vec![Some(1), Some(0), Some(0), Some(1)]));
|
||||
let input2 = Arc::new(Int64Vector::from(vec![Some(3), Some(5), Some(2), Some(3)]));
|
||||
let input1: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), Some(0), Some(0), Some(1)]));
|
||||
let input2: ArrayRef = Arc::new(Int64Array::from(vec![Some(3), Some(5), Some(2), Some(3)]));
|
||||
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![
|
||||
ColumnarValue::Array(input0),
|
||||
ColumnarValue::Array(input1),
|
||||
ColumnarValue::Array(input2),
|
||||
],
|
||||
arg_fields: vec![],
|
||||
number_rows: 5,
|
||||
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
let result = func
|
||||
.eval(&FunctionContext::default(), &[input0, input1, input2])
|
||||
.invoke_with_args(args)
|
||||
.and_then(|x| x.to_array(5))
|
||||
.unwrap();
|
||||
|
||||
let result = result.as_ref();
|
||||
let result = result.as_binary_view();
|
||||
assert_eq!(result.len(), 4);
|
||||
assert_eq!(result.value(0), veclit_to_binlit(&[2.0, 3.0]).as_slice());
|
||||
assert_eq!(
|
||||
result.get_ref(0).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[2.0, 3.0]).as_slice())
|
||||
);
|
||||
assert_eq!(
|
||||
result.get_ref(1).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[6.0, 7.0, 8.0, 9.0, 10.0]).as_slice())
|
||||
);
|
||||
assert!(result.get_ref(2).is_null());
|
||||
assert_eq!(
|
||||
result.get_ref(3).as_binary().unwrap(),
|
||||
Some(veclit_to_binlit(&[12.0, 13.0]).as_slice())
|
||||
result.value(1),
|
||||
veclit_to_binlit(&[6.0, 7.0, 8.0, 9.0, 10.0]).as_slice()
|
||||
);
|
||||
assert!(result.is_null(2));
|
||||
assert_eq!(result.value(3), veclit_to_binlit(&[12.0, 13.0]).as_slice());
|
||||
}
|
||||
#[test]
|
||||
fn test_subvector_error() {
|
||||
let func = VectorSubvectorFunction;
|
||||
|
||||
let input0 = Arc::new(StringVector::from(vec![
|
||||
let input0: ArrayRef = Arc::new(StringViewArray::from(vec![
|
||||
Some("[1.0, 2.0, 3.0]".to_string()),
|
||||
Some("[4.0, 5.0, 6.0]".to_string()),
|
||||
]));
|
||||
let input1 = Arc::new(Int64Vector::from(vec![Some(1), Some(2)]));
|
||||
let input2 = Arc::new(Int64Vector::from(vec![Some(3)]));
|
||||
let input1: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), Some(2)]));
|
||||
let input2: ArrayRef = Arc::new(Int64Array::from(vec![Some(3)]));
|
||||
|
||||
let result = func.eval(&FunctionContext::default(), &[input0, input1, input2]);
|
||||
|
||||
match result {
|
||||
Err(Error::InvalidFuncArgs { err_msg, .. }) => {
|
||||
assert_eq!(
|
||||
err_msg,
|
||||
"The lengths of the vector are not aligned, args 0: 2, args 1: 2, args 2: 1"
|
||||
)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![
|
||||
ColumnarValue::Array(input0),
|
||||
ColumnarValue::Array(input1),
|
||||
ColumnarValue::Array(input2),
|
||||
],
|
||||
arg_fields: vec![],
|
||||
number_rows: 3,
|
||||
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
let e = func.invoke_with_args(args).unwrap_err();
|
||||
assert!(e.to_string().starts_with(
|
||||
"Internal error: Arguments has mixed length. Expected length: 2, found length: 1."
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_subvector_invalid_indices() {
|
||||
let func = VectorSubvectorFunction;
|
||||
|
||||
let input0 = Arc::new(StringVector::from(vec![
|
||||
let input0 = Arc::new(StringViewArray::from(vec![
|
||||
Some("[1.0, 2.0, 3.0]".to_string()),
|
||||
Some("[4.0, 5.0, 6.0]".to_string()),
|
||||
]));
|
||||
let input1 = Arc::new(Int64Vector::from(vec![Some(1), Some(3)]));
|
||||
let input2 = Arc::new(Int64Vector::from(vec![Some(3), Some(4)]));
|
||||
let input1 = Arc::new(Int64Array::from(vec![Some(1), Some(3)]));
|
||||
let input2 = Arc::new(Int64Array::from(vec![Some(3), Some(4)]));
|
||||
|
||||
let result = func.eval(&FunctionContext::default(), &[input0, input1, input2]);
|
||||
|
||||
match result {
|
||||
Err(Error::InvalidFuncArgs { err_msg, .. }) => {
|
||||
assert_eq!(
|
||||
err_msg,
|
||||
"Invalid start and end indices: start=3, end=4, vec_len=3"
|
||||
)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
let args = ScalarFunctionArgs {
|
||||
args: vec![
|
||||
ColumnarValue::Array(input0),
|
||||
ColumnarValue::Array(input1),
|
||||
ColumnarValue::Array(input2),
|
||||
],
|
||||
arg_fields: vec![],
|
||||
number_rows: 3,
|
||||
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
};
|
||||
let e = func.invoke_with_args(args).unwrap_err();
|
||||
assert!(e.to_string().starts_with("External error: Invalid function args: Invalid start and end indices: start=3, end=4, vec_len=3"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -37,8 +37,7 @@ impl FunctionState {
|
||||
use catalog::CatalogManagerRef;
|
||||
use common_base::AffectedRows;
|
||||
use common_meta::rpc::procedure::{
|
||||
AddRegionFollowerRequest, MigrateRegionRequest, ProcedureStateResponse,
|
||||
RemoveRegionFollowerRequest,
|
||||
ManageRegionFollowerRequest, MigrateRegionRequest, ProcedureStateResponse,
|
||||
};
|
||||
use common_query::Output;
|
||||
use common_query::error::Result;
|
||||
@@ -75,13 +74,9 @@ impl FunctionState {
|
||||
})
|
||||
}
|
||||
|
||||
async fn add_region_follower(&self, _request: AddRegionFollowerRequest) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn remove_region_follower(
|
||||
async fn manage_region_follower(
|
||||
&self,
|
||||
_request: RemoveRegionFollowerRequest,
|
||||
_request: ManageRegionFollowerRequest,
|
||||
) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -16,11 +16,12 @@ use std::fmt;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::error::Result;
|
||||
use datafusion::arrow::array::StringViewArray;
|
||||
use datafusion::arrow::datatypes::DataType;
|
||||
use datafusion_expr::{Signature, Volatility};
|
||||
use datatypes::vectors::{StringVector, VectorRef};
|
||||
use datafusion::logical_expr::ColumnarValue;
|
||||
use datafusion_expr::{ScalarFunctionArgs, Signature, Volatility};
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
use crate::function::Function;
|
||||
|
||||
/// Generates build information
|
||||
#[derive(Clone, Debug, Default)]
|
||||
@@ -38,17 +39,18 @@ impl Function for BuildFunction {
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
|
||||
Ok(DataType::Utf8)
|
||||
Ok(DataType::Utf8View)
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
Signature::nullary(Volatility::Immutable)
|
||||
}
|
||||
|
||||
fn eval(&self, _func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
fn invoke_with_args(&self, _: ScalarFunctionArgs) -> datafusion_common::Result<ColumnarValue> {
|
||||
let build_info = common_version::build_info().to_string();
|
||||
let v = Arc::new(StringVector::from(vec![build_info]));
|
||||
Ok(v)
|
||||
Ok(ColumnarValue::Array(Arc::new(StringViewArray::from(vec![
|
||||
build_info,
|
||||
]))))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,16 +58,29 @@ impl Function for BuildFunction {
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_schema::Field;
|
||||
use datafusion::arrow::array::ArrayRef;
|
||||
use datafusion_common::config::ConfigOptions;
|
||||
|
||||
use super::*;
|
||||
#[test]
|
||||
fn test_build_function() {
|
||||
let build = BuildFunction;
|
||||
assert_eq!("build", build.name());
|
||||
assert_eq!(DataType::Utf8, build.return_type(&[]).unwrap());
|
||||
assert_eq!(DataType::Utf8View, build.return_type(&[]).unwrap());
|
||||
assert_eq!(build.signature(), Signature::nullary(Volatility::Immutable));
|
||||
let build_info = common_version::build_info().to_string();
|
||||
let vector = build.eval(&FunctionContext::default(), &[]).unwrap();
|
||||
let expect: VectorRef = Arc::new(StringVector::from(vec![build_info]));
|
||||
assert_eq!(expect, vector);
|
||||
let actual = build
|
||||
.invoke_with_args(ScalarFunctionArgs {
|
||||
args: vec![],
|
||||
arg_fields: vec![],
|
||||
number_rows: 0,
|
||||
return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
|
||||
config_options: Arc::new(ConfigOptions::new()),
|
||||
})
|
||||
.unwrap();
|
||||
let actual = ColumnarValue::values_to_arrays(&[actual]).unwrap();
|
||||
let expect = vec![Arc::new(StringViewArray::from(vec![build_info])) as ArrayRef];
|
||||
assert_eq!(actual, expect);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -280,6 +280,8 @@ fn build_struct(
|
||||
&self,
|
||||
args: datafusion::logical_expr::ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<datafusion_expr::ColumnarValue> {
|
||||
use common_error::ext::ErrorExt;
|
||||
|
||||
let columns = args.args
|
||||
.iter()
|
||||
.map(|arg| {
|
||||
@@ -293,7 +295,7 @@ fn build_struct(
|
||||
})
|
||||
})
|
||||
.collect::<common_query::error::Result<Vec<_>>>()
|
||||
.map_err(|e| datafusion_common::DataFusionError::Execution(format!("Column conversion error: {}", e)))?;
|
||||
.map_err(|e| datafusion_common::DataFusionError::Execution(format!("Column conversion error: {}", e.output_msg())))?;
|
||||
|
||||
// Safety check: Ensure under the `greptime` catalog for security
|
||||
#user_path::ensure_greptime!(self.func_ctx);
|
||||
@@ -314,14 +316,14 @@ fn build_struct(
|
||||
.#handler
|
||||
.as_ref()
|
||||
.context(#snafu_type)
|
||||
.map_err(|e| datafusion_common::DataFusionError::Execution(format!("Handler error: {}", e)))?;
|
||||
.map_err(|e| datafusion_common::DataFusionError::Execution(format!("Handler error: {}", e.output_msg())))?;
|
||||
|
||||
let mut builder = store_api::storage::ConcreteDataType::#ret()
|
||||
.create_mutable_vector(rows_num);
|
||||
|
||||
if columns_num == 0 {
|
||||
let result = #fn_name(handler, query_ctx, &[]).await
|
||||
.map_err(|e| datafusion_common::DataFusionError::Execution(format!("Function execution error: {}", e)))?;
|
||||
.map_err(|e| datafusion_common::DataFusionError::Execution(format!("Function execution error: {}", e.output_msg())))?;
|
||||
|
||||
builder.push_value_ref(result.as_value_ref());
|
||||
} else {
|
||||
@@ -331,7 +333,7 @@ fn build_struct(
|
||||
.collect();
|
||||
|
||||
let result = #fn_name(handler, query_ctx, &args).await
|
||||
.map_err(|e| datafusion_common::DataFusionError::Execution(format!("Function execution error: {}", e)))?;
|
||||
.map_err(|e| datafusion_common::DataFusionError::Execution(format!("Function execution error: {}", e.output_msg())))?;
|
||||
|
||||
builder.push_value_ref(result.as_value_ref());
|
||||
}
|
||||
|
||||
@@ -752,7 +752,6 @@ pub enum Error {
|
||||
location: Location,
|
||||
},
|
||||
|
||||
#[cfg(feature = "pg_kvbackend")]
|
||||
#[snafu(display("Failed to load TLS certificate from path: {}", path))]
|
||||
LoadTlsCertificate {
|
||||
path: String,
|
||||
@@ -1181,13 +1180,14 @@ impl ErrorExt for Error {
|
||||
| InvalidRole { .. }
|
||||
| EmptyDdlTasks { .. } => StatusCode::InvalidArguments,
|
||||
|
||||
LoadTlsCertificate { .. } => StatusCode::Internal,
|
||||
|
||||
#[cfg(feature = "pg_kvbackend")]
|
||||
PostgresExecution { .. }
|
||||
| CreatePostgresPool { .. }
|
||||
| GetPostgresConnection { .. }
|
||||
| PostgresTransaction { .. }
|
||||
| PostgresTlsConfig { .. }
|
||||
| LoadTlsCertificate { .. }
|
||||
| InvalidTlsConfig { .. } => StatusCode::Internal,
|
||||
#[cfg(feature = "mysql_kvbackend")]
|
||||
MySqlExecution { .. } | CreateMySqlPool { .. } | MySqlTransaction { .. } => {
|
||||
|
||||
@@ -13,15 +13,17 @@
|
||||
// limitations under the License.
|
||||
|
||||
use std::any::Any;
|
||||
use std::fs;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_telemetry::info;
|
||||
use common_telemetry::{debug, info};
|
||||
use etcd_client::{
|
||||
Client, DeleteOptions, GetOptions, PutOptions, Txn, TxnOp, TxnOpResponse, TxnResponse,
|
||||
Certificate, Client, DeleteOptions, GetOptions, Identity, PutOptions, TlsOptions, Txn, TxnOp,
|
||||
TxnOpResponse, TxnResponse,
|
||||
};
|
||||
use snafu::{ResultExt, ensure};
|
||||
|
||||
use crate::error::{self, Error, Result};
|
||||
use crate::error::{self, Error, LoadTlsCertificateSnafu, Result};
|
||||
use crate::kv_backend::txn::{Txn as KvTxn, TxnResponse as KvTxnResponse};
|
||||
use crate::kv_backend::{KvBackend, KvBackendRef, TxnService};
|
||||
use crate::metrics::METRIC_META_TXN_REQUEST;
|
||||
@@ -451,8 +453,76 @@ impl TryFrom<DeleteRangeRequest> for Delete {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Default)]
|
||||
pub enum TlsMode {
|
||||
#[default]
|
||||
Disable,
|
||||
Require,
|
||||
}
|
||||
|
||||
/// TLS configuration for Etcd connections.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct TlsOption {
|
||||
pub mode: TlsMode,
|
||||
pub cert_path: String,
|
||||
pub key_path: String,
|
||||
pub ca_cert_path: String,
|
||||
}
|
||||
|
||||
/// Creates a Etcd [`TlsOptions`] from a [`TlsOption`].
|
||||
///
|
||||
/// This function builds the TLS options for etcd client connections based on the provided
|
||||
/// [`TlsOption`]. It supports disabling TLS, setting a custom CA certificate, and configuring
|
||||
/// client identity for mutual TLS authentication.
|
||||
///
|
||||
/// Note: All TlsMode variants except [`TlsMode::Disable`] will be treated as enabling TLS.
|
||||
pub fn create_etcd_tls_options(tls_config: &TlsOption) -> Result<Option<TlsOptions>> {
|
||||
// If TLS mode is disabled, return None to indicate no TLS configuration.
|
||||
if matches!(tls_config.mode, TlsMode::Disable) {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
info!("Creating etcd TLS with mode: {:?}", tls_config.mode);
|
||||
// Start with default TLS options.
|
||||
let mut etcd_tls_opts = TlsOptions::new();
|
||||
|
||||
// If a CA certificate path is provided, load the CA certificate and add it to the options.
|
||||
if !tls_config.ca_cert_path.is_empty() {
|
||||
debug!("Using CA certificate from {}", tls_config.ca_cert_path);
|
||||
let ca_cert_pem = fs::read(&tls_config.ca_cert_path).context(LoadTlsCertificateSnafu {
|
||||
path: &tls_config.ca_cert_path,
|
||||
})?;
|
||||
let ca_cert = Certificate::from_pem(ca_cert_pem);
|
||||
etcd_tls_opts = etcd_tls_opts.ca_certificate(ca_cert);
|
||||
}
|
||||
|
||||
// If both client certificate and key paths are provided, load them and set the client identity.
|
||||
if !tls_config.cert_path.is_empty() && !tls_config.key_path.is_empty() {
|
||||
info!("Loading client certificate for mutual TLS");
|
||||
debug!(
|
||||
"Using client certificate from {} and key from {}",
|
||||
tls_config.cert_path, tls_config.key_path
|
||||
);
|
||||
let cert_pem = fs::read(&tls_config.cert_path).context(LoadTlsCertificateSnafu {
|
||||
path: &tls_config.cert_path,
|
||||
})?;
|
||||
let key_pem = fs::read(&tls_config.key_path).context(LoadTlsCertificateSnafu {
|
||||
path: &tls_config.key_path,
|
||||
})?;
|
||||
let identity = Identity::from_pem(cert_pem, key_pem);
|
||||
etcd_tls_opts = etcd_tls_opts.identity(identity);
|
||||
}
|
||||
|
||||
// Always enable native TLS roots for additional trust anchors.
|
||||
etcd_tls_opts = etcd_tls_opts.with_native_roots();
|
||||
|
||||
Ok(Some(etcd_tls_opts))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use etcd_client::ConnectOptions;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
@@ -555,6 +625,8 @@ mod tests {
|
||||
test_txn_compare_not_equal, test_txn_one_compare_op, text_txn_multi_compare_op,
|
||||
unprepare_kv,
|
||||
};
|
||||
use crate::maybe_skip_etcd_tls_integration_test;
|
||||
use crate::test_util::etcd_certs_dir;
|
||||
|
||||
async fn build_kv_backend() -> Option<EtcdStore> {
|
||||
let endpoints = std::env::var("GT_ETCD_ENDPOINTS").unwrap_or_default();
|
||||
@@ -654,4 +726,41 @@ mod tests {
|
||||
test_txn_compare_not_equal(&kv_backend).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn create_etcd_client_with_tls(endpoints: &[String], tls_config: &TlsOption) -> Client {
|
||||
let endpoints = endpoints
|
||||
.iter()
|
||||
.map(|s| s.trim())
|
||||
.filter(|s| !s.is_empty())
|
||||
.collect::<Vec<_>>();
|
||||
let connect_options =
|
||||
ConnectOptions::new().with_tls(create_etcd_tls_options(tls_config).unwrap().unwrap());
|
||||
|
||||
Client::connect(&endpoints, Some(connect_options))
|
||||
.await
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_etcd_client_with_mtls_and_ca() {
|
||||
maybe_skip_etcd_tls_integration_test!();
|
||||
let endpoints = std::env::var("GT_ETCD_TLS_ENDPOINTS")
|
||||
.unwrap()
|
||||
.split(',')
|
||||
.map(|s| s.to_string())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let cert_dir = etcd_certs_dir();
|
||||
let tls_config = TlsOption {
|
||||
mode: TlsMode::Require,
|
||||
ca_cert_path: cert_dir.join("ca.crt").to_string_lossy().to_string(),
|
||||
cert_path: cert_dir.join("client.crt").to_string_lossy().to_string(),
|
||||
key_path: cert_dir
|
||||
.join("client-key.pem")
|
||||
.to_string_lossy()
|
||||
.to_string(),
|
||||
};
|
||||
let mut client = create_etcd_client_with_tls(&endpoints, &tls_config).await;
|
||||
let _ = client.get(b"hello", None).await.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -573,6 +573,7 @@ impl MySqlStore {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use common_telemetry::init_default_ut_logging;
|
||||
use sqlx::mysql::{MySqlConnectOptions, MySqlSslMode};
|
||||
|
||||
use super::*;
|
||||
use crate::kv_backend::test::{
|
||||
@@ -584,6 +585,7 @@ mod tests {
|
||||
text_txn_multi_compare_op, unprepare_kv,
|
||||
};
|
||||
use crate::maybe_skip_mysql_integration_test;
|
||||
use crate::test_util::test_certs_dir;
|
||||
|
||||
async fn build_mysql_kv_backend(table_name: &str) -> Option<MySqlStore> {
|
||||
init_default_ut_logging();
|
||||
@@ -711,4 +713,71 @@ mod tests {
|
||||
test_txn_compare_less(&kv_backend).await;
|
||||
test_txn_compare_not_equal(&kv_backend).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mysql_with_tls() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
maybe_skip_mysql_integration_test!();
|
||||
let endpoint = std::env::var("GT_MYSQL_ENDPOINTS").unwrap();
|
||||
|
||||
let opts = endpoint
|
||||
.parse::<MySqlConnectOptions>()
|
||||
.unwrap()
|
||||
.ssl_mode(MySqlSslMode::Required);
|
||||
let pool = MySqlPool::connect_with(opts).await.unwrap();
|
||||
sqlx::query("SELECT 1").execute(&pool).await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mysql_with_mtls() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
maybe_skip_mysql_integration_test!();
|
||||
let endpoint = std::env::var("GT_MYSQL_ENDPOINTS").unwrap();
|
||||
let certs_dir = test_certs_dir();
|
||||
|
||||
let opts = endpoint
|
||||
.parse::<MySqlConnectOptions>()
|
||||
.unwrap()
|
||||
.ssl_mode(MySqlSslMode::Required)
|
||||
.ssl_client_cert(certs_dir.join("client.crt").to_string_lossy().to_string())
|
||||
.ssl_client_key(certs_dir.join("client.key").to_string_lossy().to_string());
|
||||
let pool = MySqlPool::connect_with(opts).await.unwrap();
|
||||
sqlx::query("SELECT 1").execute(&pool).await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mysql_with_tls_verify_ca() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
maybe_skip_mysql_integration_test!();
|
||||
let endpoint = std::env::var("GT_MYSQL_ENDPOINTS").unwrap();
|
||||
let certs_dir = test_certs_dir();
|
||||
|
||||
let opts = endpoint
|
||||
.parse::<MySqlConnectOptions>()
|
||||
.unwrap()
|
||||
.ssl_mode(MySqlSslMode::VerifyCa)
|
||||
.ssl_ca(certs_dir.join("root.crt").to_string_lossy().to_string())
|
||||
.ssl_client_cert(certs_dir.join("client.crt").to_string_lossy().to_string())
|
||||
.ssl_client_key(certs_dir.join("client.key").to_string_lossy().to_string());
|
||||
let pool = MySqlPool::connect_with(opts).await.unwrap();
|
||||
sqlx::query("SELECT 1").execute(&pool).await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mysql_with_tls_verify_ident() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
maybe_skip_mysql_integration_test!();
|
||||
let endpoint = std::env::var("GT_MYSQL_ENDPOINTS").unwrap();
|
||||
let certs_dir = test_certs_dir();
|
||||
|
||||
let opts = endpoint
|
||||
.parse::<MySqlConnectOptions>()
|
||||
.unwrap()
|
||||
.ssl_mode(MySqlSslMode::VerifyIdentity)
|
||||
.ssl_ca(certs_dir.join("root.crt").to_string_lossy().to_string())
|
||||
.ssl_client_cert(certs_dir.join("client.crt").to_string_lossy().to_string())
|
||||
.ssl_client_key(certs_dir.join("client.key").to_string_lossy().to_string());
|
||||
let pool = MySqlPool::connect_with(opts).await.unwrap();
|
||||
sqlx::query("SELECT 1").execute(&pool).await.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -903,6 +903,7 @@ mod tests {
|
||||
test_txn_compare_less, test_txn_compare_not_equal, test_txn_one_compare_op,
|
||||
text_txn_multi_compare_op, unprepare_kv,
|
||||
};
|
||||
use crate::test_util::test_certs_dir;
|
||||
use crate::{maybe_skip_postgres_integration_test, maybe_skip_postgres15_integration_test};
|
||||
|
||||
async fn build_pg_kv_backend(table_name: &str) -> Option<PgStore> {
|
||||
@@ -993,6 +994,97 @@ mod tests {
|
||||
unprepare_kv(&kv, prefix).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pg_with_tls() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
maybe_skip_postgres_integration_test!();
|
||||
let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap();
|
||||
let tls_connector = create_postgres_tls_connector(&TlsOption {
|
||||
mode: TlsMode::Require,
|
||||
cert_path: String::new(),
|
||||
key_path: String::new(),
|
||||
ca_cert_path: String::new(),
|
||||
watch: false,
|
||||
})
|
||||
.unwrap();
|
||||
let mut cfg = Config::new();
|
||||
cfg.url = Some(endpoints);
|
||||
let pool = cfg
|
||||
.create_pool(Some(Runtime::Tokio1), tls_connector)
|
||||
.unwrap();
|
||||
let client = pool.get().await.unwrap();
|
||||
client.execute("SELECT 1", &[]).await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pg_with_mtls() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
maybe_skip_postgres_integration_test!();
|
||||
let certs_dir = test_certs_dir();
|
||||
let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap();
|
||||
let tls_connector = create_postgres_tls_connector(&TlsOption {
|
||||
mode: TlsMode::Require,
|
||||
cert_path: certs_dir.join("client.crt").display().to_string(),
|
||||
key_path: certs_dir.join("client.key").display().to_string(),
|
||||
ca_cert_path: String::new(),
|
||||
watch: false,
|
||||
})
|
||||
.unwrap();
|
||||
let mut cfg = Config::new();
|
||||
cfg.url = Some(endpoints);
|
||||
let pool = cfg
|
||||
.create_pool(Some(Runtime::Tokio1), tls_connector)
|
||||
.unwrap();
|
||||
let client = pool.get().await.unwrap();
|
||||
client.execute("SELECT 1", &[]).await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pg_verify_ca() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
maybe_skip_postgres_integration_test!();
|
||||
let certs_dir = test_certs_dir();
|
||||
let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap();
|
||||
let tls_connector = create_postgres_tls_connector(&TlsOption {
|
||||
mode: TlsMode::VerifyCa,
|
||||
cert_path: certs_dir.join("client.crt").display().to_string(),
|
||||
key_path: certs_dir.join("client.key").display().to_string(),
|
||||
ca_cert_path: certs_dir.join("root.crt").display().to_string(),
|
||||
watch: false,
|
||||
})
|
||||
.unwrap();
|
||||
let mut cfg = Config::new();
|
||||
cfg.url = Some(endpoints);
|
||||
let pool = cfg
|
||||
.create_pool(Some(Runtime::Tokio1), tls_connector)
|
||||
.unwrap();
|
||||
let client = pool.get().await.unwrap();
|
||||
client.execute("SELECT 1", &[]).await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pg_verify_full() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
maybe_skip_postgres_integration_test!();
|
||||
let certs_dir = test_certs_dir();
|
||||
let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap();
|
||||
let tls_connector = create_postgres_tls_connector(&TlsOption {
|
||||
mode: TlsMode::VerifyFull,
|
||||
cert_path: certs_dir.join("client.crt").display().to_string(),
|
||||
key_path: certs_dir.join("client.key").display().to_string(),
|
||||
ca_cert_path: certs_dir.join("root.crt").display().to_string(),
|
||||
watch: false,
|
||||
})
|
||||
.unwrap();
|
||||
let mut cfg = Config::new();
|
||||
cfg.url = Some(endpoints);
|
||||
let pool = cfg
|
||||
.create_pool(Some(Runtime::Tokio1), tls_connector)
|
||||
.unwrap();
|
||||
let client = pool.get().await.unwrap();
|
||||
client.execute("SELECT 1", &[]).await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pg_put() {
|
||||
maybe_skip_postgres_integration_test!();
|
||||
|
||||
@@ -25,8 +25,8 @@ use crate::error::{
|
||||
};
|
||||
use crate::rpc::ddl::{SubmitDdlTaskRequest, SubmitDdlTaskResponse};
|
||||
use crate::rpc::procedure::{
|
||||
self, AddRegionFollowerRequest, MigrateRegionRequest, MigrateRegionResponse,
|
||||
ProcedureStateResponse, RemoveRegionFollowerRequest,
|
||||
self, ManageRegionFollowerRequest, MigrateRegionRequest, MigrateRegionResponse,
|
||||
ProcedureStateResponse,
|
||||
};
|
||||
|
||||
/// The context of procedure executor.
|
||||
@@ -45,26 +45,14 @@ pub trait ProcedureExecutor: Send + Sync {
|
||||
request: SubmitDdlTaskRequest,
|
||||
) -> Result<SubmitDdlTaskResponse>;
|
||||
|
||||
/// Add a region follower
|
||||
async fn add_region_follower(
|
||||
/// Submit ad manage region follower task
|
||||
async fn manage_region_follower(
|
||||
&self,
|
||||
_ctx: &ExecutorContext,
|
||||
_request: AddRegionFollowerRequest,
|
||||
_request: ManageRegionFollowerRequest,
|
||||
) -> Result<()> {
|
||||
UnsupportedSnafu {
|
||||
operation: "add_region_follower",
|
||||
}
|
||||
.fail()
|
||||
}
|
||||
|
||||
/// Remove a region follower
|
||||
async fn remove_region_follower(
|
||||
&self,
|
||||
_ctx: &ExecutorContext,
|
||||
_request: RemoveRegionFollowerRequest,
|
||||
) -> Result<()> {
|
||||
UnsupportedSnafu {
|
||||
operation: "remove_region_follower",
|
||||
operation: "manage_region_follower",
|
||||
}
|
||||
.fail()
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ use api::v1::meta::{
|
||||
use common_error::ext::ErrorExt;
|
||||
use common_procedure::{ProcedureId, ProcedureInfo, ProcedureState};
|
||||
use snafu::ResultExt;
|
||||
use table::metadata::TableId;
|
||||
|
||||
use crate::error::{ParseProcedureIdSnafu, Result};
|
||||
|
||||
@@ -44,6 +45,30 @@ pub struct AddRegionFollowerRequest {
|
||||
pub peer_id: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AddTableFollowerRequest {
|
||||
pub catalog_name: String,
|
||||
pub schema_name: String,
|
||||
pub table_name: String,
|
||||
pub table_id: TableId,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RemoveTableFollowerRequest {
|
||||
pub catalog_name: String,
|
||||
pub schema_name: String,
|
||||
pub table_name: String,
|
||||
pub table_id: TableId,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ManageRegionFollowerRequest {
|
||||
AddRegionFollower(AddRegionFollowerRequest),
|
||||
RemoveRegionFollower(RemoveRegionFollowerRequest),
|
||||
AddTableFollower(AddTableFollowerRequest),
|
||||
RemoveTableFollower(RemoveTableFollowerRequest),
|
||||
}
|
||||
|
||||
/// A request to remove region follower.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RemoveRegionFollowerRequest {
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use api::region::RegionResponse;
|
||||
@@ -299,3 +300,39 @@ macro_rules! maybe_skip_postgres15_integration_test {
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
/// Skip the test if the environment variable `GT_ETCD_TLS_ENDPOINTS` is not set.
|
||||
///
|
||||
/// The format of the environment variable is:
|
||||
/// ```text
|
||||
/// GT_ETCD_TLS_ENDPOINTS=localhost:9092,localhost:9093
|
||||
/// ```
|
||||
macro_rules! maybe_skip_etcd_tls_integration_test {
|
||||
() => {
|
||||
if std::env::var("GT_ETCD_TLS_ENDPOINTS").is_err() {
|
||||
common_telemetry::warn!("The etcd with tls endpoints is empty, skipping the test");
|
||||
return;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Returns the directory of the etcd TLS certs.
|
||||
pub fn etcd_certs_dir() -> PathBuf {
|
||||
let project_path = env!("CARGO_MANIFEST_DIR");
|
||||
let project_path = PathBuf::from(project_path);
|
||||
let base = project_path.ancestors().nth(3).unwrap();
|
||||
base.join("tests-integration")
|
||||
.join("fixtures")
|
||||
.join("etcd-tls-certs")
|
||||
}
|
||||
|
||||
/// Returns the directory of the test certs.
|
||||
pub fn test_certs_dir() -> PathBuf {
|
||||
let project_path = env!("CARGO_MANIFEST_DIR");
|
||||
let project_path = PathBuf::from(project_path);
|
||||
let base = project_path.ancestors().nth(3).unwrap();
|
||||
base.join("tests-integration")
|
||||
.join("fixtures")
|
||||
.join("certs")
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ use datatypes::prelude::ConcreteDataType;
|
||||
use datatypes::vectors::{Helper, VectorRef};
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::error::{self, GeneralDataFusionSnafu, IntoVectorSnafu, Result};
|
||||
use crate::error::{self, IntoVectorSnafu, Result};
|
||||
use crate::prelude::ScalarValue;
|
||||
|
||||
/// Represents the result from an expression
|
||||
@@ -43,9 +43,7 @@ impl ColumnarValue {
|
||||
Ok(match self {
|
||||
ColumnarValue::Vector(v) => v,
|
||||
ColumnarValue::Scalar(s) => {
|
||||
let v = s
|
||||
.to_array_of_size(num_rows)
|
||||
.context(GeneralDataFusionSnafu)?;
|
||||
let v = s.to_array_of_size(num_rows)?;
|
||||
let data_type = v.data_type().clone();
|
||||
Helper::try_into_vector(v).context(IntoVectorSnafu { data_type })?
|
||||
}
|
||||
|
||||
@@ -78,7 +78,7 @@ pub enum Error {
|
||||
location: Location,
|
||||
},
|
||||
|
||||
#[snafu(display("General DataFusion error"))]
|
||||
#[snafu(transparent)]
|
||||
GeneralDataFusion {
|
||||
#[snafu(source)]
|
||||
error: DataFusionError,
|
||||
|
||||
@@ -24,9 +24,8 @@ use datafusion_common::{Column, TableReference};
|
||||
use datafusion_expr::dml::InsertOp;
|
||||
use datafusion_expr::{DmlStatement, TableSource, WriteOp, col};
|
||||
pub use expr::{build_filter_from_timestamp, build_same_type_ts_filter};
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::error::{GeneralDataFusionSnafu, Result};
|
||||
use crate::error::Result;
|
||||
|
||||
/// Rename columns by applying a new projection. Returns an error if the column to be
|
||||
/// renamed does not exist. The `renames` parameter is a `Vector` with elements
|
||||
@@ -122,7 +121,7 @@ pub fn add_insert_to_logical_plan(
|
||||
WriteOp::Insert(InsertOp::Append),
|
||||
Arc::new(input),
|
||||
));
|
||||
let plan = plan.recompute_schema().context(GeneralDataFusionSnafu)?;
|
||||
let plan = plan.recompute_schema()?;
|
||||
Ok(plan)
|
||||
}
|
||||
|
||||
|
||||
@@ -173,6 +173,9 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_from_tz_string() {
|
||||
unsafe {
|
||||
std::env::remove_var("TZ");
|
||||
}
|
||||
assert_eq!(
|
||||
Timezone::Named(Tz::UTC),
|
||||
Timezone::from_tz_string("SYSTEM").unwrap()
|
||||
|
||||
@@ -72,7 +72,7 @@ impl RegionServer {
|
||||
})?
|
||||
};
|
||||
|
||||
let entries = mito.all_ssts_from_manifest().collect::<Vec<_>>();
|
||||
let entries = mito.all_ssts_from_manifest().await;
|
||||
let schema = ManifestSstEntry::schema().arrow_schema().clone();
|
||||
let batch = ManifestSstEntry::to_record_batch(&entries)
|
||||
.map_err(DataFusionError::from)
|
||||
|
||||
@@ -252,7 +252,7 @@ impl RegionEngine for MockRegionEngine {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
async fn get_last_seq_num(&self, _: RegionId) -> Result<Option<SequenceNumber>, BoxedError> {
|
||||
async fn get_committed_sequence(&self, _: RegionId) -> Result<SequenceNumber, BoxedError> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
|
||||
@@ -115,8 +115,8 @@ impl RegionEngine for FileRegionEngine {
|
||||
None
|
||||
}
|
||||
|
||||
async fn get_last_seq_num(&self, _: RegionId) -> Result<Option<SequenceNumber>, BoxedError> {
|
||||
Ok(None)
|
||||
async fn get_committed_sequence(&self, _: RegionId) -> Result<SequenceNumber, BoxedError> {
|
||||
Ok(Default::default())
|
||||
}
|
||||
|
||||
fn set_region_role(&self, region_id: RegionId, role: RegionRole) -> Result<(), BoxedError> {
|
||||
|
||||
@@ -376,34 +376,16 @@ impl Instance {
|
||||
ctx: QueryContextRef,
|
||||
) -> server_error::Result<bool> {
|
||||
let db_string = ctx.get_db_string();
|
||||
// fast cache check
|
||||
let cache = self
|
||||
.otlp_metrics_table_legacy_cache
|
||||
.entry(db_string)
|
||||
.entry(db_string.clone())
|
||||
.or_default();
|
||||
|
||||
// check cache
|
||||
let hit_cache = names
|
||||
.iter()
|
||||
.filter_map(|name| cache.get(*name))
|
||||
.collect::<Vec<_>>();
|
||||
if !hit_cache.is_empty() {
|
||||
let hit_legacy = hit_cache.iter().any(|en| *en.value());
|
||||
let hit_prom = hit_cache.iter().any(|en| !*en.value());
|
||||
|
||||
// hit but have true and false, means both legacy and new mode are used
|
||||
// we cannot handle this case, so return error
|
||||
// add doc links in err msg later
|
||||
ensure!(!(hit_legacy && hit_prom), OtlpMetricModeIncompatibleSnafu);
|
||||
|
||||
let flag = hit_legacy;
|
||||
// set cache for all names
|
||||
names.iter().for_each(|name| {
|
||||
if !cache.contains_key(*name) {
|
||||
cache.insert(name.to_string(), flag);
|
||||
}
|
||||
});
|
||||
if let Some(flag) = fast_legacy_check(&cache, names)? {
|
||||
return Ok(flag);
|
||||
}
|
||||
// release cache reference to avoid lock contention
|
||||
drop(cache);
|
||||
|
||||
let catalog = ctx.current_catalog();
|
||||
let schema = ctx.current_schema();
|
||||
@@ -430,7 +412,10 @@ impl Instance {
|
||||
|
||||
// means no existing table is found, use new mode
|
||||
if table_ids.is_empty() {
|
||||
// set cache
|
||||
let cache = self
|
||||
.otlp_metrics_table_legacy_cache
|
||||
.entry(db_string)
|
||||
.or_default();
|
||||
names.iter().for_each(|name| {
|
||||
cache.insert(name.to_string(), false);
|
||||
});
|
||||
@@ -455,6 +440,10 @@ impl Instance {
|
||||
.unwrap_or(&OTLP_LEGACY_DEFAULT_VALUE)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let cache = self
|
||||
.otlp_metrics_table_legacy_cache
|
||||
.entry(db_string)
|
||||
.or_default();
|
||||
if !options.is_empty() {
|
||||
// check value consistency
|
||||
let has_prom = options.iter().any(|opt| *opt == OTLP_METRIC_COMPAT_PROM);
|
||||
@@ -477,6 +466,39 @@ impl Instance {
|
||||
}
|
||||
}
|
||||
|
||||
fn fast_legacy_check(
|
||||
cache: &DashMap<String, bool>,
|
||||
names: &[&String],
|
||||
) -> server_error::Result<Option<bool>> {
|
||||
let hit_cache = names
|
||||
.iter()
|
||||
.filter_map(|name| cache.get(*name))
|
||||
.collect::<Vec<_>>();
|
||||
if !hit_cache.is_empty() {
|
||||
let hit_legacy = hit_cache.iter().any(|en| *en.value());
|
||||
let hit_prom = hit_cache.iter().any(|en| !*en.value());
|
||||
|
||||
// hit but have true and false, means both legacy and new mode are used
|
||||
// we cannot handle this case, so return error
|
||||
// add doc links in err msg later
|
||||
ensure!(!(hit_legacy && hit_prom), OtlpMetricModeIncompatibleSnafu);
|
||||
|
||||
let flag = hit_legacy;
|
||||
// drop hit_cache to release references before inserting to avoid deadlock
|
||||
drop(hit_cache);
|
||||
|
||||
// set cache for all names
|
||||
names.iter().for_each(|name| {
|
||||
if !cache.contains_key(*name) {
|
||||
cache.insert(name.to_string(), flag);
|
||||
}
|
||||
});
|
||||
Ok(Some(flag))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// If the relevant variables are set, the timeout is enforced for all PostgreSQL statements.
|
||||
/// For MySQL, it applies only to read-only statements.
|
||||
fn derive_timeout(stmt: &Statement, query_ctx: &QueryContextRef) -> Option<Duration> {
|
||||
@@ -1041,6 +1063,10 @@ fn should_capture_statement(stmt: Option<&Statement>) -> bool {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::{Arc, Barrier};
|
||||
use std::thread;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use common_base::Plugins;
|
||||
use query::query_engine::options::QueryOptions;
|
||||
@@ -1050,6 +1076,122 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_fast_legacy_check_deadlock_prevention() {
|
||||
// Create a DashMap to simulate the cache
|
||||
let cache = DashMap::new();
|
||||
|
||||
// Pre-populate cache with some entries
|
||||
cache.insert("metric1".to_string(), true); // legacy mode
|
||||
cache.insert("metric2".to_string(), false); // prom mode
|
||||
cache.insert("metric3".to_string(), true); // legacy mode
|
||||
|
||||
// Test case 1: Normal operation with cache hits
|
||||
let metric1 = "metric1".to_string();
|
||||
let metric4 = "metric4".to_string();
|
||||
let names1 = vec![&metric1, &metric4];
|
||||
let result = fast_legacy_check(&cache, &names1);
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap(), Some(true)); // should return legacy mode
|
||||
|
||||
// Verify that metric4 was added to cache
|
||||
assert!(cache.contains_key("metric4"));
|
||||
assert!(*cache.get("metric4").unwrap().value());
|
||||
|
||||
// Test case 2: No cache hits
|
||||
let metric5 = "metric5".to_string();
|
||||
let metric6 = "metric6".to_string();
|
||||
let names2 = vec![&metric5, &metric6];
|
||||
let result = fast_legacy_check(&cache, &names2);
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap(), None); // should return None as no cache hits
|
||||
|
||||
// Test case 3: Incompatible modes should return error
|
||||
let cache_incompatible = DashMap::new();
|
||||
cache_incompatible.insert("metric1".to_string(), true); // legacy
|
||||
cache_incompatible.insert("metric2".to_string(), false); // prom
|
||||
let metric1_test = "metric1".to_string();
|
||||
let metric2_test = "metric2".to_string();
|
||||
let names3 = vec![&metric1_test, &metric2_test];
|
||||
let result = fast_legacy_check(&cache_incompatible, &names3);
|
||||
assert!(result.is_err()); // should error due to incompatible modes
|
||||
|
||||
// Test case 4: Intensive concurrent access to test deadlock prevention
|
||||
// This test specifically targets the scenario where multiple threads
|
||||
// access the same cache entries simultaneously
|
||||
let cache_concurrent = Arc::new(DashMap::new());
|
||||
cache_concurrent.insert("shared_metric".to_string(), true);
|
||||
|
||||
let num_threads = 8;
|
||||
let operations_per_thread = 100;
|
||||
let barrier = Arc::new(Barrier::new(num_threads));
|
||||
let success_flag = Arc::new(AtomicBool::new(true));
|
||||
|
||||
let handles: Vec<_> = (0..num_threads)
|
||||
.map(|thread_id| {
|
||||
let cache_clone = Arc::clone(&cache_concurrent);
|
||||
let barrier_clone = Arc::clone(&barrier);
|
||||
let success_flag_clone = Arc::clone(&success_flag);
|
||||
|
||||
thread::spawn(move || {
|
||||
// Wait for all threads to be ready
|
||||
barrier_clone.wait();
|
||||
|
||||
let start_time = Instant::now();
|
||||
for i in 0..operations_per_thread {
|
||||
// Each operation references existing cache entry and adds new ones
|
||||
let shared_metric = "shared_metric".to_string();
|
||||
let new_metric = format!("thread_{}_metric_{}", thread_id, i);
|
||||
let names = vec![&shared_metric, &new_metric];
|
||||
|
||||
match fast_legacy_check(&cache_clone, &names) {
|
||||
Ok(_) => {}
|
||||
Err(_) => {
|
||||
success_flag_clone.store(false, Ordering::Relaxed);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// If the test takes too long, it likely means deadlock
|
||||
if start_time.elapsed() > Duration::from_secs(10) {
|
||||
success_flag_clone.store(false, Ordering::Relaxed);
|
||||
return;
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Join all threads with timeout
|
||||
let start_time = Instant::now();
|
||||
for (i, handle) in handles.into_iter().enumerate() {
|
||||
let join_result = handle.join();
|
||||
|
||||
// Check if we're taking too long (potential deadlock)
|
||||
if start_time.elapsed() > Duration::from_secs(30) {
|
||||
panic!("Test timed out - possible deadlock detected!");
|
||||
}
|
||||
|
||||
if join_result.is_err() {
|
||||
panic!("Thread {} panicked during execution", i);
|
||||
}
|
||||
}
|
||||
|
||||
// Verify all operations completed successfully
|
||||
assert!(
|
||||
success_flag.load(Ordering::Relaxed),
|
||||
"Some operations failed"
|
||||
);
|
||||
|
||||
// Verify that many new entries were added (proving operations completed)
|
||||
let final_count = cache_concurrent.len();
|
||||
assert!(
|
||||
final_count > 1 + num_threads * operations_per_thread / 2,
|
||||
"Expected more cache entries, got {}",
|
||||
final_count
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_exec_validation() {
|
||||
let query_ctx = QueryContext::arc();
|
||||
|
||||
@@ -43,9 +43,9 @@ use table::table::adapter::DfTableProviderAdapter;
|
||||
use table::table_name::TableName;
|
||||
|
||||
use crate::error::{
|
||||
CatalogSnafu, Error, ExternalSnafu, IncompleteGrpcRequestSnafu, NotSupportedSnafu,
|
||||
PermissionSnafu, PlanStatementSnafu, Result, SubstraitDecodeLogicalPlanSnafu,
|
||||
TableNotFoundSnafu, TableOperationSnafu,
|
||||
CatalogSnafu, DataFusionSnafu, Error, ExternalSnafu, IncompleteGrpcRequestSnafu,
|
||||
NotSupportedSnafu, PermissionSnafu, PlanStatementSnafu, Result,
|
||||
SubstraitDecodeLogicalPlanSnafu, TableNotFoundSnafu, TableOperationSnafu,
|
||||
};
|
||||
use crate::instance::{Instance, attach_timer};
|
||||
use crate::metrics::{
|
||||
@@ -395,14 +395,10 @@ impl Instance {
|
||||
let analyzed_plan = state
|
||||
.analyzer()
|
||||
.execute_and_check(insert_into, state.config_options(), |_, _| {})
|
||||
.context(common_query::error::GeneralDataFusionSnafu)
|
||||
.context(SubstraitDecodeLogicalPlanSnafu)?;
|
||||
.context(DataFusionSnafu)?;
|
||||
|
||||
// Optimize the plan
|
||||
let optimized_plan = state
|
||||
.optimize(&analyzed_plan)
|
||||
.context(common_query::error::GeneralDataFusionSnafu)
|
||||
.context(SubstraitDecodeLogicalPlanSnafu)?;
|
||||
let optimized_plan = state.optimize(&analyzed_plan).context(DataFusionSnafu)?;
|
||||
|
||||
let output = SqlQueryHandler::do_exec_plan(self, None, optimized_plan, ctx.clone()).await?;
|
||||
|
||||
|
||||
@@ -91,6 +91,21 @@ impl Filters {
|
||||
Filters::Single(filter)
|
||||
}
|
||||
}
|
||||
/// Aggregation function with optional range and alias.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AggFunc {
|
||||
/// Function name, e.g., "count", "sum", etc.
|
||||
pub name: String,
|
||||
/// Arguments to the function. e.g., column references or literals. LogExpr::NamedIdent("column1".to_string())
|
||||
pub args: Vec<LogExpr>,
|
||||
pub alias: Option<String>,
|
||||
}
|
||||
|
||||
impl AggFunc {
|
||||
pub fn new(name: String, args: Vec<LogExpr>, alias: Option<String>) -> Self {
|
||||
Self { name, args, alias }
|
||||
}
|
||||
}
|
||||
|
||||
/// Expression to calculate on log after filtering.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -103,13 +118,11 @@ pub enum LogExpr {
|
||||
args: Vec<LogExpr>,
|
||||
alias: Option<String>,
|
||||
},
|
||||
/// Aggregation function with optional grouping.
|
||||
AggrFunc {
|
||||
name: String,
|
||||
args: Vec<LogExpr>,
|
||||
/// Optional range function parameter. Stands for the time range for both step and align.
|
||||
range: Option<String>,
|
||||
/// Function name, arguments, and optional alias.
|
||||
expr: Vec<AggFunc>,
|
||||
by: Vec<LogExpr>,
|
||||
alias: Option<String>,
|
||||
},
|
||||
Decompose {
|
||||
expr: Box<LogExpr>,
|
||||
|
||||
@@ -44,8 +44,9 @@ use common_meta::range_stream::PaginationStream;
|
||||
use common_meta::rpc::KeyValue;
|
||||
use common_meta::rpc::ddl::{SubmitDdlTaskRequest, SubmitDdlTaskResponse};
|
||||
use common_meta::rpc::procedure::{
|
||||
AddRegionFollowerRequest, MigrateRegionRequest, MigrateRegionResponse, ProcedureStateResponse,
|
||||
RemoveRegionFollowerRequest,
|
||||
AddRegionFollowerRequest, AddTableFollowerRequest, ManageRegionFollowerRequest,
|
||||
MigrateRegionRequest, MigrateRegionResponse, ProcedureStateResponse,
|
||||
RemoveRegionFollowerRequest, RemoveTableFollowerRequest,
|
||||
};
|
||||
use common_meta::rpc::store::{
|
||||
BatchDeleteRequest, BatchDeleteResponse, BatchGetRequest, BatchGetResponse, BatchPutRequest,
|
||||
@@ -246,6 +247,10 @@ pub trait RegionFollowerClient: Sync + Send + Debug {
|
||||
|
||||
async fn remove_region_follower(&self, request: RemoveRegionFollowerRequest) -> Result<()>;
|
||||
|
||||
async fn add_table_follower(&self, request: AddTableFollowerRequest) -> Result<()>;
|
||||
|
||||
async fn remove_table_follower(&self, request: RemoveTableFollowerRequest) -> Result<()>;
|
||||
|
||||
async fn start(&self, urls: &[&str]) -> Result<()>;
|
||||
|
||||
async fn start_with(&self, leader_provider: LeaderProviderRef) -> Result<()>;
|
||||
@@ -286,39 +291,41 @@ impl ProcedureExecutor for MetaClient {
|
||||
.context(meta_error::ExternalSnafu)
|
||||
}
|
||||
|
||||
async fn add_region_follower(
|
||||
async fn manage_region_follower(
|
||||
&self,
|
||||
_ctx: &ExecutorContext,
|
||||
request: AddRegionFollowerRequest,
|
||||
request: ManageRegionFollowerRequest,
|
||||
) -> MetaResult<()> {
|
||||
if let Some(region_follower) = &self.region_follower {
|
||||
region_follower
|
||||
.add_region_follower(request)
|
||||
.await
|
||||
.map_err(BoxedError::new)
|
||||
.context(meta_error::ExternalSnafu)
|
||||
} else {
|
||||
UnsupportedSnafu {
|
||||
operation: "add_region_follower",
|
||||
match request {
|
||||
ManageRegionFollowerRequest::AddRegionFollower(add_region_follower_request) => {
|
||||
region_follower
|
||||
.add_region_follower(add_region_follower_request)
|
||||
.await
|
||||
}
|
||||
ManageRegionFollowerRequest::RemoveRegionFollower(
|
||||
remove_region_follower_request,
|
||||
) => {
|
||||
region_follower
|
||||
.remove_region_follower(remove_region_follower_request)
|
||||
.await
|
||||
}
|
||||
ManageRegionFollowerRequest::AddTableFollower(add_table_follower_request) => {
|
||||
region_follower
|
||||
.add_table_follower(add_table_follower_request)
|
||||
.await
|
||||
}
|
||||
ManageRegionFollowerRequest::RemoveTableFollower(remove_table_follower_request) => {
|
||||
region_follower
|
||||
.remove_table_follower(remove_table_follower_request)
|
||||
.await
|
||||
}
|
||||
}
|
||||
.fail()
|
||||
}
|
||||
}
|
||||
|
||||
async fn remove_region_follower(
|
||||
&self,
|
||||
_ctx: &ExecutorContext,
|
||||
request: RemoveRegionFollowerRequest,
|
||||
) -> MetaResult<()> {
|
||||
if let Some(region_follower) = &self.region_follower {
|
||||
region_follower
|
||||
.remove_region_follower(request)
|
||||
.await
|
||||
.map_err(BoxedError::new)
|
||||
.context(meta_error::ExternalSnafu)
|
||||
.map_err(BoxedError::new)
|
||||
.context(meta_error::ExternalSnafu)
|
||||
} else {
|
||||
UnsupportedSnafu {
|
||||
operation: "remove_region_follower",
|
||||
operation: "manage_region_follower",
|
||||
}
|
||||
.fail()
|
||||
}
|
||||
|
||||
@@ -21,45 +21,23 @@ use api::v1::meta::procedure_service_server::ProcedureServiceServer;
|
||||
use api::v1::meta::store_server::StoreServer;
|
||||
use common_base::Plugins;
|
||||
use common_config::Configurable;
|
||||
#[cfg(feature = "pg_kvbackend")]
|
||||
use common_error::ext::BoxedError;
|
||||
#[cfg(any(feature = "pg_kvbackend", feature = "mysql_kvbackend"))]
|
||||
use common_meta::distributed_time_constants::META_LEASE_SECS;
|
||||
use common_meta::kv_backend::chroot::ChrootKvBackend;
|
||||
use common_meta::kv_backend::etcd::EtcdStore;
|
||||
use common_meta::kv_backend::memory::MemoryKvBackend;
|
||||
#[cfg(feature = "mysql_kvbackend")]
|
||||
use common_meta::kv_backend::rds::MySqlStore;
|
||||
#[cfg(feature = "pg_kvbackend")]
|
||||
use common_meta::kv_backend::rds::PgStore;
|
||||
#[cfg(feature = "pg_kvbackend")]
|
||||
use common_meta::kv_backend::rds::postgres::create_postgres_tls_connector;
|
||||
#[cfg(feature = "pg_kvbackend")]
|
||||
use common_meta::kv_backend::rds::postgres::{TlsMode as PgTlsMode, TlsOption as PgTlsOption};
|
||||
use common_meta::kv_backend::{KvBackendRef, ResettableKvBackendRef};
|
||||
use common_telemetry::info;
|
||||
#[cfg(feature = "pg_kvbackend")]
|
||||
use deadpool_postgres::{Config, Runtime};
|
||||
use either::Either;
|
||||
use etcd_client::{Client, ConnectOptions};
|
||||
use servers::configurator::ConfiguratorRef;
|
||||
use servers::export_metrics::ExportMetricsTask;
|
||||
use servers::http::{HttpServer, HttpServerBuilder};
|
||||
use servers::metrics_handler::MetricsHandler;
|
||||
use servers::server::Server;
|
||||
use servers::tls::TlsOption;
|
||||
#[cfg(any(feature = "pg_kvbackend", feature = "mysql_kvbackend"))]
|
||||
use snafu::OptionExt;
|
||||
use snafu::ResultExt;
|
||||
#[cfg(feature = "mysql_kvbackend")]
|
||||
use sqlx::mysql::MySqlConnectOptions;
|
||||
#[cfg(feature = "mysql_kvbackend")]
|
||||
use sqlx::mysql::MySqlPool;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::mpsc::{self, Receiver, Sender};
|
||||
use tokio::sync::{Mutex, oneshot};
|
||||
#[cfg(feature = "pg_kvbackend")]
|
||||
use tokio_postgres::NoTls;
|
||||
use tonic::codec::CompressionEncoding;
|
||||
use tonic::transport::server::{Router, TcpIncoming};
|
||||
|
||||
@@ -67,10 +45,6 @@ use crate::cluster::{MetaPeerClientBuilder, MetaPeerClientRef};
|
||||
#[cfg(any(feature = "pg_kvbackend", feature = "mysql_kvbackend"))]
|
||||
use crate::election::CANDIDATE_LEASE_SECS;
|
||||
use crate::election::etcd::EtcdElection;
|
||||
#[cfg(feature = "mysql_kvbackend")]
|
||||
use crate::election::rds::mysql::MySqlElection;
|
||||
#[cfg(feature = "pg_kvbackend")]
|
||||
use crate::election::rds::postgres::PgElection;
|
||||
use crate::metasrv::builder::MetasrvBuilder;
|
||||
use crate::metasrv::{
|
||||
BackendImpl, ElectionRef, Metasrv, MetasrvOptions, SelectTarget, SelectorRef,
|
||||
@@ -82,6 +56,7 @@ use crate::selector::round_robin::RoundRobinSelector;
|
||||
use crate::selector::weight_compute::RegionNumsBasedWeightCompute;
|
||||
use crate::service::admin;
|
||||
use crate::service::admin::admin_axum_router;
|
||||
use crate::utils::etcd::create_etcd_client_with_tls;
|
||||
use crate::{Result, error};
|
||||
|
||||
pub struct MetasrvInstance {
|
||||
@@ -306,8 +281,11 @@ pub async fn metasrv_builder(
|
||||
use std::time::Duration;
|
||||
|
||||
use common_meta::distributed_time_constants::POSTGRES_KEEP_ALIVE_SECS;
|
||||
use common_meta::kv_backend::rds::PgStore;
|
||||
use deadpool_postgres::Config;
|
||||
|
||||
use crate::election::rds::postgres::ElectionPgClient;
|
||||
use crate::election::rds::postgres::{ElectionPgClient, PgElection};
|
||||
use crate::utils::postgres::create_postgres_pool;
|
||||
|
||||
let candidate_lease_ttl = Duration::from_secs(CANDIDATE_LEASE_SECS);
|
||||
let execution_timeout = Duration::from_secs(META_LEASE_SECS);
|
||||
@@ -319,8 +297,8 @@ pub async fn metasrv_builder(
|
||||
cfg.keepalives = Some(true);
|
||||
cfg.keepalives_idle = Some(Duration::from_secs(POSTGRES_KEEP_ALIVE_SECS));
|
||||
// We use a separate pool for election since we need a different session keep-alive idle time.
|
||||
let pool =
|
||||
create_postgres_pool_with(&opts.store_addrs, cfg, opts.backend_tls.clone()).await?;
|
||||
let pool = create_postgres_pool(&opts.store_addrs, Some(cfg), opts.backend_tls.clone())
|
||||
.await?;
|
||||
|
||||
let election_client = ElectionPgClient::new(
|
||||
pool,
|
||||
@@ -340,7 +318,8 @@ pub async fn metasrv_builder(
|
||||
)
|
||||
.await?;
|
||||
|
||||
let pool = create_postgres_pool(&opts.store_addrs, opts.backend_tls.clone()).await?;
|
||||
let pool =
|
||||
create_postgres_pool(&opts.store_addrs, None, opts.backend_tls.clone()).await?;
|
||||
let kv_backend = PgStore::with_pg_pool(
|
||||
pool,
|
||||
opts.meta_schema_name.as_deref(),
|
||||
@@ -356,9 +335,12 @@ pub async fn metasrv_builder(
|
||||
(None, BackendImpl::MysqlStore) => {
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::election::rds::mysql::ElectionMysqlClient;
|
||||
use common_meta::kv_backend::rds::MySqlStore;
|
||||
|
||||
let pool = create_mysql_pool(&opts.store_addrs).await?;
|
||||
use crate::election::rds::mysql::{ElectionMysqlClient, MySqlElection};
|
||||
use crate::utils::mysql::create_mysql_pool;
|
||||
|
||||
let pool = create_mysql_pool(&opts.store_addrs, opts.backend_tls.as_ref()).await?;
|
||||
let kv_backend =
|
||||
MySqlStore::with_mysql_pool(pool, &opts.meta_table_name, opts.max_txn_ops)
|
||||
.await
|
||||
@@ -366,7 +348,7 @@ pub async fn metasrv_builder(
|
||||
// Since election will acquire a lock of the table, we need a separate table for election.
|
||||
let election_table_name = opts.meta_table_name.clone() + "_election";
|
||||
// We use a separate pool for election since we need a different session keep-alive idle time.
|
||||
let pool = create_mysql_pool(&opts.store_addrs).await?;
|
||||
let pool = create_mysql_pool(&opts.store_addrs, opts.backend_tls.as_ref()).await?;
|
||||
let execution_timeout = Duration::from_secs(META_LEASE_SECS);
|
||||
let statement_timeout = Duration::from_secs(META_LEASE_SECS);
|
||||
let idle_session_timeout = Duration::from_secs(META_LEASE_SECS);
|
||||
@@ -452,259 +434,3 @@ pub(crate) fn build_default_meta_peer_client(
|
||||
// Safety: all required fields set at initialization
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
pub async fn create_etcd_client(store_addrs: &[String]) -> Result<Client> {
|
||||
create_etcd_client_with_tls(store_addrs, None).await
|
||||
}
|
||||
|
||||
fn build_connection_options(tls_config: Option<&TlsOption>) -> Result<Option<ConnectOptions>> {
|
||||
use std::fs;
|
||||
|
||||
use common_telemetry::debug;
|
||||
use etcd_client::{Certificate, ConnectOptions, Identity, TlsOptions};
|
||||
use servers::tls::TlsMode;
|
||||
|
||||
// If TLS options are not provided, return None
|
||||
let Some(tls_config) = tls_config else {
|
||||
return Ok(None);
|
||||
};
|
||||
// If TLS is disabled, return None
|
||||
if matches!(tls_config.mode, TlsMode::Disable) {
|
||||
return Ok(None);
|
||||
}
|
||||
let mut etcd_tls_opts = TlsOptions::new();
|
||||
// Set CA certificate if provided
|
||||
if !tls_config.ca_cert_path.is_empty() {
|
||||
debug!("Using CA certificate from {}", tls_config.ca_cert_path);
|
||||
let ca_cert_pem = fs::read(&tls_config.ca_cert_path).context(error::FileIoSnafu {
|
||||
path: &tls_config.ca_cert_path,
|
||||
})?;
|
||||
let ca_cert = Certificate::from_pem(ca_cert_pem);
|
||||
etcd_tls_opts = etcd_tls_opts.ca_certificate(ca_cert);
|
||||
}
|
||||
// Set client identity (cert + key) if both are provided
|
||||
if !tls_config.cert_path.is_empty() && !tls_config.key_path.is_empty() {
|
||||
debug!(
|
||||
"Using client certificate from {} and key from {}",
|
||||
tls_config.cert_path, tls_config.key_path
|
||||
);
|
||||
let cert_pem = fs::read(&tls_config.cert_path).context(error::FileIoSnafu {
|
||||
path: &tls_config.cert_path,
|
||||
})?;
|
||||
let key_pem = fs::read(&tls_config.key_path).context(error::FileIoSnafu {
|
||||
path: &tls_config.key_path,
|
||||
})?;
|
||||
let identity = Identity::from_pem(cert_pem, key_pem);
|
||||
etcd_tls_opts = etcd_tls_opts.identity(identity);
|
||||
}
|
||||
// Enable native TLS roots for additional trust anchors
|
||||
etcd_tls_opts = etcd_tls_opts.with_native_roots();
|
||||
Ok(Some(ConnectOptions::new().with_tls(etcd_tls_opts)))
|
||||
}
|
||||
|
||||
pub async fn create_etcd_client_with_tls(
|
||||
store_addrs: &[String],
|
||||
tls_config: Option<&TlsOption>,
|
||||
) -> Result<Client> {
|
||||
let etcd_endpoints = store_addrs
|
||||
.iter()
|
||||
.map(|x| x.trim())
|
||||
.filter(|x| !x.is_empty())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let connect_options = build_connection_options(tls_config)?;
|
||||
|
||||
Client::connect(&etcd_endpoints, connect_options)
|
||||
.await
|
||||
.context(error::ConnectEtcdSnafu)
|
||||
}
|
||||
|
||||
#[cfg(feature = "pg_kvbackend")]
|
||||
/// Converts servers::tls::TlsOption to postgres::TlsOption to avoid circular dependencies
|
||||
fn convert_tls_option(tls_option: &TlsOption) -> PgTlsOption {
|
||||
let mode = match tls_option.mode {
|
||||
servers::tls::TlsMode::Disable => PgTlsMode::Disable,
|
||||
servers::tls::TlsMode::Prefer => PgTlsMode::Prefer,
|
||||
servers::tls::TlsMode::Require => PgTlsMode::Require,
|
||||
servers::tls::TlsMode::VerifyCa => PgTlsMode::VerifyCa,
|
||||
servers::tls::TlsMode::VerifyFull => PgTlsMode::VerifyFull,
|
||||
};
|
||||
|
||||
PgTlsOption {
|
||||
mode,
|
||||
cert_path: tls_option.cert_path.clone(),
|
||||
key_path: tls_option.key_path.clone(),
|
||||
ca_cert_path: tls_option.ca_cert_path.clone(),
|
||||
watch: tls_option.watch,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "pg_kvbackend")]
|
||||
/// Creates a pool for the Postgres backend with optional TLS.
|
||||
///
|
||||
/// It only use first store addr to create a pool.
|
||||
pub async fn create_postgres_pool(
|
||||
store_addrs: &[String],
|
||||
tls_config: Option<TlsOption>,
|
||||
) -> Result<deadpool_postgres::Pool> {
|
||||
create_postgres_pool_with(store_addrs, Config::new(), tls_config).await
|
||||
}
|
||||
|
||||
#[cfg(feature = "pg_kvbackend")]
|
||||
/// Creates a pool for the Postgres backend with config and optional TLS.
|
||||
///
|
||||
/// It only use first store addr to create a pool, and use the given config to create a pool.
|
||||
pub async fn create_postgres_pool_with(
|
||||
store_addrs: &[String],
|
||||
mut cfg: Config,
|
||||
tls_config: Option<TlsOption>,
|
||||
) -> Result<deadpool_postgres::Pool> {
|
||||
let postgres_url = store_addrs.first().context(error::InvalidArgumentsSnafu {
|
||||
err_msg: "empty store addrs",
|
||||
})?;
|
||||
cfg.url = Some(postgres_url.to_string());
|
||||
|
||||
let pool = if let Some(tls_config) = tls_config {
|
||||
let pg_tls_config = convert_tls_option(&tls_config);
|
||||
let tls_connector =
|
||||
create_postgres_tls_connector(&pg_tls_config).map_err(|e| error::Error::Other {
|
||||
source: BoxedError::new(e),
|
||||
location: snafu::Location::new(file!(), line!(), 0),
|
||||
})?;
|
||||
cfg.create_pool(Some(Runtime::Tokio1), tls_connector)
|
||||
.context(error::CreatePostgresPoolSnafu)?
|
||||
} else {
|
||||
cfg.create_pool(Some(Runtime::Tokio1), NoTls)
|
||||
.context(error::CreatePostgresPoolSnafu)?
|
||||
};
|
||||
|
||||
Ok(pool)
|
||||
}
|
||||
|
||||
#[cfg(feature = "mysql_kvbackend")]
|
||||
async fn setup_mysql_options(store_addrs: &[String]) -> Result<MySqlConnectOptions> {
|
||||
let mysql_url = store_addrs.first().context(error::InvalidArgumentsSnafu {
|
||||
err_msg: "empty store addrs",
|
||||
})?;
|
||||
// Avoid `SET` commands in sqlx
|
||||
let opts: MySqlConnectOptions = mysql_url
|
||||
.parse()
|
||||
.context(error::ParseMySqlUrlSnafu { mysql_url })?;
|
||||
let opts = opts
|
||||
.no_engine_substitution(false)
|
||||
.pipes_as_concat(false)
|
||||
.timezone(None)
|
||||
.set_names(false);
|
||||
Ok(opts)
|
||||
}
|
||||
|
||||
#[cfg(feature = "mysql_kvbackend")]
|
||||
pub async fn create_mysql_pool(store_addrs: &[String]) -> Result<MySqlPool> {
|
||||
let opts = setup_mysql_options(store_addrs).await?;
|
||||
let pool = MySqlPool::connect_with(opts)
|
||||
.await
|
||||
.context(error::CreateMySqlPoolSnafu)?;
|
||||
|
||||
Ok(pool)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use servers::tls::TlsMode;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_etcd_client_tls_without_certs() {
|
||||
let endpoints: Vec<String> = match std::env::var("GT_ETCD_TLS_ENDPOINTS") {
|
||||
Ok(endpoints_str) => endpoints_str
|
||||
.split(',')
|
||||
.map(|s| s.trim().to_string())
|
||||
.collect(),
|
||||
Err(_) => return,
|
||||
};
|
||||
|
||||
let tls_config = TlsOption {
|
||||
mode: TlsMode::Require,
|
||||
ca_cert_path: String::new(),
|
||||
cert_path: String::new(),
|
||||
key_path: String::new(),
|
||||
watch: false,
|
||||
};
|
||||
|
||||
let _client = create_etcd_client_with_tls(&endpoints, Some(&tls_config))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_etcd_client_tls_with_client_certs() {
|
||||
let endpoints: Vec<String> = match std::env::var("GT_ETCD_TLS_ENDPOINTS") {
|
||||
Ok(endpoints_str) => endpoints_str
|
||||
.split(',')
|
||||
.map(|s| s.trim().to_string())
|
||||
.collect(),
|
||||
Err(_) => return,
|
||||
};
|
||||
|
||||
let cert_dir = std::env::current_dir()
|
||||
.unwrap()
|
||||
.join("tests-integration")
|
||||
.join("fixtures")
|
||||
.join("etcd-tls-certs");
|
||||
|
||||
if cert_dir.join("client.crt").exists() && cert_dir.join("client-key.pem").exists() {
|
||||
let tls_config = TlsOption {
|
||||
mode: TlsMode::Require,
|
||||
ca_cert_path: String::new(),
|
||||
cert_path: cert_dir.join("client.crt").to_string_lossy().to_string(),
|
||||
key_path: cert_dir
|
||||
.join("client-key.pem")
|
||||
.to_string_lossy()
|
||||
.to_string(),
|
||||
watch: false,
|
||||
};
|
||||
|
||||
let _client = create_etcd_client_with_tls(&endpoints, Some(&tls_config))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_etcd_client_tls_with_full_certs() {
|
||||
let endpoints: Vec<String> = match std::env::var("GT_ETCD_TLS_ENDPOINTS") {
|
||||
Ok(endpoints_str) => endpoints_str
|
||||
.split(',')
|
||||
.map(|s| s.trim().to_string())
|
||||
.collect(),
|
||||
Err(_) => return,
|
||||
};
|
||||
|
||||
let cert_dir = std::env::current_dir()
|
||||
.unwrap()
|
||||
.join("tests-integration")
|
||||
.join("fixtures")
|
||||
.join("etcd-tls-certs");
|
||||
|
||||
if cert_dir.join("ca.crt").exists()
|
||||
&& cert_dir.join("client.crt").exists()
|
||||
&& cert_dir.join("client-key.pem").exists()
|
||||
{
|
||||
let tls_config = TlsOption {
|
||||
mode: TlsMode::Require,
|
||||
ca_cert_path: cert_dir.join("ca.crt").to_string_lossy().to_string(),
|
||||
cert_path: cert_dir.join("client.crt").to_string_lossy().to_string(),
|
||||
key_path: cert_dir
|
||||
.join("client-key.pem")
|
||||
.to_string_lossy()
|
||||
.to_string(),
|
||||
watch: false,
|
||||
};
|
||||
|
||||
let _client = create_etcd_client_with_tls(&endpoints, Some(&tls_config))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -107,7 +107,7 @@ mod tests {
|
||||
use common_time::util::current_time_millis;
|
||||
use common_workload::DatanodeWorkloadType;
|
||||
|
||||
use crate::discovery::utils::{self, is_datanode_accept_ingest_workload};
|
||||
use crate::discovery::utils::{self, accept_ingest_workload};
|
||||
use crate::key::{DatanodeLeaseKey, LeaseValue};
|
||||
use crate::test_util::create_meta_peer_client;
|
||||
|
||||
@@ -219,7 +219,7 @@ mod tests {
|
||||
let peers = utils::alive_datanodes(
|
||||
client.as_ref(),
|
||||
Duration::from_secs(lease_secs),
|
||||
Some(is_datanode_accept_ingest_workload),
|
||||
Some(accept_ingest_workload),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -144,19 +144,22 @@ pub async fn alive_datanode(
|
||||
Ok(v)
|
||||
}
|
||||
|
||||
/// Returns true if the datanode can accept ingest workload based on its workload types.
|
||||
/// Determines if a datanode is capable of accepting ingest workloads.
|
||||
/// Returns `true` if the datanode's workload types include ingest capability,
|
||||
/// or if the node is not of type [NodeWorkloads::Datanode].
|
||||
///
|
||||
/// A datanode is considered to accept ingest workload if it supports either:
|
||||
/// - Hybrid workload (both ingest and query workloads)
|
||||
/// - Ingest workload (only ingest workload)
|
||||
pub fn is_datanode_accept_ingest_workload(datanode_workloads: &NodeWorkloads) -> bool {
|
||||
pub fn accept_ingest_workload(datanode_workloads: &NodeWorkloads) -> bool {
|
||||
match &datanode_workloads {
|
||||
NodeWorkloads::Datanode(workloads) => workloads
|
||||
.types
|
||||
.iter()
|
||||
.filter_map(|w| DatanodeWorkloadType::from_i32(*w))
|
||||
.any(|w| w.accept_ingest()),
|
||||
_ => false,
|
||||
// If the [NodeWorkloads] type is not [NodeWorkloads::Datanode], returns true.
|
||||
_ => true,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -984,8 +984,8 @@ mod tests {
|
||||
use common_telemetry::init_default_ut_logging;
|
||||
|
||||
use super::*;
|
||||
use crate::bootstrap::create_mysql_pool;
|
||||
use crate::error;
|
||||
use crate::utils::mysql::create_mysql_pool;
|
||||
|
||||
async fn create_mysql_client(
|
||||
table_name: Option<&str>,
|
||||
@@ -1000,7 +1000,7 @@ mod tests {
|
||||
}
|
||||
.fail();
|
||||
}
|
||||
let pool = create_mysql_pool(&[endpoint]).await.unwrap();
|
||||
let pool = create_mysql_pool(&[endpoint], None).await.unwrap();
|
||||
let mut client = ElectionMysqlClient::new(
|
||||
pool,
|
||||
execution_timeout,
|
||||
|
||||
@@ -826,8 +826,8 @@ mod tests {
|
||||
use common_meta::maybe_skip_postgres_integration_test;
|
||||
|
||||
use super::*;
|
||||
use crate::bootstrap::create_postgres_pool;
|
||||
use crate::error;
|
||||
use crate::utils::postgres::create_postgres_pool;
|
||||
|
||||
async fn create_postgres_client(
|
||||
table_name: Option<&str>,
|
||||
@@ -842,7 +842,7 @@ mod tests {
|
||||
}
|
||||
.fail();
|
||||
}
|
||||
let pool = create_postgres_pool(&[endpoint], None).await.unwrap();
|
||||
let pool = create_postgres_pool(&[endpoint], None, None).await.unwrap();
|
||||
let mut pg_client = ElectionPgClient::new(
|
||||
pool,
|
||||
execution_timeout,
|
||||
|
||||
@@ -981,6 +981,14 @@ pub enum Error {
|
||||
#[snafu(source)]
|
||||
source: common_meta::error::Error,
|
||||
},
|
||||
|
||||
#[snafu(display("Failed to build tls options"))]
|
||||
BuildTlsOptions {
|
||||
#[snafu(implicit)]
|
||||
location: Location,
|
||||
#[snafu(source)]
|
||||
source: common_meta::error::Error,
|
||||
},
|
||||
}
|
||||
|
||||
impl Error {
|
||||
@@ -1116,6 +1124,7 @@ impl ErrorExt for Error {
|
||||
| Error::InitDdlManager { source, .. }
|
||||
| Error::InitReconciliationManager { source, .. } => source.status_code(),
|
||||
|
||||
Error::BuildTlsOptions { source, .. } => source.status_code(),
|
||||
Error::Other { source, .. } => source.status_code(),
|
||||
Error::NoEnoughAvailableNode { .. } => StatusCode::RuntimeResourcesExhausted,
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ use common_meta::error::{ExternalSnafu, Result as MetaResult};
|
||||
use common_meta::peer::{Peer, PeerAllocator};
|
||||
use snafu::{ResultExt, ensure};
|
||||
|
||||
use crate::discovery::utils::accept_ingest_workload;
|
||||
use crate::error::{Result, TooManyPartitionsSnafu};
|
||||
use crate::metasrv::{SelectorContext, SelectorRef};
|
||||
use crate::selector::SelectorOptions;
|
||||
@@ -69,6 +70,7 @@ impl MetasrvPeerAllocator {
|
||||
min_required_items,
|
||||
allow_duplication: true,
|
||||
exclude_peer_ids: HashSet::new(),
|
||||
workload_filter: Some(accept_ingest_workload),
|
||||
},
|
||||
)
|
||||
.await
|
||||
|
||||
@@ -261,12 +261,8 @@ impl WalPruneManager {
|
||||
Err(error::Error::PruneTaskAlreadyRunning { topic, .. }) => {
|
||||
warn!("Prune task for topic {} is already running", topic);
|
||||
}
|
||||
Err(e) => {
|
||||
error!(
|
||||
"Failed to submit prune task for topic {}: {}",
|
||||
topic_name.clone(),
|
||||
e
|
||||
);
|
||||
Err(err) => {
|
||||
error!(err; "Failed to prune remote WAL for topic {}", topic_name.as_str());
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
@@ -40,6 +40,7 @@ use tokio::sync::mpsc::{Receiver, Sender};
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::time::{MissedTickBehavior, interval, interval_at};
|
||||
|
||||
use crate::discovery::utils::accept_ingest_workload;
|
||||
use crate::error::{self, Result};
|
||||
use crate::failure_detector::PhiAccrualFailureDetectorOptions;
|
||||
use crate::metasrv::{RegionStatAwareSelectorRef, SelectTarget, SelectorContext, SelectorRef};
|
||||
@@ -584,6 +585,7 @@ impl RegionSupervisor {
|
||||
min_required_items: regions.len(),
|
||||
allow_duplication: true,
|
||||
exclude_peer_ids,
|
||||
workload_filter: Some(accept_ingest_workload),
|
||||
};
|
||||
let peers = selector.select(&self.selector_context, opt).await?;
|
||||
ensure!(
|
||||
|
||||
@@ -22,6 +22,7 @@ pub mod weight_compute;
|
||||
pub mod weighted_choose;
|
||||
use std::collections::HashSet;
|
||||
|
||||
use api::v1::meta::heartbeat_request::NodeWorkloads;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use store_api::storage::RegionId;
|
||||
use strum::AsRefStr;
|
||||
@@ -63,6 +64,8 @@ pub struct SelectorOptions {
|
||||
pub allow_duplication: bool,
|
||||
/// The peers to exclude from the selection.
|
||||
pub exclude_peer_ids: HashSet<u64>,
|
||||
/// The filter to select the peers based on their workloads.
|
||||
pub workload_filter: Option<fn(&NodeWorkloads) -> bool>,
|
||||
}
|
||||
|
||||
impl Default for SelectorOptions {
|
||||
@@ -71,6 +74,7 @@ impl Default for SelectorOptions {
|
||||
min_required_items: 1,
|
||||
allow_duplication: false,
|
||||
exclude_peer_ids: HashSet::new(),
|
||||
workload_filter: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -139,6 +139,7 @@ mod tests {
|
||||
min_required_items: i,
|
||||
allow_duplication: false,
|
||||
exclude_peer_ids: HashSet::new(),
|
||||
workload_filter: None,
|
||||
};
|
||||
|
||||
let selected_peers: HashSet<_> =
|
||||
@@ -154,6 +155,7 @@ mod tests {
|
||||
min_required_items: 6,
|
||||
allow_duplication: false,
|
||||
exclude_peer_ids: HashSet::new(),
|
||||
workload_filter: None,
|
||||
};
|
||||
|
||||
let selected_result =
|
||||
@@ -165,6 +167,7 @@ mod tests {
|
||||
min_required_items: i,
|
||||
allow_duplication: true,
|
||||
exclude_peer_ids: HashSet::new(),
|
||||
workload_filter: None,
|
||||
};
|
||||
|
||||
let selected_peers =
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
use common_meta::peer::Peer;
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::discovery::utils::is_datanode_accept_ingest_workload;
|
||||
use crate::error::{ListActiveDatanodesSnafu, Result};
|
||||
use crate::metasrv::SelectorContext;
|
||||
use crate::selector::common::{choose_items, filter_out_excluded_peers};
|
||||
@@ -35,7 +34,7 @@ impl Selector for LeaseBasedSelector {
|
||||
// 1. get alive datanodes.
|
||||
let alive_datanodes = ctx
|
||||
.peer_discovery
|
||||
.active_datanodes(Some(is_datanode_accept_ingest_workload))
|
||||
.active_datanodes(opts.workload_filter)
|
||||
.await
|
||||
.context(ListActiveDatanodesSnafu)?;
|
||||
|
||||
|
||||
@@ -20,7 +20,6 @@ use common_telemetry::debug;
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::cluster::MetaPeerClientRef;
|
||||
use crate::discovery::utils::is_datanode_accept_ingest_workload;
|
||||
use crate::error::{ListActiveDatanodesSnafu, Result};
|
||||
use crate::metasrv::SelectorContext;
|
||||
use crate::selector::common::{choose_items, filter_out_excluded_peers};
|
||||
@@ -54,7 +53,7 @@ where
|
||||
// 1. get alive datanodes.
|
||||
let alive_datanodes = ctx
|
||||
.peer_discovery
|
||||
.active_datanodes(Some(is_datanode_accept_ingest_workload))
|
||||
.active_datanodes(opts.workload_filter)
|
||||
.await
|
||||
.context(ListActiveDatanodesSnafu)?;
|
||||
|
||||
|
||||
@@ -17,7 +17,6 @@ use std::sync::atomic::AtomicUsize;
|
||||
use common_meta::peer::Peer;
|
||||
use snafu::{ResultExt, ensure};
|
||||
|
||||
use crate::discovery::utils::is_datanode_accept_ingest_workload;
|
||||
use crate::error::{
|
||||
ListActiveDatanodesSnafu, ListActiveFlownodesSnafu, NoEnoughAvailableNodeSnafu, Result,
|
||||
};
|
||||
@@ -59,7 +58,7 @@ impl RoundRobinSelector {
|
||||
// 1. get alive datanodes.
|
||||
let alive_datanodes = ctx
|
||||
.peer_discovery
|
||||
.active_datanodes(Some(is_datanode_accept_ingest_workload))
|
||||
.active_datanodes(opts.workload_filter)
|
||||
.await
|
||||
.context(ListActiveDatanodesSnafu)?;
|
||||
|
||||
@@ -71,7 +70,7 @@ impl RoundRobinSelector {
|
||||
}
|
||||
SelectTarget::Flownode => ctx
|
||||
.peer_discovery
|
||||
.active_flownodes(None)
|
||||
.active_flownodes(opts.workload_filter)
|
||||
.await
|
||||
.context(ListActiveFlownodesSnafu)?,
|
||||
};
|
||||
@@ -150,6 +149,7 @@ mod test {
|
||||
min_required_items: 4,
|
||||
allow_duplication: true,
|
||||
exclude_peer_ids: HashSet::new(),
|
||||
workload_filter: None,
|
||||
},
|
||||
)
|
||||
.await
|
||||
@@ -167,6 +167,7 @@ mod test {
|
||||
min_required_items: 2,
|
||||
allow_duplication: true,
|
||||
exclude_peer_ids: HashSet::new(),
|
||||
workload_filter: None,
|
||||
},
|
||||
)
|
||||
.await
|
||||
@@ -208,6 +209,7 @@ mod test {
|
||||
min_required_items: 1,
|
||||
allow_duplication: true,
|
||||
exclude_peer_ids: HashSet::from([2, 5]),
|
||||
workload_filter: None,
|
||||
},
|
||||
)
|
||||
.await
|
||||
|
||||
@@ -12,7 +12,12 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
pub mod etcd;
|
||||
pub mod insert_forwarder;
|
||||
#[cfg(feature = "mysql_kvbackend")]
|
||||
pub mod mysql;
|
||||
#[cfg(feature = "pg_kvbackend")]
|
||||
pub mod postgres;
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! define_ticker {
|
||||
|
||||
56
src/meta-srv/src/utils/etcd.rs
Normal file
56
src/meta-srv/src/utils/etcd.rs
Normal file
@@ -0,0 +1,56 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use common_meta::kv_backend::etcd::create_etcd_tls_options;
|
||||
use etcd_client::{Client, ConnectOptions};
|
||||
use servers::tls::{TlsMode, TlsOption};
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::error::{self, BuildTlsOptionsSnafu, Result};
|
||||
|
||||
/// Creates an etcd client with TLS configuration.
|
||||
pub async fn create_etcd_client_with_tls(
|
||||
store_addrs: &[String],
|
||||
tls_config: Option<&TlsOption>,
|
||||
) -> Result<Client> {
|
||||
let etcd_endpoints = store_addrs
|
||||
.iter()
|
||||
.map(|x| x.trim())
|
||||
.filter(|x| !x.is_empty())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let connect_options = tls_config
|
||||
.map(|c| create_etcd_tls_options(&convert_tls_option(c)))
|
||||
.transpose()
|
||||
.context(BuildTlsOptionsSnafu)?
|
||||
.flatten()
|
||||
.map(|tls_options| ConnectOptions::new().with_tls(tls_options));
|
||||
|
||||
Client::connect(&etcd_endpoints, connect_options)
|
||||
.await
|
||||
.context(error::ConnectEtcdSnafu)
|
||||
}
|
||||
|
||||
fn convert_tls_option(tls_option: &TlsOption) -> common_meta::kv_backend::etcd::TlsOption {
|
||||
let mode = match tls_option.mode {
|
||||
TlsMode::Disable => common_meta::kv_backend::etcd::TlsMode::Disable,
|
||||
_ => common_meta::kv_backend::etcd::TlsMode::Require,
|
||||
};
|
||||
common_meta::kv_backend::etcd::TlsOption {
|
||||
mode,
|
||||
cert_path: tls_option.cert_path.clone(),
|
||||
key_path: tls_option.key_path.clone(),
|
||||
ca_cert_path: tls_option.ca_cert_path.clone(),
|
||||
}
|
||||
}
|
||||
85
src/meta-srv/src/utils/mysql.rs
Normal file
85
src/meta-srv/src/utils/mysql.rs
Normal file
@@ -0,0 +1,85 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use common_telemetry::info;
|
||||
use servers::tls::{TlsMode, TlsOption};
|
||||
use snafu::{OptionExt, ResultExt};
|
||||
use sqlx::mysql::{MySqlConnectOptions, MySqlPool, MySqlSslMode};
|
||||
|
||||
use crate::error::{self, Result};
|
||||
|
||||
async fn setup_mysql_options(
|
||||
store_addrs: &[String],
|
||||
tls_config: Option<&TlsOption>,
|
||||
) -> Result<MySqlConnectOptions> {
|
||||
let mysql_url = store_addrs.first().context(error::InvalidArgumentsSnafu {
|
||||
err_msg: "empty store addrs",
|
||||
})?;
|
||||
// Avoid `SET` commands in sqlx
|
||||
let opts: MySqlConnectOptions = mysql_url
|
||||
.parse()
|
||||
.context(error::ParseMySqlUrlSnafu { mysql_url })?;
|
||||
let mut opts = opts
|
||||
.no_engine_substitution(false)
|
||||
.pipes_as_concat(false)
|
||||
.timezone(None)
|
||||
.set_names(false);
|
||||
|
||||
let Some(tls_config) = tls_config else {
|
||||
return Ok(opts);
|
||||
};
|
||||
|
||||
match tls_config.mode {
|
||||
TlsMode::Disable => return Ok(opts),
|
||||
TlsMode::Prefer => {
|
||||
opts = opts.ssl_mode(MySqlSslMode::Preferred);
|
||||
}
|
||||
TlsMode::Require => {
|
||||
opts = opts.ssl_mode(MySqlSslMode::Required);
|
||||
}
|
||||
TlsMode::VerifyCa => {
|
||||
opts = opts.ssl_mode(MySqlSslMode::VerifyCa);
|
||||
opts = opts.ssl_ca(&tls_config.ca_cert_path);
|
||||
}
|
||||
TlsMode::VerifyFull => {
|
||||
opts = opts.ssl_mode(MySqlSslMode::VerifyIdentity);
|
||||
opts = opts.ssl_ca(&tls_config.ca_cert_path);
|
||||
}
|
||||
}
|
||||
info!(
|
||||
"Setting up MySQL options with TLS mode: {:?}",
|
||||
tls_config.mode
|
||||
);
|
||||
|
||||
if !tls_config.cert_path.is_empty() && !tls_config.key_path.is_empty() {
|
||||
info!("Loading client certificate for mutual TLS");
|
||||
opts = opts.ssl_client_cert(&tls_config.cert_path);
|
||||
opts = opts.ssl_client_key(&tls_config.key_path);
|
||||
}
|
||||
|
||||
Ok(opts)
|
||||
}
|
||||
|
||||
/// Creates a MySQL pool.
|
||||
pub async fn create_mysql_pool(
|
||||
store_addrs: &[String],
|
||||
tls_config: Option<&TlsOption>,
|
||||
) -> Result<MySqlPool> {
|
||||
let opts = setup_mysql_options(store_addrs, tls_config).await?;
|
||||
let pool = MySqlPool::connect_with(opts)
|
||||
.await
|
||||
.context(error::CreateMySqlPoolSnafu)?;
|
||||
|
||||
Ok(pool)
|
||||
}
|
||||
74
src/meta-srv/src/utils/postgres.rs
Normal file
74
src/meta-srv/src/utils/postgres.rs
Normal file
@@ -0,0 +1,74 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use common_error::ext::BoxedError;
|
||||
use common_meta::kv_backend::rds::postgres::{
|
||||
TlsMode as PgTlsMode, TlsOption as PgTlsOption, create_postgres_tls_connector,
|
||||
};
|
||||
use deadpool_postgres::{Config, Runtime};
|
||||
use servers::tls::TlsOption;
|
||||
use snafu::{OptionExt, ResultExt};
|
||||
use tokio_postgres::NoTls;
|
||||
|
||||
use crate::error::{self, Result};
|
||||
|
||||
/// Converts [`TlsOption`] to [`PgTlsOption`] to avoid circular dependencies
|
||||
fn convert_tls_option(tls_option: &TlsOption) -> PgTlsOption {
|
||||
let mode = match tls_option.mode {
|
||||
servers::tls::TlsMode::Disable => PgTlsMode::Disable,
|
||||
servers::tls::TlsMode::Prefer => PgTlsMode::Prefer,
|
||||
servers::tls::TlsMode::Require => PgTlsMode::Require,
|
||||
servers::tls::TlsMode::VerifyCa => PgTlsMode::VerifyCa,
|
||||
servers::tls::TlsMode::VerifyFull => PgTlsMode::VerifyFull,
|
||||
};
|
||||
|
||||
PgTlsOption {
|
||||
mode,
|
||||
cert_path: tls_option.cert_path.clone(),
|
||||
key_path: tls_option.key_path.clone(),
|
||||
ca_cert_path: tls_option.ca_cert_path.clone(),
|
||||
watch: tls_option.watch,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a pool for the Postgres backend with config and optional TLS.
|
||||
///
|
||||
/// It only use first store addr to create a pool, and use the given config to create a pool.
|
||||
pub async fn create_postgres_pool(
|
||||
store_addrs: &[String],
|
||||
cfg: Option<Config>,
|
||||
tls_config: Option<TlsOption>,
|
||||
) -> Result<deadpool_postgres::Pool> {
|
||||
let mut cfg = cfg.unwrap_or_default();
|
||||
let postgres_url = store_addrs.first().context(error::InvalidArgumentsSnafu {
|
||||
err_msg: "empty store addrs",
|
||||
})?;
|
||||
cfg.url = Some(postgres_url.to_string());
|
||||
|
||||
let pool = if let Some(tls_config) = tls_config {
|
||||
let pg_tls_config = convert_tls_option(&tls_config);
|
||||
let tls_connector =
|
||||
create_postgres_tls_connector(&pg_tls_config).map_err(|e| error::Error::Other {
|
||||
source: BoxedError::new(e),
|
||||
location: snafu::Location::new(file!(), line!(), 0),
|
||||
})?;
|
||||
cfg.create_pool(Some(Runtime::Tokio1), tls_connector)
|
||||
.context(error::CreatePostgresPoolSnafu)?
|
||||
} else {
|
||||
cfg.create_pool(Some(Runtime::Tokio1), NoTls)
|
||||
.context(error::CreatePostgresPoolSnafu)?
|
||||
};
|
||||
|
||||
Ok(pool)
|
||||
}
|
||||
@@ -257,10 +257,10 @@ impl RegionEngine for MetricEngine {
|
||||
self.handle_query(region_id, request).await
|
||||
}
|
||||
|
||||
async fn get_last_seq_num(
|
||||
async fn get_committed_sequence(
|
||||
&self,
|
||||
region_id: RegionId,
|
||||
) -> Result<Option<SequenceNumber>, BoxedError> {
|
||||
) -> Result<SequenceNumber, BoxedError> {
|
||||
self.inner
|
||||
.get_last_seq_num(region_id)
|
||||
.await
|
||||
|
||||
@@ -111,6 +111,8 @@ mod tests {
|
||||
let mito = env.mito();
|
||||
let debug_format = mito
|
||||
.all_ssts_from_manifest()
|
||||
.await
|
||||
.into_iter()
|
||||
.map(|mut e| {
|
||||
e.file_path = e.file_path.replace(&e.file_id, "<file_id>");
|
||||
e.index_file_path = e
|
||||
@@ -125,12 +127,12 @@ mod tests {
|
||||
assert_eq!(
|
||||
debug_format,
|
||||
r#"
|
||||
ManifestSstEntry { table_dir: "test_metric_region/", region_id: 47244640257(11, 1), table_id: 11, region_number: 1, region_group: 0, region_sequence: 1, file_id: "<file_id>", level: 0, file_path: "test_metric_region/11_0000000001/data/<file_id>.parquet", file_size: 3157, index_file_path: Some("test_metric_region/11_0000000001/data/index/<file_id>.puffin"), index_file_size: Some(235), num_rows: 10, num_row_groups: 1, min_ts: 0::Millisecond, max_ts: 9::Millisecond, sequence: Some(20), origin_region_id: 47244640257(11, 1), node_id: None }
|
||||
ManifestSstEntry { table_dir: "test_metric_region/", region_id: 47244640258(11, 2), table_id: 11, region_number: 2, region_group: 0, region_sequence: 2, file_id: "<file_id>", level: 0, file_path: "test_metric_region/11_0000000002/data/<file_id>.parquet", file_size: 3157, index_file_path: Some("test_metric_region/11_0000000002/data/index/<file_id>.puffin"), index_file_size: Some(235), num_rows: 10, num_row_groups: 1, min_ts: 0::Millisecond, max_ts: 9::Millisecond, sequence: Some(10), origin_region_id: 47244640258(11, 2), node_id: None }
|
||||
ManifestSstEntry { table_dir: "test_metric_region/", region_id: 47261417473(11, 16777217), table_id: 11, region_number: 16777217, region_group: 1, region_sequence: 1, file_id: "<file_id>", level: 0, file_path: "test_metric_region/11_0000000001/metadata/<file_id>.parquet", file_size: 3201, index_file_path: None, index_file_size: None, num_rows: 8, num_row_groups: 1, min_ts: 0::Millisecond, max_ts: 0::Millisecond, sequence: Some(8), origin_region_id: 47261417473(11, 16777217), node_id: None }
|
||||
ManifestSstEntry { table_dir: "test_metric_region/", region_id: 47261417474(11, 16777218), table_id: 11, region_number: 16777218, region_group: 1, region_sequence: 2, file_id: "<file_id>", level: 0, file_path: "test_metric_region/11_0000000002/metadata/<file_id>.parquet", file_size: 3185, index_file_path: None, index_file_size: None, num_rows: 4, num_row_groups: 1, min_ts: 0::Millisecond, max_ts: 0::Millisecond, sequence: Some(4), origin_region_id: 47261417474(11, 16777218), node_id: None }
|
||||
ManifestSstEntry { table_dir: "test_metric_region/", region_id: 94489280554(22, 42), table_id: 22, region_number: 42, region_group: 0, region_sequence: 42, file_id: "<file_id>", level: 0, file_path: "test_metric_region/22_0000000042/data/<file_id>.parquet", file_size: 3157, index_file_path: Some("test_metric_region/22_0000000042/data/index/<file_id>.puffin"), index_file_size: Some(235), num_rows: 10, num_row_groups: 1, min_ts: 0::Millisecond, max_ts: 9::Millisecond, sequence: Some(10), origin_region_id: 94489280554(22, 42), node_id: None }
|
||||
ManifestSstEntry { table_dir: "test_metric_region/", region_id: 94506057770(22, 16777258), table_id: 22, region_number: 16777258, region_group: 1, region_sequence: 42, file_id: "<file_id>", level: 0, file_path: "test_metric_region/22_0000000042/metadata/<file_id>.parquet", file_size: 3185, index_file_path: None, index_file_size: None, num_rows: 4, num_row_groups: 1, min_ts: 0::Millisecond, max_ts: 0::Millisecond, sequence: Some(4), origin_region_id: 94506057770(22, 16777258), node_id: None }"#
|
||||
ManifestSstEntry { table_dir: "test_metric_region/", region_id: 47244640257(11, 1), table_id: 11, region_number: 1, region_group: 0, region_sequence: 1, file_id: "<file_id>", level: 0, file_path: "test_metric_region/11_0000000001/data/<file_id>.parquet", file_size: 3157, index_file_path: Some("test_metric_region/11_0000000001/data/index/<file_id>.puffin"), index_file_size: Some(235), num_rows: 10, num_row_groups: 1, min_ts: 0::Millisecond, max_ts: 9::Millisecond, sequence: Some(20), origin_region_id: 47244640257(11, 1), node_id: None, visible: true }
|
||||
ManifestSstEntry { table_dir: "test_metric_region/", region_id: 47244640258(11, 2), table_id: 11, region_number: 2, region_group: 0, region_sequence: 2, file_id: "<file_id>", level: 0, file_path: "test_metric_region/11_0000000002/data/<file_id>.parquet", file_size: 3157, index_file_path: Some("test_metric_region/11_0000000002/data/index/<file_id>.puffin"), index_file_size: Some(235), num_rows: 10, num_row_groups: 1, min_ts: 0::Millisecond, max_ts: 9::Millisecond, sequence: Some(10), origin_region_id: 47244640258(11, 2), node_id: None, visible: true }
|
||||
ManifestSstEntry { table_dir: "test_metric_region/", region_id: 47261417473(11, 16777217), table_id: 11, region_number: 16777217, region_group: 1, region_sequence: 1, file_id: "<file_id>", level: 0, file_path: "test_metric_region/11_0000000001/metadata/<file_id>.parquet", file_size: 3429, index_file_path: None, index_file_size: None, num_rows: 8, num_row_groups: 1, min_ts: 0::Millisecond, max_ts: 0::Millisecond, sequence: Some(8), origin_region_id: 47261417473(11, 16777217), node_id: None, visible: true }
|
||||
ManifestSstEntry { table_dir: "test_metric_region/", region_id: 47261417474(11, 16777218), table_id: 11, region_number: 16777218, region_group: 1, region_sequence: 2, file_id: "<file_id>", level: 0, file_path: "test_metric_region/11_0000000002/metadata/<file_id>.parquet", file_size: 3413, index_file_path: None, index_file_size: None, num_rows: 4, num_row_groups: 1, min_ts: 0::Millisecond, max_ts: 0::Millisecond, sequence: Some(4), origin_region_id: 47261417474(11, 16777218), node_id: None, visible: true }
|
||||
ManifestSstEntry { table_dir: "test_metric_region/", region_id: 94489280554(22, 42), table_id: 22, region_number: 42, region_group: 0, region_sequence: 42, file_id: "<file_id>", level: 0, file_path: "test_metric_region/22_0000000042/data/<file_id>.parquet", file_size: 3157, index_file_path: Some("test_metric_region/22_0000000042/data/index/<file_id>.puffin"), index_file_size: Some(235), num_rows: 10, num_row_groups: 1, min_ts: 0::Millisecond, max_ts: 9::Millisecond, sequence: Some(10), origin_region_id: 94489280554(22, 42), node_id: None, visible: true }
|
||||
ManifestSstEntry { table_dir: "test_metric_region/", region_id: 94506057770(22, 16777258), table_id: 22, region_number: 16777258, region_group: 1, region_sequence: 42, file_id: "<file_id>", level: 0, file_path: "test_metric_region/22_0000000042/metadata/<file_id>.parquet", file_size: 3413, index_file_path: None, index_file_size: None, num_rows: 4, num_row_groups: 1, min_ts: 0::Millisecond, max_ts: 0::Millisecond, sequence: Some(4), origin_region_id: 94506057770(22, 16777258), node_id: None, visible: true }"#
|
||||
);
|
||||
|
||||
// list from storage
|
||||
|
||||
@@ -89,7 +89,7 @@ impl MetricEngineInner {
|
||||
Ok(scanner)
|
||||
}
|
||||
|
||||
pub async fn get_last_seq_num(&self, region_id: RegionId) -> Result<Option<SequenceNumber>> {
|
||||
pub async fn get_last_seq_num(&self, region_id: RegionId) -> Result<SequenceNumber> {
|
||||
let region_id = if self.is_physical_region(region_id) {
|
||||
region_id
|
||||
} else {
|
||||
@@ -97,7 +97,7 @@ impl MetricEngineInner {
|
||||
utils::to_data_region_id(physical_region_id)
|
||||
};
|
||||
self.mito
|
||||
.get_last_seq_num(region_id)
|
||||
.get_committed_sequence(region_id)
|
||||
.await
|
||||
.context(MitoReadOperationSnafu)
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ common-recordbatch.workspace = true
|
||||
common-telemetry.workspace = true
|
||||
common-time.workspace = true
|
||||
datatypes.workspace = true
|
||||
memcomparable = "0.2"
|
||||
memcomparable = { git = "https://github.com/v0y4g3r/memcomparable.git", rev = "a07122dc03556bbd88ad66234cbea7efd3b23efb" }
|
||||
paste.workspace = true
|
||||
serde.workspace = true
|
||||
snafu.workspace = true
|
||||
|
||||
@@ -51,7 +51,6 @@ index.workspace = true
|
||||
itertools.workspace = true
|
||||
lazy_static = "1.4"
|
||||
log-store = { workspace = true }
|
||||
memcomparable = "0.2"
|
||||
mito-codec.workspace = true
|
||||
moka = { workspace = true, features = ["sync", "future"] }
|
||||
object-store.workspace = true
|
||||
|
||||
@@ -13,10 +13,11 @@
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use async_stream::try_stream;
|
||||
use common_time::Timestamp;
|
||||
use either::Either;
|
||||
use futures::{Stream, TryStreamExt};
|
||||
use object_store::services::Fs;
|
||||
use object_store::util::{join_dir, with_instrument_layers};
|
||||
@@ -26,7 +27,7 @@ use snafu::ResultExt;
|
||||
use store_api::metadata::RegionMetadataRef;
|
||||
use store_api::region_request::PathType;
|
||||
use store_api::sst_entry::StorageSstEntry;
|
||||
use store_api::storage::{RegionId, SequenceNumber};
|
||||
use store_api::storage::{FileId, RegionId, SequenceNumber};
|
||||
|
||||
use crate::cache::CacheManagerRef;
|
||||
use crate::cache::file_cache::{FileCacheRef, FileType, IndexKey};
|
||||
@@ -34,9 +35,9 @@ use crate::cache::write_cache::SstUploadRequest;
|
||||
use crate::config::{BloomFilterConfig, FulltextIndexConfig, InvertedIndexConfig};
|
||||
use crate::error::{CleanDirSnafu, DeleteIndexSnafu, DeleteSstSnafu, OpenDalSnafu, Result};
|
||||
use crate::metrics::{COMPACTION_STAGE_ELAPSED, FLUSH_ELAPSED};
|
||||
use crate::read::Source;
|
||||
use crate::read::{FlatSource, Source};
|
||||
use crate::region::options::IndexOptions;
|
||||
use crate::sst::file::{FileHandle, FileId, RegionFileId};
|
||||
use crate::sst::file::{FileHandle, RegionFileId};
|
||||
use crate::sst::index::IndexerBuilderImpl;
|
||||
use crate::sst::index::intermediate::IntermediateManager;
|
||||
use crate::sst::index::puffin_manager::PuffinManagerFactory;
|
||||
@@ -44,6 +45,7 @@ use crate::sst::location::{self, region_dir_from_table_dir};
|
||||
use crate::sst::parquet::reader::ParquetReaderBuilder;
|
||||
use crate::sst::parquet::writer::ParquetWriter;
|
||||
use crate::sst::parquet::{SstInfo, WriteOptions};
|
||||
use crate::sst::{DEFAULT_WRITE_BUFFER_SIZE, DEFAULT_WRITE_CONCURRENCY};
|
||||
|
||||
pub type AccessLayerRef = Arc<AccessLayer>;
|
||||
/// SST write results.
|
||||
@@ -66,6 +68,7 @@ pub struct Metrics {
|
||||
pub(crate) update_index: Duration,
|
||||
pub(crate) upload_parquet: Duration,
|
||||
pub(crate) upload_puffin: Duration,
|
||||
pub(crate) compact_memtable: Duration,
|
||||
}
|
||||
|
||||
impl Metrics {
|
||||
@@ -77,6 +80,7 @@ impl Metrics {
|
||||
update_index: Default::default(),
|
||||
upload_parquet: Default::default(),
|
||||
upload_puffin: Default::default(),
|
||||
compact_memtable: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,6 +91,7 @@ impl Metrics {
|
||||
self.update_index += other.update_index;
|
||||
self.upload_parquet += other.upload_parquet;
|
||||
self.upload_puffin += other.upload_puffin;
|
||||
self.compact_memtable += other.compact_memtable;
|
||||
self
|
||||
}
|
||||
|
||||
@@ -108,6 +113,11 @@ impl Metrics {
|
||||
FLUSH_ELAPSED
|
||||
.with_label_values(&["upload_puffin"])
|
||||
.observe(self.upload_puffin.as_secs_f64());
|
||||
if !self.compact_memtable.is_zero() {
|
||||
FLUSH_ELAPSED
|
||||
.with_label_values(&["compact_memtable"])
|
||||
.observe(self.upload_puffin.as_secs_f64());
|
||||
}
|
||||
}
|
||||
WriteType::Compaction => {
|
||||
COMPACTION_STAGE_ELAPSED
|
||||
@@ -288,9 +298,16 @@ impl AccessLayer {
|
||||
)
|
||||
.await
|
||||
.with_file_cleaner(cleaner);
|
||||
let ssts = writer
|
||||
.write_all(request.source, request.max_sequence, write_opts)
|
||||
.await?;
|
||||
let ssts = match request.source {
|
||||
Either::Left(source) => {
|
||||
writer
|
||||
.write_all(source, request.max_sequence, write_opts)
|
||||
.await?
|
||||
}
|
||||
Either::Right(flat_source) => {
|
||||
writer.write_all_flat(flat_source, write_opts).await?
|
||||
}
|
||||
};
|
||||
let metrics = writer.into_metrics();
|
||||
(ssts, metrics)
|
||||
};
|
||||
@@ -310,6 +327,53 @@ impl AccessLayer {
|
||||
Ok((sst_info, metrics))
|
||||
}
|
||||
|
||||
/// Puts encoded SST bytes to the write cache (if enabled) and uploads it to the object store.
|
||||
pub(crate) async fn put_sst(
|
||||
&self,
|
||||
data: &bytes::Bytes,
|
||||
region_id: RegionId,
|
||||
sst_info: &SstInfo,
|
||||
cache_manager: &CacheManagerRef,
|
||||
) -> Result<Metrics> {
|
||||
if let Some(write_cache) = cache_manager.write_cache() {
|
||||
// Write to cache and upload to remote store
|
||||
let upload_request = SstUploadRequest {
|
||||
dest_path_provider: RegionFilePathFactory::new(
|
||||
self.table_dir.clone(),
|
||||
self.path_type,
|
||||
),
|
||||
remote_store: self.object_store.clone(),
|
||||
};
|
||||
write_cache
|
||||
.put_and_upload_sst(data, region_id, sst_info, upload_request)
|
||||
.await
|
||||
} else {
|
||||
let start = Instant::now();
|
||||
let cleaner = TempFileCleaner::new(region_id, self.object_store.clone());
|
||||
let path_provider = RegionFilePathFactory::new(self.table_dir.clone(), self.path_type);
|
||||
let sst_file_path =
|
||||
path_provider.build_sst_file_path(RegionFileId::new(region_id, sst_info.file_id));
|
||||
let mut writer = self
|
||||
.object_store
|
||||
.writer_with(&sst_file_path)
|
||||
.chunk(DEFAULT_WRITE_BUFFER_SIZE.as_bytes() as usize)
|
||||
.concurrent(DEFAULT_WRITE_CONCURRENCY)
|
||||
.await
|
||||
.context(OpenDalSnafu)?;
|
||||
if let Err(err) = writer.write(data.clone()).await.context(OpenDalSnafu) {
|
||||
cleaner.clean_by_file_id(sst_info.file_id).await;
|
||||
return Err(err);
|
||||
}
|
||||
if let Err(err) = writer.close().await.context(OpenDalSnafu) {
|
||||
cleaner.clean_by_file_id(sst_info.file_id).await;
|
||||
return Err(err);
|
||||
}
|
||||
let mut metrics = Metrics::new(WriteType::Flush);
|
||||
metrics.write_batch = start.elapsed();
|
||||
Ok(metrics)
|
||||
}
|
||||
}
|
||||
|
||||
/// Lists the SST entries from the storage layer in the table directory.
|
||||
pub fn storage_sst_entries(&self) -> impl Stream<Item = Result<StorageSstEntry>> + use<> {
|
||||
let object_store = self.object_store.clone();
|
||||
@@ -363,7 +427,7 @@ pub enum OperationType {
|
||||
pub struct SstWriteRequest {
|
||||
pub op_type: OperationType,
|
||||
pub metadata: RegionMetadataRef,
|
||||
pub source: Source,
|
||||
pub source: Either<Source, FlatSource>,
|
||||
pub cache_manager: CacheManagerRef,
|
||||
#[allow(dead_code)]
|
||||
pub storage: Option<String>,
|
||||
|
||||
@@ -35,7 +35,7 @@ use moka::notification::RemovalCause;
|
||||
use moka::sync::Cache;
|
||||
use parquet::file::metadata::ParquetMetaData;
|
||||
use puffin::puffin_manager::cache::{PuffinMetadataCache, PuffinMetadataCacheRef};
|
||||
use store_api::storage::{ConcreteDataType, RegionId, TimeSeriesRowSelector};
|
||||
use store_api::storage::{ConcreteDataType, FileId, RegionId, TimeSeriesRowSelector};
|
||||
|
||||
use crate::cache::cache_size::parquet_meta_size;
|
||||
use crate::cache::file_cache::{FileType, IndexKey};
|
||||
@@ -43,7 +43,7 @@ use crate::cache::index::inverted_index::{InvertedIndexCache, InvertedIndexCache
|
||||
use crate::cache::write_cache::WriteCacheRef;
|
||||
use crate::metrics::{CACHE_BYTES, CACHE_EVICTION, CACHE_HIT, CACHE_MISS};
|
||||
use crate::read::Batch;
|
||||
use crate::sst::file::{FileId, RegionFileId};
|
||||
use crate::sst::file::RegionFileId;
|
||||
|
||||
/// Metrics type key for sst meta.
|
||||
const SST_META_TYPE: &str = "sst_meta";
|
||||
|
||||
3
src/mito2/src/cache/file_cache.rs
vendored
3
src/mito2/src/cache/file_cache.rs
vendored
@@ -30,12 +30,11 @@ use object_store::util::join_path;
|
||||
use object_store::{ErrorKind, ObjectStore, Reader};
|
||||
use parquet::file::metadata::ParquetMetaData;
|
||||
use snafu::ResultExt;
|
||||
use store_api::storage::RegionId;
|
||||
use store_api::storage::{FileId, RegionId};
|
||||
|
||||
use crate::cache::FILE_TYPE;
|
||||
use crate::error::{OpenDalSnafu, Result};
|
||||
use crate::metrics::{CACHE_BYTES, CACHE_HIT, CACHE_MISS};
|
||||
use crate::sst::file::FileId;
|
||||
use crate::sst::parquet::helper::fetch_byte_ranges;
|
||||
use crate::sst::parquet::metadata::MetadataLoader;
|
||||
|
||||
|
||||
@@ -20,11 +20,10 @@ use async_trait::async_trait;
|
||||
use bytes::Bytes;
|
||||
use index::bloom_filter::error::Result;
|
||||
use index::bloom_filter::reader::BloomFilterReader;
|
||||
use store_api::storage::ColumnId;
|
||||
use store_api::storage::{ColumnId, FileId};
|
||||
|
||||
use crate::cache::index::{INDEX_METADATA_TYPE, IndexCache, PageKey};
|
||||
use crate::metrics::{CACHE_HIT, CACHE_MISS};
|
||||
use crate::sst::file::FileId;
|
||||
|
||||
const INDEX_TYPE_BLOOM_FILTER_INDEX: &str = "bloom_filter_index";
|
||||
|
||||
|
||||
2
src/mito2/src/cache/index/inverted_index.rs
vendored
2
src/mito2/src/cache/index/inverted_index.rs
vendored
@@ -21,10 +21,10 @@ use bytes::Bytes;
|
||||
use index::inverted_index::error::Result;
|
||||
use index::inverted_index::format::reader::InvertedIndexReader;
|
||||
use prost::Message;
|
||||
use store_api::storage::FileId;
|
||||
|
||||
use crate::cache::index::{INDEX_METADATA_TYPE, IndexCache, PageKey};
|
||||
use crate::metrics::{CACHE_HIT, CACHE_MISS};
|
||||
use crate::sst::file::FileId;
|
||||
|
||||
const INDEX_TYPE_INVERTED_INDEX: &str = "inverted_index";
|
||||
|
||||
|
||||
3
src/mito2/src/cache/index/result_cache.rs
vendored
3
src/mito2/src/cache/index/result_cache.rs
vendored
@@ -19,10 +19,9 @@ use index::bloom_filter::applier::InListPredicate;
|
||||
use index::inverted_index::search::predicate::{Predicate, RangePredicate};
|
||||
use moka::notification::RemovalCause;
|
||||
use moka::sync::Cache;
|
||||
use store_api::storage::ColumnId;
|
||||
use store_api::storage::{ColumnId, FileId};
|
||||
|
||||
use crate::metrics::{CACHE_BYTES, CACHE_EVICTION, CACHE_HIT, CACHE_MISS};
|
||||
use crate::sst::file::FileId;
|
||||
use crate::sst::index::fulltext_index::applier::builder::{
|
||||
FulltextQuery, FulltextRequest, FulltextTerm,
|
||||
};
|
||||
|
||||
79
src/mito2/src/cache/write_cache.rs
vendored
79
src/mito2/src/cache/write_cache.rs
vendored
@@ -37,8 +37,8 @@ use crate::sst::file::RegionFileId;
|
||||
use crate::sst::index::IndexerBuilderImpl;
|
||||
use crate::sst::index::intermediate::IntermediateManager;
|
||||
use crate::sst::index::puffin_manager::PuffinManagerFactory;
|
||||
use crate::sst::parquet::WriteOptions;
|
||||
use crate::sst::parquet::writer::ParquetWriter;
|
||||
use crate::sst::parquet::{SstInfo, WriteOptions};
|
||||
use crate::sst::{DEFAULT_WRITE_BUFFER_SIZE, DEFAULT_WRITE_CONCURRENCY};
|
||||
|
||||
/// A cache for uploading files to remote object stores.
|
||||
@@ -101,6 +101,66 @@ impl WriteCache {
|
||||
self.file_cache.clone()
|
||||
}
|
||||
|
||||
/// Put encoded SST data to the cache and upload to the remote object store.
|
||||
pub(crate) async fn put_and_upload_sst(
|
||||
&self,
|
||||
data: &bytes::Bytes,
|
||||
region_id: RegionId,
|
||||
sst_info: &SstInfo,
|
||||
upload_request: SstUploadRequest,
|
||||
) -> Result<Metrics> {
|
||||
let file_id = sst_info.file_id;
|
||||
let mut metrics = Metrics::new(WriteType::Flush);
|
||||
|
||||
// Create index key for the SST file
|
||||
let parquet_key = IndexKey::new(region_id, file_id, FileType::Parquet);
|
||||
|
||||
// Write to cache first
|
||||
let cache_start = Instant::now();
|
||||
let cache_path = self.file_cache.cache_file_path(parquet_key);
|
||||
let mut cache_writer = self
|
||||
.file_cache
|
||||
.local_store()
|
||||
.writer(&cache_path)
|
||||
.await
|
||||
.context(crate::error::OpenDalSnafu)?;
|
||||
|
||||
cache_writer
|
||||
.write(data.clone())
|
||||
.await
|
||||
.context(crate::error::OpenDalSnafu)?;
|
||||
cache_writer
|
||||
.close()
|
||||
.await
|
||||
.context(crate::error::OpenDalSnafu)?;
|
||||
|
||||
// Register in file cache
|
||||
let index_value = IndexValue {
|
||||
file_size: data.len() as u32,
|
||||
};
|
||||
self.file_cache.put(parquet_key, index_value).await;
|
||||
metrics.write_batch = cache_start.elapsed();
|
||||
|
||||
// Upload to remote store
|
||||
let upload_start = Instant::now();
|
||||
let region_file_id = RegionFileId::new(region_id, file_id);
|
||||
let remote_path = upload_request
|
||||
.dest_path_provider
|
||||
.build_sst_file_path(region_file_id);
|
||||
|
||||
if let Err(e) = self
|
||||
.upload(parquet_key, &remote_path, &upload_request.remote_store)
|
||||
.await
|
||||
{
|
||||
// Clean up cache on failure
|
||||
self.remove(parquet_key).await;
|
||||
return Err(e);
|
||||
}
|
||||
|
||||
metrics.upload_parquet = upload_start.elapsed();
|
||||
Ok(metrics)
|
||||
}
|
||||
|
||||
/// Writes SST to the cache and then uploads it to the remote object store.
|
||||
pub(crate) async fn write_and_upload_sst(
|
||||
&self,
|
||||
@@ -139,9 +199,14 @@ impl WriteCache {
|
||||
.await
|
||||
.with_file_cleaner(cleaner);
|
||||
|
||||
let sst_info = writer
|
||||
.write_all(write_request.source, write_request.max_sequence, write_opts)
|
||||
.await?;
|
||||
let sst_info = match write_request.source {
|
||||
either::Left(source) => {
|
||||
writer
|
||||
.write_all(source, write_request.max_sequence, write_opts)
|
||||
.await?
|
||||
}
|
||||
either::Right(flat_source) => writer.write_all_flat(flat_source, write_opts).await?,
|
||||
};
|
||||
let mut metrics = writer.into_metrics();
|
||||
|
||||
// Upload sst file to remote object store.
|
||||
@@ -469,7 +534,7 @@ mod tests {
|
||||
let write_request = SstWriteRequest {
|
||||
op_type: OperationType::Flush,
|
||||
metadata,
|
||||
source,
|
||||
source: either::Left(source),
|
||||
storage: None,
|
||||
max_sequence: None,
|
||||
cache_manager: Default::default(),
|
||||
@@ -567,7 +632,7 @@ mod tests {
|
||||
let write_request = SstWriteRequest {
|
||||
op_type: OperationType::Flush,
|
||||
metadata,
|
||||
source,
|
||||
source: either::Left(source),
|
||||
storage: None,
|
||||
max_sequence: None,
|
||||
cache_manager: cache_manager.clone(),
|
||||
@@ -646,7 +711,7 @@ mod tests {
|
||||
let write_request = SstWriteRequest {
|
||||
op_type: OperationType::Flush,
|
||||
metadata,
|
||||
source,
|
||||
source: either::Left(source),
|
||||
storage: None,
|
||||
max_sequence: None,
|
||||
cache_manager: cache_manager.clone(),
|
||||
|
||||
@@ -55,10 +55,10 @@ use crate::error::{
|
||||
TimeRangePredicateOverflowSnafu, TimeoutSnafu,
|
||||
};
|
||||
use crate::metrics::{COMPACTION_STAGE_ELAPSED, INFLIGHT_COMPACTION_COUNT};
|
||||
use crate::read::BoxedBatchReader;
|
||||
use crate::read::projection::ProjectionMapper;
|
||||
use crate::read::scan_region::{PredicateGroup, ScanInput};
|
||||
use crate::read::seq_scan::SeqScan;
|
||||
use crate::read::{BoxedBatchReader, BoxedRecordBatchStream};
|
||||
use crate::region::options::MergeMode;
|
||||
use crate::region::version::VersionControlRef;
|
||||
use crate::region::{ManifestContextRef, RegionLeaderState, RegionRoleState};
|
||||
@@ -638,9 +638,26 @@ struct CompactionSstReaderBuilder<'a> {
|
||||
impl CompactionSstReaderBuilder<'_> {
|
||||
/// Builds [BoxedBatchReader] that reads all SST files and yields batches in primary key order.
|
||||
async fn build_sst_reader(self) -> Result<BoxedBatchReader> {
|
||||
let scan_input = self.build_scan_input(false)?;
|
||||
|
||||
SeqScan::new(scan_input, true)
|
||||
.build_reader_for_compaction()
|
||||
.await
|
||||
}
|
||||
|
||||
/// Builds [BoxedRecordBatchStream] that reads all SST files and yields batches in flat format for compaction.
|
||||
async fn build_flat_sst_reader(self) -> Result<BoxedRecordBatchStream> {
|
||||
let scan_input = self.build_scan_input(true)?;
|
||||
|
||||
SeqScan::new(scan_input, true)
|
||||
.build_flat_reader_for_compaction()
|
||||
.await
|
||||
}
|
||||
|
||||
fn build_scan_input(self, flat_format: bool) -> Result<ScanInput> {
|
||||
let mut scan_input = ScanInput::new(
|
||||
self.sst_layer,
|
||||
ProjectionMapper::all(&self.metadata, false)?,
|
||||
ProjectionMapper::all(&self.metadata, flat_format)?,
|
||||
)
|
||||
.with_files(self.inputs.to_vec())
|
||||
.with_append_mode(self.append_mode)
|
||||
@@ -649,7 +666,8 @@ impl CompactionSstReaderBuilder<'_> {
|
||||
.with_filter_deleted(self.filter_deleted)
|
||||
// We ignore file not found error during compaction.
|
||||
.with_ignore_file_not_found(true)
|
||||
.with_merge_mode(self.merge_mode);
|
||||
.with_merge_mode(self.merge_mode)
|
||||
.with_flat_format(flat_format);
|
||||
|
||||
// This serves as a workaround of https://github.com/GreptimeTeam/greptimedb/issues/3944
|
||||
// by converting time ranges into predicate.
|
||||
@@ -658,9 +676,7 @@ impl CompactionSstReaderBuilder<'_> {
|
||||
scan_input.with_predicate(time_range_to_predicate(time_range, &self.metadata)?);
|
||||
}
|
||||
|
||||
SeqScan::new(scan_input, true)
|
||||
.build_reader_for_compaction()
|
||||
.await
|
||||
Ok(scan_input)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -80,9 +80,10 @@ pub(crate) const TIME_BUCKETS: TimeBuckets = TimeBuckets([
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use store_api::storage::FileId;
|
||||
|
||||
use super::*;
|
||||
use crate::compaction::test_util::new_file_handle;
|
||||
use crate::sst::file::FileId;
|
||||
|
||||
#[test]
|
||||
fn test_time_bucket() {
|
||||
|
||||
@@ -42,7 +42,7 @@ use crate::manifest::action::{RegionEdit, RegionMetaAction, RegionMetaActionList
|
||||
use crate::manifest::manager::{RegionManifestManager, RegionManifestOptions, RemoveFileOptions};
|
||||
use crate::manifest::storage::manifest_compress_type;
|
||||
use crate::metrics;
|
||||
use crate::read::Source;
|
||||
use crate::read::{FlatSource, Source};
|
||||
use crate::region::opener::new_manifest_dir;
|
||||
use crate::region::options::RegionOptions;
|
||||
use crate::region::version::VersionRef;
|
||||
@@ -342,6 +342,9 @@ impl Compactor for DefaultCompactor {
|
||||
.clone();
|
||||
let append_mode = compaction_region.current_version.options.append_mode;
|
||||
let merge_mode = compaction_region.current_version.options.merge_mode();
|
||||
let flat_format = compaction_region
|
||||
.engine_config
|
||||
.enable_experimental_flat_format;
|
||||
let inverted_index_config = compaction_region.engine_config.inverted_index.clone();
|
||||
let fulltext_index_config = compaction_region.engine_config.fulltext_index.clone();
|
||||
let bloom_filter_index_config =
|
||||
@@ -359,7 +362,7 @@ impl Compactor for DefaultCompactor {
|
||||
.iter()
|
||||
.map(|f| f.file_id().to_string())
|
||||
.join(",");
|
||||
let reader = CompactionSstReaderBuilder {
|
||||
let builder = CompactionSstReaderBuilder {
|
||||
metadata: region_metadata.clone(),
|
||||
sst_layer: sst_layer.clone(),
|
||||
cache: cache_manager.clone(),
|
||||
@@ -368,15 +371,20 @@ impl Compactor for DefaultCompactor {
|
||||
filter_deleted: output.filter_deleted,
|
||||
time_range: output.output_time_range,
|
||||
merge_mode,
|
||||
}
|
||||
.build_sst_reader()
|
||||
.await?;
|
||||
};
|
||||
let source = if flat_format {
|
||||
let reader = builder.build_flat_sst_reader().await?;
|
||||
either::Right(FlatSource::Stream(reader))
|
||||
} else {
|
||||
let reader = builder.build_sst_reader().await?;
|
||||
either::Left(Source::Reader(reader))
|
||||
};
|
||||
let (sst_infos, metrics) = sst_layer
|
||||
.write_sst(
|
||||
SstWriteRequest {
|
||||
op_type: OperationType::Compact,
|
||||
metadata: region_metadata,
|
||||
source: Source::Reader(reader),
|
||||
source,
|
||||
cache_manager,
|
||||
storage,
|
||||
max_sequence: max_sequence.map(NonZero::get),
|
||||
@@ -475,6 +483,7 @@ impl Compactor for DefaultCompactor {
|
||||
.map(|seconds| Duration::from_secs(seconds as u64)),
|
||||
flushed_entry_id: None,
|
||||
flushed_sequence: None,
|
||||
committed_sequence: None,
|
||||
};
|
||||
|
||||
let action_list = RegionMetaActionList::with_action(RegionMetaAction::Edit(edit.clone()));
|
||||
|
||||
@@ -147,9 +147,10 @@ pub fn new_picker(
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use store_api::storage::FileId;
|
||||
|
||||
use super::*;
|
||||
use crate::compaction::test_util::new_file_handle;
|
||||
use crate::sst::file::FileId;
|
||||
use crate::test_util::new_noop_file_purger;
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -15,8 +15,9 @@
|
||||
use std::num::NonZeroU64;
|
||||
|
||||
use common_time::Timestamp;
|
||||
use store_api::storage::FileId;
|
||||
|
||||
use crate::sst::file::{FileHandle, FileId, FileMeta, Level};
|
||||
use crate::sst::file::{FileHandle, FileMeta, Level};
|
||||
use crate::test_util::new_noop_file_purger;
|
||||
|
||||
/// Test util to create file handles.
|
||||
|
||||
@@ -350,11 +350,13 @@ fn find_latest_window_in_seconds<'a>(
|
||||
mod tests {
|
||||
use std::collections::HashSet;
|
||||
|
||||
use store_api::storage::FileId;
|
||||
|
||||
use super::*;
|
||||
use crate::compaction::test_util::{
|
||||
new_file_handle, new_file_handle_with_sequence, new_file_handle_with_size_and_sequence,
|
||||
};
|
||||
use crate::sst::file::{FileId, Level};
|
||||
use crate::sst::file::Level;
|
||||
|
||||
#[test]
|
||||
fn test_get_latest_window_in_seconds() {
|
||||
|
||||
@@ -206,12 +206,12 @@ mod tests {
|
||||
|
||||
use common_time::Timestamp;
|
||||
use common_time::range::TimestampRange;
|
||||
use store_api::storage::RegionId;
|
||||
use store_api::storage::{FileId, RegionId};
|
||||
|
||||
use crate::compaction::compactor::CompactionVersion;
|
||||
use crate::compaction::window::{WindowedCompactionPicker, file_time_bucket_span};
|
||||
use crate::region::options::RegionOptions;
|
||||
use crate::sst::file::{FileId, FileMeta, Level};
|
||||
use crate::sst::file::{FileMeta, Level};
|
||||
use crate::sst::file_purger::NoopFilePurger;
|
||||
use crate::sst::version::SstVersion;
|
||||
use crate::test_util::memtable_util::metadata_for_test;
|
||||
|
||||
@@ -141,6 +141,10 @@ pub struct MitoConfig {
|
||||
/// To align with the old behavior, the default value is 0 (no restrictions).
|
||||
#[serde(with = "humantime_serde")]
|
||||
pub min_compaction_interval: Duration,
|
||||
|
||||
/// Whether to enable experimental flat format.
|
||||
/// When enabled, forces using BulkMemtable and BulkMemtableBuilder.
|
||||
pub enable_experimental_flat_format: bool,
|
||||
}
|
||||
|
||||
impl Default for MitoConfig {
|
||||
@@ -177,6 +181,7 @@ impl Default for MitoConfig {
|
||||
bloom_filter_index: BloomFilterConfig::default(),
|
||||
memtable: MemtableConfig::default(),
|
||||
min_compaction_interval: Duration::from_secs(0),
|
||||
enable_experimental_flat_format: false,
|
||||
};
|
||||
|
||||
// Adjust buffer and cache size according to system memory if we can.
|
||||
|
||||
@@ -23,6 +23,8 @@ mod basic_test;
|
||||
#[cfg(test)]
|
||||
mod batch_open_test;
|
||||
#[cfg(test)]
|
||||
mod bump_committed_sequence_test;
|
||||
#[cfg(test)]
|
||||
mod catchup_test;
|
||||
#[cfg(test)]
|
||||
mod close_test;
|
||||
@@ -53,6 +55,8 @@ mod prune_test;
|
||||
#[cfg(test)]
|
||||
mod row_selector_test;
|
||||
#[cfg(test)]
|
||||
mod scan_corrupt;
|
||||
#[cfg(test)]
|
||||
mod scan_test;
|
||||
#[cfg(test)]
|
||||
mod set_role_state_test;
|
||||
@@ -414,16 +418,20 @@ impl MitoEngine {
|
||||
}
|
||||
|
||||
/// Lists all SSTs from the manifest of all regions in the engine.
|
||||
pub fn all_ssts_from_manifest(&self) -> impl Iterator<Item = ManifestSstEntry> + use<'_> {
|
||||
pub async fn all_ssts_from_manifest(&self) -> Vec<ManifestSstEntry> {
|
||||
let node_id = self.inner.workers.file_ref_manager().node_id();
|
||||
self.inner
|
||||
.workers
|
||||
.all_regions()
|
||||
.flat_map(|region| region.manifest_sst_entries())
|
||||
.map(move |mut entry| {
|
||||
entry.node_id = node_id;
|
||||
entry
|
||||
})
|
||||
let regions = self.inner.workers.all_regions();
|
||||
|
||||
let mut results = Vec::new();
|
||||
for region in regions {
|
||||
let mut entries = region.manifest_sst_entries().await;
|
||||
for e in &mut entries {
|
||||
e.node_id = node_id;
|
||||
}
|
||||
results.extend(entries);
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
|
||||
/// Lists all SSTs from the storage layer of all regions in the engine.
|
||||
@@ -465,6 +473,7 @@ fn is_valid_region_edit(edit: &RegionEdit) -> bool {
|
||||
compaction_time_window: None,
|
||||
flushed_entry_id: None,
|
||||
flushed_sequence: None,
|
||||
..
|
||||
}
|
||||
)
|
||||
}
|
||||
@@ -658,10 +667,11 @@ impl EngineInner {
|
||||
receiver.await.context(RecvSnafu)?
|
||||
}
|
||||
|
||||
fn get_last_seq_num(&self, region_id: RegionId) -> Result<Option<SequenceNumber>> {
|
||||
/// Returns the sequence of latest committed data.
|
||||
fn get_committed_sequence(&self, region_id: RegionId) -> Result<SequenceNumber> {
|
||||
// Reading a region doesn't need to go through the region worker thread.
|
||||
let region = self.find_region(region_id)?;
|
||||
Ok(Some(region.find_committed_sequence()))
|
||||
self.find_region(region_id)
|
||||
.map(|r| r.find_committed_sequence())
|
||||
}
|
||||
|
||||
/// Handles the scan `request` and returns a [ScanRegion].
|
||||
@@ -685,8 +695,7 @@ impl EngineInner {
|
||||
.with_ignore_fulltext_index(self.config.fulltext_index.apply_on_query.disabled())
|
||||
.with_ignore_bloom_filter(self.config.bloom_filter_index.apply_on_query.disabled())
|
||||
.with_start_time(query_start)
|
||||
// TODO(yingwen): Enable it after flat format is supported.
|
||||
.with_flat_format(false);
|
||||
.with_flat_format(self.config.enable_experimental_flat_format);
|
||||
|
||||
#[cfg(feature = "enterprise")]
|
||||
let scan_region = self.maybe_fill_extension_range_provider(scan_region, region);
|
||||
@@ -829,12 +838,12 @@ impl RegionEngine for MitoEngine {
|
||||
.map_err(BoxedError::new)
|
||||
}
|
||||
|
||||
async fn get_last_seq_num(
|
||||
async fn get_committed_sequence(
|
||||
&self,
|
||||
region_id: RegionId,
|
||||
) -> Result<Option<SequenceNumber>, BoxedError> {
|
||||
) -> Result<SequenceNumber, BoxedError> {
|
||||
self.inner
|
||||
.get_last_seq_num(region_id)
|
||||
.get_committed_sequence(region_id)
|
||||
.map_err(BoxedError::new)
|
||||
}
|
||||
|
||||
@@ -1018,6 +1027,7 @@ mod tests {
|
||||
compaction_time_window: None,
|
||||
flushed_entry_id: None,
|
||||
flushed_sequence: None,
|
||||
committed_sequence: None,
|
||||
};
|
||||
assert!(is_valid_region_edit(&edit));
|
||||
|
||||
@@ -1029,6 +1039,7 @@ mod tests {
|
||||
compaction_time_window: None,
|
||||
flushed_entry_id: None,
|
||||
flushed_sequence: None,
|
||||
committed_sequence: None,
|
||||
};
|
||||
assert!(!is_valid_region_edit(&edit));
|
||||
|
||||
@@ -1040,6 +1051,7 @@ mod tests {
|
||||
compaction_time_window: None,
|
||||
flushed_entry_id: None,
|
||||
flushed_sequence: None,
|
||||
committed_sequence: None,
|
||||
};
|
||||
assert!(!is_valid_region_edit(&edit));
|
||||
|
||||
@@ -1051,6 +1063,7 @@ mod tests {
|
||||
compaction_time_window: Some(Duration::from_secs(1)),
|
||||
flushed_entry_id: None,
|
||||
flushed_sequence: None,
|
||||
committed_sequence: None,
|
||||
};
|
||||
assert!(!is_valid_region_edit(&edit));
|
||||
let edit = RegionEdit {
|
||||
@@ -1060,6 +1073,7 @@ mod tests {
|
||||
compaction_time_window: None,
|
||||
flushed_entry_id: Some(1),
|
||||
flushed_sequence: None,
|
||||
committed_sequence: None,
|
||||
};
|
||||
assert!(!is_valid_region_edit(&edit));
|
||||
let edit = RegionEdit {
|
||||
@@ -1069,6 +1083,7 @@ mod tests {
|
||||
compaction_time_window: None,
|
||||
flushed_entry_id: None,
|
||||
flushed_sequence: Some(1),
|
||||
committed_sequence: None,
|
||||
};
|
||||
assert!(!is_valid_region_edit(&edit));
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user