mirror of
https://github.com/kyutai-labs/delayed-streams-modeling.git
synced 2025-12-22 19:09:57 +00:00
Enable ruff in the pre-commit hooks (#124)
* Enable ruff in the pre-commit hooks. * Disable the old hooks. * Install uv in the CI.
This commit is contained in:
4
.github/actions/moshi_build/action.yml
vendored
4
.github/actions/moshi_build/action.yml
vendored
@@ -26,3 +26,7 @@ runs:
|
||||
run: |
|
||||
source env/bin/activate
|
||||
pre-commit install
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
version: "0.8.13"
|
||||
|
||||
@@ -1,4 +1,18 @@
|
||||
repos:
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: ruff
|
||||
name: ruff
|
||||
language: system
|
||||
entry: bash -c 'uvx ruff check'
|
||||
pass_filenames: false
|
||||
always_run: true
|
||||
- id: ruff-format
|
||||
name: ruff format
|
||||
language: system
|
||||
entry: bash -c 'uvx ruff format --check'
|
||||
pass_filenames: false
|
||||
always_run: true
|
||||
# Get rid of Jupyter Notebook output because we don't want to keep it in Git
|
||||
- repo: https://github.com/kynan/nbstripout
|
||||
rev: 0.8.1
|
||||
@@ -9,14 +23,3 @@ repos:
|
||||
hooks:
|
||||
- id: check-added-large-files
|
||||
args: ["--maxkb=2048"]
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
rev: v0.11.7
|
||||
hooks:
|
||||
# Run the linter.
|
||||
- id: ruff
|
||||
types_or: [python, pyi] # Don't run on `jupyter` files
|
||||
args: [--fix]
|
||||
# Run the formatter.
|
||||
- id: ruff-format
|
||||
types_or: [python, pyi] # Don't run on `jupyter` files
|
||||
|
||||
@@ -117,16 +117,22 @@ def main():
|
||||
break
|
||||
time.sleep(1)
|
||||
else:
|
||||
last_time = time.time()
|
||||
|
||||
def _on_frame(frame):
|
||||
nonlocal _frames_cnt
|
||||
nonlocal last_time
|
||||
if (frame != -1).all():
|
||||
_frames_cnt += 1
|
||||
print(f"generated {_frames_cnt / 12.5:.2f}s", end="\r", flush=True)
|
||||
print("{}", time.time() - last_time)
|
||||
last_time = time.time()
|
||||
|
||||
start_time = time.time()
|
||||
result = tts_model.generate(
|
||||
[entries], [condition_attributes], on_frame=_on_frame
|
||||
)
|
||||
print(f"\nTotal time: {time.time() - start_time:.2f}s")
|
||||
with tts_model.mimi.streaming(1), torch.no_grad():
|
||||
pcms = []
|
||||
for frame in result.frames[tts_model.delay_steps :]:
|
||||
|
||||
@@ -80,7 +80,6 @@
|
||||
" self.lm_gen.streaming_forever(batch_size)\n",
|
||||
"\n",
|
||||
" def run(self, in_pcms: torch.Tensor):\n",
|
||||
" device = self.lm_gen.lm_model.device\n",
|
||||
" ntokens = 0\n",
|
||||
" first_frame = True\n",
|
||||
" chunks = [\n",
|
||||
|
||||
@@ -21,9 +21,6 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import argparse\n",
|
||||
"import sys\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
"from moshi.models.loaders import CheckpointInfo\n",
|
||||
@@ -64,9 +61,7 @@
|
||||
"# CFG coef goes here because the model was trained with CFG distillation,\n",
|
||||
"# so it's not _actually_ doing CFG at inference time.\n",
|
||||
"# Also, if you are generating a dialog, you should have two voices in the list.\n",
|
||||
"condition_attributes = tts_model.make_condition_attributes(\n",
|
||||
" [voice_path], cfg_coef=2.0\n",
|
||||
")"
|
||||
"condition_attributes = tts_model.make_condition_attributes([voice_path], cfg_coef=2.0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -79,17 +74,22 @@
|
||||
"print(\"Generating audio...\")\n",
|
||||
"\n",
|
||||
"pcms = []\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def _on_frame(frame):\n",
|
||||
" print(\"Step\", len(pcms), end=\"\\r\")\n",
|
||||
" if (frame != -1).all():\n",
|
||||
" pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()\n",
|
||||
" pcms.append(np.clip(pcm[0, 0], -1, 1))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# You could also generate multiple audios at once by extending the following lists.\n",
|
||||
"all_entries = [entries]\n",
|
||||
"all_condition_attributes = [condition_attributes]\n",
|
||||
"with tts_model.mimi.streaming(len(all_entries)):\n",
|
||||
" result = tts_model.generate(all_entries, all_condition_attributes, on_frame=_on_frame)\n",
|
||||
" result = tts_model.generate(\n",
|
||||
" all_entries, all_condition_attributes, on_frame=_on_frame\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(\"Done generating.\")\n",
|
||||
"audio = np.concatenate(pcms, axis=-1)"
|
||||
@@ -102,9 +102,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"display(\n",
|
||||
" Audio(audio, rate=tts_model.mimi.sample_rate, autoplay=True)\n",
|
||||
")"
|
||||
"display(Audio(audio, rate=tts_model.mimi.sample_rate, autoplay=True))"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user