mirror of
https://github.com/vishpat/candle-coursera-ml.git
synced 2025-12-22 22:19:58 +00:00
Cleaned up
This commit is contained in:
@@ -7,7 +7,6 @@ use candle_nn::{linear, Linear, Module, VarMap};
|
||||
|
||||
struct SelfAttention {
|
||||
d: usize, // Embedding size
|
||||
scale: Tensor,
|
||||
w_q: Linear,
|
||||
w_k: Linear,
|
||||
w_v: Linear,
|
||||
@@ -16,18 +15,12 @@ struct SelfAttention {
|
||||
impl SelfAttention {
|
||||
fn new(d: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let d = d;
|
||||
let scale = Tensor::new((d as f32).sqrt(), vb.device())?;
|
||||
|
||||
let w_q = linear(d, d, vb.pp("w_q"))?;
|
||||
let w_k = linear(d, d, vb.pp("w_k"))?;
|
||||
let w_v = linear(d, d, vb.pp("w_v"))?;
|
||||
|
||||
Ok(Self {
|
||||
d,
|
||||
scale,
|
||||
w_q,
|
||||
w_k,
|
||||
w_v,
|
||||
})
|
||||
Ok(Self { d, w_q, w_k, w_v })
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,8 +31,8 @@ impl Module for SelfAttention {
|
||||
let v = self.w_v.forward(x)?;
|
||||
|
||||
let qk = q.matmul(&k.transpose(1, 0)?)?;
|
||||
|
||||
let qk = qk.broadcast_div(&self.scale)?;
|
||||
let scale = Tensor::new((self.d as f32).sqrt(), qk.device())?;
|
||||
let qk = qk.broadcast_div(&scale)?;
|
||||
let attn_pct = softmax(&qk, 1)?;
|
||||
let attn_scores = attn_pct.matmul(&v)?;
|
||||
Ok(attn_scores)
|
||||
|
||||
Reference in New Issue
Block a user