From a1a7d115830e28d14583b5429e9a7a3279186156 Mon Sep 17 00:00:00 2001 From: Matthew Gretton-Dann Date: Wed, 12 Jan 2022 09:30:21 +0000 Subject: [PATCH] Separate out Dijkstra implementation This separates the Dijkstra implementation into its own generic function. Whilst doing this we can simplify the code and also save yet more memory! Still don't get correct results for part 2. --- 2021/graph-utils.h | 160 +++++++++++++++++++++++++++++++++++++++++++ 2021/puzzle-23-01.cc | 114 ++++++++---------------------- 2021/puzzle-23-02.cc | 129 +++++++++++----------------------- 3 files changed, 227 insertions(+), 176 deletions(-) create mode 100644 2021/graph-utils.h diff --git a/2021/graph-utils.h b/2021/graph-utils.h new file mode 100644 index 0000000..be60678 --- /dev/null +++ b/2021/graph-utils.h @@ -0,0 +1,160 @@ +// +// Created by mgretton on 10/01/2022. +// + +#ifndef ADVENT_OF_CODE_GRAPH_UTILS_H +#define ADVENT_OF_CODE_GRAPH_UTILS_H + +#include +#include +#include +#include +#include +#include +#include + +/** \brief Implement Dijkstra's algorithm. + * + * @tparam Cost Cost type + * @tparam Node Node class + * @tparam TransitionManager Class type that manages state transitions + * @tparam FinishedFn Function type to call to test whether a node matches finished state + * @param initial Initial node + * @param initial_cost Cost of initial node + * @param transition_manager Object to manage state transitions + * @return Node, Cost pair indicating success. On failure default node with + * maximum cost is returned. + * + * \a Cost Must be a numeric type. \a Node must be a class which is default constructable, be less + * than comparable. + * + * \a transition_manager must be an object with the following public interfaces: + * \code + * auto is_finished(Node const& node) -> bool; + * + * template + * void generate_children(Node const& node, Inserter inserter); + * \endcode + * + * The \c generate_children() method is called whenever we examine the node \c node, and should + * call \c inserter(next_node, cost_delta) for every node directly reachable from \c node. Where + * \c next_node is the next node to visit and \c cost_delta is the incremental cost to visit the new + * node. \c inserter should be called for all possible nodes that are visitable, it will ensure + * that duplicates are managed correctly/ + * + * The \c is_finished() method should return true if the given node is in a finished state. + */ +template +auto dijkstra(Node const& initial, Cost initial_cost, TransitionManager transition_manager) + -> std::pair +{ + /** Helper struct to order node pointers. */ + struct NodePCmp + { + auto operator()(Node const* lhs, Node const* rhs) const noexcept -> bool + { + if (lhs == nullptr && rhs != nullptr) { + return true; + } + if (rhs == nullptr) { + return false; + } + return *lhs < *rhs; + } + + auto operator()(Node const* lhs, Node const& rhs) const noexcept -> bool + { + if (lhs == nullptr) { + return true; + } + return *lhs < rhs; + } + + auto operator()(Node const& lhs, Node const* rhs) const noexcept -> bool + { + if (rhs == nullptr) { + return false; + } + return lhs < *rhs; + } + }; + + /* We maintain two maps - one from node to cost and the other from cost to nodes. */ + /** \c nodes maintains a map of all the nodes we've visited or want to visit. Nodes that cost + * less than the current cost at the front of costs have been visited. The rest haven't. + */ + std::map nodes; + /** \c costs maintains a map of costs to the nodes that cost that much to visit. */ + std::map> costs; + Cost current_cost{std::numeric_limits::min()}; + + /* Helper lambda to clean up after ourselves. */ + auto cleanup = [](auto& nodes) { + for (auto const& it : nodes) { + delete it.first; + } + }; + + /* Helper lambda to insert into the maps. */ + auto inserter = [&costs, &nodes, ¤t_cost](Node const& node, Cost cost) { + cost += current_cost; + auto node_it{nodes.find(&node)}; + /* Skip inserting nodes we've already visited, or that would cost more to visit. */ + if (node_it != nodes.end() && node_it->second <= cost) { + return; + } + + Node const* nodep{nullptr}; + if (node_it == nodes.end()) { + nodep = new Node(node); + nodes.insert({nodep, cost}); + } + else { + /* Node has a cheaper cost than we thought: Remove the node from its old cost list */ + nodep = node_it->first; + auto cost_it{costs.find(node_it->second)}; + assert(cost_it != costs.end()); + cost_it->second.erase(nodep); + + /* Now update the cost in the nodes map and use the nodep as the node pointer. */ + node_it->second = cost; + } + + auto [cost_it, success] = costs.insert({cost, {}}); + cost_it->second.insert(nodep); + }; + + Node* init{new Node(initial)}; + nodes.insert({init, initial_cost}); + costs.insert({initial_cost, {init}}); + + std::uint64_t iter{0}; + while (!costs.empty()) { + auto cost_it{costs.begin()}; + current_cost = cost_it->first; + assert(iter < nodes.size()); + assert(std::accumulate(costs.begin(), costs.end(), std::size_t{0}, [](auto a, auto c) { + return a + c.second.size(); + }) == nodes.size() - iter); + for (auto& nodep : cost_it->second) { + if (iter++ % 100'000 == 0) { + std::cout << "Iteration: " << iter << " cost " << current_cost + << " total number of nodes: " << nodes.size() + << ", number of costs left to visit: " << costs.size() + << ", number of nodes left: " << (nodes.size() - iter) << '\n'; + } + if (transition_manager.is_finished(*nodep)) { + auto result{std::make_pair(*nodep, current_cost)}; + cleanup(nodes); + return result; + } + transition_manager.generate_children(*nodep, inserter); + } + costs.erase(costs.begin()); + } + + cleanup(nodes); + return {Node{}, std::numeric_limits::max()}; +} + +#endif // ADVENT_OF_CODE_GRAPH_UTILS_H diff --git a/2021/puzzle-23-01.cc b/2021/puzzle-23-01.cc index b21b52a..c4d1715 100644 --- a/2021/puzzle-23-01.cc +++ b/2021/puzzle-23-01.cc @@ -2,11 +2,14 @@ #include #include #include +#include #include #include #include #include +#include "graph-utils.h" + // Map: // ############# // #ab.c.d.e.fg# @@ -166,31 +169,27 @@ private: std::array State::finished_ = {'.', '.', '.', '.', '.', '.', '.', 'A', 'B', 'C', 'D', 'A', 'B', 'C', 'D'}; -struct StateCmp +struct StateTranstitionManager { - bool operator()(State const* lhs, State const* rhs) const noexcept - { - if (lhs == nullptr && rhs != nullptr) { - return true; - } - if (rhs == nullptr) { - return true; - } - return *lhs < *rhs; - } -}; + bool is_finished(State const& state) { return state.finished(); } -struct CostStateCmp -{ - bool operator()(State const* lhs, State const* rhs) const noexcept + template + void generate_children(State const& state, Inserter inserter) { - if (lhs == nullptr && rhs != nullptr) { - return true; + for (unsigned i = 0; i < state.size(); ++i) { + if (state.node(i) == '.') { + continue; + } + + auto [it_begin, it_end] = valid_moves.equal_range(i); + for (auto move_it{it_begin}; move_it != it_end; ++move_it) { + State next_state(state); + UInt cost_delta = move_it->second.second * multipliers[state.node(i)]; + if (next_state.move(i, move_it->second.first, cost_delta)) { + inserter(next_state, cost_delta); + } + } } - if (rhs == nullptr) { - return true; - } - return lhs->cost() < rhs->cost() || (lhs->cost() == rhs->cost() && *lhs < *rhs); } }; @@ -207,7 +206,7 @@ auto main() -> int std::cerr << "Incorrect first line.\n"; return 1; } - State* initial_state = new State; + State initial_state{}; std::getline(std::cin, line); if (!std::regex_search(line, m, line2_re)) { @@ -215,7 +214,7 @@ auto main() -> int return 1; } for (unsigned i = 0; i < 7; ++i) { - initial_state->node(i) = m.str(i + 1)[0]; + initial_state.node(i) = m.str(i + 1)[0]; } std::getline(std::cin, line); @@ -224,7 +223,7 @@ auto main() -> int return 1; } for (unsigned i = 0; i < 4; ++i) { - initial_state->node(7 + i) = m.str(i + 1)[0]; + initial_state.node(7 + i) = m.str(i + 1)[0]; } std::getline(std::cin, line); @@ -233,68 +232,11 @@ auto main() -> int return 1; } for (unsigned i = 0; i < 4; ++i) { - initial_state->node(11 + i) = m.str(i + 1)[0]; + initial_state.node(11 + i) = m.str(i + 1)[0]; } - std::set states; - std::set costs; - std::set visited; - states.insert(initial_state); - costs.insert(initial_state); - - while (!costs.empty()) { - assert(costs.size() == states.size()); - auto it{costs.begin()}; - - State* state{*it}; - visited.insert(state); - states.erase(state); - costs.erase(it); - - if (visited.size() % 10'000 == 0) { - std::cout << "Visited: " << visited.size() << " number of states: " << states.size() - << " Min energy: " << state->cost() << '\n'; - } - - if (state->finished()) { - std::cout << "Done with cost " << state->cost() << '\n'; - return 0; - } - - for (unsigned i = 0; i < state->size(); ++i) { - if (state->node(i) == '.') { - continue; - } - auto [it_begin, it_end] = valid_moves.equal_range(i); - for (auto move_it{it_begin}; move_it != it_end; ++move_it) { - State* next_state = new State{*state}; - UInt cost_delta = move_it->second.second * multipliers[state->node(i)]; - bool keep{false}; - if (next_state->move(i, move_it->second.first, cost_delta) && - !visited.contains(next_state)) { - auto [insert_it, success] = states.insert(next_state); - if (!success) { - auto old_state{*insert_it}; - if (next_state->cost() < old_state->cost()) { - keep = true; - costs.erase(old_state); - states.erase(old_state); - states.insert(next_state); - costs.insert(next_state); - } - } - else { - keep = true; - costs.insert(next_state); - } - } - if (!keep) { - delete next_state; - } - } - } - } - - std::cerr << "FAILED\n"; - return 1; + auto [finished_state, finished_cost] = + dijkstra(initial_state, UInt{0}, StateTranstitionManager{}); + std::cout << "Done with cost " << finished_cost << '\n'; + return 0; } diff --git a/2021/puzzle-23-02.cc b/2021/puzzle-23-02.cc index a933daa..c68621a 100644 --- a/2021/puzzle-23-02.cc +++ b/2021/puzzle-23-02.cc @@ -7,6 +7,7 @@ #include #include +#include "graph-utils.h" // Map: // ############# // #ab.c.d.e.fg# @@ -50,8 +51,8 @@ std::multimap> valid_moves{ {10, {2, 6}}, {10, {3, 4}}, {10, {4, 2}}, {10, {5, 2}}, {10, {6, 3}}, {10, {14, 1}}, {11, {7, 1}}, {12, {8, 1}}, {13, {9, 1}}, {14, {10, 1}}, {11, {15, 1}}, {12, {16, 1}}, {13, {17, 1}}, {14, {18, 1}}, {15, {11, 1}}, {16, {12, 1}}, {17, {13, 1}}, {18, {14, 1}}, - {14, {10, 1}}, {15, {19, 1}}, {16, {20, 1}}, {17, {21, 1}}, {18, {22, 1}}, {19, {15, 1}}, - {20, {16, 1}}, {21, {17, 1}}, {22, {18, 1}}}; + {15, {19, 1}}, {16, {20, 1}}, {17, {21, 1}}, {18, {22, 1}}, {19, {15, 1}}, {20, {16, 1}}, + {21, {17, 1}}, {22, {18, 1}}}; std::map multipliers{{'A', 1}, {'B', 10}, {'C', 100}, {'D', 1000}}; @@ -190,36 +191,41 @@ std::array State::finished_ = {'.', '.', '.', '.', '.', '.', 'B', 'C', 'D', 'A', 'B', 'C', 'D', 'A', 'B', 'C', 'D', 'A', 'B', 'C', 'D'}; -struct StateCmp +struct StateTranstitionManager { - bool operator()(State const* lhs, State const* rhs) const noexcept - { - if (lhs == nullptr && rhs != nullptr) { - return true; - } - if (rhs == nullptr) { - return true; - } - return *lhs < *rhs; - } -}; + bool is_finished(State const& state) { return state.finished(); } -struct CostStateCmp -{ - bool operator()(State const* lhs, State const* rhs) const noexcept + template + void generate_children(State const& state, Inserter inserter) { - if (lhs == nullptr && rhs != nullptr) { - return true; + for (unsigned i = 0; i < state.size(); ++i) { + if (state.node(i) == '.') { + continue; + } + + auto [it_begin, it_end] = valid_moves.equal_range(i); + for (auto move_it{it_begin}; move_it != it_end; ++move_it) { + State next_state(state); + UInt cost_delta = move_it->second.second * multipliers[state.node(i)]; + if (next_state.move(i, move_it->second.first, cost_delta)) { + inserter(next_state, cost_delta); + } + } } - if (rhs == nullptr) { - return true; - } - return lhs->cost() < rhs->cost() || (lhs->cost() == rhs->cost() && *lhs < *rhs); } }; auto main() -> int { + std::cout << "digraph { \n"; + for (auto move_it : valid_moves) { + if (move_it.first < move_it.second.first) { + continue; + } + std::cout << (int)move_it.first << " -> " << (int)move_it.second.first << " [label=\"" + << move_it.second.second << "\" ];\n"; + } + std::cout << "};\n"; std::regex line2_re{R"(#(.)(.)\.(.)\.(.)\.(.)\.(.)(.)#)"}; std::regex line3_re{"###(.)#(.)#(.)#(.)###"}; std::regex line4_re{"#(.)#(.)#(.)#(.)#"}; @@ -231,7 +237,7 @@ auto main() -> int std::cerr << "Incorrect first line.\n"; return 1; } - State* initial_state = new State; + State initial_state{}; std::getline(std::cin, line); if (!std::regex_search(line, m, line2_re)) { @@ -239,7 +245,7 @@ auto main() -> int return 1; } for (unsigned i = 0; i < 7; ++i) { - initial_state->node(i) = m.str(i + 1)[0]; + initial_state.node(i) = m.str(i + 1)[0]; } std::getline(std::cin, line); @@ -248,7 +254,7 @@ auto main() -> int return 1; } for (unsigned i = 0; i < 4; ++i) { - initial_state->node(7 + i) = m.str(i + 1)[0]; + initial_state.node(7 + i) = m.str(i + 1)[0]; } std::getline(std::cin, line); @@ -257,7 +263,7 @@ auto main() -> int return 1; } for (unsigned i = 0; i < 4; ++i) { - initial_state->node(11 + i) = m.str(i + 1)[0]; + initial_state.node(11 + i) = m.str(i + 1)[0]; } std::getline(std::cin, line); @@ -266,7 +272,7 @@ auto main() -> int return 1; } for (unsigned i = 0; i < 4; ++i) { - initial_state->node(15 + i) = m.str(i + 1)[0]; + initial_state.node(15 + i) = m.str(i + 1)[0]; } std::getline(std::cin, line); @@ -275,68 +281,11 @@ auto main() -> int return 1; } for (unsigned i = 0; i < 4; ++i) { - initial_state->node(19 + i) = m.str(i + 1)[0]; + initial_state.node(19 + i) = m.str(i + 1)[0]; } - std::set states; - std::set costs; - std::set visited; - states.insert(initial_state); - costs.insert(initial_state); - - while (!costs.empty()) { - assert(costs.size() == states.size()); - auto it{costs.begin()}; - - State* state{*it}; - visited.insert(state); - states.erase(state); - costs.erase(it); - - if (visited.size() % 10'000 == 0) { - std::cout << "Visited: " << visited.size() << " number of states: " << states.size() - << " Min energy: " << state->cost() << '\n'; - } - - if (state->finished()) { - std::cout << "Done with cost " << state->cost() << '\n'; - return 0; - } - - for (unsigned i = 0; i < state->size(); ++i) { - if (state->node(i) == '.') { - continue; - } - auto [it_begin, it_end] = valid_moves.equal_range(i); - for (auto move_it{it_begin}; move_it != it_end; ++move_it) { - State* next_state = new State{*state}; - UInt cost_delta = move_it->second.second * multipliers[state->node(i)]; - bool keep{false}; - if (next_state->move(i, move_it->second.first, cost_delta) && - !visited.contains(next_state)) { - auto [insert_it, success] = states.insert(next_state); - if (!success) { - auto old_state{*insert_it}; - if (next_state->cost() < old_state->cost()) { - keep = true; - costs.erase(old_state); - states.erase(old_state); - states.insert(next_state); - costs.insert(next_state); - } - } - else { - keep = true; - costs.insert(next_state); - } - } - if (!keep) { - delete next_state; - } - } - } - } - - std::cerr << "FAILED\n"; - return 1; + auto [finished_state, finished_cost] = + dijkstra(initial_state, UInt{0}, StateTranstitionManager{}); + std::cout << "Done with cost " << finished_cost << '\n'; + return 0; }