mirror of
https://github.com/neodyland/sbv2-api.git
synced 2025-12-22 23:49:58 +00:00
add g2p
This commit is contained in:
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
target
|
||||
models/*.onnx
|
||||
1786
Cargo.lock
generated
Normal file
1786
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
6
Cargo.toml
Normal file
6
Cargo.toml
Normal file
@@ -0,0 +1,6 @@
|
||||
[workspace]
|
||||
resolver = "2"
|
||||
members = ["sbv2_core"]
|
||||
|
||||
[workspace.dependencies]
|
||||
anyhow = "1.0.86"
|
||||
0
models/.gitkeep
Normal file
0
models/.gitkeep
Normal file
14
sbv2_core/Cargo.toml
Normal file
14
sbv2_core/Cargo.toml
Normal file
@@ -0,0 +1,14 @@
|
||||
[package]
|
||||
name = "sbv2_core"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
jpreprocess = { version = "0.10.0", features = ["naist-jdic"] }
|
||||
ndarray = "0.16.1"
|
||||
once_cell = "1.19.0"
|
||||
ort = { git = "https://github.com/pykeio/ort.git", version = "2.0.0-rc.5" }
|
||||
regex = "1.10.6"
|
||||
thiserror = "1.0.63"
|
||||
tokenizers = "0.20.0"
|
||||
26
sbv2_core/src/bert.rs
Normal file
26
sbv2_core/src/bert.rs
Normal file
@@ -0,0 +1,26 @@
|
||||
use crate::error::Result;
|
||||
use ndarray::Array2;
|
||||
use ort::{GraphOptimizationLevel, Session};
|
||||
|
||||
pub fn load_model() -> Result<Session> {
|
||||
let session = Session::builder()?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level1)?
|
||||
.with_intra_threads(1)?
|
||||
.commit_from_file("models/debert.onnx")?;
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
pub fn predict(session: &Session, token_ids: Vec<i64>, attention_masks: Vec<i64>) -> Result<()> {
|
||||
let outputs = session.run(
|
||||
ort::inputs! {
|
||||
"input_ids" => Array2::from_shape_vec((1, token_ids.len()), token_ids).unwrap(),
|
||||
"attention_mask" => Array2::from_shape_vec((1, attention_masks.len()), attention_masks).unwrap(),
|
||||
}?
|
||||
)?;
|
||||
|
||||
let output = outputs.get("output").unwrap();
|
||||
|
||||
println!("{:?}", output);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
17
sbv2_core/src/error.rs
Normal file
17
sbv2_core/src/error.rs
Normal file
@@ -0,0 +1,17 @@
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum Error {
|
||||
#[error("Tokenizer error: {0}")]
|
||||
TokenizerError(#[from] tokenizers::Error),
|
||||
#[error("JPreprocess error: {0}")]
|
||||
JPreprocessError(#[from] jpreprocess::error::JPreprocessError),
|
||||
#[error("ONNX error: {0}")]
|
||||
OrtError(#[from] ort::Error),
|
||||
#[error("NDArray error: {0}")]
|
||||
NdArrayError(#[from] ndarray::ShapeError),
|
||||
#[error("Value error: {0}")]
|
||||
ValueError(String),
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
18
sbv2_core/src/lib.rs
Normal file
18
sbv2_core/src/lib.rs
Normal file
@@ -0,0 +1,18 @@
|
||||
pub mod bert;
|
||||
pub mod error;
|
||||
pub mod text;
|
||||
|
||||
pub fn add(left: usize, right: usize) -> usize {
|
||||
left + right
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn it_works() {
|
||||
let result = add(2, 2);
|
||||
assert_eq!(result, 4);
|
||||
}
|
||||
}
|
||||
24
sbv2_core/src/main.rs
Normal file
24
sbv2_core/src/main.rs
Normal file
@@ -0,0 +1,24 @@
|
||||
use sbv2_core::{bert, error, text};
|
||||
|
||||
fn main() -> error::Result<()> {
|
||||
let text = "こんにちは,世界!";
|
||||
|
||||
let normalized_text = text::normalize_text(text);
|
||||
println!("{}", normalized_text);
|
||||
|
||||
let jtalk = text::JTalk::new()?;
|
||||
let phones = jtalk.g2p(&normalized_text)?;
|
||||
println!("{:?}", phones);
|
||||
|
||||
let tokenizer = text::get_tokenizer()?;
|
||||
println!("{:?}", tokenizer);
|
||||
|
||||
let (token_ids, attention_masks) = text::tokenize(&normalized_text, &tokenizer)?;
|
||||
println!("{:?}", token_ids);
|
||||
|
||||
let session = bert::load_model()?;
|
||||
|
||||
bert::predict(&session, token_ids, attention_masks)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
202
sbv2_core/src/text.rs
Normal file
202
sbv2_core/src/text.rs
Normal file
@@ -0,0 +1,202 @@
|
||||
use crate::error::{Error, Result};
|
||||
use jpreprocess::*;
|
||||
use once_cell::sync::Lazy;
|
||||
use regex::Regex;
|
||||
use std::collections::HashSet;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
type JPreprocessType = JPreprocess<DefaultFetcher>;
|
||||
|
||||
fn get_jtalk() -> Result<JPreprocessType> {
|
||||
let config = JPreprocessConfig {
|
||||
dictionary: SystemDictionaryConfig::Bundled(kind::JPreprocessDictionaryKind::NaistJdic),
|
||||
user_dictionary: None,
|
||||
};
|
||||
let jpreprocess = JPreprocess::from_config(config)?;
|
||||
Ok(jpreprocess)
|
||||
}
|
||||
|
||||
static JTALK_G2P_G_A1_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new(r"/A:([0-9\-]+)\+").unwrap());
|
||||
static JTALK_G2P_G_A2_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new(r"\+(\d+)\+").unwrap());
|
||||
static JTALK_G2P_G_A3_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new(r"\+(\d+)/").unwrap());
|
||||
static JTALK_G2P_G_E3_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new(r"!(\d+)_").unwrap());
|
||||
static JTALK_G2P_G_F1_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new(r"/F:(\d+)_").unwrap());
|
||||
static JTALK_G2P_G_P3_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new(r"\-(.*?)\+").unwrap());
|
||||
|
||||
fn numeric_feature_by_regex(regex: &Regex, text: &str) -> i32 {
|
||||
if let Some(mat) = regex.captures(text) {
|
||||
mat[1].parse::<i32>().unwrap()
|
||||
} else {
|
||||
-50
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! hash_set {
|
||||
($($elem:expr),* $(,)?) => {{
|
||||
let mut set = HashSet::new();
|
||||
$(
|
||||
set.insert($elem);
|
||||
)*
|
||||
set
|
||||
}};
|
||||
}
|
||||
|
||||
pub struct JTalk {
|
||||
pub jpreprocess: JPreprocessType,
|
||||
}
|
||||
|
||||
impl JTalk {
|
||||
pub fn new() -> Result<Self> {
|
||||
let jpreprocess = get_jtalk()?;
|
||||
Ok(Self { jpreprocess })
|
||||
}
|
||||
|
||||
fn fix_phone_tone(&self, phone_tone_list: Vec<(String, i32)>) -> Result<Vec<(String, i32)>> {
|
||||
let tone_values: HashSet<i32> = phone_tone_list
|
||||
.iter()
|
||||
.map(|(_letter, tone)| tone.clone())
|
||||
.collect();
|
||||
if tone_values.len() == 1 {
|
||||
assert!(tone_values == hash_set![0], "{:?}", tone_values);
|
||||
return Ok(phone_tone_list);
|
||||
} else if tone_values.len() == 2 {
|
||||
if tone_values == hash_set![0, 1] {
|
||||
return Ok(phone_tone_list);
|
||||
} else if tone_values == hash_set![-1, 0] {
|
||||
return Ok(phone_tone_list
|
||||
.iter()
|
||||
.map(|x| {
|
||||
let new_tone = if x.1 == -1 { 0 } else { x.1 };
|
||||
(x.0.clone(), new_tone)
|
||||
})
|
||||
.collect());
|
||||
} else {
|
||||
return Err(Error::ValueError("Invalid tone values 0".to_string()));
|
||||
}
|
||||
} else {
|
||||
return Err(Error::ValueError("Invalid tone values 1".to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn g2p(&self, text: &str) -> Result<()> {
|
||||
let phone_tone_list_wo_punct = self.g2phone_tone_wo_punct(text)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn g2phone_tone_wo_punct(&self, text: &str) -> Result<Vec<(String, i32)>> {
|
||||
let prosodies = self.g2p_prosody(text)?;
|
||||
|
||||
let mut results: Vec<(String, i32)> = Vec::new();
|
||||
let mut current_phrase: Vec<(String, i32)> = Vec::new();
|
||||
let mut current_tone = 0;
|
||||
|
||||
for (i, letter) in prosodies.iter().enumerate() {
|
||||
if letter == "^" {
|
||||
assert!(i == 0);
|
||||
} else if vec!["$", "?", "_", "#"].contains(&letter.as_str()) {
|
||||
results.extend(self.fix_phone_tone(current_phrase.clone())?);
|
||||
if vec!["$", "?"].contains(&letter.as_str()) {
|
||||
assert!(i == prosodies.len() - 1);
|
||||
}
|
||||
current_phrase = Vec::new();
|
||||
current_tone = 0;
|
||||
} else if letter == "[" {
|
||||
current_tone += 1;
|
||||
} else if letter == "]" {
|
||||
current_tone -= 1;
|
||||
} else {
|
||||
let new_letter = if letter == "cl" {
|
||||
"q".to_string()
|
||||
} else {
|
||||
letter.clone()
|
||||
};
|
||||
current_phrase.push((new_letter, current_tone));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
fn g2p_prosody(&self, text: &str) -> Result<Vec<String>> {
|
||||
let labels = self.jpreprocess.extract_fullcontext(text)?;
|
||||
|
||||
let mut phones: Vec<String> = Vec::new();
|
||||
for (i, label) in labels.iter().enumerate() {
|
||||
let mut p3 = {
|
||||
let label_text = label.to_string();
|
||||
let mattched = JTALK_G2P_G_P3_PATTERN.captures(&label_text).unwrap();
|
||||
mattched[1].to_string()
|
||||
};
|
||||
if "AIUEO".contains(&p3) {
|
||||
// 文字をlowerする
|
||||
p3 = p3.to_lowercase();
|
||||
}
|
||||
if p3 == "sil" {
|
||||
assert!(i == 0 || i == labels.len() - 1);
|
||||
if i == 0 {
|
||||
phones.push("^".to_string());
|
||||
} else if i == labels.len() - 1 {
|
||||
let e3 = numeric_feature_by_regex(&JTALK_G2P_G_E3_PATTERN, &label.to_string());
|
||||
if e3 == 0 {
|
||||
phones.push("$".to_string());
|
||||
} else if e3 == 1 {
|
||||
phones.push("?".to_string());
|
||||
}
|
||||
}
|
||||
continue;
|
||||
} else if p3 == "pau" {
|
||||
phones.push("_".to_string());
|
||||
continue;
|
||||
} else {
|
||||
phones.push(p3.clone());
|
||||
}
|
||||
|
||||
let a1 = numeric_feature_by_regex(&JTALK_G2P_G_A1_PATTERN, &label.to_string());
|
||||
let a2 = numeric_feature_by_regex(&JTALK_G2P_G_A2_PATTERN, &label.to_string());
|
||||
let a3 = numeric_feature_by_regex(&JTALK_G2P_G_A3_PATTERN, &label.to_string());
|
||||
|
||||
let f1 = numeric_feature_by_regex(&JTALK_G2P_G_F1_PATTERN, &label.to_string());
|
||||
|
||||
let a2_next =
|
||||
numeric_feature_by_regex(&JTALK_G2P_G_A2_PATTERN, &labels[i + 1].to_string());
|
||||
|
||||
if a3 == 1 && a2_next == 1 && "aeiouAEIOUNcl".contains(&p3) {
|
||||
phones.push("#".to_string());
|
||||
} else if a1 == 0 && a2_next == a2 + 1 && a2 != f1 {
|
||||
phones.push("]".to_string());
|
||||
} else if a2 == 1 && a2_next == 2 {
|
||||
phones.push("[".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(phones)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn normalize_text(text: &str) -> String {
|
||||
// 日本語のテキストを正規化する
|
||||
let text = text.replace("~", "ー");
|
||||
let text = text.replace("~", "ー");
|
||||
let text = text.replace("〜", "ー");
|
||||
|
||||
text
|
||||
}
|
||||
|
||||
pub fn get_tokenizer() -> Result<Tokenizer> {
|
||||
let tokenizer = Tokenizer::from_file("tokenizer.json")?;
|
||||
Ok(tokenizer)
|
||||
}
|
||||
|
||||
pub fn tokenize(text: &str, tokenizer: &Tokenizer) -> Result<(Vec<i64>, Vec<i64>)> {
|
||||
let mut token_ids = vec![1];
|
||||
let mut attention_masks = vec![1];
|
||||
for content in text.chars() {
|
||||
let token = tokenizer.encode(content.to_string(), false)?;
|
||||
let ids = token.get_ids();
|
||||
token_ids.extend(ids.iter().map(|&x| x as i64));
|
||||
attention_masks.extend(token.get_attention_mask().iter().map(|&x| x as i64));
|
||||
}
|
||||
token_ids.push(2);
|
||||
attention_masks.push(1);
|
||||
Ok((token_ids, attention_masks))
|
||||
}
|
||||
22116
tokenizer.json
Normal file
22116
tokenizer.json
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user