diff --git a/self-attention/src/main.rs b/self-attention/src/main.rs index b4ff465..462dd0a 100644 --- a/self-attention/src/main.rs +++ b/self-attention/src/main.rs @@ -7,7 +7,6 @@ use candle_nn::{linear, Linear, Module, VarMap}; struct SelfAttention { d: usize, // Embedding size - scale: Tensor, w_q: Linear, w_k: Linear, w_v: Linear, @@ -16,18 +15,12 @@ struct SelfAttention { impl SelfAttention { fn new(d: usize, vb: VarBuilder) -> Result { let d = d; - let scale = Tensor::new((d as f32).sqrt(), vb.device())?; + let w_q = linear(d, d, vb.pp("w_q"))?; let w_k = linear(d, d, vb.pp("w_k"))?; let w_v = linear(d, d, vb.pp("w_v"))?; - Ok(Self { - d, - scale, - w_q, - w_k, - w_v, - }) + Ok(Self { d, w_q, w_k, w_v }) } } @@ -38,8 +31,8 @@ impl Module for SelfAttention { let v = self.w_v.forward(x)?; let qk = q.matmul(&k.transpose(1, 0)?)?; - - let qk = qk.broadcast_div(&self.scale)?; + let scale = Tensor::new((self.d as f32).sqrt(), qk.device())?; + let qk = qk.broadcast_div(&scale)?; let attn_pct = softmax(&qk, 1)?; let attn_scores = attn_pct.matmul(&v)?; Ok(attn_scores)