Extend the C API to support loss functions
This extends the C API to support loss functions. Both
user-supplied cost functions as well as the stock Ceres cost
functions (Cauchy, Huber, etc) are supported. In addition, this
adds a simple unit test for the C API.
Supporting loss functions required changing the signature of the
ceres_add_residual_block() function to also take a thunk for the
loss function.
Change-Id: Iefa58cf709adbb8f24588e5eb6aed9aef46b6d73
diff --git a/internal/ceres/c_api.cc b/internal/ceres/c_api.cc
index 4d7d59b..02bc129 100644
--- a/internal/ceres/c_api.cc
+++ b/internal/ceres/c_api.cc
@@ -32,10 +32,13 @@
//
// TODO(keir): Figure out why logging does not seem to work.
-#include <vector>
-#include <iostream> // XXX remove me
#include "ceres/c_api.h"
+
+#include <vector>
+#include <iostream>
+#include <string>
#include "ceres/cost_function.h"
+#include "ceres/loss_function.h"
#include "ceres/problem.h"
#include "ceres/solver.h"
#include "ceres/types.h" // for std
@@ -57,6 +60,8 @@
delete reinterpret_cast<Problem*>(problem);
}
+// This cost function wraps a C-level function pointer from the user, to bridge
+// between C and C++.
class CallbackCostFunction : public ceres::CostFunction {
public:
CallbackCostFunction(ceres_cost_function_t cost_function,
@@ -88,11 +93,56 @@
void* user_data_;
};
+// This loss function wraps a C-level function pointer from the user, to bridge
+// between C and C++.
+class CallbackLossFunction : public ceres::LossFunction {
+ public:
+ explicit CallbackLossFunction(ceres_loss_function_t loss_function,
+ void* user_data)
+ : loss_function_(loss_function), user_data_(user_data) {}
+ virtual void Evaluate(double sq_norm, double* rho) const {
+ (*loss_function_)(user_data_, sq_norm, rho);
+ }
+
+ private:
+ ceres_loss_function_t loss_function_;
+ void* user_data_;
+};
+
+// Wrappers for the stock loss functions.
+void* ceres_create_huber_loss_function_data(double a) {
+ return new ceres::HuberLoss(a);
+}
+void* ceres_create_softl1_loss_function_data(double a) {
+ return new ceres::SoftLOneLoss(a);
+}
+void* ceres_create_cauchy_loss_function_data(double a) {
+ return new ceres::CauchyLoss(a);
+}
+void* ceres_create_arctan_loss_function_data(double a) {
+ return new ceres::ArctanLoss(a);
+}
+void* ceres_create_tolerant_loss_function_data(double a, double b) {
+ return new ceres::TolerantLoss(a, b);
+}
+
+void ceres_free_stock_loss_function_data(void* loss_function_data) {
+ delete reinterpret_cast<ceres::LossFunction*>(loss_function_data);
+}
+
+void ceres_stock_loss_function(void* user_data,
+ double squared_norm,
+ double out[3]) {
+ reinterpret_cast<ceres::LossFunction*>(user_data)
+ ->Evaluate(squared_norm, out);
+}
+
ceres_residual_block_id_t* ceres_problem_add_residual_block(
ceres_problem_t* problem,
ceres_cost_function_t cost_function,
+ void* cost_function_data,
ceres_loss_function_t loss_function,
- void* user_data,
+ void* loss_function_data,
int num_residuals,
int num_parameter_blocks,
int* parameter_block_sizes,
@@ -101,16 +151,22 @@
ceres::CostFunction* callback_cost_function =
new CallbackCostFunction(cost_function,
- user_data,
+ cost_function_data,
num_residuals,
num_parameter_blocks,
parameter_block_sizes);
+ ceres::LossFunction* callback_loss_function = NULL;
+ if (loss_function != NULL) {
+ callback_loss_function = new CallbackLossFunction(loss_function,
+ loss_function_data);
+ }
+
std::vector<double*> parameter_blocks(parameters,
parameters + num_parameter_blocks);
return reinterpret_cast<ceres_residual_block_id_t*>(
ceres_problem->AddResidualBlock(callback_cost_function,
- NULL, /* Ignore loss for now */
+ callback_loss_function,
parameter_blocks));
}
@@ -121,7 +177,7 @@
// Instead, figure out a way to specify some of the options without
// duplicating everything.
ceres::Solver::Options options;
- options.max_num_iterations = 25;
+ options.max_num_iterations = 100;
options.linear_solver_type = ceres::DENSE_QR;
options.minimizer_progress_to_stdout = true;