Extend the C API to support loss functions This extends the C API to support loss functions. Both user-supplied cost functions as well as the stock Ceres cost functions (Cauchy, Huber, etc) are supported. In addition, this adds a simple unit test for the C API. Supporting loss functions required changing the signature of the ceres_add_residual_block() function to also take a thunk for the loss function. Change-Id: Iefa58cf709adbb8f24588e5eb6aed9aef46b6d73
diff --git a/examples/curve_fitting.c b/examples/curve_fitting.c index 5e3cd16..1d1ec9b 100644 --- a/examples/curve_fitting.c +++ b/examples/curve_fitting.c
@@ -36,15 +36,18 @@ #include <string.h> // For NULL #include "ceres/c_api.h" -// Data generated using the following octave code. -// randn('seed', 23497); -// m = 0.3; -// c = 0.1; -// x=[0:0.075:5]; -// y = exp(m * x + c); -// noise = randn(size(x)) * 0.2; -// y_observed = y + noise; -// data = [x', y_observed']; +/* Data generated using the following octave code. + * + * randn('seed', 23497); + * m = 0.3; + * c = 0.1; + * x=[0:0.075:5]; + * y = exp(m * x + c); + * noise = randn(size(x)) * 0.2; + * y_observed = y + noise; + * data = [x', y_observed']; + * + */ int num_observations = 67; double data[] = { @@ -135,7 +138,7 @@ return 1; } if (jacobians[0] != NULL) { - jacobians[0][0] = - m * exp(m * x + c); /* dr/dm */ + jacobians[0][0] = - x * exp(m * x + c); /* dr/dm */ } if (jacobians[1] != NULL) { jacobians[1][0] = - exp(m * x + c); /* dr/dc */ @@ -154,17 +157,20 @@ int parameter_sizes[] = { 1, 1 }; ceres_problem_t* problem; - int i; - ceres_init(argc, argv); + /* Ceres has some internal stuff that needs to get initialized. */ + ceres_init(); problem = ceres_create_problem(); - for (i = 0; i < num_observations; ++i) { + + /* Add all the residuals. */ + for (int i = 0; i < num_observations; ++i) { ceres_problem_add_residual_block( problem, exponential_residual, /* Cost function */ - NULL, /* No loss function */ &data[2 * i], /* Points to the (x,y) measurement */ + NULL, /* No loss function */ + NULL, /* No loss function user data */ 1, /* Number of residuals */ 2, /* Number of parameter blocks */ parameter_sizes, @@ -172,6 +178,7 @@ } ceres_solve(problem); + ceres_free_problem(problem); printf("Initial m: 0.0, c: 0.0\n"); printf("Final m: %g, c: %g\n", m, c);
diff --git a/include/ceres/c_api.h b/include/ceres/c_api.h index 8d74a27..8b4eaf5 100644 --- a/include/ceres/c_api.h +++ b/include/ceres/c_api.h
@@ -55,9 +55,55 @@ double** jacobians); /* Equivalent to LossFunction::Evaluate() from the C++ API. */ -typedef int (*ceres_loss_function_t)(void* user_data, - double squared_norm, - double out[3]); +typedef void (*ceres_loss_function_t)(void* user_data, + double squared_norm, + double out[3]); + +/* Create callback data for Ceres' stock loss functions. + * + * Ceres has several loss functions available by default, and these functions + * expose those to the C API. To use the stock loss functions, call + * ceres_create_*_loss_data(), which internally creates an instance of one of + * the stock loss functions (for example ceres::CauchyLoss), and pass the + * returned "loss_function_data" along with the ceres_stock_loss_function to + * ceres_add_residual_block(). + * + * For example: + * + * void* cauchy_loss_function_data = + * ceres_create_cauchy_loss_function_data(1.2, 0.0); + * ceres_problem_add_residual_block( + * problem, + * my_cost_function, + * my_cost_function_data, + * ceres_stock_loss_function, + * cauchy_loss_function_data, + * 1, + * 2, + * parameter_sizes, + * parameter_pointers); + * ... + * ceres_free_stock_loss_function_data(cauchy_loss_function_data); + * + * See loss_function.h for the details of each loss function. + */ +void* ceres_create_huber_loss_function_data(double a); +void* ceres_create_softl1_loss_function_data(double a); +void* ceres_create_cauchy_loss_function_data(double a); +void* ceres_create_arctan_loss_function_data(double a); +void* ceres_create_tolerant_loss_function_data(double a, double b); + +/* Free the given stock loss function data. */ +void ceres_free_stock_loss_function_data(void* loss_function_data); + +/* This is an implementation of ceres_loss_function_t contained within Ceres + * itself, intended as a way to access the various stock Ceres loss functions + * from the C API. This should be passed to ceres_add_residual() below, in + * combination with a user_data pointer generated by + * ceres_create_stock_loss_function() above. */ +void ceres_stock_loss_function(void* user_data, + double squared_norm, + double out[3]); /* Equivalent to Problem from the C++ API. */ struct ceres_problem_s; @@ -72,12 +118,12 @@ void ceres_free_problem(ceres_problem_t* problem); /* Add a residual block. */ -/* TODO(keir): Add support for loss functions */ ceres_residual_block_id_t* ceres_problem_add_residual_block( ceres_problem_t* problem, ceres_cost_function_t cost_function, + void* cost_function_data, ceres_loss_function_t loss_function, - void* user_data, + void* loss_function_data, int num_residuals, int num_parameter_blocks, int* parameter_block_sizes,
diff --git a/internal/ceres/CMakeLists.txt b/internal/ceres/CMakeLists.txt index 3b8b2f0..dfa567c 100644 --- a/internal/ceres/CMakeLists.txt +++ b/internal/ceres/CMakeLists.txt
@@ -240,6 +240,7 @@ CERES_TEST(block_random_access_dense_matrix) CERES_TEST(block_random_access_sparse_matrix) CERES_TEST(block_sparse_matrix) + CERES_TEST(c_api) CERES_TEST(canonical_views_clustering) CERES_TEST(compressed_row_sparse_matrix) CERES_TEST(conditioned_cost_function)
diff --git a/internal/ceres/c_api.cc b/internal/ceres/c_api.cc index 4d7d59b..02bc129 100644 --- a/internal/ceres/c_api.cc +++ b/internal/ceres/c_api.cc
@@ -32,10 +32,13 @@ // // TODO(keir): Figure out why logging does not seem to work. -#include <vector> -#include <iostream> // XXX remove me #include "ceres/c_api.h" + +#include <vector> +#include <iostream> +#include <string> #include "ceres/cost_function.h" +#include "ceres/loss_function.h" #include "ceres/problem.h" #include "ceres/solver.h" #include "ceres/types.h" // for std @@ -57,6 +60,8 @@ delete reinterpret_cast<Problem*>(problem); } +// This cost function wraps a C-level function pointer from the user, to bridge +// between C and C++. class CallbackCostFunction : public ceres::CostFunction { public: CallbackCostFunction(ceres_cost_function_t cost_function, @@ -88,11 +93,56 @@ void* user_data_; }; +// This loss function wraps a C-level function pointer from the user, to bridge +// between C and C++. +class CallbackLossFunction : public ceres::LossFunction { + public: + explicit CallbackLossFunction(ceres_loss_function_t loss_function, + void* user_data) + : loss_function_(loss_function), user_data_(user_data) {} + virtual void Evaluate(double sq_norm, double* rho) const { + (*loss_function_)(user_data_, sq_norm, rho); + } + + private: + ceres_loss_function_t loss_function_; + void* user_data_; +}; + +// Wrappers for the stock loss functions. +void* ceres_create_huber_loss_function_data(double a) { + return new ceres::HuberLoss(a); +} +void* ceres_create_softl1_loss_function_data(double a) { + return new ceres::SoftLOneLoss(a); +} +void* ceres_create_cauchy_loss_function_data(double a) { + return new ceres::CauchyLoss(a); +} +void* ceres_create_arctan_loss_function_data(double a) { + return new ceres::ArctanLoss(a); +} +void* ceres_create_tolerant_loss_function_data(double a, double b) { + return new ceres::TolerantLoss(a, b); +} + +void ceres_free_stock_loss_function_data(void* loss_function_data) { + delete reinterpret_cast<ceres::LossFunction*>(loss_function_data); +} + +void ceres_stock_loss_function(void* user_data, + double squared_norm, + double out[3]) { + reinterpret_cast<ceres::LossFunction*>(user_data) + ->Evaluate(squared_norm, out); +} + ceres_residual_block_id_t* ceres_problem_add_residual_block( ceres_problem_t* problem, ceres_cost_function_t cost_function, + void* cost_function_data, ceres_loss_function_t loss_function, - void* user_data, + void* loss_function_data, int num_residuals, int num_parameter_blocks, int* parameter_block_sizes, @@ -101,16 +151,22 @@ ceres::CostFunction* callback_cost_function = new CallbackCostFunction(cost_function, - user_data, + cost_function_data, num_residuals, num_parameter_blocks, parameter_block_sizes); + ceres::LossFunction* callback_loss_function = NULL; + if (loss_function != NULL) { + callback_loss_function = new CallbackLossFunction(loss_function, + loss_function_data); + } + std::vector<double*> parameter_blocks(parameters, parameters + num_parameter_blocks); return reinterpret_cast<ceres_residual_block_id_t*>( ceres_problem->AddResidualBlock(callback_cost_function, - NULL, /* Ignore loss for now */ + callback_loss_function, parameter_blocks)); } @@ -121,7 +177,7 @@ // Instead, figure out a way to specify some of the options without // duplicating everything. ceres::Solver::Options options; - options.max_num_iterations = 25; + options.max_num_iterations = 100; options.linear_solver_type = ceres::DENSE_QR; options.minimizer_progress_to_stdout = true;
diff --git a/internal/ceres/c_api_test.cc b/internal/ceres/c_api_test.cc new file mode 100644 index 0000000..c6bfb37 --- /dev/null +++ b/internal/ceres/c_api_test.cc
@@ -0,0 +1,221 @@ +// Ceres Solver - A fast non-linear least squares minimizer +// Copyright 2013 Google Inc. All rights reserved. +// http://code.google.com/p/ceres-solver/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// * Neither the name of Google Inc. nor the names of its contributors may be +// used to endorse or promote products derived from this software without +// specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Author: mierle@gmail.com (Keir Mierle) + +#include "ceres/c_api.h" + +#include <cmath> + +#include "glog/logging.h" +#include "gtest/gtest.h" + +// Duplicated from curve_fitting.cc. +int num_observations = 67; +double data[] = { + 0.000000e+00, 1.133898e+00, + 7.500000e-02, 1.334902e+00, + 1.500000e-01, 1.213546e+00, + 2.250000e-01, 1.252016e+00, + 3.000000e-01, 1.392265e+00, + 3.750000e-01, 1.314458e+00, + 4.500000e-01, 1.472541e+00, + 5.250000e-01, 1.536218e+00, + 6.000000e-01, 1.355679e+00, + 6.750000e-01, 1.463566e+00, + 7.500000e-01, 1.490201e+00, + 8.250000e-01, 1.658699e+00, + 9.000000e-01, 1.067574e+00, + 9.750000e-01, 1.464629e+00, + 1.050000e+00, 1.402653e+00, + 1.125000e+00, 1.713141e+00, + 1.200000e+00, 1.527021e+00, + 1.275000e+00, 1.702632e+00, + 1.350000e+00, 1.423899e+00, + 1.425000e+00, 1.543078e+00, + 1.500000e+00, 1.664015e+00, + 1.575000e+00, 1.732484e+00, + 1.650000e+00, 1.543296e+00, + 1.725000e+00, 1.959523e+00, + 1.800000e+00, 1.685132e+00, + 1.875000e+00, 1.951791e+00, + 1.950000e+00, 2.095346e+00, + 2.025000e+00, 2.361460e+00, + 2.100000e+00, 2.169119e+00, + 2.175000e+00, 2.061745e+00, + 2.250000e+00, 2.178641e+00, + 2.325000e+00, 2.104346e+00, + 2.400000e+00, 2.584470e+00, + 2.475000e+00, 1.914158e+00, + 2.550000e+00, 2.368375e+00, + 2.625000e+00, 2.686125e+00, + 2.700000e+00, 2.712395e+00, + 2.775000e+00, 2.499511e+00, + 2.850000e+00, 2.558897e+00, + 2.925000e+00, 2.309154e+00, + 3.000000e+00, 2.869503e+00, + 3.075000e+00, 3.116645e+00, + 3.150000e+00, 3.094907e+00, + 3.225000e+00, 2.471759e+00, + 3.300000e+00, 3.017131e+00, + 3.375000e+00, 3.232381e+00, + 3.450000e+00, 2.944596e+00, + 3.525000e+00, 3.385343e+00, + 3.600000e+00, 3.199826e+00, + 3.675000e+00, 3.423039e+00, + 3.750000e+00, 3.621552e+00, + 3.825000e+00, 3.559255e+00, + 3.900000e+00, 3.530713e+00, + 3.975000e+00, 3.561766e+00, + 4.050000e+00, 3.544574e+00, + 4.125000e+00, 3.867945e+00, + 4.200000e+00, 4.049776e+00, + 4.275000e+00, 3.885601e+00, + 4.350000e+00, 4.110505e+00, + 4.425000e+00, 4.345320e+00, + 4.500000e+00, 4.161241e+00, + 4.575000e+00, 4.363407e+00, + 4.650000e+00, 4.161576e+00, + 4.725000e+00, 4.619728e+00, + 4.800000e+00, 4.737410e+00, + 4.875000e+00, 4.727863e+00, + 4.950000e+00, 4.669206e+00, +}; + +// A test cost function, similar to the one in curve_fitting.c. +int exponential_residual(void* user_data, + double** parameters, + double* residuals, + double** jacobians) { + double* measurement = (double*) user_data; + double x = measurement[0]; + double y = measurement[1]; + double m = parameters[0][0]; + double c = parameters[1][0]; + + residuals[0] = y - exp(m * x + c); + if (jacobians == NULL) { + return 1; + } + if (jacobians[0] != NULL) { + jacobians[0][0] = - x * exp(m * x + c); // dr/dm + } + if (jacobians[1] != NULL) { + jacobians[1][0] = - exp(m * x + c); // dr/dc + } + return 1; +} + +namespace ceres { +namespace internal { + +TEST(C_API, SimpleEndToEndTest) { + double m = 0.0; + double c = 0.0; + double *parameter_pointers[] = { &m, &c }; + int parameter_sizes[] = { 1, 1 }; + + ceres_problem_t* problem = ceres_create_problem(); + for (int i = 0; i < num_observations; ++i) { + ceres_problem_add_residual_block( + problem, + exponential_residual, // Cost function + &data[2 * i], // Points to the (x,y) measurement + NULL, // Loss function + NULL, // Loss function user data + 1, // Number of residuals + 2, // Number of parameter blocks + parameter_sizes, + parameter_pointers); + } + + ceres_solve(problem); + + EXPECT_NEAR(0.3, m, 0.02); + EXPECT_NEAR(0.1, c, 0.04); + + ceres_free_problem(problem); +} + +template<typename T> +class ScopedSetValue { + public: + ScopedSetValue(T* variable, T new_value) + : variable_(variable), old_value_(*variable) { + *variable = new_value; + } + ~ScopedSetValue() { + *variable_ = old_value_; + } + + private: + T* variable_; + T old_value_; +}; + +TEST(C_API, LossFunctions) { + double m = 0.2; + double c = 0.03; + double *parameter_pointers[] = { &m, &c }; + int parameter_sizes[] = { 1, 1 }; + + // Create two outliers, but be careful to leave the data intact. + ScopedSetValue<double> outlier1x(&data[12], 2.5); + ScopedSetValue<double> outlier1y(&data[13], 1.0e3); + ScopedSetValue<double> outlier2x(&data[14], 3.2); + ScopedSetValue<double> outlier2y(&data[15], 30e3); + + // Create a cauchy cost function, and reuse it many times. + void* cauchy_loss_data = + ceres_create_cauchy_loss_function_data(5.0); + + ceres_problem_t* problem = ceres_create_problem(); + for (int i = 0; i < num_observations; ++i) { + ceres_problem_add_residual_block( + problem, + exponential_residual, // Cost function + &data[2 * i], // Points to the (x,y) measurement + ceres_stock_loss_function, + cauchy_loss_data, // Loss function user data + 1, // Number of residuals + 2, // Number of parameter blocks + parameter_sizes, + parameter_pointers); + } + + ceres_solve(problem); + + EXPECT_NEAR(0.3, m, 0.02); + EXPECT_NEAR(0.1, c, 0.04); + + ceres_free_stock_loss_function_data(cauchy_loss_data); + ceres_free_problem(problem); +} + +} // namespace internal +} // namespace ceres