diff --git a/include/usgscsm/Utilities.h b/include/usgscsm/Utilities.h index 8110e3e28797bbb9d0a878400c09140a0dfa646e..0f329449a3e467d47fe8350accb703ec50bc5c50 100644 --- a/include/usgscsm/Utilities.h +++ b/include/usgscsm/Utilities.h @@ -69,6 +69,14 @@ void lagrangeInterp ( const int& i_order, double* valueVector); +// Brent's algorithm for finding the roots of a function +// Arguments are two inputs that bracket a root, the function, and a convergence tolerance +double brentRoot( + double lowerBound, + double upperBound, + double (*func)(double), + double epsilon = 1e-10); + // Methods for checking/accessing the ISD double metric_conversion(double val, std::string from, std::string to="m"); diff --git a/src/Utilities.cpp b/src/Utilities.cpp index 39ce6ed2ac2c2b56f994e37b23b06201a9951c8f..cf2f4047b38c2cfc4ea76af72ac8fc6fa559e5f3 100644 --- a/src/Utilities.cpp +++ b/src/Utilities.cpp @@ -4,6 +4,7 @@ #include <Error.h> #include <stack> #include <utility> +#include <stdexcept> using json = nlohmann::json; @@ -276,6 +277,74 @@ void lagrangeInterp( } } +double brentRoot( + double lowerBound, + double upperBound, + double (*func)(double), + double epsilon) { + double counterPoint = lowerBound; + double currentPoint = upperBound; + double counterFunc = func(counterPoint); + double currentFunc = func(currentPoint); + if (counterFunc * currentFunc > 0.0) { + throw std::invalid_argument("Function values at the boundaries have the same sign."); + } + if (fabs(counterFunc) < fabs(currentFunc)) { + std::swap(counterPoint, currentPoint); + std::swap(counterFunc, currentFunc); + } + + double previousPoint = counterPoint; + double previousFunc = counterFunc; + double evenOlderPoint = previousPoint; + double nextPoint; + double nextFunc; + int iteration = 0; + bool bisected = true; + + do { + // Inverse quadratic interpolation + if (counterFunc != previousFunc && counterFunc != currentFunc) { + nextPoint = (counterPoint * currentFunc * previousFunc) / ((counterFunc - currentFunc) * (counterFunc - previousFunc)); + nextPoint += (currentPoint * counterFunc * previousFunc) / ((currentFunc - counterFunc) * (currentFunc - previousFunc)); + nextPoint += (previousPoint * currentFunc * counterFunc) / ((previousFunc - counterFunc) * (previousFunc - currentFunc)); + } + // Secant method + else { + nextPoint = currentPoint - currentFunc * (currentPoint - counterPoint) / (currentFunc - counterFunc); + } + + // Bisection method + if (((currentPoint - nextPoint) * (nextPoint - (3 * counterPoint + currentPoint) / 4) < 0) || + (bisected && fabs(nextPoint - currentPoint) >= fabs(currentPoint - previousPoint) / 2) || + (!bisected && fabs(nextPoint - currentPoint) >= fabs(previousPoint - evenOlderPoint) / 2) || + (bisected && fabs(currentPoint - previousPoint) < epsilon) || + (!bisected && fabs(previousPoint - evenOlderPoint) < epsilon)) { + nextPoint = (currentPoint + counterPoint) / 2; + bisected = true; + } + else { + bisected = false; + } + + // Setup for next iteration + evenOlderPoint = previousPoint; + previousPoint = currentPoint; + previousFunc = currentFunc; + nextFunc = func(nextPoint); + if (counterFunc * nextFunc < 0) { + currentPoint = nextPoint; + currentFunc = nextFunc; + } + else { + counterPoint = nextPoint; + counterFunc = nextFunc; + } + } while (++iteration < 30 && fabs(counterPoint - currentPoint) > epsilon); + + return nextPoint; + } + // convert a measurement double metric_conversion(double val, std::string from, std::string to) { json typemap = { diff --git a/tests/UtilitiesTests.cpp b/tests/UtilitiesTests.cpp index f837808e0dcf28c4091f5d212b2a63031b6629c0..f1918bdf24539c3ab3e0c21275212835ceec9cc1 100644 --- a/tests/UtilitiesTests.cpp +++ b/tests/UtilitiesTests.cpp @@ -8,6 +8,7 @@ #include <gtest/gtest.h> #include <math.h> +#include <stdexcept> using json = nlohmann::json; @@ -337,3 +338,13 @@ TEST(UtilitiesTests, lagrangeInterp2D) { EXPECT_DOUBLE_EQ(outputValue[0], 0.5); EXPECT_DOUBLE_EQ(outputValue[1], 1.5); } + +double testPoly(double x) { + return (x - 2) * (x + 1) * (x + 7); +}; + +TEST(UtilitiesTests, brentRoot) { + EXPECT_NEAR(brentRoot(1.0, 3.0, testPoly, 1e-10), 2.0, 1e-10); + EXPECT_NEAR(brentRoot(0.0, -3.0, testPoly, 1e-10), -1.0, 1e-10); + EXPECT_THROW(brentRoot(-3.0, 3.0, testPoly), std::invalid_argument); +}