mirror of
https://github.com/neodyland/sbv2-api.git
synced 2026-01-07 23:12:57 +00:00
breaking: support of length_scale, sdp_ratio, /models endpoint
This commit is contained in:
@@ -90,8 +90,18 @@ model = get_net_g(
|
||||
)
|
||||
|
||||
|
||||
def forward(*args):
|
||||
return model.infer(*args)
|
||||
def forward(x, x_len, sid, tone, lang, bert, style, length_scale, sdp_ratio):
|
||||
return model.infer(
|
||||
x,
|
||||
x_len,
|
||||
sid,
|
||||
tone,
|
||||
lang,
|
||||
bert,
|
||||
style,
|
||||
sdp_ratio=sdp_ratio,
|
||||
length_scale=length_scale,
|
||||
)
|
||||
|
||||
|
||||
model.forward = forward
|
||||
@@ -106,6 +116,8 @@ torch.onnx.export(
|
||||
lang_ids,
|
||||
bert,
|
||||
style_vec_tensor,
|
||||
torch.tensor(1.0),
|
||||
torch.tensor(0.0),
|
||||
),
|
||||
f"../models/model_{out_name}.onnx",
|
||||
verbose=True,
|
||||
@@ -124,6 +136,8 @@ torch.onnx.export(
|
||||
"language",
|
||||
"bert",
|
||||
"style_vec",
|
||||
"length_scale",
|
||||
"sdp_ratio",
|
||||
],
|
||||
output_names=["output"],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user