diff --git a/src/bin/xor.rs b/src/bin/xor.rs index e7cfb56..e8874f8 100644 --- a/src/bin/xor.rs +++ b/src/bin/xor.rs @@ -1,9 +1,10 @@ use aicaramba::functions::*; +use aicaramba::matrix::Mat; use aicaramba::neural_net::NeuralNet; fn main() { - let mut net = NeuralNet::new(vec![2, 3, 1], SIGMOID, MSE, 0.05); - let epochs = 10_000; + let mut net = NeuralNet::new(vec![2, 3, 1], RELU, MSE, 0.05); + let epochs = 500; let inputs = vec![ vec![0.0, 0.0], @@ -14,5 +15,11 @@ fn main() { let expected = vec![vec![0.0], vec![1.0], vec![1.0], vec![0.0]]; - net.train_basic(inputs, expected, epochs); + net.train_basic(inputs.clone(), expected, epochs); + + for input in inputs { + let output = net.forward(Mat::from(input.clone())); + let o = output.into_iter().collect::>(); + println!("{} ^ {} = {:.20}", input[0], input[1], o[0]); + } }