mirror of
https://github.com/vishpat/candle-coursera-ml.git
synced 2025-12-22 22:19:58 +00:00
Checkpoint
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
use anyhow::Result;
|
||||
use candle_core::{DType, Device, Tensor, D};
|
||||
use candle_nn::{loss, ops, Linear, Module, Optimizer, VarBuilder, VarMap};
|
||||
use candle_nn::{loss, Linear, Module, Optimizer, VarBuilder, VarMap};
|
||||
use clap::Parser;
|
||||
use rand::Rng;
|
||||
use std::f64::consts::PI;
|
||||
@@ -8,7 +8,7 @@ use std::rc::Rc;
|
||||
|
||||
const INPUT_COUNT: usize = 3;
|
||||
const LAYER1_COUNT: usize = 100;
|
||||
const LAYER2_COUNT: usize = 100;
|
||||
const LAYER2_COUNT: usize = 50;
|
||||
const OUTPUT_COUNT: usize = 1;
|
||||
|
||||
struct Dataset {
|
||||
@@ -122,7 +122,7 @@ fn main() -> Result<()> {
|
||||
let test_images = dataset.test_data.to_device(&device)?;
|
||||
let test_labels = dataset
|
||||
.test_labels
|
||||
.to_dtype(DType::U32)?
|
||||
.to_dtype(DType::F32)?
|
||||
.to_device(&device)?;
|
||||
|
||||
for epoch in 1..args.epochs {
|
||||
@@ -131,19 +131,9 @@ fn main() -> Result<()> {
|
||||
sgd.backward_step(&loss)?;
|
||||
|
||||
let test_logits = model.forward(&test_images)?;
|
||||
let sum_ok = test_logits
|
||||
.argmax(D::Minus1)?
|
||||
.eq(&test_labels)?
|
||||
.to_dtype(DType::F32)?
|
||||
.sum_all()?
|
||||
.to_scalar::<f32>()?;
|
||||
let test_accuracy = sum_ok / test_labels.dims1()? as f32;
|
||||
let test_loss = loss::mse(&test_logits.squeeze(1)?, &test_labels)?;
|
||||
if args.progress && epoch % 100 == 0 {
|
||||
println!(
|
||||
"{epoch:4} train loss: {:8.5} test acc: {:5.2}%",
|
||||
loss.to_scalar::<f32>()?,
|
||||
100. * test_accuracy
|
||||
);
|
||||
println!("{epoch:4} test loss: {:8.5}", test_loss.to_scalar::<f32>()?);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user