mirror of
https://github.com/vishpat/candle-coursera-ml.git
synced 2025-12-22 22:19:58 +00:00
Added Multi-Head attention
This commit is contained in:
@@ -63,6 +63,63 @@ impl Module for SelfAttention {
|
|||||||
Ok(attn_scores)
|
Ok(attn_scores)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct MultiHeadAttention {
|
||||||
|
heads: usize,
|
||||||
|
d: usize,
|
||||||
|
masked: bool,
|
||||||
|
w_qs: Vec<Linear>,
|
||||||
|
w_ks: Vec<Linear>,
|
||||||
|
w_vs: Vec<Linear>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MultiHeadAttention {
|
||||||
|
fn new(heads: usize, d: usize, masked: bool, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let w_qs = (0..heads)
|
||||||
|
.map(|_| linear(d, d, vb.pp("w_q")))
|
||||||
|
.collect::<CandleResult<_>>()?;
|
||||||
|
let w_ks = (0..heads)
|
||||||
|
.map(|_| linear(d, d, vb.pp("w_k")))
|
||||||
|
.collect::<CandleResult<_>>()?;
|
||||||
|
let w_vs = (0..heads)
|
||||||
|
.map(|_| linear(d, d, vb.pp("w_v")))
|
||||||
|
.collect::<CandleResult<_>>()?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
heads,
|
||||||
|
d,
|
||||||
|
masked,
|
||||||
|
w_qs,
|
||||||
|
w_ks,
|
||||||
|
w_vs,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for MultiHeadAttention {
|
||||||
|
fn forward(&self, x: &Tensor) -> CandleResult<Tensor> {
|
||||||
|
let mut attns = Vec::with_capacity(self.heads);
|
||||||
|
for i in 0..self.heads {
|
||||||
|
let q = self.w_qs[i].forward(x)?;
|
||||||
|
let k = self.w_ks[i].forward(x)?;
|
||||||
|
let v = self.w_vs[i].forward(x)?;
|
||||||
|
|
||||||
|
let sims = q.matmul(&k.transpose(1, 0)?)?;
|
||||||
|
let scale = Tensor::new((self.d as f32).sqrt(), sims.device())?;
|
||||||
|
let mut scaled_sims = sims.broadcast_div(&scale)?;
|
||||||
|
if self.masked {
|
||||||
|
let mask = get_mask(scaled_sims.dims()[0], scaled_sims.device())?;
|
||||||
|
scaled_sims = masked_fill(&scaled_sims, &mask, f32::NEG_INFINITY)?;
|
||||||
|
}
|
||||||
|
let attn_pct = softmax(&scaled_sims, 1)?;
|
||||||
|
let attn_scores = attn_pct.matmul(&v)?;
|
||||||
|
attns.push(attn_scores);
|
||||||
|
}
|
||||||
|
let attns = Tensor::stack(&attns, 0)?;
|
||||||
|
Ok(attns)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
let device = Device::cuda_if_available(0)?;
|
let device = Device::cuda_if_available(0)?;
|
||||||
let varmap = VarMap::new();
|
let varmap = VarMap::new();
|
||||||
Reference in New Issue
Block a user