diff --git a/stt-rs/Cargo.lock b/stt-rs/Cargo.lock index 6c777ad..e87e292 100644 --- a/stt-rs/Cargo.lock +++ b/stt-rs/Cargo.lock @@ -1427,6 +1427,7 @@ dependencies = [ "anyhow", "candle-core", "candle-nn", + "candle-transformers", "clap", "hf-hub", "kaudio", diff --git a/stt-rs/Cargo.toml b/stt-rs/Cargo.toml index b2934e8..6810450 100644 --- a/stt-rs/Cargo.toml +++ b/stt-rs/Cargo.toml @@ -7,6 +7,7 @@ edition = "2024" anyhow = "1.0" candle = { version = "0.9.1", package = "candle-core" } candle-nn = "0.9.1" +candle-transformers = "0.9.1" clap = { version = "4.4.12", features = ["derive"] } hf-hub = "0.4.3" kaudio = "0.2.1" diff --git a/stt-rs/src/main.rs b/stt-rs/src/main.rs index 4f81327..71faacb 100644 --- a/stt-rs/src/main.rs +++ b/stt-rs/src/main.rs @@ -15,6 +15,10 @@ struct Args { #[arg(long, default_value = "kyutai/stt-1b-en_fr-candle")] hf_repo: String, + /// Path to the model file in the repo. + #[arg(long, default_value = "model.safetensors")] + model_path: String, + /// Run the model on cpu. #[arg(long)] cpu: bool, @@ -120,25 +124,39 @@ struct Model { impl Model { fn load_from_hf(args: &Args, dev: &Device) -> Result { - let dtype = dev.bf16_default_to_f32(); - // Retrieve the model files from the Hugging Face Hub let api = hf_hub::api::sync::Api::new()?; let repo = api.model(args.hf_repo.to_string()); let config_file = repo.get("config.json")?; let config: Config = serde_json::from_str(&std::fs::read_to_string(&config_file)?)?; let tokenizer_file = repo.get(&config.tokenizer_name)?; - let model_file = repo.get("model.safetensors")?; + let model_file = repo.get(&args.model_path)?; let mimi_file = repo.get(&config.mimi_name)?; + let is_quantized = model_file.to_str().unwrap().ends_with(".gguf"); let text_tokenizer = sentencepiece::SentencePieceProcessor::open(&tokenizer_file)?; - let vb_lm = - unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(&[&model_file], dtype, dev)? }; + + let lm = if is_quantized { + let vb_lm = candle_transformers::quantized_var_builder::VarBuilder::from_gguf( + &model_file, + dev, + )?; + moshi::lm::LmModel::new( + &config.model_config(args.vad), + moshi::nn::MaybeQuantizedVarBuilder::Quantized(vb_lm), + )? + } else { + let dtype = dev.bf16_default_to_f32(); + let vb_lm = unsafe { + candle_nn::VarBuilder::from_mmaped_safetensors(&[&model_file], dtype, dev)? + }; + moshi::lm::LmModel::new( + &config.model_config(args.vad), + moshi::nn::MaybeQuantizedVarBuilder::Real(vb_lm), + )? + }; + let audio_tokenizer = moshi::mimi::load(mimi_file.to_str().unwrap(), Some(32), dev)?; - let lm = moshi::lm::LmModel::new( - &config.model_config(args.vad), - moshi::nn::MaybeQuantizedVarBuilder::Real(vb_lm), - )?; let asr_delay_in_tokens = (config.stt_config.audio_delay_seconds * 12.5) as usize; let state = moshi::asr::State::new(1, asr_delay_in_tokens, 0., audio_tokenizer, lm)?; Ok(Model {