Rename examples and add pre-commit

This commit is contained in:
Václav Volhejn
2025-06-25 10:50:14 +02:00
parent 8bd3f59631
commit 7b818c2636
8 changed files with 347 additions and 58 deletions

View File

@@ -41,18 +41,17 @@ uv run scripts/streaming_stt.py \
# Rev16 === cer: 6.57% wer: 10.08% corpus_wer: 11.43% RTF = 40.34
# Earnings21 === cer: 5.73% wer: 9.84% corpus_wer: 10.38% RTF = 73.15
import dataclasses
import julius
import jiwer
from datasets import load_dataset, Dataset
from whisper.normalizers import EnglishTextNormalizer
import argparse
import torch
import moshi.models
import tqdm
import dataclasses
import time
import jiwer
import julius
import moshi.models
import torch
import tqdm
from datasets import Dataset, load_dataset
from whisper.normalizers import EnglishTextNormalizer
_NORMALIZER = EnglishTextNormalizer()
@@ -120,9 +119,9 @@ class AsrMetrics:
self.num_sequences += 1
def compute(self) -> dict:
assert (
self.num_sequences > 0
), "Unable to compute with total number of comparisons <= 0" # type: ignore
assert self.num_sequences > 0, (
"Unable to compute with total number of comparisons <= 0"
) # type: ignore
return {
"cer": (self.cer_sum / self.num_sequences),
"wer": (self.wer_sum / self.num_sequences),

View File

@@ -19,15 +19,15 @@ uv run scripts/streaming_stt_timestamps.py \
```
"""
import itertools
import dataclasses
import julius
import sphn
import argparse
import dataclasses
import itertools
import math
import torch
import julius
import moshi.models
import sphn
import torch
import tqdm

View File

@@ -10,17 +10,16 @@
import argparse
import asyncio
import json
import msgpack
import sphn
import struct
import time
import numpy as np
import msgpack
import sphn
import websockets
# Desired audio properties
TARGET_SAMPLE_RATE = 24000
TARGET_CHANNELS = 1 # Mono
HEADERS = {"kyutai-api-key": "open_token"}
all_text = []
transcript = []
finished = False
@@ -44,11 +43,13 @@ async def receive_messages(websocket):
print("received:", data)
if data["type"] == "Word":
all_text.append(data["text"])
transcript.append({
"speaker": "SPEAKER_00",
"text": data["text"],
"timestamp": [data["start_time"], data["start_time"]],
})
transcript.append(
{
"speaker": "SPEAKER_00",
"text": data["text"],
"timestamp": [data["start_time"], data["start_time"]],
}
)
if data["type"] == "EndWord":
if len(transcript) > 0:
transcript[-1]["timestamp"][1] = data["stop_time"]
@@ -64,15 +65,19 @@ async def send_messages(websocket, rtf: float):
global finished
audio_data = load_and_process_audio(args.in_file)
try:
# Start with a second of silence
chunk = { "type": "Audio", "pcm": [0.0] * 24000 }
# Start with a second of silence.
# This is needed for the 2.6B model for technical reasons.
chunk = {"type": "Audio", "pcm": [0.0] * 24000}
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
await websocket.send(msg)
chunk_size = 1920 # Send data in chunks
start_time = time.time()
for i in range(0, len(audio_data), chunk_size):
chunk = { "type": "Audio", "pcm": [float(x) for x in audio_data[i : i + chunk_size]] }
chunk = {
"type": "Audio",
"pcm": [float(x) for x in audio_data[i : i + chunk_size]],
}
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
await websocket.send(msg)
expected_send_time = start_time + (i + 1) / 24000 / rtf
@@ -81,13 +86,15 @@ async def send_messages(websocket, rtf: float):
await asyncio.sleep(expected_send_time - current_time)
else:
await asyncio.sleep(0.001)
chunk = { "type": "Audio", "pcm": [0.0] * 1920 * 5 }
chunk = {"type": "Audio", "pcm": [0.0] * 1920 * 5}
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
await websocket.send(msg)
msg = msgpack.packb({"type": "Marker", "id": 0}, use_bin_type=True, use_single_float=True)
msg = msgpack.packb(
{"type": "Marker", "id": 0}, use_bin_type=True, use_single_float=True
)
await websocket.send(msg)
for _ in range(35):
chunk = { "type": "Audio", "pcm": [0.0] * 1920 }
chunk = {"type": "Audio", "pcm": [0.0] * 1920}
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
await websocket.send(msg)
while True:
@@ -100,11 +107,10 @@ async def send_messages(websocket, rtf: float):
print("Connection closed while sending messages.")
async def stream_audio(url: str, rtf: float, api_key: str):
async def stream_audio(url: str, rtf: float):
"""Stream audio data to a WebSocket server."""
headers = {"kyutai-api-key": api_key}
async with websockets.connect(url, additional_headers=headers) as websocket:
async with websockets.connect(url, additional_headers=HEADERS) as websocket:
send_task = asyncio.create_task(send_messages(websocket, rtf))
receive_task = asyncio.create_task(receive_messages(websocket))
await asyncio.gather(send_task, receive_task)
@@ -115,7 +121,6 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("in_file")
parser.add_argument("--transcript")
parser.add_argument("--api-key", default="open_token")
parser.add_argument(
"--url",
help="The url of the server to which to send the audio",
@@ -125,7 +130,7 @@ if __name__ == "__main__":
args = parser.parse_args()
url = f"{args.url}/api/asr-streaming"
asyncio.run(stream_audio(url, args.rtf, args.api_key))
asyncio.run(stream_audio(url, args.rtf))
print(" ".join(all_text))
if args.transcript is not None:
with open(args.transcript, "w") as fobj:

View File

@@ -11,19 +11,16 @@
# ///
import argparse
from dataclasses import dataclass
import json
import numpy as np
import queue
import sounddevice as sd
from huggingface_hub import hf_hub_download
import mlx.core as mx
import mlx.nn as nn
from moshi_mlx import models, utils
import rustymimi
import sentencepiece
import sounddevice as sd
from huggingface_hub import hf_hub_download
from moshi_mlx import models, utils
if __name__ == "__main__":
parser = argparse.ArgumentParser()
@@ -69,6 +66,7 @@ if __name__ == "__main__":
)
block_queue = queue.Queue()
def audio_callback(indata, _frames, _time, _status):
block_queue.put(indata.copy())
@@ -84,7 +82,9 @@ if __name__ == "__main__":
block = block_queue.get()
block = block[None, :, 0]
other_audio_tokens = audio_tokenizer.encode_step(block[None, 0:1])
other_audio_tokens = mx.array(other_audio_tokens).transpose(0, 2, 1)[:, :, :other_codebooks]
other_audio_tokens = mx.array(other_audio_tokens).transpose(0, 2, 1)[
:, :, :other_codebooks
]
text_token = gen.step(other_audio_tokens[0])
text_token = text_token[0].item()
audio_tokens = gen.last_audio_tokens()
@@ -93,4 +93,3 @@ if __name__ == "__main__":
_text = text_tokenizer.id_to_piece(text_token) # type: ignore
_text = _text.replace("", " ")
print(_text, end="", flush=True)

View File

@@ -9,9 +9,9 @@
# ///
import argparse
import asyncio
import msgpack
import signal
import msgpack
import numpy as np
import sounddevice as sd
import websockets
@@ -21,6 +21,7 @@ TARGET_SAMPLE_RATE = 24000
TARGET_CHANNELS = 1 # Mono
audio_queue = asyncio.Queue()
async def receive_messages(websocket):
"""Receive and process messages from the WebSocket server."""
try:
@@ -47,22 +48,26 @@ async def send_messages(websocket):
except websockets.ConnectionClosed:
print("Connection closed while sending messages.")
async def stream_audio(url: str, api_key: str):
"""Stream audio data to a WebSocket server."""
print("Starting microphone recording...")
print("Press Ctrl+C to stop recording")
loop = asyncio.get_event_loop()
def audio_callback(indata, frames, time, status):
loop.call_soon_threadsafe(audio_queue.put_nowait, indata[:, 0].astype(np.float32).copy())
# Start audio stream
def audio_callback(indata, frames, time, status):
loop.call_soon_threadsafe(
audio_queue.put_nowait, indata[:, 0].astype(np.float32).copy()
)
# Start audio stream
with sd.InputStream(
samplerate=TARGET_SAMPLE_RATE,
channels=TARGET_CHANNELS,
dtype='float32',
dtype="float32",
callback=audio_callback,
blocksize=1920 # 80ms blocks
blocksize=1920, # 80ms blocks
):
headers = {"kyutai-api-key": api_key}
async with websockets.connect(url, additional_headers=headers) as websocket:
@@ -79,11 +84,15 @@ if __name__ == "__main__":
default="ws://127.0.0.1:8080",
)
parser.add_argument("--api-key", default="open_token")
parser.add_argument("--list-devices", action="store_true", help="List available audio devices")
parser.add_argument("--device", type=int, help="Input device ID (use --list-devices to see options)")
parser.add_argument(
"--list-devices", action="store_true", help="List available audio devices"
)
parser.add_argument(
"--device", type=int, help="Input device ID (use --list-devices to see options)"
)
args = parser.parse_args()
def handle_sigint(signum, frame):
print("Interrupted by user")
exit(0)
@@ -94,9 +103,9 @@ if __name__ == "__main__":
print("Available audio devices:")
print(sd.query_devices())
exit(0)
if args.device is not None:
sd.default.device[0] = args.device # Set input device
url = f"{args.url}/api/asr-streaming"
asyncio.run(stream_audio(url, args.api_key))