Add quantization support and GGUF loading to standalone STT Rust script (#120)

* scripts to int8 quantize the thing

* target bf16 to uint8, 2x reduction

* able to load the model

* quantized working

* remove unused scripts

* conditional init depending on quantized
This commit is contained in:
Bai Li
2025-08-25 00:28:48 -07:00
committed by GitHub
parent affc0a052b
commit 7d1e4b703a
3 changed files with 29 additions and 9 deletions

1
stt-rs/Cargo.lock generated
View File

@@ -1427,6 +1427,7 @@ dependencies = [
"anyhow",
"candle-core",
"candle-nn",
"candle-transformers",
"clap",
"hf-hub",
"kaudio",

View File

@@ -7,6 +7,7 @@ edition = "2024"
anyhow = "1.0"
candle = { version = "0.9.1", package = "candle-core" }
candle-nn = "0.9.1"
candle-transformers = "0.9.1"
clap = { version = "4.4.12", features = ["derive"] }
hf-hub = "0.4.3"
kaudio = "0.2.1"

View File

@@ -15,6 +15,10 @@ struct Args {
#[arg(long, default_value = "kyutai/stt-1b-en_fr-candle")]
hf_repo: String,
/// Path to the model file in the repo.
#[arg(long, default_value = "model.safetensors")]
model_path: String,
/// Run the model on cpu.
#[arg(long)]
cpu: bool,
@@ -120,25 +124,39 @@ struct Model {
impl Model {
fn load_from_hf(args: &Args, dev: &Device) -> Result<Self> {
let dtype = dev.bf16_default_to_f32();
// Retrieve the model files from the Hugging Face Hub
let api = hf_hub::api::sync::Api::new()?;
let repo = api.model(args.hf_repo.to_string());
let config_file = repo.get("config.json")?;
let config: Config = serde_json::from_str(&std::fs::read_to_string(&config_file)?)?;
let tokenizer_file = repo.get(&config.tokenizer_name)?;
let model_file = repo.get("model.safetensors")?;
let model_file = repo.get(&args.model_path)?;
let mimi_file = repo.get(&config.mimi_name)?;
let is_quantized = model_file.to_str().unwrap().ends_with(".gguf");
let text_tokenizer = sentencepiece::SentencePieceProcessor::open(&tokenizer_file)?;
let vb_lm =
unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(&[&model_file], dtype, dev)? };
let lm = if is_quantized {
let vb_lm = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
&model_file,
dev,
)?;
moshi::lm::LmModel::new(
&config.model_config(args.vad),
moshi::nn::MaybeQuantizedVarBuilder::Quantized(vb_lm),
)?
} else {
let dtype = dev.bf16_default_to_f32();
let vb_lm = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&[&model_file], dtype, dev)?
};
moshi::lm::LmModel::new(
&config.model_config(args.vad),
moshi::nn::MaybeQuantizedVarBuilder::Real(vb_lm),
)?
};
let audio_tokenizer = moshi::mimi::load(mimi_file.to_str().unwrap(), Some(32), dev)?;
let lm = moshi::lm::LmModel::new(
&config.model_config(args.vad),
moshi::nn::MaybeQuantizedVarBuilder::Real(vb_lm),
)?;
let asr_delay_in_tokens = (config.stt_config.audio_delay_seconds * 12.5) as usize;
let state = moshi::asr::State::new(1, asr_delay_in_tokens, 0., audio_tokenizer, lm)?;
Ok(Model {