Separate out Dijkstra implementation

This separates the Dijkstra implementation into its own generic function.

Whilst doing this we can simplify the code and also save yet more memory!

Still don't get correct results for part 2.
This commit is contained in:
2022-01-12 09:30:21 +00:00
parent 29005ed7ba
commit a1a7d11583
3 changed files with 227 additions and 176 deletions

160
2021/graph-utils.h Normal file
View File

@@ -0,0 +1,160 @@
//
// Created by mgretton on 10/01/2022.
//
#ifndef ADVENT_OF_CODE_GRAPH_UTILS_H
#define ADVENT_OF_CODE_GRAPH_UTILS_H
#include <algorithm>
#include <limits>
#include <map>
#include <numeric>
#include <unordered_map>
#include <unordered_set>
#include <utility>
/** \brief Implement Dijkstra's algorithm.
*
* @tparam Cost Cost type
* @tparam Node Node class
* @tparam TransitionManager Class type that manages state transitions
* @tparam FinishedFn Function type to call to test whether a node matches finished state
* @param initial Initial node
* @param initial_cost Cost of initial node
* @param transition_manager Object to manage state transitions
* @return Node, Cost pair indicating success. On failure default node with
* maximum cost is returned.
*
* \a Cost Must be a numeric type. \a Node must be a class which is default constructable, be less
* than comparable.
*
* \a transition_manager must be an object with the following public interfaces:
* \code
* auto is_finished(Node const& node) -> bool;
*
* template <typename Inserter>
* void generate_children(Node const& node, Inserter inserter);
* \endcode
*
* The \c generate_children() method is called whenever we examine the node \c node, and should
* call \c inserter(next_node, cost_delta) for every node directly reachable from \c node. Where
* \c next_node is the next node to visit and \c cost_delta is the incremental cost to visit the new
* node. \c inserter should be called for all possible nodes that are visitable, it will ensure
* that duplicates are managed correctly/
*
* The \c is_finished() method should return true if the given node is in a finished state.
*/
template<typename Cost, typename Node, typename TransitionManager>
auto dijkstra(Node const& initial, Cost initial_cost, TransitionManager transition_manager)
-> std::pair<Node, Cost>
{
/** Helper struct to order node pointers. */
struct NodePCmp
{
auto operator()(Node const* lhs, Node const* rhs) const noexcept -> bool
{
if (lhs == nullptr && rhs != nullptr) {
return true;
}
if (rhs == nullptr) {
return false;
}
return *lhs < *rhs;
}
auto operator()(Node const* lhs, Node const& rhs) const noexcept -> bool
{
if (lhs == nullptr) {
return true;
}
return *lhs < rhs;
}
auto operator()(Node const& lhs, Node const* rhs) const noexcept -> bool
{
if (rhs == nullptr) {
return false;
}
return lhs < *rhs;
}
};
/* We maintain two maps - one from node to cost and the other from cost to nodes. */
/** \c nodes maintains a map of all the nodes we've visited or want to visit. Nodes that cost
* less than the current cost at the front of costs have been visited. The rest haven't.
*/
std::map<Node const*, Cost, NodePCmp> nodes;
/** \c costs maintains a map of costs to the nodes that cost that much to visit. */
std::map<Cost, std::unordered_set<Node const*>> costs;
Cost current_cost{std::numeric_limits<Cost>::min()};
/* Helper lambda to clean up after ourselves. */
auto cleanup = [](auto& nodes) {
for (auto const& it : nodes) {
delete it.first;
}
};
/* Helper lambda to insert into the maps. */
auto inserter = [&costs, &nodes, &current_cost](Node const& node, Cost cost) {
cost += current_cost;
auto node_it{nodes.find(&node)};
/* Skip inserting nodes we've already visited, or that would cost more to visit. */
if (node_it != nodes.end() && node_it->second <= cost) {
return;
}
Node const* nodep{nullptr};
if (node_it == nodes.end()) {
nodep = new Node(node);
nodes.insert({nodep, cost});
}
else {
/* Node has a cheaper cost than we thought: Remove the node from its old cost list */
nodep = node_it->first;
auto cost_it{costs.find(node_it->second)};
assert(cost_it != costs.end());
cost_it->second.erase(nodep);
/* Now update the cost in the nodes map and use the nodep as the node pointer. */
node_it->second = cost;
}
auto [cost_it, success] = costs.insert({cost, {}});
cost_it->second.insert(nodep);
};
Node* init{new Node(initial)};
nodes.insert({init, initial_cost});
costs.insert({initial_cost, {init}});
std::uint64_t iter{0};
while (!costs.empty()) {
auto cost_it{costs.begin()};
current_cost = cost_it->first;
assert(iter < nodes.size());
assert(std::accumulate(costs.begin(), costs.end(), std::size_t{0}, [](auto a, auto c) {
return a + c.second.size();
}) == nodes.size() - iter);
for (auto& nodep : cost_it->second) {
if (iter++ % 100'000 == 0) {
std::cout << "Iteration: " << iter << " cost " << current_cost
<< " total number of nodes: " << nodes.size()
<< ", number of costs left to visit: " << costs.size()
<< ", number of nodes left: " << (nodes.size() - iter) << '\n';
}
if (transition_manager.is_finished(*nodep)) {
auto result{std::make_pair(*nodep, current_cost)};
cleanup(nodes);
return result;
}
transition_manager.generate_children(*nodep, inserter);
}
costs.erase(costs.begin());
}
cleanup(nodes);
return {Node{}, std::numeric_limits<Cost>::max()};
}
#endif // ADVENT_OF_CODE_GRAPH_UTILS_H

View File

@@ -2,11 +2,14 @@
#include <cassert>
#include <functional>
#include <iostream>
#include <limits>
#include <map>
#include <regex>
#include <set>
#include <string>
#include "graph-utils.h"
// Map:
// #############
// #ab.c.d.e.fg#
@@ -166,31 +169,27 @@ private:
std::array<Type, 15> State::finished_ = {'.', '.', '.', '.', '.', '.', '.', 'A',
'B', 'C', 'D', 'A', 'B', 'C', 'D'};
struct StateCmp
struct StateTranstitionManager
{
bool operator()(State const* lhs, State const* rhs) const noexcept
{
if (lhs == nullptr && rhs != nullptr) {
return true;
}
if (rhs == nullptr) {
return true;
}
return *lhs < *rhs;
}
};
bool is_finished(State const& state) { return state.finished(); }
struct CostStateCmp
{
bool operator()(State const* lhs, State const* rhs) const noexcept
template<typename Inserter>
void generate_children(State const& state, Inserter inserter)
{
if (lhs == nullptr && rhs != nullptr) {
return true;
for (unsigned i = 0; i < state.size(); ++i) {
if (state.node(i) == '.') {
continue;
}
auto [it_begin, it_end] = valid_moves.equal_range(i);
for (auto move_it{it_begin}; move_it != it_end; ++move_it) {
State next_state(state);
UInt cost_delta = move_it->second.second * multipliers[state.node(i)];
if (next_state.move(i, move_it->second.first, cost_delta)) {
inserter(next_state, cost_delta);
}
}
}
if (rhs == nullptr) {
return true;
}
return lhs->cost() < rhs->cost() || (lhs->cost() == rhs->cost() && *lhs < *rhs);
}
};
@@ -207,7 +206,7 @@ auto main() -> int
std::cerr << "Incorrect first line.\n";
return 1;
}
State* initial_state = new State;
State initial_state{};
std::getline(std::cin, line);
if (!std::regex_search(line, m, line2_re)) {
@@ -215,7 +214,7 @@ auto main() -> int
return 1;
}
for (unsigned i = 0; i < 7; ++i) {
initial_state->node(i) = m.str(i + 1)[0];
initial_state.node(i) = m.str(i + 1)[0];
}
std::getline(std::cin, line);
@@ -224,7 +223,7 @@ auto main() -> int
return 1;
}
for (unsigned i = 0; i < 4; ++i) {
initial_state->node(7 + i) = m.str(i + 1)[0];
initial_state.node(7 + i) = m.str(i + 1)[0];
}
std::getline(std::cin, line);
@@ -233,68 +232,11 @@ auto main() -> int
return 1;
}
for (unsigned i = 0; i < 4; ++i) {
initial_state->node(11 + i) = m.str(i + 1)[0];
initial_state.node(11 + i) = m.str(i + 1)[0];
}
std::set<State*, StateCmp> states;
std::set<State*, CostStateCmp> costs;
std::set<State*, StateCmp> visited;
states.insert(initial_state);
costs.insert(initial_state);
while (!costs.empty()) {
assert(costs.size() == states.size());
auto it{costs.begin()};
State* state{*it};
visited.insert(state);
states.erase(state);
costs.erase(it);
if (visited.size() % 10'000 == 0) {
std::cout << "Visited: " << visited.size() << " number of states: " << states.size()
<< " Min energy: " << state->cost() << '\n';
}
if (state->finished()) {
std::cout << "Done with cost " << state->cost() << '\n';
return 0;
}
for (unsigned i = 0; i < state->size(); ++i) {
if (state->node(i) == '.') {
continue;
}
auto [it_begin, it_end] = valid_moves.equal_range(i);
for (auto move_it{it_begin}; move_it != it_end; ++move_it) {
State* next_state = new State{*state};
UInt cost_delta = move_it->second.second * multipliers[state->node(i)];
bool keep{false};
if (next_state->move(i, move_it->second.first, cost_delta) &&
!visited.contains(next_state)) {
auto [insert_it, success] = states.insert(next_state);
if (!success) {
auto old_state{*insert_it};
if (next_state->cost() < old_state->cost()) {
keep = true;
costs.erase(old_state);
states.erase(old_state);
states.insert(next_state);
costs.insert(next_state);
}
}
else {
keep = true;
costs.insert(next_state);
}
}
if (!keep) {
delete next_state;
}
}
}
}
std::cerr << "FAILED\n";
return 1;
auto [finished_state, finished_cost] =
dijkstra(initial_state, UInt{0}, StateTranstitionManager{});
std::cout << "Done with cost " << finished_cost << '\n';
return 0;
}

View File

@@ -7,6 +7,7 @@
#include <set>
#include <string>
#include "graph-utils.h"
// Map:
// #############
// #ab.c.d.e.fg#
@@ -50,8 +51,8 @@ std::multimap<Position, std::pair<Position, UInt>> valid_moves{
{10, {2, 6}}, {10, {3, 4}}, {10, {4, 2}}, {10, {5, 2}}, {10, {6, 3}}, {10, {14, 1}},
{11, {7, 1}}, {12, {8, 1}}, {13, {9, 1}}, {14, {10, 1}}, {11, {15, 1}}, {12, {16, 1}},
{13, {17, 1}}, {14, {18, 1}}, {15, {11, 1}}, {16, {12, 1}}, {17, {13, 1}}, {18, {14, 1}},
{14, {10, 1}}, {15, {19, 1}}, {16, {20, 1}}, {17, {21, 1}}, {18, {22, 1}}, {19, {15, 1}},
{20, {16, 1}}, {21, {17, 1}}, {22, {18, 1}}};
{15, {19, 1}}, {16, {20, 1}}, {17, {21, 1}}, {18, {22, 1}}, {19, {15, 1}}, {20, {16, 1}},
{21, {17, 1}}, {22, {18, 1}}};
std::map<Type, UInt> multipliers{{'A', 1}, {'B', 10}, {'C', 100}, {'D', 1000}};
@@ -190,36 +191,41 @@ std::array<Type, State::size_> State::finished_ = {'.', '.', '.', '.', '.', '.',
'B', 'C', 'D', 'A', 'B', 'C', 'D', 'A',
'B', 'C', 'D', 'A', 'B', 'C', 'D'};
struct StateCmp
struct StateTranstitionManager
{
bool operator()(State const* lhs, State const* rhs) const noexcept
{
if (lhs == nullptr && rhs != nullptr) {
return true;
}
if (rhs == nullptr) {
return true;
}
return *lhs < *rhs;
}
};
bool is_finished(State const& state) { return state.finished(); }
struct CostStateCmp
{
bool operator()(State const* lhs, State const* rhs) const noexcept
template<typename Inserter>
void generate_children(State const& state, Inserter inserter)
{
if (lhs == nullptr && rhs != nullptr) {
return true;
for (unsigned i = 0; i < state.size(); ++i) {
if (state.node(i) == '.') {
continue;
}
auto [it_begin, it_end] = valid_moves.equal_range(i);
for (auto move_it{it_begin}; move_it != it_end; ++move_it) {
State next_state(state);
UInt cost_delta = move_it->second.second * multipliers[state.node(i)];
if (next_state.move(i, move_it->second.first, cost_delta)) {
inserter(next_state, cost_delta);
}
}
}
if (rhs == nullptr) {
return true;
}
return lhs->cost() < rhs->cost() || (lhs->cost() == rhs->cost() && *lhs < *rhs);
}
};
auto main() -> int
{
std::cout << "digraph { \n";
for (auto move_it : valid_moves) {
if (move_it.first < move_it.second.first) {
continue;
}
std::cout << (int)move_it.first << " -> " << (int)move_it.second.first << " [label=\""
<< move_it.second.second << "\" ];\n";
}
std::cout << "};\n";
std::regex line2_re{R"(#(.)(.)\.(.)\.(.)\.(.)\.(.)(.)#)"};
std::regex line3_re{"###(.)#(.)#(.)#(.)###"};
std::regex line4_re{"#(.)#(.)#(.)#(.)#"};
@@ -231,7 +237,7 @@ auto main() -> int
std::cerr << "Incorrect first line.\n";
return 1;
}
State* initial_state = new State;
State initial_state{};
std::getline(std::cin, line);
if (!std::regex_search(line, m, line2_re)) {
@@ -239,7 +245,7 @@ auto main() -> int
return 1;
}
for (unsigned i = 0; i < 7; ++i) {
initial_state->node(i) = m.str(i + 1)[0];
initial_state.node(i) = m.str(i + 1)[0];
}
std::getline(std::cin, line);
@@ -248,7 +254,7 @@ auto main() -> int
return 1;
}
for (unsigned i = 0; i < 4; ++i) {
initial_state->node(7 + i) = m.str(i + 1)[0];
initial_state.node(7 + i) = m.str(i + 1)[0];
}
std::getline(std::cin, line);
@@ -257,7 +263,7 @@ auto main() -> int
return 1;
}
for (unsigned i = 0; i < 4; ++i) {
initial_state->node(11 + i) = m.str(i + 1)[0];
initial_state.node(11 + i) = m.str(i + 1)[0];
}
std::getline(std::cin, line);
@@ -266,7 +272,7 @@ auto main() -> int
return 1;
}
for (unsigned i = 0; i < 4; ++i) {
initial_state->node(15 + i) = m.str(i + 1)[0];
initial_state.node(15 + i) = m.str(i + 1)[0];
}
std::getline(std::cin, line);
@@ -275,68 +281,11 @@ auto main() -> int
return 1;
}
for (unsigned i = 0; i < 4; ++i) {
initial_state->node(19 + i) = m.str(i + 1)[0];
initial_state.node(19 + i) = m.str(i + 1)[0];
}
std::set<State*, StateCmp> states;
std::set<State*, CostStateCmp> costs;
std::set<State*, StateCmp> visited;
states.insert(initial_state);
costs.insert(initial_state);
while (!costs.empty()) {
assert(costs.size() == states.size());
auto it{costs.begin()};
State* state{*it};
visited.insert(state);
states.erase(state);
costs.erase(it);
if (visited.size() % 10'000 == 0) {
std::cout << "Visited: " << visited.size() << " number of states: " << states.size()
<< " Min energy: " << state->cost() << '\n';
}
if (state->finished()) {
std::cout << "Done with cost " << state->cost() << '\n';
return 0;
}
for (unsigned i = 0; i < state->size(); ++i) {
if (state->node(i) == '.') {
continue;
}
auto [it_begin, it_end] = valid_moves.equal_range(i);
for (auto move_it{it_begin}; move_it != it_end; ++move_it) {
State* next_state = new State{*state};
UInt cost_delta = move_it->second.second * multipliers[state->node(i)];
bool keep{false};
if (next_state->move(i, move_it->second.first, cost_delta) &&
!visited.contains(next_state)) {
auto [insert_it, success] = states.insert(next_state);
if (!success) {
auto old_state{*insert_it};
if (next_state->cost() < old_state->cost()) {
keep = true;
costs.erase(old_state);
states.erase(old_state);
states.insert(next_state);
costs.insert(next_state);
}
}
else {
keep = true;
costs.insert(next_state);
}
}
if (!keep) {
delete next_state;
}
}
}
}
std::cerr << "FAILED\n";
return 1;
auto [finished_state, finished_cost] =
dijkstra(initial_state, UInt{0}, StateTranstitionManager{});
std::cout << "Done with cost " << finished_cost << '\n';
return 0;
}