mirror of
https://github.com/kyutai-labs/delayed-streams-modeling.git
synced 2025-12-22 19:09:57 +00:00
Add quantization support and GGUF loading to standalone STT Rust script (#120)
* scripts to int8 quantize the thing * target bf16 to uint8, 2x reduction * able to load the model * quantized working * remove unused scripts * conditional init depending on quantized
This commit is contained in:
1
stt-rs/Cargo.lock
generated
1
stt-rs/Cargo.lock
generated
@@ -1427,6 +1427,7 @@ dependencies = [
|
||||
"anyhow",
|
||||
"candle-core",
|
||||
"candle-nn",
|
||||
"candle-transformers",
|
||||
"clap",
|
||||
"hf-hub",
|
||||
"kaudio",
|
||||
|
||||
@@ -7,6 +7,7 @@ edition = "2024"
|
||||
anyhow = "1.0"
|
||||
candle = { version = "0.9.1", package = "candle-core" }
|
||||
candle-nn = "0.9.1"
|
||||
candle-transformers = "0.9.1"
|
||||
clap = { version = "4.4.12", features = ["derive"] }
|
||||
hf-hub = "0.4.3"
|
||||
kaudio = "0.2.1"
|
||||
|
||||
@@ -15,6 +15,10 @@ struct Args {
|
||||
#[arg(long, default_value = "kyutai/stt-1b-en_fr-candle")]
|
||||
hf_repo: String,
|
||||
|
||||
/// Path to the model file in the repo.
|
||||
#[arg(long, default_value = "model.safetensors")]
|
||||
model_path: String,
|
||||
|
||||
/// Run the model on cpu.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
@@ -120,25 +124,39 @@ struct Model {
|
||||
|
||||
impl Model {
|
||||
fn load_from_hf(args: &Args, dev: &Device) -> Result<Self> {
|
||||
let dtype = dev.bf16_default_to_f32();
|
||||
|
||||
// Retrieve the model files from the Hugging Face Hub
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let repo = api.model(args.hf_repo.to_string());
|
||||
let config_file = repo.get("config.json")?;
|
||||
let config: Config = serde_json::from_str(&std::fs::read_to_string(&config_file)?)?;
|
||||
let tokenizer_file = repo.get(&config.tokenizer_name)?;
|
||||
let model_file = repo.get("model.safetensors")?;
|
||||
let model_file = repo.get(&args.model_path)?;
|
||||
let mimi_file = repo.get(&config.mimi_name)?;
|
||||
let is_quantized = model_file.to_str().unwrap().ends_with(".gguf");
|
||||
|
||||
let text_tokenizer = sentencepiece::SentencePieceProcessor::open(&tokenizer_file)?;
|
||||
let vb_lm =
|
||||
unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(&[&model_file], dtype, dev)? };
|
||||
|
||||
let lm = if is_quantized {
|
||||
let vb_lm = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
|
||||
&model_file,
|
||||
dev,
|
||||
)?;
|
||||
moshi::lm::LmModel::new(
|
||||
&config.model_config(args.vad),
|
||||
moshi::nn::MaybeQuantizedVarBuilder::Quantized(vb_lm),
|
||||
)?
|
||||
} else {
|
||||
let dtype = dev.bf16_default_to_f32();
|
||||
let vb_lm = unsafe {
|
||||
candle_nn::VarBuilder::from_mmaped_safetensors(&[&model_file], dtype, dev)?
|
||||
};
|
||||
moshi::lm::LmModel::new(
|
||||
&config.model_config(args.vad),
|
||||
moshi::nn::MaybeQuantizedVarBuilder::Real(vb_lm),
|
||||
)?
|
||||
};
|
||||
|
||||
let audio_tokenizer = moshi::mimi::load(mimi_file.to_str().unwrap(), Some(32), dev)?;
|
||||
let lm = moshi::lm::LmModel::new(
|
||||
&config.model_config(args.vad),
|
||||
moshi::nn::MaybeQuantizedVarBuilder::Real(vb_lm),
|
||||
)?;
|
||||
let asr_delay_in_tokens = (config.stt_config.audio_delay_seconds * 12.5) as usize;
|
||||
let state = moshi::asr::State::new(1, asr_delay_in_tokens, 0., audio_tokenizer, lm)?;
|
||||
Ok(Model {
|
||||
|
||||
Reference in New Issue
Block a user