From f3e33bac04dd306028b78ab35965b97f56080b52 Mon Sep 17 00:00:00 2001 From: Jesse Mapel <jmapel@usgs.gov> Date: Thu, 16 Apr 2020 06:52:35 -0700 Subject: [PATCH] Added root finding algorithm (#277) --- include/usgscsm/Utilities.h | 8 +++++ src/Utilities.cpp | 69 +++++++++++++++++++++++++++++++++++++ tests/UtilitiesTests.cpp | 11 ++++++ 3 files changed, 88 insertions(+) diff --git a/include/usgscsm/Utilities.h b/include/usgscsm/Utilities.h index 8110e3e..0f32944 100644 --- a/include/usgscsm/Utilities.h +++ b/include/usgscsm/Utilities.h @@ -69,6 +69,14 @@ void lagrangeInterp ( const int& i_order, double* valueVector); +// Brent's algorithm for finding the roots of a function +// Arguments are two inputs that bracket a root, the function, and a convergence tolerance +double brentRoot( + double lowerBound, + double upperBound, + double (*func)(double), + double epsilon = 1e-10); + // Methods for checking/accessing the ISD double metric_conversion(double val, std::string from, std::string to="m"); diff --git a/src/Utilities.cpp b/src/Utilities.cpp index 39ce6ed..cf2f404 100644 --- a/src/Utilities.cpp +++ b/src/Utilities.cpp @@ -4,6 +4,7 @@ #include <Error.h> #include <stack> #include <utility> +#include <stdexcept> using json = nlohmann::json; @@ -276,6 +277,74 @@ void lagrangeInterp( } } +double brentRoot( + double lowerBound, + double upperBound, + double (*func)(double), + double epsilon) { + double counterPoint = lowerBound; + double currentPoint = upperBound; + double counterFunc = func(counterPoint); + double currentFunc = func(currentPoint); + if (counterFunc * currentFunc > 0.0) { + throw std::invalid_argument("Function values at the boundaries have the same sign."); + } + if (fabs(counterFunc) < fabs(currentFunc)) { + std::swap(counterPoint, currentPoint); + std::swap(counterFunc, currentFunc); + } + + double previousPoint = counterPoint; + double previousFunc = counterFunc; + double evenOlderPoint = previousPoint; + double nextPoint; + double nextFunc; + int iteration = 0; + bool bisected = true; + + do { + // Inverse quadratic interpolation + if (counterFunc != previousFunc && counterFunc != currentFunc) { + nextPoint = (counterPoint * currentFunc * previousFunc) / ((counterFunc - currentFunc) * (counterFunc - previousFunc)); + nextPoint += (currentPoint * counterFunc * previousFunc) / ((currentFunc - counterFunc) * (currentFunc - previousFunc)); + nextPoint += (previousPoint * currentFunc * counterFunc) / ((previousFunc - counterFunc) * (previousFunc - currentFunc)); + } + // Secant method + else { + nextPoint = currentPoint - currentFunc * (currentPoint - counterPoint) / (currentFunc - counterFunc); + } + + // Bisection method + if (((currentPoint - nextPoint) * (nextPoint - (3 * counterPoint + currentPoint) / 4) < 0) || + (bisected && fabs(nextPoint - currentPoint) >= fabs(currentPoint - previousPoint) / 2) || + (!bisected && fabs(nextPoint - currentPoint) >= fabs(previousPoint - evenOlderPoint) / 2) || + (bisected && fabs(currentPoint - previousPoint) < epsilon) || + (!bisected && fabs(previousPoint - evenOlderPoint) < epsilon)) { + nextPoint = (currentPoint + counterPoint) / 2; + bisected = true; + } + else { + bisected = false; + } + + // Setup for next iteration + evenOlderPoint = previousPoint; + previousPoint = currentPoint; + previousFunc = currentFunc; + nextFunc = func(nextPoint); + if (counterFunc * nextFunc < 0) { + currentPoint = nextPoint; + currentFunc = nextFunc; + } + else { + counterPoint = nextPoint; + counterFunc = nextFunc; + } + } while (++iteration < 30 && fabs(counterPoint - currentPoint) > epsilon); + + return nextPoint; + } + // convert a measurement double metric_conversion(double val, std::string from, std::string to) { json typemap = { diff --git a/tests/UtilitiesTests.cpp b/tests/UtilitiesTests.cpp index f837808..f1918bd 100644 --- a/tests/UtilitiesTests.cpp +++ b/tests/UtilitiesTests.cpp @@ -8,6 +8,7 @@ #include <gtest/gtest.h> #include <math.h> +#include <stdexcept> using json = nlohmann::json; @@ -337,3 +338,13 @@ TEST(UtilitiesTests, lagrangeInterp2D) { EXPECT_DOUBLE_EQ(outputValue[0], 0.5); EXPECT_DOUBLE_EQ(outputValue[1], 1.5); } + +double testPoly(double x) { + return (x - 2) * (x + 1) * (x + 7); +}; + +TEST(UtilitiesTests, brentRoot) { + EXPECT_NEAR(brentRoot(1.0, 3.0, testPoly, 1e-10), 2.0, 1e-10); + EXPECT_NEAR(brentRoot(0.0, -3.0, testPoly, 1e-10), -1.0, 1e-10); + EXPECT_THROW(brentRoot(-3.0, 3.0, testPoly), std::invalid_argument); +} -- GitLab