diff --git a/self-intention/Cargo.toml b/self-intention/Cargo.toml new file mode 100644 index 0000000..5141677 --- /dev/null +++ b/self-intention/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "self-intention" +version = "0.1.0" +edition = "2021" + +[dependencies] +anyhow = "1.0.40" +csv = "1.1.6" +clap = { version = "4.5.1", features = ["derive"] } +rand = "0.8.5" +candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.8.2", features = [ + "cuda", +] } +candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.8.2" } diff --git a/self-intention/src/main.rs b/self-intention/src/main.rs new file mode 100644 index 0000000..5db4984 --- /dev/null +++ b/self-intention/src/main.rs @@ -0,0 +1,43 @@ +use anyhow::Ok; +use anyhow::Result; +use candle_core::{DType, Device, Tensor, D}; +use candle_nn::ops::softmax; +use candle_nn::{loss, Linear, Module, Optimizer, Sequential, VarMap, SGD}; + +struct SelfAttention { + d: u32, // Embedding size + w_q: Linear, + w_k: Linear, + w_v: Linear, +} + +impl SelfAttention { + fn new(d: u32) -> Self { + let d = d; + let w_q = Linear::new(d, d); + let w_k = Linear::new(d, d); + let w_v = Linear::new(d, d); + + Self { d, w_q, w_k, w_v } + } +} + +impl Module for SelfAttention { + fn forward(&self, x: &Tensor) -> Tensor { + let q = self.w_q.forward(x); + let k = self.w_k.forward(x); + let v = self.w_v.forward(x); + + let qk = q.matmul(&k.transpose(1, 0)); + let qk = qk / D::from_f32((self.d as f32).sqrt()); + let qk = softmax(&qk, 1); + + qk.matmul(&v) + } +} +fn main() -> Result<()> { + let device = Device::cuda_if_available(0)?; + let varmap = VarMap::new(); + + Ok(()) +}