Extend the C API to support loss functions This extends the C API to support loss functions. Both user-supplied cost functions as well as the stock Ceres cost functions (Cauchy, Huber, etc) are supported. In addition, this adds a simple unit test for the C API. Supporting loss functions required changing the signature of the ceres_add_residual_block() function to also take a thunk for the loss function. Change-Id: Iefa58cf709adbb8f24588e5eb6aed9aef46b6d73
diff --git a/internal/ceres/c_api.cc b/internal/ceres/c_api.cc index 4d7d59b..02bc129 100644 --- a/internal/ceres/c_api.cc +++ b/internal/ceres/c_api.cc
@@ -32,10 +32,13 @@ // // TODO(keir): Figure out why logging does not seem to work. -#include <vector> -#include <iostream> // XXX remove me #include "ceres/c_api.h" + +#include <vector> +#include <iostream> +#include <string> #include "ceres/cost_function.h" +#include "ceres/loss_function.h" #include "ceres/problem.h" #include "ceres/solver.h" #include "ceres/types.h" // for std @@ -57,6 +60,8 @@ delete reinterpret_cast<Problem*>(problem); } +// This cost function wraps a C-level function pointer from the user, to bridge +// between C and C++. class CallbackCostFunction : public ceres::CostFunction { public: CallbackCostFunction(ceres_cost_function_t cost_function, @@ -88,11 +93,56 @@ void* user_data_; }; +// This loss function wraps a C-level function pointer from the user, to bridge +// between C and C++. +class CallbackLossFunction : public ceres::LossFunction { + public: + explicit CallbackLossFunction(ceres_loss_function_t loss_function, + void* user_data) + : loss_function_(loss_function), user_data_(user_data) {} + virtual void Evaluate(double sq_norm, double* rho) const { + (*loss_function_)(user_data_, sq_norm, rho); + } + + private: + ceres_loss_function_t loss_function_; + void* user_data_; +}; + +// Wrappers for the stock loss functions. +void* ceres_create_huber_loss_function_data(double a) { + return new ceres::HuberLoss(a); +} +void* ceres_create_softl1_loss_function_data(double a) { + return new ceres::SoftLOneLoss(a); +} +void* ceres_create_cauchy_loss_function_data(double a) { + return new ceres::CauchyLoss(a); +} +void* ceres_create_arctan_loss_function_data(double a) { + return new ceres::ArctanLoss(a); +} +void* ceres_create_tolerant_loss_function_data(double a, double b) { + return new ceres::TolerantLoss(a, b); +} + +void ceres_free_stock_loss_function_data(void* loss_function_data) { + delete reinterpret_cast<ceres::LossFunction*>(loss_function_data); +} + +void ceres_stock_loss_function(void* user_data, + double squared_norm, + double out[3]) { + reinterpret_cast<ceres::LossFunction*>(user_data) + ->Evaluate(squared_norm, out); +} + ceres_residual_block_id_t* ceres_problem_add_residual_block( ceres_problem_t* problem, ceres_cost_function_t cost_function, + void* cost_function_data, ceres_loss_function_t loss_function, - void* user_data, + void* loss_function_data, int num_residuals, int num_parameter_blocks, int* parameter_block_sizes, @@ -101,16 +151,22 @@ ceres::CostFunction* callback_cost_function = new CallbackCostFunction(cost_function, - user_data, + cost_function_data, num_residuals, num_parameter_blocks, parameter_block_sizes); + ceres::LossFunction* callback_loss_function = NULL; + if (loss_function != NULL) { + callback_loss_function = new CallbackLossFunction(loss_function, + loss_function_data); + } + std::vector<double*> parameter_blocks(parameters, parameters + num_parameter_blocks); return reinterpret_cast<ceres_residual_block_id_t*>( ceres_problem->AddResidualBlock(callback_cost_function, - NULL, /* Ignore loss for now */ + callback_loss_function, parameter_blocks)); } @@ -121,7 +177,7 @@ // Instead, figure out a way to specify some of the options without // duplicating everything. ceres::Solver::Options options; - options.max_num_iterations = 25; + options.max_num_iterations = 100; options.linear_solver_type = ceres::DENSE_QR; options.minimizer_progress_to_stdout = true;