mirror of
https://github.com/vishpat/candle-coursera-ml.git
synced 2025-12-22 22:19:58 +00:00
Checkpoint
This commit is contained in:
@@ -33,18 +33,15 @@ impl SelfAttention {
|
||||
|
||||
impl SelfAttention {
|
||||
fn attention(&self, x: &Tensor) -> Result<Tensor> {
|
||||
println!("x: {:?}", x);
|
||||
let q = self.w_q.forward(x)?;
|
||||
let k = self.w_k.forward(x)?;
|
||||
let v = self.w_v.forward(x)?;
|
||||
|
||||
println!("q: {:?}, k: {:?}", q, k);
|
||||
let qk = q.matmul(&k.transpose(1, 0)?)?;
|
||||
|
||||
let qk = qk.broadcast_div(&self.scale)?;
|
||||
let qk = softmax(&qk, 1)?;
|
||||
|
||||
println!("qk: {:?}, v: {:?}", qk, v);
|
||||
Ok(qk.matmul(&v)?)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user