Add Rust server usage example (#32)

* Run Ruff on tts_mlx.py

* Add tts_rust_server.py example

* Remove unused HF repo arguments and reset audio output data in TTS server script
This commit is contained in:
Václav Volhejn
2025-07-03 09:47:50 +02:00
committed by GitHub
parent d92e4c2695
commit ef52b8ef0f
4 changed files with 233 additions and 22 deletions

View File

@@ -14,19 +14,20 @@ import queue
import sys
import time
import numpy as np
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import sentencepiece
import sphn
import time
import sounddevice as sd
from moshi_mlx.client_utils import make_log
import sphn
from moshi_mlx import models
from moshi_mlx.client_utils import make_log
from moshi_mlx.models.tts import (
DEFAULT_DSM_TTS_REPO,
DEFAULT_DSM_TTS_VOICE_REPO,
TTSModel,
)
from moshi_mlx.utils.loaders import hf_get
from moshi_mlx.models.tts import TTSModel, DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO
def log(level: str, msg: str):
@@ -34,15 +35,32 @@ def log(level: str, msg: str):
def main():
parser = argparse.ArgumentParser(prog='moshi-tts', description='Run Moshi')
parser = argparse.ArgumentParser(
description="Run Kyutai TTS using the PyTorch implementation"
)
parser.add_argument("inp", type=str, help="Input file, use - for stdin")
parser.add_argument("out", type=str, help="Output file to generate, use - for playing the audio")
parser.add_argument("--hf-repo", type=str, default=DEFAULT_DSM_TTS_REPO,
help="HF repo in which to look for the pretrained models.")
parser.add_argument("--voice-repo", default=DEFAULT_DSM_TTS_VOICE_REPO,
help="HF repo in which to look for pre-computed voice embeddings.")
parser.add_argument("--voice", default="expresso/ex03-ex01_happy_001_channel1_334s.wav")
parser.add_argument("--quantize", type=int, help="The quantization to be applied, e.g. 8 for 8 bits.")
parser.add_argument(
"out", type=str, help="Output file to generate, use - for playing the audio"
)
parser.add_argument(
"--hf-repo",
type=str,
default=DEFAULT_DSM_TTS_REPO,
help="HF repo in which to look for the pretrained models.",
)
parser.add_argument(
"--voice-repo",
default=DEFAULT_DSM_TTS_VOICE_REPO,
help="HF repo in which to look for pre-computed voice embeddings.",
)
parser.add_argument(
"--voice", default="expresso/ex03-ex01_happy_001_channel1_334s.wav"
)
parser.add_argument(
"--quantize",
type=int,
help="The quantization to be applied, e.g. 8 for 8 bits.",
)
args = parser.parse_args()
mx.random.seed(299792458)
@@ -96,7 +114,7 @@ def main():
if tts_model.valid_cfg_conditionings:
# Model was trained with CFG distillation.
cfg_coef_conditioning = tts_model.cfg_coef
tts_model.cfg_coef = 1.
tts_model.cfg_coef = 1.0
cfg_is_no_text = False
cfg_is_no_prefix = False
else:
@@ -118,9 +136,12 @@ def main():
voices = [tts_model.get_voice_path(args.voice)]
else:
voices = []
all_attributes = [tts_model.make_condition_attributes(voices, cfg_coef_conditioning)]
all_attributes = [
tts_model.make_condition_attributes(voices, cfg_coef_conditioning)
]
wav_frames = queue.Queue()
def _on_frame(frame):
if (frame == -1).any():
return
@@ -146,16 +167,20 @@ def main():
return result
if args.out == "-":
def audio_callback(outdata, _a, _b, _c):
try:
pcm_data = wav_frames.get(block=False)
outdata[:, 0] = pcm_data
except queue.Empty:
outdata[:] = 0
with sd.OutputStream(samplerate=mimi.sample_rate,
blocksize=1920,
channels=1,
callback=audio_callback):
with sd.OutputStream(
samplerate=mimi.sample_rate,
blocksize=1920,
channels=1,
callback=audio_callback,
):
run()
time.sleep(3)
while True:

138
scripts/tts_rust_server.py Normal file
View File

@@ -0,0 +1,138 @@
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "msgpack",
# "numpy",
# "sphn",
# "websockets",
# "sounddevice",
# "tqdm",
# ]
# ///
import argparse
import asyncio
import sys
from urllib.parse import urlencode
import msgpack
import numpy as np
import sounddevice as sd
import sphn
import tqdm
import websockets
SAMPLE_RATE = 24000
TTS_TEXT = "Hello, this is a test of the moshi text to speech system, this should result in some nicely sounding generated voice."
DEFAULT_DSM_TTS_VOICE_REPO = "kyutai/tts-voices"
AUTH_TOKEN = "public_token"
async def receive_messages(websocket: websockets.ClientConnection, output_queue):
with tqdm.tqdm(desc="Receiving audio", unit=" seconds generated") as pbar:
accumulated_samples = 0
last_seconds = 0
async for message_bytes in websocket:
msg = msgpack.unpackb(message_bytes)
if msg["type"] == "Audio":
pcm = np.array(msg["pcm"]).astype(np.float32)
await output_queue.put(pcm)
accumulated_samples += len(msg["pcm"])
current_seconds = accumulated_samples // SAMPLE_RATE
if current_seconds > last_seconds:
pbar.update(current_seconds - last_seconds)
last_seconds = current_seconds
print("End of audio.")
await output_queue.put(None) # Signal end of audio
async def output_audio(out: str, output_queue: asyncio.Queue[np.ndarray | None]):
if out == "-":
should_exit = False
def audio_callback(outdata, _a, _b, _c):
nonlocal should_exit
try:
pcm_data = output_queue.get_nowait()
if pcm_data is not None:
outdata[:, 0] = pcm_data
else:
should_exit = True
outdata[:] = 0
except asyncio.QueueEmpty:
outdata[:] = 0
with sd.OutputStream(
samplerate=SAMPLE_RATE,
blocksize=1920,
channels=1,
callback=audio_callback,
):
while True:
if should_exit:
break
await asyncio.sleep(1)
else:
frames = []
while True:
item = await output_queue.get()
if item is None:
break
frames.append(item)
sphn.write_wav(out, np.concat(frames, -1), SAMPLE_RATE)
print(f"Saved audio to {out}")
async def websocket_client():
parser = argparse.ArgumentParser(description="Use the TTS streaming API")
parser.add_argument("inp", type=str, help="Input file, use - for stdin.")
parser.add_argument(
"out", type=str, help="Output file to generate, use - for playing the audio"
)
parser.add_argument(
"--voice",
default="expresso/ex03-ex01_happy_001_channel1_334s.wav",
help="The voice to use, relative to the voice repo root. "
f"See {DEFAULT_DSM_TTS_VOICE_REPO}",
)
parser.add_argument(
"--url",
help="The URL of the server to which to send the audio",
default="ws://127.0.0.1:8080",
)
parser.add_argument("--api-key", default="public_token")
args = parser.parse_args()
params = {"voice": args.voice, "format": "PcmMessagePack"}
uri = f"{args.url}/api/tts_streaming?{urlencode(params)}"
print(uri)
# TODO: stream the text instead of sending it all at once
if args.inp == "-":
if sys.stdin.isatty(): # Interactive
print("Enter text to synthesize (Ctrl+D to end input):")
text_to_tts = sys.stdin.read().strip()
else:
with open(args.inp, "r") as fobj:
text_to_tts = fobj.read().strip()
headers = {"kyutai-api-key": args.api_key}
async with websockets.connect(uri, additional_headers=headers) as websocket:
await websocket.send(msgpack.packb({"type": "Text", "text": text_to_tts}))
await websocket.send(msgpack.packb({"type": "Eos"}))
output_queue = asyncio.Queue()
receive_task = asyncio.create_task(receive_messages(websocket, output_queue))
output_audio_task = asyncio.create_task(output_audio(args.out, output_queue))
await asyncio.gather(receive_task, output_audio_task)
if __name__ == "__main__":
asyncio.run(websocket_client())