Line data Source code
1 : #include "IntegerCast.h"
2 : #include "LinewiseInput.h"
3 : #include "Parsing.h"
4 : #include "PuzzleImpl.h"
5 :
6 : #include <absl/container/flat_hash_map.h>
7 : #include <absl/hash/hash.h>
8 : #include <algorithm>
9 : #include <libassert/assert.hpp>
10 : #include <re2/re2.h>
11 :
12 : #include <execution>
13 : #include <ranges>
14 : #include <string_view>
15 : #include <tuple>
16 :
17 : namespace {
18 :
19 : using ConditionRecord = std::tuple<std::string, std::vector<unsigned>>;
20 :
21 2000 : ConditionRecord parseLine(std::string_view const line) {
22 2000 : auto it = std::ranges::find(line, ' ');
23 2000 : DEBUG_ASSERT(it != line.end());
24 2000 : std::string springs(line.begin(), it);
25 2000 : std::string_view groupsStr(std::next(it), line.end());
26 2000 : std::vector<unsigned> groups = parseIntegerRange<unsigned>(groupsStr);
27 2000 : return {std::move(springs), std::move(groups)};
28 2000 : }
29 :
30 2 : std::vector<ConditionRecord> parse(std::string_view const input) {
31 2 : LinewiseInput lines(input);
32 2 : return std::ranges::to<std::vector<ConditionRecord>>(lines |
33 2 : std::ranges::views::transform(parseLine));
34 2 : }
35 :
36 1000 : void unfoldRecord(std::tuple<std::string, std::vector<unsigned>> &r) {
37 1000 : std::string &springs = std::get<0>(r);
38 1000 : std::vector<unsigned> &groups = std::get<1>(r);
39 :
40 1000 : auto numSprings = std::ssize(springs);
41 1000 : auto numGroups = std::ssize(groups);
42 :
43 1000 : springs.reserve(springs.size() * 5u + 4u);
44 1000 : groups.reserve(groups.size() * 5u);
45 :
46 5000 : for (int i = 0; i < 4; ++i) {
47 4000 : springs.push_back('?');
48 4000 : std::copy(springs.begin(), std::next(springs.begin(), numSprings), std::back_inserter(springs));
49 4000 : std::copy(groups.begin(), std::next(groups.begin(), numGroups), std::back_inserter(groups));
50 4000 : }
51 1000 : }
52 :
53 : struct ConditionRecordPart {
54 2000 : ConditionRecordPart(ConditionRecord const &r) : springs(std::get<0>(r)), groups(std::get<1>(r)) {}
55 : ConditionRecordPart(std::string_view const s, std::span<unsigned const> const g)
56 388019 : : springs(s), groups(g) {}
57 :
58 : std::string_view springs;
59 : std::span<unsigned const> groups;
60 :
61 456866 : friend bool operator==(ConditionRecordPart const &lhs, ConditionRecordPart const &rhs) {
62 456866 : return lhs.springs == rhs.springs && std::ranges::equal(lhs.groups, rhs.groups);
63 456866 : }
64 :
65 1294750 : template <typename H> friend H AbslHashValue(H h, ConditionRecordPart const &r) {
66 1294750 : H state = H::combine(std::move(h), r.springs);
67 1294750 : return H::combine_contiguous(std::move(state), r.groups.data(), r.groups.size());
68 1294750 : }
69 : };
70 :
71 : class ArrangementCounter {
72 : public:
73 389962 : size_t count(ConditionRecordPart const &r) const {
74 :
75 389962 : if (auto it = _cache.find(r); it != _cache.end())
76 119007 : return it->second;
77 :
78 270955 : if (r.springs.empty()) {
79 7474 : size_t const cnt = r.groups.empty() ? 1u : 0u;
80 7474 : _cache.emplace(r, cnt);
81 7474 : return cnt;
82 7474 : }
83 :
84 263481 : if (r.groups.empty()) {
85 4513 : size_t const cnt = r.springs.contains('#') ? 0u : 1u;
86 4513 : _cache.emplace(r, cnt);
87 4513 : return cnt;
88 4513 : }
89 :
90 258968 : size_t cnt = 0;
91 258968 : char const spring = r.springs.front();
92 258968 : if (spring == '.' || spring == '?') {
93 222847 : cnt += count({r.springs.substr(1), r.groups});
94 222847 : }
95 258968 : if ((spring == '#' || spring == '?') && !r.groups.empty() &&
96 258968 : canConsumeGroup(r.springs, r.groups.front())) {
97 165450 : unsigned const consume =
98 165450 : std::min(r.groups.front() + 1u, integerCast<unsigned>(r.springs.size()));
99 165450 : cnt += count({r.springs.substr(consume), r.groups.subspan(1)});
100 165450 : }
101 :
102 258968 : _cache.emplace(r, cnt);
103 258968 : return cnt;
104 263481 : }
105 :
106 : private:
107 222081 : static bool canConsumeGroup(std::string_view const s, unsigned groupSize) {
108 222081 : return s.size() >= groupSize && !s.substr(0, groupSize).contains('.') &&
109 222081 : (s.size() == groupSize || s[groupSize] != '#');
110 222081 : }
111 :
112 : using Cache = absl::flat_hash_map<ConditionRecordPart, size_t>;
113 : mutable Cache _cache;
114 : };
115 :
116 : } // namespace
117 :
118 1 : template <> size_t part1<2023, 12>(std::string_view const input) {
119 1 : std::vector<ConditionRecord> records = parse(input);
120 :
121 1 : return std::transform_reduce(std::execution::par, records.begin(), records.end(), size_t(0),
122 1000 : std::plus<>(), [](auto const &r) {
123 1000 : ArrangementCounter ctr;
124 1000 : return ctr.count(r);
125 1000 : });
126 1 : }
127 :
128 1 : template <> size_t part2<2023, 12>(std::string_view const input) {
129 1 : std::vector<ConditionRecord> records = parse(input);
130 1 : std::for_each(std::execution::par, records.begin(), records.end(),
131 1000 : [](ConditionRecord &r) { unfoldRecord(r); });
132 :
133 1 : return std::transform_reduce(std::execution::par, records.begin(), records.end(), size_t(0),
134 1000 : std::plus<>(), [](auto const &r) {
135 1000 : ArrangementCounter ctr;
136 1000 : return ctr.count(r);
137 1000 : });
138 1 : }
|