From fd1caf8ab5c2ad99a1deba337351afed4d6b7141 Mon Sep 17 00:00:00 2001 From: Googlefan Date: Wed, 11 Sep 2024 11:20:25 +0000 Subject: [PATCH] feat: sbv2file format --- Cargo.lock | 41 ++++++++++++++++++++++++++++++++++++++++ Cargo.toml | 2 +- convert/convert_model.py | 24 +++++++++++++++++++++++ convert/requirements.txt | 3 ++- sbv2_api/src/main.rs | 15 +++++++++++++++ sbv2_core/Cargo.toml | 4 +++- sbv2_core/src/tts.rs | 32 ++++++++++++++++++++++++++++++- 7 files changed, 117 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 33a8baa..4808321 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -228,6 +228,8 @@ version = "1.1.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b62ac837cdb5cb22e10a256099b4fc502b1dfe560cb282963a974d7abd80e476" dependencies = [ + "jobserver", + "libc", "shlex", ] @@ -848,6 +850,15 @@ dependencies = [ "thiserror", ] +[[package]] +name = "jobserver" +version = "0.1.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" +dependencies = [ + "libc", +] + [[package]] name = "jpreprocess" version = "0.10.0" @@ -1754,8 +1765,10 @@ dependencies = [ "regex", "serde", "serde_json", + "tar", "thiserror", "tokenizers", + "zstd", ] [[package]] @@ -2368,3 +2381,31 @@ name = "zeroize" version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" + +[[package]] +name = "zstd" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcf2b778a664581e31e389454a7072dab1647606d44f7feea22cd5abb9c9f3f9" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "7.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54a3ab4db68cea366acc5c897c7b4d4d1b8994a9cd6e6f841f8964566a419059" +dependencies = [ + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.13+zstd.1.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38ff0f21cfee8f97d94cef41359e0c89aa6113028ab0291aa8ca0038995a95aa" +dependencies = [ + "cc", + "pkg-config", +] diff --git a/Cargo.toml b/Cargo.toml index 7c30b9d..e6acc62 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,4 +4,4 @@ members = ["sbv2_api", "sbv2_core"] [workspace.dependencies] anyhow = "1.0.86" -dotenvy = "0.15.7" \ No newline at end of file +dotenvy = "0.15.7" diff --git a/convert/convert_model.py b/convert/convert_model.py index a8097db..17ddc6c 100644 --- a/convert/convert_model.py +++ b/convert/convert_model.py @@ -1,5 +1,6 @@ import numpy as np import json +from io import BytesIO from style_bert_vits2.nlp import bert_models from style_bert_vits2.constants import Languages from style_bert_vits2.models.infer import get_net_g, get_text @@ -11,6 +12,9 @@ from style_bert_vits2.constants import ( DEFAULT_STYLE_WEIGHT, Languages, ) +import os +from tarfile import open as taropen, TarInfo +from zstandard import ZstdCompressor from style_bert_vits2.tts_model import TTSModel import numpy as np from argparse import ArgumentParser @@ -141,3 +145,23 @@ torch.onnx.export( ], output_names=["output"], ) +os.system(f"onnxsim ../models/model_{out_name}.onnx ../models/model_{out_name}.onnx") +onnxfile = open(f"../models/model_{out_name}.onnx", "rb").read() +stylefile = open(f"../models/style_vectors_{out_name}.json", "rb").read() +version = bytes("1", "utf8") +with taropen(f"../models/tmp_{out_name}.sbv2tar", "w") as w: + + def add_tar(f, b): + t = TarInfo(f) + t.size = len(b) + w.addfile(t, BytesIO(b)) + + add_tar("version.txt", version) + add_tar("model.onnx", onnxfile) + add_tar("style_vectors.json", stylefile) +open(f"../models/{out_name}.sbv2", "wb").write( + ZstdCompressor(threads=-1, level=22).compress( + open(f"../models/tmp_{out_name}.sbv2tar", "rb").read() + ) +) +os.unlink(f"../models/tmp_{out_name}.sbv2tar") diff --git a/convert/requirements.txt b/convert/requirements.txt index 00790b8..5f17210 100644 --- a/convert/requirements.txt +++ b/convert/requirements.txt @@ -1,3 +1,4 @@ style-bert-vits2 onnxsim -numpy<3 \ No newline at end of file +numpy<2 +zstandard \ No newline at end of file diff --git a/sbv2_api/src/main.rs b/sbv2_api/src/main.rs index e956e93..3b015d2 100644 --- a/sbv2_api/src/main.rs +++ b/sbv2_api/src/main.rs @@ -88,6 +88,20 @@ impl AppState { .iter() .collect::(), ); + } else if name.ends_with(".sbv2") { + let entry = &name[..name.len() - 5]; + log::info!("Try loading: {entry}"); + let sbv2_bytes = match fs::read(format!("{models}/{entry}.sbv2")).await { + Ok(b) => b, + Err(e) => { + log::warn!("Error loading sbv2_bytes from file {entry}: {e}"); + continue; + } + }; + if let Err(e) = tts_model.load_sbv2file(&entry, sbv2_bytes) { + log::warn!("Error loading {entry}: {e}"); + }; + log::info!("Loaded: {entry}"); } } for entry in entries { @@ -110,6 +124,7 @@ impl AppState { if let Err(e) = tts_model.load(&entry, style_vectors_bytes, vits2_bytes) { log::warn!("Error loading {entry}: {e}"); }; + log::info!("Loaded: {entry}"); } Ok(Self { tts_model: Arc::new(Mutex::new(tts_model)), diff --git a/sbv2_core/Cargo.toml b/sbv2_core/Cargo.toml index 37203d9..8befc48 100644 --- a/sbv2_core/Cargo.toml +++ b/sbv2_core/Cargo.toml @@ -15,10 +15,12 @@ ort = { git = "https://github.com/pykeio/ort.git", version = "2.0.0-rc.6" } regex = "1.10.6" serde = { version = "1.0.210", features = ["derive"] } serde_json = "1.0.128" +tar = "0.4.41" thiserror = "1.0.63" tokenizers = "0.20.0" +zstd = "0.13.2" [features] cuda = ["ort/cuda"] cuda_tf32 = [] -dynamic = ["ort/load-dynamic"] \ No newline at end of file +dynamic = ["ort/load-dynamic"] diff --git a/sbv2_core/src/tts.rs b/sbv2_core/src/tts.rs index 5a8c454..d185e41 100644 --- a/sbv2_core/src/tts.rs +++ b/sbv2_core/src/tts.rs @@ -2,7 +2,10 @@ use crate::error::{Error, Result}; use crate::{bert, jtalk, model, nlp, norm, style, tokenizer, utils}; use ndarray::{concatenate, s, Array, Array1, Array2, Axis}; use ort::Session; +use std::io::{Cursor, Read}; +use tar::Archive; use tokenizers::Tokenizer; +use zstd::decode_all; #[derive(PartialEq, Eq, Clone)] pub struct TTSIdent(String); @@ -52,7 +55,34 @@ impl TTSModelHolder { pub fn models(&self) -> Vec { self.models.iter().map(|m| m.ident.to_string()).collect() } - + pub fn load_sbv2file, P: AsRef<[u8]>>( + &mut self, + ident: I, + sbv2_bytes: P, + ) -> Result<()> { + let mut arc = Archive::new(Cursor::new(decode_all(Cursor::new(sbv2_bytes.as_ref()))?)); + let mut vits2 = None; + let mut style_vectors = None; + let mut et = arc.entries()?; + while let Some(Ok(mut e)) = et.next() { + let pth = String::from_utf8_lossy(&e.path_bytes()).to_string(); + let mut b = Vec::with_capacity(e.size() as usize); + e.read_to_end(&mut b)?; + match pth.as_str() { + "model.onnx" => vits2 = Some(b), + "style_vectors.json" => style_vectors = Some(b), + _ => continue, + } + } + if style_vectors.is_none() { + return Err(Error::ModelNotFoundError("style_vectors".to_string())); + } + if vits2.is_none() { + return Err(Error::ModelNotFoundError("vits2".to_string())); + } + self.load(ident, style_vectors.unwrap(), vits2.unwrap())?; + Ok(()) + } pub fn load, P: AsRef<[u8]>>( &mut self, ident: I,