From 5787924b35839b05e5b5f5f4cf1e1c8536a3cbae Mon Sep 17 00:00:00 2001 From: markichnich Date: Thu, 21 Mar 2024 22:00:41 +0100 Subject: [PATCH] backprop logic correction --- src/neural_net.rs | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/neural_net.rs b/src/neural_net.rs index e74bf07..9d6f523 100644 --- a/src/neural_net.rs +++ b/src/neural_net.rs @@ -90,8 +90,8 @@ where .elementwise_mul(&losses) .map(|x| x * self.learning_rate.clone()); - self.weights[i] = self.weights[i].add(&gradients.dot(&self.data[i].transpose())); - self.biases[i] = self.biases[i].add(&gradients); + self.weights[i] = self.weights[i].sub(&gradients.dot(&self.data[i].transpose())); + self.biases[i] = self.biases[i].sub(&gradients); losses = self.weights[i].transpose().dot(&losses); gradients = self.data[i].map(self.activation.f_prime); @@ -125,10 +125,7 @@ where sum + (self.loss.f)(y_hat, y.clone()) }); } - println!( - "epoch: {i:0>width$} / {epochs:0>width$} ;\tloss: {:.5}", - loss - ); + println!("epoch: {i:0>width$} / {epochs:0>width$} ;\tloss: {}", loss); } } }