diff --git a/src/functions.rs b/src/functions.rs index 3358431..aa332e3 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -23,7 +23,12 @@ pub const SIGMOID: ActivationFn = ActivationFn { f_prime: |x| x * (1.0 - x), }; -pub const MSE: LossFn = LossFn { - f: |y_hat, y| (y_hat - y).powi(2), - f_prime: |y_hat, y| -2.0 * (y_hat - y), +pub const RELU: ActivationFn = ActivationFn { + f: |x| x.max(0.0), + f_prime: |x| if x > 0.0 { 1.0 } else { 0.0 }, +}; + +pub const MSE: LossFn = LossFn { + f: |y_hat, y| (y - y_hat).powi(2), + f_prime: |y_hat, y| -2.0 * (y - y_hat), };