Use bfloat16 rather than half by default.

This commit is contained in:
Laurent
2025-07-05 23:02:58 +02:00
parent f9739881e6
commit bfc200f6ee
2 changed files with 2 additions and 2 deletions

View File

@@ -49,7 +49,7 @@ def main():
print("Loading model...")
checkpoint_info = CheckpointInfo.from_hf_repo(args.hf_repo)
tts_model = TTSModel.from_checkpoint_info(
checkpoint_info, n_q=32, temp=0.6, device=torch.device("cuda"), dtype=torch.half
checkpoint_info, n_q=32, temp=0.6, device=torch.device("cuda")
)
if args.inp == "-":