Files
sbv2-api/crates/sbv2_wasm/src-js/index.ts
2025-02-22 10:04:40 +00:00

109 lines
2.9 KiB
TypeScript

import * as wasm from "../pkg/sbv2_wasm.js";
import { InferenceSession, Tensor } from "onnxruntime-web";
export class ModelHolder {
private models: Map<string, [InferenceSession, wasm.StyleVectorWrap]> =
new Map();
constructor(
private tok: wasm.TokenizerWrap,
private deberta: InferenceSession,
) {}
public static async globalInit(buf: ArrayBufferLike) {
await wasm.default(buf);
}
public static async create(tok: string, deberta: ArrayBufferLike) {
return new ModelHolder(
wasm.load_tokenizer(tok),
await InferenceSession.create(deberta, {
executionProviders: ["webnn", "webgpu", "wasm", "cpu"],
graphOptimizationLevel: "all",
}),
);
}
public async synthesize(
name: string,
text: string,
style_id: number = 0,
style_weight: number = 1.0,
sdp_ratio: number = 0.4,
speed: number = 1.0,
) {
const mod = this.models.get(name);
if (!mod) throw new Error(`No model named ${name}`);
const [vits2, style] = mod;
return wasm.synthesize(
text,
this.tok,
async (a: BigInt64Array, b: BigInt64Array) => {
try {
const res = (
await this.deberta.run({
input_ids: new Tensor("int64", a, [1, a.length]),
attention_mask: new Tensor("int64", b, [1, b.length]),
})
)["output"];
return [new Uint32Array(res.dims), await res.getData(true)];
} catch (e) {
console.warn(e);
throw e;
}
},
async (
[a_shape, a_array]: any,
b_d: any,
c_d: any,
d_d: any,
e_d: any,
f: number,
g: number,
) => {
try {
const a = new Tensor("float32", a_array, [1, ...a_shape]);
const b = new Tensor("int64", b_d, [1, b_d.length]);
const c = new Tensor("int64", c_d, [1, c_d.length]);
const d = new Tensor("int64", d_d, [1, d_d.length]);
const e = new Tensor("float32", e_d, [1, e_d.length]);
const res = (
await vits2.run({
x_tst: b,
x_tst_lengths: new Tensor("int64", [b_d.length]),
sid: new Tensor("int64", [0]),
tones: c,
language: d,
bert: a,
style_vec: e,
sdp_ratio: new Tensor("float32", [f]),
length_scale: new Tensor("float32", [g]),
noise_scale: new Tensor("float32", [0.677]),
noise_scale_w: new Tensor("float32", [0.8]),
})
).output;
return [new Uint32Array(res.dims), await res.getData(true)];
} catch (e) {
console.warn(e);
throw e;
}
},
sdp_ratio,
1.0 / speed,
style_id,
style_weight,
style,
);
}
public async load(name: string, b: Uint8Array) {
const [style, vits2_b] = wasm.load_sbv2file(b);
const vits2 = await InferenceSession.create(vits2_b as Uint8Array, {
executionProviders: ["webnn", "webgpu", "wasm", "cpu"],
graphOptimizationLevel: "all",
});
this.models.set(name, [vits2, style]);
}
public async unload(name: string) {
return this.models.delete(name);
}
public modelList() {
return this.models.keys();
}
}