diff --git a/convert/convert_model.py b/convert/convert_model.py index 70b7765..e5badcc 100644 --- a/convert/convert_model.py +++ b/convert/convert_model.py @@ -94,7 +94,7 @@ model = get_net_g( ) -def forward(x, x_len, sid, tone, lang, bert, style, length_scale, sdp_ratio): +def forward(x, x_len, sid, tone, lang, bert, style, length_scale, sdp_ratio, noise_scale, noise_scale_w): return model.infer( x, x_len, @@ -105,6 +105,8 @@ def forward(x, x_len, sid, tone, lang, bert, style, length_scale, sdp_ratio): style, sdp_ratio=sdp_ratio, length_scale=length_scale, + noise_scale=noise_scale, + noise_scale_w=noise_scale_w, ) @@ -122,6 +124,8 @@ torch.onnx.export( style_vec_tensor, torch.tensor(1.0), torch.tensor(0.0), + torch.tensor(0.6777), + torch.tensor(0.8), ), f"../models/model_{out_name}.onnx", verbose=True, @@ -144,6 +148,8 @@ torch.onnx.export( "style_vec", "length_scale", "sdp_ratio", + "noise_scale", + "noise_scale_w" ], output_names=["output"], ) diff --git a/sbv2_core/src/model.rs b/sbv2_core/src/model.rs index a851909..603ea83 100644 --- a/sbv2_core/src/model.rs +++ b/sbv2_core/src/model.rs @@ -58,6 +58,8 @@ pub fn synthesize( style_vector: Array1, sdp_ratio: f32, length_scale: f32, + noise_scale: f32, + noise_scale_w: f32, ) -> Result> { let bert = bert_ori.insert_axis(Axis(0)); let x_tst_lengths: Array1 = array![x_tst.shape()[0] as i64]; @@ -75,6 +77,8 @@ pub fn synthesize( "style_vec" => style_vector, "sdp_ratio" => array![sdp_ratio], "length_scale" => array![length_scale], + "noise_scale" => array![noise_scale], + "noise_scale_w" => array![noise_scale_w] }?)?; let audio_array = outputs["output"] diff --git a/sbv2_core/src/tts.rs b/sbv2_core/src/tts.rs index 4b12961..0057260 100644 --- a/sbv2_core/src/tts.rs +++ b/sbv2_core/src/tts.rs @@ -310,6 +310,8 @@ impl TTSModelHolder { style_vector.clone(), options.sdp_ratio, options.length_scale, + 0.677, + 0.8, )?; audios.push(audio.clone()); if i != texts.len() - 1 { @@ -332,6 +334,8 @@ impl TTSModelHolder { style_vector, options.sdp_ratio, options.length_scale, + 0.677, + 0.8, )? }; tts_util::array_to_vec(audio_array)