mirror of
https://github.com/kyutai-labs/delayed-streams-modeling.git
synced 2026-01-07 17:52:54 +00:00
Use bfloat16 rather than half by default.
This commit is contained in:
@@ -49,7 +49,7 @@ def main():
|
||||
print("Loading model...")
|
||||
checkpoint_info = CheckpointInfo.from_hf_repo(args.hf_repo)
|
||||
tts_model = TTSModel.from_checkpoint_info(
|
||||
checkpoint_info, n_q=32, temp=0.6, device=torch.device("cuda"), dtype=torch.half
|
||||
checkpoint_info, n_q=32, temp=0.6, device=torch.device("cuda")
|
||||
)
|
||||
|
||||
if args.inp == "-":
|
||||
|
||||
Reference in New Issue
Block a user