This commit is contained in:
tuna2134
2024-09-10 08:52:14 +00:00
parent 72eb1f2aa8
commit 354481fccc
5 changed files with 36 additions and 49 deletions

Binary file not shown.

View File

@@ -223,27 +223,6 @@ impl JTalkProcess {
}
fn kata_to_phoneme_list(mut text: String) -> Result<Vec<String>> {
/*
if set(text).issubset(set(PUNCTUATIONS)):
return list(text)
# `text` がカタカナ(`ー`含む)のみからなるかどうかをチェック
if __KATAKANA_PATTERN.fullmatch(text) is None:
raise ValueError(f"Input must be katakana only: {text}")
def mora2phonemes(mora: str) -> str:
consonant, vowel = MORA_KATA_TO_MORA_PHONEMES[mora]
if consonant is None:
return f" {vowel}"
return f" {consonant} {vowel}"
spaced_phonemes = __MORA_PATTERN.sub(lambda m: mora2phonemes(m.group()), text)
# 長音記号「ー」の処理
long_replacement = lambda m: m.group(1) + (" " + m.group(1)) * len(m.group(2)) # type: ignore
spaced_phonemes = __LONG_PATTERN.sub(long_replacement, spaced_phonemes)
return spaced_phonemes.strip().split(" ")
*/
if PUNCTUATIONS.contains(&text.as_str()) {
return Ok(text.chars().map(|x| x.to_string()).collect());
}

View File

@@ -1,7 +1,7 @@
use sbv2_core::{error, tts};
fn main() -> error::Result<()> {
let text = "おはようございます。";
let text = "眠たい";
let tts_model = tts::TTSModel::new(
"models/debert.onnx",
@@ -12,7 +12,9 @@ fn main() -> error::Result<()> {
let (bert_ori, phones, tones, lang_ids) = tts_model.parse_text(text)?;
let style_vector = tts_model.get_style_vector(0, 1.0)?;
tts_model.synthesize(bert_ori.to_owned(), phones, tones, lang_ids, style_vector)?;
let data = tts_model.synthesize(bert_ori.to_owned(), phones, tones, lang_ids, style_vector)?;
std::fs::write("output.wav", data)?;
Ok(())
}

View File

@@ -1,7 +1,8 @@
use crate::error::Result;
use hound::{SampleFormat, WavSpec, WavWriter};
use ndarray::{array, Array1, Array2, Axis};
use ndarray::{array, s, Array1, Array2, Axis};
use ort::{GraphOptimizationLevel, Session};
use std::io::Cursor;
pub fn load_model(model_file: &str) -> Result<Session> {
let session = Session::builder()?
@@ -11,24 +12,6 @@ pub fn load_model(model_file: &str) -> Result<Session> {
Ok(session)
}
fn write_wav(file_path: &str, audio: &[f32], sample_rate: u32) -> Result<()> {
let spec = WavSpec {
channels: 1, // モラルの場合。ステレオなどの場合は2に変更
sample_rate,
bits_per_sample: 16,
sample_format: SampleFormat::Int,
};
let mut writer = WavWriter::create(file_path, spec)?;
for &sample in audio {
let int_sample = (sample * i16::MAX as f32).clamp(i16::MIN as f32, i16::MAX as f32) as i16;
writer.write_sample(int_sample)?;
}
writer.finalize()?;
Ok(())
}
pub fn synthesize(
session: &Session,
bert_ori: Array2<f32>,
@@ -36,7 +19,7 @@ pub fn synthesize(
tones: Array1<i64>,
lang_ids: Array1<i64>,
style_vector: Array1<f32>,
) -> Result<()> {
) -> Result<Vec<u8>> {
let bert = bert_ori.insert_axis(Axis(0));
let x_tst_lengths: Array1<i64> = array![x_tst.shape()[0] as i64];
let x_tst = x_tst.insert_axis(Axis(0));
@@ -53,7 +36,30 @@ pub fn synthesize(
"ja_bert" => style_vector,
}?)?;
let audio_array = outputs.get("output").unwrap().try_extract_tensor::<f32>()?;
write_wav("output.wav", audio_array.as_slice().unwrap(), 44100)?;
Ok(())
let audio_array = outputs
.get("output")
.unwrap()
.try_extract_tensor::<f32>()?
.to_owned();
let buffer = {
let spec = WavSpec {
channels: 1,
sample_rate: 44100,
bits_per_sample: 32,
sample_format: SampleFormat::Float,
};
let mut cursor = Cursor::new(Vec::new());
let mut writer = WavWriter::new(&mut cursor, spec)?;
for i in 0..audio_array.shape()[0] {
let output = audio_array.slice(s![i, 0, ..]).to_vec();
for sample in output {
writer.write_sample(sample)?;
}
}
writer.finalize()?;
cursor.into_inner()
};
Ok(buffer)
}

View File

@@ -107,8 +107,8 @@ impl TTSModel {
tones: Array1<i64>,
lang_ids: Array1<i64>,
style_vector: Array1<f32>,
) -> Result<()> {
model::synthesize(
) -> Result<Vec<u8>> {
let buffer = model::synthesize(
&self.vits2,
bert_ori.to_owned(),
phones,
@@ -116,6 +116,6 @@ impl TTSModel {
lang_ids,
style_vector,
)?;
Ok(())
Ok(buffer)
}
}