This commit is contained in:
Googlefan
2025-02-22 08:00:17 +00:00
parent 14d631eeaa
commit 506ee4d883
60 changed files with 927 additions and 517 deletions

View File

@@ -0,0 +1,24 @@
[package]
name = "sbv2_api"
version = "0.2.0-alpha"
edition = "2021"
[dependencies]
anyhow.workspace = true
axum = "0.7.5"
dotenvy.workspace = true
env_logger.workspace = true
log = "0.4.22"
sbv2_core = { version = "0.2.0-alpha2", path = "../sbv2_core" }
serde = { version = "1.0.210", features = ["derive"] }
tokio = { version = "1.40.0", features = ["full"] }
utoipa = { version = "5.0.0", features = ["axum_extras"] }
utoipa-scalar = { version = "0.2.0", features = ["axum"] }
[features]
coreml = ["sbv2_core/coreml"]
cuda = ["sbv2_core/cuda"]
cuda_tf32 = ["sbv2_core/cuda_tf32"]
dynamic = ["sbv2_core/dynamic"]
directml = ["sbv2_core/directml"]
tensorrt = ["sbv2_core/tensorrt"]

5
crates/sbv2_api/build.rs Normal file
View File

@@ -0,0 +1,5 @@
fn main() {
if cfg!(feature = "coreml") {
println!("cargo:rustc-link-arg=-fapple-link-rtlib");
}
}

View File

@@ -0,0 +1,27 @@
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
};
pub type AppResult<T> = std::result::Result<T, AppError>;
pub struct AppError(anyhow::Error);
impl IntoResponse for AppError {
fn into_response(self) -> Response {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Something went wrong: {}", self.0),
)
.into_response()
}
}
impl<E> From<E> for AppError
where
E: Into<anyhow::Error>,
{
fn from(err: E) -> Self {
Self(err.into())
}
}

174
crates/sbv2_api/src/main.rs Normal file
View File

@@ -0,0 +1,174 @@
use axum::{
extract::State,
http::header::CONTENT_TYPE,
response::IntoResponse,
routing::{get, post},
Json, Router,
};
use sbv2_core::tts::{SynthesizeOptions, TTSModelHolder};
use serde::Deserialize;
use std::env;
use std::sync::Arc;
use tokio::fs;
use tokio::sync::Mutex;
use utoipa::{OpenApi, ToSchema};
use utoipa_scalar::{Scalar, Servable};
mod error;
use crate::error::AppResult;
#[derive(OpenApi)]
#[openapi(paths(models, synthesize), components(schemas(SynthesizeRequest)))]
struct ApiDoc;
#[utoipa::path(
get,
path = "/models",
responses(
(status = 200, description = "Return model list", body = Vec<String>),
)
)]
async fn models(State(state): State<AppState>) -> AppResult<impl IntoResponse> {
Ok(Json(state.tts_model.lock().await.models()))
}
fn sdp_default() -> f32 {
0.0
}
fn length_default() -> f32 {
1.0
}
#[derive(Deserialize, ToSchema)]
struct SynthesizeRequest {
text: String,
ident: String,
#[serde(default = "sdp_default")]
sdp_ratio: f32,
#[serde(default = "length_default")]
length_scale: f32,
}
#[utoipa::path(
post,
path = "/synthesize",
request_body = SynthesizeRequest,
responses(
(status = 200, description = "Return audio/wav", body = Vec<u8>, content_type = "audio/wav")
)
)]
async fn synthesize(
State(state): State<AppState>,
Json(SynthesizeRequest {
text,
ident,
sdp_ratio,
length_scale,
}): Json<SynthesizeRequest>,
) -> AppResult<impl IntoResponse> {
log::debug!("processing request: text={text}, ident={ident}, sdp_ratio={sdp_ratio}, length_scale={length_scale}");
let buffer = {
let mut tts_model = state.tts_model.lock().await;
tts_model.easy_synthesize(
&ident,
&text,
0,
SynthesizeOptions {
sdp_ratio,
length_scale,
..Default::default()
},
)?
};
Ok(([(CONTENT_TYPE, "audio/wav")], buffer))
}
#[derive(Clone)]
struct AppState {
tts_model: Arc<Mutex<TTSModelHolder>>,
}
impl AppState {
pub async fn new() -> anyhow::Result<Self> {
let mut tts_model = TTSModelHolder::new(
&fs::read(env::var("BERT_MODEL_PATH")?).await?,
&fs::read(env::var("TOKENIZER_PATH")?).await?,
env::var("HOLDER_MAX_LOADED_MODElS")
.ok()
.and_then(|x| x.parse().ok()),
)?;
let models = env::var("MODELS_PATH").unwrap_or("models".to_string());
let mut f = fs::read_dir(&models).await?;
let mut entries = vec![];
while let Ok(Some(e)) = f.next_entry().await {
let name = e.file_name().to_string_lossy().to_string();
if name.ends_with(".onnx") && name.starts_with("model_") {
let name_len = name.len();
let name = name.chars();
entries.push(
name.collect::<Vec<_>>()[6..name_len - 5]
.iter()
.collect::<String>(),
);
} else if name.ends_with(".sbv2") {
let entry = &name[..name.len() - 5];
log::info!("Try loading: {entry}");
let sbv2_bytes = match fs::read(format!("{models}/{entry}.sbv2")).await {
Ok(b) => b,
Err(e) => {
log::warn!("Error loading sbv2_bytes from file {entry}: {e}");
continue;
}
};
if let Err(e) = tts_model.load_sbv2file(entry, sbv2_bytes) {
log::warn!("Error loading {entry}: {e}");
};
log::info!("Loaded: {entry}");
}
}
for entry in entries {
log::info!("Try loading: {entry}");
let style_vectors_bytes =
match fs::read(format!("{models}/style_vectors_{entry}.json")).await {
Ok(b) => b,
Err(e) => {
log::warn!("Error loading style_vectors_bytes from file {entry}: {e}");
continue;
}
};
let vits2_bytes = match fs::read(format!("{models}/model_{entry}.onnx")).await {
Ok(b) => b,
Err(e) => {
log::warn!("Error loading vits2_bytes from file {entry}: {e}");
continue;
}
};
if let Err(e) = tts_model.load(&entry, style_vectors_bytes, vits2_bytes) {
log::warn!("Error loading {entry}: {e}");
};
log::info!("Loaded: {entry}");
}
Ok(Self {
tts_model: Arc::new(Mutex::new(tts_model)),
})
}
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
dotenvy::dotenv_override().ok();
env_logger::init();
let app = Router::new()
.route("/", get(|| async { "Hello, World!" }))
.route("/synthesize", post(synthesize))
.route("/models", get(models))
.with_state(AppState::new().await?)
.merge(Scalar::with_url("/docs", ApiDoc::openapi()));
let addr = env::var("ADDR").unwrap_or("0.0.0.0:3000".to_string());
let listener = tokio::net::TcpListener::bind(&addr).await?;
log::info!("Listening on {addr}");
axum::serve(listener, app).await?;
Ok(())
}

View File

@@ -0,0 +1,15 @@
[package]
name = "sbv2_bindings"
version = "0.2.0-alpha2"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[lib]
name = "sbv2_bindings"
crate-type = ["cdylib"]
[dependencies]
anyhow.workspace = true
ndarray.workspace = true
pyo3 = { version = "0.22.0", features = ["anyhow"] }
sbv2_core = { version = "0.2.0-alpha2", path = "../sbv2_core" }

View File

@@ -0,0 +1,17 @@
[build-system]
requires = ["maturin>=1.7,<2.0"]
build-backend = "maturin"
[project]
name = "sbv2_bindings"
requires-python = ">=3.8"
classifiers = [
"Programming Language :: Rust",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
dynamic = ["version"]
[tool.maturin]
features = ["pyo3/extension-module"]
strip = true

View File

@@ -0,0 +1,11 @@
use pyo3::prelude::*;
mod sbv2;
pub mod style;
/// sbv2 bindings module
#[pymodule]
fn sbv2_bindings(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<sbv2::TTSModel>()?;
m.add_class::<style::StyleVector>()?;
Ok(())
}

View File

@@ -0,0 +1,164 @@
use pyo3::prelude::*;
use pyo3::types::PyBytes;
use sbv2_core::tts::{SynthesizeOptions, TTSModelHolder};
use crate::style::StyleVector;
use std::fs;
/// TTSModel class
///
/// 音声合成するために使うクラス
///
/// Parameters
/// ----------
/// bert_model_bytes : bytes
/// BERTモデルのバイナリデータ
/// tokenizer_bytes : bytes
/// トークナイザーのバイナリデータ
#[pyclass]
pub struct TTSModel {
pub model: TTSModelHolder,
}
#[pymethods]
impl TTSModel {
#[pyo3(signature = (bert_model_bytes, tokenizer_bytes, max_loaded_models=None))]
#[new]
fn new(
bert_model_bytes: Vec<u8>,
tokenizer_bytes: Vec<u8>,
max_loaded_models: Option<usize>,
) -> anyhow::Result<Self> {
Ok(Self {
model: TTSModelHolder::new(bert_model_bytes, tokenizer_bytes, max_loaded_models)?,
})
}
/// パスからTTSModelインスタンスを生成する
///
/// Parameters
/// ----------
/// bert_model_path : str
/// BERTモデルのパス
/// tokenizer_path : str
/// トークナイザーのパス
/// max_loaded_models: int | None
/// 同時にVRAMに存在するモデルの数
#[pyo3(signature = (bert_model_path, tokenizer_path, max_loaded_models=None))]
#[staticmethod]
fn from_path(
bert_model_path: String,
tokenizer_path: String,
max_loaded_models: Option<usize>,
) -> anyhow::Result<Self> {
Ok(Self {
model: TTSModelHolder::new(
fs::read(bert_model_path)?,
fs::read(tokenizer_path)?,
max_loaded_models,
)?,
})
}
/// SBV2ファイルを読み込む
///
/// Parameters
/// ----------
/// ident : str
/// 識別子
/// sbv2file_bytes : bytes
/// SBV2ファイルのバイナリデータ
fn load_sbv2file(&mut self, ident: String, sbv2file_bytes: Vec<u8>) -> anyhow::Result<()> {
self.model.load_sbv2file(ident, sbv2file_bytes)?;
Ok(())
}
/// パスからSBV2ファイルを読み込む
///
/// Parameters
/// ----------
/// ident : str
/// 識別子
/// sbv2file_path : str
/// SBV2ファイルのパス
fn load_sbv2file_from_path(
&mut self,
ident: String,
sbv2file_path: String,
) -> anyhow::Result<()> {
self.model.load_sbv2file(ident, fs::read(sbv2file_path)?)?;
Ok(())
}
/// スタイルベクトルを取得する
///
/// Parameters
/// ----------
/// ident : str
/// 識別子
/// style_id : int
/// スタイルID
/// weight : float
/// 重み
///
/// Returns
/// -------
/// style_vector : StyleVector
/// スタイルベクトル
fn get_style_vector(
&mut self,
ident: String,
style_id: i32,
weight: f32,
) -> anyhow::Result<StyleVector> {
Ok(StyleVector::new(
self.model.get_style_vector(ident, style_id, weight)?,
))
}
/// テキストから音声を合成する
///
/// Parameters
/// ----------
/// text : str
/// テキスト
/// ident : str
/// 識別子
/// style_id : int
/// スタイルID
/// sdp_ratio : float
/// SDP比率
/// length_scale : float
/// 音声の長さのスケール
///
/// Returns
/// -------
/// voice_data : bytes
/// 音声データ
fn synthesize<'p>(
&'p mut self,
py: Python<'p>,
text: String,
ident: String,
style_id: i32,
sdp_ratio: f32,
length_scale: f32,
) -> anyhow::Result<Bound<'p, PyBytes>> {
let data = self.model.easy_synthesize(
ident.as_str(),
&text,
style_id,
SynthesizeOptions {
sdp_ratio,
length_scale,
..Default::default()
},
)?;
Ok(PyBytes::new_bound(py, &data))
}
fn unload(&mut self, ident: String) -> bool {
self.model.unload(ident)
}
}

View File

@@ -0,0 +1,19 @@
use ndarray::Array1;
use pyo3::prelude::*;
/// StyleVector class
///
/// スタイルベクトルを表すクラス
#[pyclass]
#[derive(Clone)]
pub struct StyleVector(Array1<f32>);
impl StyleVector {
pub fn new(data: Array1<f32>) -> Self {
StyleVector(data)
}
pub fn get(&self) -> Array1<f32> {
self.0.clone()
}
}

View File

@@ -0,0 +1,38 @@
[package]
name = "sbv2_core"
description = "Style-Bert-VITSの推論ライブラリ"
version = "0.2.0-alpha2"
edition = "2021"
license = "MIT"
readme = "../README.md"
repository = "https://github.com/tuna2134/sbv2-api"
documentation = "https://docs.rs/sbv2_core"
[dependencies]
anyhow.workspace = true
dotenvy.workspace = true
env_logger.workspace = true
hound = "3.5.1"
jpreprocess = { version = "0.12.0", features = ["naist-jdic"] }
ndarray.workspace = true
num_cpus = "1.16.0"
once_cell.workspace = true
ort = { git = "https://github.com/pykeio/ort.git", version = "2.0.0-rc.9", optional = true }
regex = "1.10.6"
serde = { version = "1.0.210", features = ["derive"] }
serde_json = "1.0.128"
tar = "0.4.41"
thiserror = "2.0.11"
tokenizers = { version = "0.21.0", default-features = false }
zstd = "0.13.2"
[features]
cuda = ["ort/cuda", "std"]
cuda_tf32 = ["std", "cuda"]
std = ["dep:ort", "tokenizers/progressbar", "tokenizers/onig", "tokenizers/esaxx_fast"]
dynamic = ["ort/load-dynamic", "std"]
directml = ["ort/directml", "std"]
tensorrt = ["ort/tensorrt", "std"]
coreml = ["ort/coreml", "std"]
default = ["std"]
no_std = ["tokenizers/unstable_wasm"]

View File

@@ -0,0 +1,196 @@
# moraに変換します
import json
__MORA_LIST_MINIMUM: list[tuple[str, str | None, str]] = [
("ヴォ", "v", "o"),
("ヴェ", "v", "e"),
("ヴィ", "v", "i"),
("ヴァ", "v", "a"),
("", "v", "u"),
("", None, "N"),
("", "w", "a"),
("", "r", "o"),
("", "r", "e"),
("", "r", "u"),
("リョ", "ry", "o"),
("リュ", "ry", "u"),
("リャ", "ry", "a"),
("リェ", "ry", "e"),
("", "r", "i"),
("", "r", "a"),
("", "y", "o"),
("", "y", "u"),
("", "y", "a"),
("", "m", "o"),
("", "m", "e"),
("", "m", "u"),
("ミョ", "my", "o"),
("ミュ", "my", "u"),
("ミャ", "my", "a"),
("ミェ", "my", "e"),
("", "m", "i"),
("", "m", "a"),
("", "p", "o"),
("", "b", "o"),
("", "h", "o"),
("", "p", "e"),
("", "b", "e"),
("", "h", "e"),
("", "p", "u"),
("", "b", "u"),
("フォ", "f", "o"),
("フェ", "f", "e"),
("フィ", "f", "i"),
("ファ", "f", "a"),
("", "f", "u"),
("ピョ", "py", "o"),
("ピュ", "py", "u"),
("ピャ", "py", "a"),
("ピェ", "py", "e"),
("", "p", "i"),
("ビョ", "by", "o"),
("ビュ", "by", "u"),
("ビャ", "by", "a"),
("ビェ", "by", "e"),
("", "b", "i"),
("ヒョ", "hy", "o"),
("ヒュ", "hy", "u"),
("ヒャ", "hy", "a"),
("ヒェ", "hy", "e"),
("", "h", "i"),
("", "p", "a"),
("", "b", "a"),
("", "h", "a"),
("", "n", "o"),
("", "n", "e"),
("", "n", "u"),
("ニョ", "ny", "o"),
("ニュ", "ny", "u"),
("ニャ", "ny", "a"),
("ニェ", "ny", "e"),
("", "n", "i"),
("", "n", "a"),
("ドゥ", "d", "u"),
("", "d", "o"),
("トゥ", "t", "u"),
("", "t", "o"),
("デョ", "dy", "o"),
("デュ", "dy", "u"),
("デャ", "dy", "a"),
# ("デェ", "dy", "e"),
("ディ", "d", "i"),
("", "d", "e"),
("テョ", "ty", "o"),
("テュ", "ty", "u"),
("テャ", "ty", "a"),
("ティ", "t", "i"),
("", "t", "e"),
("ツォ", "ts", "o"),
("ツェ", "ts", "e"),
("ツィ", "ts", "i"),
("ツァ", "ts", "a"),
("", "ts", "u"),
("", None, "q"), # 「cl」から「q」に変更
("チョ", "ch", "o"),
("チュ", "ch", "u"),
("チャ", "ch", "a"),
("チェ", "ch", "e"),
("", "ch", "i"),
("", "d", "a"),
("", "t", "a"),
("", "z", "o"),
("", "s", "o"),
("", "z", "e"),
("", "s", "e"),
("ズィ", "z", "i"),
("", "z", "u"),
("スィ", "s", "i"),
("", "s", "u"),
("ジョ", "j", "o"),
("ジュ", "j", "u"),
("ジャ", "j", "a"),
("ジェ", "j", "e"),
("", "j", "i"),
("ショ", "sh", "o"),
("シュ", "sh", "u"),
("シャ", "sh", "a"),
("シェ", "sh", "e"),
("", "sh", "i"),
("", "z", "a"),
("", "s", "a"),
("", "g", "o"),
("", "k", "o"),
("", "g", "e"),
("", "k", "e"),
("グヮ", "gw", "a"),
("", "g", "u"),
("クヮ", "kw", "a"),
("", "k", "u"),
("ギョ", "gy", "o"),
("ギュ", "gy", "u"),
("ギャ", "gy", "a"),
("ギェ", "gy", "e"),
("", "g", "i"),
("キョ", "ky", "o"),
("キュ", "ky", "u"),
("キャ", "ky", "a"),
("キェ", "ky", "e"),
("", "k", "i"),
("", "g", "a"),
("", "k", "a"),
("", None, "o"),
("", None, "e"),
("ウォ", "w", "o"),
("ウェ", "w", "e"),
("ウィ", "w", "i"),
("", None, "u"),
("イェ", "y", "e"),
("", None, "i"),
("", None, "a"),
]
__MORA_LIST_ADDITIONAL: list[tuple[str, str | None, str]] = [
("ヴョ", "by", "o"),
("ヴュ", "by", "u"),
("ヴャ", "by", "a"),
("", None, "o"),
("", None, "e"),
("", None, "i"),
("", "w", "a"),
("", "y", "o"),
("", "y", "u"),
("", "z", "u"),
("", "j", "i"),
("", "k", "e"),
("", "y", "a"),
("", None, "o"),
("", None, "e"),
("", None, "u"),
("", None, "i"),
("", None, "a"),
]
data = {"minimum": [], "additional": []}
for mora, consonant, vowel in __MORA_LIST_MINIMUM:
data["minimum"].append(
{
"mora": mora,
"consonant": consonant,
"vowel": vowel,
}
)
for mora, consonant, vowel in __MORA_LIST_ADDITIONAL:
data["additional"].append(
{
"mora": mora,
"consonant": consonant,
"vowel": vowel,
}
)
with open("src/mora_list.json", "w") as f:
json.dump(data, f, ensure_ascii=False, indent=4)

View File

@@ -0,0 +1,24 @@
use crate::error::Result;
use ndarray::{Array2, Ix2};
use ort::session::Session;
use ort::value::TensorRef;
pub fn predict(
session: &mut Session,
token_ids: Vec<i64>,
attention_masks: Vec<i64>,
) -> Result<Array2<f32>> {
let outputs = session.run(
ort::inputs! {
"input_ids" => TensorRef::from_array_view((vec![1, token_ids.len() as i64], token_ids.as_slice()))?,
"attention_mask" => TensorRef::from_array_view((vec![1, attention_masks.len() as i64], attention_masks.as_slice()))?,
}
)?;
let output = outputs["output"]
.try_extract_tensor::<f32>()?
.into_dimensionality::<Ix2>()?
.to_owned();
Ok(output)
}

View File

@@ -0,0 +1,30 @@
use thiserror::Error;
#[derive(Error, Debug)]
pub enum Error {
#[error("Tokenizer error: {0}")]
TokenizerError(#[from] tokenizers::Error),
#[error("JPreprocess error: {0}")]
JPreprocessError(#[from] jpreprocess::error::JPreprocessError),
#[error("Lindera error: {0}")]
LinderaError(String),
#[cfg(feature = "std")]
#[error("ONNX error: {0}")]
OrtError(#[from] ort::Error),
#[error("NDArray error: {0}")]
NdArrayError(#[from] ndarray::ShapeError),
#[error("Value error: {0}")]
ValueError(String),
#[error("Serde_json error: {0}")]
SerdeJsonError(#[from] serde_json::Error),
#[error("IO error: {0}")]
IoError(#[from] std::io::Error),
#[error("hound error: {0}")]
HoundError(#[from] hound::Error),
#[error("model not found error")]
ModelNotFoundError(String),
#[error("other")]
OtherError(String),
}
pub type Result<T> = std::result::Result<T, Error>;

View File

@@ -0,0 +1,406 @@
use crate::error::{Error, Result};
use crate::mora::{MORA_KATA_TO_MORA_PHONEMES, VOWELS};
use crate::norm::{replace_punctuation, PUNCTUATIONS};
use jpreprocess::{kind, DefaultTokenizer, JPreprocess, SystemDictionaryConfig, UserDictionary};
use once_cell::sync::Lazy;
use regex::Regex;
use std::cmp::Reverse;
use std::collections::HashSet;
use std::sync::Arc;
type JPreprocessType = JPreprocess<DefaultTokenizer>;
fn initialize_jtalk() -> Result<JPreprocessType> {
let sdic =
SystemDictionaryConfig::Bundled(kind::JPreprocessDictionaryKind::NaistJdic).load()?;
let u = UserDictionary::load(include_bytes!("./dic/all.dic/all.bin"))
.map_err(|e| Error::LinderaError(e.to_string()))?;
let jpreprocess = JPreprocess::with_dictionaries(sdic, Some(u));
Ok(jpreprocess)
}
macro_rules! hash_set {
($($elem:expr),* $(,)?) => {{
let mut set = HashSet::new();
$(
set.insert($elem);
)*
set
}};
}
pub struct JTalk {
pub jpreprocess: Arc<JPreprocessType>,
}
impl JTalk {
pub fn new() -> Result<Self> {
let jpreprocess = Arc::new(initialize_jtalk()?);
Ok(Self { jpreprocess })
}
pub fn num2word(&self, text: &str) -> Result<String> {
let mut parsed = self.jpreprocess.text_to_njd(text)?;
parsed.preprocess();
let texts: Vec<String> = parsed
.nodes
.iter()
.map(|x| x.get_string().to_string())
.collect();
Ok(texts.join(""))
}
pub fn process_text(&self, text: &str) -> Result<JTalkProcess> {
let parsed = self.jpreprocess.run_frontend(text)?;
let jtalk_process = JTalkProcess::new(Arc::clone(&self.jpreprocess), parsed);
Ok(jtalk_process)
}
}
static KATAKANA_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new(r"[\u30A0-\u30FF]+").unwrap());
static MORA_PATTERN: Lazy<Vec<String>> = Lazy::new(|| {
let mut sorted_keys: Vec<String> = MORA_KATA_TO_MORA_PHONEMES.keys().cloned().collect();
sorted_keys.sort_by_key(|b| Reverse(b.len()));
sorted_keys
});
static LONG_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new(r"(\w)(ー*)").unwrap());
pub struct JTalkProcess {
jpreprocess: Arc<JPreprocessType>,
parsed: Vec<String>,
}
impl JTalkProcess {
fn new(jpreprocess: Arc<JPreprocessType>, parsed: Vec<String>) -> Self {
Self {
jpreprocess,
parsed,
}
}
fn fix_phone_tone(&self, phone_tone_list: Vec<(String, i32)>) -> Result<Vec<(String, i32)>> {
let tone_values: HashSet<i32> = phone_tone_list
.iter()
.map(|(_letter, tone)| *tone)
.collect();
if tone_values.len() == 1 {
assert!(tone_values == hash_set![0], "{:?}", tone_values);
Ok(phone_tone_list)
} else if tone_values.len() == 2 {
if tone_values == hash_set![0, 1] {
return Ok(phone_tone_list);
} else if tone_values == hash_set![-1, 0] {
return Ok(phone_tone_list
.iter()
.map(|x| {
let new_tone = if x.1 == -1 { 0 } else { 1 };
(x.0.clone(), new_tone)
})
.collect());
} else {
return Err(Error::ValueError("Invalid tone values 0".to_string()));
}
} else {
return Err(Error::ValueError("Invalid tone values 1".to_string()));
}
}
pub fn g2p(&self) -> Result<(Vec<String>, Vec<i32>, Vec<i32>)> {
let phone_tone_list_wo_punct = self.g2phone_tone_wo_punct()?;
let (seq_text, seq_kata) = self.text_to_seq_kata()?;
let sep_phonemes = JTalkProcess::handle_long(
seq_kata
.iter()
.map(|x| JTalkProcess::kata_to_phoneme_list(x.clone()).unwrap())
.collect(),
);
let phone_w_punct: Vec<String> = sep_phonemes
.iter()
.flat_map(|x| x.iter())
.cloned()
.collect();
let mut phone_tone_list =
JTalkProcess::align_tones(phone_w_punct, phone_tone_list_wo_punct)?;
let mut sep_tokenized: Vec<Vec<String>> = Vec::new();
for seq_text_item in &seq_text {
let text = seq_text_item.clone();
if !PUNCTUATIONS.contains(&text.as_str()) {
sep_tokenized.push(text.chars().map(|x| x.to_string()).collect());
} else {
sep_tokenized.push(vec![text]);
}
}
let mut word2ph = Vec::new();
for (token, phoneme) in sep_tokenized.iter().zip(sep_phonemes.iter()) {
let phone_len = phoneme.len() as i32;
let word_len = token.len() as i32;
word2ph.append(&mut JTalkProcess::distribute_phone(phone_len, word_len));
}
let mut new_phone_tone_list = vec![("_".to_string(), 0)];
new_phone_tone_list.append(&mut phone_tone_list);
new_phone_tone_list.push(("_".to_string(), 0));
let mut new_word2ph = vec![1];
new_word2ph.extend(word2ph.clone());
new_word2ph.push(1);
let phones: Vec<String> = new_phone_tone_list.iter().map(|(x, _)| x.clone()).collect();
let tones: Vec<i32> = new_phone_tone_list.iter().map(|(_, x)| *x).collect();
Ok((phones, tones, new_word2ph))
}
fn distribute_phone(n_phone: i32, n_word: i32) -> Vec<i32> {
let mut phones_per_word = vec![0; n_word as usize];
for _ in 0..n_phone {
let min_task = phones_per_word.iter().min().unwrap();
let min_index = phones_per_word
.iter()
.position(|&x| x == *min_task)
.unwrap();
phones_per_word[min_index] += 1;
}
phones_per_word
}
fn align_tones(
phone_with_punct: Vec<String>,
phone_tone_list: Vec<(String, i32)>,
) -> Result<Vec<(String, i32)>> {
let mut result: Vec<(String, i32)> = Vec::new();
let mut tone_index = 0;
for phone in phone_with_punct.clone() {
if tone_index >= phone_tone_list.len() {
result.push((phone, 0));
} else if phone == phone_tone_list[tone_index].0 {
result.push((phone, phone_tone_list[tone_index].1));
tone_index += 1;
} else if PUNCTUATIONS.contains(&phone.as_str()) {
result.push((phone, 0));
} else {
println!("phones {:?}", phone_with_punct);
println!("phone_tone_list: {:?}", phone_tone_list);
println!("result: {:?}", result);
println!("tone_index: {:?}", tone_index);
println!("phone: {:?}", phone);
return Err(Error::ValueError(format!("Mismatched phoneme: {}", phone)));
}
}
Ok(result)
}
fn handle_long(mut sep_phonemes: Vec<Vec<String>>) -> Vec<Vec<String>> {
for i in 0..sep_phonemes.len() {
if sep_phonemes[i].is_empty() {
continue;
}
if sep_phonemes[i][0] == "" {
if i != 0 {
let prev_phoneme = sep_phonemes[i - 1].last().unwrap();
if VOWELS.contains(&prev_phoneme.as_str()) {
sep_phonemes[i][0] = prev_phoneme.clone();
} else {
sep_phonemes[i][0] = "".to_string();
}
} else {
sep_phonemes[i][0] = "".to_string();
}
}
if sep_phonemes[i].contains(&"".to_string()) {
for e in 0..sep_phonemes[i].len() {
if sep_phonemes[i][e] == "" {
sep_phonemes[i][e] =
sep_phonemes[i][e - 1].chars().last().unwrap().to_string();
}
}
}
}
sep_phonemes
}
fn kata_to_phoneme_list(mut text: String) -> Result<Vec<String>> {
let chars: HashSet<String> = text.chars().map(|x| x.to_string()).collect();
if chars.is_subset(&HashSet::from_iter(
PUNCTUATIONS.iter().map(|x| x.to_string()),
)) {
return Ok(text.chars().map(|x| x.to_string()).collect());
}
if !KATAKANA_PATTERN.is_match(&text) {
return Err(Error::ValueError(format!(
"Input must be katakana only: {}",
text
)));
}
for mora in MORA_PATTERN.iter() {
let mora = mora.to_string();
let (consonant, vowel) = MORA_KATA_TO_MORA_PHONEMES.get(&mora).unwrap();
if consonant.is_none() {
text = text.replace(&mora, &format!(" {}", vowel));
} else {
text = text.replace(
&mora,
&format!(" {} {}", consonant.as_ref().unwrap(), vowel),
);
}
}
let long_replacement = |m: &regex::Captures| {
let result = m.get(1).unwrap().as_str().to_string();
let mut second = String::new();
for _ in 0..m.get(2).unwrap().as_str().char_indices().count() {
second += &format!(" {}", m.get(1).unwrap().as_str());
}
result + &second
};
text = LONG_PATTERN
.replace_all(&text, long_replacement)
.to_string();
let data = text.trim().split(' ').map(|x| x.to_string()).collect();
Ok(data)
}
pub fn text_to_seq_kata(&self) -> Result<(Vec<String>, Vec<String>)> {
let mut seq_kata = vec![];
let mut seq_text = vec![];
for parts in &self.parsed {
let (string, pron) = self.parse_to_string_and_pron(parts.clone());
let mut yomi = pron.replace('', "");
let word = replace_punctuation(string);
assert!(!yomi.is_empty(), "Empty yomi: {}", word);
if yomi == "" {
if !word
.chars()
.all(|x| PUNCTUATIONS.contains(&x.to_string().as_str()))
{
yomi = "'".repeat(word.len());
} else {
yomi = word.clone();
}
} else if yomi == "" {
assert!(word == "?", "yomi `` comes from: {}", word);
yomi = "?".to_string();
}
seq_text.push(word);
seq_kata.push(yomi);
}
Ok((seq_text, seq_kata))
}
fn parse_to_string_and_pron(&self, parts: String) -> (String, String) {
let part_lists: Vec<String> = parts.split(',').map(|x| x.to_string()).collect();
(part_lists[0].clone(), part_lists[9].clone())
}
fn g2phone_tone_wo_punct(&self) -> Result<Vec<(String, i32)>> {
let prosodies = self.g2p_prosody()?;
let mut results: Vec<(String, i32)> = Vec::new();
let mut current_phrase: Vec<(String, i32)> = Vec::new();
let mut current_tone = 0;
for (i, letter) in prosodies.iter().enumerate() {
if letter == "^" {
assert!(i == 0);
} else if ["$", "?", "_", "#"].contains(&letter.as_str()) {
results.extend(self.fix_phone_tone(current_phrase.clone())?);
if ["$", "?"].contains(&letter.as_str()) {
assert!(i == prosodies.len() - 1);
}
current_phrase = Vec::new();
current_tone = 0;
} else if letter == "[" {
current_tone += 1;
} else if letter == "]" {
current_tone -= 1;
} else {
let new_letter = if letter == "cl" {
"q".to_string()
} else {
letter.clone()
};
current_phrase.push((new_letter, current_tone));
}
}
Ok(results)
}
fn g2p_prosody(&self) -> Result<Vec<String>> {
let labels = self.jpreprocess.make_label(self.parsed.clone());
let mut phones: Vec<String> = Vec::new();
for (i, label) in labels.iter().enumerate() {
let mut p3 = label.phoneme.c.clone().unwrap();
if "AIUEO".contains(&p3) {
// 文字をlowerする
p3 = p3.to_lowercase();
}
if p3 == "sil" {
assert!(i == 0 || i == labels.len() - 1);
if i == 0 {
phones.push("^".to_string());
} else if i == labels.len() - 1 {
let e3 = label.accent_phrase_prev.clone().unwrap().is_interrogative;
if e3 {
phones.push("$".to_string());
} else {
phones.push("?".to_string());
}
}
continue;
} else if p3 == "pau" {
phones.push("_".to_string());
continue;
} else {
phones.push(p3.clone());
}
let a1 = if let Some(mora) = &label.mora {
mora.relative_accent_position as i32
} else {
-50
};
let a2 = if let Some(mora) = &label.mora {
mora.position_forward as i32
} else {
-50
};
let a3 = if let Some(mora) = &label.mora {
mora.position_backward as i32
} else {
-50
};
let f1 = if let Some(accent_phrase) = &label.accent_phrase_curr {
accent_phrase.mora_count as i32
} else {
-50
};
let a2_next = if let Some(mora) = &labels[i + 1].mora {
mora.position_forward as i32
} else {
-50
};
if a3 == 1 && a2_next == 1 && "aeiouAEIOUNcl".contains(&p3) {
phones.push("#".to_string());
} else if a1 == 0 && a2_next == a2 + 1 && a2 != f1 {
phones.push("]".to_string());
} else if a2 == 1 && a2_next == 2 {
phones.push("[".to_string());
}
}
Ok(phones)
}
}

View File

@@ -0,0 +1,16 @@
#[cfg(feature = "std")]
pub mod bert;
pub mod error;
pub mod jtalk;
#[cfg(feature = "std")]
pub mod model;
pub mod mora;
pub mod nlp;
pub mod norm;
pub mod sbv2file;
pub mod style;
pub mod tokenizer;
#[cfg(feature = "std")]
pub mod tts;
pub mod tts_util;
pub mod utils;

View File

@@ -0,0 +1,34 @@
use std::env;
use std::fs;
#[cfg(feature = "std")]
fn main_inner() -> anyhow::Result<()> {
use sbv2_core::tts;
dotenvy::dotenv_override().ok();
env_logger::init();
let text = "今日の天気は快晴です。";
let ident = "aaa";
let mut tts_holder = tts::TTSModelHolder::new(
&fs::read(env::var("BERT_MODEL_PATH")?)?,
&fs::read(env::var("TOKENIZER_PATH")?)?,
env::var("HOLDER_MAX_LOADED_MODElS")
.ok()
.and_then(|x| x.parse().ok()),
)?;
tts_holder.load_sbv2file(ident, fs::read(env::var("MODEL_PATH")?)?)?;
let audio = tts_holder.easy_synthesize(ident, &text, 0, tts::SynthesizeOptions::default())?;
fs::write("output.wav", audio)?;
Ok(())
}
#[cfg(not(feature = "std"))]
fn main_inner() -> anyhow::Result<()> {
Ok(())
}
fn main() {
if let Err(e) = main_inner() {
println!("Error: {e}");
}
}

View File

@@ -0,0 +1,101 @@
use crate::error::Result;
use ndarray::{array, Array1, Array2, Array3, Axis, Ix3};
use ort::session::{builder::GraphOptimizationLevel, Session};
#[allow(clippy::vec_init_then_push, unused_variables)]
pub fn load_model<P: AsRef<[u8]>>(model_file: P, bert: bool) -> Result<Session> {
let mut exp = Vec::new();
#[cfg(feature = "tensorrt")]
{
if bert {
exp.push(
ort::execution_providers::TensorRTExecutionProvider::default()
.with_fp16(true)
.with_profile_min_shapes("input_ids:1x1,attention_mask:1x1")
.with_profile_max_shapes("input_ids:1x100,attention_mask:1x100")
.with_profile_opt_shapes("input_ids:1x25,attention_mask:1x25")
.build(),
);
}
}
#[cfg(feature = "cuda")]
{
#[allow(unused_mut)]
let mut cuda = ort::execution_providers::CUDAExecutionProvider::default()
.with_conv_algorithm_search(
ort::execution_providers::CUDAExecutionProviderCuDNNConvAlgoSearch::Default,
);
#[cfg(feature = "cuda_tf32")]
{
cuda = cuda.with_tf32(true);
}
exp.push(cuda.build());
}
#[cfg(feature = "directml")]
{
exp.push(ort::execution_providers::DirectMLExecutionProvider::default().build());
}
#[cfg(feature = "coreml")]
{
exp.push(ort::execution_providers::CoreMLExecutionProvider::default().build());
}
exp.push(ort::execution_providers::CPUExecutionProvider::default().build());
Ok(Session::builder()?
.with_execution_providers(exp)?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_intra_threads(num_cpus::get_physical())?
.with_parallel_execution(true)?
.with_inter_threads(num_cpus::get_physical())?
.commit_from_memory(model_file.as_ref())?)
}
#[allow(clippy::too_many_arguments)]
pub fn synthesize(
session: &mut Session,
bert_ori: Array2<f32>,
x_tst: Array1<i64>,
tones: Array1<i64>,
lang_ids: Array1<i64>,
style_vector: Array1<f32>,
sdp_ratio: f32,
length_scale: f32,
) -> Result<Array3<f32>> {
let bert_ori = bert_ori.insert_axis(Axis(0));
let bert_ori = bert_ori.as_standard_layout();
let bert = ort::value::TensorRef::from_array_view(&bert_ori)?;
let mut x_tst_lengths = array![x_tst.shape()[0] as i64];
let x_tst_lengths = ort::value::TensorRef::from_array_view(&mut x_tst_lengths)?;
let mut x_tst = x_tst.insert_axis(Axis(0));
let x_tst = ort::value::TensorRef::from_array_view(&mut x_tst)?;
let mut lang_ids = lang_ids.insert_axis(Axis(0));
let lang_ids = ort::value::TensorRef::from_array_view(&mut lang_ids)?;
let mut tones = tones.insert_axis(Axis(0));
let tones = ort::value::TensorRef::from_array_view(&mut tones)?;
let mut style_vector = style_vector.insert_axis(Axis(0));
let style_vector = ort::value::TensorRef::from_array_view(&mut style_vector)?;
let sid = vec![1_i64];
let sid = ort::value::TensorRef::from_array_view((vec![1_i64], sid.as_slice()))?;
let sdp_ratio = vec![sdp_ratio];
let sdp_ratio = ort::value::TensorRef::from_array_view((vec![1_i64], sdp_ratio.as_slice()))?;
let length_scale = vec![length_scale];
let length_scale =
ort::value::TensorRef::from_array_view((vec![1_i64], length_scale.as_slice()))?;
let outputs = session.run(ort::inputs! {
"x_tst" => x_tst,
"x_tst_lengths" => x_tst_lengths,
"sid" => sid,
"tones" => tones,
"language" => lang_ids,
"bert" => bert,
"style_vec" => style_vector,
"sdp_ratio" => sdp_ratio,
"length_scale" => length_scale,
})?;
let audio_array = outputs["output"]
.try_extract_tensor::<f32>()?
.into_dimensionality::<Ix3>()?
.to_owned();
Ok(audio_array)
}

View File

@@ -0,0 +1,40 @@
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Serialize, Deserialize)]
pub struct Mora {
pub mora: String,
pub consonant: Option<String>,
pub vowel: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct MoraFile {
pub minimum: Vec<Mora>,
pub additional: Vec<Mora>,
}
static MORA_LIST_MINIMUM: Lazy<Vec<Mora>> = Lazy::new(|| {
let data: MoraFile = serde_json::from_str(include_str!("./mora_list.json")).unwrap();
data.minimum
});
static MORA_LIST_ADDITIONAL: Lazy<Vec<Mora>> = Lazy::new(|| {
let data: MoraFile = serde_json::from_str(include_str!("./mora_list.json")).unwrap();
data.additional
});
pub static MORA_KATA_TO_MORA_PHONEMES: Lazy<HashMap<String, (Option<String>, String)>> =
Lazy::new(|| {
let mut map = HashMap::new();
for mora in MORA_LIST_MINIMUM.iter().chain(MORA_LIST_ADDITIONAL.iter()) {
map.insert(
mora.mora.clone(),
(mora.consonant.clone(), mora.vowel.clone()),
);
}
map
});
pub const VOWELS: [&str; 6] = ["a", "i", "u", "e", "o", "N"];

View File

@@ -0,0 +1,816 @@
{
"minimum": [
{
"mora": "ヴォ",
"consonant": "v",
"vowel": "o"
},
{
"mora": "ヴェ",
"consonant": "v",
"vowel": "e"
},
{
"mora": "ヴィ",
"consonant": "v",
"vowel": "i"
},
{
"mora": "ヴァ",
"consonant": "v",
"vowel": "a"
},
{
"mora": "ヴ",
"consonant": "v",
"vowel": "u"
},
{
"mora": "ン",
"consonant": null,
"vowel": "N"
},
{
"mora": "ワ",
"consonant": "w",
"vowel": "a"
},
{
"mora": "ロ",
"consonant": "r",
"vowel": "o"
},
{
"mora": "レ",
"consonant": "r",
"vowel": "e"
},
{
"mora": "ル",
"consonant": "r",
"vowel": "u"
},
{
"mora": "リョ",
"consonant": "ry",
"vowel": "o"
},
{
"mora": "リュ",
"consonant": "ry",
"vowel": "u"
},
{
"mora": "リャ",
"consonant": "ry",
"vowel": "a"
},
{
"mora": "リェ",
"consonant": "ry",
"vowel": "e"
},
{
"mora": "リ",
"consonant": "r",
"vowel": "i"
},
{
"mora": "ラ",
"consonant": "r",
"vowel": "a"
},
{
"mora": "ヨ",
"consonant": "y",
"vowel": "o"
},
{
"mora": "ユ",
"consonant": "y",
"vowel": "u"
},
{
"mora": "ヤ",
"consonant": "y",
"vowel": "a"
},
{
"mora": "モ",
"consonant": "m",
"vowel": "o"
},
{
"mora": "メ",
"consonant": "m",
"vowel": "e"
},
{
"mora": "ム",
"consonant": "m",
"vowel": "u"
},
{
"mora": "ミョ",
"consonant": "my",
"vowel": "o"
},
{
"mora": "ミュ",
"consonant": "my",
"vowel": "u"
},
{
"mora": "ミャ",
"consonant": "my",
"vowel": "a"
},
{
"mora": "ミェ",
"consonant": "my",
"vowel": "e"
},
{
"mora": "ミ",
"consonant": "m",
"vowel": "i"
},
{
"mora": "マ",
"consonant": "m",
"vowel": "a"
},
{
"mora": "ポ",
"consonant": "p",
"vowel": "o"
},
{
"mora": "ボ",
"consonant": "b",
"vowel": "o"
},
{
"mora": "ホ",
"consonant": "h",
"vowel": "o"
},
{
"mora": "ペ",
"consonant": "p",
"vowel": "e"
},
{
"mora": "ベ",
"consonant": "b",
"vowel": "e"
},
{
"mora": "ヘ",
"consonant": "h",
"vowel": "e"
},
{
"mora": "プ",
"consonant": "p",
"vowel": "u"
},
{
"mora": "ブ",
"consonant": "b",
"vowel": "u"
},
{
"mora": "フォ",
"consonant": "f",
"vowel": "o"
},
{
"mora": "フェ",
"consonant": "f",
"vowel": "e"
},
{
"mora": "フィ",
"consonant": "f",
"vowel": "i"
},
{
"mora": "ファ",
"consonant": "f",
"vowel": "a"
},
{
"mora": "フ",
"consonant": "f",
"vowel": "u"
},
{
"mora": "ピョ",
"consonant": "py",
"vowel": "o"
},
{
"mora": "ピュ",
"consonant": "py",
"vowel": "u"
},
{
"mora": "ピャ",
"consonant": "py",
"vowel": "a"
},
{
"mora": "ピェ",
"consonant": "py",
"vowel": "e"
},
{
"mora": "ピ",
"consonant": "p",
"vowel": "i"
},
{
"mora": "ビョ",
"consonant": "by",
"vowel": "o"
},
{
"mora": "ビュ",
"consonant": "by",
"vowel": "u"
},
{
"mora": "ビャ",
"consonant": "by",
"vowel": "a"
},
{
"mora": "ビェ",
"consonant": "by",
"vowel": "e"
},
{
"mora": "ビ",
"consonant": "b",
"vowel": "i"
},
{
"mora": "ヒョ",
"consonant": "hy",
"vowel": "o"
},
{
"mora": "ヒュ",
"consonant": "hy",
"vowel": "u"
},
{
"mora": "ヒャ",
"consonant": "hy",
"vowel": "a"
},
{
"mora": "ヒェ",
"consonant": "hy",
"vowel": "e"
},
{
"mora": "ヒ",
"consonant": "h",
"vowel": "i"
},
{
"mora": "パ",
"consonant": "p",
"vowel": "a"
},
{
"mora": "バ",
"consonant": "b",
"vowel": "a"
},
{
"mora": "ハ",
"consonant": "h",
"vowel": "a"
},
{
"mora": "",
"consonant": "n",
"vowel": "o"
},
{
"mora": "ネ",
"consonant": "n",
"vowel": "e"
},
{
"mora": "ヌ",
"consonant": "n",
"vowel": "u"
},
{
"mora": "ニョ",
"consonant": "ny",
"vowel": "o"
},
{
"mora": "ニュ",
"consonant": "ny",
"vowel": "u"
},
{
"mora": "ニャ",
"consonant": "ny",
"vowel": "a"
},
{
"mora": "ニェ",
"consonant": "ny",
"vowel": "e"
},
{
"mora": "ニ",
"consonant": "n",
"vowel": "i"
},
{
"mora": "ナ",
"consonant": "n",
"vowel": "a"
},
{
"mora": "ドゥ",
"consonant": "d",
"vowel": "u"
},
{
"mora": "ド",
"consonant": "d",
"vowel": "o"
},
{
"mora": "トゥ",
"consonant": "t",
"vowel": "u"
},
{
"mora": "ト",
"consonant": "t",
"vowel": "o"
},
{
"mora": "デョ",
"consonant": "dy",
"vowel": "o"
},
{
"mora": "デュ",
"consonant": "dy",
"vowel": "u"
},
{
"mora": "デャ",
"consonant": "dy",
"vowel": "a"
},
{
"mora": "ディ",
"consonant": "d",
"vowel": "i"
},
{
"mora": "デ",
"consonant": "d",
"vowel": "e"
},
{
"mora": "テョ",
"consonant": "ty",
"vowel": "o"
},
{
"mora": "テュ",
"consonant": "ty",
"vowel": "u"
},
{
"mora": "テャ",
"consonant": "ty",
"vowel": "a"
},
{
"mora": "ティ",
"consonant": "t",
"vowel": "i"
},
{
"mora": "テ",
"consonant": "t",
"vowel": "e"
},
{
"mora": "ツォ",
"consonant": "ts",
"vowel": "o"
},
{
"mora": "ツェ",
"consonant": "ts",
"vowel": "e"
},
{
"mora": "ツィ",
"consonant": "ts",
"vowel": "i"
},
{
"mora": "ツァ",
"consonant": "ts",
"vowel": "a"
},
{
"mora": "ツ",
"consonant": "ts",
"vowel": "u"
},
{
"mora": "ッ",
"consonant": null,
"vowel": "q"
},
{
"mora": "チョ",
"consonant": "ch",
"vowel": "o"
},
{
"mora": "チュ",
"consonant": "ch",
"vowel": "u"
},
{
"mora": "チャ",
"consonant": "ch",
"vowel": "a"
},
{
"mora": "チェ",
"consonant": "ch",
"vowel": "e"
},
{
"mora": "チ",
"consonant": "ch",
"vowel": "i"
},
{
"mora": "ダ",
"consonant": "d",
"vowel": "a"
},
{
"mora": "タ",
"consonant": "t",
"vowel": "a"
},
{
"mora": "ゾ",
"consonant": "z",
"vowel": "o"
},
{
"mora": "ソ",
"consonant": "s",
"vowel": "o"
},
{
"mora": "ゼ",
"consonant": "z",
"vowel": "e"
},
{
"mora": "セ",
"consonant": "s",
"vowel": "e"
},
{
"mora": "ズィ",
"consonant": "z",
"vowel": "i"
},
{
"mora": "ズ",
"consonant": "z",
"vowel": "u"
},
{
"mora": "スィ",
"consonant": "s",
"vowel": "i"
},
{
"mora": "ス",
"consonant": "s",
"vowel": "u"
},
{
"mora": "ジョ",
"consonant": "j",
"vowel": "o"
},
{
"mora": "ジュ",
"consonant": "j",
"vowel": "u"
},
{
"mora": "ジャ",
"consonant": "j",
"vowel": "a"
},
{
"mora": "ジェ",
"consonant": "j",
"vowel": "e"
},
{
"mora": "ジ",
"consonant": "j",
"vowel": "i"
},
{
"mora": "ショ",
"consonant": "sh",
"vowel": "o"
},
{
"mora": "シュ",
"consonant": "sh",
"vowel": "u"
},
{
"mora": "シャ",
"consonant": "sh",
"vowel": "a"
},
{
"mora": "シェ",
"consonant": "sh",
"vowel": "e"
},
{
"mora": "シ",
"consonant": "sh",
"vowel": "i"
},
{
"mora": "ザ",
"consonant": "z",
"vowel": "a"
},
{
"mora": "サ",
"consonant": "s",
"vowel": "a"
},
{
"mora": "ゴ",
"consonant": "g",
"vowel": "o"
},
{
"mora": "コ",
"consonant": "k",
"vowel": "o"
},
{
"mora": "ゲ",
"consonant": "g",
"vowel": "e"
},
{
"mora": "ケ",
"consonant": "k",
"vowel": "e"
},
{
"mora": "グヮ",
"consonant": "gw",
"vowel": "a"
},
{
"mora": "グ",
"consonant": "g",
"vowel": "u"
},
{
"mora": "クヮ",
"consonant": "kw",
"vowel": "a"
},
{
"mora": "ク",
"consonant": "k",
"vowel": "u"
},
{
"mora": "ギョ",
"consonant": "gy",
"vowel": "o"
},
{
"mora": "ギュ",
"consonant": "gy",
"vowel": "u"
},
{
"mora": "ギャ",
"consonant": "gy",
"vowel": "a"
},
{
"mora": "ギェ",
"consonant": "gy",
"vowel": "e"
},
{
"mora": "ギ",
"consonant": "g",
"vowel": "i"
},
{
"mora": "キョ",
"consonant": "ky",
"vowel": "o"
},
{
"mora": "キュ",
"consonant": "ky",
"vowel": "u"
},
{
"mora": "キャ",
"consonant": "ky",
"vowel": "a"
},
{
"mora": "キェ",
"consonant": "ky",
"vowel": "e"
},
{
"mora": "キ",
"consonant": "k",
"vowel": "i"
},
{
"mora": "ガ",
"consonant": "g",
"vowel": "a"
},
{
"mora": "カ",
"consonant": "k",
"vowel": "a"
},
{
"mora": "オ",
"consonant": null,
"vowel": "o"
},
{
"mora": "エ",
"consonant": null,
"vowel": "e"
},
{
"mora": "ウォ",
"consonant": "w",
"vowel": "o"
},
{
"mora": "ウェ",
"consonant": "w",
"vowel": "e"
},
{
"mora": "ウィ",
"consonant": "w",
"vowel": "i"
},
{
"mora": "ウ",
"consonant": null,
"vowel": "u"
},
{
"mora": "イェ",
"consonant": "y",
"vowel": "e"
},
{
"mora": "イ",
"consonant": null,
"vowel": "i"
},
{
"mora": "ア",
"consonant": null,
"vowel": "a"
}
],
"additional": [
{
"mora": "ヴョ",
"consonant": "by",
"vowel": "o"
},
{
"mora": "ヴュ",
"consonant": "by",
"vowel": "u"
},
{
"mora": "ヴャ",
"consonant": "by",
"vowel": "a"
},
{
"mora": "ヲ",
"consonant": null,
"vowel": "o"
},
{
"mora": "ヱ",
"consonant": null,
"vowel": "e"
},
{
"mora": "ヰ",
"consonant": null,
"vowel": "i"
},
{
"mora": "ヮ",
"consonant": "w",
"vowel": "a"
},
{
"mora": "ョ",
"consonant": "y",
"vowel": "o"
},
{
"mora": "ュ",
"consonant": "y",
"vowel": "u"
},
{
"mora": "ヅ",
"consonant": "z",
"vowel": "u"
},
{
"mora": "ヂ",
"consonant": "j",
"vowel": "i"
},
{
"mora": "ヶ",
"consonant": "k",
"vowel": "e"
},
{
"mora": "ャ",
"consonant": "y",
"vowel": "a"
},
{
"mora": "ォ",
"consonant": null,
"vowel": "o"
},
{
"mora": "ェ",
"consonant": null,
"vowel": "e"
},
{
"mora": "ゥ",
"consonant": null,
"vowel": "u"
},
{
"mora": "ィ",
"consonant": null,
"vowel": "i"
},
{
"mora": "ァ",
"consonant": null,
"vowel": "a"
}
]
}

View File

@@ -0,0 +1,24 @@
use crate::norm::SYMBOLS;
use once_cell::sync::Lazy;
use std::collections::HashMap;
static SYMBOL_TO_ID: Lazy<HashMap<String, i32>> = Lazy::new(|| {
let mut map = HashMap::new();
for (i, symbols) in SYMBOLS.iter().enumerate() {
map.insert(symbols.to_string(), i as i32);
}
map
});
pub fn cleaned_text_to_sequence(
cleaned_phones: Vec<String>,
tones: Vec<i32>,
) -> (Vec<i64>, Vec<i64>, Vec<i64>) {
let phones: Vec<i64> = cleaned_phones
.iter()
.map(|phone| *SYMBOL_TO_ID.get(phone).unwrap() as i64)
.collect();
let tones: Vec<i64> = tones.iter().map(|tone| (*tone + 6) as i64).collect();
let lang_ids: Vec<i64> = vec![1; phones.len()];
(phones, tones, lang_ids)
}

View File

@@ -0,0 +1,127 @@
use once_cell::sync::Lazy;
use std::collections::{HashMap, HashSet};
static REPLACE_MAP: Lazy<HashMap<&str, &str>> = Lazy::new(|| {
let mut map = HashMap::new();
map.insert("", ",");
map.insert("", ",");
map.insert("", ",");
map.insert("", ".");
map.insert("", "!");
map.insert("", "?");
map.insert("\n", ".");
map.insert("", ".");
map.insert("", "...");
map.insert("···", "...");
map.insert("・・・", "...");
map.insert("·", ",");
map.insert("", ",");
map.insert("", ",");
map.insert("$", ".");
map.insert("", "'");
map.insert("", "'");
map.insert("\"", "'");
map.insert("", "'");
map.insert("", "'");
map.insert("", "'");
map.insert("", "'");
map.insert("(", "'");
map.insert(")", "'");
map.insert("", "'");
map.insert("", "'");
map.insert("", "'");
map.insert("", "'");
map.insert("[", "'");
map.insert("]", "'");
// NFKC 正規化後のハイフン・ダッシュの変種を全て通常半角ハイフン - \u002d に変換
map.insert("\u{02d7}", "\u{002d}"); // ˗, Modifier Letter Minus Sign
map.insert("\u{2010}", "\u{002d}"); // , Hyphen,
map.insert("\u{2012}", "\u{002d}"); // , Figure Dash
map.insert("\u{2013}", "\u{002d}"); // , En Dash
map.insert("\u{2014}", "\u{002d}"); // —, Em Dash
map.insert("\u{2015}", "\u{002d}"); // ―, Horizontal Bar
map.insert("\u{2043}", "\u{002d}"); // , Hyphen Bullet
map.insert("\u{2212}", "\u{002d}"); // , Minus Sign
map.insert("\u{23af}", "\u{002d}"); // ⎯, Horizontal Line Extension
map.insert("\u{23e4}", "\u{002d}"); // ⏤, Straightness
map.insert("\u{2500}", "\u{002d}"); // ─, Box Drawings Light Horizontal
map.insert("\u{2501}", "\u{002d}"); // ━, Box Drawings Heavy Horizontal
map.insert("\u{2e3a}", "\u{002d}"); // ⸺, Two-Em Dash
map.insert("\u{2e3b}", "\u{002d}"); // ⸻, Three-Em Dash
map.insert("", "'");
map.insert("", "'");
map
});
const ZH_SYMBOLS: [&str; 65] = [
"E", "En", "a", "ai", "an", "ang", "ao", "b", "c", "ch", "d", "e", "ei", "en", "eng", "er",
"f", "g", "h", "i", "i0", "ia", "ian", "iang", "iao", "ie", "in", "ing", "iong", "ir", "iu",
"j", "k", "l", "m", "n", "o", "ong", "ou", "p", "q", "r", "s", "sh", "t", "u", "ua", "uai",
"uan", "uang", "ui", "un", "uo", "v", "van", "ve", "vn", "w", "x", "y", "z", "zh", "AA", "EE",
"OO",
];
pub const JP_SYMBOLS: [&str; 42] = [
"N", "a", "a:", "b", "by", "ch", "d", "dy", "e", "e:", "f", "g", "gy", "h", "hy", "i", "i:",
"j", "k", "ky", "m", "my", "n", "ny", "o", "o:", "p", "py", "q", "r", "ry", "s", "sh", "t",
"ts", "ty", "u", "u:", "w", "y", "z", "zy",
];
pub const EN_SYMBOLS: [&str; 39] = [
"aa", "ae", "ah", "ao", "aw", "ay", "b", "ch", "d", "dh", "eh", "er", "ey", "f", "g", "hh",
"ih", "iy", "jh", "k", "l", "m", "n", "ng", "ow", "oy", "p", "r", "s", "sh", "t", "th", "uh",
"uw", "V", "w", "y", "z", "zh",
];
pub static PUNCTUATIONS: [&str; 7] = ["!", "?", "", ",", ".", "'", "-"];
pub static PUNCTUATION_SYMBOLS: Lazy<Vec<&str>> = Lazy::new(|| {
let mut symbols = PUNCTUATIONS.to_vec();
symbols.append(&mut vec!["SP", "UNK"]);
symbols
});
const PAD: &str = "_";
pub static NORMAL_SYMBOLS: Lazy<Vec<&str>> = Lazy::new(|| {
let mut symbols: Vec<&str> = ZH_SYMBOLS.to_vec();
symbols.append(&mut JP_SYMBOLS.to_vec());
symbols.append(&mut EN_SYMBOLS.to_vec());
let symbols: HashSet<&str> = symbols.drain(..).collect();
let mut symbols: Vec<&str> = symbols.into_iter().collect();
symbols.sort();
symbols
});
pub static SYMBOLS: Lazy<Vec<&str>> = Lazy::new(|| {
let mut symbols = vec![PAD];
symbols.append(&mut NORMAL_SYMBOLS.clone());
symbols.append(&mut PUNCTUATION_SYMBOLS.to_vec());
symbols
});
static PUNCTUATION_CLEANUP_PATTERN: Lazy<regex::Regex> = Lazy::new(|| {
let pattern = r"[^\u{3040}-\u{309F}\u{30A0}-\u{30FF}\u{4E00}-\u{9FFF}\u{3400}-\u{4DBF}\u{3005}"
.to_owned()
+ r"\u{0041}-\u{005A}\u{0061}-\u{007A}"
+ r"\u{FF21}-\u{FF3A}\u{FF41}-\u{FF5A}"
+ r"\u{0370}-\u{03FF}\u{1F00}-\u{1FFF}"
+ &PUNCTUATIONS.join("")
+ r"]+";
regex::Regex::new(&pattern).unwrap()
});
pub fn normalize_text(text: &str) -> String {
// 日本語のテキストを正規化する
let text = text.replace('~', "");
let text = text.replace('', "");
let text = text.replace('〜', "");
replace_punctuation(text)
}
pub fn replace_punctuation(mut text: String) -> String {
for (k, v) in REPLACE_MAP.iter() {
text = text.replace(k, v);
}
let content = PUNCTUATION_CLEANUP_PATTERN
.replace_all(&text, "")
.to_string();
content
}

View File

@@ -0,0 +1,37 @@
use std::io::{Cursor, Read};
use tar::Archive;
use zstd::decode_all;
use crate::error::{Error, Result};
/// Parse a .sbv2 file binary
///
/// # Examples
///
/// ```rs
/// parse_sbv2file("tsukuyomi", std::fs::read("tsukuyomi.sbv2")?)?;
/// ```
pub fn parse_sbv2file<P: AsRef<[u8]>>(sbv2_bytes: P) -> Result<(Vec<u8>, Vec<u8>)> {
let mut arc = Archive::new(Cursor::new(decode_all(Cursor::new(sbv2_bytes.as_ref()))?));
let mut vits2 = None;
let mut style_vectors = None;
let mut et = arc.entries()?;
while let Some(Ok(mut e)) = et.next() {
let pth = String::from_utf8_lossy(&e.path_bytes()).to_string();
let mut b = Vec::with_capacity(e.size() as usize);
e.read_to_end(&mut b)?;
match pth.as_str() {
"model.onnx" => vits2 = Some(b),
"style_vectors.json" => style_vectors = Some(b),
_ => continue,
}
}
if style_vectors.is_none() {
return Err(Error::ModelNotFoundError("style_vectors".to_string()));
}
if vits2.is_none() {
return Err(Error::ModelNotFoundError("vits2".to_string()));
}
Ok((style_vectors.unwrap(), vits2.unwrap()))
}

View File

@@ -0,0 +1,28 @@
use crate::error::Result;
use ndarray::{s, Array1, Array2};
use serde::Deserialize;
#[derive(Deserialize)]
pub struct Data {
pub shape: [usize; 2],
pub data: Vec<Vec<f32>>,
}
pub fn load_style<P: AsRef<[u8]>>(path: P) -> Result<Array2<f32>> {
let data: Data = serde_json::from_slice(path.as_ref())?;
Ok(Array2::from_shape_vec(
data.shape,
data.data.iter().flatten().copied().collect(),
)?)
}
pub fn get_style_vector(
style_vectors: &Array2<f32>,
style_id: i32,
weight: f32,
) -> Result<Array1<f32>> {
let mean = style_vectors.slice(s![0, ..]).to_owned();
let style_vector = style_vectors.slice(s![style_id as usize, ..]).to_owned();
let diff = (style_vector - &mean) * weight;
Ok(mean + &diff)
}

View File

@@ -0,0 +1,21 @@
use crate::error::Result;
pub use tokenizers::Tokenizer;
pub fn get_tokenizer<P: AsRef<[u8]>>(p: P) -> Result<Tokenizer> {
let tokenizer = Tokenizer::from_bytes(p)?;
Ok(tokenizer)
}
pub fn tokenize(text: &str, tokenizer: &Tokenizer) -> Result<(Vec<i64>, Vec<i64>)> {
let mut token_ids = vec![1];
let mut attention_masks = vec![1];
for content in text.chars() {
let token = tokenizer.encode(content.to_string(), false)?;
let ids = token.get_ids();
token_ids.extend(ids.iter().map(|&x| x as i64));
attention_masks.extend(token.get_attention_mask().iter().map(|&x| x as i64));
}
token_ids.push(2);
attention_masks.push(1);
Ok((token_ids, attention_masks))
}

350
crates/sbv2_core/src/tts.rs Normal file
View File

@@ -0,0 +1,350 @@
use crate::error::{Error, Result};
use crate::{jtalk, model, style, tokenizer, tts_util};
use ndarray::{concatenate, Array1, Array2, Array3, Axis};
use ort::session::Session;
use tokenizers::Tokenizer;
#[derive(PartialEq, Eq, Clone)]
pub struct TTSIdent(String);
impl std::fmt::Display for TTSIdent {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)?;
Ok(())
}
}
impl<S> From<S> for TTSIdent
where
S: AsRef<str>,
{
fn from(value: S) -> Self {
TTSIdent(value.as_ref().to_string())
}
}
pub struct TTSModel {
vits2: Option<Session>,
style_vectors: Array2<f32>,
ident: TTSIdent,
bytes: Option<Vec<u8>>,
}
/// High-level Style-Bert-VITS2's API
pub struct TTSModelHolder {
tokenizer: Tokenizer,
bert: Session,
models: Vec<TTSModel>,
jtalk: jtalk::JTalk,
max_loaded_models: Option<usize>,
}
impl TTSModelHolder {
/// Initialize a new TTSModelHolder
///
/// # Examples
///
/// ```rs
/// let mut tts_holder = TTSModelHolder::new(std::fs::read("deberta.onnx")?, std::fs::read("tokenizer.json")?, None)?;
/// ```
pub fn new<P: AsRef<[u8]>>(
bert_model_bytes: P,
tokenizer_bytes: P,
max_loaded_models: Option<usize>,
) -> Result<Self> {
let bert = model::load_model(bert_model_bytes, true)?;
let jtalk = jtalk::JTalk::new()?;
let tokenizer = tokenizer::get_tokenizer(tokenizer_bytes)?;
Ok(TTSModelHolder {
bert,
models: vec![],
jtalk,
tokenizer,
max_loaded_models,
})
}
/// Return a list of model names
pub fn models(&self) -> Vec<String> {
self.models.iter().map(|m| m.ident.to_string()).collect()
}
/// Load a .sbv2 file binary
///
/// # Examples
///
/// ```rs
/// tts_holder.load_sbv2file("tsukuyomi", std::fs::read("tsukuyomi.sbv2")?)?;
/// ```
pub fn load_sbv2file<I: Into<TTSIdent>, P: AsRef<[u8]>>(
&mut self,
ident: I,
sbv2_bytes: P,
) -> Result<()> {
let (style_vectors, vits2) = crate::sbv2file::parse_sbv2file(sbv2_bytes)?;
self.load(ident, style_vectors, vits2)?;
Ok(())
}
/// Load a style vector and onnx model binary
///
/// # Examples
///
/// ```rs
/// tts_holder.load("tsukuyomi", std::fs::read("style_vectors.json")?, std::fs::read("model.onnx")?)?;
/// ```
pub fn load<I: Into<TTSIdent>, P: AsRef<[u8]>>(
&mut self,
ident: I,
style_vectors_bytes: P,
vits2_bytes: P,
) -> Result<()> {
let ident = ident.into();
if self.find_model(ident.clone()).is_err() {
let mut load = true;
if let Some(max) = self.max_loaded_models {
if self.models.iter().filter(|x| x.vits2.is_some()).count() >= max {
load = false;
}
}
self.models.push(TTSModel {
vits2: if load {
Some(model::load_model(&vits2_bytes, false)?)
} else {
None
},
style_vectors: style::load_style(style_vectors_bytes)?,
ident,
bytes: if self.max_loaded_models.is_some() {
Some(vits2_bytes.as_ref().to_vec())
} else {
None
},
})
}
Ok(())
}
/// Unload a model
pub fn unload<I: Into<TTSIdent>>(&mut self, ident: I) -> bool {
let ident = ident.into();
if let Some((i, _)) = self
.models
.iter()
.enumerate()
.find(|(_, m)| m.ident == ident)
{
self.models.remove(i);
true
} else {
false
}
}
/// Parse text and return the input for synthesize
///
/// # Note
/// This function is for low-level usage, use `easy_synthesize` for high-level usage.
#[allow(clippy::type_complexity)]
pub fn parse_text(
&mut self,
text: &str,
) -> Result<(Array2<f32>, Array1<i64>, Array1<i64>, Array1<i64>)> {
crate::tts_util::parse_text_blocking(
text,
&self.jtalk,
&self.tokenizer,
|token_ids, attention_masks| {
crate::bert::predict(&mut self.bert, token_ids, attention_masks)
},
)
}
fn find_model<I: Into<TTSIdent>>(&mut self, ident: I) -> Result<&mut TTSModel> {
let ident = ident.into();
self.models
.iter_mut()
.find(|m| m.ident == ident)
.ok_or(Error::ModelNotFoundError(ident.to_string()))
}
fn find_and_load_model<I: Into<TTSIdent>>(&mut self, ident: I) -> Result<bool> {
let ident = ident.into();
let (bytes, style_vectors) = {
let model = self
.models
.iter()
.find(|m| m.ident == ident)
.ok_or(Error::ModelNotFoundError(ident.to_string()))?;
if model.vits2.is_some() {
return Ok(true);
}
(model.bytes.clone().unwrap(), model.style_vectors.clone())
};
self.unload(ident.clone());
let s = model::load_model(&bytes, false)?;
if let Some(max) = self.max_loaded_models {
if self.models.iter().filter(|x| x.vits2.is_some()).count() >= max {
self.unload(self.models.first().unwrap().ident.clone());
}
}
self.models.push(TTSModel {
bytes: Some(bytes.to_vec()),
vits2: Some(s),
style_vectors,
ident: ident.clone(),
});
let model = self
.models
.iter()
.find(|m| m.ident == ident)
.ok_or(Error::ModelNotFoundError(ident.to_string()))?;
if model.vits2.is_some() {
return Ok(true);
}
Err(Error::ModelNotFoundError(ident.to_string()))
}
/// Get style vector by style id and weight
///
/// # Note
/// This function is for low-level usage, use `easy_synthesize` for high-level usage.
pub fn get_style_vector<I: Into<TTSIdent>>(
&mut self,
ident: I,
style_id: i32,
weight: f32,
) -> Result<Array1<f32>> {
style::get_style_vector(&self.find_model(ident)?.style_vectors, style_id, weight)
}
/// Synthesize text to audio
///
/// # Examples
///
/// ```rs
/// let audio = tts_holder.easy_synthesize("tsukuyomi", "こんにちは", 0, SynthesizeOptions::default())?;
/// ```
pub fn easy_synthesize<I: Into<TTSIdent> + Copy>(
&mut self,
ident: I,
text: &str,
style_id: i32,
options: SynthesizeOptions,
) -> Result<Vec<u8>> {
self.find_and_load_model(ident)?;
let style_vector = self.get_style_vector(ident, style_id, options.style_weight)?;
let audio_array = if options.split_sentences {
let texts: Vec<&str> = text.split('\n').collect();
let mut audios = vec![];
for (i, t) in texts.iter().enumerate() {
if t.is_empty() {
continue;
}
let (bert_ori, phones, tones, lang_ids) = self.parse_text(t)?;
let vits2 = self
.find_model(ident)?
.vits2
.as_mut()
.ok_or(Error::ModelNotFoundError(ident.into().to_string()))?;
let audio = model::synthesize(
vits2,
bert_ori.to_owned(),
phones,
tones,
lang_ids,
style_vector.clone(),
options.sdp_ratio,
options.length_scale,
)?;
audios.push(audio.clone());
if i != texts.len() - 1 {
audios.push(Array3::zeros((1, 1, 22050)));
}
}
concatenate(
Axis(2),
&audios.iter().map(|x| x.view()).collect::<Vec<_>>(),
)?
} else {
let (bert_ori, phones, tones, lang_ids) = self.parse_text(text)?;
let vits2 = self
.find_model(ident)?
.vits2
.as_mut()
.ok_or(Error::ModelNotFoundError(ident.into().to_string()))?;
model::synthesize(
vits2,
bert_ori.to_owned(),
phones,
tones,
lang_ids,
style_vector,
options.sdp_ratio,
options.length_scale,
)?
};
tts_util::array_to_vec(audio_array)
}
/// Synthesize text to audio
///
/// # Note
/// This function is for low-level usage, use `easy_synthesize` for high-level usage.
#[allow(clippy::too_many_arguments)]
pub fn synthesize<I: Into<TTSIdent> + Copy>(
&mut self,
ident: I,
bert_ori: Array2<f32>,
phones: Array1<i64>,
tones: Array1<i64>,
lang_ids: Array1<i64>,
style_vector: Array1<f32>,
sdp_ratio: f32,
length_scale: f32,
) -> Result<Vec<u8>> {
self.find_and_load_model(ident)?;
let vits2 = self
.find_model(ident)?
.vits2
.as_mut()
.ok_or(Error::ModelNotFoundError(ident.into().to_string()))?;
let audio_array = model::synthesize(
vits2,
bert_ori.to_owned(),
phones,
tones,
lang_ids,
style_vector,
sdp_ratio,
length_scale,
)?;
tts_util::array_to_vec(audio_array)
}
}
/// Synthesize options
///
/// # Fields
/// - `sdp_ratio`: SDP ratio
/// - `length_scale`: Length scale
/// - `style_weight`: Style weight
/// - `split_sentences`: Split sentences
pub struct SynthesizeOptions {
pub sdp_ratio: f32,
pub length_scale: f32,
pub style_weight: f32,
pub split_sentences: bool,
}
impl Default for SynthesizeOptions {
fn default() -> Self {
SynthesizeOptions {
sdp_ratio: 0.0,
length_scale: 1.0,
style_weight: 1.0,
split_sentences: true,
}
}
}

View File

@@ -0,0 +1,180 @@
use std::io::Cursor;
use crate::error::Result;
use crate::{jtalk, nlp, norm, tokenizer, utils};
use hound::{SampleFormat, WavSpec, WavWriter};
use ndarray::{concatenate, s, Array, Array1, Array2, Array3, Axis};
use tokenizers::Tokenizer;
/// Parse text and return the input for synthesize
///
/// # Note
/// This function is for low-level usage, use `easy_synthesize` for high-level usage.
#[allow(clippy::type_complexity)]
pub async fn parse_text(
text: &str,
jtalk: &jtalk::JTalk,
tokenizer: &Tokenizer,
bert_predict: impl FnOnce(
Vec<i64>,
Vec<i64>,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<ndarray::Array2<f32>>>>,
>,
) -> Result<(Array2<f32>, Array1<i64>, Array1<i64>, Array1<i64>)> {
let text = jtalk.num2word(text)?;
let normalized_text = norm::normalize_text(&text);
let process = jtalk.process_text(&normalized_text)?;
let (phones, tones, mut word2ph) = process.g2p()?;
let (phones, tones, lang_ids) = nlp::cleaned_text_to_sequence(phones, tones);
let phones = utils::intersperse(&phones, 0);
let tones = utils::intersperse(&tones, 0);
let lang_ids = utils::intersperse(&lang_ids, 0);
for item in &mut word2ph {
*item *= 2;
}
word2ph[0] += 1;
let text = {
let (seq_text, _) = process.text_to_seq_kata()?;
seq_text.join("")
};
let (token_ids, attention_masks) = tokenizer::tokenize(&text, tokenizer)?;
let bert_content = bert_predict(token_ids, attention_masks).await?;
assert!(
word2ph.len() == text.chars().count() + 2,
"{} {}",
word2ph.len(),
normalized_text.chars().count()
);
let mut phone_level_feature = vec![];
for (i, reps) in word2ph.iter().enumerate() {
let repeat_feature = {
let (reps_rows, reps_cols) = (*reps, 1);
let arr_len = bert_content.slice(s![i, ..]).len();
let mut results: Array2<f32> = Array::zeros((reps_rows as usize, arr_len * reps_cols));
for j in 0..reps_rows {
for k in 0..reps_cols {
let mut view = results.slice_mut(s![j, k * arr_len..(k + 1) * arr_len]);
view.assign(&bert_content.slice(s![i, ..]));
}
}
results
};
phone_level_feature.push(repeat_feature);
}
let phone_level_feature = concatenate(
Axis(0),
&phone_level_feature
.iter()
.map(|x| x.view())
.collect::<Vec<_>>(),
)?;
let bert_ori = phone_level_feature.t();
Ok((
bert_ori.to_owned(),
phones.into(),
tones.into(),
lang_ids.into(),
))
}
/// Parse text and return the input for synthesize
///
/// # Note
/// This function is for low-level usage, use `easy_synthesize` for high-level usage.
#[allow(clippy::type_complexity)]
pub fn parse_text_blocking(
text: &str,
jtalk: &jtalk::JTalk,
tokenizer: &Tokenizer,
bert_predict: impl FnOnce(Vec<i64>, Vec<i64>) -> Result<ndarray::Array2<f32>>,
) -> Result<(Array2<f32>, Array1<i64>, Array1<i64>, Array1<i64>)> {
let text = jtalk.num2word(text)?;
let normalized_text = norm::normalize_text(&text);
let process = jtalk.process_text(&normalized_text)?;
let (phones, tones, mut word2ph) = process.g2p()?;
let (phones, tones, lang_ids) = nlp::cleaned_text_to_sequence(phones, tones);
let phones = utils::intersperse(&phones, 0);
let tones = utils::intersperse(&tones, 0);
let lang_ids = utils::intersperse(&lang_ids, 0);
for item in &mut word2ph {
*item *= 2;
}
word2ph[0] += 1;
let text = {
let (seq_text, _) = process.text_to_seq_kata()?;
seq_text.join("")
};
let (token_ids, attention_masks) = tokenizer::tokenize(&text, tokenizer)?;
let bert_content = bert_predict(token_ids, attention_masks)?;
assert!(
word2ph.len() == text.chars().count() + 2,
"{} {}",
word2ph.len(),
normalized_text.chars().count()
);
let mut phone_level_feature = vec![];
for (i, reps) in word2ph.iter().enumerate() {
let repeat_feature = {
let (reps_rows, reps_cols) = (*reps, 1);
let arr_len = bert_content.slice(s![i, ..]).len();
let mut results: Array2<f32> = Array::zeros((reps_rows as usize, arr_len * reps_cols));
for j in 0..reps_rows {
for k in 0..reps_cols {
let mut view = results.slice_mut(s![j, k * arr_len..(k + 1) * arr_len]);
view.assign(&bert_content.slice(s![i, ..]));
}
}
results
};
phone_level_feature.push(repeat_feature);
}
let phone_level_feature = concatenate(
Axis(0),
&phone_level_feature
.iter()
.map(|x| x.view())
.collect::<Vec<_>>(),
)?;
let bert_ori = phone_level_feature.t();
Ok((
bert_ori.to_owned(),
phones.into(),
tones.into(),
lang_ids.into(),
))
}
pub fn array_to_vec(audio_array: Array3<f32>) -> Result<Vec<u8>> {
let spec = WavSpec {
channels: 1,
sample_rate: 44100,
bits_per_sample: 32,
sample_format: SampleFormat::Float,
};
let mut cursor = Cursor::new(Vec::new());
let mut writer = WavWriter::new(&mut cursor, spec)?;
for i in 0..audio_array.shape()[0] {
let output = audio_array.slice(s![i, 0, ..]).to_vec();
for sample in output {
writer.write_sample(sample)?;
}
}
writer.finalize()?;
Ok(cursor.into_inner())
}

View File

@@ -0,0 +1,12 @@
pub fn intersperse<T>(slice: &[T], sep: T) -> Vec<T>
where
T: Clone,
{
let mut result = vec![sep.clone(); slice.len() * 2 + 1];
result
.iter_mut()
.step_by(2)
.zip(slice.iter())
.for_each(|(r, s)| *r = s.clone());
result
}

View File

@@ -0,0 +1,15 @@
[package]
name = "sbv2_wasm"
version = "0.1.0"
edition = "2021"
[lib]
crate-type = ["cdylib", "rlib"]
[dependencies]
wasm-bindgen = "0.2.93"
sbv2_core = { path = "../sbv2_core", default-features = false, features = ["no_std"] }
once_cell.workspace = true
js-sys = "0.3.70"
ndarray.workspace = true
wasm-bindgen-futures = "0.4.43"

View File

@@ -0,0 +1,2 @@
# StyleBertVITS2 wasm
refer to https://github.com/tuna2134/sbv2-api

View File

@@ -0,0 +1,31 @@
{
"$schema": "https://biomejs.dev/schemas/1.9.2/schema.json",
"vcs": {
"enabled": false,
"clientKind": "git",
"useIgnoreFile": false
},
"files": {
"ignoreUnknown": false,
"ignore": []
},
"formatter": {
"enabled": true,
"indentStyle": "tab",
"ignore": ["dist/", "pkg/"]
},
"organizeImports": {
"enabled": true
},
"linter": {
"enabled": true,
"rules": {
"recommended": true
}
},
"javascript": {
"formatter": {
"quoteStyle": "double"
}
}
}

5
crates/sbv2_wasm/build.sh Executable file
View File

@@ -0,0 +1,5 @@
wasm-pack build --target web sbv2_wasm
wasm-opt -O3 -o ./sbv2_wasm/pkg/sbv2_wasm_bg.wasm ./sbv2_wasm/pkg/sbv2_wasm_bg.wasm
wasm-strip ./sbv2_wasm/pkg/sbv2_wasm_bg.wasm
mkdir -p ./sbv2_wasm/dist
cp ./sbv2_wasm/sbv2_wasm/pkg/sbv2_wasm_bg.wasm ./sbv2_wasm/dist/sbv2_wasm_bg.wasm

View File

@@ -0,0 +1,51 @@
<!doctype html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Style Bert VITS2 Web</title>
<script type="importmap">
{
"imports": {
"onnxruntime-web": "https://cdn.jsdelivr.net/npm/onnxruntime-web@1.19.2/dist/ort.all.min.mjs",
"sbv2": "https://cdn.jsdelivr.net/npm/sbv2@0.1.1+esm"
}
}
</script>
<script type="module" async defer>
import { ModelHolder } from "sbv2";
await ModelHolder.globalInit(
await (
await fetch("https://esm.sh/sbv2@0.1.1/dist/sbv2_wasm_bg.wasm", { cache: "force-cache" })
).arrayBuffer(),
);
const holder = await ModelHolder.create(
await (
await fetch("/models/tokenizer.json", { cache: "force-cache" })
).text(),
await (
await fetch("/models/deberta.onnx", { cache: "force-cache" })
).arrayBuffer(),
);
if (typeof window.onready == "function") {
window.onready(holder);
}
</script>
<script type="module" async defer>
window.onready = async function (holder) {
await holder.load(
"amitaro",
await (await fetch("/models/amitaro.sbv2")).arrayBuffer(),
);
const wave = await holder.synthesize("amitaro", "おはよう");
console.log(wave);
};
</script>
</head>
<body>
<div id="root"></div>
</body>
</html>

View File

@@ -0,0 +1,11 @@
import { ModelHolder } from "./dist/index.js";
import fs from "node:fs/promises";
ModelHolder.globalInit(await fs.readFile("./dist/sbv2_wasm_bg.wasm"));
const holder = await ModelHolder.create(
(await fs.readFile("../models/tokenizer.json")).toString("utf-8"),
await fs.readFile("../models/deberta.onnx"),
);
await holder.load("tsukuyomi", await fs.readFile("../models/iroha2.sbv2"));
await fs.writeFile("out.wav", await holder.synthesize("tsukuyomi", "おはよう"));
holder.unload("tsukuyomi");

View File

@@ -0,0 +1,29 @@
{
"name": "sbv2",
"version": "0.1.1",
"description": "Style Bert VITS2 wasm",
"main": "dist/index.js",
"types": "dist/index.d.ts",
"type": "module",
"scripts": {
"build": "tsc && esbuild src-js/index.ts --outfile=dist/index.js --minify --format=esm --bundle --external:onnxruntime-web",
"format": "biome format --write ."
},
"keywords": [],
"author": "tuna2134",
"license": "MIT",
"devDependencies": {
"@biomejs/biome": "^1.9.4",
"@types/node": "^22.13.5",
"esbuild": "^0.25.0",
"typescript": "^5.7.3"
},
"dependencies": {
"onnxruntime-web": "^1.20.1"
},
"files": [
"dist/*",
"package.json",
"README.md"
]
}

504
crates/sbv2_wasm/pnpm-lock.yaml generated Normal file
View File

@@ -0,0 +1,504 @@
lockfileVersion: '9.0'
settings:
autoInstallPeers: true
excludeLinksFromLockfile: false
importers:
.:
dependencies:
onnxruntime-web:
specifier: ^1.20.1
version: 1.20.1
devDependencies:
'@biomejs/biome':
specifier: ^1.9.4
version: 1.9.4
'@types/node':
specifier: ^22.13.5
version: 22.13.5
esbuild:
specifier: ^0.25.0
version: 0.25.0
typescript:
specifier: ^5.7.3
version: 5.7.3
packages:
'@biomejs/biome@1.9.4':
resolution: {integrity: sha512-1rkd7G70+o9KkTn5KLmDYXihGoTaIGO9PIIN2ZB7UJxFrWw04CZHPYiMRjYsaDvVV7hP1dYNRLxSANLaBFGpog==}
engines: {node: '>=14.21.3'}
hasBin: true
'@biomejs/cli-darwin-arm64@1.9.4':
resolution: {integrity: sha512-bFBsPWrNvkdKrNCYeAp+xo2HecOGPAy9WyNyB/jKnnedgzl4W4Hb9ZMzYNbf8dMCGmUdSavlYHiR01QaYR58cw==}
engines: {node: '>=14.21.3'}
cpu: [arm64]
os: [darwin]
'@biomejs/cli-darwin-x64@1.9.4':
resolution: {integrity: sha512-ngYBh/+bEedqkSevPVhLP4QfVPCpb+4BBe2p7Xs32dBgs7rh9nY2AIYUL6BgLw1JVXV8GlpKmb/hNiuIxfPfZg==}
engines: {node: '>=14.21.3'}
cpu: [x64]
os: [darwin]
'@biomejs/cli-linux-arm64-musl@1.9.4':
resolution: {integrity: sha512-v665Ct9WCRjGa8+kTr0CzApU0+XXtRgwmzIf1SeKSGAv+2scAlW6JR5PMFo6FzqqZ64Po79cKODKf3/AAmECqA==}
engines: {node: '>=14.21.3'}
cpu: [arm64]
os: [linux]
'@biomejs/cli-linux-arm64@1.9.4':
resolution: {integrity: sha512-fJIW0+LYujdjUgJJuwesP4EjIBl/N/TcOX3IvIHJQNsAqvV2CHIogsmA94BPG6jZATS4Hi+xv4SkBBQSt1N4/g==}
engines: {node: '>=14.21.3'}
cpu: [arm64]
os: [linux]
'@biomejs/cli-linux-x64-musl@1.9.4':
resolution: {integrity: sha512-gEhi/jSBhZ2m6wjV530Yy8+fNqG8PAinM3oV7CyO+6c3CEh16Eizm21uHVsyVBEB6RIM8JHIl6AGYCv6Q6Q9Tg==}
engines: {node: '>=14.21.3'}
cpu: [x64]
os: [linux]
'@biomejs/cli-linux-x64@1.9.4':
resolution: {integrity: sha512-lRCJv/Vi3Vlwmbd6K+oQ0KhLHMAysN8lXoCI7XeHlxaajk06u7G+UsFSO01NAs5iYuWKmVZjmiOzJ0OJmGsMwg==}
engines: {node: '>=14.21.3'}
cpu: [x64]
os: [linux]
'@biomejs/cli-win32-arm64@1.9.4':
resolution: {integrity: sha512-tlbhLk+WXZmgwoIKwHIHEBZUwxml7bRJgk0X2sPyNR3S93cdRq6XulAZRQJ17FYGGzWne0fgrXBKpl7l4M87Hg==}
engines: {node: '>=14.21.3'}
cpu: [arm64]
os: [win32]
'@biomejs/cli-win32-x64@1.9.4':
resolution: {integrity: sha512-8Y5wMhVIPaWe6jw2H+KlEm4wP/f7EW3810ZLmDlrEEy5KvBsb9ECEfu/kMWD484ijfQ8+nIi0giMgu9g1UAuuA==}
engines: {node: '>=14.21.3'}
cpu: [x64]
os: [win32]
'@esbuild/aix-ppc64@0.25.0':
resolution: {integrity: sha512-O7vun9Sf8DFjH2UtqK8Ku3LkquL9SZL8OLY1T5NZkA34+wG3OQF7cl4Ql8vdNzM6fzBbYfLaiRLIOZ+2FOCgBQ==}
engines: {node: '>=18'}
cpu: [ppc64]
os: [aix]
'@esbuild/android-arm64@0.25.0':
resolution: {integrity: sha512-grvv8WncGjDSyUBjN9yHXNt+cq0snxXbDxy5pJtzMKGmmpPxeAmAhWxXI+01lU5rwZomDgD3kJwulEnhTRUd6g==}
engines: {node: '>=18'}
cpu: [arm64]
os: [android]
'@esbuild/android-arm@0.25.0':
resolution: {integrity: sha512-PTyWCYYiU0+1eJKmw21lWtC+d08JDZPQ5g+kFyxP0V+es6VPPSUhM6zk8iImp2jbV6GwjX4pap0JFbUQN65X1g==}
engines: {node: '>=18'}
cpu: [arm]
os: [android]
'@esbuild/android-x64@0.25.0':
resolution: {integrity: sha512-m/ix7SfKG5buCnxasr52+LI78SQ+wgdENi9CqyCXwjVR2X4Jkz+BpC3le3AoBPYTC9NHklwngVXvbJ9/Akhrfg==}
engines: {node: '>=18'}
cpu: [x64]
os: [android]
'@esbuild/darwin-arm64@0.25.0':
resolution: {integrity: sha512-mVwdUb5SRkPayVadIOI78K7aAnPamoeFR2bT5nszFUZ9P8UpK4ratOdYbZZXYSqPKMHfS1wdHCJk1P1EZpRdvw==}
engines: {node: '>=18'}
cpu: [arm64]
os: [darwin]
'@esbuild/darwin-x64@0.25.0':
resolution: {integrity: sha512-DgDaYsPWFTS4S3nWpFcMn/33ZZwAAeAFKNHNa1QN0rI4pUjgqf0f7ONmXf6d22tqTY+H9FNdgeaAa+YIFUn2Rg==}
engines: {node: '>=18'}
cpu: [x64]
os: [darwin]
'@esbuild/freebsd-arm64@0.25.0':
resolution: {integrity: sha512-VN4ocxy6dxefN1MepBx/iD1dH5K8qNtNe227I0mnTRjry8tj5MRk4zprLEdG8WPyAPb93/e4pSgi1SoHdgOa4w==}
engines: {node: '>=18'}
cpu: [arm64]
os: [freebsd]
'@esbuild/freebsd-x64@0.25.0':
resolution: {integrity: sha512-mrSgt7lCh07FY+hDD1TxiTyIHyttn6vnjesnPoVDNmDfOmggTLXRv8Id5fNZey1gl/V2dyVK1VXXqVsQIiAk+A==}
engines: {node: '>=18'}
cpu: [x64]
os: [freebsd]
'@esbuild/linux-arm64@0.25.0':
resolution: {integrity: sha512-9QAQjTWNDM/Vk2bgBl17yWuZxZNQIF0OUUuPZRKoDtqF2k4EtYbpyiG5/Dk7nqeK6kIJWPYldkOcBqjXjrUlmg==}
engines: {node: '>=18'}
cpu: [arm64]
os: [linux]
'@esbuild/linux-arm@0.25.0':
resolution: {integrity: sha512-vkB3IYj2IDo3g9xX7HqhPYxVkNQe8qTK55fraQyTzTX/fxaDtXiEnavv9geOsonh2Fd2RMB+i5cbhu2zMNWJwg==}
engines: {node: '>=18'}
cpu: [arm]
os: [linux]
'@esbuild/linux-ia32@0.25.0':
resolution: {integrity: sha512-43ET5bHbphBegyeqLb7I1eYn2P/JYGNmzzdidq/w0T8E2SsYL1U6un2NFROFRg1JZLTzdCoRomg8Rvf9M6W6Gg==}
engines: {node: '>=18'}
cpu: [ia32]
os: [linux]
'@esbuild/linux-loong64@0.25.0':
resolution: {integrity: sha512-fC95c/xyNFueMhClxJmeRIj2yrSMdDfmqJnyOY4ZqsALkDrrKJfIg5NTMSzVBr5YW1jf+l7/cndBfP3MSDpoHw==}
engines: {node: '>=18'}
cpu: [loong64]
os: [linux]
'@esbuild/linux-mips64el@0.25.0':
resolution: {integrity: sha512-nkAMFju7KDW73T1DdH7glcyIptm95a7Le8irTQNO/qtkoyypZAnjchQgooFUDQhNAy4iu08N79W4T4pMBwhPwQ==}
engines: {node: '>=18'}
cpu: [mips64el]
os: [linux]
'@esbuild/linux-ppc64@0.25.0':
resolution: {integrity: sha512-NhyOejdhRGS8Iwv+KKR2zTq2PpysF9XqY+Zk77vQHqNbo/PwZCzB5/h7VGuREZm1fixhs4Q/qWRSi5zmAiO4Fw==}
engines: {node: '>=18'}
cpu: [ppc64]
os: [linux]
'@esbuild/linux-riscv64@0.25.0':
resolution: {integrity: sha512-5S/rbP5OY+GHLC5qXp1y/Mx//e92L1YDqkiBbO9TQOvuFXM+iDqUNG5XopAnXoRH3FjIUDkeGcY1cgNvnXp/kA==}
engines: {node: '>=18'}
cpu: [riscv64]
os: [linux]
'@esbuild/linux-s390x@0.25.0':
resolution: {integrity: sha512-XM2BFsEBz0Fw37V0zU4CXfcfuACMrppsMFKdYY2WuTS3yi8O1nFOhil/xhKTmE1nPmVyvQJjJivgDT+xh8pXJA==}
engines: {node: '>=18'}
cpu: [s390x]
os: [linux]
'@esbuild/linux-x64@0.25.0':
resolution: {integrity: sha512-9yl91rHw/cpwMCNytUDxwj2XjFpxML0y9HAOH9pNVQDpQrBxHy01Dx+vaMu0N1CKa/RzBD2hB4u//nfc+Sd3Cw==}
engines: {node: '>=18'}
cpu: [x64]
os: [linux]
'@esbuild/netbsd-arm64@0.25.0':
resolution: {integrity: sha512-RuG4PSMPFfrkH6UwCAqBzauBWTygTvb1nxWasEJooGSJ/NwRw7b2HOwyRTQIU97Hq37l3npXoZGYMy3b3xYvPw==}
engines: {node: '>=18'}
cpu: [arm64]
os: [netbsd]
'@esbuild/netbsd-x64@0.25.0':
resolution: {integrity: sha512-jl+qisSB5jk01N5f7sPCsBENCOlPiS/xptD5yxOx2oqQfyourJwIKLRA2yqWdifj3owQZCL2sn6o08dBzZGQzA==}
engines: {node: '>=18'}
cpu: [x64]
os: [netbsd]
'@esbuild/openbsd-arm64@0.25.0':
resolution: {integrity: sha512-21sUNbq2r84YE+SJDfaQRvdgznTD8Xc0oc3p3iW/a1EVWeNj/SdUCbm5U0itZPQYRuRTW20fPMWMpcrciH2EJw==}
engines: {node: '>=18'}
cpu: [arm64]
os: [openbsd]
'@esbuild/openbsd-x64@0.25.0':
resolution: {integrity: sha512-2gwwriSMPcCFRlPlKx3zLQhfN/2WjJ2NSlg5TKLQOJdV0mSxIcYNTMhk3H3ulL/cak+Xj0lY1Ym9ysDV1igceg==}
engines: {node: '>=18'}
cpu: [x64]
os: [openbsd]
'@esbuild/sunos-x64@0.25.0':
resolution: {integrity: sha512-bxI7ThgLzPrPz484/S9jLlvUAHYMzy6I0XiU1ZMeAEOBcS0VePBFxh1JjTQt3Xiat5b6Oh4x7UC7IwKQKIJRIg==}
engines: {node: '>=18'}
cpu: [x64]
os: [sunos]
'@esbuild/win32-arm64@0.25.0':
resolution: {integrity: sha512-ZUAc2YK6JW89xTbXvftxdnYy3m4iHIkDtK3CLce8wg8M2L+YZhIvO1DKpxrd0Yr59AeNNkTiic9YLf6FTtXWMw==}
engines: {node: '>=18'}
cpu: [arm64]
os: [win32]
'@esbuild/win32-ia32@0.25.0':
resolution: {integrity: sha512-eSNxISBu8XweVEWG31/JzjkIGbGIJN/TrRoiSVZwZ6pkC6VX4Im/WV2cz559/TXLcYbcrDN8JtKgd9DJVIo8GA==}
engines: {node: '>=18'}
cpu: [ia32]
os: [win32]
'@esbuild/win32-x64@0.25.0':
resolution: {integrity: sha512-ZENoHJBxA20C2zFzh6AI4fT6RraMzjYw4xKWemRTRmRVtN9c5DcH9r/f2ihEkMjOW5eGgrwCslG/+Y/3bL+DHQ==}
engines: {node: '>=18'}
cpu: [x64]
os: [win32]
'@protobufjs/aspromise@1.1.2':
resolution: {integrity: sha512-j+gKExEuLmKwvz3OgROXtrJ2UG2x8Ch2YZUxahh+s1F2HZ+wAceUNLkvy6zKCPVRkU++ZWQrdxsUeQXmcg4uoQ==}
'@protobufjs/base64@1.1.2':
resolution: {integrity: sha512-AZkcAA5vnN/v4PDqKyMR5lx7hZttPDgClv83E//FMNhR2TMcLUhfRUBHCmSl0oi9zMgDDqRUJkSxO3wm85+XLg==}
'@protobufjs/codegen@2.0.4':
resolution: {integrity: sha512-YyFaikqM5sH0ziFZCN3xDC7zeGaB/d0IUb9CATugHWbd1FRFwWwt4ld4OYMPWu5a3Xe01mGAULCdqhMlPl29Jg==}
'@protobufjs/eventemitter@1.1.0':
resolution: {integrity: sha512-j9ednRT81vYJ9OfVuXG6ERSTdEL1xVsNgqpkxMsbIabzSo3goCjDIveeGv5d03om39ML71RdmrGNjG5SReBP/Q==}
'@protobufjs/fetch@1.1.0':
resolution: {integrity: sha512-lljVXpqXebpsijW71PZaCYeIcE5on1w5DlQy5WH6GLbFryLUrBD4932W/E2BSpfRJWseIL4v/KPgBFxDOIdKpQ==}
'@protobufjs/float@1.0.2':
resolution: {integrity: sha512-Ddb+kVXlXst9d+R9PfTIxh1EdNkgoRe5tOX6t01f1lYWOvJnSPDBlG241QLzcyPdoNTsblLUdujGSE4RzrTZGQ==}
'@protobufjs/inquire@1.1.0':
resolution: {integrity: sha512-kdSefcPdruJiFMVSbn801t4vFK7KB/5gd2fYvrxhuJYg8ILrmn9SKSX2tZdV6V+ksulWqS7aXjBcRXl3wHoD9Q==}
'@protobufjs/path@1.1.2':
resolution: {integrity: sha512-6JOcJ5Tm08dOHAbdR3GrvP+yUUfkjG5ePsHYczMFLq3ZmMkAD98cDgcT2iA1lJ9NVwFd4tH/iSSoe44YWkltEA==}
'@protobufjs/pool@1.1.0':
resolution: {integrity: sha512-0kELaGSIDBKvcgS4zkjz1PeddatrjYcmMWOlAuAPwAeccUrPHdUqo/J6LiymHHEiJT5NrF1UVwxY14f+fy4WQw==}
'@protobufjs/utf8@1.1.0':
resolution: {integrity: sha512-Vvn3zZrhQZkkBE8LSuW3em98c0FwgO4nxzv6OdSxPKJIEKY2bGbHn+mhGIPerzI4twdxaP8/0+06HBpwf345Lw==}
'@types/node@22.13.5':
resolution: {integrity: sha512-+lTU0PxZXn0Dr1NBtC7Y8cR21AJr87dLLU953CWA6pMxxv/UDc7jYAY90upcrie1nRcD6XNG5HOYEDtgW5TxAg==}
esbuild@0.25.0:
resolution: {integrity: sha512-BXq5mqc8ltbaN34cDqWuYKyNhX8D/Z0J1xdtdQ8UcIIIyJyz+ZMKUt58tF3SrZ85jcfN/PZYhjR5uDQAYNVbuw==}
engines: {node: '>=18'}
hasBin: true
flatbuffers@1.12.0:
resolution: {integrity: sha512-c7CZADjRcl6j0PlvFy0ZqXQ67qSEZfrVPynmnL+2zPc+NtMvrF8Y0QceMo7QqnSPc7+uWjUIAbvCQ5WIKlMVdQ==}
guid-typescript@1.0.9:
resolution: {integrity: sha512-Y8T4vYhEfwJOTbouREvG+3XDsjr8E3kIr7uf+JZ0BYloFsttiHU0WfvANVsR7TxNUJa/WpCnw/Ino/p+DeBhBQ==}
long@5.3.1:
resolution: {integrity: sha512-ka87Jz3gcx/I7Hal94xaN2tZEOPoUOEVftkQqZx2EeQRN7LGdfLlI3FvZ+7WDplm+vK2Urx9ULrvSowtdCieng==}
onnxruntime-common@1.20.1:
resolution: {integrity: sha512-YiU0s0IzYYC+gWvqD1HzLc46Du1sXpSiwzKb63PACIJr6LfL27VsXSXQvt68EzD3V0D5Bc0vyJTjmMxp0ylQiw==}
onnxruntime-web@1.20.1:
resolution: {integrity: sha512-TePF6XVpLL1rWVMIl5Y9ACBQcyCNFThZON/jgElNd9Txb73CIEGlklhYR3UEr1cp5r0rbGI6nDwwrs79g7WjoA==}
platform@1.3.6:
resolution: {integrity: sha512-fnWVljUchTro6RiCFvCXBbNhJc2NijN7oIQxbwsyL0buWJPG85v81ehlHI9fXrJsMNgTofEoWIQeClKpgxFLrg==}
protobufjs@7.4.0:
resolution: {integrity: sha512-mRUWCc3KUU4w1jU8sGxICXH/gNS94DvI1gxqDvBzhj1JpcsimQkYiOJfwsPUykUI5ZaspFbSgmBLER8IrQ3tqw==}
engines: {node: '>=12.0.0'}
typescript@5.7.3:
resolution: {integrity: sha512-84MVSjMEHP+FQRPy3pX9sTVV/INIex71s9TL2Gm5FG/WG1SqXeKyZ0k7/blY/4FdOzI12CBy1vGc4og/eus0fw==}
engines: {node: '>=14.17'}
hasBin: true
undici-types@6.20.0:
resolution: {integrity: sha512-Ny6QZ2Nju20vw1SRHe3d9jVu6gJ+4e3+MMpqu7pqE5HT6WsTSlce++GQmK5UXS8mzV8DSYHrQH+Xrf2jVcuKNg==}
snapshots:
'@biomejs/biome@1.9.4':
optionalDependencies:
'@biomejs/cli-darwin-arm64': 1.9.4
'@biomejs/cli-darwin-x64': 1.9.4
'@biomejs/cli-linux-arm64': 1.9.4
'@biomejs/cli-linux-arm64-musl': 1.9.4
'@biomejs/cli-linux-x64': 1.9.4
'@biomejs/cli-linux-x64-musl': 1.9.4
'@biomejs/cli-win32-arm64': 1.9.4
'@biomejs/cli-win32-x64': 1.9.4
'@biomejs/cli-darwin-arm64@1.9.4':
optional: true
'@biomejs/cli-darwin-x64@1.9.4':
optional: true
'@biomejs/cli-linux-arm64-musl@1.9.4':
optional: true
'@biomejs/cli-linux-arm64@1.9.4':
optional: true
'@biomejs/cli-linux-x64-musl@1.9.4':
optional: true
'@biomejs/cli-linux-x64@1.9.4':
optional: true
'@biomejs/cli-win32-arm64@1.9.4':
optional: true
'@biomejs/cli-win32-x64@1.9.4':
optional: true
'@esbuild/aix-ppc64@0.25.0':
optional: true
'@esbuild/android-arm64@0.25.0':
optional: true
'@esbuild/android-arm@0.25.0':
optional: true
'@esbuild/android-x64@0.25.0':
optional: true
'@esbuild/darwin-arm64@0.25.0':
optional: true
'@esbuild/darwin-x64@0.25.0':
optional: true
'@esbuild/freebsd-arm64@0.25.0':
optional: true
'@esbuild/freebsd-x64@0.25.0':
optional: true
'@esbuild/linux-arm64@0.25.0':
optional: true
'@esbuild/linux-arm@0.25.0':
optional: true
'@esbuild/linux-ia32@0.25.0':
optional: true
'@esbuild/linux-loong64@0.25.0':
optional: true
'@esbuild/linux-mips64el@0.25.0':
optional: true
'@esbuild/linux-ppc64@0.25.0':
optional: true
'@esbuild/linux-riscv64@0.25.0':
optional: true
'@esbuild/linux-s390x@0.25.0':
optional: true
'@esbuild/linux-x64@0.25.0':
optional: true
'@esbuild/netbsd-arm64@0.25.0':
optional: true
'@esbuild/netbsd-x64@0.25.0':
optional: true
'@esbuild/openbsd-arm64@0.25.0':
optional: true
'@esbuild/openbsd-x64@0.25.0':
optional: true
'@esbuild/sunos-x64@0.25.0':
optional: true
'@esbuild/win32-arm64@0.25.0':
optional: true
'@esbuild/win32-ia32@0.25.0':
optional: true
'@esbuild/win32-x64@0.25.0':
optional: true
'@protobufjs/aspromise@1.1.2': {}
'@protobufjs/base64@1.1.2': {}
'@protobufjs/codegen@2.0.4': {}
'@protobufjs/eventemitter@1.1.0': {}
'@protobufjs/fetch@1.1.0':
dependencies:
'@protobufjs/aspromise': 1.1.2
'@protobufjs/inquire': 1.1.0
'@protobufjs/float@1.0.2': {}
'@protobufjs/inquire@1.1.0': {}
'@protobufjs/path@1.1.2': {}
'@protobufjs/pool@1.1.0': {}
'@protobufjs/utf8@1.1.0': {}
'@types/node@22.13.5':
dependencies:
undici-types: 6.20.0
esbuild@0.25.0:
optionalDependencies:
'@esbuild/aix-ppc64': 0.25.0
'@esbuild/android-arm': 0.25.0
'@esbuild/android-arm64': 0.25.0
'@esbuild/android-x64': 0.25.0
'@esbuild/darwin-arm64': 0.25.0
'@esbuild/darwin-x64': 0.25.0
'@esbuild/freebsd-arm64': 0.25.0
'@esbuild/freebsd-x64': 0.25.0
'@esbuild/linux-arm': 0.25.0
'@esbuild/linux-arm64': 0.25.0
'@esbuild/linux-ia32': 0.25.0
'@esbuild/linux-loong64': 0.25.0
'@esbuild/linux-mips64el': 0.25.0
'@esbuild/linux-ppc64': 0.25.0
'@esbuild/linux-riscv64': 0.25.0
'@esbuild/linux-s390x': 0.25.0
'@esbuild/linux-x64': 0.25.0
'@esbuild/netbsd-arm64': 0.25.0
'@esbuild/netbsd-x64': 0.25.0
'@esbuild/openbsd-arm64': 0.25.0
'@esbuild/openbsd-x64': 0.25.0
'@esbuild/sunos-x64': 0.25.0
'@esbuild/win32-arm64': 0.25.0
'@esbuild/win32-ia32': 0.25.0
'@esbuild/win32-x64': 0.25.0
flatbuffers@1.12.0: {}
guid-typescript@1.0.9: {}
long@5.3.1: {}
onnxruntime-common@1.20.1: {}
onnxruntime-web@1.20.1:
dependencies:
flatbuffers: 1.12.0
guid-typescript: 1.0.9
long: 5.3.1
onnxruntime-common: 1.20.1
platform: 1.3.6
protobufjs: 7.4.0
platform@1.3.6: {}
protobufjs@7.4.0:
dependencies:
'@protobufjs/aspromise': 1.1.2
'@protobufjs/base64': 1.1.2
'@protobufjs/codegen': 2.0.4
'@protobufjs/eventemitter': 1.1.0
'@protobufjs/fetch': 1.1.0
'@protobufjs/float': 1.0.2
'@protobufjs/inquire': 1.1.0
'@protobufjs/path': 1.1.2
'@protobufjs/pool': 1.1.0
'@protobufjs/utf8': 1.1.0
'@types/node': 22.13.5
long: 5.3.1
typescript@5.7.3: {}
undici-types@6.20.0: {}

View File

@@ -0,0 +1,106 @@
import * as wasm from "../pkg/sbv2_wasm.js";
import { InferenceSession, Tensor } from "onnxruntime-web";
export class ModelHolder {
private models: Map<string, [InferenceSession, wasm.StyleVectorWrap]> =
new Map();
constructor(
private tok: wasm.TokenizerWrap,
private deberta: InferenceSession,
) {}
public static async globalInit(buf: ArrayBufferLike) {
await wasm.default(buf);
}
public static async create(tok: string, deberta: ArrayBufferLike) {
return new ModelHolder(
wasm.load_tokenizer(tok),
await InferenceSession.create(deberta, {
executionProviders: ["webnn", "webgpu", "wasm", "cpu"],
graphOptimizationLevel: "all",
}),
);
}
public async synthesize(
name: string,
text: string,
style_id: number = 0,
style_weight: number = 1.0,
sdp_ratio: number = 0.4,
speed: number = 1.0,
) {
const mod = this.models.get(name);
if (!mod) throw new Error(`No model named ${name}`);
const [vits2, style] = mod;
return wasm.synthesize(
text,
this.tok,
async (a: BigInt64Array, b: BigInt64Array) => {
try {
const res = (
await this.deberta.run({
input_ids: new Tensor("int64", a, [1, a.length]),
attention_mask: new Tensor("int64", b, [1, b.length]),
})
)["output"];
return [new Uint32Array(res.dims), await res.getData(true)];
} catch (e) {
console.warn(e);
throw e;
}
},
async (
[a_shape, a_array]: any,
b_d: any,
c_d: any,
d_d: any,
e_d: any,
f: number,
g: number,
) => {
try {
const a = new Tensor("float32", a_array, [1, ...a_shape]);
const b = new Tensor("int64", b_d, [1, b_d.length]);
const c = new Tensor("int64", c_d, [1, c_d.length]);
const d = new Tensor("int64", d_d, [1, d_d.length]);
const e = new Tensor("float32", e_d, [1, e_d.length]);
const res = (
await vits2.run({
x_tst: b,
x_tst_lengths: new Tensor("int64", [b_d.length]),
sid: new Tensor("int64", [0]),
tones: c,
language: d,
bert: a,
style_vec: e,
sdp_ratio: new Tensor("float32", [f]),
length_scale: new Tensor("float32", [g]),
})
).output;
return [new Uint32Array(res.dims), await res.getData(true)];
} catch (e) {
console.warn(e);
throw e;
}
},
sdp_ratio,
1.0 / speed,
style_id,
style_weight,
style,
);
}
public async load(name: string, b: Uint8Array) {
const [style, vits2_b] = wasm.load_sbv2file(b);
const vits2 = await InferenceSession.create(vits2_b as Uint8Array, {
executionProviders: ["webnn", "webgpu", "wasm", "cpu"],
graphOptimizationLevel: "all",
});
this.models.set(name, [vits2, style]);
}
public async unload(name: string) {
return this.models.delete(name);
}
public modelList() {
return this.models.keys();
}
}

View File

@@ -0,0 +1,102 @@
pub fn vec8_to_array8(v: Vec<u8>) -> js_sys::Uint8Array {
let arr = js_sys::Uint8Array::new_with_length(v.len() as u32);
arr.copy_from(&v);
arr
}
pub fn vec_f32_to_array_f32(v: Vec<f32>) -> js_sys::Float32Array {
let arr = js_sys::Float32Array::new_with_length(v.len() as u32);
arr.copy_from(&v);
arr
}
pub fn array8_to_vec8(buf: js_sys::Uint8Array) -> Vec<u8> {
let mut body = vec![0; buf.length() as usize];
buf.copy_to(&mut body[..]);
body
}
pub fn vec64_to_array64(v: Vec<i64>) -> js_sys::BigInt64Array {
let arr = js_sys::BigInt64Array::new_with_length(v.len() as u32);
arr.copy_from(&v);
arr
}
pub fn vec_to_array(v: Vec<wasm_bindgen::JsValue>) -> js_sys::Array {
let arr = js_sys::Array::new_with_length(v.len() as u32);
for (i, v) in v.into_iter().enumerate() {
arr.set(i as u32, v);
}
arr
}
struct A {
shape: Vec<u32>,
data: Vec<f32>,
}
impl TryFrom<wasm_bindgen::JsValue> for A {
type Error = sbv2_core::error::Error;
fn try_from(value: wasm_bindgen::JsValue) -> Result<Self, Self::Error> {
let value: js_sys::Array = value.into();
let mut shape = vec![];
let mut data = vec![];
for (i, v) in value.iter().enumerate() {
match i {
0 => {
let v: js_sys::Uint32Array = v.into();
shape = vec![0; v.length() as usize];
v.copy_to(&mut shape);
}
1 => {
let v: js_sys::Float32Array = v.into();
data = vec![0.0; v.length() as usize];
v.copy_to(&mut data);
}
_ => {}
};
}
Ok(A { shape, data })
}
}
pub fn array_to_array2_f32(
a: wasm_bindgen::JsValue,
) -> sbv2_core::error::Result<ndarray::Array2<f32>> {
let a = A::try_from(a)?;
if a.shape.len() != 2 {
return Err(sbv2_core::error::Error::OtherError(
"Length mismatch".to_string(),
));
}
let shape = [a.shape[0] as usize, a.shape[1] as usize];
let arr = ndarray::Array2::from_shape_vec(shape, a.data.to_vec())
.map_err(|e| sbv2_core::error::Error::OtherError(e.to_string()))?;
Ok(arr)
}
pub fn array_to_array3_f32(
a: wasm_bindgen::JsValue,
) -> sbv2_core::error::Result<ndarray::Array3<f32>> {
let a = A::try_from(a)?;
if a.shape.len() != 3 {
return Err(sbv2_core::error::Error::OtherError(
"Length mismatch".to_string(),
));
}
let shape = [
a.shape[0] as usize,
a.shape[1] as usize,
a.shape[2] as usize,
];
let arr = ndarray::Array3::from_shape_vec(shape, a.data.to_vec())
.map_err(|e| sbv2_core::error::Error::OtherError(e.to_string()))?;
Ok(arr)
}
pub fn array2_f32_to_array(a: ndarray::Array2<f32>) -> js_sys::Array {
let shape: Vec<wasm_bindgen::JsValue> = a.shape().iter().map(|f| (*f as u32).into()).collect();
let typed_array = js_sys::Float32Array::new_with_length(a.len() as u32);
typed_array.copy_from(&a.into_flat().to_vec());
vec_to_array(vec![vec_to_array(shape).into(), typed_array.into()])
}

123
crates/sbv2_wasm/src/lib.rs Normal file
View File

@@ -0,0 +1,123 @@
use once_cell::sync::Lazy;
use sbv2_core::*;
use wasm_bindgen::prelude::*;
use wasm_bindgen_futures::JsFuture;
mod array_helper;
static JTALK: Lazy<jtalk::JTalk> = Lazy::new(|| jtalk::JTalk::new().unwrap());
#[wasm_bindgen]
pub struct TokenizerWrap {
tokenizer: tokenizer::Tokenizer,
}
#[wasm_bindgen]
pub fn load_tokenizer(s: js_sys::JsString) -> Result<TokenizerWrap, JsError> {
if let Some(s) = s.as_string() {
Ok(TokenizerWrap {
tokenizer: tokenizer::Tokenizer::from_bytes(s.as_bytes())
.map_err(|e| JsError::new(&e.to_string()))?,
})
} else {
Err(JsError::new("invalid utf8"))
}
}
#[wasm_bindgen]
pub struct StyleVectorWrap {
style_vector: ndarray::Array2<f32>,
}
#[wasm_bindgen]
pub fn load_sbv2file(buf: js_sys::Uint8Array) -> Result<js_sys::Array, JsError> {
let (style_vectors, vits2) = sbv2file::parse_sbv2file(array_helper::array8_to_vec8(buf))?;
let buf = array_helper::vec8_to_array8(vits2);
Ok(array_helper::vec_to_array(vec![
StyleVectorWrap {
style_vector: style::load_style(style_vectors)?,
}
.into(),
buf.into(),
]))
}
#[allow(clippy::too_many_arguments)]
#[wasm_bindgen]
pub async fn synthesize(
text: &str,
tokenizer: &TokenizerWrap,
bert_predict_fn: js_sys::Function,
synthesize_fn: js_sys::Function,
sdp_ratio: f32,
length_scale: f32,
style_id: i32,
style_weight: f32,
style_vectors: &StyleVectorWrap,
) -> Result<js_sys::Uint8Array, JsError> {
let synthesize_wrap = |bert_ori: ndarray::Array2<f32>,
x_tst: ndarray::Array1<i64>,
tones: ndarray::Array1<i64>,
lang_ids: ndarray::Array1<i64>,
style_vector: ndarray::Array1<f32>,
sdp_ratio: f32,
length_scale: f32| async move {
let arr = array_helper::vec_to_array(vec![
array_helper::array2_f32_to_array(bert_ori).into(),
array_helper::vec64_to_array64(x_tst.to_vec()).into(),
array_helper::vec64_to_array64(tones.to_vec()).into(),
array_helper::vec64_to_array64(lang_ids.to_vec()).into(),
array_helper::vec_f32_to_array_f32(style_vector.to_vec()).into(),
sdp_ratio.into(),
length_scale.into(),
]);
let res = synthesize_fn
.apply(&js_sys::Object::new().into(), &arr)
.map_err(|e| {
error::Error::OtherError(e.as_string().unwrap_or("unknown".to_string()))
})?;
let res = JsFuture::from(Into::<js_sys::Promise>::into(res))
.await
.map_err(|e| {
sbv2_core::error::Error::OtherError(e.as_string().unwrap_or("unknown".to_string()))
})?;
array_helper::array_to_array3_f32(res)
};
let (bert_ori, phones, tones, lang_ids) = tts_util::parse_text(
text,
&JTALK,
&tokenizer.tokenizer,
|token_ids: Vec<i64>, attention_masks: Vec<i64>| {
Box::pin(async move {
let arr = array_helper::vec_to_array(vec![
array_helper::vec64_to_array64(token_ids).into(),
array_helper::vec64_to_array64(attention_masks).into(),
]);
let res = bert_predict_fn
.apply(&js_sys::Object::new().into(), &arr)
.map_err(|e| {
error::Error::OtherError(e.as_string().unwrap_or("unknown".to_string()))
})?;
let res = JsFuture::from(Into::<js_sys::Promise>::into(res))
.await
.map_err(|e| {
sbv2_core::error::Error::OtherError(
e.as_string().unwrap_or("unknown".to_string()),
)
})?;
array_helper::array_to_array2_f32(res)
})
},
)
.await?;
let audio = synthesize_wrap(
bert_ori.to_owned(),
phones,
tones,
lang_ids,
style::get_style_vector(&style_vectors.style_vector, style_id, style_weight)?,
sdp_ratio,
length_scale,
)
.await?;
Ok(array_helper::vec8_to_array8(tts_util::array_to_vec(audio)?))
}

View File

@@ -0,0 +1,15 @@
{
"compilerOptions": {
"target": "ESNext",
"module": "ESNext",
"rootDir": "./src-js",
"outDir": "./dist",
"moduleResolution": "node",
"esModuleInterop": true,
"forceConsistentCasingInFileNames": true,
"strict": true,
"skipLibCheck": true,
"declaration": true,
"emitDeclarationOnly": true
}
}