Fix Tukey loss function
Since the output of LossFunction::Evaluate is multiplied by 0.5, the
current implementation of the Tukey loss function must be multiplied
by 2.
Change-Id: Ia94753eef1a375fe48cc2e0d2cc61350904c8248
diff --git a/include/ceres/loss_function.h b/include/ceres/loss_function.h
index 1294100..7aabf7d 100644
--- a/include/ceres/loss_function.h
+++ b/include/ceres/loss_function.h
@@ -186,7 +186,7 @@
//
// rho(s) = 2 (sqrt(1 + s) - 1).
//
-// At s = 0: rho = [0, 1, -1/2].
+// At s = 0: rho = [0, 1, -1 / (2 * a^2)].
class CERES_EXPORT SoftLOneLoss : public LossFunction {
public:
explicit SoftLOneLoss(double a) : b_(a * a), c_(1 / b_) {}
@@ -203,7 +203,7 @@
//
// rho(s) = log(1 + s).
//
-// At s = 0: rho = [0, 1, -1].
+// At s = 0: rho = [0, 1, -1 / a^2].
class CERES_EXPORT CauchyLoss : public LossFunction {
public:
explicit CauchyLoss(double a) : b_(a * a), c_(1 / b_) {}
@@ -276,12 +276,13 @@
// This is the Tukey biweight loss function which aggressively
// attempts to suppress large errors.
//
-// The term is computed as:
+// The term is computed as follows where the equations are scaled by a
+// factor of 2 because the cost function is given by 1/2 rho(s):
//
-// rho(s) = a^2 / 6 * (1 - (1 - s / a^2)^3 ) for s <= a^2,
-// rho(s) = a^2 / 6 for s > a^2.
+// rho(s) = a^2 / 3 * (1 - (1 - s / a^2)^3 ) for s <= a^2,
+// rho(s) = a^2 / 3 for s > a^2.
//
-// At s = 0: rho = [0, 0.5, -1 / a^2]
+// At s = 0: rho = [0, 1, -2 / a^2]
class CERES_EXPORT TukeyLoss : public ceres::LossFunction {
public:
explicit TukeyLoss(double a) : a_squared_(a * a) {}
diff --git a/internal/ceres/loss_function.cc b/internal/ceres/loss_function.cc
index bf41b9e..5963d48 100644
--- a/internal/ceres/loss_function.cc
+++ b/internal/ceres/loss_function.cc
@@ -120,12 +120,12 @@
// Inlier region.
const double value = 1.0 - s / a_squared_;
const double value_sq = value * value;
- rho[0] = a_squared_ / 6.0 * (1.0 - value_sq * value);
- rho[1] = 0.5 * value_sq;
- rho[2] = -1.0 / a_squared_ * value;
+ rho[0] = a_squared_ / 3.0 * (1.0 - value_sq * value);
+ rho[1] = value_sq;
+ rho[2] = -2.0 / a_squared_ * value;
} else {
// Outlier region.
- rho[0] = a_squared_ / 6.0;
+ rho[0] = a_squared_ / 3.0;
rho[1] = 0.0;
rho[2] = 0.0;
}
diff --git a/internal/ceres/loss_function_test.cc b/internal/ceres/loss_function_test.cc
index 406ace7..6302dbe 100644
--- a/internal/ceres/loss_function_test.cc
+++ b/internal/ceres/loss_function_test.cc
@@ -81,6 +81,12 @@
TEST(LossFunction, TrivialLoss) {
AssertLossFunctionIsValid(TrivialLoss(), 0.357);
AssertLossFunctionIsValid(TrivialLoss(), 1.792);
+ // Check that at s = 0: rho = [0, 1, 0].
+ double rho[3];
+ TrivialLoss().Evaluate(0.0, rho);
+ ASSERT_NEAR(rho[0], 0.0, 1e-6);
+ ASSERT_NEAR(rho[1], 1.0, 1e-6);
+ ASSERT_NEAR(rho[2], 0.0, 1e-6);
}
TEST(LossFunction, HuberLoss) {
@@ -88,6 +94,12 @@
AssertLossFunctionIsValid(HuberLoss(0.7), 1.792);
AssertLossFunctionIsValid(HuberLoss(1.3), 0.357);
AssertLossFunctionIsValid(HuberLoss(1.3), 1.792);
+ // Check that at s = 0: rho = [0, 1, 0].
+ double rho[3];
+ HuberLoss(0.7).Evaluate(0.0, rho);
+ ASSERT_NEAR(rho[0], 0.0, 1e-6);
+ ASSERT_NEAR(rho[1], 1.0, 1e-6);
+ ASSERT_NEAR(rho[2], 0.0, 1e-6);
}
TEST(LossFunction, SoftLOneLoss) {
@@ -95,6 +107,12 @@
AssertLossFunctionIsValid(SoftLOneLoss(0.7), 1.792);
AssertLossFunctionIsValid(SoftLOneLoss(1.3), 0.357);
AssertLossFunctionIsValid(SoftLOneLoss(1.3), 1.792);
+ // Check that at s = 0: rho = [0, 1, -1 / (2 * a^2)].
+ double rho[3];
+ SoftLOneLoss(0.7).Evaluate(0.0, rho);
+ ASSERT_NEAR(rho[0], 0.0, 1e-6);
+ ASSERT_NEAR(rho[1], 1.0, 1e-6);
+ ASSERT_NEAR(rho[2], -0.5 / (0.7 * 0.7), 1e-6);
}
TEST(LossFunction, CauchyLoss) {
@@ -102,6 +120,12 @@
AssertLossFunctionIsValid(CauchyLoss(0.7), 1.792);
AssertLossFunctionIsValid(CauchyLoss(1.3), 0.357);
AssertLossFunctionIsValid(CauchyLoss(1.3), 1.792);
+ // Check that at s = 0: rho = [0, 1, -1 / a^2].
+ double rho[3];
+ CauchyLoss(0.7).Evaluate(0.0, rho);
+ ASSERT_NEAR(rho[0], 0.0, 1e-6);
+ ASSERT_NEAR(rho[1], 1.0, 1e-6);
+ ASSERT_NEAR(rho[2], -1.0 / (0.7 * 0.7), 1e-6);
}
TEST(LossFunction, ArctanLoss) {
@@ -109,6 +133,12 @@
AssertLossFunctionIsValid(ArctanLoss(0.7), 1.792);
AssertLossFunctionIsValid(ArctanLoss(1.3), 0.357);
AssertLossFunctionIsValid(ArctanLoss(1.3), 1.792);
+ // Check that at s = 0: rho = [0, 1, 0].
+ double rho[3];
+ ArctanLoss(0.7).Evaluate(0.0, rho);
+ ASSERT_NEAR(rho[0], 0.0, 1e-6);
+ ASSERT_NEAR(rho[1], 1.0, 1e-6);
+ ASSERT_NEAR(rho[2], 0.0, 1e-6);
}
TEST(LossFunction, TolerantLoss) {
@@ -135,6 +165,12 @@
AssertLossFunctionIsValid(TukeyLoss(0.7), 1.792);
AssertLossFunctionIsValid(TukeyLoss(1.3), 0.357);
AssertLossFunctionIsValid(TukeyLoss(1.3), 1.792);
+ // Check that at s = 0: rho = [0, 1, -2 / a^2].
+ double rho[3];
+ TukeyLoss(0.7).Evaluate(0.0, rho);
+ ASSERT_NEAR(rho[0], 0.0, 1e-6);
+ ASSERT_NEAR(rho[1], 1.0, 1e-6);
+ ASSERT_NEAR(rho[2], -2.0 / (0.7 * 0.7), 1e-6);
}
TEST(LossFunction, ComposedLoss) {