Cleaned up

This commit is contained in:
Vishal Patil
2025-02-14 19:58:15 -05:00
parent e7ddfccc85
commit 51d4b2605c

View File

@@ -7,7 +7,6 @@ use candle_nn::{linear, Linear, Module, VarMap};
struct SelfAttention {
d: usize, // Embedding size
scale: Tensor,
w_q: Linear,
w_k: Linear,
w_v: Linear,
@@ -16,18 +15,12 @@ struct SelfAttention {
impl SelfAttention {
fn new(d: usize, vb: VarBuilder) -> Result<Self> {
let d = d;
let scale = Tensor::new((d as f32).sqrt(), vb.device())?;
let w_q = linear(d, d, vb.pp("w_q"))?;
let w_k = linear(d, d, vb.pp("w_k"))?;
let w_v = linear(d, d, vb.pp("w_v"))?;
Ok(Self {
d,
scale,
w_q,
w_k,
w_v,
})
Ok(Self { d, w_q, w_k, w_v })
}
}
@@ -38,8 +31,8 @@ impl Module for SelfAttention {
let v = self.w_v.forward(x)?;
let qk = q.matmul(&k.transpose(1, 0)?)?;
let qk = qk.broadcast_div(&self.scale)?;
let scale = Tensor::new((self.d as f32).sqrt(), qk.device())?;
let qk = qk.broadcast_div(&scale)?;
let attn_pct = softmax(&qk, 1)?;
let attn_scores = attn_pct.matmul(&v)?;
Ok(attn_scores)