diff --git a/self-attention/Cargo.toml b/attention/Cargo.toml similarity index 100% rename from self-attention/Cargo.toml rename to attention/Cargo.toml diff --git a/self-attention/src/main.rs b/attention/src/main.rs similarity index 58% rename from self-attention/src/main.rs rename to attention/src/main.rs index b90c400..c5fda1c 100644 --- a/self-attention/src/main.rs +++ b/attention/src/main.rs @@ -63,6 +63,63 @@ impl Module for SelfAttention { Ok(attn_scores) } } + +struct MultiHeadAttention { + heads: usize, + d: usize, + masked: bool, + w_qs: Vec, + w_ks: Vec, + w_vs: Vec, +} + +impl MultiHeadAttention { + fn new(heads: usize, d: usize, masked: bool, vb: VarBuilder) -> Result { + let w_qs = (0..heads) + .map(|_| linear(d, d, vb.pp("w_q"))) + .collect::>()?; + let w_ks = (0..heads) + .map(|_| linear(d, d, vb.pp("w_k"))) + .collect::>()?; + let w_vs = (0..heads) + .map(|_| linear(d, d, vb.pp("w_v"))) + .collect::>()?; + + Ok(Self { + heads, + d, + masked, + w_qs, + w_ks, + w_vs, + }) + } +} + +impl Module for MultiHeadAttention { + fn forward(&self, x: &Tensor) -> CandleResult { + let mut attns = Vec::with_capacity(self.heads); + for i in 0..self.heads { + let q = self.w_qs[i].forward(x)?; + let k = self.w_ks[i].forward(x)?; + let v = self.w_vs[i].forward(x)?; + + let sims = q.matmul(&k.transpose(1, 0)?)?; + let scale = Tensor::new((self.d as f32).sqrt(), sims.device())?; + let mut scaled_sims = sims.broadcast_div(&scale)?; + if self.masked { + let mask = get_mask(scaled_sims.dims()[0], scaled_sims.device())?; + scaled_sims = masked_fill(&scaled_sims, &mask, f32::NEG_INFINITY)?; + } + let attn_pct = softmax(&scaled_sims, 1)?; + let attn_scores = attn_pct.matmul(&v)?; + attns.push(attn_scores); + } + let attns = Tensor::stack(&attns, 0)?; + Ok(attns) + } +} + fn main() -> Result<()> { let device = Device::cuda_if_available(0)?; let varmap = VarMap::new();