Best loss

This commit is contained in:
Vishal Patil
2025-01-16 22:05:33 -05:00
parent 89514173a6
commit d4434378a9

View File

@@ -1,16 +1,19 @@
use anyhow::Result;
use candle_core::{DType, Device, Tensor, D};
use candle_core::{DType, Device, Tensor};
use candle_nn::{loss, Linear, Module, Optimizer, VarBuilder, VarMap};
use clap::Parser;
use rand::prelude::*;
use rand::Rng;
use std::f64::consts::PI;
use std::rc::Rc;
const INPUT_COUNT: usize = 3;
const LAYER1_COUNT: usize = 100;
const LAYER2_COUNT: usize = 50;
const LAYER1_COUNT: usize = 50;
const LAYER2_COUNT: usize = 10;
const OUTPUT_COUNT: usize = 1;
const BATCH_SIZE: usize = 1000;
struct Dataset {
pub training_data: Tensor,
pub training_labels: Tensor,
@@ -45,11 +48,11 @@ fn load_tensors(samples: u32, device: &Device) -> Result<(Tensor, Tensor)> {
let x_values = x_values.into_iter().flatten().collect::<Vec<f32>>();
let x_values = x_values.as_slice();
let x_tensor = Tensor::from_slice(x_values, &[samples as usize, 3], &device)?;
let x_tensor = Tensor::from_slice(x_values, &[samples as usize, 3], device)?;
let y_values = y_values.into_iter().collect::<Vec<f32>>();
let y_values = y_values.as_slice();
let y_tensor = Tensor::from_slice(y_values, &[samples as usize], &device)?;
let y_tensor = Tensor::from_slice(y_values, &[samples as usize], device)?;
Ok((x_tensor, y_tensor))
}
@@ -125,11 +128,24 @@ fn main() -> Result<()> {
.to_dtype(DType::F32)?
.to_device(&device)?;
for epoch in 1..args.epochs {
let logits = model.forward(&dataset.training_data)?;
let loss = loss::mse(&logits.squeeze(1)?, &dataset.training_labels)?;
sgd.backward_step(&loss)?;
let (training_size, _) = dataset.training_data.shape().dims2()?;
let n_batches = training_size / BATCH_SIZE;
let mut batch_idxs = (0..n_batches).collect::<Vec<usize>>();
for epoch in 1..args.epochs {
batch_idxs.shuffle(&mut rand::thread_rng());
for batch_idx in batch_idxs.iter() {
let train_data = dataset
.training_data
.narrow(0, batch_idx * BATCH_SIZE, BATCH_SIZE)?;
let train_labels =
dataset
.training_labels
.narrow(0, batch_idx * BATCH_SIZE, BATCH_SIZE)?;
let logits = model.forward(&train_data)?;
let loss = loss::mse(&logits.squeeze(1)?, &train_labels)?;
sgd.backward_step(&loss)?;
}
let test_logits = model.forward(&test_images)?;
let test_loss = loss::mse(&test_logits.squeeze(1)?, &test_labels)?;
if args.progress && epoch % 100 == 0 {