This commit is contained in:
tuna2134
2024-09-09 08:46:48 +00:00
parent ac94add3ed
commit e075937ee7
11 changed files with 24211 additions and 0 deletions

2
.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
target
models/*.onnx

1786
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

6
Cargo.toml Normal file
View File

@@ -0,0 +1,6 @@
[workspace]
resolver = "2"
members = ["sbv2_core"]
[workspace.dependencies]
anyhow = "1.0.86"

0
models/.gitkeep Normal file
View File

14
sbv2_core/Cargo.toml Normal file
View File

@@ -0,0 +1,14 @@
[package]
name = "sbv2_core"
version = "0.1.0"
edition = "2021"
[dependencies]
anyhow.workspace = true
jpreprocess = { version = "0.10.0", features = ["naist-jdic"] }
ndarray = "0.16.1"
once_cell = "1.19.0"
ort = { git = "https://github.com/pykeio/ort.git", version = "2.0.0-rc.5" }
regex = "1.10.6"
thiserror = "1.0.63"
tokenizers = "0.20.0"

26
sbv2_core/src/bert.rs Normal file
View File

@@ -0,0 +1,26 @@
use crate::error::Result;
use ndarray::Array2;
use ort::{GraphOptimizationLevel, Session};
pub fn load_model() -> Result<Session> {
let session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level1)?
.with_intra_threads(1)?
.commit_from_file("models/debert.onnx")?;
Ok(session)
}
pub fn predict(session: &Session, token_ids: Vec<i64>, attention_masks: Vec<i64>) -> Result<()> {
let outputs = session.run(
ort::inputs! {
"input_ids" => Array2::from_shape_vec((1, token_ids.len()), token_ids).unwrap(),
"attention_mask" => Array2::from_shape_vec((1, attention_masks.len()), attention_masks).unwrap(),
}?
)?;
let output = outputs.get("output").unwrap();
println!("{:?}", output);
Ok(())
}

17
sbv2_core/src/error.rs Normal file
View File

@@ -0,0 +1,17 @@
use thiserror::Error;
#[derive(Error, Debug)]
pub enum Error {
#[error("Tokenizer error: {0}")]
TokenizerError(#[from] tokenizers::Error),
#[error("JPreprocess error: {0}")]
JPreprocessError(#[from] jpreprocess::error::JPreprocessError),
#[error("ONNX error: {0}")]
OrtError(#[from] ort::Error),
#[error("NDArray error: {0}")]
NdArrayError(#[from] ndarray::ShapeError),
#[error("Value error: {0}")]
ValueError(String),
}
pub type Result<T> = std::result::Result<T, Error>;

18
sbv2_core/src/lib.rs Normal file
View File

@@ -0,0 +1,18 @@
pub mod bert;
pub mod error;
pub mod text;
pub fn add(left: usize, right: usize) -> usize {
left + right
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn it_works() {
let result = add(2, 2);
assert_eq!(result, 4);
}
}

24
sbv2_core/src/main.rs Normal file
View File

@@ -0,0 +1,24 @@
use sbv2_core::{bert, error, text};
fn main() -> error::Result<()> {
let text = "こんにちは,世界!";
let normalized_text = text::normalize_text(text);
println!("{}", normalized_text);
let jtalk = text::JTalk::new()?;
let phones = jtalk.g2p(&normalized_text)?;
println!("{:?}", phones);
let tokenizer = text::get_tokenizer()?;
println!("{:?}", tokenizer);
let (token_ids, attention_masks) = text::tokenize(&normalized_text, &tokenizer)?;
println!("{:?}", token_ids);
let session = bert::load_model()?;
bert::predict(&session, token_ids, attention_masks)?;
Ok(())
}

202
sbv2_core/src/text.rs Normal file
View File

@@ -0,0 +1,202 @@
use crate::error::{Error, Result};
use jpreprocess::*;
use once_cell::sync::Lazy;
use regex::Regex;
use std::collections::HashSet;
use tokenizers::Tokenizer;
type JPreprocessType = JPreprocess<DefaultFetcher>;
fn get_jtalk() -> Result<JPreprocessType> {
let config = JPreprocessConfig {
dictionary: SystemDictionaryConfig::Bundled(kind::JPreprocessDictionaryKind::NaistJdic),
user_dictionary: None,
};
let jpreprocess = JPreprocess::from_config(config)?;
Ok(jpreprocess)
}
static JTALK_G2P_G_A1_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new(r"/A:([0-9\-]+)\+").unwrap());
static JTALK_G2P_G_A2_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new(r"\+(\d+)\+").unwrap());
static JTALK_G2P_G_A3_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new(r"\+(\d+)/").unwrap());
static JTALK_G2P_G_E3_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new(r"!(\d+)_").unwrap());
static JTALK_G2P_G_F1_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new(r"/F:(\d+)_").unwrap());
static JTALK_G2P_G_P3_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new(r"\-(.*?)\+").unwrap());
fn numeric_feature_by_regex(regex: &Regex, text: &str) -> i32 {
if let Some(mat) = regex.captures(text) {
mat[1].parse::<i32>().unwrap()
} else {
-50
}
}
macro_rules! hash_set {
($($elem:expr),* $(,)?) => {{
let mut set = HashSet::new();
$(
set.insert($elem);
)*
set
}};
}
pub struct JTalk {
pub jpreprocess: JPreprocessType,
}
impl JTalk {
pub fn new() -> Result<Self> {
let jpreprocess = get_jtalk()?;
Ok(Self { jpreprocess })
}
fn fix_phone_tone(&self, phone_tone_list: Vec<(String, i32)>) -> Result<Vec<(String, i32)>> {
let tone_values: HashSet<i32> = phone_tone_list
.iter()
.map(|(_letter, tone)| tone.clone())
.collect();
if tone_values.len() == 1 {
assert!(tone_values == hash_set![0], "{:?}", tone_values);
return Ok(phone_tone_list);
} else if tone_values.len() == 2 {
if tone_values == hash_set![0, 1] {
return Ok(phone_tone_list);
} else if tone_values == hash_set![-1, 0] {
return Ok(phone_tone_list
.iter()
.map(|x| {
let new_tone = if x.1 == -1 { 0 } else { x.1 };
(x.0.clone(), new_tone)
})
.collect());
} else {
return Err(Error::ValueError("Invalid tone values 0".to_string()));
}
} else {
return Err(Error::ValueError("Invalid tone values 1".to_string()));
}
}
pub fn g2p(&self, text: &str) -> Result<()> {
let phone_tone_list_wo_punct = self.g2phone_tone_wo_punct(text)?;
Ok(())
}
fn g2phone_tone_wo_punct(&self, text: &str) -> Result<Vec<(String, i32)>> {
let prosodies = self.g2p_prosody(text)?;
let mut results: Vec<(String, i32)> = Vec::new();
let mut current_phrase: Vec<(String, i32)> = Vec::new();
let mut current_tone = 0;
for (i, letter) in prosodies.iter().enumerate() {
if letter == "^" {
assert!(i == 0);
} else if vec!["$", "?", "_", "#"].contains(&letter.as_str()) {
results.extend(self.fix_phone_tone(current_phrase.clone())?);
if vec!["$", "?"].contains(&letter.as_str()) {
assert!(i == prosodies.len() - 1);
}
current_phrase = Vec::new();
current_tone = 0;
} else if letter == "[" {
current_tone += 1;
} else if letter == "]" {
current_tone -= 1;
} else {
let new_letter = if letter == "cl" {
"q".to_string()
} else {
letter.clone()
};
current_phrase.push((new_letter, current_tone));
}
}
Ok(results)
}
fn g2p_prosody(&self, text: &str) -> Result<Vec<String>> {
let labels = self.jpreprocess.extract_fullcontext(text)?;
let mut phones: Vec<String> = Vec::new();
for (i, label) in labels.iter().enumerate() {
let mut p3 = {
let label_text = label.to_string();
let mattched = JTALK_G2P_G_P3_PATTERN.captures(&label_text).unwrap();
mattched[1].to_string()
};
if "AIUEO".contains(&p3) {
// 文字をlowerする
p3 = p3.to_lowercase();
}
if p3 == "sil" {
assert!(i == 0 || i == labels.len() - 1);
if i == 0 {
phones.push("^".to_string());
} else if i == labels.len() - 1 {
let e3 = numeric_feature_by_regex(&JTALK_G2P_G_E3_PATTERN, &label.to_string());
if e3 == 0 {
phones.push("$".to_string());
} else if e3 == 1 {
phones.push("?".to_string());
}
}
continue;
} else if p3 == "pau" {
phones.push("_".to_string());
continue;
} else {
phones.push(p3.clone());
}
let a1 = numeric_feature_by_regex(&JTALK_G2P_G_A1_PATTERN, &label.to_string());
let a2 = numeric_feature_by_regex(&JTALK_G2P_G_A2_PATTERN, &label.to_string());
let a3 = numeric_feature_by_regex(&JTALK_G2P_G_A3_PATTERN, &label.to_string());
let f1 = numeric_feature_by_regex(&JTALK_G2P_G_F1_PATTERN, &label.to_string());
let a2_next =
numeric_feature_by_regex(&JTALK_G2P_G_A2_PATTERN, &labels[i + 1].to_string());
if a3 == 1 && a2_next == 1 && "aeiouAEIOUNcl".contains(&p3) {
phones.push("#".to_string());
} else if a1 == 0 && a2_next == a2 + 1 && a2 != f1 {
phones.push("]".to_string());
} else if a2 == 1 && a2_next == 2 {
phones.push("[".to_string());
}
}
Ok(phones)
}
}
pub fn normalize_text(text: &str) -> String {
// 日本語のテキストを正規化する
let text = text.replace("~", "");
let text = text.replace("", "");
let text = text.replace("", "");
text
}
pub fn get_tokenizer() -> Result<Tokenizer> {
let tokenizer = Tokenizer::from_file("tokenizer.json")?;
Ok(tokenizer)
}
pub fn tokenize(text: &str, tokenizer: &Tokenizer) -> Result<(Vec<i64>, Vec<i64>)> {
let mut token_ids = vec![1];
let mut attention_masks = vec![1];
for content in text.chars() {
let token = tokenizer.encode(content.to_string(), false)?;
let ids = token.get_ids();
token_ids.extend(ids.iter().map(|&x| x as i64));
attention_masks.extend(token.get_attention_mask().iter().map(|&x| x as i64));
}
token_ids.push(2);
attention_masks.push(1);
Ok((token_ids, attention_masks))
}

22116
tokenizer.json Normal file

File diff suppressed because it is too large Load Diff