Fixed self attention

This commit is contained in:
Vishal Patil
2025-02-14 19:54:12 -05:00
parent 9e008f91e0
commit e7ddfccc85

View File

@@ -1,5 +1,5 @@
use anyhow::Ok;
use anyhow::Result;
use candle_core::Result as CandleResult;
use candle_core::{DType, Device, Tensor};
use candle_nn::ops::softmax;
use candle_nn::VarBuilder;
@@ -31,8 +31,8 @@ impl SelfAttention {
}
}
impl SelfAttention {
fn attention(&self, x: &Tensor) -> Result<Tensor> {
impl Module for SelfAttention {
fn forward(&self, x: &Tensor) -> CandleResult<Tensor> {
let q = self.w_q.forward(x)?;
let k = self.w_k.forward(x)?;
let v = self.w_v.forward(x)?;
@@ -40,9 +40,9 @@ impl SelfAttention {
let qk = q.matmul(&k.transpose(1, 0)?)?;
let qk = qk.broadcast_div(&self.scale)?;
let qk = softmax(&qk, 1)?;
Ok(qk.matmul(&v)?)
let attn_pct = softmax(&qk, 1)?;
let attn_scores = attn_pct.matmul(&v)?;
Ok(attn_scores)
}
}
fn main() -> Result<()> {
@@ -60,7 +60,7 @@ fn main() -> Result<()> {
&device,
)?;
let attn = self_attn.attention(&encoding_matrix)?;
let attn = self_attn.forward(&encoding_matrix)?;
println!("{}", attn);
Ok(())
}