From b86ef2201f626af040dc292a097e122ece86ec04 Mon Sep 17 00:00:00 2001 From: Vishal Patil Date: Fri, 14 Feb 2025 10:16:06 -0500 Subject: [PATCH] Able to compile --- self-intention/src/main.rs | 45 ++++++++++++++++++++++++-------------- 1 file changed, 29 insertions(+), 16 deletions(-) diff --git a/self-intention/src/main.rs b/self-intention/src/main.rs index 5db4984..a092ec0 100644 --- a/self-intention/src/main.rs +++ b/self-intention/src/main.rs @@ -1,43 +1,56 @@ use anyhow::Ok; use anyhow::Result; +use candle_core::Var; use candle_core::{DType, Device, Tensor, D}; use candle_nn::ops::softmax; -use candle_nn::{loss, Linear, Module, Optimizer, Sequential, VarMap, SGD}; +use candle_nn::VarBuilder; +use candle_nn::{linear, Linear, Module, Optimizer, Sequential, VarMap, SGD}; struct SelfAttention { - d: u32, // Embedding size + d: usize, // Embedding size + scale: Tensor, w_q: Linear, w_k: Linear, w_v: Linear, } impl SelfAttention { - fn new(d: u32) -> Self { + fn new(d: usize, vb: VarBuilder) -> Result { let d = d; - let w_q = Linear::new(d, d); - let w_k = Linear::new(d, d); - let w_v = Linear::new(d, d); + let scale = Tensor::new((d as f32).sqrt(), vb.device())?; + let w_q = linear(d, d, vb.pp("w_q"))?; + let w_k = linear(d, d, vb.pp("w_k"))?; + let w_v = linear(d, d, vb.pp("w_v"))?; - Self { d, w_q, w_k, w_v } + Ok(Self { + d, + scale, + w_q, + w_k, + w_v, + }) } } -impl Module for SelfAttention { - fn forward(&self, x: &Tensor) -> Tensor { - let q = self.w_q.forward(x); - let k = self.w_k.forward(x); - let v = self.w_v.forward(x); +impl SelfAttention { + fn attention(&self, x: &Tensor) -> Result { + let q = self.w_q.forward(x)?; + let k = self.w_k.forward(x)?; + let v = self.w_v.forward(x)?; - let qk = q.matmul(&k.transpose(1, 0)); - let qk = qk / D::from_f32((self.d as f32).sqrt()); - let qk = softmax(&qk, 1); + let qk = q.matmul(&k.transpose(1, 0)?)?; - qk.matmul(&v) + let qk = qk.broadcast_div(&self.scale)?; + let qk = softmax(&qk, 1)?; + + Ok(qk.matmul(&v)?) } } fn main() -> Result<()> { let device = Device::cuda_if_available(0)?; let varmap = VarMap::new(); + let vs = VarBuilder::from_varmap(&varmap, DType::F32, &device); + let self_attn = SelfAttention::new(4, vs)?; Ok(()) }