LinearSystem.h 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. // David Eberly, Geometric Tools, Redmond WA 98052
  2. // Copyright (c) 1998-2020
  3. // Distributed under the Boost Software License, Version 1.0.
  4. // https://www.boost.org/LICENSE_1_0.txt
  5. // https://www.geometrictools.com/License/Boost/LICENSE_1_0.txt
  6. // Version: 4.0.2019.08.13
  7. #pragma once
  8. #include <Mathematics/Matrix2x2.h>
  9. #include <Mathematics/Matrix3x3.h>
  10. #include <Mathematics/Matrix4x4.h>
  11. #include <Mathematics/GaussianElimination.h>
  12. #include <map>
  13. // Solve linear systems of equations where the matrix A is NxN. The return
  14. // value of a function is 'true' when A is invertible. In this case the
  15. // solution X and the solution is valid. If the return value is 'false', A
  16. // is not invertible and X and Y are invalid, so do not use them. When a
  17. // matrix is passed as Real*, the storage order is assumed to be the one
  18. // consistent with your choice of GTE_USE_ROW_MAJOR or GTE_USE_COL_MAJOR.
  19. //
  20. // The linear solvers that use the conjugate gradient algorithm are based
  21. // on the discussion in "Matrix Computations, 2nd edition" by G. H. Golub
  22. // and Charles F. Van Loan, The Johns Hopkins Press, Baltimore MD, Fourth
  23. // Printing 1993.
  24. namespace WwiseGTE
  25. {
  26. template <typename Real>
  27. class LinearSystem
  28. {
  29. public:
  30. // Solve 2x2, 3x3, and 4x4 systems by inverting the matrix directly.
  31. // This avoids the overhead of Gaussian elimination in small
  32. // dimensions.
  33. static bool Solve(Matrix2x2<Real> const& A, Vector2<Real> const& B, Vector2<Real>& X)
  34. {
  35. bool invertible;
  36. Matrix2x2<Real> invA = Inverse(A, &invertible);
  37. if (invertible)
  38. {
  39. X = invA * B;
  40. }
  41. else
  42. {
  43. X = Vector2<Real>::Zero();
  44. }
  45. return invertible;
  46. }
  47. static bool Solve(Matrix3x3<Real> const& A, Vector3<Real> const& B, Vector3<Real>& X)
  48. {
  49. bool invertible;
  50. Matrix3x3<Real> invA = Inverse(A, &invertible);
  51. if (invertible)
  52. {
  53. X = invA * B;
  54. }
  55. else
  56. {
  57. X = Vector3<Real>::Zero();
  58. }
  59. return invertible;
  60. }
  61. static bool Solve(Matrix4x4<Real> const& A, Vector4<Real> const& B, Vector4<Real>& X)
  62. {
  63. bool invertible;
  64. Matrix4x4<Real> invA = Inverse(A, &invertible);
  65. if (invertible)
  66. {
  67. X = invA * B;
  68. }
  69. else
  70. {
  71. X = Vector4<Real>::Zero();
  72. }
  73. return invertible;
  74. }
  75. // Solve A*X = B, where B is Nx1 and the solution X is Nx1.
  76. static bool Solve(int N, Real const* A, Real const* B, Real* X)
  77. {
  78. Real determinant;
  79. return GaussianElimination<Real>()(N, A, nullptr, determinant, B, X,
  80. nullptr, 0, nullptr);
  81. }
  82. // Solve A*X = B, where B is NxM and the solution X is NxM.
  83. static bool Solve(int N, int M, Real const* A, Real const* B, Real* X)
  84. {
  85. Real determinant;
  86. return GaussianElimination<Real>()(N, A, nullptr, determinant, nullptr,
  87. nullptr, B, M, X);
  88. }
  89. // Solve A*X = B, where A is tridiagonal. The function expects the
  90. // subdiagonal, diagonal, and superdiagonal of A. The diagonal input
  91. // must have N elements. The subdiagonal and superdiagonal inputs
  92. // must have N-1 elements.
  93. static bool SolveTridiagonal(int N, Real const* subdiagonal,
  94. Real const* diagonal, Real const* superdiagonal, Real const* B, Real* X)
  95. {
  96. if (diagonal[0] == (Real)0)
  97. {
  98. return false;
  99. }
  100. std::vector<Real> tmp(N - 1);
  101. Real expr = diagonal[0];
  102. Real invExpr = ((Real)1) / expr;
  103. X[0] = B[0] * invExpr;
  104. int i0, i1;
  105. for (i0 = 0, i1 = 1; i1 < N; ++i0, ++i1)
  106. {
  107. tmp[i0] = superdiagonal[i0] * invExpr;
  108. expr = diagonal[i1] - subdiagonal[i0] * tmp[i0];
  109. if (expr == (Real)0)
  110. {
  111. return false;
  112. }
  113. invExpr = ((Real)1) / expr;
  114. X[i1] = (B[i1] - subdiagonal[i0] * X[i0]) * invExpr;
  115. }
  116. for (i0 = N - 1, i1 = N - 2; i1 >= 0; --i0, --i1)
  117. {
  118. X[i1] -= tmp[i1] * X[i0];
  119. }
  120. return true;
  121. }
  122. // Solve A*X = B, where A is tridiagonal. The function expects the
  123. // subdiagonal, diagonal, and superdiagonal of A. Moreover, the
  124. // subdiagonal elements are a constant, the diagonal elements are a
  125. // constant, and the superdiagonal elements are a constant.
  126. static bool SolveConstantTridiagonal(int N, Real subdiagonal,
  127. Real diagonal, Real superdiagonal, Real const* B, Real* X)
  128. {
  129. if (diagonal == (Real)0)
  130. {
  131. return false;
  132. }
  133. std::vector<Real> tmp(N - 1);
  134. Real expr = diagonal;
  135. Real invExpr = ((Real)1) / expr;
  136. X[0] = B[0] * invExpr;
  137. int i0, i1;
  138. for (i0 = 0, i1 = 1; i1 < N; ++i0, ++i1)
  139. {
  140. tmp[i0] = superdiagonal * invExpr;
  141. expr = diagonal - subdiagonal * tmp[i0];
  142. if (expr == (Real)0)
  143. {
  144. return false;
  145. }
  146. invExpr = ((Real)1) / expr;
  147. X[i1] = (B[i1] - subdiagonal * X[i0]) * invExpr;
  148. }
  149. for (i0 = N - 1, i1 = N - 2; i1 >= 0; --i0, --i1)
  150. {
  151. X[i1] -= tmp[i1] * X[i0];
  152. }
  153. return true;
  154. }
  155. // Solve A*X = B using the conjugate gradient method, where A is
  156. // symmetric. You must specify the maximum number of iterations and a
  157. // tolerance for terminating the iterations. Reasonable choices for
  158. // tolerance are 1e-06f for 'float' or 1e-08 for 'double'.
  159. static unsigned int SolveSymmetricCG(int N, Real const* A, Real const* B,
  160. Real* X, unsigned int maxIterations, Real tolerance)
  161. {
  162. // The first iteration.
  163. std::vector<Real> tmpR(N), tmpP(N), tmpW(N);
  164. Real* R = &tmpR[0];
  165. Real* P = &tmpP[0];
  166. Real* W = &tmpW[0];
  167. size_t numBytes = N * sizeof(Real);
  168. std::memset(X, 0, numBytes);
  169. std::memcpy(R, B, numBytes);
  170. Real rho0 = Dot(N, R, R);
  171. std::memcpy(P, R, numBytes);
  172. Mul(N, A, P, W);
  173. Real alpha = rho0 / Dot(N, P, W);
  174. UpdateX(N, X, alpha, P);
  175. UpdateR(N, R, alpha, W);
  176. Real rho1 = Dot(N, R, R);
  177. // The remaining iterations.
  178. unsigned int iteration;
  179. for (iteration = 1; iteration <= maxIterations; ++iteration)
  180. {
  181. Real root0 = std::sqrt(rho1);
  182. Real norm = Dot(N, B, B);
  183. Real root1 = std::sqrt(norm);
  184. if (root0 <= tolerance * root1)
  185. {
  186. break;
  187. }
  188. Real beta = rho1 / rho0;
  189. UpdateP(N, P, beta, R);
  190. Mul(N, A, P, W);
  191. alpha = rho1 / Dot(N, P, W);
  192. UpdateX(N, X, alpha, P);
  193. UpdateR(N, R, alpha, W);
  194. rho0 = rho1;
  195. rho1 = Dot(N, R, R);
  196. }
  197. return iteration;
  198. }
  199. // Solve A*X = B using the conjugate gradient method, where A is
  200. // sparse and symmetric. The nonzero entries of the symmetrix matrix
  201. // A are stored in a map whose keys are pairs (i,j) and whose values
  202. // are real numbers. The pair (i,j) is the location of the value in
  203. // the array. Only one of (i,j) and (j,i) should be stored since A is
  204. // symmetric. The column vector B is stored as an array of contiguous
  205. // values. You must specify the maximum number of iterations and a
  206. // tolerance for terminating the iterations. Reasonable choices for
  207. // tolerance are 1e-06f for 'float' or 1e-08 for 'double'.
  208. typedef std::map<std::array<int, 2>, Real> SparseMatrix;
  209. static unsigned int SolveSymmetricCG(int N, SparseMatrix const& A,
  210. Real const* B, Real* X, unsigned int maxIterations, Real tolerance)
  211. {
  212. // The first iteration.
  213. std::vector<Real> tmpR(N), tmpP(N), tmpW(N);
  214. Real* R = &tmpR[0];
  215. Real* P = &tmpP[0];
  216. Real* W = &tmpW[0];
  217. size_t numBytes = N * sizeof(Real);
  218. std::memset(X, 0, numBytes);
  219. std::memcpy(R, B, numBytes);
  220. Real rho0 = Dot(N, R, R);
  221. std::memcpy(P, R, numBytes);
  222. Mul(N, A, P, W);
  223. Real alpha = rho0 / Dot(N, P, W);
  224. UpdateX(N, X, alpha, P);
  225. UpdateR(N, R, alpha, W);
  226. Real rho1 = Dot(N, R, R);
  227. // The remaining iterations.
  228. unsigned int iteration;
  229. for (iteration = 1; iteration <= maxIterations; ++iteration)
  230. {
  231. Real root0 = std::sqrt(rho1);
  232. Real norm = Dot(N, B, B);
  233. Real root1 = std::sqrt(norm);
  234. if (root0 <= tolerance * root1)
  235. {
  236. break;
  237. }
  238. Real beta = rho1 / rho0;
  239. UpdateP(N, P, beta, R);
  240. Mul(N, A, P, W);
  241. alpha = rho1 / Dot(N, P, W);
  242. UpdateX(N, X, alpha, P);
  243. UpdateR(N, R, alpha, W);
  244. rho0 = rho1;
  245. rho1 = Dot(N, R, R);
  246. }
  247. return iteration;
  248. }
  249. private:
  250. // Support for the conjugate gradient method.
  251. static Real Dot(int N, Real const* U, Real const* V)
  252. {
  253. Real dot = (Real)0;
  254. for (int i = 0; i < N; ++i)
  255. {
  256. dot += U[i] * V[i];
  257. }
  258. return dot;
  259. }
  260. static void Mul(int N, Real const* A, Real const* X, Real* P)
  261. {
  262. #if defined(GTE_USE_ROW_MAJOR)
  263. LexicoArray2<true, Real> matA(N, N, const_cast<Real*>(A));
  264. #else
  265. LexicoArray2<false, Real> matA(N, N, const_cast<Real*>(A));
  266. #endif
  267. std::memset(P, 0, N * sizeof(Real));
  268. for (int row = 0; row < N; ++row)
  269. {
  270. for (int col = 0; col < N; ++col)
  271. {
  272. P[row] += matA(row, col) * X[col];
  273. }
  274. }
  275. }
  276. static void Mul(int N, SparseMatrix const& A, Real const* X, Real* P)
  277. {
  278. std::memset(P, 0, N * sizeof(Real));
  279. for (auto const& element : A)
  280. {
  281. int i = element.first[0];
  282. int j = element.first[1];
  283. Real value = element.second;
  284. P[i] += value * X[j];
  285. if (i != j)
  286. {
  287. P[j] += value * X[i];
  288. }
  289. }
  290. }
  291. static void UpdateX(int N, Real* X, Real alpha, Real const* P)
  292. {
  293. for (int i = 0; i < N; ++i)
  294. {
  295. X[i] += alpha * P[i];
  296. }
  297. }
  298. static void UpdateR(int N, Real* R, Real alpha, Real const* W)
  299. {
  300. for (int i = 0; i < N; ++i)
  301. {
  302. R[i] -= alpha * W[i];
  303. }
  304. }
  305. static void UpdateP(int N, Real* P, Real beta, Real const* R)
  306. {
  307. for (int i = 0; i < N; ++i)
  308. {
  309. P[i] = R[i] + beta * P[i];
  310. }
  311. }
  312. };
  313. }