mirror of
https://github.com/vishpat/candle-coursera-ml.git
synced 2025-12-22 22:19:58 +00:00
Best loss
This commit is contained in:
@@ -1,16 +1,19 @@
|
||||
use anyhow::Result;
|
||||
use candle_core::{DType, Device, Tensor, D};
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use candle_nn::{loss, Linear, Module, Optimizer, VarBuilder, VarMap};
|
||||
use clap::Parser;
|
||||
use rand::prelude::*;
|
||||
use rand::Rng;
|
||||
use std::f64::consts::PI;
|
||||
use std::rc::Rc;
|
||||
|
||||
const INPUT_COUNT: usize = 3;
|
||||
const LAYER1_COUNT: usize = 100;
|
||||
const LAYER2_COUNT: usize = 50;
|
||||
const LAYER1_COUNT: usize = 50;
|
||||
const LAYER2_COUNT: usize = 10;
|
||||
const OUTPUT_COUNT: usize = 1;
|
||||
|
||||
const BATCH_SIZE: usize = 1000;
|
||||
|
||||
struct Dataset {
|
||||
pub training_data: Tensor,
|
||||
pub training_labels: Tensor,
|
||||
@@ -45,11 +48,11 @@ fn load_tensors(samples: u32, device: &Device) -> Result<(Tensor, Tensor)> {
|
||||
|
||||
let x_values = x_values.into_iter().flatten().collect::<Vec<f32>>();
|
||||
let x_values = x_values.as_slice();
|
||||
let x_tensor = Tensor::from_slice(x_values, &[samples as usize, 3], &device)?;
|
||||
let x_tensor = Tensor::from_slice(x_values, &[samples as usize, 3], device)?;
|
||||
|
||||
let y_values = y_values.into_iter().collect::<Vec<f32>>();
|
||||
let y_values = y_values.as_slice();
|
||||
let y_tensor = Tensor::from_slice(y_values, &[samples as usize], &device)?;
|
||||
let y_tensor = Tensor::from_slice(y_values, &[samples as usize], device)?;
|
||||
|
||||
Ok((x_tensor, y_tensor))
|
||||
}
|
||||
@@ -125,11 +128,24 @@ fn main() -> Result<()> {
|
||||
.to_dtype(DType::F32)?
|
||||
.to_device(&device)?;
|
||||
|
||||
for epoch in 1..args.epochs {
|
||||
let logits = model.forward(&dataset.training_data)?;
|
||||
let loss = loss::mse(&logits.squeeze(1)?, &dataset.training_labels)?;
|
||||
sgd.backward_step(&loss)?;
|
||||
let (training_size, _) = dataset.training_data.shape().dims2()?;
|
||||
let n_batches = training_size / BATCH_SIZE;
|
||||
let mut batch_idxs = (0..n_batches).collect::<Vec<usize>>();
|
||||
|
||||
for epoch in 1..args.epochs {
|
||||
batch_idxs.shuffle(&mut rand::thread_rng());
|
||||
for batch_idx in batch_idxs.iter() {
|
||||
let train_data = dataset
|
||||
.training_data
|
||||
.narrow(0, batch_idx * BATCH_SIZE, BATCH_SIZE)?;
|
||||
let train_labels =
|
||||
dataset
|
||||
.training_labels
|
||||
.narrow(0, batch_idx * BATCH_SIZE, BATCH_SIZE)?;
|
||||
let logits = model.forward(&train_data)?;
|
||||
let loss = loss::mse(&logits.squeeze(1)?, &train_labels)?;
|
||||
sgd.backward_step(&loss)?;
|
||||
}
|
||||
let test_logits = model.forward(&test_images)?;
|
||||
let test_loss = loss::mse(&test_logits.squeeze(1)?, &test_labels)?;
|
||||
if args.progress && epoch % 100 == 0 {
|
||||
|
||||
Reference in New Issue
Block a user