mirror of
https://github.com/neodyland/sbv2-api.git
synced 2025-12-22 23:49:58 +00:00
breaking: support of length_scale, sdp_ratio, /models endpoint
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
BERT_MODEL_PATH=models/deberta.onnx
|
||||
MODEL_PATH=models/model_tsukuyomi.onnx
|
||||
MODELS_PATH=models
|
||||
STYLE_VECTORS_PATH=models/style_vectors.json
|
||||
TOKENIZER_PATH=models/tokenizer.json
|
||||
TOKENIZER_PATH=models/tokenizer.json
|
||||
ADDR=localhost:3000
|
||||
99
Cargo.lock
generated
99
Cargo.lock
generated
@@ -26,6 +26,55 @@ dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anstream"
|
||||
version = "0.6.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "64e15c1ab1f89faffbf04a634d5e1962e9074f2741eef6d97f3c4e322426d526"
|
||||
dependencies = [
|
||||
"anstyle",
|
||||
"anstyle-parse",
|
||||
"anstyle-query",
|
||||
"anstyle-wincon",
|
||||
"colorchoice",
|
||||
"is_terminal_polyfill",
|
||||
"utf8parse",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anstyle"
|
||||
version = "1.0.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1"
|
||||
|
||||
[[package]]
|
||||
name = "anstyle-parse"
|
||||
version = "0.2.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "eb47de1e80c2b463c735db5b217a0ddc39d612e7ac9e2e96a5aed1f57616c1cb"
|
||||
dependencies = [
|
||||
"utf8parse",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anstyle-query"
|
||||
version = "1.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6d36fc52c7f6c869915e99412912f22093507da8d9e942ceaf66fe4b7c14422a"
|
||||
dependencies = [
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anstyle-wincon"
|
||||
version = "3.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5bf74e1b6e971609db8ca7a9ce79fd5768ab6ae46441c572e46cf596f59e57f8"
|
||||
dependencies = [
|
||||
"anstyle",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anyhow"
|
||||
version = "1.0.87"
|
||||
@@ -188,6 +237,12 @@ version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||
|
||||
[[package]]
|
||||
name = "colorchoice"
|
||||
version = "1.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0"
|
||||
|
||||
[[package]]
|
||||
name = "console"
|
||||
version = "0.15.8"
|
||||
@@ -457,6 +512,29 @@ dependencies = [
|
||||
"encoding_rs",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "env_filter"
|
||||
version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4f2c92ceda6ceec50f43169f9ee8424fe2db276791afde7b2cd8bc084cb376ab"
|
||||
dependencies = [
|
||||
"log",
|
||||
"regex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "env_logger"
|
||||
version = "0.11.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e13fa619b91fb2381732789fc5de83b45675e882f66623b7d8cb4f643017018d"
|
||||
dependencies = [
|
||||
"anstream",
|
||||
"anstyle",
|
||||
"env_filter",
|
||||
"humantime",
|
||||
"log",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "errno"
|
||||
version = "0.3.9"
|
||||
@@ -653,6 +731,12 @@ version = "1.0.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
|
||||
|
||||
[[package]]
|
||||
name = "humantime"
|
||||
version = "2.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4"
|
||||
|
||||
[[package]]
|
||||
name = "hyper"
|
||||
version = "1.4.1"
|
||||
@@ -725,6 +809,12 @@ dependencies = [
|
||||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "is_terminal_polyfill"
|
||||
version = "1.70.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
|
||||
|
||||
[[package]]
|
||||
name = "itertools"
|
||||
version = "0.11.0"
|
||||
@@ -1642,6 +1732,8 @@ dependencies = [
|
||||
"anyhow",
|
||||
"axum",
|
||||
"dotenvy",
|
||||
"env_logger",
|
||||
"log",
|
||||
"sbv2_core",
|
||||
"serde",
|
||||
"tokio",
|
||||
@@ -1652,6 +1744,7 @@ name = "sbv2_core"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"dotenvy",
|
||||
"hound",
|
||||
"jpreprocess",
|
||||
"ndarray",
|
||||
@@ -2101,6 +2194,12 @@ dependencies = [
|
||||
"percent-encoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "utf8parse"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
|
||||
|
||||
[[package]]
|
||||
name = "version_check"
|
||||
version = "0.9.5"
|
||||
|
||||
@@ -4,3 +4,4 @@ members = ["sbv2_api", "sbv2_core"]
|
||||
|
||||
[workspace.dependencies]
|
||||
anyhow = "1.0.86"
|
||||
dotenvy = "0.15.7"
|
||||
@@ -90,8 +90,18 @@ model = get_net_g(
|
||||
)
|
||||
|
||||
|
||||
def forward(*args):
|
||||
return model.infer(*args)
|
||||
def forward(x, x_len, sid, tone, lang, bert, style, length_scale, sdp_ratio):
|
||||
return model.infer(
|
||||
x,
|
||||
x_len,
|
||||
sid,
|
||||
tone,
|
||||
lang,
|
||||
bert,
|
||||
style,
|
||||
sdp_ratio=sdp_ratio,
|
||||
length_scale=length_scale,
|
||||
)
|
||||
|
||||
|
||||
model.forward = forward
|
||||
@@ -106,6 +116,8 @@ torch.onnx.export(
|
||||
lang_ids,
|
||||
bert,
|
||||
style_vec_tensor,
|
||||
torch.tensor(1.0),
|
||||
torch.tensor(0.0),
|
||||
),
|
||||
f"../models/model_{out_name}.onnx",
|
||||
verbose=True,
|
||||
@@ -124,6 +136,8 @@ torch.onnx.export(
|
||||
"language",
|
||||
"bert",
|
||||
"style_vec",
|
||||
"length_scale",
|
||||
"sdp_ratio",
|
||||
],
|
||||
output_names=["output"],
|
||||
)
|
||||
|
||||
@@ -1 +1 @@
|
||||
docker run -it --rm -p 3000:3000 --name sbv2 -v ./models:/work/models --env-file .env.sample sbv2
|
||||
docker run -it --rm -p 3000:3000 --name sbv2 -v ./models:/work/models --env-file .env sbv2
|
||||
@@ -6,7 +6,9 @@ edition = "2021"
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
axum = "0.7.5"
|
||||
dotenvy = "0.15.7"
|
||||
dotenvy.workspace = true
|
||||
env_logger = "0.11.5"
|
||||
log = "0.4.22"
|
||||
sbv2_core = { version = "0.1.0", path = "../sbv2_core" }
|
||||
serde = { version = "1.0.210", features = ["derive"] }
|
||||
tokio = { version = "1.40.0", features = ["full"] }
|
||||
@@ -14,4 +16,4 @@ tokio = { version = "1.40.0", features = ["full"] }
|
||||
[features]
|
||||
cuda = ["sbv2_core/cuda"]
|
||||
cuda_tf32 = ["sbv2_core/cuda_tf32"]
|
||||
dynamic = ["sbv2_core/dynamic"]
|
||||
dynamic = ["sbv2_core/dynamic"]
|
||||
|
||||
@@ -15,33 +15,39 @@ use tokio::sync::Mutex;
|
||||
mod error;
|
||||
use crate::error::AppResult;
|
||||
|
||||
async fn models(State(state): State<AppState>) -> AppResult<impl IntoResponse> {
|
||||
Ok(Json(state.tts_model.lock().await.models()))
|
||||
}
|
||||
|
||||
fn sdp_default() -> f32 {
|
||||
0.0
|
||||
}
|
||||
|
||||
fn length_default() -> f32 {
|
||||
1.0
|
||||
}
|
||||
#[derive(Deserialize)]
|
||||
struct SynthesizeRequest {
|
||||
text: String,
|
||||
ident: String,
|
||||
#[serde(default = "sdp_default")]
|
||||
sdp_ratio: f32,
|
||||
#[serde(default = "length_default")]
|
||||
length_scale: f32,
|
||||
}
|
||||
|
||||
async fn synthesize(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Json(SynthesizeRequest { text, ident }): Json<SynthesizeRequest>,
|
||||
State(state): State<AppState>,
|
||||
Json(SynthesizeRequest {
|
||||
text,
|
||||
ident,
|
||||
sdp_ratio,
|
||||
length_scale,
|
||||
}): Json<SynthesizeRequest>,
|
||||
) -> AppResult<impl IntoResponse> {
|
||||
log::debug!("processing request: text={text}, ident={ident}, sdp_ratio={sdp_ratio}, length_scale={length_scale}");
|
||||
let buffer = {
|
||||
let mut tts_model = state.tts_model.lock().await;
|
||||
let tts_model = if let Some(tts_model) = &*tts_model {
|
||||
tts_model
|
||||
} else {
|
||||
let mut tts_holder = TTSModelHolder::new(
|
||||
&fs::read(env::var("BERT_MODEL_PATH")?).await?,
|
||||
&fs::read(env::var("TOKENIZER_PATH")?).await?,
|
||||
)?;
|
||||
tts_holder.load(
|
||||
"tsukuyomi",
|
||||
fs::read(env::var("STYLE_VECTORS_PATH")?).await?,
|
||||
fs::read(env::var("MODEL_PATH")?).await?,
|
||||
)?;
|
||||
*tts_model = Some(tts_holder);
|
||||
tts_model.as_ref().unwrap()
|
||||
};
|
||||
let tts_model = state.tts_model.lock().await;
|
||||
let (bert_ori, phones, tones, lang_ids) = tts_model.parse_text(&text)?;
|
||||
let style_vector = tts_model.get_style_vector(&ident, 0, 1.0)?;
|
||||
tts_model.synthesize(
|
||||
@@ -51,26 +57,78 @@ async fn synthesize(
|
||||
tones,
|
||||
lang_ids,
|
||||
style_vector,
|
||||
sdp_ratio,
|
||||
length_scale,
|
||||
)?
|
||||
};
|
||||
Ok(([(CONTENT_TYPE, "audio/wav")], buffer))
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct AppState {
|
||||
tts_model: Arc<Mutex<Option<TTSModelHolder>>>,
|
||||
tts_model: Arc<Mutex<TTSModelHolder>>,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
pub async fn new() -> anyhow::Result<Self> {
|
||||
let mut tts_model = TTSModelHolder::new(
|
||||
&fs::read(env::var("BERT_MODEL_PATH")?).await?,
|
||||
&fs::read(env::var("TOKENIZER_PATH")?).await?,
|
||||
)?;
|
||||
let models = env::var("MODELS_PATH").unwrap_or("models".to_string());
|
||||
let mut f = fs::read_dir(&models).await?;
|
||||
let mut entries = vec![];
|
||||
while let Ok(Some(e)) = f.next_entry().await {
|
||||
let name = e.file_name().to_string_lossy().to_string();
|
||||
if name.ends_with(".onnx") && name.starts_with("model_") {
|
||||
let name_len = name.len();
|
||||
let name = name.chars();
|
||||
entries.push(
|
||||
name.collect::<Vec<_>>()[6..name_len - 5]
|
||||
.iter()
|
||||
.collect::<String>(),
|
||||
);
|
||||
}
|
||||
}
|
||||
for entry in entries {
|
||||
log::info!("Try loading: {entry}");
|
||||
let style_vectors_bytes =
|
||||
match fs::read(format!("{models}/style_vectors_{entry}.json")).await {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
log::warn!("Error loading style_vectors_bytes from file {entry}: {e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let vits2_bytes = match fs::read(format!("{models}/model_{entry}.onnx")).await {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
log::warn!("Error loading vits2_bytes from file {entry}: {e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
if let Err(e) = tts_model.load(&entry, style_vectors_bytes, vits2_bytes) {
|
||||
log::warn!("Error loading {entry}: {e}");
|
||||
};
|
||||
}
|
||||
Ok(Self {
|
||||
tts_model: Arc::new(Mutex::new(tts_model)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
dotenvy::dotenv().ok();
|
||||
env_logger::init();
|
||||
let app = Router::new()
|
||||
.route("/", get(|| async { "Hello, World!" }))
|
||||
.route("/synthesize", post(synthesize))
|
||||
.with_state(Arc::new(AppState {
|
||||
tts_model: Arc::new(Mutex::new(None)),
|
||||
}));
|
||||
|
||||
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await?;
|
||||
.route("/models", get(models))
|
||||
.with_state(AppState::new().await?);
|
||||
let addr = env::var("ADDR").unwrap_or("0.0.0.0:3000".to_string());
|
||||
let listener = tokio::net::TcpListener::bind(&addr).await?;
|
||||
log::info!("Listening on {addr}");
|
||||
axum::serve(listener, app).await?;
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -5,6 +5,7 @@ edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
dotenvy.workspace = true
|
||||
hound = "3.5.1"
|
||||
jpreprocess = { version = "0.10.0", features = ["naist-jdic"] }
|
||||
ndarray = "0.16.1"
|
||||
|
||||
@@ -1,41 +1,47 @@
|
||||
use std::{fs, time::Instant};
|
||||
|
||||
use sbv2_core::{error, tts};
|
||||
use sbv2_core::tts;
|
||||
use std::env;
|
||||
|
||||
fn main() -> error::Result<()> {
|
||||
fn main() -> anyhow::Result<()> {
|
||||
dotenvy::dotenv().ok();
|
||||
let text = "眠たい";
|
||||
|
||||
let mut tts_model = tts::TTSModelHolder::new(
|
||||
fs::read("models/debert.onnx")?,
|
||||
fs::read("models/model_opt.onnx")?,
|
||||
let ident = "aaa";
|
||||
let mut tts_holder = tts::TTSModelHolder::new(
|
||||
&fs::read(env::var("BERT_MODEL_PATH")?)?,
|
||||
&fs::read(env::var("TOKENIZER_PATH")?)?,
|
||||
)?;
|
||||
tts_model.load(
|
||||
"tsukuyomi",
|
||||
fs::read("models/style_vectors.json")?,
|
||||
fs::read("models/tokenizer.json")?,
|
||||
tts_holder.load(
|
||||
ident,
|
||||
fs::read(env::var("STYLE_VECTORS_PATH")?)?,
|
||||
fs::read(env::var("MODEL_PATH")?)?,
|
||||
)?;
|
||||
|
||||
let (bert_ori, phones, tones, lang_ids) = tts_model.parse_text(text)?;
|
||||
let (bert_ori, phones, tones, lang_ids) = tts_holder.parse_text(text)?;
|
||||
|
||||
let style_vector = tts_model.get_style_vector("tsukuyomi", 0, 1.0)?;
|
||||
let data = tts_model.synthesize(
|
||||
"tsukuyomi",
|
||||
let style_vector = tts_holder.get_style_vector(ident, 0, 1.0)?;
|
||||
let data = tts_holder.synthesize(
|
||||
ident,
|
||||
bert_ori.to_owned(),
|
||||
phones.clone(),
|
||||
tones.clone(),
|
||||
lang_ids.clone(),
|
||||
style_vector.clone(),
|
||||
0.0,
|
||||
0.5,
|
||||
)?;
|
||||
std::fs::write("output.wav", data)?;
|
||||
let now = Instant::now();
|
||||
for _ in 0..10 {
|
||||
tts_model.synthesize(
|
||||
"tsukuyomi",
|
||||
tts_holder.synthesize(
|
||||
ident,
|
||||
bert_ori.to_owned(),
|
||||
phones.clone(),
|
||||
tones.clone(),
|
||||
lang_ids.clone(),
|
||||
style_vector.clone(),
|
||||
0.0,
|
||||
1.0,
|
||||
)?;
|
||||
}
|
||||
println!("Time taken: {}", now.elapsed().as_millis());
|
||||
|
||||
@@ -26,7 +26,7 @@ pub fn load_model<P: AsRef<[u8]>>(model_file: P) -> Result<Session> {
|
||||
.with_inter_threads(num_cpus::get_physical())?
|
||||
.commit_from_memory(model_file.as_ref())?)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn synthesize(
|
||||
session: &Session,
|
||||
bert_ori: Array2<f32>,
|
||||
@@ -34,6 +34,8 @@ pub fn synthesize(
|
||||
tones: Array1<i64>,
|
||||
lang_ids: Array1<i64>,
|
||||
style_vector: Array1<f32>,
|
||||
sdp_ratio: f32,
|
||||
length_scale: f32,
|
||||
) -> Result<Vec<u8>> {
|
||||
let bert = bert_ori.insert_axis(Axis(0));
|
||||
let x_tst_lengths: Array1<i64> = array![x_tst.shape()[0] as i64];
|
||||
@@ -49,6 +51,8 @@ pub fn synthesize(
|
||||
"language" => lang_ids,
|
||||
"bert" => bert,
|
||||
"style_vec" => style_vector,
|
||||
"sdp_ratio" => array![sdp_ratio],
|
||||
"length_scale" => array![length_scale],
|
||||
}?)?;
|
||||
|
||||
let audio_array = outputs
|
||||
|
||||
@@ -48,6 +48,9 @@ impl TTSModelHolder {
|
||||
tokenizer,
|
||||
})
|
||||
}
|
||||
pub fn models(&self) -> Vec<String> {
|
||||
self.models.iter().map(|m| m.ident.to_string()).collect()
|
||||
}
|
||||
pub fn load<I: Into<TTSIdent>, P: AsRef<[u8]>>(
|
||||
&mut self,
|
||||
ident: I,
|
||||
@@ -156,7 +159,7 @@ impl TTSModelHolder {
|
||||
) -> Result<Array1<f32>> {
|
||||
style::get_style_vector(&self.find_model(ident)?.style_vectors, style_id, weight)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn synthesize<I: Into<TTSIdent>>(
|
||||
&self,
|
||||
ident: I,
|
||||
@@ -165,6 +168,8 @@ impl TTSModelHolder {
|
||||
tones: Array1<i64>,
|
||||
lang_ids: Array1<i64>,
|
||||
style_vector: Array1<f32>,
|
||||
sdp_ratio: f32,
|
||||
length_scale: f32,
|
||||
) -> Result<Vec<u8>> {
|
||||
let buffer = model::synthesize(
|
||||
&self.find_model(ident)?.vits2,
|
||||
@@ -173,6 +178,8 @@ impl TTSModelHolder {
|
||||
tones,
|
||||
lang_ids,
|
||||
style_vector,
|
||||
sdp_ratio,
|
||||
length_scale,
|
||||
)?;
|
||||
Ok(buffer)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user