Fix Tukey loss function Since the output of LossFunction::Evaluate is multiplied by 0.5, the current implementation of the Tukey loss function must be multiplied by 2. Change-Id: Ia94753eef1a375fe48cc2e0d2cc61350904c8248
diff --git a/include/ceres/loss_function.h b/include/ceres/loss_function.h index 1294100..7aabf7d 100644 --- a/include/ceres/loss_function.h +++ b/include/ceres/loss_function.h
@@ -186,7 +186,7 @@ // // rho(s) = 2 (sqrt(1 + s) - 1). // -// At s = 0: rho = [0, 1, -1/2]. +// At s = 0: rho = [0, 1, -1 / (2 * a^2)]. class CERES_EXPORT SoftLOneLoss : public LossFunction { public: explicit SoftLOneLoss(double a) : b_(a * a), c_(1 / b_) {} @@ -203,7 +203,7 @@ // // rho(s) = log(1 + s). // -// At s = 0: rho = [0, 1, -1]. +// At s = 0: rho = [0, 1, -1 / a^2]. class CERES_EXPORT CauchyLoss : public LossFunction { public: explicit CauchyLoss(double a) : b_(a * a), c_(1 / b_) {} @@ -276,12 +276,13 @@ // This is the Tukey biweight loss function which aggressively // attempts to suppress large errors. // -// The term is computed as: +// The term is computed as follows where the equations are scaled by a +// factor of 2 because the cost function is given by 1/2 rho(s): // -// rho(s) = a^2 / 6 * (1 - (1 - s / a^2)^3 ) for s <= a^2, -// rho(s) = a^2 / 6 for s > a^2. +// rho(s) = a^2 / 3 * (1 - (1 - s / a^2)^3 ) for s <= a^2, +// rho(s) = a^2 / 3 for s > a^2. // -// At s = 0: rho = [0, 0.5, -1 / a^2] +// At s = 0: rho = [0, 1, -2 / a^2] class CERES_EXPORT TukeyLoss : public ceres::LossFunction { public: explicit TukeyLoss(double a) : a_squared_(a * a) {}
diff --git a/internal/ceres/loss_function.cc b/internal/ceres/loss_function.cc index bf41b9e..5963d48 100644 --- a/internal/ceres/loss_function.cc +++ b/internal/ceres/loss_function.cc
@@ -120,12 +120,12 @@ // Inlier region. const double value = 1.0 - s / a_squared_; const double value_sq = value * value; - rho[0] = a_squared_ / 6.0 * (1.0 - value_sq * value); - rho[1] = 0.5 * value_sq; - rho[2] = -1.0 / a_squared_ * value; + rho[0] = a_squared_ / 3.0 * (1.0 - value_sq * value); + rho[1] = value_sq; + rho[2] = -2.0 / a_squared_ * value; } else { // Outlier region. - rho[0] = a_squared_ / 6.0; + rho[0] = a_squared_ / 3.0; rho[1] = 0.0; rho[2] = 0.0; }
diff --git a/internal/ceres/loss_function_test.cc b/internal/ceres/loss_function_test.cc index 406ace7..6302dbe 100644 --- a/internal/ceres/loss_function_test.cc +++ b/internal/ceres/loss_function_test.cc
@@ -81,6 +81,12 @@ TEST(LossFunction, TrivialLoss) { AssertLossFunctionIsValid(TrivialLoss(), 0.357); AssertLossFunctionIsValid(TrivialLoss(), 1.792); + // Check that at s = 0: rho = [0, 1, 0]. + double rho[3]; + TrivialLoss().Evaluate(0.0, rho); + ASSERT_NEAR(rho[0], 0.0, 1e-6); + ASSERT_NEAR(rho[1], 1.0, 1e-6); + ASSERT_NEAR(rho[2], 0.0, 1e-6); } TEST(LossFunction, HuberLoss) { @@ -88,6 +94,12 @@ AssertLossFunctionIsValid(HuberLoss(0.7), 1.792); AssertLossFunctionIsValid(HuberLoss(1.3), 0.357); AssertLossFunctionIsValid(HuberLoss(1.3), 1.792); + // Check that at s = 0: rho = [0, 1, 0]. + double rho[3]; + HuberLoss(0.7).Evaluate(0.0, rho); + ASSERT_NEAR(rho[0], 0.0, 1e-6); + ASSERT_NEAR(rho[1], 1.0, 1e-6); + ASSERT_NEAR(rho[2], 0.0, 1e-6); } TEST(LossFunction, SoftLOneLoss) { @@ -95,6 +107,12 @@ AssertLossFunctionIsValid(SoftLOneLoss(0.7), 1.792); AssertLossFunctionIsValid(SoftLOneLoss(1.3), 0.357); AssertLossFunctionIsValid(SoftLOneLoss(1.3), 1.792); + // Check that at s = 0: rho = [0, 1, -1 / (2 * a^2)]. + double rho[3]; + SoftLOneLoss(0.7).Evaluate(0.0, rho); + ASSERT_NEAR(rho[0], 0.0, 1e-6); + ASSERT_NEAR(rho[1], 1.0, 1e-6); + ASSERT_NEAR(rho[2], -0.5 / (0.7 * 0.7), 1e-6); } TEST(LossFunction, CauchyLoss) { @@ -102,6 +120,12 @@ AssertLossFunctionIsValid(CauchyLoss(0.7), 1.792); AssertLossFunctionIsValid(CauchyLoss(1.3), 0.357); AssertLossFunctionIsValid(CauchyLoss(1.3), 1.792); + // Check that at s = 0: rho = [0, 1, -1 / a^2]. + double rho[3]; + CauchyLoss(0.7).Evaluate(0.0, rho); + ASSERT_NEAR(rho[0], 0.0, 1e-6); + ASSERT_NEAR(rho[1], 1.0, 1e-6); + ASSERT_NEAR(rho[2], -1.0 / (0.7 * 0.7), 1e-6); } TEST(LossFunction, ArctanLoss) { @@ -109,6 +133,12 @@ AssertLossFunctionIsValid(ArctanLoss(0.7), 1.792); AssertLossFunctionIsValid(ArctanLoss(1.3), 0.357); AssertLossFunctionIsValid(ArctanLoss(1.3), 1.792); + // Check that at s = 0: rho = [0, 1, 0]. + double rho[3]; + ArctanLoss(0.7).Evaluate(0.0, rho); + ASSERT_NEAR(rho[0], 0.0, 1e-6); + ASSERT_NEAR(rho[1], 1.0, 1e-6); + ASSERT_NEAR(rho[2], 0.0, 1e-6); } TEST(LossFunction, TolerantLoss) { @@ -135,6 +165,12 @@ AssertLossFunctionIsValid(TukeyLoss(0.7), 1.792); AssertLossFunctionIsValid(TukeyLoss(1.3), 0.357); AssertLossFunctionIsValid(TukeyLoss(1.3), 1.792); + // Check that at s = 0: rho = [0, 1, -2 / a^2]. + double rho[3]; + TukeyLoss(0.7).Evaluate(0.0, rho); + ASSERT_NEAR(rho[0], 0.0, 1e-6); + ASSERT_NEAR(rho[1], 1.0, 1e-6); + ASSERT_NEAR(rho[2], -2.0 / (0.7 * 0.7), 1e-6); } TEST(LossFunction, ComposedLoss) {