From 477cbb13f56f023922bde712ea5000df3803949c Mon Sep 17 00:00:00 2001 From: mxhagen Date: Fri, 21 Feb 2025 20:43:22 +0100 Subject: [PATCH] add dijkstra --- src/graph.rs | 197 ++++++++++++++++++++++++++++++++++++++++++++++++++- src/main.rs | 39 ++++++++-- 2 files changed, 229 insertions(+), 7 deletions(-) diff --git a/src/graph.rs b/src/graph.rs index ee64bba..606135a 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -1,11 +1,12 @@ use std::{ - collections::{HashMap, HashSet, VecDeque}, + collections::{BTreeMap, HashMap, HashSet, VecDeque}, + fmt::Debug, hash::Hash, }; -pub trait VertexType: Eq + Hash + Ord + Clone {} +pub trait VertexType: Eq + Hash + Ord + Clone + Debug {} -impl VertexType for T where T: Eq + Hash + Ord + Clone {} +impl VertexType for T where T: Eq + Hash + Ord + Clone + Debug {} pub struct Graph { pub vertices: HashSet, @@ -154,6 +155,196 @@ where } } +pub struct WeightedGraph { + pub vertices: HashSet, + pub edges: HashMap>, +} + +impl WeightedGraph +where + V: VertexType, +{ + /// create a weighted graph from a set of vertices and weighted edges. + pub fn new(vertices: I, edges: J) -> Self + where + I: IntoIterator, + J: IntoIterator, + { + let vertices = vertices.into_iter().collect(); + let mut parsed_edges = HashMap::new(); + for (from, to, c) in edges { + parsed_edges + .entry(from) + .and_modify(|e: &mut HashSet<(u64, V)>| { + e.insert((c, to.clone())); + }) + .or_insert_with(|| HashSet::from([(c, to)])); + } + + Self { + vertices, + edges: parsed_edges, + } + } + + /// get the number of vertices of the graph + pub fn vertex_count(&self) -> usize { + self.vertices.len() + } + + /// get the number of edges of the graph + pub fn edge_count(&self) -> usize { + self.edges.values().map(HashSet::len).sum() + } + + /// check if an edge is contained in the graph + /// (only checks the provided direction) + pub fn has_edge(&self, edge: (&V, &V)) -> bool { + let (from, to) = edge; + self.edges + .get(from) + .map_or(false, |v| v.iter().any(|(_, x)| x == to)) + } + + /// check if an edge (a, b) and its' reverse (b, a) are both contained in the graph + pub fn has_bidirectional_edge(&self, edge: (&V, &V)) -> bool { + let (a, b) = edge; + self.has_edge((a, b)) && self.has_edge((b, a)) + } + + /// check if a vertex is contained in the graph + pub fn has_vertex(&self, vertex: &V) -> bool { + self.vertices.contains(vertex) + } + + /// get a slice containing all neighbors of the edge + pub fn neighbors(&self, vertex: &V) -> Vec<&(u64, V)> { + self.edges + .get(vertex) + .map(|es| es.iter().collect()) + .unwrap_or_default() + } + + /// find a path between two nodes using breadth-first-search. + /// this finds a path with the least possible edges. + /// + /// does not take into account edge weights. + pub fn find_path_bfs(&self, from: &V, to: &V) -> Option> { + let mut q = VecDeque::with_capacity(self.vertices.len()); + let mut visited = HashSet::with_capacity(self.vertices.len()); + + q.push_back(vec![from]); + visited.insert(from); + + while let Some(mut path) = q.pop_front() { + let current = path.last().unwrap(); + + for (_, neighbor) in self.neighbors(current) { + if neighbor == to { + return path + .into_iter() + .cloned() + .chain([to.clone()]) + .collect::>() + .into(); + } + + path.push(neighbor); + q.push_back(path.clone()); + path.pop(); + } + } + + None + } + + /// find a path between two nodes using depth-first-search. + /// + /// this is short-circuiting and therefore does not guarantee + /// the found path to contain the least possible edges. + /// + /// does not take into account edge weights. + pub fn find_path_dfs(&self, from: &V, to: &V) -> Option> { + let mut q = Vec::with_capacity(self.vertices.len()); + + q.push(vec![from]); + + while let Some(mut path) = q.pop() { + let ¤t = path.last().unwrap(); + + if current == to { + return path.into_iter().cloned().collect::>().into(); + } + + for (_, neighbor) in self.neighbors(current).iter().rev() { + if path.contains(&neighbor) { + continue; + } + + path.push(neighbor); + q.push(path.clone()); + path.pop(); + } + } + + None + } + + pub fn find_path_dijkstra(&self, from: &V, to: &V) -> Option> { + use std::cmp::Reverse as Rev; + use std::collections::BinaryHeap; + + let mut dist: BTreeMap<_, _> = self + .vertices + .iter() + .zip(std::iter::repeat(u64::MAX)) + .collect(); + + let mut prev: HashMap<_, _> = self.vertices.iter().zip(std::iter::repeat(None)).collect(); + let mut q: BinaryHeap<_> = self.vertices.iter().map(Rev).collect(); + dist.entry(from).and_modify(|c| *c = 0); + + let mut i = 0; + while !q.is_empty() { + let u = q.iter().min_by_key(|v| dist[v.0]).unwrap().0.clone(); + + if &u == to { + let mut s = VecDeque::new(); + let mut u = Some(to.clone()); + if prev[&u.clone()?].is_some() || u.clone()? == *from { + while u.is_some() { + s.push_front(u.clone()?); + u = prev[&u?].clone(); + } + return Some(s.iter().cloned().collect()); + } + + return None; + } + + q.retain(|v| *v != Rev(&u)); // remove u + i += 1; + if i > 5 { + panic!(); + } + + for (weight, v) in self.neighbors(&u) { + let alt = dist[&u] + weight; + if alt < dist[&v] { + *dist.get_mut(v)? = alt; + *prev.get_mut(v)? = Some(u.clone()); + } + } + } + + None + } + + pub fn iter(&self) -> impl Iterator { + self.vertices.iter() + } +} + #[macro_export] macro_rules! graph { ( $( $v:tt : $( $e:tt ),* );* $(;)? ) => {{ diff --git a/src/main.rs b/src/main.rs index acbd595..987d6dc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,7 @@ use std::collections::HashSet; mod graph; -pub use graph::Graph; +pub use graph::{Graph, WeightedGraph}; fn main() { // test graph: @@ -22,7 +22,6 @@ fn main() { assert!(g.has_vertex(v), "G should contain vertex {v}"); } - // edges for (a, neighbors) in &g.edges { for b in neighbors { @@ -34,7 +33,6 @@ fn main() { } } - // dfs, bfs let n = g.vertex_count(); let vs = g.vertices.iter().collect::>(); @@ -66,12 +64,45 @@ fn main() { } } - // iterator implementation let all_vs = g.vertices.clone(); let seen = g.into_iter().collect::>(); assert_eq!(seen, all_vs, "Iterating G should yield all vertices"); + // weighted test graph: + // ┏━5━━ b ━━2━┓ + // ┃ ┃ ┃ + // -> a 1 d ━━2━━ e ━━6━━ f + // ┃ ┃ ┃ + // ┗━3━━ c ━━4━┛ + + let g = WeightedGraph::new( + ['a', 'b', 'c', 'd', 'e', 'f'], + [ + ('a', 'b', 5), + ('a', 'c', 2), + ('b', 'c', 1), + ('b', 'd', 2), + ('c', 'b', 1), + ('c', 'd', 4), + ('d', 'e', 2), + ('e', 'f', 6), + ], + ); + + // disjkstra + let path = g.find_path_dijkstra(&'a', &'f'); + assert!( + path.is_some(), + "Path from a to f should be found in weighted test graph using Dijkstra" + ); + + let path = path.unwrap(); + assert_eq!( + path, + vec!['a', 'c', 'b', 'd', 'e', 'f'], + "Dijkstra should find cheapest way in weighted example graph" + ); // yay println!("All tests passed.");