LevenbergMarquardtMinimizer.h 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. // David Eberly, Geometric Tools, Redmond WA 98052
  2. // Copyright (c) 1998-2020
  3. // Distributed under the Boost Software License, Version 1.0.
  4. // https://www.boost.org/LICENSE_1_0.txt
  5. // https://www.geometrictools.com/License/Boost/LICENSE_1_0.txt
  6. // Version: 4.0.2019.08.13
  7. #pragma once
  8. #include <Mathematics/CholeskyDecomposition.h>
  9. #include <functional>
  10. // See GteGaussNewtonMinimizer.h for a formulation of the minimization
  11. // problem and how Levenberg-Marquardt relates to Gauss-Newton.
  12. namespace WwiseGTE
  13. {
  14. template <typename Real>
  15. class LevenbergMarquardtMinimizer
  16. {
  17. public:
  18. // Convenient types for the domain vectors, the range vectors, the
  19. // function F and the Jacobian J.
  20. typedef GVector<Real> DVector; // numPDimensions
  21. typedef GVector<Real> RVector; // numFDImensions
  22. typedef GMatrix<Real> JMatrix; // numFDimensions-by-numPDimensions
  23. typedef GMatrix<Real> JTJMatrix; // numPDimensions-by-numPDimensions
  24. typedef GVector<Real> JTFVector; // numPDimensions
  25. typedef std::function<void(DVector const&, RVector&)> FFunction;
  26. typedef std::function<void(DVector const&, JMatrix&)> JFunction;
  27. typedef std::function<void(DVector const&, JTJMatrix&, JTFVector&)> JPlusFunction;
  28. // Create the minimizer that computes F(p) and J(p) directly.
  29. LevenbergMarquardtMinimizer(int numPDimensions, int numFDimensions,
  30. FFunction const& inFFunction, JFunction const& inJFunction)
  31. :
  32. mNumPDimensions(numPDimensions),
  33. mNumFDimensions(numFDimensions),
  34. mFFunction(inFFunction),
  35. mJFunction(inJFunction),
  36. mF(mNumFDimensions),
  37. mJ(mNumFDimensions, mNumPDimensions),
  38. mJTJ(mNumPDimensions, mNumPDimensions),
  39. mNegJTF(mNumPDimensions),
  40. mDecomposer(mNumPDimensions),
  41. mUseJFunction(true)
  42. {
  43. LogAssert(mNumPDimensions > 0 && mNumFDimensions > 0, "Invalid dimensions.");
  44. }
  45. // Create the minimizer that computes J^T(p)*J(p) and -J(p)*F(p).
  46. LevenbergMarquardtMinimizer(int numPDimensions, int numFDimensions,
  47. FFunction const& inFFunction, JPlusFunction const& inJPlusFunction)
  48. :
  49. mNumPDimensions(numPDimensions),
  50. mNumFDimensions(numFDimensions),
  51. mFFunction(inFFunction),
  52. mJPlusFunction(inJPlusFunction),
  53. mF(mNumFDimensions),
  54. mJ(mNumFDimensions, mNumPDimensions),
  55. mJTJ(mNumPDimensions, mNumPDimensions),
  56. mNegJTF(mNumPDimensions),
  57. mDecomposer(mNumPDimensions),
  58. mUseJFunction(false)
  59. {
  60. LogAssert(mNumPDimensions > 0 && mNumFDimensions > 0, "Invalid dimensions.");
  61. }
  62. // Disallow copy, assignment and move semantics.
  63. LevenbergMarquardtMinimizer(LevenbergMarquardtMinimizer const&) = delete;
  64. LevenbergMarquardtMinimizer& operator=(LevenbergMarquardtMinimizer const&) = delete;
  65. LevenbergMarquardtMinimizer(LevenbergMarquardtMinimizer&&) = delete;
  66. LevenbergMarquardtMinimizer& operator=(LevenbergMarquardtMinimizer&&) = delete;
  67. inline int GetNumPDimensions() const { return mNumPDimensions; }
  68. inline int GetNumFDimensions() const { return mNumFDimensions; }
  69. // The lambda is positive, the multiplier is positive, and the initial
  70. // guess for the p-parameter is p0. Typical choices are lambda =
  71. // 0.001 and multiplier = 10. TODO: Explain lambda in more detail,
  72. // Multiview Geometry mentions lambda = 0.001*average(diagonal(JTJ)),
  73. // but let's just expose the factor in front of the average.
  74. struct Result
  75. {
  76. DVector minLocation;
  77. Real minError;
  78. Real minErrorDifference;
  79. Real minUpdateLength;
  80. size_t numIterations;
  81. size_t numAdjustments;
  82. bool converged;
  83. };
  84. Result operator()(DVector const& p0, size_t maxIterations,
  85. Real updateLengthTolerance, Real errorDifferenceTolerance,
  86. Real lambdaFactor, Real lambdaAdjust, size_t maxAdjustments)
  87. {
  88. Result result;
  89. result.minLocation = p0;
  90. result.minError = std::numeric_limits<Real>::max();
  91. result.minErrorDifference = std::numeric_limits<Real>::max();
  92. result.minUpdateLength = (Real)0;
  93. result.numIterations = 0;
  94. result.numAdjustments = 0;
  95. result.converged = false;
  96. // As a simple precaution, ensure that the lambda inputs are
  97. // valid. If invalid, fall back to Gauss-Newton iteration.
  98. if (lambdaFactor <= (Real)0 || lambdaAdjust <= (Real)0)
  99. {
  100. maxAdjustments = 1;
  101. lambdaFactor = (Real)0;
  102. lambdaAdjust = (Real)1;
  103. }
  104. // As a simple precaution, ensure the tolerances are nonnegative.
  105. updateLengthTolerance = std::max(updateLengthTolerance, (Real)0);
  106. errorDifferenceTolerance = std::max(errorDifferenceTolerance, (Real)0);
  107. // Compute the initial error.
  108. mFFunction(p0, mF);
  109. result.minError = Dot(mF, mF);
  110. // Do the Levenberg-Marquart iterations.
  111. auto pCurrent = p0;
  112. for (result.numIterations = 1; result.numIterations <= maxIterations; ++result.numIterations)
  113. {
  114. std::pair<bool, bool> status;
  115. DVector pNext;
  116. for (result.numAdjustments = 0; result.numAdjustments < maxAdjustments; ++result.numAdjustments)
  117. {
  118. status = DoIteration(pCurrent, lambdaFactor, updateLengthTolerance,
  119. errorDifferenceTolerance, pNext, result);
  120. if (status.first)
  121. {
  122. // Either the Cholesky decomposition failed or the
  123. // iterates converged within tolerance. TODO: See the
  124. // note in DoIteration about not failing on Cholesky
  125. // decomposition.
  126. return result;
  127. }
  128. if (status.second)
  129. {
  130. // The error has been reduced but we have not yet
  131. // converged within tolerance.
  132. break;
  133. }
  134. lambdaFactor *= lambdaAdjust;
  135. }
  136. if (result.numAdjustments < maxAdjustments)
  137. {
  138. // The current value of lambda led us to an update that
  139. // reduced the error, but the error is not yet small
  140. // enough to conclude we converged. Reduce lambda for the
  141. // next outer-loop iteration.
  142. lambdaFactor /= lambdaAdjust;
  143. }
  144. else
  145. {
  146. // All lambdas tried during the inner-loop iteration did
  147. // not lead to a reduced error. If we do nothing here,
  148. // the next inner-loop iteration will continue to multiply
  149. // lambda, risking eventual floating-point overflow. To
  150. // avoid this, fall back to a Gauss-Newton iterate.
  151. status = DoIteration(pCurrent, lambdaFactor, updateLengthTolerance,
  152. errorDifferenceTolerance, pNext, result);
  153. if (status.first)
  154. {
  155. // Either the Cholesky decomposition failed or the
  156. // iterates converged within tolerance. TODO: See the
  157. // note in DoIteration about not failing on Cholesky
  158. // decomposition.
  159. return result;
  160. }
  161. }
  162. pCurrent = pNext;
  163. }
  164. return result;
  165. }
  166. private:
  167. void ComputeLinearSystemInputs(DVector const& pCurrent, Real lambda)
  168. {
  169. if (mUseJFunction)
  170. {
  171. mJFunction(pCurrent, mJ);
  172. mJTJ = MultiplyATB(mJ, mJ);
  173. mNegJTF = -(mF * mJ);
  174. }
  175. else
  176. {
  177. mJPlusFunction(pCurrent, mJTJ, mNegJTF);
  178. }
  179. Real diagonalSum(0);
  180. for (int i = 0; i < mNumPDimensions; ++i)
  181. {
  182. diagonalSum += mJTJ(i, i);
  183. }
  184. Real diagonalAdjust = lambda * diagonalSum / static_cast<Real>(mNumPDimensions);
  185. for (int i = 0; i < mNumPDimensions; ++i)
  186. {
  187. mJTJ(i, i) += diagonalAdjust;
  188. }
  189. }
  190. // The returned 'first' is true when the linear system cannot be
  191. // solved (result.converged is false in this case) or when the
  192. // error is reduced to within the tolerances specified by the caller
  193. // (result.converged is true in this case). When the 'first' value
  194. // is true, the 'second' value is true when the error is reduced or
  195. // false when it is not.
  196. std::pair<bool, bool> DoIteration(DVector const& pCurrent, Real lambdaFactor,
  197. Real updateLengthTolerance, Real errorDifferenceTolerance, DVector& pNext,
  198. Result& result)
  199. {
  200. ComputeLinearSystemInputs(pCurrent, lambdaFactor);
  201. if (!mDecomposer.Factor(mJTJ))
  202. {
  203. // TODO: The matrix mJTJ is positive semi-definite, so the
  204. // failure can occur when mJTJ has a zero eigenvalue in
  205. // which case mJTJ is not invertible. Generate an iterate
  206. // anyway, perhaps using gradient descent?
  207. return std::make_pair(true, false);
  208. }
  209. mDecomposer.SolveLower(mJTJ, mNegJTF);
  210. mDecomposer.SolveUpper(mJTJ, mNegJTF);
  211. pNext = pCurrent + mNegJTF;
  212. mFFunction(pNext, mF);
  213. Real error = Dot(mF, mF);
  214. if (error < result.minError)
  215. {
  216. result.minErrorDifference = result.minError - error;
  217. result.minUpdateLength = Length(mNegJTF);
  218. result.minLocation = pNext;
  219. result.minError = error;
  220. if (result.minErrorDifference <= errorDifferenceTolerance
  221. || result.minUpdateLength <= updateLengthTolerance)
  222. {
  223. result.converged = true;
  224. return std::make_pair(true, true);
  225. }
  226. else
  227. {
  228. return std::make_pair(false, true);
  229. }
  230. }
  231. else
  232. {
  233. return std::make_pair(false, false);
  234. }
  235. }
  236. int mNumPDimensions, mNumFDimensions;
  237. FFunction mFFunction;
  238. JFunction mJFunction;
  239. JPlusFunction mJPlusFunction;
  240. // Storage for J^T(p)*J(p) and -J^T(p)*F(p) during the iterations.
  241. RVector mF;
  242. JMatrix mJ;
  243. JTJMatrix mJTJ;
  244. JTFVector mNegJTF;
  245. CholeskyDecomposition<Real> mDecomposer;
  246. bool mUseJFunction;
  247. };
  248. }