mirror of
https://github.com/neodyland/sbv2-api.git
synced 2025-12-22 23:49:58 +00:00
add g2p
This commit is contained in:
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
target
|
||||||
|
models/*.onnx
|
||||||
1786
Cargo.lock
generated
Normal file
1786
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
6
Cargo.toml
Normal file
6
Cargo.toml
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
[workspace]
|
||||||
|
resolver = "2"
|
||||||
|
members = ["sbv2_core"]
|
||||||
|
|
||||||
|
[workspace.dependencies]
|
||||||
|
anyhow = "1.0.86"
|
||||||
0
models/.gitkeep
Normal file
0
models/.gitkeep
Normal file
14
sbv2_core/Cargo.toml
Normal file
14
sbv2_core/Cargo.toml
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
[package]
|
||||||
|
name = "sbv2_core"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
anyhow.workspace = true
|
||||||
|
jpreprocess = { version = "0.10.0", features = ["naist-jdic"] }
|
||||||
|
ndarray = "0.16.1"
|
||||||
|
once_cell = "1.19.0"
|
||||||
|
ort = { git = "https://github.com/pykeio/ort.git", version = "2.0.0-rc.5" }
|
||||||
|
regex = "1.10.6"
|
||||||
|
thiserror = "1.0.63"
|
||||||
|
tokenizers = "0.20.0"
|
||||||
26
sbv2_core/src/bert.rs
Normal file
26
sbv2_core/src/bert.rs
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
use crate::error::Result;
|
||||||
|
use ndarray::Array2;
|
||||||
|
use ort::{GraphOptimizationLevel, Session};
|
||||||
|
|
||||||
|
pub fn load_model() -> Result<Session> {
|
||||||
|
let session = Session::builder()?
|
||||||
|
.with_optimization_level(GraphOptimizationLevel::Level1)?
|
||||||
|
.with_intra_threads(1)?
|
||||||
|
.commit_from_file("models/debert.onnx")?;
|
||||||
|
Ok(session)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn predict(session: &Session, token_ids: Vec<i64>, attention_masks: Vec<i64>) -> Result<()> {
|
||||||
|
let outputs = session.run(
|
||||||
|
ort::inputs! {
|
||||||
|
"input_ids" => Array2::from_shape_vec((1, token_ids.len()), token_ids).unwrap(),
|
||||||
|
"attention_mask" => Array2::from_shape_vec((1, attention_masks.len()), attention_masks).unwrap(),
|
||||||
|
}?
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let output = outputs.get("output").unwrap();
|
||||||
|
|
||||||
|
println!("{:?}", output);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
17
sbv2_core/src/error.rs
Normal file
17
sbv2_core/src/error.rs
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
pub enum Error {
|
||||||
|
#[error("Tokenizer error: {0}")]
|
||||||
|
TokenizerError(#[from] tokenizers::Error),
|
||||||
|
#[error("JPreprocess error: {0}")]
|
||||||
|
JPreprocessError(#[from] jpreprocess::error::JPreprocessError),
|
||||||
|
#[error("ONNX error: {0}")]
|
||||||
|
OrtError(#[from] ort::Error),
|
||||||
|
#[error("NDArray error: {0}")]
|
||||||
|
NdArrayError(#[from] ndarray::ShapeError),
|
||||||
|
#[error("Value error: {0}")]
|
||||||
|
ValueError(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub type Result<T> = std::result::Result<T, Error>;
|
||||||
18
sbv2_core/src/lib.rs
Normal file
18
sbv2_core/src/lib.rs
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
pub mod bert;
|
||||||
|
pub mod error;
|
||||||
|
pub mod text;
|
||||||
|
|
||||||
|
pub fn add(left: usize, right: usize) -> usize {
|
||||||
|
left + right
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn it_works() {
|
||||||
|
let result = add(2, 2);
|
||||||
|
assert_eq!(result, 4);
|
||||||
|
}
|
||||||
|
}
|
||||||
24
sbv2_core/src/main.rs
Normal file
24
sbv2_core/src/main.rs
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
use sbv2_core::{bert, error, text};
|
||||||
|
|
||||||
|
fn main() -> error::Result<()> {
|
||||||
|
let text = "こんにちは,世界!";
|
||||||
|
|
||||||
|
let normalized_text = text::normalize_text(text);
|
||||||
|
println!("{}", normalized_text);
|
||||||
|
|
||||||
|
let jtalk = text::JTalk::new()?;
|
||||||
|
let phones = jtalk.g2p(&normalized_text)?;
|
||||||
|
println!("{:?}", phones);
|
||||||
|
|
||||||
|
let tokenizer = text::get_tokenizer()?;
|
||||||
|
println!("{:?}", tokenizer);
|
||||||
|
|
||||||
|
let (token_ids, attention_masks) = text::tokenize(&normalized_text, &tokenizer)?;
|
||||||
|
println!("{:?}", token_ids);
|
||||||
|
|
||||||
|
let session = bert::load_model()?;
|
||||||
|
|
||||||
|
bert::predict(&session, token_ids, attention_masks)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
202
sbv2_core/src/text.rs
Normal file
202
sbv2_core/src/text.rs
Normal file
@@ -0,0 +1,202 @@
|
|||||||
|
use crate::error::{Error, Result};
|
||||||
|
use jpreprocess::*;
|
||||||
|
use once_cell::sync::Lazy;
|
||||||
|
use regex::Regex;
|
||||||
|
use std::collections::HashSet;
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
type JPreprocessType = JPreprocess<DefaultFetcher>;
|
||||||
|
|
||||||
|
fn get_jtalk() -> Result<JPreprocessType> {
|
||||||
|
let config = JPreprocessConfig {
|
||||||
|
dictionary: SystemDictionaryConfig::Bundled(kind::JPreprocessDictionaryKind::NaistJdic),
|
||||||
|
user_dictionary: None,
|
||||||
|
};
|
||||||
|
let jpreprocess = JPreprocess::from_config(config)?;
|
||||||
|
Ok(jpreprocess)
|
||||||
|
}
|
||||||
|
|
||||||
|
static JTALK_G2P_G_A1_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new(r"/A:([0-9\-]+)\+").unwrap());
|
||||||
|
static JTALK_G2P_G_A2_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new(r"\+(\d+)\+").unwrap());
|
||||||
|
static JTALK_G2P_G_A3_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new(r"\+(\d+)/").unwrap());
|
||||||
|
static JTALK_G2P_G_E3_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new(r"!(\d+)_").unwrap());
|
||||||
|
static JTALK_G2P_G_F1_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new(r"/F:(\d+)_").unwrap());
|
||||||
|
static JTALK_G2P_G_P3_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new(r"\-(.*?)\+").unwrap());
|
||||||
|
|
||||||
|
fn numeric_feature_by_regex(regex: &Regex, text: &str) -> i32 {
|
||||||
|
if let Some(mat) = regex.captures(text) {
|
||||||
|
mat[1].parse::<i32>().unwrap()
|
||||||
|
} else {
|
||||||
|
-50
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
macro_rules! hash_set {
|
||||||
|
($($elem:expr),* $(,)?) => {{
|
||||||
|
let mut set = HashSet::new();
|
||||||
|
$(
|
||||||
|
set.insert($elem);
|
||||||
|
)*
|
||||||
|
set
|
||||||
|
}};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct JTalk {
|
||||||
|
pub jpreprocess: JPreprocessType,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl JTalk {
|
||||||
|
pub fn new() -> Result<Self> {
|
||||||
|
let jpreprocess = get_jtalk()?;
|
||||||
|
Ok(Self { jpreprocess })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn fix_phone_tone(&self, phone_tone_list: Vec<(String, i32)>) -> Result<Vec<(String, i32)>> {
|
||||||
|
let tone_values: HashSet<i32> = phone_tone_list
|
||||||
|
.iter()
|
||||||
|
.map(|(_letter, tone)| tone.clone())
|
||||||
|
.collect();
|
||||||
|
if tone_values.len() == 1 {
|
||||||
|
assert!(tone_values == hash_set![0], "{:?}", tone_values);
|
||||||
|
return Ok(phone_tone_list);
|
||||||
|
} else if tone_values.len() == 2 {
|
||||||
|
if tone_values == hash_set![0, 1] {
|
||||||
|
return Ok(phone_tone_list);
|
||||||
|
} else if tone_values == hash_set![-1, 0] {
|
||||||
|
return Ok(phone_tone_list
|
||||||
|
.iter()
|
||||||
|
.map(|x| {
|
||||||
|
let new_tone = if x.1 == -1 { 0 } else { x.1 };
|
||||||
|
(x.0.clone(), new_tone)
|
||||||
|
})
|
||||||
|
.collect());
|
||||||
|
} else {
|
||||||
|
return Err(Error::ValueError("Invalid tone values 0".to_string()));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return Err(Error::ValueError("Invalid tone values 1".to_string()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn g2p(&self, text: &str) -> Result<()> {
|
||||||
|
let phone_tone_list_wo_punct = self.g2phone_tone_wo_punct(text)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn g2phone_tone_wo_punct(&self, text: &str) -> Result<Vec<(String, i32)>> {
|
||||||
|
let prosodies = self.g2p_prosody(text)?;
|
||||||
|
|
||||||
|
let mut results: Vec<(String, i32)> = Vec::new();
|
||||||
|
let mut current_phrase: Vec<(String, i32)> = Vec::new();
|
||||||
|
let mut current_tone = 0;
|
||||||
|
|
||||||
|
for (i, letter) in prosodies.iter().enumerate() {
|
||||||
|
if letter == "^" {
|
||||||
|
assert!(i == 0);
|
||||||
|
} else if vec!["$", "?", "_", "#"].contains(&letter.as_str()) {
|
||||||
|
results.extend(self.fix_phone_tone(current_phrase.clone())?);
|
||||||
|
if vec!["$", "?"].contains(&letter.as_str()) {
|
||||||
|
assert!(i == prosodies.len() - 1);
|
||||||
|
}
|
||||||
|
current_phrase = Vec::new();
|
||||||
|
current_tone = 0;
|
||||||
|
} else if letter == "[" {
|
||||||
|
current_tone += 1;
|
||||||
|
} else if letter == "]" {
|
||||||
|
current_tone -= 1;
|
||||||
|
} else {
|
||||||
|
let new_letter = if letter == "cl" {
|
||||||
|
"q".to_string()
|
||||||
|
} else {
|
||||||
|
letter.clone()
|
||||||
|
};
|
||||||
|
current_phrase.push((new_letter, current_tone));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(results)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn g2p_prosody(&self, text: &str) -> Result<Vec<String>> {
|
||||||
|
let labels = self.jpreprocess.extract_fullcontext(text)?;
|
||||||
|
|
||||||
|
let mut phones: Vec<String> = Vec::new();
|
||||||
|
for (i, label) in labels.iter().enumerate() {
|
||||||
|
let mut p3 = {
|
||||||
|
let label_text = label.to_string();
|
||||||
|
let mattched = JTALK_G2P_G_P3_PATTERN.captures(&label_text).unwrap();
|
||||||
|
mattched[1].to_string()
|
||||||
|
};
|
||||||
|
if "AIUEO".contains(&p3) {
|
||||||
|
// 文字をlowerする
|
||||||
|
p3 = p3.to_lowercase();
|
||||||
|
}
|
||||||
|
if p3 == "sil" {
|
||||||
|
assert!(i == 0 || i == labels.len() - 1);
|
||||||
|
if i == 0 {
|
||||||
|
phones.push("^".to_string());
|
||||||
|
} else if i == labels.len() - 1 {
|
||||||
|
let e3 = numeric_feature_by_regex(&JTALK_G2P_G_E3_PATTERN, &label.to_string());
|
||||||
|
if e3 == 0 {
|
||||||
|
phones.push("$".to_string());
|
||||||
|
} else if e3 == 1 {
|
||||||
|
phones.push("?".to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
} else if p3 == "pau" {
|
||||||
|
phones.push("_".to_string());
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
phones.push(p3.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
let a1 = numeric_feature_by_regex(&JTALK_G2P_G_A1_PATTERN, &label.to_string());
|
||||||
|
let a2 = numeric_feature_by_regex(&JTALK_G2P_G_A2_PATTERN, &label.to_string());
|
||||||
|
let a3 = numeric_feature_by_regex(&JTALK_G2P_G_A3_PATTERN, &label.to_string());
|
||||||
|
|
||||||
|
let f1 = numeric_feature_by_regex(&JTALK_G2P_G_F1_PATTERN, &label.to_string());
|
||||||
|
|
||||||
|
let a2_next =
|
||||||
|
numeric_feature_by_regex(&JTALK_G2P_G_A2_PATTERN, &labels[i + 1].to_string());
|
||||||
|
|
||||||
|
if a3 == 1 && a2_next == 1 && "aeiouAEIOUNcl".contains(&p3) {
|
||||||
|
phones.push("#".to_string());
|
||||||
|
} else if a1 == 0 && a2_next == a2 + 1 && a2 != f1 {
|
||||||
|
phones.push("]".to_string());
|
||||||
|
} else if a2 == 1 && a2_next == 2 {
|
||||||
|
phones.push("[".to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(phones)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn normalize_text(text: &str) -> String {
|
||||||
|
// 日本語のテキストを正規化する
|
||||||
|
let text = text.replace("~", "ー");
|
||||||
|
let text = text.replace("~", "ー");
|
||||||
|
let text = text.replace("〜", "ー");
|
||||||
|
|
||||||
|
text
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_tokenizer() -> Result<Tokenizer> {
|
||||||
|
let tokenizer = Tokenizer::from_file("tokenizer.json")?;
|
||||||
|
Ok(tokenizer)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tokenize(text: &str, tokenizer: &Tokenizer) -> Result<(Vec<i64>, Vec<i64>)> {
|
||||||
|
let mut token_ids = vec![1];
|
||||||
|
let mut attention_masks = vec![1];
|
||||||
|
for content in text.chars() {
|
||||||
|
let token = tokenizer.encode(content.to_string(), false)?;
|
||||||
|
let ids = token.get_ids();
|
||||||
|
token_ids.extend(ids.iter().map(|&x| x as i64));
|
||||||
|
attention_masks.extend(token.get_attention_mask().iter().map(|&x| x as i64));
|
||||||
|
}
|
||||||
|
token_ids.push(2);
|
||||||
|
attention_masks.push(1);
|
||||||
|
Ok((token_ids, attention_masks))
|
||||||
|
}
|
||||||
22116
tokenizer.json
Normal file
22116
tokenizer.json
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user