mirror of
https://github.com/kyutai-labs/delayed-streams-modeling.git
synced 2025-12-22 19:09:57 +00:00
Enable ruff in the pre-commit hooks (#124)
* Enable ruff in the pre-commit hooks. * Disable the old hooks. * Install uv in the CI.
This commit is contained in:
4
.github/actions/moshi_build/action.yml
vendored
4
.github/actions/moshi_build/action.yml
vendored
@@ -26,3 +26,7 @@ runs:
|
|||||||
run: |
|
run: |
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
pre-commit install
|
pre-commit install
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v6
|
||||||
|
with:
|
||||||
|
version: "0.8.13"
|
||||||
|
|||||||
@@ -1,4 +1,18 @@
|
|||||||
repos:
|
repos:
|
||||||
|
- repo: local
|
||||||
|
hooks:
|
||||||
|
- id: ruff
|
||||||
|
name: ruff
|
||||||
|
language: system
|
||||||
|
entry: bash -c 'uvx ruff check'
|
||||||
|
pass_filenames: false
|
||||||
|
always_run: true
|
||||||
|
- id: ruff-format
|
||||||
|
name: ruff format
|
||||||
|
language: system
|
||||||
|
entry: bash -c 'uvx ruff format --check'
|
||||||
|
pass_filenames: false
|
||||||
|
always_run: true
|
||||||
# Get rid of Jupyter Notebook output because we don't want to keep it in Git
|
# Get rid of Jupyter Notebook output because we don't want to keep it in Git
|
||||||
- repo: https://github.com/kynan/nbstripout
|
- repo: https://github.com/kynan/nbstripout
|
||||||
rev: 0.8.1
|
rev: 0.8.1
|
||||||
@@ -9,14 +23,3 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: check-added-large-files
|
- id: check-added-large-files
|
||||||
args: ["--maxkb=2048"]
|
args: ["--maxkb=2048"]
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
|
||||||
# Ruff version.
|
|
||||||
rev: v0.11.7
|
|
||||||
hooks:
|
|
||||||
# Run the linter.
|
|
||||||
- id: ruff
|
|
||||||
types_or: [python, pyi] # Don't run on `jupyter` files
|
|
||||||
args: [--fix]
|
|
||||||
# Run the formatter.
|
|
||||||
- id: ruff-format
|
|
||||||
types_or: [python, pyi] # Don't run on `jupyter` files
|
|
||||||
|
|||||||
@@ -117,16 +117,22 @@ def main():
|
|||||||
break
|
break
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
else:
|
else:
|
||||||
|
last_time = time.time()
|
||||||
|
|
||||||
def _on_frame(frame):
|
def _on_frame(frame):
|
||||||
nonlocal _frames_cnt
|
nonlocal _frames_cnt
|
||||||
|
nonlocal last_time
|
||||||
if (frame != -1).all():
|
if (frame != -1).all():
|
||||||
_frames_cnt += 1
|
_frames_cnt += 1
|
||||||
print(f"generated {_frames_cnt / 12.5:.2f}s", end="\r", flush=True)
|
print(f"generated {_frames_cnt / 12.5:.2f}s", end="\r", flush=True)
|
||||||
|
print("{}", time.time() - last_time)
|
||||||
|
last_time = time.time()
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
result = tts_model.generate(
|
result = tts_model.generate(
|
||||||
[entries], [condition_attributes], on_frame=_on_frame
|
[entries], [condition_attributes], on_frame=_on_frame
|
||||||
)
|
)
|
||||||
|
print(f"\nTotal time: {time.time() - start_time:.2f}s")
|
||||||
with tts_model.mimi.streaming(1), torch.no_grad():
|
with tts_model.mimi.streaming(1), torch.no_grad():
|
||||||
pcms = []
|
pcms = []
|
||||||
for frame in result.frames[tts_model.delay_steps :]:
|
for frame in result.frames[tts_model.delay_steps :]:
|
||||||
|
|||||||
@@ -80,7 +80,6 @@
|
|||||||
" self.lm_gen.streaming_forever(batch_size)\n",
|
" self.lm_gen.streaming_forever(batch_size)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" def run(self, in_pcms: torch.Tensor):\n",
|
" def run(self, in_pcms: torch.Tensor):\n",
|
||||||
" device = self.lm_gen.lm_model.device\n",
|
|
||||||
" ntokens = 0\n",
|
" ntokens = 0\n",
|
||||||
" first_frame = True\n",
|
" first_frame = True\n",
|
||||||
" chunks = [\n",
|
" chunks = [\n",
|
||||||
|
|||||||
@@ -21,9 +21,6 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import argparse\n",
|
|
||||||
"import sys\n",
|
|
||||||
"\n",
|
|
||||||
"import numpy as np\n",
|
"import numpy as np\n",
|
||||||
"import torch\n",
|
"import torch\n",
|
||||||
"from moshi.models.loaders import CheckpointInfo\n",
|
"from moshi.models.loaders import CheckpointInfo\n",
|
||||||
@@ -64,9 +61,7 @@
|
|||||||
"# CFG coef goes here because the model was trained with CFG distillation,\n",
|
"# CFG coef goes here because the model was trained with CFG distillation,\n",
|
||||||
"# so it's not _actually_ doing CFG at inference time.\n",
|
"# so it's not _actually_ doing CFG at inference time.\n",
|
||||||
"# Also, if you are generating a dialog, you should have two voices in the list.\n",
|
"# Also, if you are generating a dialog, you should have two voices in the list.\n",
|
||||||
"condition_attributes = tts_model.make_condition_attributes(\n",
|
"condition_attributes = tts_model.make_condition_attributes([voice_path], cfg_coef=2.0)"
|
||||||
" [voice_path], cfg_coef=2.0\n",
|
|
||||||
")"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -79,17 +74,22 @@
|
|||||||
"print(\"Generating audio...\")\n",
|
"print(\"Generating audio...\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"pcms = []\n",
|
"pcms = []\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
"def _on_frame(frame):\n",
|
"def _on_frame(frame):\n",
|
||||||
" print(\"Step\", len(pcms), end=\"\\r\")\n",
|
" print(\"Step\", len(pcms), end=\"\\r\")\n",
|
||||||
" if (frame != -1).all():\n",
|
" if (frame != -1).all():\n",
|
||||||
" pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()\n",
|
" pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()\n",
|
||||||
" pcms.append(np.clip(pcm[0, 0], -1, 1))\n",
|
" pcms.append(np.clip(pcm[0, 0], -1, 1))\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"\n",
|
||||||
"# You could also generate multiple audios at once by extending the following lists.\n",
|
"# You could also generate multiple audios at once by extending the following lists.\n",
|
||||||
"all_entries = [entries]\n",
|
"all_entries = [entries]\n",
|
||||||
"all_condition_attributes = [condition_attributes]\n",
|
"all_condition_attributes = [condition_attributes]\n",
|
||||||
"with tts_model.mimi.streaming(len(all_entries)):\n",
|
"with tts_model.mimi.streaming(len(all_entries)):\n",
|
||||||
" result = tts_model.generate(all_entries, all_condition_attributes, on_frame=_on_frame)\n",
|
" result = tts_model.generate(\n",
|
||||||
|
" all_entries, all_condition_attributes, on_frame=_on_frame\n",
|
||||||
|
" )\n",
|
||||||
"\n",
|
"\n",
|
||||||
"print(\"Done generating.\")\n",
|
"print(\"Done generating.\")\n",
|
||||||
"audio = np.concatenate(pcms, axis=-1)"
|
"audio = np.concatenate(pcms, axis=-1)"
|
||||||
@@ -102,9 +102,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"display(\n",
|
"display(Audio(audio, rate=tts_model.mimi.sample_rate, autoplay=True))"
|
||||||
" Audio(audio, rate=tts_model.mimi.sample_rate, autoplay=True)\n",
|
|
||||||
")"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
Reference in New Issue
Block a user