mirror of
https://github.com/vishpat/candle-coursera-ml.git
synced 2025-12-22 22:19:58 +00:00
Added r-square
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
use anyhow::Ok;
|
||||
use anyhow::Result;
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use candle_core::{DType, Device, Tensor, D};
|
||||
use candle_nn::{loss, Linear, Module, Optimizer, VarBuilder, VarMap};
|
||||
use clap::Parser;
|
||||
use rand::prelude::*;
|
||||
@@ -9,16 +10,16 @@ use std::rc::Rc;
|
||||
|
||||
const INPUT_COUNT: usize = 3;
|
||||
const LAYER1_COUNT: usize = 50;
|
||||
const LAYER2_COUNT: usize = 10;
|
||||
const LAYER2_COUNT: usize = 50;
|
||||
const OUTPUT_COUNT: usize = 1;
|
||||
|
||||
const BATCH_SIZE: usize = 1000;
|
||||
|
||||
struct Dataset {
|
||||
pub training_data: Tensor,
|
||||
pub training_labels: Tensor,
|
||||
pub training_values: Tensor,
|
||||
pub test_data: Tensor,
|
||||
pub test_labels: Tensor,
|
||||
pub test_values: Tensor,
|
||||
}
|
||||
|
||||
fn func(x1: f32, x2: f32, x3: f32) -> f32 {
|
||||
@@ -58,17 +59,43 @@ fn load_tensors(samples: u32, device: &Device) -> Result<(Tensor, Tensor)> {
|
||||
}
|
||||
|
||||
fn load_dataset(device: &Device) -> Result<Dataset> {
|
||||
let (training_data, training_labels) = load_tensors(5000, device)?;
|
||||
let (test_data, test_labels) = load_tensors(2000, device)?;
|
||||
let (training_data, training_values) = load_tensors(5000, device)?;
|
||||
let (test_data, test_values) = load_tensors(2000, device)?;
|
||||
|
||||
Ok(Dataset {
|
||||
training_data,
|
||||
training_labels,
|
||||
training_values,
|
||||
test_data,
|
||||
test_labels,
|
||||
test_values,
|
||||
})
|
||||
}
|
||||
|
||||
fn r_square(y_true: &Tensor, y_pred: &Tensor, device: &Device) -> Result<Tensor> {
|
||||
let samples = y_true.shape().dims1()?;
|
||||
let pred_samples = y_pred.shape().dims1()?;
|
||||
if samples != pred_samples {
|
||||
return Err(anyhow::anyhow!(
|
||||
"y_true and y_pred must have the same number of samples"
|
||||
));
|
||||
}
|
||||
let y_mean = y_true.mean(0).unwrap();
|
||||
let ss_tot = y_true
|
||||
.broadcast_sub(&y_mean)
|
||||
.unwrap()
|
||||
.sqr()
|
||||
.unwrap()
|
||||
.sum(D::Minus1)
|
||||
.unwrap();
|
||||
let ss_res = y_true
|
||||
.sub(&y_pred)
|
||||
.unwrap()
|
||||
.sqr()
|
||||
.unwrap()
|
||||
.sum(D::Minus1)
|
||||
.unwrap();
|
||||
Ok(Tensor::new(1.0 as f32, &device)?.sub(&ss_res.broadcast_div(&ss_tot)?)?)
|
||||
}
|
||||
|
||||
struct Mlp {
|
||||
ln1: Linear,
|
||||
ln2: Linear,
|
||||
@@ -122,9 +149,9 @@ fn main() -> Result<()> {
|
||||
let model = Mlp::new(vs)?;
|
||||
let mut sgd = candle_nn::SGD::new(varmap.all_vars(), args.learning_rate)?;
|
||||
|
||||
let test_images = dataset.test_data.to_device(&device)?;
|
||||
let test_labels = dataset
|
||||
.test_labels
|
||||
let test_data = dataset.test_data.to_device(&device)?;
|
||||
let test_values = dataset
|
||||
.test_values
|
||||
.to_dtype(DType::F32)?
|
||||
.to_device(&device)?;
|
||||
|
||||
@@ -138,18 +165,20 @@ fn main() -> Result<()> {
|
||||
let train_data = dataset
|
||||
.training_data
|
||||
.narrow(0, batch_idx * BATCH_SIZE, BATCH_SIZE)?;
|
||||
let train_labels =
|
||||
let train_values =
|
||||
dataset
|
||||
.training_labels
|
||||
.training_values
|
||||
.narrow(0, batch_idx * BATCH_SIZE, BATCH_SIZE)?;
|
||||
let logits = model.forward(&train_data)?;
|
||||
let loss = loss::mse(&logits.squeeze(1)?, &train_labels)?;
|
||||
let loss = loss::mse(&logits.squeeze(1)?, &train_values)?;
|
||||
sgd.backward_step(&loss)?;
|
||||
}
|
||||
let test_logits = model.forward(&test_images)?;
|
||||
let test_loss = loss::mse(&test_logits.squeeze(1)?, &test_labels)?;
|
||||
let test_logits = model.forward(&test_data)?;
|
||||
if args.progress && epoch % 100 == 0 {
|
||||
println!("{epoch:4} test loss: {:8.5}", test_loss.to_scalar::<f32>()?);
|
||||
println!(
|
||||
"{epoch:4} test r2: {:?}",
|
||||
r_square(&test_values, &test_logits.squeeze(1)?, &device)?.to_scalar::<f32>()?
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user