diff --git a/Mesh/mwis.hpp b/Mesh/mwis.hpp index 6faafc4298c4b0fc614343215a4b2460298eb96d..9d8563d0b32ba7aae8325cbe8f7bdd744221f7ac 100644 --- a/Mesh/mwis.hpp +++ b/Mesh/mwis.hpp @@ -124,11 +124,12 @@ public: RestIterator rest_end) { weight best(std::numeric_limits<weight>::max()); - for (size_t i = 0; i < _selected_storage.size(); i++) - _selected_storage[i] = false; - - for (SolutionIterator it = solution_begin; it != solution_end; it++) - put(_selected, *it, true); + std::set<size_t> selectable_cliques; + for (RestIterator it = rest_begin; it != rest_end; it++) { + const std::vector<size_t> cs = get(_cliques, *it); + std::copy(cs.begin(), cs.end(), + std::inserter(selectable_cliques, selectable_cliques.begin())); + } /** * We can get away with very few steps, because the starting point is the @@ -138,18 +139,23 @@ public: */ static const size_t n = 2; for (size_t i = 0; i < n; i++) { - compute_weights(); + compute_weights(rest_begin, rest_end); + if (_max_size == std::numeric_limits<size_t>::max()) select(rest_begin, rest_end); else select(rest_begin, rest_end, _max_size - (size_t)std::distance(solution_begin, solution_end)); - best = std::min(best, evaluate()); + + best = std::min(best, evaluate(rest_begin, rest_end, + selectable_cliques.begin(), + selectable_cliques.end())); if (i != n - 1) { - compute_gradient(); - follow_gradient(weight(-1e-4 / 1)); + compute_gradient(selectable_cliques.begin(), selectable_cliques.end()); + follow_gradient(selectable_cliques.begin(), selectable_cliques.end(), + weight(-1e-4 / 1)); } } @@ -204,30 +210,30 @@ private: /** * Calculates the gradient according to which nodes are currently selected. */ - void compute_gradient() { - std::transform(_clique_contents.begin(), _clique_contents.end(), - _gradient.begin(), - [&](const std::vector<vertex> &vertices) { - weight sum{1}; - for (const vertex &v : vertices) { - sum -= weight(get(_selected, v) ? 1 : 0); - } - - return sum; - }); + template<typename Iterator> + void compute_gradient(Iterator cliques_begin, Iterator cliques_end) { + for (Iterator it = cliques_begin; it != cliques_end; it++) { + weight sum{1}; + for (const vertex &v : _clique_contents[*it]) { + sum -= weight(get(_selected, v) ? 1 : 0); + } + + _gradient[*it] = sum; + } } /** * Calculates the effective weight of each node (w_i - \sum_{c, x \in c} λ_c). */ - void compute_weights() { - for (size_t i = 0; i < _effective_weight_storage.size(); i++) { - weight result(_weight[i]); + template<typename Iterator> + void compute_weights(Iterator begin, Iterator end) { + for (Iterator it = begin; it != end; it++) { + weight result(get(_weight, *it)); - for (size_t clique : _clique_storage[i]) + for (size_t clique : get(_cliques, *it)) result -= _lambda[clique]; - _effective_weight_storage[i] = result; + put(_effective_weight, *it, result); } } @@ -235,11 +241,14 @@ private: * Update the Lagrange multipliers in accordance with the last computed * gradient, multiplied by scale. */ - void follow_gradient(weight scale) { - std::transform(_lambda.begin(), _lambda.end(), _gradient.begin(), - _lambda.begin(), [&](const weight &cur, const weight &grad) { - return std::max(weight(0), cur + scale * grad); - }); + template<typename Iterator> + void follow_gradient(Iterator cliques_begin, Iterator cliques_end, + weight scale) { + for (Iterator it = cliques_begin; it != cliques_end; it++) { + weight cur(_lambda[*it]); + weight grad(_gradient[*it]); + _lambda[*it] = cur + scale * grad; + } } /** @@ -287,13 +296,17 @@ private: * Calculates the upper bound for the current values of the Lagrange * multipliers and the current selection of nodes. */ - weight evaluate() { - weight result(std::accumulate(_lambda.begin(), _lambda.end(), - weight(0), std::plus<weight>())); + template<typename Iterator, typename CliqueIterator> + weight evaluate(Iterator begin, Iterator end, + CliqueIterator cliques_begin, CliqueIterator cliques_end) { + weight result(0); + for (CliqueIterator it = cliques_begin; it != cliques_end; it++) { + result += _lambda[*it]; + } - for (size_t i = 0; i < _selected_storage.size(); i++) { - if (_selected_storage[i]) - result += _effective_weight_storage[i]; + for (Iterator it = begin; it != end; it++) { + if (get(_selected, *it)) + result += get(_effective_weight, *it); } return result; @@ -589,7 +602,8 @@ public: } if (vertices.size() > 0) { - weight max(_bound(state.solution.begin(), state.solution.end(), + weight max(state.solution_value + + _bound(state.solution.begin(), state.solution.end(), state.selectable.begin(), state.selectable.end())); if (max <= *_best_weight) return;