Initial commit
This commit is contained in:
commit
c84187f446
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
/target
|
||||||
75
Cargo.lock
generated
Normal file
75
Cargo.lock
generated
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
# This file is automatically @generated by Cargo.
|
||||||
|
# It is not intended for manual editing.
|
||||||
|
version = 3
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "aicaramba"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"rand",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cfg-if"
|
||||||
|
version = "1.0.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "getrandom"
|
||||||
|
version = "0.2.12"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"libc",
|
||||||
|
"wasi",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "libc"
|
||||||
|
version = "0.2.153"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ppv-lite86"
|
||||||
|
version = "0.2.17"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rand"
|
||||||
|
version = "0.8.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
"rand_chacha",
|
||||||
|
"rand_core",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rand_chacha"
|
||||||
|
version = "0.3.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
|
||||||
|
dependencies = [
|
||||||
|
"ppv-lite86",
|
||||||
|
"rand_core",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rand_core"
|
||||||
|
version = "0.6.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
|
||||||
|
dependencies = [
|
||||||
|
"getrandom",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "wasi"
|
||||||
|
version = "0.11.0+wasi-snapshot-preview1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
|
||||||
10
Cargo.toml
Normal file
10
Cargo.toml
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
[package]
|
||||||
|
name = "aicaramba"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
license = "MIT"
|
||||||
|
|
||||||
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
rand = { version = "0.8.5", optional = true }
|
||||||
18
src/bin/xor.rs
Normal file
18
src/bin/xor.rs
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
use aicaramba::functions::*;
|
||||||
|
use aicaramba::neural_net::NeuralNet;
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let mut net = NeuralNet::new(vec![2, 3, 1], SIGMOID, MSE, 0.05);
|
||||||
|
let epochs = 10_000;
|
||||||
|
|
||||||
|
let inputs = vec![
|
||||||
|
vec![0.0, 0.0],
|
||||||
|
vec![0.0, 1.0],
|
||||||
|
vec![1.0, 0.0],
|
||||||
|
vec![1.0, 1.0],
|
||||||
|
];
|
||||||
|
|
||||||
|
let expected = vec![vec![0.0], vec![1.0], vec![1.0], vec![0.0]];
|
||||||
|
|
||||||
|
net.train_basic(inputs, expected, epochs);
|
||||||
|
}
|
||||||
29
src/functions.rs
Normal file
29
src/functions.rs
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
use crate::matrix::MatElem;
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug)]
|
||||||
|
pub struct ActivationFn<T>
|
||||||
|
where
|
||||||
|
T: MatElem,
|
||||||
|
{
|
||||||
|
pub f: fn(T) -> T,
|
||||||
|
pub f_prime: fn(T) -> T,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug)]
|
||||||
|
pub struct LossFn<T>
|
||||||
|
where
|
||||||
|
T: MatElem,
|
||||||
|
{
|
||||||
|
pub f: fn(T, T) -> T,
|
||||||
|
pub f_prime: fn(T, T) -> T,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub const SIGMOID: ActivationFn<f64> = ActivationFn {
|
||||||
|
f: |x| 1.0 / (1.0 + f64::exp(-x)),
|
||||||
|
f_prime: |x| x * (1.0 - x),
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const MSE: LossFn<f64> = LossFn {
|
||||||
|
f: |y_hat, y| (y_hat - y).powi(2),
|
||||||
|
f_prime: |y_hat, y| -2.0 * (y_hat - y),
|
||||||
|
};
|
||||||
3
src/lib.rs
Normal file
3
src/lib.rs
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
pub mod functions;
|
||||||
|
pub mod matrix;
|
||||||
|
pub mod neural_net;
|
||||||
300
src/matrix.rs
Normal file
300
src/matrix.rs
Normal file
@ -0,0 +1,300 @@
|
|||||||
|
use rand::Rng;
|
||||||
|
use std::ops::{Add, Mul, Sub};
|
||||||
|
|
||||||
|
// NOTE: might want to rethink design (to 2d-vec?) to enable `matrix[i][j]`
|
||||||
|
// indexing and make nice row-iterator implementation possible
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Mat<T>
|
||||||
|
where
|
||||||
|
T: MatElem,
|
||||||
|
{
|
||||||
|
pub rows: usize,
|
||||||
|
pub cols: usize,
|
||||||
|
data: Vec<T>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Shorthand/alias trait for types that are valid as matrix elements.
|
||||||
|
pub trait MatElem:
|
||||||
|
PartialEq + Clone + Default + Add<Output = Self> + Sub<Output = Self> + Mul<Output = Self>
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> MatElem for T where
|
||||||
|
T: PartialEq + Clone + Default + Add<Output = T> + Sub<Output = T> + Mul<Output = T>
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> Mat<T>
|
||||||
|
where
|
||||||
|
T: MatElem,
|
||||||
|
{
|
||||||
|
pub fn new(rows: usize, cols: usize, data: Vec<T>) -> Mat<T> {
|
||||||
|
assert!(data.len() == rows * cols, "Invalid Size");
|
||||||
|
Mat { rows, cols, data }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn at(&self, row: usize, col: usize) -> &T {
|
||||||
|
&self.data[row * self.cols + col]
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn at_mut(&mut self, row: usize, col: usize) -> &mut T {
|
||||||
|
&mut self.data[row * self.cols + col]
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn default_with_size(rows: usize, cols: usize) -> Mat<T> {
|
||||||
|
Mat {
|
||||||
|
rows,
|
||||||
|
cols,
|
||||||
|
data: vec![T::default(); cols * rows],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn add(&self, other: &Mat<T>) -> Mat<T> {
|
||||||
|
if self.rows != other.rows || self.cols != other.cols {
|
||||||
|
panic!("Attempted to add matrices with differing shapes.");
|
||||||
|
}
|
||||||
|
self.elementwise(other, |a, b| a + b)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn sub(&self, other: &Mat<T>) -> Mat<T>
|
||||||
|
where
|
||||||
|
T: std::ops::Sub<Output = T>,
|
||||||
|
{
|
||||||
|
if self.rows != other.rows || self.cols != other.cols {
|
||||||
|
panic!("Attempted to subtract matrices with differing shapes.");
|
||||||
|
}
|
||||||
|
self.elementwise(other, |a, b| a - b)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn elementwise_mul(&self, other: &Mat<T>) -> Mat<T>
|
||||||
|
where
|
||||||
|
T: std::ops::Mul<Output = T>,
|
||||||
|
{
|
||||||
|
if self.rows != other.rows || self.cols != other.cols {
|
||||||
|
panic!("Attempted to elementwise-multiply matrices of differing shapes.");
|
||||||
|
}
|
||||||
|
self.elementwise(other, |a, b| a * b)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn elementwise(&self, other: &Mat<T>, f: fn(T, T) -> T) -> Mat<T> {
|
||||||
|
if self.rows != other.rows || self.cols != other.cols {
|
||||||
|
panic!("Attempted to apply element-wise operation to matrices with differing shapes.");
|
||||||
|
}
|
||||||
|
|
||||||
|
let data = self
|
||||||
|
.data
|
||||||
|
.iter()
|
||||||
|
.zip(other.data.iter())
|
||||||
|
.map(|(a, b)| f(a.clone(), b.clone()))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
Mat {
|
||||||
|
rows: self.rows,
|
||||||
|
cols: self.cols,
|
||||||
|
data,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn dot(&self, other: &Mat<T>) -> Mat<T> {
|
||||||
|
if self.cols != other.rows {
|
||||||
|
panic!(
|
||||||
|
"Attempted to take dot product of incompatible matrix shapes. (A.cols != B.rows)"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut data = vec![T::default(); self.rows * other.cols];
|
||||||
|
|
||||||
|
for i in 0..self.rows {
|
||||||
|
for j in 0..other.cols {
|
||||||
|
let mut sum = T::default();
|
||||||
|
for k in 0..self.cols {
|
||||||
|
sum = sum
|
||||||
|
+ self.data[i * self.cols + k].clone()
|
||||||
|
* other.data[k * other.cols + j].clone();
|
||||||
|
}
|
||||||
|
data[i * other.cols + j] = sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Mat {
|
||||||
|
rows: self.rows,
|
||||||
|
cols: other.cols,
|
||||||
|
data,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn transpose(&self) -> Mat<T> {
|
||||||
|
let mut buffer = vec![T::default(); self.cols * self.rows];
|
||||||
|
|
||||||
|
for i in 0..self.rows {
|
||||||
|
for j in 0..self.cols {
|
||||||
|
buffer[j * self.rows + i] = self.data[i * self.cols + j].clone();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Mat {
|
||||||
|
rows: self.cols,
|
||||||
|
cols: self.rows,
|
||||||
|
data: buffer,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn map<F>(&self, f: F) -> Mat<T>
|
||||||
|
where F: FnMut(T) -> T
|
||||||
|
{
|
||||||
|
Mat {
|
||||||
|
rows: self.rows, cols: self.cols,
|
||||||
|
data: self.data.clone()
|
||||||
|
.into_iter()
|
||||||
|
.map(f)
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Collect<T>
|
||||||
|
where T: MatElem,
|
||||||
|
{
|
||||||
|
fn collect_mat(self, rows: usize, cols: usize) -> Mat<T>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> Collect<T> for T
|
||||||
|
where T: MatElem + std::iter::IntoIterator<Item = T>
|
||||||
|
{
|
||||||
|
fn collect_mat(self, rows: usize, cols: usize) -> Mat<T> {
|
||||||
|
let data = self.into_iter().collect::<Vec<T>>();
|
||||||
|
if data.len() != rows * cols {
|
||||||
|
panic!("Collecting iterator into matrix failed due to incompatible matrix shape.")
|
||||||
|
}
|
||||||
|
Mat { rows, cols, data }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// the random function is only available if `rand` supports randomizing the element type
|
||||||
|
impl<T> Mat<T>
|
||||||
|
where
|
||||||
|
T: MatElem,
|
||||||
|
rand::distributions::Standard: rand::distributions::Distribution<T>,
|
||||||
|
{
|
||||||
|
// TODO: depend on randomization feature
|
||||||
|
pub fn random(rows: usize, cols: usize) -> Mat<T> {
|
||||||
|
let mut data = Vec::with_capacity(rows * cols);
|
||||||
|
|
||||||
|
for _ in 0..rows * cols {
|
||||||
|
data.push(rand::thread_rng().gen());
|
||||||
|
}
|
||||||
|
|
||||||
|
Mat { rows, cols, data }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE: might want to change this to two row- and col-iters in the future
|
||||||
|
// then might implement something like `flat_iter` that mirrors
|
||||||
|
// current behavior.
|
||||||
|
impl<T> IntoIterator for Mat<T>
|
||||||
|
where
|
||||||
|
T: MatElem,
|
||||||
|
{
|
||||||
|
type Item = T;
|
||||||
|
type IntoIter = std::vec::IntoIter<Self::Item>;
|
||||||
|
|
||||||
|
fn into_iter(self) -> Self::IntoIter {
|
||||||
|
self.data.into_iter()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> From<Vec<T>> for Mat<T>
|
||||||
|
where
|
||||||
|
T: MatElem,
|
||||||
|
{
|
||||||
|
fn from(value: Vec<T>) -> Self {
|
||||||
|
let rows = value.len();
|
||||||
|
let cols = 1;
|
||||||
|
Mat {
|
||||||
|
rows,
|
||||||
|
cols,
|
||||||
|
data: value,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> From<Vec<Vec<T>>> for Mat<T>
|
||||||
|
where
|
||||||
|
T: MatElem,
|
||||||
|
{
|
||||||
|
fn from(value: Vec<Vec<T>>) -> Self {
|
||||||
|
let rows = value.len();
|
||||||
|
let cols = value.first().map(Vec::len).unwrap_or(0);
|
||||||
|
Mat {
|
||||||
|
rows,
|
||||||
|
cols,
|
||||||
|
data: value.into_iter().flatten().collect(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> PartialEq for Mat<T>
|
||||||
|
where
|
||||||
|
T: MatElem,
|
||||||
|
{
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
self.rows == other.rows && self.cols == other.cols && self.data == other.data
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> std::ops::Index<usize> for Mat<T>
|
||||||
|
where
|
||||||
|
T: MatElem,
|
||||||
|
{
|
||||||
|
type Output = [T];
|
||||||
|
|
||||||
|
fn index(&self, index: usize) -> &Self::Output {
|
||||||
|
&self.data[index * self.cols..(index + 1) * self.cols]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> std::fmt::Display for Mat<T>
|
||||||
|
where
|
||||||
|
T: MatElem + std::fmt::Display,
|
||||||
|
{
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
for row in 0..self.rows {
|
||||||
|
for col in 0..self.cols {
|
||||||
|
write!(f, "{}", self.data[row * self.cols + col])?;
|
||||||
|
if col < self.cols - 1 {
|
||||||
|
write!(f, "\t")?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
writeln!(f)?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[macro_export]
|
||||||
|
macro_rules! matrix {
|
||||||
|
( $( $($val:expr),+ );* $(;)? ) => {
|
||||||
|
{
|
||||||
|
let mut data = Vec::<f64>::new();
|
||||||
|
let mut rows = 0;
|
||||||
|
let mut cols = 0;
|
||||||
|
$(
|
||||||
|
let row_data = vec![$($val),+];
|
||||||
|
data.extend(row_data);
|
||||||
|
rows += 1;
|
||||||
|
let row_len = vec![$($val),+].len();
|
||||||
|
if cols == 0 {
|
||||||
|
cols = row_len;
|
||||||
|
} else if cols != row_len {
|
||||||
|
panic!("Inconsistent number of elements in the matrix rows");
|
||||||
|
}
|
||||||
|
)*
|
||||||
|
|
||||||
|
Mat {
|
||||||
|
rows,
|
||||||
|
cols,
|
||||||
|
data,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
135
src/neural_net.rs
Normal file
135
src/neural_net.rs
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
use crate::functions::*;
|
||||||
|
use crate::matrix::{Mat, MatElem};
|
||||||
|
|
||||||
|
/// Contains the following values:
|
||||||
|
/// - `architecture: Vec<usize>`: The node counts for each layer (eg. `vec![2, 3, 1]`)
|
||||||
|
/// - `weights: Vec<Mat>`: The weight matrices between two layers.
|
||||||
|
/// - `biases: Vec<Mat>`: The bias matrices of the layers.
|
||||||
|
/// - `learning_rate: f64`: The scalar learning rate.
|
||||||
|
/// - `activation: ActivationFn`: Struct containing activation function and derivative
|
||||||
|
/// - `loss: LossFn`: Struct containing loss function and derivative
|
||||||
|
/// - `data: Vec<Mat>`: A buffer for the activated values during forward- and backward pass.
|
||||||
|
pub struct NeuralNet<T>
|
||||||
|
where
|
||||||
|
T: MatElem,
|
||||||
|
{
|
||||||
|
architecture: Vec<usize>,
|
||||||
|
weights: Vec<Mat<T>>,
|
||||||
|
biases: Vec<Mat<T>>,
|
||||||
|
|
||||||
|
learning_rate: T,
|
||||||
|
|
||||||
|
activation: ActivationFn<T>,
|
||||||
|
loss: LossFn<T>,
|
||||||
|
|
||||||
|
data: Vec<Mat<T>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> NeuralNet<T>
|
||||||
|
where
|
||||||
|
T: MatElem,
|
||||||
|
rand::distributions::Standard: rand::distributions::Distribution<T>,
|
||||||
|
{
|
||||||
|
pub fn new(
|
||||||
|
layers: Vec<usize>,
|
||||||
|
activation: ActivationFn<T>,
|
||||||
|
loss: LossFn<T>,
|
||||||
|
learning_rate: T,
|
||||||
|
) -> Self {
|
||||||
|
let mut weights = vec![];
|
||||||
|
|
||||||
|
let mut biases = vec![];
|
||||||
|
|
||||||
|
for i in 0..layers.len() - 1 {
|
||||||
|
weights.push(Mat::random(layers[i + 1], layers[i]));
|
||||||
|
biases.push(Mat::random(layers[i + 1], 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
NeuralNet {
|
||||||
|
architecture: layers,
|
||||||
|
weights,
|
||||||
|
biases,
|
||||||
|
data: vec![],
|
||||||
|
activation,
|
||||||
|
loss,
|
||||||
|
learning_rate,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> NeuralNet<T>
|
||||||
|
where
|
||||||
|
T: MatElem,
|
||||||
|
{
|
||||||
|
pub fn forward(&mut self, inputs: Mat<T>) -> Mat<T> {
|
||||||
|
if self.architecture[0] != inputs.rows {
|
||||||
|
panic!("Input vector does not have correct number of rows.")
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut current = inputs;
|
||||||
|
self.data = vec![current.clone()];
|
||||||
|
|
||||||
|
for i in 0..self.architecture.len() - 1 {
|
||||||
|
current = self.weights[i]
|
||||||
|
.dot(¤t)
|
||||||
|
.add(&self.biases[i])
|
||||||
|
.map(self.activation.f);
|
||||||
|
|
||||||
|
self.data.push(current.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
current
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn backprop(&mut self, prediction: Mat<T>, truth: Mat<T>) {
|
||||||
|
let mut losses = prediction.elementwise(&truth, self.loss.f_prime);
|
||||||
|
let mut gradients = prediction.clone().map(self.activation.f_prime);
|
||||||
|
|
||||||
|
for i in (0..self.architecture.len() - 1).rev() {
|
||||||
|
gradients = gradients
|
||||||
|
.elementwise_mul(&losses)
|
||||||
|
.map(|x| x * self.learning_rate.clone());
|
||||||
|
|
||||||
|
self.weights[i] = self.weights[i].add(&gradients.dot(&self.data[i].transpose()));
|
||||||
|
self.biases[i] = self.biases[i].add(&gradients);
|
||||||
|
|
||||||
|
losses = self.weights[i].transpose().dot(&losses);
|
||||||
|
gradients = self.data[i].map(self.activation.f_prime);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: add batch-wise training
|
||||||
|
// TODO: refactor to use matrices instead of 2d-vecs
|
||||||
|
pub fn train_basic(&mut self, inputs: Vec<Vec<T>>, truth: Vec<Vec<T>>, epochs: u32)
|
||||||
|
where
|
||||||
|
T: std::fmt::Display,
|
||||||
|
{
|
||||||
|
let width = epochs.ilog10() as usize + 1;
|
||||||
|
|
||||||
|
for i in 1..=epochs {
|
||||||
|
let mut outputs: Mat<T>;
|
||||||
|
for j in 0..inputs.len() {
|
||||||
|
outputs = self.forward(Mat::from(inputs[j].clone()));
|
||||||
|
self.backprop(outputs, Mat::from(truth[j].clone()));
|
||||||
|
}
|
||||||
|
|
||||||
|
if epochs < 20 || i % (epochs / 20) == 0 {
|
||||||
|
let mut loss = T::default();
|
||||||
|
for j in 0..inputs.len() {
|
||||||
|
outputs = self.forward(Mat::from(inputs[j].clone()));
|
||||||
|
loss = loss
|
||||||
|
+ outputs
|
||||||
|
.into_iter()
|
||||||
|
.zip(&truth[j])
|
||||||
|
.fold(T::default(), |sum, (y_hat, y)| {
|
||||||
|
sum + (self.loss.f)(y_hat, y.clone())
|
||||||
|
});
|
||||||
|
}
|
||||||
|
println!(
|
||||||
|
"epoch: {i:0>width$} / {epochs:0>width$} ;\tloss: {:.5}",
|
||||||
|
loss
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user