Able to compile

This commit is contained in:
Vishal Patil
2025-02-14 10:16:06 -05:00
parent 9375ad6455
commit b86ef2201f

View File

@@ -1,43 +1,56 @@
use anyhow::Ok;
use anyhow::Result;
use candle_core::Var;
use candle_core::{DType, Device, Tensor, D};
use candle_nn::ops::softmax;
use candle_nn::{loss, Linear, Module, Optimizer, Sequential, VarMap, SGD};
use candle_nn::VarBuilder;
use candle_nn::{linear, Linear, Module, Optimizer, Sequential, VarMap, SGD};
struct SelfAttention {
d: u32, // Embedding size
d: usize, // Embedding size
scale: Tensor,
w_q: Linear,
w_k: Linear,
w_v: Linear,
}
impl SelfAttention {
fn new(d: u32) -> Self {
fn new(d: usize, vb: VarBuilder) -> Result<Self> {
let d = d;
let w_q = Linear::new(d, d);
let w_k = Linear::new(d, d);
let w_v = Linear::new(d, d);
let scale = Tensor::new((d as f32).sqrt(), vb.device())?;
let w_q = linear(d, d, vb.pp("w_q"))?;
let w_k = linear(d, d, vb.pp("w_k"))?;
let w_v = linear(d, d, vb.pp("w_v"))?;
Self { d, w_q, w_k, w_v }
Ok(Self {
d,
scale,
w_q,
w_k,
w_v,
})
}
}
impl Module for SelfAttention {
fn forward(&self, x: &Tensor) -> Tensor {
let q = self.w_q.forward(x);
let k = self.w_k.forward(x);
let v = self.w_v.forward(x);
impl SelfAttention {
fn attention(&self, x: &Tensor) -> Result<Tensor> {
let q = self.w_q.forward(x)?;
let k = self.w_k.forward(x)?;
let v = self.w_v.forward(x)?;
let qk = q.matmul(&k.transpose(1, 0));
let qk = qk / D::from_f32((self.d as f32).sqrt());
let qk = softmax(&qk, 1);
let qk = q.matmul(&k.transpose(1, 0)?)?;
qk.matmul(&v)
let qk = qk.broadcast_div(&self.scale)?;
let qk = softmax(&qk, 1)?;
Ok(qk.matmul(&v)?)
}
}
fn main() -> Result<()> {
let device = Device::cuda_if_available(0)?;
let varmap = VarMap::new();
let vs = VarBuilder::from_varmap(&varmap, DType::F32, &device);
let self_attn = SelfAttention::new(4, vs)?;
Ok(())
}