Move IterationCallbacks into their own file.
1. Merge TrustRegionLoggingCallback and LineSearchLoggingCallback
into a single callback.
2. Move the callbacks into callback.h
3. Update SolverImpl to use the new callbacks.
Change-Id: I9e82173cf2b828d023d96c57d1cba17f4832aeae
diff --git a/internal/ceres/CMakeLists.txt b/internal/ceres/CMakeLists.txt
index 1dd4090..18ff92b 100644
--- a/internal/ceres/CMakeLists.txt
+++ b/internal/ceres/CMakeLists.txt
@@ -43,6 +43,7 @@
c_api.cc
canonical_views_clustering.cc
cgnr_solver.cc
+ callbacks.cc
compressed_col_sparse_matrix_utils.cc
compressed_row_jacobian_writer.cc
compressed_row_sparse_matrix.cc
diff --git a/internal/ceres/callbacks.cc b/internal/ceres/callbacks.cc
new file mode 100644
index 0000000..c551162
--- /dev/null
+++ b/internal/ceres/callbacks.cc
@@ -0,0 +1,107 @@
+// Ceres Solver - A fast non-linear least squares minimizer
+// Copyright 2014 Google Inc. All rights reserved.
+// http://code.google.com/p/ceres-solver/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// * Redistributions of source code must retain the above copyright notice,
+// this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+// * Neither the name of Google Inc. nor the names of its contributors may be
+// used to endorse or promote products derived from this software without
+// specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+// POSSIBILITY OF SUCH DAMAGE.
+//
+// Author: sameeragarwal@google.com (Sameer Agarwal)
+
+#include <iostream> // NO LINT
+#include "ceres/callbacks.h"
+#include "ceres/program.h"
+#include "ceres/stringprintf.h"
+#include "glog/logging.h"
+
+namespace ceres {
+namespace internal {
+
+StateUpdatingCallback::StateUpdatingCallback(Program* program,
+ double* parameters)
+ : program_(program), parameters_(parameters) {}
+
+StateUpdatingCallback::~StateUpdatingCallback() {}
+
+CallbackReturnType StateUpdatingCallback::operator()(
+ const IterationSummary& summary) {
+ if (summary.step_is_successful) {
+ program_->StateVectorToParameterBlocks(parameters_);
+ program_->CopyParameterBlockStateToUserState();
+ }
+ return SOLVER_CONTINUE;
+}
+
+LoggingCallback::LoggingCallback(const MinimizerType minimizer_type,
+ const bool log_to_stdout)
+ : minimizer_type(minimizer_type),
+ log_to_stdout_(log_to_stdout) {}
+
+LoggingCallback::~LoggingCallback() {}
+
+CallbackReturnType LoggingCallback::operator()(
+ const IterationSummary& summary) {
+ string output;
+ if (minimizer_type == LINE_SEARCH) {
+ const char* kReportRowFormat =
+ "% 4d: f:% 8e d:% 3.2e g:% 3.2e h:% 3.2e "
+ "s:% 3.2e e:% 3d it:% 3.2e tt:% 3.2e";
+ output = StringPrintf(kReportRowFormat,
+ summary.iteration,
+ summary.cost,
+ summary.cost_change,
+ summary.gradient_max_norm,
+ summary.step_norm,
+ summary.step_size,
+ summary.line_search_function_evaluations,
+ summary.iteration_time_in_seconds,
+ summary.cumulative_time_in_seconds);
+ } else if (minimizer_type == TRUST_REGION) {
+ const char* kReportRowFormat =
+ "% 4d: f:% 8e d:% 3.2e g:% 3.2e h:% 3.2e "
+ "rho:% 3.2e mu:% 3.2e li:% 3d it:% 3.2e tt:% 3.2e";
+ output = StringPrintf(kReportRowFormat,
+ summary.iteration,
+ summary.cost,
+ summary.cost_change,
+ summary.gradient_max_norm,
+ summary.step_norm,
+ summary.relative_decrease,
+ summary.trust_region_radius,
+ summary.linear_solver_iterations,
+ summary.iteration_time_in_seconds,
+ summary.cumulative_time_in_seconds);
+ } else {
+ LOG(FATAL) << "Unknown minimizer type.";
+ }
+
+ if (log_to_stdout_) {
+ cout << output << endl;
+ } else {
+ VLOG(1) << output;
+ }
+ return SOLVER_CONTINUE;
+}
+
+} // namespace internal
+} // namespace ceres
diff --git a/internal/ceres/callbacks.h b/internal/ceres/callbacks.h
new file mode 100644
index 0000000..93704df
--- /dev/null
+++ b/internal/ceres/callbacks.h
@@ -0,0 +1,71 @@
+// Ceres Solver - A fast non-linear least squares minimizer
+// Copyright 2014 Google Inc. All rights reserved.
+// http://code.google.com/p/ceres-solver/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// * Redistributions of source code must retain the above copyright notice,
+// this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+// * Neither the name of Google Inc. nor the names of its contributors may be
+// used to endorse or promote products derived from this software without
+// specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+// POSSIBILITY OF SUCH DAMAGE.
+//
+// Author: sameeragarwal@google.com (Sameer Agarwal)
+
+#ifndef CERES_INTERNAL_CALLBACKS_H_
+#define CERES_INTERNAL_CALLBACKS_H_
+
+#include <string>
+#include "ceres/iteration_callback.h"
+#include "ceres/internal/port.h"
+
+namespace ceres {
+namespace internal {
+
+class Program;
+
+// Callback for updating the externally visible state of parameter
+// blocks.
+class StateUpdatingCallback : public IterationCallback {
+ public:
+ StateUpdatingCallback(Program* program, double* parameters);
+ virtual ~StateUpdatingCallback();
+ virtual CallbackReturnType operator()(const IterationSummary& summary);
+ private:
+ Program* program_;
+ double* parameters_;
+};
+
+// Callback for logging the state of the minimizer to STDERR or
+// STDOUT depending on the user's preferences and logging level.
+class LoggingCallback : public IterationCallback {
+ public:
+ LoggingCallback(MinimizerType minimizer_type, bool log_to_stdout);
+ virtual ~LoggingCallback();
+ virtual CallbackReturnType operator()(const IterationSummary& summary);
+
+ private:
+ const MinimizerType minimizer_type;
+ const bool log_to_stdout_;
+};
+
+} // namespace internal
+} // namespace ceres
+
+#endif // CERES_INTERNAL_CALLBACKS_H_
diff --git a/internal/ceres/solver_impl.cc b/internal/ceres/solver_impl.cc
index 8148716..101439c 100644
--- a/internal/ceres/solver_impl.cc
+++ b/internal/ceres/solver_impl.cc
@@ -35,6 +35,7 @@
#include <numeric>
#include <string>
#include "ceres/array_utils.h"
+#include "ceres/callbacks.h"
#include "ceres/coordinate_descent_minimizer.h"
#include "ceres/cxsparse.h"
#include "ceres/evaluator.h"
@@ -61,26 +62,6 @@
namespace internal {
namespace {
-// Callback for updating the user's parameter blocks. Updates are only
-// done if the step is successful.
-class StateUpdatingCallback : public IterationCallback {
- public:
- StateUpdatingCallback(Program* program, double* parameters)
- : program_(program), parameters_(parameters) {}
-
- CallbackReturnType operator()(const IterationSummary& summary) {
- if (summary.step_is_successful) {
- program_->StateVectorToParameterBlocks(parameters_);
- program_->CopyParameterBlockStateToUserState();
- }
- return SOLVER_CONTINUE;
- }
-
- private:
- Program* program_;
- double* parameters_;
-};
-
void SetSummaryFinalCost(Solver::Summary* summary) {
summary->final_cost = summary->initial_cost;
// We need the loop here, instead of just looking at the last
@@ -91,106 +72,6 @@
}
}
-// Callback for logging the state of the minimizer to STDERR or STDOUT
-// depending on the user's preferences and logging level.
-class TrustRegionLoggingCallback : public IterationCallback {
- public:
- explicit TrustRegionLoggingCallback(bool log_to_stdout)
- : log_to_stdout_(log_to_stdout) {}
-
- ~TrustRegionLoggingCallback() {}
-
- CallbackReturnType operator()(const IterationSummary& summary) {
- const char* kReportRowFormat =
- "% 4d: f:% 8e d:% 3.2e g:% 3.2e h:% 3.2e "
- "rho:% 3.2e mu:% 3.2e li:% 3d it:% 3.2e tt:% 3.2e";
- string output = StringPrintf(kReportRowFormat,
- summary.iteration,
- summary.cost,
- summary.cost_change,
- summary.gradient_max_norm,
- summary.step_norm,
- summary.relative_decrease,
- summary.trust_region_radius,
- summary.linear_solver_iterations,
- summary.iteration_time_in_seconds,
- summary.cumulative_time_in_seconds);
- if (log_to_stdout_) {
- cout << output << endl;
- } else {
- VLOG(1) << output;
- }
- return SOLVER_CONTINUE;
- }
-
- private:
- const bool log_to_stdout_;
-};
-
-// Callback for logging the state of the minimizer to STDERR or STDOUT
-// depending on the user's preferences and logging level.
-class LineSearchLoggingCallback : public IterationCallback {
- public:
- explicit LineSearchLoggingCallback(bool log_to_stdout)
- : log_to_stdout_(log_to_stdout) {}
-
- ~LineSearchLoggingCallback() {}
-
- CallbackReturnType operator()(const IterationSummary& summary) {
- const char* kReportRowFormat =
- "% 4d: f:% 8e d:% 3.2e g:% 3.2e h:% 3.2e "
- "s:% 3.2e e:% 3d it:% 3.2e tt:% 3.2e";
- string output = StringPrintf(kReportRowFormat,
- summary.iteration,
- summary.cost,
- summary.cost_change,
- summary.gradient_max_norm,
- summary.step_norm,
- summary.step_size,
- summary.line_search_function_evaluations,
- summary.iteration_time_in_seconds,
- summary.cumulative_time_in_seconds);
- if (log_to_stdout_) {
- cout << output << endl;
- } else {
- VLOG(1) << output;
- }
- return SOLVER_CONTINUE;
- }
-
- private:
- const bool log_to_stdout_;
-};
-
-
-// Basic callback to record the execution of the solver to a file for
-// offline analysis.
-class FileLoggingCallback : public IterationCallback {
- public:
- explicit FileLoggingCallback(const string& filename)
- : fptr_(NULL) {
- fptr_ = fopen(filename.c_str(), "w");
- CHECK_NOTNULL(fptr_);
- }
-
- virtual ~FileLoggingCallback() {
- if (fptr_ != NULL) {
- fclose(fptr_);
- }
- }
-
- virtual CallbackReturnType operator()(const IterationSummary& summary) {
- fprintf(fptr_,
- "%4d %e %e\n",
- summary.iteration,
- summary.cost,
- summary.cumulative_time_in_seconds);
- return SOLVER_CONTINUE;
- }
- private:
- FILE* fptr_;
-};
-
// Iterate over each of the groups in order of their priority and fill
// summary with their sizes.
void SummarizeOrdering(ParameterBlockOrdering* ordering,
@@ -422,8 +303,8 @@
// vector.
program->ParameterBlocksToStateVector(parameters.data());
- TrustRegionLoggingCallback logging_callback(
- options.minimizer_progress_to_stdout);
+ LoggingCallback logging_callback(TRUST_REGION,
+ options.minimizer_progress_to_stdout);
if (options.logging_type != SILENT) {
minimizer_options.callbacks.insert(minimizer_options.callbacks.begin(),
&logging_callback);
@@ -487,8 +368,8 @@
// Collect the discontiguous parameters into a contiguous state vector.
program->ParameterBlocksToStateVector(parameters.data());
- LineSearchLoggingCallback logging_callback(
- options.minimizer_progress_to_stdout);
+ LoggingCallback logging_callback(LINE_SEARCH,
+ options.minimizer_progress_to_stdout);
if (options.logging_type != SILENT) {
minimizer_options.callbacks.insert(minimizer_options.callbacks.begin(),
&logging_callback);