Add Rust server usage example (#32)

* Run Ruff on tts_mlx.py

* Add tts_rust_server.py example

* Remove unused HF repo arguments and reset audio output data in TTS server script
This commit is contained in:
Václav Volhejn
2025-07-03 09:47:50 +02:00
committed by GitHub
parent d92e4c2695
commit ef52b8ef0f
4 changed files with 233 additions and 22 deletions

View File

@@ -14,19 +14,20 @@ import queue
import sys
import time
import numpy as np
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import sentencepiece
import sphn
import time
import sounddevice as sd
from moshi_mlx.client_utils import make_log
import sphn
from moshi_mlx import models
from moshi_mlx.client_utils import make_log
from moshi_mlx.models.tts import (
DEFAULT_DSM_TTS_REPO,
DEFAULT_DSM_TTS_VOICE_REPO,
TTSModel,
)
from moshi_mlx.utils.loaders import hf_get
from moshi_mlx.models.tts import TTSModel, DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO
def log(level: str, msg: str):
@@ -34,15 +35,32 @@ def log(level: str, msg: str):
def main():
parser = argparse.ArgumentParser(prog='moshi-tts', description='Run Moshi')
parser = argparse.ArgumentParser(
description="Run Kyutai TTS using the PyTorch implementation"
)
parser.add_argument("inp", type=str, help="Input file, use - for stdin")
parser.add_argument("out", type=str, help="Output file to generate, use - for playing the audio")
parser.add_argument("--hf-repo", type=str, default=DEFAULT_DSM_TTS_REPO,
help="HF repo in which to look for the pretrained models.")
parser.add_argument("--voice-repo", default=DEFAULT_DSM_TTS_VOICE_REPO,
help="HF repo in which to look for pre-computed voice embeddings.")
parser.add_argument("--voice", default="expresso/ex03-ex01_happy_001_channel1_334s.wav")
parser.add_argument("--quantize", type=int, help="The quantization to be applied, e.g. 8 for 8 bits.")
parser.add_argument(
"out", type=str, help="Output file to generate, use - for playing the audio"
)
parser.add_argument(
"--hf-repo",
type=str,
default=DEFAULT_DSM_TTS_REPO,
help="HF repo in which to look for the pretrained models.",
)
parser.add_argument(
"--voice-repo",
default=DEFAULT_DSM_TTS_VOICE_REPO,
help="HF repo in which to look for pre-computed voice embeddings.",
)
parser.add_argument(
"--voice", default="expresso/ex03-ex01_happy_001_channel1_334s.wav"
)
parser.add_argument(
"--quantize",
type=int,
help="The quantization to be applied, e.g. 8 for 8 bits.",
)
args = parser.parse_args()
mx.random.seed(299792458)
@@ -96,7 +114,7 @@ def main():
if tts_model.valid_cfg_conditionings:
# Model was trained with CFG distillation.
cfg_coef_conditioning = tts_model.cfg_coef
tts_model.cfg_coef = 1.
tts_model.cfg_coef = 1.0
cfg_is_no_text = False
cfg_is_no_prefix = False
else:
@@ -118,9 +136,12 @@ def main():
voices = [tts_model.get_voice_path(args.voice)]
else:
voices = []
all_attributes = [tts_model.make_condition_attributes(voices, cfg_coef_conditioning)]
all_attributes = [
tts_model.make_condition_attributes(voices, cfg_coef_conditioning)
]
wav_frames = queue.Queue()
def _on_frame(frame):
if (frame == -1).any():
return
@@ -146,16 +167,20 @@ def main():
return result
if args.out == "-":
def audio_callback(outdata, _a, _b, _c):
try:
pcm_data = wav_frames.get(block=False)
outdata[:, 0] = pcm_data
except queue.Empty:
outdata[:] = 0
with sd.OutputStream(samplerate=mimi.sample_rate,
blocksize=1920,
channels=1,
callback=audio_callback):
with sd.OutputStream(
samplerate=mimi.sample_rate,
blocksize=1920,
channels=1,
callback=audio_callback,
):
run()
time.sleep(3)
while True: