Line data Source code
1 : #include "IntegerCast.h"
2 : #include "LinewiseInput.h"
3 : #include "Parsing.h"
4 : #include "PuzzleImpl.h"
5 :
6 : #include <absl/container/flat_hash_map.h>
7 : #include <absl/hash/hash.h>
8 : #include <algorithm>
9 : #include <libassert/assert.hpp>
10 : #include <re2/re2.h>
11 :
12 : #include <execution>
13 : #include <ranges>
14 : #include <string>
15 : #include <string_view>
16 : #include <tuple>
17 :
18 : namespace {
19 :
20 : using ConditionRecord = std::tuple<std::string, std::vector<unsigned>>;
21 :
22 2000 : ConditionRecord parseLine(std::string_view const line) {
23 2000 : auto it = std::ranges::find(line, ' ');
24 2000 : DEBUG_ASSERT(it != line.end());
25 2000 : std::string springs(line.begin(), it);
26 2000 : std::string_view groupsStr(std::next(it), line.end());
27 2000 : std::vector<unsigned> groups = parseIntegerRange<unsigned>(groupsStr);
28 2000 : return {std::move(springs), std::move(groups)};
29 2000 : }
30 :
31 2 : std::vector<ConditionRecord> parse(std::string_view const input) {
32 2 : LinewiseInput lines(input);
33 2 : return std::ranges::to<std::vector<ConditionRecord>>(lines |
34 2 : std::ranges::views::transform(parseLine));
35 2 : }
36 :
37 1000 : void unfoldRecord(std::tuple<std::string, std::vector<unsigned>> &r) {
38 1000 : std::string &springs = std::get<0>(r);
39 1000 : std::vector<unsigned> &groups = std::get<1>(r);
40 :
41 1000 : auto numSprings = std::ssize(springs);
42 1000 : auto numGroups = std::ssize(groups);
43 :
44 1000 : springs.reserve(springs.size() * 5u + 4u);
45 1000 : groups.reserve(groups.size() * 5u);
46 :
47 5000 : for (int i = 0; i < 4; ++i) {
48 4000 : springs.push_back('?');
49 4000 : std::copy(springs.begin(), std::next(springs.begin(), numSprings), std::back_inserter(springs));
50 4000 : std::copy(groups.begin(), std::next(groups.begin(), numGroups), std::back_inserter(groups));
51 4000 : }
52 1000 : }
53 :
54 : struct ConditionRecordPart {
55 2000 : ConditionRecordPart(ConditionRecord const &r) : springs(std::get<0>(r)), groups(std::get<1>(r)) {}
56 : ConditionRecordPart(std::string_view const s, std::span<unsigned const> const g)
57 388223 : : springs(s), groups(g) {}
58 :
59 : std::string_view springs;
60 : std::span<unsigned const> groups;
61 :
62 1061260 : friend bool operator==(ConditionRecordPart const &lhs, ConditionRecordPart const &rhs) {
63 1061260 : return lhs.springs == rhs.springs && std::ranges::equal(lhs.groups, rhs.groups);
64 1061260 : }
65 :
66 2174383 : template <typename H> friend H AbslHashValue(H h, ConditionRecordPart const &r) {
67 2174383 : H state = H::combine(std::move(h), r.springs);
68 2174383 : return H::combine_contiguous(std::move(state), r.groups.data(), r.groups.size());
69 2174383 : }
70 : };
71 :
72 : class ArrangementCounter {
73 : public:
74 390191 : size_t count(ConditionRecordPart const &r) const {
75 :
76 390191 : if (auto it = _cache.find(r); it != _cache.end())
77 119014 : return it->second;
78 :
79 271177 : if (r.springs.empty()) {
80 7474 : size_t const cnt = r.groups.empty() ? 1u : 0u;
81 7474 : _cache.emplace(r, cnt);
82 7474 : return cnt;
83 7474 : }
84 :
85 263703 : if (r.groups.empty()) {
86 4513 : size_t const cnt = r.springs.contains('#') ? 0u : 1u;
87 4513 : _cache.emplace(r, cnt);
88 4513 : return cnt;
89 4513 : }
90 :
91 259190 : size_t cnt = 0;
92 259190 : char const spring = r.springs.front();
93 259190 : if (spring == '.' || spring == '?') {
94 222909 : cnt += count({r.springs.substr(1), r.groups});
95 222909 : }
96 259190 : if ((spring == '#' || spring == '?') && !r.groups.empty() &&
97 259190 : canConsumeGroup(r.springs, r.groups.front())) {
98 165471 : unsigned const consume =
99 165471 : std::min(r.groups.front() + 1u, integerCast<unsigned>(r.springs.size()));
100 165471 : cnt += count({r.springs.substr(consume), r.groups.subspan(1)});
101 165471 : }
102 :
103 259190 : _cache.emplace(r, cnt);
104 259190 : return cnt;
105 263703 : }
106 :
107 : private:
108 222063 : static bool canConsumeGroup(std::string_view const s, unsigned groupSize) {
109 222063 : return s.size() >= groupSize && !s.substr(0, groupSize).contains('.') &&
110 222063 : (s.size() == groupSize || s[groupSize] != '#');
111 222063 : }
112 :
113 : using Cache = absl::flat_hash_map<ConditionRecordPart, size_t>;
114 : mutable Cache _cache;
115 : };
116 :
117 : } // namespace
118 :
119 1 : template <> std::string solvePart1<2023, 12>(std::string_view const input) {
120 1 : std::vector<ConditionRecord> records = parse(input);
121 :
122 1 : return std::to_string(std::transform_reduce(std::execution::par, records.begin(), records.end(),
123 1000 : size_t(0), std::plus<>(), [](auto const &r) {
124 1000 : ArrangementCounter ctr;
125 1000 : return ctr.count(r);
126 1000 : }));
127 1 : }
128 :
129 1 : template <> std::string solvePart2<2023, 12>(std::string_view const input) {
130 1 : std::vector<ConditionRecord> records = parse(input);
131 1 : std::for_each(std::execution::par, records.begin(), records.end(),
132 1000 : [](ConditionRecord &r) { unfoldRecord(r); });
133 :
134 1 : return std::to_string(std::transform_reduce(std::execution::par, records.begin(), records.end(),
135 1000 : size_t(0), std::plus<>(), [](auto const &r) {
136 1000 : ArrangementCounter ctr;
137 1000 : return ctr.count(r);
138 1000 : }));
139 1 : }
|