add relu and correct mse.

This commit is contained in:
markichnich 2024-03-21 22:01:42 +01:00
parent 5787924b35
commit 561d6f52c4

View File

@ -23,7 +23,12 @@ pub const SIGMOID: ActivationFn<f64> = ActivationFn {
f_prime: |x| x * (1.0 - x), f_prime: |x| x * (1.0 - x),
}; };
pub const MSE: LossFn<f64> = LossFn { pub const RELU: ActivationFn<f64> = ActivationFn {
f: |y_hat, y| (y_hat - y).powi(2), f: |x| x.max(0.0),
f_prime: |y_hat, y| -2.0 * (y_hat - y), f_prime: |x| if x > 0.0 { 1.0 } else { 0.0 },
};
pub const MSE: LossFn<f64> = LossFn {
f: |y_hat, y| (y - y_hat).powi(2),
f_prime: |y_hat, y| -2.0 * (y - y_hat),
}; };