Proof of concept C API for Ceres
This introduces a simple C API for a subset of Ceres. This opens the door to
using languages like Python to call Ceres, since it is much easier to bind to C
than it is to bind to C++. It will mean giving up the native Ceres autodiff.
The implementation in this patch does not attempt to do everything but is only
just enough to get started. Subsequent patches will increase the surface area
of Ceres that is covered by the C API.
Change-Id: Ic51804bac6865e1a2e476553248aabc91dff3409
diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt
index 94132be..dfe5589 100644
--- a/examples/CMakeLists.txt
+++ b/examples/CMakeLists.txt
@@ -43,6 +43,10 @@
ADD_EXECUTABLE(curve_fitting curve_fitting.cc)
TARGET_LINK_LIBRARIES(curve_fitting ceres)
+ADD_EXECUTABLE(curve_fitting_c curve_fitting.c)
+TARGET_LINK_LIBRARIES(curve_fitting_c ceres)
+
+
ADD_EXECUTABLE(robust_curve_fitting robust_curve_fitting.cc)
TARGET_LINK_LIBRARIES(robust_curve_fitting ceres)
diff --git a/examples/curve_fitting.c b/examples/curve_fitting.c
new file mode 100644
index 0000000..fb75d7d
--- /dev/null
+++ b/examples/curve_fitting.c
@@ -0,0 +1,174 @@
+/* Ceres Solver - A fast non-linear least squares minimizer
+ * Copyright 2013 Google Inc. All rights reserved.
+ * http://code.google.com/p/ceres-solver/
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * - Redistributions of source code must retain the above copyright notice,
+ * this list of conditions and the following disclaimer.
+ * - Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ * - Neither the name of Google Inc. nor the names of its contributors may be
+ * used to endorse or promote products derived from this software without
+ * specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+ * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+ * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+ * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+ * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+ * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+ * POSSIBILITY OF SUCH DAMAGE.
+ *
+ * Author: mierle@gmail.com (Keir Mierle)
+ *
+ * This is a port of curve_fitting.cc to the minimal C API for Ceres.
+ */
+
+#include <math.h>
+#include <stdio.h>
+#include <string.h> // For NULL
+#include "ceres/c_api.h"
+
+// Data generated using the following octave code.
+// randn('seed', 23497);
+// m = 0.3;
+// c = 0.1;
+// x=[0:0.075:5];
+// y = exp(m * x + c);
+// noise = randn(size(x)) * 0.2;
+// y_observed = y + noise;
+// data = [x', y_observed'];
+
+int num_observations = 67;
+double data[] = {
+ 0.000000e+00, 1.133898e+00,
+ 7.500000e-02, 1.334902e+00,
+ 1.500000e-01, 1.213546e+00,
+ 2.250000e-01, 1.252016e+00,
+ 3.000000e-01, 1.392265e+00,
+ 3.750000e-01, 1.314458e+00,
+ 4.500000e-01, 1.472541e+00,
+ 5.250000e-01, 1.536218e+00,
+ 6.000000e-01, 1.355679e+00,
+ 6.750000e-01, 1.463566e+00,
+ 7.500000e-01, 1.490201e+00,
+ 8.250000e-01, 1.658699e+00,
+ 9.000000e-01, 1.067574e+00,
+ 9.750000e-01, 1.464629e+00,
+ 1.050000e+00, 1.402653e+00,
+ 1.125000e+00, 1.713141e+00,
+ 1.200000e+00, 1.527021e+00,
+ 1.275000e+00, 1.702632e+00,
+ 1.350000e+00, 1.423899e+00,
+ 1.425000e+00, 1.543078e+00,
+ 1.500000e+00, 1.664015e+00,
+ 1.575000e+00, 1.732484e+00,
+ 1.650000e+00, 1.543296e+00,
+ 1.725000e+00, 1.959523e+00,
+ 1.800000e+00, 1.685132e+00,
+ 1.875000e+00, 1.951791e+00,
+ 1.950000e+00, 2.095346e+00,
+ 2.025000e+00, 2.361460e+00,
+ 2.100000e+00, 2.169119e+00,
+ 2.175000e+00, 2.061745e+00,
+ 2.250000e+00, 2.178641e+00,
+ 2.325000e+00, 2.104346e+00,
+ 2.400000e+00, 2.584470e+00,
+ 2.475000e+00, 1.914158e+00,
+ 2.550000e+00, 2.368375e+00,
+ 2.625000e+00, 2.686125e+00,
+ 2.700000e+00, 2.712395e+00,
+ 2.775000e+00, 2.499511e+00,
+ 2.850000e+00, 2.558897e+00,
+ 2.925000e+00, 2.309154e+00,
+ 3.000000e+00, 2.869503e+00,
+ 3.075000e+00, 3.116645e+00,
+ 3.150000e+00, 3.094907e+00,
+ 3.225000e+00, 2.471759e+00,
+ 3.300000e+00, 3.017131e+00,
+ 3.375000e+00, 3.232381e+00,
+ 3.450000e+00, 2.944596e+00,
+ 3.525000e+00, 3.385343e+00,
+ 3.600000e+00, 3.199826e+00,
+ 3.675000e+00, 3.423039e+00,
+ 3.750000e+00, 3.621552e+00,
+ 3.825000e+00, 3.559255e+00,
+ 3.900000e+00, 3.530713e+00,
+ 3.975000e+00, 3.561766e+00,
+ 4.050000e+00, 3.544574e+00,
+ 4.125000e+00, 3.867945e+00,
+ 4.200000e+00, 4.049776e+00,
+ 4.275000e+00, 3.885601e+00,
+ 4.350000e+00, 4.110505e+00,
+ 4.425000e+00, 4.345320e+00,
+ 4.500000e+00, 4.161241e+00,
+ 4.575000e+00, 4.363407e+00,
+ 4.650000e+00, 4.161576e+00,
+ 4.725000e+00, 4.619728e+00,
+ 4.800000e+00, 4.737410e+00,
+ 4.875000e+00, 4.727863e+00,
+ 4.950000e+00, 4.669206e+00,
+};
+
+/* This is the equivalent of a use-defined CostFunction in the C++ Ceres API.
+ * This is passed as a callback to the Ceres C API, which internally converts
+ * the callback into a CostFunction. */
+int exponential_residual(void* user_data,
+ double** parameters,
+ double* residuals,
+ double** jacobians) {
+ double* measurement = (double*) user_data;
+ double x = measurement[0];
+ double y = measurement[1];
+ double m = parameters[0][0];
+ double c = parameters[1][0];
+
+ residuals[0] = y - exp(m * x + c);
+ if (jacobians == NULL) {
+ return 1;
+ }
+ if (jacobians[0] != NULL) {
+ jacobians[0][0] = - m * exp(m * x + c); /* dr/dm */
+ }
+ if (jacobians[1] != NULL) {
+ jacobians[1][0] = - exp(m * x + c); /* dr/dc */
+ }
+ return 1;
+}
+
+int main(int argc, char** argv) {
+ ceres_init(argc, argv);
+
+ /* Note: Typically it is better to compact m and c into one block,
+ * but in this case use separate blocks for illustration. */
+ double m = 0.0;
+ double c = 0.0;
+ double *parameter_pointers[] = { &m, &c };
+ int parameter_sizes[] = { 1, 1 };
+
+ ceres_problem_t* problem = ceres_create_problem();
+ for (int i = 0; i < num_observations; ++i) {
+ ceres_problem_add_residual_block(
+ problem,
+ exponential_residual, /* Cost function */
+ NULL, /* No loss function */
+ &data[2 * i], /* Points to the (x,y) measurement */
+ 1, /* Number of residuals */
+ 2, /* Number of parameter blocks */
+ parameter_sizes,
+ parameter_pointers);
+ }
+
+ ceres_solve(problem);
+
+ printf("Initial m: 0.0, c: 0.0\n");
+ printf("Final m: %g, c: %g\n", m, c);
+ return 0;
+}
diff --git a/include/ceres/c_api.h b/include/ceres/c_api.h
new file mode 100644
index 0000000..8d74a27
--- /dev/null
+++ b/include/ceres/c_api.h
@@ -0,0 +1,94 @@
+/* Ceres Solver - A fast non-linear least squares minimizer
+ * Copyright 2013 Google Inc. All rights reserved.
+ * http://code.google.com/p/ceres-solver/
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * - Redistributions of source code must retain the above copyright notice,
+ * this list of conditions and the following disclaimer.
+ * - Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ * - Neither the name of Google Inc. nor the names of its contributors may be
+ * used to endorse or promote products derived from this software without
+ * specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+ * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+ * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+ * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+ * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+ * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+ * POSSIBILITY OF SUCH DAMAGE.
+ *
+ * Author: mierle@gmail.com (Keir Mierle)
+ *
+ * A minimal C API for Ceres. Not all functionality is included. This API is
+ * not intended for clients of Ceres, but is instead intended for easing the
+ * process of binding Ceres to other languages.
+ *
+ * Currently this is a work in progress.
+ */
+
+#ifndef CERES_PUBLIC_C_API_H_
+#define CERES_PUBLIC_C_API_H_
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/* Init the Ceres private data. Must be called before anything else. */
+void ceres_init();
+
+/* Equivalent to CostFunction::Evaluate() in the C++ API.
+ *
+ * The user is may keep private information inside the opaque user_data object.
+ * The pointer here is the same one passed in the ceres_add_residual_block(). */
+typedef int (*ceres_cost_function_t)(void* user_data,
+ double** parameters,
+ double* residuals,
+ double** jacobians);
+
+/* Equivalent to LossFunction::Evaluate() from the C++ API. */
+typedef int (*ceres_loss_function_t)(void* user_data,
+ double squared_norm,
+ double out[3]);
+
+/* Equivalent to Problem from the C++ API. */
+struct ceres_problem_s;
+typedef struct ceres_problem_s ceres_problem_t;
+
+struct ceres_residual_block_id_s;
+typedef struct ceres_residual_block_id_s ceres_residual_block_id_t;
+
+/* Create and destroy a problem */
+/* TODO(keir): Add options for the problem. */
+ceres_problem_t* ceres_create_problem();
+void ceres_free_problem(ceres_problem_t* problem);
+
+/* Add a residual block. */
+/* TODO(keir): Add support for loss functions */
+ceres_residual_block_id_t* ceres_problem_add_residual_block(
+ ceres_problem_t* problem,
+ ceres_cost_function_t cost_function,
+ ceres_loss_function_t loss_function,
+ void* user_data,
+ int num_residuals,
+ int num_parameter_blocks,
+ int* parameter_block_sizes,
+ double** parameters);
+
+void ceres_solve(ceres_problem_t* problem);
+
+/* TODO(keir): Figure out a way to pass a config in. */
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif /* CERES_PUBLIC_C_API_H_ */
diff --git a/internal/ceres/CMakeLists.txt b/internal/ceres/CMakeLists.txt
index 549c94e..392d058 100644
--- a/internal/ceres/CMakeLists.txt
+++ b/internal/ceres/CMakeLists.txt
@@ -38,6 +38,7 @@
block_random_access_sparse_matrix.cc
block_sparse_matrix.cc
block_structure.cc
+ c_api.cc
canonical_views_clustering.cc
cgnr_solver.cc
compressed_col_sparse_matrix_utils.cc
diff --git a/internal/ceres/c_api.cc b/internal/ceres/c_api.cc
new file mode 100644
index 0000000..4d7d59b
--- /dev/null
+++ b/internal/ceres/c_api.cc
@@ -0,0 +1,131 @@
+// Ceres Solver - A fast non-linear least squares minimizer
+// Copyright 2013 Google Inc. All rights reserved.
+// http://code.google.com/p/ceres-solver/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// * Redistributions of source code must retain the above copyright notice,
+// this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+// * Neither the name of Google Inc. nor the names of its contributors may be
+// used to endorse or promote products derived from this software without
+// specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+// POSSIBILITY OF SUCH DAMAGE.
+//
+// Author: mierle@gmail.com (Keir Mierle)
+//
+// An incomplete C API for Ceres.
+//
+// TODO(keir): Figure out why logging does not seem to work.
+
+#include <vector>
+#include <iostream> // XXX remove me
+#include "ceres/c_api.h"
+#include "ceres/cost_function.h"
+#include "ceres/problem.h"
+#include "ceres/solver.h"
+#include "ceres/types.h" // for std
+#include "glog/logging.h"
+
+using ceres::Problem;
+
+void ceres_init() {
+ // This is not ideal, but it's not clear what to do if there is no gflags and
+ // no access to command line arguments.
+ google::InitGoogleLogging("<unknown>");
+}
+
+ceres_problem_t* ceres_create_problem() {
+ return reinterpret_cast<ceres_problem_t*>(new Problem);
+}
+
+void ceres_free_problem(ceres_problem_t* problem) {
+ delete reinterpret_cast<Problem*>(problem);
+}
+
+class CallbackCostFunction : public ceres::CostFunction {
+ public:
+ CallbackCostFunction(ceres_cost_function_t cost_function,
+ void* user_data,
+ int num_residuals,
+ int num_parameter_blocks,
+ int* parameter_block_sizes)
+ : cost_function_(cost_function),
+ user_data_(user_data) {
+ set_num_residuals(num_residuals);
+ for (int i = 0; i < num_parameter_blocks; ++i) {
+ mutable_parameter_block_sizes()->push_back(parameter_block_sizes[i]);
+ }
+ }
+
+ virtual ~CallbackCostFunction() {}
+
+ virtual bool Evaluate(double const* const* parameters,
+ double* residuals,
+ double** jacobians) const {
+ return (*cost_function_)(user_data_,
+ const_cast<double**>(parameters),
+ residuals,
+ jacobians);
+ }
+
+ private:
+ ceres_cost_function_t cost_function_;
+ void* user_data_;
+};
+
+ceres_residual_block_id_t* ceres_problem_add_residual_block(
+ ceres_problem_t* problem,
+ ceres_cost_function_t cost_function,
+ ceres_loss_function_t loss_function,
+ void* user_data,
+ int num_residuals,
+ int num_parameter_blocks,
+ int* parameter_block_sizes,
+ double** parameters) {
+ Problem* ceres_problem = reinterpret_cast<Problem*>(problem);
+
+ ceres::CostFunction* callback_cost_function =
+ new CallbackCostFunction(cost_function,
+ user_data,
+ num_residuals,
+ num_parameter_blocks,
+ parameter_block_sizes);
+
+ std::vector<double*> parameter_blocks(parameters,
+ parameters + num_parameter_blocks);
+ return reinterpret_cast<ceres_residual_block_id_t*>(
+ ceres_problem->AddResidualBlock(callback_cost_function,
+ NULL, /* Ignore loss for now */
+ parameter_blocks));
+}
+
+void ceres_solve(ceres_problem_t* c_problem) {
+ Problem* problem = reinterpret_cast<Problem*>(c_problem);
+
+ // TODO(keir): Obviously, this way of setting options won't scale or last.
+ // Instead, figure out a way to specify some of the options without
+ // duplicating everything.
+ ceres::Solver::Options options;
+ options.max_num_iterations = 25;
+ options.linear_solver_type = ceres::DENSE_QR;
+ options.minimizer_progress_to_stdout = true;
+
+ ceres::Solver::Summary summary;
+ ceres::Solve(options, problem, &summary);
+ std::cout << summary.FullReport() << "\n";
+}