Checkpoint

This commit is contained in:
Vishal Patil
2025-02-14 09:24:06 -05:00
parent 769c864c22
commit 9375ad6455
2 changed files with 57 additions and 0 deletions

14
self-intention/Cargo.toml Normal file
View File

@@ -0,0 +1,14 @@
[package]
name = "self-intention"
version = "0.1.0"
edition = "2021"
[dependencies]
anyhow = "1.0.40"
csv = "1.1.6"
clap = { version = "4.5.1", features = ["derive"] }
rand = "0.8.5"
candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.8.2", features = [
"cuda",
] }
candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.8.2" }

View File

@@ -0,0 +1,43 @@
use anyhow::Ok;
use anyhow::Result;
use candle_core::{DType, Device, Tensor, D};
use candle_nn::ops::softmax;
use candle_nn::{loss, Linear, Module, Optimizer, Sequential, VarMap, SGD};
struct SelfAttention {
d: u32, // Embedding size
w_q: Linear,
w_k: Linear,
w_v: Linear,
}
impl SelfAttention {
fn new(d: u32) -> Self {
let d = d;
let w_q = Linear::new(d, d);
let w_k = Linear::new(d, d);
let w_v = Linear::new(d, d);
Self { d, w_q, w_k, w_v }
}
}
impl Module for SelfAttention {
fn forward(&self, x: &Tensor) -> Tensor {
let q = self.w_q.forward(x);
let k = self.w_k.forward(x);
let v = self.w_v.forward(x);
let qk = q.matmul(&k.transpose(1, 0));
let qk = qk / D::from_f32((self.d as f32).sqrt());
let qk = softmax(&qk, 1);
qk.matmul(&v)
}
}
fn main() -> Result<()> {
let device = Device::cuda_if_available(0)?;
let varmap = VarMap::new();
Ok(())
}