mirror of
https://github.com/vishpat/candle-coursera-ml.git
synced 2025-12-22 22:19:58 +00:00
Checkpoint
This commit is contained in:
@@ -70,30 +70,18 @@ fn load_dataset(device: &Device) -> Result<Dataset> {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn r_square(y_true: &Tensor, y_pred: &Tensor, device: &Device) -> Result<Tensor> {
|
fn r_square(labels: &Tensor, predictions: &Tensor) -> Result<f32> {
|
||||||
let samples = y_true.shape().dims1()?;
|
let mean = labels.mean(D::Minus1)?;
|
||||||
let pred_samples = y_pred.shape().dims1()?;
|
|
||||||
if samples != pred_samples {
|
let ssr = labels.sub(predictions)?;
|
||||||
return Err(anyhow::anyhow!(
|
let ssr = ssr.mul(&ssr)?.sum(D::Minus1)?;
|
||||||
"y_true and y_pred must have the same number of samples"
|
|
||||||
));
|
let sst = labels.broadcast_sub(&mean)?;
|
||||||
}
|
let sst = sst.mul(&sst)?.sum(D::Minus1)?;
|
||||||
let y_mean = y_true.mean(0).unwrap();
|
|
||||||
let ss_tot = y_true
|
let tmp = ssr.div(&sst)?.to_scalar::<f32>()?;
|
||||||
.broadcast_sub(&y_mean)
|
|
||||||
.unwrap()
|
Ok(1.0 - tmp)
|
||||||
.sqr()
|
|
||||||
.unwrap()
|
|
||||||
.sum(D::Minus1)
|
|
||||||
.unwrap();
|
|
||||||
let ss_res = y_true
|
|
||||||
.sub(&y_pred)
|
|
||||||
.unwrap()
|
|
||||||
.sqr()
|
|
||||||
.unwrap()
|
|
||||||
.sum(D::Minus1)
|
|
||||||
.unwrap();
|
|
||||||
Ok(Tensor::new(1.0 as f32, &device)?.sub(&ss_res.broadcast_div(&ss_tot)?)?)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Mlp {
|
struct Mlp {
|
||||||
@@ -177,7 +165,7 @@ fn main() -> Result<()> {
|
|||||||
if args.progress && epoch % 100 == 0 {
|
if args.progress && epoch % 100 == 0 {
|
||||||
println!(
|
println!(
|
||||||
"{epoch:4} test r2: {:?}",
|
"{epoch:4} test r2: {:?}",
|
||||||
r_square(&test_values, &test_logits.squeeze(1)?, &device)?.to_scalar::<f32>()?
|
r_square(&test_values, &test_logits.squeeze(1)?)?
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user