diff --git a/2023/puzzle-12-02.cc b/2023/puzzle-12-02.cc index 32b4aab..47f55a9 100644 --- a/2023/puzzle-12-02.cc +++ b/2023/puzzle-12-02.cc @@ -1,6 +1,8 @@ +#include #include #include #include +#include #include #include #include @@ -29,75 +31,79 @@ auto split(std::string_view line) -> std::pair return std::make_pair(pattern, nums); } -[[nodiscard]] auto validate_subpattern(std::string_view pattern, std::string_view start, - UInt num_hashs, UInt gaps) -> bool -{ - if (pattern.size() < start.size()) { - return false; - } - UInt hash_count{0}; +std::unordered_map counts; - auto pit{pattern.begin()}; - for (auto sit{start.begin()}; sit != start.end(); ++pit, ++sit) { - if (*sit == '#') { ++hash_count; } - if (*pit != '?' && *pit != *sit) { - return false; +auto valid_sequences(std::string_view pattern, UIntVec::const_iterator begin, + UIntVec::const_iterator end) -> UInt; + +auto generate_pattern(std::string pattern, UIntVec::const_iterator begin, + UIntVec::const_iterator end) -> UInt +{ + while (begin != end) { + if (pattern.empty()) { + // We still have ###s to add but nowhere to put them - this isn't a match. + return 0; + } + + if (pattern[0] == '?') { + // We don't know what the first character is so let's try both options recursively. + auto new_pattern{pattern}; + new_pattern[0] = '.'; + auto count = valid_sequences(new_pattern, begin, end); + new_pattern[0] = '#'; + count += valid_sequences(new_pattern, begin, end); + return count; + } + + if (pattern[0] == '.') { + // If the pattern begins with a '.' then we just skip over it. + pattern = pattern.substr(1); + continue; + } + + /* We must have a '#' now. */ + assert(pattern[0] == '#'); + + if (pattern.size() < *begin) { + // Not enough space. + return 0; + } + + for (std::size_t count = 0; count < *begin; ++count) { + // Check for dots. + if (pattern[count] == '.') { return 0; } + } + + // Move along the pattern. + pattern = pattern.substr(*begin++); + + if (begin != end) { + if (pattern.empty()) { return 0; } + if (pattern[0] == '#') { return 0; } + pattern = pattern.substr(1); } } - for (; pit != pattern.end(); ++pit) { - if (*pit == '#') { ++hash_count; } - } - - if (num_hashs < hash_count) { - return false; - } - if (start.size() + (num_hashs - hash_count) + gaps > pattern.size()) { - return false; - } - - return true; + return pattern.find('#') == std::string_view::npos; } -void generate_pattern(std::string start, UIntVec::const_iterator it, UIntVec::const_iterator end, - std::string_view pattern, UInt hash_count, UInt& count) +auto valid_sequences(std::string_view pattern, UIntVec::const_iterator begin, + UIntVec::const_iterator end) -> UInt { - auto pos = start.size(); - auto num_hashs = *it++; - start += std::string(num_hashs, '#'); - - if (it == end) { - if (start.size() <= pattern.size()) { - start.resize(pattern.size(), '.'); - if (validate_subpattern(pattern.substr(pos), start.substr(pos), num_hashs, 0)) { - ++count; - if (count % 1'000'000 == 0) { - std::cout << start << '\n'; - } - } - } - return; + std::string hash = std::string(pattern); + for (auto it{begin}; it != end; ++it) { + hash += ','; + hash += std::to_string(*it); } - start += '.'; - while (validate_subpattern(pattern.substr(pos), start.substr(pos), hash_count, end - it - 1)) { - generate_pattern(start, it, end, pattern, hash_count - num_hashs, count); - start += '.'; - } -} - -auto valid_sequences(std::string_view pattern, UIntVec const& nums) -> UInt -{ - UInt count{0}; - UInt hash_count{0}; - for (auto num : nums) { hash_count += num; } - std::string start; - while (validate_subpattern(pattern, start, hash_count, nums.size() - 1)) { - generate_pattern(start, nums.begin(), nums.end(), pattern, hash_count, count); - start += '.'; + auto it = counts.find(hash); + if (it == counts.end()) { + bool success; + std::tie(it, success) = counts.insert( + {hash, generate_pattern(std::string(pattern), begin, end)}); } - return count; + return it->second; } auto main() -> int try { @@ -115,7 +121,7 @@ auto main() -> int try { unfolded_nums.insert(unfolded_nums.begin(), nums.begin(), nums.end()); } - line_count += valid_sequences(unfolded_pattern, unfolded_nums); + line_count += valid_sequences(unfolded_pattern, unfolded_nums.begin(), unfolded_nums.end()); std::cout << line << ": " << line_count << '\n'; count += line_count; }