Add PyTorch notebook and documentation (#29)

* Add example for PyTorch implementation

* Document PyTorch and MLX examples

* Reorganize for TTS

* Remove waitlist signup CTA
This commit is contained in:
Václav Volhejn
2025-07-02 17:51:27 +02:00
committed by GitHub
parent 96ff217437
commit 07ac744609
3 changed files with 228 additions and 25 deletions

View File

@@ -17,16 +17,6 @@ from moshi.models.loaders import CheckpointInfo
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel
def audio_to_int16(audio: np.ndarray) -> np.ndarray:
if audio.dtype == np.int16:
return audio
elif audio.dtype == np.float32:
# Multiply by 32767 and not 32768 so that int16 doesn't overflow.
return (np.clip(audio, -1, 1) * 32767).astype(np.int16)
else:
raise TypeError(f"Unsupported audio data type: {audio.dtype}")
def play_audio(audio: np.ndarray, sample_rate: int):
# Requires the Portaudio library which might not be available in all environments.
import sounddevice as sd
@@ -86,7 +76,8 @@ def main():
)
print("Generating audio...")
# This doesn't do streaming generation,
# This doesn't do streaming generation, but the model allows it. For now, see Rust
# example.
result = tts_model.generate([entries], [condition_attributes])
frames = torch.cat(result.frames, dim=-1)