breaking: support of length_scale, sdp_ratio, /models endpoint

This commit is contained in:
Googlefan
2024-09-11 04:42:11 +00:00
parent 83b69083ca
commit 441e35b9a6
12 changed files with 243 additions and 49 deletions

View File

@@ -1,4 +1,6 @@
BERT_MODEL_PATH=models/deberta.onnx
MODEL_PATH=models/model_tsukuyomi.onnx
MODELS_PATH=models
STYLE_VECTORS_PATH=models/style_vectors.json
TOKENIZER_PATH=models/tokenizer.json
TOKENIZER_PATH=models/tokenizer.json
ADDR=localhost:3000

99
Cargo.lock generated
View File

@@ -26,6 +26,55 @@ dependencies = [
"memchr",
]
[[package]]
name = "anstream"
version = "0.6.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "64e15c1ab1f89faffbf04a634d5e1962e9074f2741eef6d97f3c4e322426d526"
dependencies = [
"anstyle",
"anstyle-parse",
"anstyle-query",
"anstyle-wincon",
"colorchoice",
"is_terminal_polyfill",
"utf8parse",
]
[[package]]
name = "anstyle"
version = "1.0.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1"
[[package]]
name = "anstyle-parse"
version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eb47de1e80c2b463c735db5b217a0ddc39d612e7ac9e2e96a5aed1f57616c1cb"
dependencies = [
"utf8parse",
]
[[package]]
name = "anstyle-query"
version = "1.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d36fc52c7f6c869915e99412912f22093507da8d9e942ceaf66fe4b7c14422a"
dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "anstyle-wincon"
version = "3.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5bf74e1b6e971609db8ca7a9ce79fd5768ab6ae46441c572e46cf596f59e57f8"
dependencies = [
"anstyle",
"windows-sys 0.52.0",
]
[[package]]
name = "anyhow"
version = "1.0.87"
@@ -188,6 +237,12 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "colorchoice"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0"
[[package]]
name = "console"
version = "0.15.8"
@@ -457,6 +512,29 @@ dependencies = [
"encoding_rs",
]
[[package]]
name = "env_filter"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4f2c92ceda6ceec50f43169f9ee8424fe2db276791afde7b2cd8bc084cb376ab"
dependencies = [
"log",
"regex",
]
[[package]]
name = "env_logger"
version = "0.11.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e13fa619b91fb2381732789fc5de83b45675e882f66623b7d8cb4f643017018d"
dependencies = [
"anstream",
"anstyle",
"env_filter",
"humantime",
"log",
]
[[package]]
name = "errno"
version = "0.3.9"
@@ -653,6 +731,12 @@ version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
[[package]]
name = "humantime"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4"
[[package]]
name = "hyper"
version = "1.4.1"
@@ -725,6 +809,12 @@ dependencies = [
"cfg-if",
]
[[package]]
name = "is_terminal_polyfill"
version = "1.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
[[package]]
name = "itertools"
version = "0.11.0"
@@ -1642,6 +1732,8 @@ dependencies = [
"anyhow",
"axum",
"dotenvy",
"env_logger",
"log",
"sbv2_core",
"serde",
"tokio",
@@ -1652,6 +1744,7 @@ name = "sbv2_core"
version = "0.1.0"
dependencies = [
"anyhow",
"dotenvy",
"hound",
"jpreprocess",
"ndarray",
@@ -2101,6 +2194,12 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "utf8parse"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
[[package]]
name = "version_check"
version = "0.9.5"

View File

@@ -4,3 +4,4 @@ members = ["sbv2_api", "sbv2_core"]
[workspace.dependencies]
anyhow = "1.0.86"
dotenvy = "0.15.7"

View File

@@ -90,8 +90,18 @@ model = get_net_g(
)
def forward(*args):
return model.infer(*args)
def forward(x, x_len, sid, tone, lang, bert, style, length_scale, sdp_ratio):
return model.infer(
x,
x_len,
sid,
tone,
lang,
bert,
style,
sdp_ratio=sdp_ratio,
length_scale=length_scale,
)
model.forward = forward
@@ -106,6 +116,8 @@ torch.onnx.export(
lang_ids,
bert,
style_vec_tensor,
torch.tensor(1.0),
torch.tensor(0.0),
),
f"../models/model_{out_name}.onnx",
verbose=True,
@@ -124,6 +136,8 @@ torch.onnx.export(
"language",
"bert",
"style_vec",
"length_scale",
"sdp_ratio",
],
output_names=["output"],
)

View File

@@ -1 +1 @@
docker run -it --rm -p 3000:3000 --name sbv2 -v ./models:/work/models --env-file .env.sample sbv2
docker run -it --rm -p 3000:3000 --name sbv2 -v ./models:/work/models --env-file .env sbv2

View File

@@ -6,7 +6,9 @@ edition = "2021"
[dependencies]
anyhow.workspace = true
axum = "0.7.5"
dotenvy = "0.15.7"
dotenvy.workspace = true
env_logger = "0.11.5"
log = "0.4.22"
sbv2_core = { version = "0.1.0", path = "../sbv2_core" }
serde = { version = "1.0.210", features = ["derive"] }
tokio = { version = "1.40.0", features = ["full"] }
@@ -14,4 +16,4 @@ tokio = { version = "1.40.0", features = ["full"] }
[features]
cuda = ["sbv2_core/cuda"]
cuda_tf32 = ["sbv2_core/cuda_tf32"]
dynamic = ["sbv2_core/dynamic"]
dynamic = ["sbv2_core/dynamic"]

View File

@@ -15,33 +15,39 @@ use tokio::sync::Mutex;
mod error;
use crate::error::AppResult;
async fn models(State(state): State<AppState>) -> AppResult<impl IntoResponse> {
Ok(Json(state.tts_model.lock().await.models()))
}
fn sdp_default() -> f32 {
0.0
}
fn length_default() -> f32 {
1.0
}
#[derive(Deserialize)]
struct SynthesizeRequest {
text: String,
ident: String,
#[serde(default = "sdp_default")]
sdp_ratio: f32,
#[serde(default = "length_default")]
length_scale: f32,
}
async fn synthesize(
State(state): State<Arc<AppState>>,
Json(SynthesizeRequest { text, ident }): Json<SynthesizeRequest>,
State(state): State<AppState>,
Json(SynthesizeRequest {
text,
ident,
sdp_ratio,
length_scale,
}): Json<SynthesizeRequest>,
) -> AppResult<impl IntoResponse> {
log::debug!("processing request: text={text}, ident={ident}, sdp_ratio={sdp_ratio}, length_scale={length_scale}");
let buffer = {
let mut tts_model = state.tts_model.lock().await;
let tts_model = if let Some(tts_model) = &*tts_model {
tts_model
} else {
let mut tts_holder = TTSModelHolder::new(
&fs::read(env::var("BERT_MODEL_PATH")?).await?,
&fs::read(env::var("TOKENIZER_PATH")?).await?,
)?;
tts_holder.load(
"tsukuyomi",
fs::read(env::var("STYLE_VECTORS_PATH")?).await?,
fs::read(env::var("MODEL_PATH")?).await?,
)?;
*tts_model = Some(tts_holder);
tts_model.as_ref().unwrap()
};
let tts_model = state.tts_model.lock().await;
let (bert_ori, phones, tones, lang_ids) = tts_model.parse_text(&text)?;
let style_vector = tts_model.get_style_vector(&ident, 0, 1.0)?;
tts_model.synthesize(
@@ -51,26 +57,78 @@ async fn synthesize(
tones,
lang_ids,
style_vector,
sdp_ratio,
length_scale,
)?
};
Ok(([(CONTENT_TYPE, "audio/wav")], buffer))
}
#[derive(Clone)]
struct AppState {
tts_model: Arc<Mutex<Option<TTSModelHolder>>>,
tts_model: Arc<Mutex<TTSModelHolder>>,
}
impl AppState {
pub async fn new() -> anyhow::Result<Self> {
let mut tts_model = TTSModelHolder::new(
&fs::read(env::var("BERT_MODEL_PATH")?).await?,
&fs::read(env::var("TOKENIZER_PATH")?).await?,
)?;
let models = env::var("MODELS_PATH").unwrap_or("models".to_string());
let mut f = fs::read_dir(&models).await?;
let mut entries = vec![];
while let Ok(Some(e)) = f.next_entry().await {
let name = e.file_name().to_string_lossy().to_string();
if name.ends_with(".onnx") && name.starts_with("model_") {
let name_len = name.len();
let name = name.chars();
entries.push(
name.collect::<Vec<_>>()[6..name_len - 5]
.iter()
.collect::<String>(),
);
}
}
for entry in entries {
log::info!("Try loading: {entry}");
let style_vectors_bytes =
match fs::read(format!("{models}/style_vectors_{entry}.json")).await {
Ok(b) => b,
Err(e) => {
log::warn!("Error loading style_vectors_bytes from file {entry}: {e}");
continue;
}
};
let vits2_bytes = match fs::read(format!("{models}/model_{entry}.onnx")).await {
Ok(b) => b,
Err(e) => {
log::warn!("Error loading vits2_bytes from file {entry}: {e}");
continue;
}
};
if let Err(e) = tts_model.load(&entry, style_vectors_bytes, vits2_bytes) {
log::warn!("Error loading {entry}: {e}");
};
}
Ok(Self {
tts_model: Arc::new(Mutex::new(tts_model)),
})
}
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
dotenvy::dotenv().ok();
env_logger::init();
let app = Router::new()
.route("/", get(|| async { "Hello, World!" }))
.route("/synthesize", post(synthesize))
.with_state(Arc::new(AppState {
tts_model: Arc::new(Mutex::new(None)),
}));
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await?;
.route("/models", get(models))
.with_state(AppState::new().await?);
let addr = env::var("ADDR").unwrap_or("0.0.0.0:3000".to_string());
let listener = tokio::net::TcpListener::bind(&addr).await?;
log::info!("Listening on {addr}");
axum::serve(listener, app).await?;
Ok(())

View File

@@ -5,6 +5,7 @@ edition = "2021"
[dependencies]
anyhow.workspace = true
dotenvy.workspace = true
hound = "3.5.1"
jpreprocess = { version = "0.10.0", features = ["naist-jdic"] }
ndarray = "0.16.1"

View File

@@ -1,41 +1,47 @@
use std::{fs, time::Instant};
use sbv2_core::{error, tts};
use sbv2_core::tts;
use std::env;
fn main() -> error::Result<()> {
fn main() -> anyhow::Result<()> {
dotenvy::dotenv().ok();
let text = "眠たい";
let mut tts_model = tts::TTSModelHolder::new(
fs::read("models/debert.onnx")?,
fs::read("models/model_opt.onnx")?,
let ident = "aaa";
let mut tts_holder = tts::TTSModelHolder::new(
&fs::read(env::var("BERT_MODEL_PATH")?)?,
&fs::read(env::var("TOKENIZER_PATH")?)?,
)?;
tts_model.load(
"tsukuyomi",
fs::read("models/style_vectors.json")?,
fs::read("models/tokenizer.json")?,
tts_holder.load(
ident,
fs::read(env::var("STYLE_VECTORS_PATH")?)?,
fs::read(env::var("MODEL_PATH")?)?,
)?;
let (bert_ori, phones, tones, lang_ids) = tts_model.parse_text(text)?;
let (bert_ori, phones, tones, lang_ids) = tts_holder.parse_text(text)?;
let style_vector = tts_model.get_style_vector("tsukuyomi", 0, 1.0)?;
let data = tts_model.synthesize(
"tsukuyomi",
let style_vector = tts_holder.get_style_vector(ident, 0, 1.0)?;
let data = tts_holder.synthesize(
ident,
bert_ori.to_owned(),
phones.clone(),
tones.clone(),
lang_ids.clone(),
style_vector.clone(),
0.0,
0.5,
)?;
std::fs::write("output.wav", data)?;
let now = Instant::now();
for _ in 0..10 {
tts_model.synthesize(
"tsukuyomi",
tts_holder.synthesize(
ident,
bert_ori.to_owned(),
phones.clone(),
tones.clone(),
lang_ids.clone(),
style_vector.clone(),
0.0,
1.0,
)?;
}
println!("Time taken: {}", now.elapsed().as_millis());

View File

@@ -26,7 +26,7 @@ pub fn load_model<P: AsRef<[u8]>>(model_file: P) -> Result<Session> {
.with_inter_threads(num_cpus::get_physical())?
.commit_from_memory(model_file.as_ref())?)
}
#[allow(clippy::too_many_arguments)]
pub fn synthesize(
session: &Session,
bert_ori: Array2<f32>,
@@ -34,6 +34,8 @@ pub fn synthesize(
tones: Array1<i64>,
lang_ids: Array1<i64>,
style_vector: Array1<f32>,
sdp_ratio: f32,
length_scale: f32,
) -> Result<Vec<u8>> {
let bert = bert_ori.insert_axis(Axis(0));
let x_tst_lengths: Array1<i64> = array![x_tst.shape()[0] as i64];
@@ -49,6 +51,8 @@ pub fn synthesize(
"language" => lang_ids,
"bert" => bert,
"style_vec" => style_vector,
"sdp_ratio" => array![sdp_ratio],
"length_scale" => array![length_scale],
}?)?;
let audio_array = outputs

View File

@@ -48,6 +48,9 @@ impl TTSModelHolder {
tokenizer,
})
}
pub fn models(&self) -> Vec<String> {
self.models.iter().map(|m| m.ident.to_string()).collect()
}
pub fn load<I: Into<TTSIdent>, P: AsRef<[u8]>>(
&mut self,
ident: I,
@@ -156,7 +159,7 @@ impl TTSModelHolder {
) -> Result<Array1<f32>> {
style::get_style_vector(&self.find_model(ident)?.style_vectors, style_id, weight)
}
#[allow(clippy::too_many_arguments)]
pub fn synthesize<I: Into<TTSIdent>>(
&self,
ident: I,
@@ -165,6 +168,8 @@ impl TTSModelHolder {
tones: Array1<i64>,
lang_ids: Array1<i64>,
style_vector: Array1<f32>,
sdp_ratio: f32,
length_scale: f32,
) -> Result<Vec<u8>> {
let buffer = model::synthesize(
&self.find_model(ident)?.vits2,
@@ -173,6 +178,8 @@ impl TTSModelHolder {
tones,
lang_ids,
style_vector,
sdp_ratio,
length_scale,
)?;
Ok(buffer)
}

View File

@@ -1,7 +1,7 @@
import requests
res = requests.post(
"http://localhost:3000/synthesize",
"http://localhost:3001/synthesize",
json={"text": "おはようございます", "ident": "tsukuyomi"},
)
with open("output.wav", "wb") as f: