1use std::collections::{
2 HashMap,
3 BinaryHeap,
4};
5use std::collections::hash_map::Entry::{
6 Occupied,
7 Vacant,
8};
9
10use std::hash::Hash;
11
12use scored::MinScored;
13use super::visit::{
14 EdgeRef,
15 GraphBase,
16 IntoEdges,
17 VisitMap,
18 Visitable,
19};
20
21use algo::Measure;
22
23pub fn astar<G, F, H, K, IsGoal>(graph: G, start: G::NodeId, mut is_goal: IsGoal,
68 mut edge_cost: F, mut estimate_cost: H)
69 -> Option<(K, Vec<G::NodeId>)>
70 where G: IntoEdges + Visitable,
71 IsGoal: FnMut(G::NodeId) -> bool,
72 G::NodeId: Eq + Hash,
73 F: FnMut(G::EdgeRef) -> K,
74 H: FnMut(G::NodeId) -> K,
75 K: Measure + Copy,
76{
77 let mut visited = graph.visit_map();
78 let mut visit_next = BinaryHeap::new();
79 let mut scores = HashMap::new();
80 let mut path_tracker = PathTracker::<G>::new();
81
82 let zero_score = K::default();
83 scores.insert(start, zero_score);
84 visit_next.push(MinScored(estimate_cost(start), start));
85
86 while let Some(MinScored(_, node)) = visit_next.pop() {
87 if is_goal(node) {
88 let path = path_tracker.reconstruct_path_to(node);
89 let cost = scores[&node];
90 return Some((cost, path));
91 }
92
93 if !visited.visit(node) {
96 continue
97 }
98
99 let node_score = scores[&node];
102
103 for edge in graph.edges(node) {
104 let next = edge.target();
105 if visited.is_visited(&next) {
106 continue
107 }
108
109 let mut next_score = node_score + edge_cost(edge);
110
111 match scores.entry(next) {
112 Occupied(ent) => {
113 let old_score = *ent.get();
114 if next_score < old_score {
115 *ent.into_mut() = next_score;
116 path_tracker.set_predecessor(next, node);
117 } else {
118 next_score = old_score;
119 }
120 },
121 Vacant(ent) => {
122 ent.insert(next_score);
123 path_tracker.set_predecessor(next, node);
124 }
125 }
126
127 let next_estimate_score = next_score + estimate_cost(next);
128 visit_next.push(MinScored(next_estimate_score, next));
129 }
130 }
131
132 None
133}
134
135struct PathTracker<G>
136 where G: GraphBase,
137 G::NodeId: Eq + Hash,
138{
139 came_from: HashMap<G::NodeId, G::NodeId>,
140}
141
142impl<G> PathTracker<G>
143 where G: GraphBase,
144 G::NodeId: Eq + Hash,
145{
146 fn new() -> PathTracker<G> {
147 PathTracker {
148 came_from: HashMap::new(),
149 }
150 }
151
152 fn set_predecessor(&mut self, node: G::NodeId, previous: G::NodeId) {
153 self.came_from.insert(node, previous);
154 }
155
156 fn reconstruct_path_to(&self, last: G::NodeId) -> Vec<G::NodeId> {
157 let mut path = vec![last];
158
159 let mut current = last;
160 while let Some(&previous) = self.came_from.get(¤t) {
161 path.push(previous);
162 current = previous;
163 }
164
165 path.reverse();
166
167 path
168 }
169}