mirror of
https://github.com/neodyland/sbv2-api.git
synced 2025-12-22 23:49:58 +00:00
feat: trt partial support
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
BERT_MODEL_PATH=models/deberta.onnx
|
||||
MODEL_PATH=models/model_tsukuyomi.onnx
|
||||
MODEL_PATH=models/tsukuyomi.sbv2
|
||||
MODELS_PATH=models
|
||||
STYLE_VECTORS_PATH=models/style_vectors_tsukuyomi.json
|
||||
TOKENIZER_PATH=models/tokenizer.json
|
||||
ADDR=localhost:3000
|
||||
RUST_LOG=warn
|
||||
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -1756,6 +1756,7 @@ version = "0.1.1"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"dotenvy",
|
||||
"env_logger",
|
||||
"hound",
|
||||
"jpreprocess",
|
||||
"ndarray",
|
||||
|
||||
@@ -5,3 +5,4 @@ members = ["sbv2_api", "sbv2_core"]
|
||||
[workspace.dependencies]
|
||||
anyhow = "1.0.86"
|
||||
dotenvy = "0.15.7"
|
||||
env_logger = "0.11.5"
|
||||
|
||||
@@ -7,7 +7,7 @@ edition = "2021"
|
||||
anyhow.workspace = true
|
||||
axum = "0.7.5"
|
||||
dotenvy.workspace = true
|
||||
env_logger = "0.11.5"
|
||||
env_logger.workspace = true
|
||||
log = "0.4.22"
|
||||
sbv2_core = { version = "0.1.1", path = "../sbv2_core" }
|
||||
serde = { version = "1.0.210", features = ["derive"] }
|
||||
@@ -18,3 +18,4 @@ cuda = ["sbv2_core/cuda"]
|
||||
cuda_tf32 = ["sbv2_core/cuda_tf32"]
|
||||
dynamic = ["sbv2_core/dynamic"]
|
||||
directml = ["sbv2_core/directml"]
|
||||
tensorrt = ["sbv2_core/tensorrt"]
|
||||
@@ -135,7 +135,7 @@ impl AppState {
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
dotenvy::dotenv().ok();
|
||||
dotenvy::dotenv_override().ok();
|
||||
env_logger::init();
|
||||
let app = Router::new()
|
||||
.route("/", get(|| async { "Hello, World!" }))
|
||||
|
||||
@@ -10,6 +10,7 @@ repository = "https://github.com/tuna2134/sbv2-api"
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
dotenvy.workspace = true
|
||||
env_logger.workspace = true
|
||||
hound = "3.5.1"
|
||||
jpreprocess = { version = "0.10.0", features = ["naist-jdic"] }
|
||||
ndarray = "0.16.1"
|
||||
@@ -29,3 +30,4 @@ cuda = ["ort/cuda"]
|
||||
cuda_tf32 = []
|
||||
dynamic = ["ort/load-dynamic"]
|
||||
directml = ["ort/directml"]
|
||||
tensorrt = ["ort/tensorrt"]
|
||||
|
||||
@@ -4,18 +4,15 @@ use sbv2_core::tts;
|
||||
use std::env;
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
dotenvy::dotenv().ok();
|
||||
dotenvy::dotenv_override().ok();
|
||||
env_logger::init();
|
||||
let text = "眠たい";
|
||||
let ident = "aaa";
|
||||
let mut tts_holder = tts::TTSModelHolder::new(
|
||||
&fs::read(env::var("BERT_MODEL_PATH")?)?,
|
||||
&fs::read(env::var("TOKENIZER_PATH")?)?,
|
||||
)?;
|
||||
tts_holder.load(
|
||||
ident,
|
||||
fs::read(env::var("STYLE_VECTORS_PATH")?)?,
|
||||
fs::read(env::var("MODEL_PATH")?)?,
|
||||
)?;
|
||||
tts_holder.load_sbv2file(ident, fs::read(env::var("MODEL_PATH")?)?)?;
|
||||
|
||||
let (bert_ori, phones, tones, lang_ids) = tts_holder.parse_text(text)?;
|
||||
|
||||
@@ -32,6 +29,14 @@ fn main() -> anyhow::Result<()> {
|
||||
)?;
|
||||
std::fs::write("output.wav", data)?;
|
||||
let now = Instant::now();
|
||||
for _ in 0..10 {
|
||||
tts_holder.parse_text(text)?;
|
||||
}
|
||||
println!(
|
||||
"Time taken(parse_text): {}ms/it",
|
||||
now.elapsed().as_millis() / 10
|
||||
);
|
||||
let now = Instant::now();
|
||||
for _ in 0..10 {
|
||||
tts_holder.synthesize(
|
||||
ident,
|
||||
@@ -44,6 +49,9 @@ fn main() -> anyhow::Result<()> {
|
||||
1.0,
|
||||
)?;
|
||||
}
|
||||
println!("Time taken: {}", now.elapsed().as_millis());
|
||||
println!(
|
||||
"Time taken(synthesize): {}ms/it",
|
||||
now.elapsed().as_millis() / 10
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -4,11 +4,25 @@ use ndarray::{array, s, Array1, Array2, Axis};
|
||||
use ort::{GraphOptimizationLevel, Session};
|
||||
use std::io::Cursor;
|
||||
|
||||
#[allow(clippy::vec_init_then_push)]
|
||||
pub fn load_model<P: AsRef<[u8]>>(model_file: P) -> Result<Session> {
|
||||
#[allow(clippy::vec_init_then_push, unused_variables)]
|
||||
pub fn load_model<P: AsRef<[u8]>>(model_file: P, bert: bool) -> Result<Session> {
|
||||
let mut exp = Vec::new();
|
||||
#[cfg(feature = "tensorrt")]
|
||||
{
|
||||
if bert {
|
||||
exp.push(
|
||||
ort::TensorRTExecutionProvider::default()
|
||||
.with_fp16(true)
|
||||
.with_profile_min_shapes("input_ids:1x1,attention_mask:1x1")
|
||||
.with_profile_max_shapes("input_ids:1x100,attention_mask:1x100")
|
||||
.with_profile_opt_shapes("input_ids:1x25,attention_mask:1x25")
|
||||
.build(),
|
||||
);
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
#[allow(unused_mut)]
|
||||
let mut cuda = ort::CUDAExecutionProvider::default()
|
||||
.with_conv_algorithm_search(ort::CUDAExecutionProviderCuDNNConvAlgoSearch::Default);
|
||||
#[cfg(feature = "cuda_tf32")]
|
||||
|
||||
@@ -41,7 +41,7 @@ pub struct TTSModelHolder {
|
||||
|
||||
impl TTSModelHolder {
|
||||
pub fn new<P: AsRef<[u8]>>(bert_model_bytes: P, tokenizer_bytes: P) -> Result<Self> {
|
||||
let bert = model::load_model(bert_model_bytes)?;
|
||||
let bert = model::load_model(bert_model_bytes, true)?;
|
||||
let jtalk = jtalk::JTalk::new()?;
|
||||
let tokenizer = tokenizer::get_tokenizer(tokenizer_bytes)?;
|
||||
Ok(TTSModelHolder {
|
||||
@@ -94,7 +94,7 @@ impl TTSModelHolder {
|
||||
let ident = ident.into();
|
||||
if self.find_model(ident.clone()).is_err() {
|
||||
self.models.push(TTSModel {
|
||||
vits2: model::load_model(vits2_bytes)?,
|
||||
vits2: model::load_model(vits2_bytes, false)?,
|
||||
style_vectors: style::load_style(style_vectors_bytes)?,
|
||||
ident,
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user