From 769c864c2263e2f31f699a7da1fb1bd199d38a4e Mon Sep 17 00:00:00 2001 From: Vishal Patil Date: Thu, 13 Feb 2025 19:06:14 -0500 Subject: [PATCH] Checkpoint --- neural-networks-101/src/main.rs | 38 +++++++++++---------------------- 1 file changed, 13 insertions(+), 25 deletions(-) diff --git a/neural-networks-101/src/main.rs b/neural-networks-101/src/main.rs index 78e0d10..59fab2e 100644 --- a/neural-networks-101/src/main.rs +++ b/neural-networks-101/src/main.rs @@ -70,30 +70,18 @@ fn load_dataset(device: &Device) -> Result { }) } -fn r_square(y_true: &Tensor, y_pred: &Tensor, device: &Device) -> Result { - let samples = y_true.shape().dims1()?; - let pred_samples = y_pred.shape().dims1()?; - if samples != pred_samples { - return Err(anyhow::anyhow!( - "y_true and y_pred must have the same number of samples" - )); - } - let y_mean = y_true.mean(0).unwrap(); - let ss_tot = y_true - .broadcast_sub(&y_mean) - .unwrap() - .sqr() - .unwrap() - .sum(D::Minus1) - .unwrap(); - let ss_res = y_true - .sub(&y_pred) - .unwrap() - .sqr() - .unwrap() - .sum(D::Minus1) - .unwrap(); - Ok(Tensor::new(1.0 as f32, &device)?.sub(&ss_res.broadcast_div(&ss_tot)?)?) +fn r_square(labels: &Tensor, predictions: &Tensor) -> Result { + let mean = labels.mean(D::Minus1)?; + + let ssr = labels.sub(predictions)?; + let ssr = ssr.mul(&ssr)?.sum(D::Minus1)?; + + let sst = labels.broadcast_sub(&mean)?; + let sst = sst.mul(&sst)?.sum(D::Minus1)?; + + let tmp = ssr.div(&sst)?.to_scalar::()?; + + Ok(1.0 - tmp) } struct Mlp { @@ -177,7 +165,7 @@ fn main() -> Result<()> { if args.progress && epoch % 100 == 0 { println!( "{epoch:4} test r2: {:?}", - r_square(&test_values, &test_logits.squeeze(1)?, &device)?.to_scalar::()? + r_square(&test_values, &test_logits.squeeze(1)?)? ); } }