diff --git a/self-attention/src/main.rs b/self-attention/src/main.rs index c1f4690..b4ff465 100644 --- a/self-attention/src/main.rs +++ b/self-attention/src/main.rs @@ -1,5 +1,5 @@ -use anyhow::Ok; use anyhow::Result; +use candle_core::Result as CandleResult; use candle_core::{DType, Device, Tensor}; use candle_nn::ops::softmax; use candle_nn::VarBuilder; @@ -31,8 +31,8 @@ impl SelfAttention { } } -impl SelfAttention { - fn attention(&self, x: &Tensor) -> Result { +impl Module for SelfAttention { + fn forward(&self, x: &Tensor) -> CandleResult { let q = self.w_q.forward(x)?; let k = self.w_k.forward(x)?; let v = self.w_v.forward(x)?; @@ -40,9 +40,9 @@ impl SelfAttention { let qk = q.matmul(&k.transpose(1, 0)?)?; let qk = qk.broadcast_div(&self.scale)?; - let qk = softmax(&qk, 1)?; - - Ok(qk.matmul(&v)?) + let attn_pct = softmax(&qk, 1)?; + let attn_scores = attn_pct.matmul(&v)?; + Ok(attn_scores) } } fn main() -> Result<()> { @@ -60,7 +60,7 @@ fn main() -> Result<()> { &device, )?; - let attn = self_attn.attention(&encoding_matrix)?; + let attn = self_attn.forward(&encoding_matrix)?; println!("{}", attn); Ok(()) }