This commit is contained in:
tuna2134
2024-09-11 10:30:19 +00:00
4 changed files with 8 additions and 27 deletions

View File

@@ -20,3 +20,4 @@ tokio = { version = "1.40.0", features = ["full"] }
cuda = ["sbv2_core/cuda"]
cuda_tf32 = ["sbv2_core/cuda_tf32"]
dynamic = ["sbv2_core/dynamic"]
directml = ["sbv2_core/directml"]

View File

@@ -21,4 +21,5 @@ tokenizers = "0.20.0"
[features]
cuda = ["ort/cuda"]
cuda_tf32 = []
dynamic = ["ort/load-dynamic"]
dynamic = ["ort/load-dynamic"]
directml = ["ort/directml"]

View File

@@ -17,6 +17,10 @@ pub fn load_model<P: AsRef<[u8]>>(model_file: P) -> Result<Session> {
}
exp.push(cuda.build());
}
#[cfg(feature = "directml")]
{
exp.push(ort::DirectMLExecutionProvider::default().build());
}
exp.push(ort::CPUExecutionProvider::default().build());
Ok(Session::builder()?
.with_execution_providers(exp)?
@@ -26,6 +30,7 @@ pub fn load_model<P: AsRef<[u8]>>(model_file: P) -> Result<Session> {
.with_inter_threads(num_cpus::get_physical())?
.commit_from_memory(model_file.as_ref())?)
}
#[allow(clippy::too_many_arguments)]
pub fn synthesize(
session: &Session,

View File

@@ -2,11 +2,6 @@ pub fn intersperse<T>(slice: &[T], sep: T) -> Vec<T>
where
T: Clone,
{
/*
result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst
return result
*/
let mut result = vec![sep.clone(); slice.len() * 2 + 1];
result
.iter_mut()
@@ -15,24 +10,3 @@ where
.for_each(|(r, s)| *r = s.clone());
result
}
/*
fn tile<T: Clone>(arr: &Array2<T>, reps: (usize, usize)) -> Array2<T> {
let (rows, cols) = arr.dim();
let (rep_rows, rep_cols) = reps;
let mut result = Array::zeros((rows * rep_rows, cols * rep_cols));
for i in 0..rep_rows {
for j in 0..rep_cols {
let view = result.slice_mut(s![
i * rows..(i + 1) * rows,
j * cols..(j + 1) * cols
]);
view.assign(arr);
}
}
result
}
*/