GaussianElimination.h 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. // David Eberly, Geometric Tools, Redmond WA 98052
  2. // Copyright (c) 1998-2020
  3. // Distributed under the Boost Software License, Version 1.0.
  4. // https://www.boost.org/LICENSE_1_0.txt
  5. // https://www.geometrictools.com/License/Boost/LICENSE_1_0.txt
  6. // Version: 4.0.2019.11.23
  7. #pragma once
  8. #include <Mathematics/Logger.h>
  9. #include <Mathematics/LexicoArray2.h>
  10. #include <cstring>
  11. #include <vector>
  12. // The input matrix M must be NxN. The storage convention for element lookup
  13. // is determined by GTE_USE_ROW_MAJOR or GTE_USE_COL_MAJOR, whichever is
  14. // active. If you want the inverse of M, pass a nonnull pointer inverseM;
  15. // this matrix must also be NxN and use the same storage convention as M. If
  16. // you do not want the inverse of M, pass a nullptr for inverseM. If you want
  17. // to solve M*X = B for X, where X and B are Nx1, pass nonnull pointers for B
  18. // and X. If you want to solve M*Y = C for Y, where X and C are NxK, pass
  19. // nonnull pointers for C and Y and pass K to numCols. In all cases, pass
  20. // N to numRows.
  21. namespace WwiseGTE
  22. {
  23. template <typename Real>
  24. class GaussianElimination
  25. {
  26. public:
  27. bool operator()(int numRows,
  28. Real const* M, Real* inverseM, Real& determinant,
  29. Real const* B, Real* X,
  30. Real const* C, int numCols, Real* Y) const
  31. {
  32. if (numRows <= 0 || !M
  33. || ((B != nullptr) != (X != nullptr))
  34. || ((C != nullptr) != (Y != nullptr))
  35. || (C != nullptr && numCols < 1))
  36. {
  37. LogError("Invalid input.");
  38. }
  39. int numElements = numRows * numRows;
  40. bool wantInverse = (inverseM != nullptr);
  41. std::vector<Real> localInverseM;
  42. if (!wantInverse)
  43. {
  44. localInverseM.resize(numElements);
  45. inverseM = localInverseM.data();
  46. }
  47. Set(numElements, M, inverseM);
  48. if (B)
  49. {
  50. Set(numRows, B, X);
  51. }
  52. if (C)
  53. {
  54. Set(numRows * numCols, C, Y);
  55. }
  56. #if defined(GTE_USE_ROW_MAJOR)
  57. LexicoArray2<true, Real> matInvM(numRows, numRows, inverseM);
  58. LexicoArray2<true, Real> matY(numRows, numCols, Y);
  59. #else
  60. LexicoArray2<false, Real> matInvM(numRows, numRows, inverseM);
  61. LexicoArray2<false, Real> matY(numRows, numCols, Y);
  62. #endif
  63. std::vector<int> colIndex(numRows), rowIndex(numRows), pivoted(numRows);
  64. std::fill(pivoted.begin(), pivoted.end(), 0);
  65. Real const zero = (Real)0;
  66. Real const one = (Real)1;
  67. bool odd = false;
  68. determinant = one;
  69. // Elimination by full pivoting.
  70. int i1, i2, row = 0, col = 0;
  71. for (int i0 = 0; i0 < numRows; ++i0)
  72. {
  73. // Search matrix (excluding pivoted rows) for maximum absolute entry.
  74. Real maxValue = zero;
  75. for (i1 = 0; i1 < numRows; ++i1)
  76. {
  77. if (!pivoted[i1])
  78. {
  79. for (i2 = 0; i2 < numRows; ++i2)
  80. {
  81. if (!pivoted[i2])
  82. {
  83. Real value = matInvM(i1, i2);
  84. Real absValue = (value >= zero ? value : -value);
  85. if (absValue > maxValue)
  86. {
  87. maxValue = absValue;
  88. row = i1;
  89. col = i2;
  90. }
  91. }
  92. }
  93. }
  94. }
  95. if (maxValue == zero)
  96. {
  97. // The matrix is not invertible.
  98. if (wantInverse)
  99. {
  100. Set(numElements, nullptr, inverseM);
  101. }
  102. determinant = zero;
  103. if (B)
  104. {
  105. Set(numRows, nullptr, X);
  106. }
  107. if (C)
  108. {
  109. Set(numRows * numCols, nullptr, Y);
  110. }
  111. return false;
  112. }
  113. pivoted[col] = true;
  114. // Swap rows so that the pivot entry is in row 'col'.
  115. if (row != col)
  116. {
  117. odd = !odd;
  118. for (int i = 0; i < numRows; ++i)
  119. {
  120. std::swap(matInvM(row, i), matInvM(col, i));
  121. }
  122. if (B)
  123. {
  124. std::swap(X[row], X[col]);
  125. }
  126. if (C)
  127. {
  128. for (int i = 0; i < numCols; ++i)
  129. {
  130. std::swap(matY(row, i), matY(col, i));
  131. }
  132. }
  133. }
  134. // Keep track of the permutations of the rows.
  135. rowIndex[i0] = row;
  136. colIndex[i0] = col;
  137. // Scale the row so that the pivot entry is 1.
  138. Real diagonal = matInvM(col, col);
  139. determinant *= diagonal;
  140. Real inv = one / diagonal;
  141. matInvM(col, col) = one;
  142. for (i2 = 0; i2 < numRows; ++i2)
  143. {
  144. matInvM(col, i2) *= inv;
  145. }
  146. if (B)
  147. {
  148. X[col] *= inv;
  149. }
  150. if (C)
  151. {
  152. for (i2 = 0; i2 < numCols; ++i2)
  153. {
  154. matY(col, i2) *= inv;
  155. }
  156. }
  157. // Zero out the pivot column locations in the other rows.
  158. for (i1 = 0; i1 < numRows; ++i1)
  159. {
  160. if (i1 != col)
  161. {
  162. Real save = matInvM(i1, col);
  163. matInvM(i1, col) = zero;
  164. for (i2 = 0; i2 < numRows; ++i2)
  165. {
  166. matInvM(i1, i2) -= matInvM(col, i2) * save;
  167. }
  168. if (B)
  169. {
  170. X[i1] -= X[col] * save;
  171. }
  172. if (C)
  173. {
  174. for (i2 = 0; i2 < numCols; ++i2)
  175. {
  176. matY(i1, i2) -= matY(col, i2) * save;
  177. }
  178. }
  179. }
  180. }
  181. }
  182. if (wantInverse)
  183. {
  184. // Reorder rows to undo any permutations in Gaussian elimination.
  185. for (i1 = numRows - 1; i1 >= 0; --i1)
  186. {
  187. if (rowIndex[i1] != colIndex[i1])
  188. {
  189. for (i2 = 0; i2 < numRows; ++i2)
  190. {
  191. std::swap(matInvM(i2, rowIndex[i1]),
  192. matInvM(i2, colIndex[i1]));
  193. }
  194. }
  195. }
  196. }
  197. if (odd)
  198. {
  199. determinant = -determinant;
  200. }
  201. return true;
  202. }
  203. private:
  204. // Support for copying source to target or to set target to zero. If
  205. // source is nullptr, then target is set to zero; otherwise source is
  206. // copied to target. This function hides the type traits used to
  207. // determine whether Real is native floating-point or otherwise (such
  208. // as BSNumber or BSRational).
  209. void Set(int numElements, Real const* source, Real* target) const
  210. {
  211. if (std::is_floating_point<Real>() == std::true_type())
  212. {
  213. // Fast set/copy for native floating-point.
  214. size_t numBytes = numElements * sizeof(Real);
  215. if (source)
  216. {
  217. std::memcpy(target, source, numBytes);
  218. }
  219. else
  220. {
  221. std::memset(target, 0, numBytes);
  222. }
  223. }
  224. else
  225. {
  226. // The inputs are not std containers, so ensure assignment works
  227. // correctly.
  228. if (source)
  229. {
  230. for (int i = 0; i < numElements; ++i)
  231. {
  232. target[i] = source[i];
  233. }
  234. }
  235. else
  236. {
  237. Real const zero = (Real)0;
  238. for (int i = 0; i < numElements; ++i)
  239. {
  240. target[i] = zero;
  241. }
  242. }
  243. }
  244. }
  245. };
  246. }