From fb12b7be28f1b368820d9dd231d3ad38d0b1de44 Mon Sep 17 00:00:00 2001 From: Remi Dettai Date: Thu, 3 Apr 2025 10:07:34 +0200 Subject: [PATCH] Tag UserInputAst --- query-grammar/src/infallible.rs | 5 +- query-grammar/src/user_input_ast.rs | 82 ++++++++++++++++++++--------- 2 files changed, 62 insertions(+), 25 deletions(-) diff --git a/query-grammar/src/infallible.rs b/query-grammar/src/infallible.rs index 36a39bda9..414054045 100644 --- a/query-grammar/src/infallible.rs +++ b/query-grammar/src/infallible.rs @@ -367,6 +367,9 @@ mod tests { message: "test error message".to_string(), }; - assert_eq!(serde_json::to_string(&error).unwrap(), "{\"pos\":42,\"message\":\"test error message\"}"); + assert_eq!( + serde_json::to_string(&error).unwrap(), + "{\"pos\":42,\"message\":\"test error message\"}" + ); } } diff --git a/query-grammar/src/user_input_ast.rs b/query-grammar/src/user_input_ast.rs index 25cb01f72..bd079175f 100644 --- a/query-grammar/src/user_input_ast.rs +++ b/query-grammar/src/user_input_ast.rs @@ -196,14 +196,39 @@ impl UserInputBound { } #[derive(PartialEq, Clone, Serialize)] -#[serde(rename_all = "snake_case")] +#[serde(into = "UserInputAstSerde")] pub enum UserInputAst { Clause(Vec<(Option, UserInputAst)>), Boost(Box, f64), + Leaf(Box), +} + +#[derive(Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +enum UserInputAstSerde { + Clause { + clause: Vec<(Option, UserInputAst)>, + }, + Boost { + underlying: Box, + boost: f64, + }, #[serde(untagged)] Leaf(Box), } +impl From for UserInputAstSerde { + fn from(ast: UserInputAst) -> Self { + match ast { + UserInputAst::Clause(clause) => UserInputAstSerde::Clause { clause }, + UserInputAst::Boost(underlying, boost) => { + UserInputAstSerde::Boost { underlying, boost } + } + UserInputAst::Leaf(leaf) => UserInputAstSerde::Leaf(leaf), + } + } +} + impl UserInputAst { #[must_use] pub fn unary(self, occur: Occur) -> UserInputAst { @@ -357,14 +382,11 @@ mod tests { #[test] fn test_boost_serialization() { let inner_ast = UserInputAst::Leaf(Box::new(UserInputLeaf::All)); - let boost_ast = UserInputAst::Boost( - Box::new(inner_ast), - 2.5, - ); + let boost_ast = UserInputAst::Boost(Box::new(inner_ast), 2.5); let json = serde_json::to_string(&boost_ast).unwrap(); assert_eq!( json, - r#"{"boost":[{"type":"all"},2.5]}"# + r#"{"type":"boost","underlying":{"type":"all"},"boost":2.5}"# ); } @@ -372,40 +394,52 @@ mod tests { fn test_boost_serialization2() { let boost_ast = UserInputAst::Boost( Box::new(UserInputAst::Clause(vec![ - (Some(Occur::Must), UserInputAst::Leaf(Box::new(UserInputLeaf::All))), - (Some(Occur::Should), UserInputAst::Leaf(Box::new(UserInputLeaf::Literal(UserInputLiteral { - field_name: Some("title".to_string()), - phrase: "hello".to_string(), - delimiter: Delimiter::None, - slop: 0, - prefix: false, - })))) + ( + Some(Occur::Must), + UserInputAst::Leaf(Box::new(UserInputLeaf::All)), + ), + ( + Some(Occur::Should), + UserInputAst::Leaf(Box::new(UserInputLeaf::Literal(UserInputLiteral { + field_name: Some("title".to_string()), + phrase: "hello".to_string(), + delimiter: Delimiter::None, + slop: 0, + prefix: false, + }))), + ), ])), 2.5, ); let json = serde_json::to_string(&boost_ast).unwrap(); assert_eq!( json, - r#"{"boost":[{"clause":[["must",{"type":"all"}],["should",{"type":"literal","field_name":"title","phrase":"hello","delimiter":"none","slop":0,"prefix":false}]]},2.5]}"# + r#"{"type":"boost","underlying":{"type":"clause","clause":[["must",{"type":"all"}],["should",{"type":"literal","field_name":"title","phrase":"hello","delimiter":"none","slop":0,"prefix":false}]]},"boost":2.5}"# ); } #[test] fn test_clause_serialization() { let clause = UserInputAst::Clause(vec![ - (Some(Occur::Must), UserInputAst::Leaf(Box::new(UserInputLeaf::All))), - (Some(Occur::Should), UserInputAst::Leaf(Box::new(UserInputLeaf::Literal(UserInputLiteral { - field_name: Some("title".to_string()), - phrase: "hello".to_string(), - delimiter: Delimiter::None, - slop: 0, - prefix: false, - })))) + ( + Some(Occur::Must), + UserInputAst::Leaf(Box::new(UserInputLeaf::All)), + ), + ( + Some(Occur::Should), + UserInputAst::Leaf(Box::new(UserInputLeaf::Literal(UserInputLiteral { + field_name: Some("title".to_string()), + phrase: "hello".to_string(), + delimiter: Delimiter::None, + slop: 0, + prefix: false, + }))), + ), ]); let json = serde_json::to_string(&clause).unwrap(); assert_eq!( json, - r#"{"clause":[["must",{"type":"all"}],["should",{"type":"literal","field_name":"title","phrase":"hello","delimiter":"none","slop":0,"prefix":false}]]}"# + r#"{"type":"clause","clause":[["must",{"type":"all"}],["should",{"type":"literal","field_name":"title","phrase":"hello","delimiter":"none","slop":0,"prefix":false}]]}"# ); } }