feat(vector): add conversion between vector and string (#5029)

* feat(vector): add conversion between vector and string

Signed-off-by: Zhenchi <zhongzc_arch@outlook.com>

* fix sqlness

Signed-off-by: Zhenchi <zhongzc_arch@outlook.com>

* address comments

Signed-off-by: Zhenchi <zhongzc_arch@outlook.com>

---------

Signed-off-by: Zhenchi <zhongzc_arch@outlook.com>
This commit is contained in:
Zhenchi
2024-11-20 16:42:00 +08:00
committed by GitHub
parent 027284ed1b
commit 0aab68c23b
17 changed files with 655 additions and 169 deletions

View File

@@ -12,20 +12,24 @@
// See the License for the specific language governing permissions and
// limitations under the License.
mod convert;
mod distance;
use std::sync::Arc;
use distance::{CosDistanceFunction, DotProductFunction, L2SqDistanceFunction};
use crate::function_registry::FunctionRegistry;
pub(crate) struct VectorFunction;
impl VectorFunction {
pub fn register(registry: &FunctionRegistry) {
registry.register(Arc::new(CosDistanceFunction));
registry.register(Arc::new(DotProductFunction));
registry.register(Arc::new(L2SqDistanceFunction));
// conversion
registry.register(Arc::new(convert::ParseVectorFunction));
registry.register(Arc::new(convert::VectorToStringFunction));
// distance
registry.register(Arc::new(distance::CosDistanceFunction));
registry.register(Arc::new(distance::DotProductFunction));
registry.register(Arc::new(distance::L2SqDistanceFunction));
}
}

View File

@@ -0,0 +1,19 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
mod parse_vector;
mod vector_to_string;
pub use parse_vector::ParseVectorFunction;
pub use vector_to_string::VectorToStringFunction;

View File

@@ -0,0 +1,160 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::fmt::Display;
use common_query::error::{InvalidFuncArgsSnafu, InvalidVectorStringSnafu, Result};
use common_query::prelude::{Signature, Volatility};
use datatypes::prelude::ConcreteDataType;
use datatypes::scalars::ScalarVectorBuilder;
use datatypes::types::parse_string_to_vector_type_value;
use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
use snafu::{ensure, ResultExt};
use crate::function::{Function, FunctionContext};
const NAME: &str = "parse_vec";
#[derive(Debug, Clone, Default)]
pub struct ParseVectorFunction;
impl Function for ParseVectorFunction {
fn name(&self) -> &str {
NAME
}
fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
Ok(ConcreteDataType::binary_datatype())
}
fn signature(&self) -> Signature {
Signature::exact(
vec![ConcreteDataType::string_datatype()],
Volatility::Immutable,
)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 1,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly one, have: {}",
columns.len()
),
}
);
let column = &columns[0];
let size = column.len();
let mut result = BinaryVectorBuilder::with_capacity(size);
for i in 0..size {
let value = column.get(i).as_string();
if let Some(value) = value {
let res = parse_string_to_vector_type_value(&value, None)
.context(InvalidVectorStringSnafu { vec_str: &value })?;
result.push(Some(&res));
} else {
result.push_null();
}
}
Ok(result.to_vector())
}
}
impl Display for ParseVectorFunction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", NAME.to_ascii_uppercase())
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use common_base::bytes::Bytes;
use datatypes::value::Value;
use datatypes::vectors::StringVector;
use super::*;
#[test]
fn test_parse_vector() {
let func = ParseVectorFunction;
let input = Arc::new(StringVector::from(vec![
Some("[1.0,2.0,3.0]".to_string()),
Some("[4.0,5.0,6.0]".to_string()),
None,
]));
let result = func.eval(FunctionContext::default(), &[input]).unwrap();
let result = result.as_ref();
assert_eq!(result.len(), 3);
assert_eq!(
result.get(0),
Value::Binary(Bytes::from(
[1.0f32, 2.0, 3.0]
.iter()
.flat_map(|e| e.to_le_bytes())
.collect::<Vec<u8>>()
))
);
assert_eq!(
result.get(1),
Value::Binary(Bytes::from(
[4.0f32, 5.0, 6.0]
.iter()
.flat_map(|e| e.to_le_bytes())
.collect::<Vec<u8>>()
))
);
assert!(result.get(2).is_null());
}
#[test]
fn test_parse_vector_error() {
let func = ParseVectorFunction;
let input = Arc::new(StringVector::from(vec![
Some("[1.0,2.0,3.0]".to_string()),
Some("[4.0,5.0,6.0]".to_string()),
Some("[7.0,8.0,9.0".to_string()),
]));
let result = func.eval(FunctionContext::default(), &[input]);
assert!(result.is_err());
let input = Arc::new(StringVector::from(vec![
Some("[1.0,2.0,3.0]".to_string()),
Some("[4.0,5.0,6.0]".to_string()),
Some("7.0,8.0,9.0]".to_string()),
]));
let result = func.eval(FunctionContext::default(), &[input]);
assert!(result.is_err());
let input = Arc::new(StringVector::from(vec![
Some("[1.0,2.0,3.0]".to_string()),
Some("[4.0,5.0,6.0]".to_string()),
Some("[7.0,hello,9.0]".to_string()),
]));
let result = func.eval(FunctionContext::default(), &[input]);
assert!(result.is_err());
}
}

View File

@@ -0,0 +1,139 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::fmt::Display;
use common_query::error::{InvalidFuncArgsSnafu, Result};
use common_query::prelude::{Signature, Volatility};
use datatypes::prelude::ConcreteDataType;
use datatypes::scalars::ScalarVectorBuilder;
use datatypes::types::vector_type_value_to_string;
use datatypes::value::Value;
use datatypes::vectors::{MutableVector, StringVectorBuilder, VectorRef};
use snafu::ensure;
use crate::function::{Function, FunctionContext};
const NAME: &str = "vec_to_string";
#[derive(Debug, Clone, Default)]
pub struct VectorToStringFunction;
impl Function for VectorToStringFunction {
fn name(&self) -> &str {
NAME
}
fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
Ok(ConcreteDataType::string_datatype())
}
fn signature(&self) -> Signature {
Signature::exact(
vec![ConcreteDataType::binary_datatype()],
Volatility::Immutable,
)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 1,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly one, have: {}",
columns.len()
),
}
);
let column = &columns[0];
let size = column.len();
let mut result = StringVectorBuilder::with_capacity(size);
for i in 0..size {
let value = column.get(i);
match value {
Value::Binary(bytes) => {
let len = bytes.len();
if len % std::mem::size_of::<f32>() != 0 {
return InvalidFuncArgsSnafu {
err_msg: format!("Invalid binary length of vector: {}", len),
}
.fail();
}
let dim = len / std::mem::size_of::<f32>();
// Safety: `dim` is calculated from the length of `bytes` and is guaranteed to be valid
let res = vector_type_value_to_string(&bytes, dim as _).unwrap();
result.push(Some(&res));
}
Value::Null => {
result.push_null();
}
_ => {
return InvalidFuncArgsSnafu {
err_msg: format!("Invalid value type: {:?}", value.data_type()),
}
.fail();
}
}
}
Ok(result.to_vector())
}
}
impl Display for VectorToStringFunction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", NAME.to_ascii_uppercase())
}
}
#[cfg(test)]
mod tests {
use datatypes::value::Value;
use datatypes::vectors::BinaryVectorBuilder;
use super::*;
#[test]
fn test_vector_to_string() {
let func = VectorToStringFunction;
let mut builder = BinaryVectorBuilder::with_capacity(3);
builder.push(Some(
[1.0f32, 2.0, 3.0]
.iter()
.flat_map(|e| e.to_le_bytes())
.collect::<Vec<_>>()
.as_slice(),
));
builder.push(Some(
[4.0f32, 5.0, 6.0]
.iter()
.flat_map(|e| e.to_le_bytes())
.collect::<Vec<_>>()
.as_slice(),
));
builder.push_null();
let vector = builder.to_vector();
let result = func.eval(FunctionContext::default(), &[vector]).unwrap();
assert_eq!(result.len(), 3);
assert_eq!(result.get(0), Value::String("[1,2,3]".to_string().into()));
assert_eq!(result.get(1), Value::String("[4,5,6]".to_string().into()));
assert_eq!(result.get(2), Value::Null);
}
}

View File

@@ -245,6 +245,14 @@ pub enum Error {
#[snafu(implicit)]
location: Location,
},
#[snafu(display("Invalid vector string: {}", vec_str))]
InvalidVectorString {
vec_str: String,
source: DataTypeError,
#[snafu(implicit)]
location: Location,
},
}
pub type Result<T> = std::result::Result<T, Error>;
@@ -273,7 +281,8 @@ impl ErrorExt for Error {
| Error::IntoVector { source, .. }
| Error::FromScalarValue { source, .. }
| Error::ConvertArrowSchema { source, .. }
| Error::FromArrowArray { source, .. } => source.status_code(),
| Error::FromArrowArray { source, .. }
| Error::InvalidVectorString { source, .. } => source.status_code(),
Error::MissingTableMutationHandler { .. }
| Error::MissingProcedureServiceHandler { .. }

View File

@@ -102,7 +102,7 @@ pub fn vector_type_value_to_string(val: &[u8], dim: u32) -> Result<String> {
/// Parses a string to a vector type value
/// Valid input format: "[1.0,2.0,3.0]", "[1.0, 2.0, 3.0]"
pub fn parse_string_to_vector_type_value(s: &str, dim: u32) -> Result<Vec<u8>> {
pub fn parse_string_to_vector_type_value(s: &str, dim: Option<u32>) -> Result<Vec<u8>> {
// Trim the brackets
let trimmed = s.trim();
if !trimmed.starts_with('[') || !trimmed.ends_with(']') {
@@ -115,7 +115,7 @@ pub fn parse_string_to_vector_type_value(s: &str, dim: u32) -> Result<Vec<u8>> {
let content = trimmed[1..trimmed.len() - 1].trim();
if content.is_empty() {
if dim != 0 {
if dim.map_or(false, |d| d != 0) {
return InvalidVectorSnafu {
msg: format!("Failed to parse {s} to Vector value: wrong dimension"),
}
@@ -139,7 +139,7 @@ pub fn parse_string_to_vector_type_value(s: &str, dim: u32) -> Result<Vec<u8>> {
.collect::<Result<Vec<f32>>>()?;
// Check dimension
if elements.len() != dim as usize {
if dim.map_or(false, |d| d as usize != elements.len()) {
return InvalidVectorSnafu {
msg: format!("Failed to parse {s} to Vector value: wrong dimension"),
}
@@ -180,7 +180,7 @@ mod tests {
];
for (s, expected) in cases.iter() {
let val = parse_string_to_vector_type_value(s, dim).unwrap();
let val = parse_string_to_vector_type_value(s, Some(dim)).unwrap();
let s = vector_type_value_to_string(&val, dim).unwrap();
assert_eq!(s, *expected);
}
@@ -188,7 +188,7 @@ mod tests {
let dim = 0;
let cases = [("[]", "[]"), ("[ ]", "[]"), ("[ ]", "[]")];
for (s, expected) in cases.iter() {
let val = parse_string_to_vector_type_value(s, dim).unwrap();
let val = parse_string_to_vector_type_value(s, Some(dim)).unwrap();
let s = vector_type_value_to_string(&val, dim).unwrap();
assert_eq!(s, *expected);
}
@@ -211,15 +211,15 @@ mod tests {
fn test_parse_string_to_vector_type_value_not_properly_enclosed_in_brackets() {
let dim = 3;
let s = "1.0,2.0,3.0";
let res = parse_string_to_vector_type_value(s, dim);
let res = parse_string_to_vector_type_value(s, Some(dim));
assert!(res.is_err());
let s = "[1.0,2.0,3.0";
let res = parse_string_to_vector_type_value(s, dim);
let res = parse_string_to_vector_type_value(s, Some(dim));
assert!(res.is_err());
let s = "1.0,2.0,3.0]";
let res = parse_string_to_vector_type_value(s, dim);
let res = parse_string_to_vector_type_value(s, Some(dim));
assert!(res.is_err());
}
@@ -227,7 +227,7 @@ mod tests {
fn test_parse_string_to_vector_type_value_wrong_dimension() {
let dim = 3;
let s = "[1.0,2.0]";
let res = parse_string_to_vector_type_value(s, dim);
let res = parse_string_to_vector_type_value(s, Some(dim));
assert!(res.is_err());
}
@@ -235,7 +235,7 @@ mod tests {
fn test_parse_string_to_vector_type_value_elements_are_not_all_float32() {
let dim = 3;
let s = "[1.0,2.0,ah]";
let res = parse_string_to_vector_type_value(s, dim);
let res = parse_string_to_vector_type_value(s, Some(dim));
assert!(res.is_err());
}
}

View File

@@ -80,7 +80,7 @@ impl BinaryVector {
let v = if let Some(binary) = binary {
let bytes_size = dim as usize * std::mem::size_of::<f32>();
if let Ok(s) = String::from_utf8(binary.to_vec()) {
let v = parse_string_to_vector_type_value(&s, dim)?;
let v = parse_string_to_vector_type_value(&s, Some(dim))?;
Some(v)
} else if binary.len() == dim as usize * std::mem::size_of::<f32>() {
Some(binary.to_vec())

View File

@@ -21,7 +21,7 @@ use common_recordbatch::{RecordBatch, SendableRecordBatchStream};
use common_telemetry::{debug, error};
use datatypes::prelude::{ConcreteDataType, Value};
use datatypes::schema::SchemaRef;
use datatypes::types::{json_type_value_to_string, vector_type_value_to_string};
use datatypes::types::json_type_value_to_string;
use futures::StreamExt;
use opensrv_mysql::{
Column, ColumnFlags, ColumnType, ErrorKind, OkResponse, QueryResultWriter, RowWriter,
@@ -217,11 +217,6 @@ impl<'a, W: AsyncWrite + Unpin> MysqlResultWriter<'a, W> {
.context(ConvertSqlValueSnafu)?;
row_writer.write_col(s)?;
}
ConcreteDataType::Vector(d) => {
let s = vector_type_value_to_string(&v, d.dim)
.context(ConvertSqlValueSnafu)?;
row_writer.write_col(s)?;
}
_ => {
row_writer.write_col(v.deref())?;
}
@@ -303,7 +298,7 @@ pub(crate) fn create_mysql_column(
ConcreteDataType::Duration(_) => Ok(ColumnType::MYSQL_TYPE_TIME),
ConcreteDataType::Decimal128(_) => Ok(ColumnType::MYSQL_TYPE_DECIMAL),
ConcreteDataType::Json(_) => Ok(ColumnType::MYSQL_TYPE_JSON),
ConcreteDataType::Vector(_) => Ok(ColumnType::MYSQL_TYPE_STRING),
ConcreteDataType::Vector(_) => Ok(ColumnType::MYSQL_TYPE_BLOB),
_ => error::UnsupportedDataTypeSnafu {
data_type,
reason: "not implemented",

View File

@@ -27,9 +27,7 @@ use datafusion_expr::LogicalPlan;
use datatypes::arrow::datatypes::DataType as ArrowDataType;
use datatypes::prelude::{ConcreteDataType, Value};
use datatypes::schema::Schema;
use datatypes::types::{
json_type_value_to_string, vector_type_value_to_string, IntervalType, TimestampType,
};
use datatypes::types::{json_type_value_to_string, IntervalType, TimestampType};
use datatypes::value::ListValue;
use pgwire::api::portal::{Format, Portal};
use pgwire::api::results::{DataRowEncoder, FieldInfo};
@@ -178,7 +176,7 @@ fn encode_array(
.collect::<PgWireResult<Vec<Option<f64>>>>()?;
builder.encode_field(&array)
}
&ConcreteDataType::Binary(_) => {
&ConcreteDataType::Binary(_) | &ConcreteDataType::Vector(_) => {
let bytea_output = query_ctx.configuration_parameter().postgres_bytea_output();
match *bytea_output {
@@ -370,24 +368,6 @@ fn encode_array(
.collect::<PgWireResult<Vec<Option<String>>>>()?;
builder.encode_field(&array)
}
&ConcreteDataType::Vector(d) => {
let array = value_list
.items()
.iter()
.map(|v| match v {
Value::Null => Ok(None),
Value::Binary(v) => {
let s = vector_type_value_to_string(v, d.dim)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
Ok(Some(s))
}
_ => Err(PgWireError::ApiError(Box::new(Error::Internal {
err_msg: format!("Invalid list item type, find {v:?}, expected vector",),
}))),
})
.collect::<PgWireResult<Vec<Option<String>>>>()?;
builder.encode_field(&array)
}
_ => Err(PgWireError::ApiError(Box::new(Error::Internal {
err_msg: format!(
"cannot write array type {:?} in postgres protocol: unimplemented",
@@ -423,11 +403,6 @@ pub(super) fn encode_value(
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
builder.encode_field(&s)
}
ConcreteDataType::Vector(d) => {
let s = vector_type_value_to_string(v, d.dim)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
builder.encode_field(&s)
}
_ => {
let bytea_output = query_ctx.configuration_parameter().postgres_bytea_output();
match *bytea_output {
@@ -503,7 +478,7 @@ pub(super) fn type_gt_to_pg(origin: &ConcreteDataType) -> Result<Type> {
&ConcreteDataType::Int64(_) | &ConcreteDataType::UInt64(_) => Ok(Type::INT8),
&ConcreteDataType::Float32(_) => Ok(Type::FLOAT4),
&ConcreteDataType::Float64(_) => Ok(Type::FLOAT8),
&ConcreteDataType::Binary(_) => Ok(Type::BYTEA),
&ConcreteDataType::Binary(_) | &ConcreteDataType::Vector(_) => Ok(Type::BYTEA),
&ConcreteDataType::String(_) => Ok(Type::VARCHAR),
&ConcreteDataType::Date(_) => Ok(Type::DATE),
&ConcreteDataType::DateTime(_) | &ConcreteDataType::Timestamp(_) => Ok(Type::TIMESTAMP),
@@ -546,7 +521,6 @@ pub(super) fn type_gt_to_pg(origin: &ConcreteDataType) -> Result<Type> {
}
.fail()
}
&ConcreteDataType::Vector(_) => Ok(Type::FLOAT4_ARRAY),
}
}

View File

@@ -133,7 +133,7 @@ fn parse_string_to_value(
Ok(Value::Binary(v.into()))
}
ConcreteDataType::Vector(d) => {
let v = parse_string_to_vector_type_value(&s, d.dim).context(DatatypeSnafu)?;
let v = parse_string_to_vector_type_value(&s, Some(d.dim)).context(DatatypeSnafu)?;
Ok(Value::Binary(v.into()))
}
_ => {