QuarticRootsQR.h 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. // David Eberly, Geometric Tools, Redmond WA 98052
  2. // Copyright (c) 1998-2020
  3. // Distributed under the Boost Software License, Version 1.0.
  4. // https://www.boost.org/LICENSE_1_0.txt
  5. // https://www.geometrictools.com/License/Boost/LICENSE_1_0.txt
  6. // Version: 4.0.2019.08.13
  7. #pragma once
  8. #include <Mathematics/CubicRootsQR.h>
  9. // An implementation of the QR algorithm described in "Matrix Computations,
  10. // 2nd edition" by G. H. Golub and C. F. Van Loan, The Johns Hopkins
  11. // University Press, Baltimore MD, Fourth Printing 1993. In particular,
  12. // the implementation is based on Chapter 7 (The Unsymmetric Eigenvalue
  13. // Problem), Section 7.5 (The Practical QR Algorithm). The algorithm is
  14. // specialized for the companion matrix associated with a quartic polynomial.
  15. namespace WwiseGTE
  16. {
  17. template <typename Real>
  18. class QuarticRootsQR
  19. {
  20. public:
  21. typedef std::array<std::array<Real, 4>, 4> Matrix;
  22. // Solve p(x) = c0 + c1 * x + c2 * x^2 + c3 * x^3 + x^4 = 0.
  23. uint32_t operator() (uint32_t maxIterations, Real c0, Real c1, Real c2, Real c3,
  24. uint32_t& numRoots, std::array<Real, 4>& roots) const
  25. {
  26. // Create the companion matrix for the polynomial. The matrix is
  27. // in upper Hessenberg form.
  28. Matrix A;
  29. A[0][0] = (Real)0;
  30. A[0][1] = (Real)0;
  31. A[0][2] = (Real)0;
  32. A[0][3] = -c0;
  33. A[1][0] = (Real)1;
  34. A[1][1] = (Real)0;
  35. A[1][2] = (Real)0;
  36. A[1][3] = -c1;
  37. A[2][0] = (Real)0;
  38. A[2][1] = (Real)1;
  39. A[2][2] = (Real)0;
  40. A[2][3] = -c2;
  41. A[3][0] = (Real)0;
  42. A[3][1] = (Real)0;
  43. A[3][2] = (Real)1;
  44. A[3][3] = -c3;
  45. // Avoid the QR-cycle when c1 = c2 = 0 and avoid the slow
  46. // convergence when c1 and c2 are nearly zero.
  47. std::array<Real, 3> V{
  48. (Real)1,
  49. (Real)0.36602540378443865,
  50. (Real)0.36602540378443865 };
  51. DoIteration(V, A);
  52. return operator()(maxIterations, A, numRoots, roots);
  53. }
  54. // Compute the real eigenvalues of the upper Hessenberg matrix A. The
  55. // matrix is modified by in-place operations, so if you need to
  56. // remember A, you must make your own copy before calling this
  57. // function.
  58. uint32_t operator() (uint32_t maxIterations, Matrix& A,
  59. uint32_t& numRoots, std::array<Real, 4>& roots) const
  60. {
  61. numRoots = 0;
  62. std::fill(roots.begin(), roots.end(), (Real)0);
  63. for (uint32_t numIterations = 0; numIterations < maxIterations; ++numIterations)
  64. {
  65. // Apply a Francis QR iteration.
  66. Real tr = A[2][2] + A[3][3];
  67. Real det = A[2][2] * A[3][3] - A[2][3] * A[3][2];
  68. std::array<Real, 3> X{
  69. A[0][0] * A[0][0] + A[0][1] * A[1][0] - tr * A[0][0] + det,
  70. A[1][0] * (A[0][0] + A[1][1] - tr),
  71. A[1][0] * A[2][1] };
  72. std::array<Real, 3> V = House<3>(X);
  73. DoIteration(V, A);
  74. // Test for uncoupling of A.
  75. Real tr12 = A[1][1] + A[2][2];
  76. if (tr12 + A[2][1] == tr12)
  77. {
  78. GetQuadraticRoots(0, 1, A, numRoots, roots);
  79. GetQuadraticRoots(2, 3, A, numRoots, roots);
  80. return numIterations;
  81. }
  82. Real tr01 = A[0][0] + A[1][1];
  83. if (tr01 + A[1][0] == tr01)
  84. {
  85. numRoots = 1;
  86. roots[0] = A[0][0];
  87. // TODO: The cubic solver is not designed to process 3x3
  88. // submatrices of an NxN matrix, so the copy of a
  89. // submatrix of A to B is a simple workaround for running
  90. // the solver. Write general root-finding/ code that
  91. // avoids such copying.
  92. uint32_t subMaxIterations = maxIterations - numIterations;
  93. typename CubicRootsQR<Real>::Matrix B;
  94. for (int r = 0; r < 3; ++r)
  95. {
  96. for (int c = 0; c < 3; ++c)
  97. {
  98. B[r][c] = A[r + 1][c + 1];
  99. }
  100. }
  101. uint32_t numSubroots = 0;
  102. std::array<Real, 3> subroots;
  103. uint32_t numSubiterations = CubicRootsQR<Real>()(subMaxIterations, B,
  104. numSubroots, subroots);
  105. for (uint32_t i = 0; i < numSubroots; ++i)
  106. {
  107. roots[numRoots++] = subroots[i];
  108. }
  109. return numIterations + numSubiterations;
  110. }
  111. Real tr23 = A[2][2] + A[3][3];
  112. if (tr23 + A[3][2] == tr23)
  113. {
  114. numRoots = 1;
  115. roots[0] = A[3][3];
  116. // TODO: The cubic solver is not designed to process 3x3
  117. // submatrices of an NxN matrix, so the copy of a
  118. // submatrix of A to B is a simple workaround for running
  119. // the solver. Write general root-finding/ code that
  120. // avoids such copying.
  121. uint32_t subMaxIterations = maxIterations - numIterations;
  122. typename CubicRootsQR<Real>::Matrix B;
  123. for (int r = 0; r < 3; ++r)
  124. {
  125. for (int c = 0; c < 3; ++c)
  126. {
  127. B[r][c] = A[r][c];
  128. }
  129. }
  130. uint32_t numSubroots = 0;
  131. std::array<Real, 3> subroots;
  132. uint32_t numSubiterations = CubicRootsQR<Real>()(subMaxIterations, B,
  133. numSubroots, subroots);
  134. for (uint32_t i = 0; i < numSubroots; ++i)
  135. {
  136. roots[numRoots++] = subroots[i];
  137. }
  138. return numIterations + numSubiterations;
  139. }
  140. }
  141. return maxIterations;
  142. }
  143. private:
  144. void DoIteration(std::array<Real, 3> const& V, Matrix& A) const
  145. {
  146. Real multV = (Real)-2 / (V[0] * V[0] + V[1] * V[1] + V[2] * V[2]);
  147. std::array<Real, 3> MV{ multV * V[0], multV * V[1], multV * V[2] };
  148. RowHouse<3>(0, 2, 0, 3, V, MV, A);
  149. ColHouse<3>(0, 3, 0, 2, V, MV, A);
  150. std::array<Real, 3> X{ A[1][0], A[2][0], A[3][0] };
  151. std::array<Real, 3> locV = House<3>(X);
  152. multV = (Real)-2 / (locV[0] * locV[0] + locV[1] * locV[1] + locV[2] * locV[2]);
  153. MV = { multV * locV[0], multV * locV[1], multV * locV[2] };
  154. RowHouse<3>(1, 3, 0, 3, locV, MV, A);
  155. ColHouse<3>(0, 3, 1, 3, locV, MV, A);
  156. std::array<Real, 2> Y{ A[2][1], A[3][1] };
  157. std::array<Real, 2> W = House<2>(Y);
  158. Real multW = (Real)-2 / (W[0] * W[0] + W[1] * W[1]);
  159. std::array<Real, 2> MW = { multW * W[0], multW * W[1] };
  160. RowHouse<2>(2, 3, 0, 3, W, MW, A);
  161. ColHouse<2>(0, 3, 2, 3, W, MW, A);
  162. }
  163. template <int N>
  164. std::array<Real, N> House(std::array<Real, N> const& X) const
  165. {
  166. std::array<Real, N> V;
  167. Real length = (Real)0;
  168. for (int i = 0; i < N; ++i)
  169. {
  170. length += X[i] * X[i];
  171. }
  172. length = std::sqrt(length);
  173. if (length != (Real)0)
  174. {
  175. Real sign = (X[0] >= (Real)0 ? (Real)1 : (Real)-1);
  176. Real denom = X[0] + sign * length;
  177. for (int i = 1; i < N; ++i)
  178. {
  179. V[i] = X[i] / denom;
  180. }
  181. }
  182. else
  183. {
  184. V.fill((Real)0);
  185. }
  186. V[0] = (Real)1;
  187. return V;
  188. }
  189. template <int N>
  190. void RowHouse(int rmin, int rmax, int cmin, int cmax,
  191. std::array<Real, N> const& V, std::array<Real, N> const& MV, Matrix& A) const
  192. {
  193. // Only elements cmin through cmax are used.
  194. std::array<Real, 4> W;
  195. for (int c = cmin; c <= cmax; ++c)
  196. {
  197. W[c] = (Real)0;
  198. for (int r = rmin, k = 0; r <= rmax; ++r, ++k)
  199. {
  200. W[c] += V[k] * A[r][c];
  201. }
  202. }
  203. for (int r = rmin, k = 0; r <= rmax; ++r, ++k)
  204. {
  205. for (int c = cmin; c <= cmax; ++c)
  206. {
  207. A[r][c] += MV[k] * W[c];
  208. }
  209. }
  210. }
  211. template <int N>
  212. void ColHouse(int rmin, int rmax, int cmin, int cmax,
  213. std::array<Real, N> const& V, std::array<Real, N> const& MV, Matrix& A) const
  214. {
  215. // Only elements rmin through rmax are used.
  216. std::array<Real, 4> W;
  217. for (int r = rmin; r <= rmax; ++r)
  218. {
  219. W[r] = (Real)0;
  220. for (int c = cmin, k = 0; c <= cmax; ++c, ++k)
  221. {
  222. W[r] += V[k] * A[r][c];
  223. }
  224. }
  225. for (int r = rmin; r <= rmax; ++r)
  226. {
  227. for (int c = cmin, k = 0; c <= cmax; ++c, ++k)
  228. {
  229. A[r][c] += W[r] * MV[k];
  230. }
  231. }
  232. }
  233. void GetQuadraticRoots(int i0, int i1, Matrix const& A,
  234. uint32_t& numRoots, std::array<Real, 4>& roots) const
  235. {
  236. // Solve x^2 - t * x + d = 0, where t is the trace and d is the
  237. // determinant of the 2x2 matrix defined by indices i0 and i1.
  238. // The discriminant is D = (t/2)^2 - d. When D >= 0, the roots
  239. // are real values t/2 - sqrt(D) and t/2 + sqrt(D). To avoid
  240. // potential numerical issues with subtractive cancellation, the
  241. // roots are computed as
  242. // r0 = t/2 + sign(t/2)*sqrt(D), r1 = trace - r0
  243. Real trace = A[i0][i0] + A[i1][i1];
  244. Real halfTrace = trace * (Real)0.5;
  245. Real determinant = A[i0][i0] * A[i1][i1] - A[i0][i1] * A[i1][i0];
  246. Real discriminant = halfTrace * halfTrace - determinant;
  247. if (discriminant >= (Real)0)
  248. {
  249. Real sign = (trace >= (Real)0 ? (Real)1 : (Real)-1);
  250. Real root = halfTrace + sign * std::sqrt(discriminant);
  251. roots[numRoots++] = root;
  252. roots[numRoots++] = trace - root;
  253. }
  254. }
  255. };
  256. }