mirror of
https://github.com/vishpat/candle-coursera-ml.git
synced 2025-12-22 22:19:58 +00:00
Added masking
This commit is contained in:
@@ -7,20 +7,41 @@ use candle_nn::{linear, Linear, Module, VarMap};
|
||||
|
||||
struct SelfAttention {
|
||||
d: usize, // Embedding size
|
||||
masked: bool,
|
||||
w_q: Linear,
|
||||
w_k: Linear,
|
||||
w_v: Linear,
|
||||
}
|
||||
|
||||
fn get_mask(size: usize, device: &Device) -> CandleResult<Tensor> {
|
||||
let mask: Vec<_> = (0..size)
|
||||
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
|
||||
.collect();
|
||||
Tensor::from_slice(&mask, (size, size), device)
|
||||
}
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> CandleResult<Tensor> {
|
||||
let shape = mask.shape();
|
||||
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
||||
let m = mask.where_cond(&on_true, on_false)?;
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
impl SelfAttention {
|
||||
fn new(d: usize, vb: VarBuilder) -> Result<Self> {
|
||||
fn new(d: usize, masked: bool, vb: VarBuilder) -> Result<Self> {
|
||||
let d = d;
|
||||
|
||||
let w_q = linear(d, d, vb.pp("w_q"))?;
|
||||
let w_k = linear(d, d, vb.pp("w_k"))?;
|
||||
let w_v = linear(d, d, vb.pp("w_v"))?;
|
||||
|
||||
Ok(Self { d, w_q, w_k, w_v })
|
||||
Ok(Self {
|
||||
d,
|
||||
masked,
|
||||
w_q,
|
||||
w_k,
|
||||
w_v,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,10 +51,14 @@ impl Module for SelfAttention {
|
||||
let k = self.w_k.forward(x)?;
|
||||
let v = self.w_v.forward(x)?;
|
||||
|
||||
let qk = q.matmul(&k.transpose(1, 0)?)?;
|
||||
let scale = Tensor::new((self.d as f32).sqrt(), qk.device())?;
|
||||
let qk = qk.broadcast_div(&scale)?;
|
||||
let attn_pct = softmax(&qk, 1)?;
|
||||
let sims = q.matmul(&k.transpose(1, 0)?)?;
|
||||
let scale = Tensor::new((self.d as f32).sqrt(), sims.device())?;
|
||||
let mut scaled_sims = sims.broadcast_div(&scale)?;
|
||||
if self.masked {
|
||||
let mask = get_mask(scaled_sims.dims()[0], scaled_sims.device())?;
|
||||
scaled_sims = masked_fill(&scaled_sims, &mask, f32::NEG_INFINITY)?;
|
||||
}
|
||||
let attn_pct = softmax(&scaled_sims, 1)?;
|
||||
let attn_scores = attn_pct.matmul(&v)?;
|
||||
Ok(attn_scores)
|
||||
}
|
||||
@@ -42,7 +67,7 @@ fn main() -> Result<()> {
|
||||
let device = Device::cuda_if_available(0)?;
|
||||
let varmap = VarMap::new();
|
||||
let vs = VarBuilder::from_varmap(&varmap, DType::F32, &device);
|
||||
let self_attn = SelfAttention::new(2, vs)?;
|
||||
let self_attn = SelfAttention::new(2, true, vs)?;
|
||||
|
||||
let encoding_matrix = Tensor::new(
|
||||
vec![
|
||||
|
||||
Reference in New Issue
Block a user