feat: trt partial support

This commit is contained in:
Googlefan
2024-09-12 03:56:16 +00:00
parent 0f29d2f5fe
commit a38be530f9
9 changed files with 45 additions and 18 deletions

View File

@@ -1,6 +1,6 @@
BERT_MODEL_PATH=models/deberta.onnx
MODEL_PATH=models/model_tsukuyomi.onnx
MODEL_PATH=models/tsukuyomi.sbv2
MODELS_PATH=models
STYLE_VECTORS_PATH=models/style_vectors_tsukuyomi.json
TOKENIZER_PATH=models/tokenizer.json
ADDR=localhost:3000
ADDR=localhost:3000
RUST_LOG=warn

1
Cargo.lock generated
View File

@@ -1756,6 +1756,7 @@ version = "0.1.1"
dependencies = [
"anyhow",
"dotenvy",
"env_logger",
"hound",
"jpreprocess",
"ndarray",

View File

@@ -5,3 +5,4 @@ members = ["sbv2_api", "sbv2_core"]
[workspace.dependencies]
anyhow = "1.0.86"
dotenvy = "0.15.7"
env_logger = "0.11.5"

View File

@@ -7,7 +7,7 @@ edition = "2021"
anyhow.workspace = true
axum = "0.7.5"
dotenvy.workspace = true
env_logger = "0.11.5"
env_logger.workspace = true
log = "0.4.22"
sbv2_core = { version = "0.1.1", path = "../sbv2_core" }
serde = { version = "1.0.210", features = ["derive"] }
@@ -17,4 +17,5 @@ tokio = { version = "1.40.0", features = ["full"] }
cuda = ["sbv2_core/cuda"]
cuda_tf32 = ["sbv2_core/cuda_tf32"]
dynamic = ["sbv2_core/dynamic"]
directml = ["sbv2_core/directml"]
directml = ["sbv2_core/directml"]
tensorrt = ["sbv2_core/tensorrt"]

View File

@@ -135,7 +135,7 @@ impl AppState {
#[tokio::main]
async fn main() -> anyhow::Result<()> {
dotenvy::dotenv().ok();
dotenvy::dotenv_override().ok();
env_logger::init();
let app = Router::new()
.route("/", get(|| async { "Hello, World!" }))

View File

@@ -10,6 +10,7 @@ repository = "https://github.com/tuna2134/sbv2-api"
[dependencies]
anyhow.workspace = true
dotenvy.workspace = true
env_logger.workspace = true
hound = "3.5.1"
jpreprocess = { version = "0.10.0", features = ["naist-jdic"] }
ndarray = "0.16.1"
@@ -29,3 +30,4 @@ cuda = ["ort/cuda"]
cuda_tf32 = []
dynamic = ["ort/load-dynamic"]
directml = ["ort/directml"]
tensorrt = ["ort/tensorrt"]

View File

@@ -4,18 +4,15 @@ use sbv2_core::tts;
use std::env;
fn main() -> anyhow::Result<()> {
dotenvy::dotenv().ok();
dotenvy::dotenv_override().ok();
env_logger::init();
let text = "眠たい";
let ident = "aaa";
let mut tts_holder = tts::TTSModelHolder::new(
&fs::read(env::var("BERT_MODEL_PATH")?)?,
&fs::read(env::var("TOKENIZER_PATH")?)?,
)?;
tts_holder.load(
ident,
fs::read(env::var("STYLE_VECTORS_PATH")?)?,
fs::read(env::var("MODEL_PATH")?)?,
)?;
tts_holder.load_sbv2file(ident, fs::read(env::var("MODEL_PATH")?)?)?;
let (bert_ori, phones, tones, lang_ids) = tts_holder.parse_text(text)?;
@@ -32,6 +29,14 @@ fn main() -> anyhow::Result<()> {
)?;
std::fs::write("output.wav", data)?;
let now = Instant::now();
for _ in 0..10 {
tts_holder.parse_text(text)?;
}
println!(
"Time taken(parse_text): {}ms/it",
now.elapsed().as_millis() / 10
);
let now = Instant::now();
for _ in 0..10 {
tts_holder.synthesize(
ident,
@@ -44,6 +49,9 @@ fn main() -> anyhow::Result<()> {
1.0,
)?;
}
println!("Time taken: {}", now.elapsed().as_millis());
println!(
"Time taken(synthesize): {}ms/it",
now.elapsed().as_millis() / 10
);
Ok(())
}

View File

@@ -4,11 +4,25 @@ use ndarray::{array, s, Array1, Array2, Axis};
use ort::{GraphOptimizationLevel, Session};
use std::io::Cursor;
#[allow(clippy::vec_init_then_push)]
pub fn load_model<P: AsRef<[u8]>>(model_file: P) -> Result<Session> {
#[allow(clippy::vec_init_then_push, unused_variables)]
pub fn load_model<P: AsRef<[u8]>>(model_file: P, bert: bool) -> Result<Session> {
let mut exp = Vec::new();
#[cfg(feature = "tensorrt")]
{
if bert {
exp.push(
ort::TensorRTExecutionProvider::default()
.with_fp16(true)
.with_profile_min_shapes("input_ids:1x1,attention_mask:1x1")
.with_profile_max_shapes("input_ids:1x100,attention_mask:1x100")
.with_profile_opt_shapes("input_ids:1x25,attention_mask:1x25")
.build(),
);
}
}
#[cfg(feature = "cuda")]
{
#[allow(unused_mut)]
let mut cuda = ort::CUDAExecutionProvider::default()
.with_conv_algorithm_search(ort::CUDAExecutionProviderCuDNNConvAlgoSearch::Default);
#[cfg(feature = "cuda_tf32")]

View File

@@ -41,7 +41,7 @@ pub struct TTSModelHolder {
impl TTSModelHolder {
pub fn new<P: AsRef<[u8]>>(bert_model_bytes: P, tokenizer_bytes: P) -> Result<Self> {
let bert = model::load_model(bert_model_bytes)?;
let bert = model::load_model(bert_model_bytes, true)?;
let jtalk = jtalk::JTalk::new()?;
let tokenizer = tokenizer::get_tokenizer(tokenizer_bytes)?;
Ok(TTSModelHolder {
@@ -55,7 +55,7 @@ impl TTSModelHolder {
pub fn models(&self) -> Vec<String> {
self.models.iter().map(|m| m.ident.to_string()).collect()
}
pub fn load_sbv2file<I: Into<TTSIdent>, P: AsRef<[u8]>>(
&mut self,
ident: I,
@@ -94,7 +94,7 @@ impl TTSModelHolder {
let ident = ident.into();
if self.find_model(ident.clone()).is_err() {
self.models.push(TTSModel {
vits2: model::load_model(vits2_bytes)?,
vits2: model::load_model(vits2_bytes, false)?,
style_vectors: style::load_style(style_vectors_bytes)?,
ident,
})