Files
sbv2-api/sbv2_core/src/model.rs
tuna2134 41c9d2c092 fixed
2024-09-10 06:00:31 +00:00

30 lines
879 B
Rust

use crate::error::Result;
use ndarray::{array, Array1, Array2, Axis};
use ort::Session;
pub fn synthesize(
session: &Session,
bert_ori: Array2<f32>,
x_tst: Array1<i64>,
tones: Array1<i64>,
lang_ids: Array1<i64>,
style_vector: Array1<f32>,
) -> Result<()> {
let bert = bert_ori.insert_axis(Axis(0));
let x_tst_lengths: Array1<i64> = array![x_tst.shape()[0] as i64];
let x_tst = x_tst.insert_axis(Axis(0));
let lang_ids = lang_ids.insert_axis(Axis(0));
let tones = tones.insert_axis(Axis(0));
let style_vector = style_vector.insert_axis(Axis(0));
let outputs = session.run(ort::inputs! {
"x_tst" => x_tst,
"x_tst_lengths" => x_tst_lengths,
"sid" => array![0_i64],
"tones" => tones,
"language" => lang_ids,
"bert" => bert,
"ja_bert" => style_vector,
}?)?;
Ok(())
}