diff --git a/src/neural_net.rs b/src/neural_net.rs index e74bf07..9d6f523 100644 --- a/src/neural_net.rs +++ b/src/neural_net.rs @@ -90,8 +90,8 @@ where .elementwise_mul(&losses) .map(|x| x * self.learning_rate.clone()); - self.weights[i] = self.weights[i].add(&gradients.dot(&self.data[i].transpose())); - self.biases[i] = self.biases[i].add(&gradients); + self.weights[i] = self.weights[i].sub(&gradients.dot(&self.data[i].transpose())); + self.biases[i] = self.biases[i].sub(&gradients); losses = self.weights[i].transpose().dot(&losses); gradients = self.data[i].map(self.activation.f_prime); @@ -125,10 +125,7 @@ where sum + (self.loss.f)(y_hat, y.clone()) }); } - println!( - "epoch: {i:0>width$} / {epochs:0>width$} ;\tloss: {:.5}", - loss - ); + println!("epoch: {i:0>width$} / {epochs:0>width$} ;\tloss: {}", loss); } } }