diff --git a/output.wav b/output.wav index 7883ed9..0a301dd 100644 Binary files a/output.wav and b/output.wav differ diff --git a/sbv2_core/src/jtalk.rs b/sbv2_core/src/jtalk.rs index 201a651..c553abb 100644 --- a/sbv2_core/src/jtalk.rs +++ b/sbv2_core/src/jtalk.rs @@ -223,27 +223,6 @@ impl JTalkProcess { } fn kata_to_phoneme_list(mut text: String) -> Result> { - /* - if set(text).issubset(set(PUNCTUATIONS)): - return list(text) - # `text` がカタカナ(`ー`含む)のみからなるかどうかをチェック - if __KATAKANA_PATTERN.fullmatch(text) is None: - raise ValueError(f"Input must be katakana only: {text}") - - def mora2phonemes(mora: str) -> str: - consonant, vowel = MORA_KATA_TO_MORA_PHONEMES[mora] - if consonant is None: - return f" {vowel}" - return f" {consonant} {vowel}" - - spaced_phonemes = __MORA_PATTERN.sub(lambda m: mora2phonemes(m.group()), text) - - # 長音記号「ー」の処理 - long_replacement = lambda m: m.group(1) + (" " + m.group(1)) * len(m.group(2)) # type: ignore - spaced_phonemes = __LONG_PATTERN.sub(long_replacement, spaced_phonemes) - - return spaced_phonemes.strip().split(" ") - */ if PUNCTUATIONS.contains(&text.as_str()) { return Ok(text.chars().map(|x| x.to_string()).collect()); } diff --git a/sbv2_core/src/main.rs b/sbv2_core/src/main.rs index 63ad56b..b0eb401 100644 --- a/sbv2_core/src/main.rs +++ b/sbv2_core/src/main.rs @@ -1,7 +1,7 @@ use sbv2_core::{error, tts}; fn main() -> error::Result<()> { - let text = "おはようございます。"; + let text = "眠たい"; let tts_model = tts::TTSModel::new( "models/debert.onnx", @@ -12,7 +12,9 @@ fn main() -> error::Result<()> { let (bert_ori, phones, tones, lang_ids) = tts_model.parse_text(text)?; let style_vector = tts_model.get_style_vector(0, 1.0)?; - tts_model.synthesize(bert_ori.to_owned(), phones, tones, lang_ids, style_vector)?; + let data = tts_model.synthesize(bert_ori.to_owned(), phones, tones, lang_ids, style_vector)?; + + std::fs::write("output.wav", data)?; Ok(()) } diff --git a/sbv2_core/src/model.rs b/sbv2_core/src/model.rs index d5c0eb2..ae6c535 100644 --- a/sbv2_core/src/model.rs +++ b/sbv2_core/src/model.rs @@ -1,7 +1,8 @@ use crate::error::Result; use hound::{SampleFormat, WavSpec, WavWriter}; -use ndarray::{array, Array1, Array2, Axis}; +use ndarray::{array, s, Array1, Array2, Axis}; use ort::{GraphOptimizationLevel, Session}; +use std::io::Cursor; pub fn load_model(model_file: &str) -> Result { let session = Session::builder()? @@ -11,24 +12,6 @@ pub fn load_model(model_file: &str) -> Result { Ok(session) } -fn write_wav(file_path: &str, audio: &[f32], sample_rate: u32) -> Result<()> { - let spec = WavSpec { - channels: 1, // モノラルの場合。ステレオなどの場合は2に変更 - sample_rate, - bits_per_sample: 16, - sample_format: SampleFormat::Int, - }; - - let mut writer = WavWriter::create(file_path, spec)?; - for &sample in audio { - let int_sample = (sample * i16::MAX as f32).clamp(i16::MIN as f32, i16::MAX as f32) as i16; - writer.write_sample(int_sample)?; - } - writer.finalize()?; - - Ok(()) -} - pub fn synthesize( session: &Session, bert_ori: Array2, @@ -36,7 +19,7 @@ pub fn synthesize( tones: Array1, lang_ids: Array1, style_vector: Array1, -) -> Result<()> { +) -> Result> { let bert = bert_ori.insert_axis(Axis(0)); let x_tst_lengths: Array1 = array![x_tst.shape()[0] as i64]; let x_tst = x_tst.insert_axis(Axis(0)); @@ -53,7 +36,30 @@ pub fn synthesize( "ja_bert" => style_vector, }?)?; - let audio_array = outputs.get("output").unwrap().try_extract_tensor::()?; - write_wav("output.wav", audio_array.as_slice().unwrap(), 44100)?; - Ok(()) + let audio_array = outputs + .get("output") + .unwrap() + .try_extract_tensor::()? + .to_owned(); + + let buffer = { + let spec = WavSpec { + channels: 1, + sample_rate: 44100, + bits_per_sample: 32, + sample_format: SampleFormat::Float, + }; + let mut cursor = Cursor::new(Vec::new()); + let mut writer = WavWriter::new(&mut cursor, spec)?; + for i in 0..audio_array.shape()[0] { + let output = audio_array.slice(s![i, 0, ..]).to_vec(); + for sample in output { + writer.write_sample(sample)?; + } + } + writer.finalize()?; + cursor.into_inner() + }; + + Ok(buffer) } diff --git a/sbv2_core/src/tts.rs b/sbv2_core/src/tts.rs index 26b7434..a7537d8 100644 --- a/sbv2_core/src/tts.rs +++ b/sbv2_core/src/tts.rs @@ -107,8 +107,8 @@ impl TTSModel { tones: Array1, lang_ids: Array1, style_vector: Array1, - ) -> Result<()> { - model::synthesize( + ) -> Result> { + let buffer = model::synthesize( &self.vits2, bert_ori.to_owned(), phones, @@ -116,6 +116,6 @@ impl TTSModel { lang_ids, style_vector, )?; - Ok(()) + Ok(buffer) } }