mirror of
https://github.com/kyutai-labs/delayed-streams-modeling.git
synced 2026-01-07 17:52:54 +00:00
Rename examples and add pre-commit
This commit is contained in:
@@ -41,18 +41,17 @@ uv run scripts/streaming_stt.py \
|
||||
# Rev16 === cer: 6.57% wer: 10.08% corpus_wer: 11.43% RTF = 40.34
|
||||
# Earnings21 === cer: 5.73% wer: 9.84% corpus_wer: 10.38% RTF = 73.15
|
||||
|
||||
import dataclasses
|
||||
import julius
|
||||
import jiwer
|
||||
from datasets import load_dataset, Dataset
|
||||
from whisper.normalizers import EnglishTextNormalizer
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import moshi.models
|
||||
import tqdm
|
||||
import dataclasses
|
||||
import time
|
||||
|
||||
import jiwer
|
||||
import julius
|
||||
import moshi.models
|
||||
import torch
|
||||
import tqdm
|
||||
from datasets import Dataset, load_dataset
|
||||
from whisper.normalizers import EnglishTextNormalizer
|
||||
|
||||
_NORMALIZER = EnglishTextNormalizer()
|
||||
|
||||
@@ -120,9 +119,9 @@ class AsrMetrics:
|
||||
self.num_sequences += 1
|
||||
|
||||
def compute(self) -> dict:
|
||||
assert (
|
||||
self.num_sequences > 0
|
||||
), "Unable to compute with total number of comparisons <= 0" # type: ignore
|
||||
assert self.num_sequences > 0, (
|
||||
"Unable to compute with total number of comparisons <= 0"
|
||||
) # type: ignore
|
||||
return {
|
||||
"cer": (self.cer_sum / self.num_sequences),
|
||||
"wer": (self.wer_sum / self.num_sequences),
|
||||
@@ -19,15 +19,15 @@ uv run scripts/streaming_stt_timestamps.py \
|
||||
```
|
||||
"""
|
||||
|
||||
import itertools
|
||||
import dataclasses
|
||||
import julius
|
||||
import sphn
|
||||
import argparse
|
||||
import dataclasses
|
||||
import itertools
|
||||
import math
|
||||
|
||||
import torch
|
||||
import julius
|
||||
import moshi.models
|
||||
import sphn
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
|
||||
@@ -10,17 +10,16 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import msgpack
|
||||
import sphn
|
||||
import struct
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import msgpack
|
||||
import sphn
|
||||
import websockets
|
||||
|
||||
# Desired audio properties
|
||||
TARGET_SAMPLE_RATE = 24000
|
||||
TARGET_CHANNELS = 1 # Mono
|
||||
HEADERS = {"kyutai-api-key": "open_token"}
|
||||
all_text = []
|
||||
transcript = []
|
||||
finished = False
|
||||
@@ -44,11 +43,13 @@ async def receive_messages(websocket):
|
||||
print("received:", data)
|
||||
if data["type"] == "Word":
|
||||
all_text.append(data["text"])
|
||||
transcript.append({
|
||||
"speaker": "SPEAKER_00",
|
||||
"text": data["text"],
|
||||
"timestamp": [data["start_time"], data["start_time"]],
|
||||
})
|
||||
transcript.append(
|
||||
{
|
||||
"speaker": "SPEAKER_00",
|
||||
"text": data["text"],
|
||||
"timestamp": [data["start_time"], data["start_time"]],
|
||||
}
|
||||
)
|
||||
if data["type"] == "EndWord":
|
||||
if len(transcript) > 0:
|
||||
transcript[-1]["timestamp"][1] = data["stop_time"]
|
||||
@@ -64,15 +65,19 @@ async def send_messages(websocket, rtf: float):
|
||||
global finished
|
||||
audio_data = load_and_process_audio(args.in_file)
|
||||
try:
|
||||
# Start with a second of silence
|
||||
chunk = { "type": "Audio", "pcm": [0.0] * 24000 }
|
||||
# Start with a second of silence.
|
||||
# This is needed for the 2.6B model for technical reasons.
|
||||
chunk = {"type": "Audio", "pcm": [0.0] * 24000}
|
||||
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
|
||||
await websocket.send(msg)
|
||||
|
||||
chunk_size = 1920 # Send data in chunks
|
||||
start_time = time.time()
|
||||
for i in range(0, len(audio_data), chunk_size):
|
||||
chunk = { "type": "Audio", "pcm": [float(x) for x in audio_data[i : i + chunk_size]] }
|
||||
chunk = {
|
||||
"type": "Audio",
|
||||
"pcm": [float(x) for x in audio_data[i : i + chunk_size]],
|
||||
}
|
||||
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
|
||||
await websocket.send(msg)
|
||||
expected_send_time = start_time + (i + 1) / 24000 / rtf
|
||||
@@ -81,13 +86,15 @@ async def send_messages(websocket, rtf: float):
|
||||
await asyncio.sleep(expected_send_time - current_time)
|
||||
else:
|
||||
await asyncio.sleep(0.001)
|
||||
chunk = { "type": "Audio", "pcm": [0.0] * 1920 * 5 }
|
||||
chunk = {"type": "Audio", "pcm": [0.0] * 1920 * 5}
|
||||
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
|
||||
await websocket.send(msg)
|
||||
msg = msgpack.packb({"type": "Marker", "id": 0}, use_bin_type=True, use_single_float=True)
|
||||
msg = msgpack.packb(
|
||||
{"type": "Marker", "id": 0}, use_bin_type=True, use_single_float=True
|
||||
)
|
||||
await websocket.send(msg)
|
||||
for _ in range(35):
|
||||
chunk = { "type": "Audio", "pcm": [0.0] * 1920 }
|
||||
chunk = {"type": "Audio", "pcm": [0.0] * 1920}
|
||||
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
|
||||
await websocket.send(msg)
|
||||
while True:
|
||||
@@ -100,11 +107,10 @@ async def send_messages(websocket, rtf: float):
|
||||
print("Connection closed while sending messages.")
|
||||
|
||||
|
||||
async def stream_audio(url: str, rtf: float, api_key: str):
|
||||
async def stream_audio(url: str, rtf: float):
|
||||
"""Stream audio data to a WebSocket server."""
|
||||
|
||||
headers = {"kyutai-api-key": api_key}
|
||||
async with websockets.connect(url, additional_headers=headers) as websocket:
|
||||
async with websockets.connect(url, additional_headers=HEADERS) as websocket:
|
||||
send_task = asyncio.create_task(send_messages(websocket, rtf))
|
||||
receive_task = asyncio.create_task(receive_messages(websocket))
|
||||
await asyncio.gather(send_task, receive_task)
|
||||
@@ -115,7 +121,6 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("in_file")
|
||||
parser.add_argument("--transcript")
|
||||
parser.add_argument("--api-key", default="open_token")
|
||||
parser.add_argument(
|
||||
"--url",
|
||||
help="The url of the server to which to send the audio",
|
||||
@@ -125,7 +130,7 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
url = f"{args.url}/api/asr-streaming"
|
||||
asyncio.run(stream_audio(url, args.rtf, args.api_key))
|
||||
asyncio.run(stream_audio(url, args.rtf))
|
||||
print(" ".join(all_text))
|
||||
if args.transcript is not None:
|
||||
with open(args.transcript, "w") as fobj:
|
||||
@@ -11,19 +11,16 @@
|
||||
# ///
|
||||
|
||||
import argparse
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
import numpy as np
|
||||
import queue
|
||||
import sounddevice as sd
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from moshi_mlx import models, utils
|
||||
import rustymimi
|
||||
import sentencepiece
|
||||
|
||||
import sounddevice as sd
|
||||
from huggingface_hub import hf_hub_download
|
||||
from moshi_mlx import models, utils
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
@@ -69,6 +66,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
block_queue = queue.Queue()
|
||||
|
||||
def audio_callback(indata, _frames, _time, _status):
|
||||
block_queue.put(indata.copy())
|
||||
|
||||
@@ -84,7 +82,9 @@ if __name__ == "__main__":
|
||||
block = block_queue.get()
|
||||
block = block[None, :, 0]
|
||||
other_audio_tokens = audio_tokenizer.encode_step(block[None, 0:1])
|
||||
other_audio_tokens = mx.array(other_audio_tokens).transpose(0, 2, 1)[:, :, :other_codebooks]
|
||||
other_audio_tokens = mx.array(other_audio_tokens).transpose(0, 2, 1)[
|
||||
:, :, :other_codebooks
|
||||
]
|
||||
text_token = gen.step(other_audio_tokens[0])
|
||||
text_token = text_token[0].item()
|
||||
audio_tokens = gen.last_audio_tokens()
|
||||
@@ -93,4 +93,3 @@ if __name__ == "__main__":
|
||||
_text = text_tokenizer.id_to_piece(text_token) # type: ignore
|
||||
_text = _text.replace("▁", " ")
|
||||
print(_text, end="", flush=True)
|
||||
|
||||
@@ -9,9 +9,9 @@
|
||||
# ///
|
||||
import argparse
|
||||
import asyncio
|
||||
import msgpack
|
||||
import signal
|
||||
|
||||
import msgpack
|
||||
import numpy as np
|
||||
import sounddevice as sd
|
||||
import websockets
|
||||
@@ -21,6 +21,7 @@ TARGET_SAMPLE_RATE = 24000
|
||||
TARGET_CHANNELS = 1 # Mono
|
||||
audio_queue = asyncio.Queue()
|
||||
|
||||
|
||||
async def receive_messages(websocket):
|
||||
"""Receive and process messages from the WebSocket server."""
|
||||
try:
|
||||
@@ -47,22 +48,26 @@ async def send_messages(websocket):
|
||||
except websockets.ConnectionClosed:
|
||||
print("Connection closed while sending messages.")
|
||||
|
||||
|
||||
async def stream_audio(url: str, api_key: str):
|
||||
"""Stream audio data to a WebSocket server."""
|
||||
print("Starting microphone recording...")
|
||||
print("Press Ctrl+C to stop recording")
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
def audio_callback(indata, frames, time, status):
|
||||
loop.call_soon_threadsafe(audio_queue.put_nowait, indata[:, 0].astype(np.float32).copy())
|
||||
|
||||
# Start audio stream
|
||||
def audio_callback(indata, frames, time, status):
|
||||
loop.call_soon_threadsafe(
|
||||
audio_queue.put_nowait, indata[:, 0].astype(np.float32).copy()
|
||||
)
|
||||
|
||||
# Start audio stream
|
||||
with sd.InputStream(
|
||||
samplerate=TARGET_SAMPLE_RATE,
|
||||
channels=TARGET_CHANNELS,
|
||||
dtype='float32',
|
||||
dtype="float32",
|
||||
callback=audio_callback,
|
||||
blocksize=1920 # 80ms blocks
|
||||
blocksize=1920, # 80ms blocks
|
||||
):
|
||||
headers = {"kyutai-api-key": api_key}
|
||||
async with websockets.connect(url, additional_headers=headers) as websocket:
|
||||
@@ -79,11 +84,15 @@ if __name__ == "__main__":
|
||||
default="ws://127.0.0.1:8080",
|
||||
)
|
||||
parser.add_argument("--api-key", default="open_token")
|
||||
parser.add_argument("--list-devices", action="store_true", help="List available audio devices")
|
||||
parser.add_argument("--device", type=int, help="Input device ID (use --list-devices to see options)")
|
||||
|
||||
parser.add_argument(
|
||||
"--list-devices", action="store_true", help="List available audio devices"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device", type=int, help="Input device ID (use --list-devices to see options)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def handle_sigint(signum, frame):
|
||||
print("Interrupted by user")
|
||||
exit(0)
|
||||
@@ -94,9 +103,9 @@ if __name__ == "__main__":
|
||||
print("Available audio devices:")
|
||||
print(sd.query_devices())
|
||||
exit(0)
|
||||
|
||||
|
||||
if args.device is not None:
|
||||
sd.default.device[0] = args.device # Set input device
|
||||
|
||||
|
||||
url = f"{args.url}/api/asr-streaming"
|
||||
asyncio.run(stream_audio(url, args.api_key))
|
||||
Reference in New Issue
Block a user