Merge branch 'main' into main

This commit is contained in:
Googlefan256
2024-09-11 20:22:05 +09:00
committed by GitHub
4 changed files with 11 additions and 26 deletions

View File

@@ -17,3 +17,4 @@ tokio = { version = "1.40.0", features = ["full"] }
cuda = ["sbv2_core/cuda"]
cuda_tf32 = ["sbv2_core/cuda_tf32"]
dynamic = ["sbv2_core/dynamic"]
directml = ["sbv2_core/directml"]

View File

@@ -1,7 +1,11 @@
[package]
name = "sbv2_core"
description = "Style-Bert-VITSの推論ライブラリ"
version = "0.1.0"
edition = "2021"
license = "MIT"
readme = "../README.md"
repository = "https://github.com/tuna2134/sbv2-api"
[dependencies]
anyhow.workspace = true
@@ -24,3 +28,4 @@ zstd = "0.13.2"
cuda = ["ort/cuda"]
cuda_tf32 = []
dynamic = ["ort/load-dynamic"]
directml = ["ort/directml"]

View File

@@ -17,6 +17,10 @@ pub fn load_model<P: AsRef<[u8]>>(model_file: P) -> Result<Session> {
}
exp.push(cuda.build());
}
#[cfg(feature = "directml")]
{
exp.push(ort::DirectMLExecutionProvider::default().build());
}
exp.push(ort::CPUExecutionProvider::default().build());
Ok(Session::builder()?
.with_execution_providers(exp)?
@@ -26,6 +30,7 @@ pub fn load_model<P: AsRef<[u8]>>(model_file: P) -> Result<Session> {
.with_inter_threads(num_cpus::get_physical())?
.commit_from_memory(model_file.as_ref())?)
}
#[allow(clippy::too_many_arguments)]
pub fn synthesize(
session: &Session,

View File

@@ -2,11 +2,6 @@ pub fn intersperse<T>(slice: &[T], sep: T) -> Vec<T>
where
T: Clone,
{
/*
result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst
return result
*/
let mut result = vec![sep.clone(); slice.len() * 2 + 1];
result
.iter_mut()
@@ -15,24 +10,3 @@ where
.for_each(|(r, s)| *r = s.clone());
result
}
/*
fn tile<T: Clone>(arr: &Array2<T>, reps: (usize, usize)) -> Array2<T> {
let (rows, cols) = arr.dim();
let (rep_rows, rep_cols) = reps;
let mut result = Array::zeros((rows * rep_rows, cols * rep_cols));
for i in 0..rep_rows {
for j in 0..rep_cols {
let view = result.slice_mut(s![
i * rows..(i + 1) * rows,
j * cols..(j + 1) * cols
]);
view.assign(arr);
}
}
result
}
*/