breaking: support of length_scale, sdp_ratio, /models endpoint

This commit is contained in:
Googlefan
2024-09-11 04:42:11 +00:00
parent 83b69083ca
commit 441e35b9a6
12 changed files with 243 additions and 49 deletions

View File

@@ -90,8 +90,18 @@ model = get_net_g(
)
def forward(*args):
return model.infer(*args)
def forward(x, x_len, sid, tone, lang, bert, style, length_scale, sdp_ratio):
return model.infer(
x,
x_len,
sid,
tone,
lang,
bert,
style,
sdp_ratio=sdp_ratio,
length_scale=length_scale,
)
model.forward = forward
@@ -106,6 +116,8 @@ torch.onnx.export(
lang_ids,
bert,
style_vec_tensor,
torch.tensor(1.0),
torch.tensor(0.0),
),
f"../models/model_{out_name}.onnx",
verbose=True,
@@ -124,6 +136,8 @@ torch.onnx.export(
"language",
"bert",
"style_vec",
"length_scale",
"sdp_ratio",
],
output_names=["output"],
)