BandedMatrix.h 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520
  1. // David Eberly, Geometric Tools, Redmond WA 98052
  2. // Copyright (c) 1998-2020
  3. // Distributed under the Boost Software License, Version 1.0.
  4. // https://www.boost.org/LICENSE_1_0.txt
  5. // https://www.geometrictools.com/License/Boost/LICENSE_1_0.txt
  6. // Version: 4.0.2019.08.13
  7. #pragma once
  8. #include <Mathematics/Math.h>
  9. #include <Mathematics/LexicoArray2.h>
  10. #include <vector>
  11. namespace WwiseGTE
  12. {
  13. template <typename Real>
  14. class BandedMatrix
  15. {
  16. public:
  17. // Construction and destruction.
  18. BandedMatrix(int size, int numLBands, int numUBands)
  19. :
  20. mSize(size),
  21. mZero((Real)0)
  22. {
  23. if (size > 0
  24. && 0 <= numLBands && numLBands < size
  25. && 0 <= numUBands && numUBands < size)
  26. {
  27. mDBand.resize(size);
  28. std::fill(mDBand.begin(), mDBand.end(), (Real)0);
  29. if (numLBands > 0)
  30. {
  31. mLBands.resize(numLBands);
  32. int numElements = size - 1;
  33. for (auto& band : mLBands)
  34. {
  35. band.resize(numElements--);
  36. std::fill(band.begin(), band.end(), (Real)0);
  37. }
  38. }
  39. if (numUBands > 0)
  40. {
  41. mUBands.resize(numUBands);
  42. int numElements = size - 1;
  43. for (auto& band : mUBands)
  44. {
  45. band.resize(numElements--);
  46. std::fill(band.begin(), band.end(), (Real)0);
  47. }
  48. }
  49. }
  50. else
  51. {
  52. // Invalid argument to BandedMatrix constructor.
  53. mSize = 0;
  54. }
  55. }
  56. ~BandedMatrix()
  57. {
  58. }
  59. // Member access.
  60. inline int GetSize() const
  61. {
  62. return mSize;
  63. }
  64. inline std::vector<Real>& GetDBand()
  65. {
  66. return mDBand;
  67. }
  68. inline std::vector<Real> const& GetDBand() const
  69. {
  70. return mDBand;
  71. }
  72. inline std::vector<std::vector<Real>>& GetLBands()
  73. {
  74. return mLBands;
  75. }
  76. inline std::vector<std::vector<Real>> const& GetLBands() const
  77. {
  78. return mLBands;
  79. }
  80. inline std::vector<std::vector<Real>>& GetUBands()
  81. {
  82. return mUBands;
  83. }
  84. inline std::vector<std::vector<Real>> const& GetUBands() const
  85. {
  86. return mUBands;
  87. }
  88. Real& operator()(int r, int c)
  89. {
  90. if (0 <= r && r < mSize && 0 <= c && c < mSize)
  91. {
  92. int band = c - r;
  93. if (band > 0)
  94. {
  95. int const numUBands = static_cast<int>(mUBands.size());
  96. if (--band < numUBands && r < mSize - 1 - band)
  97. {
  98. return mUBands[band][r];
  99. }
  100. }
  101. else if (band < 0)
  102. {
  103. band = -band;
  104. int const numLBands = static_cast<int>(mLBands.size());
  105. if (--band < numLBands && c < mSize - 1 - band)
  106. {
  107. return mLBands[band][c];
  108. }
  109. }
  110. else
  111. {
  112. return mDBand[r];
  113. }
  114. }
  115. // else invalid index
  116. // Set the value to zero in case someone unknowingly modified mZero on a
  117. // previous call to operator(int,int).
  118. mZero = (Real)0;
  119. return mZero;
  120. }
  121. Real const& operator()(int r, int c) const
  122. {
  123. if (0 <= r && r < mSize && 0 <= c && c < mSize)
  124. {
  125. int band = c - r;
  126. if (band > 0)
  127. {
  128. int const numUBands = static_cast<int>(mUBands.size());
  129. if (--band < numUBands && r < mSize - 1 - band)
  130. {
  131. return mUBands[band][r];
  132. }
  133. }
  134. else if (band < 0)
  135. {
  136. band = -band;
  137. int const numLBands = static_cast<int>(mLBands.size());
  138. if (--band < numLBands && c < mSize - 1 - band)
  139. {
  140. return mLBands[band][c];
  141. }
  142. }
  143. else
  144. {
  145. return mDBand[r];
  146. }
  147. }
  148. // else invalid index
  149. // Set the value to zero in case someone unknowingly modified
  150. // mZero on a previous call to operator(int,int).
  151. mZero = (Real)0;
  152. return mZero;
  153. }
  154. // Factor the square banded matrix A into A = L*L^T, where L is a
  155. // lower-triangular matrix (L^T is an upper-triangular matrix). This
  156. // is an LU decomposition that allows for stable inversion of A to
  157. // solve A*X = B. The return value is 'true' iff the factorizing is
  158. // successful (L is invertible). If successful, A contains the
  159. // Cholesky factorization: L in the lower-triangular part of A and
  160. // L^T in the upper-triangular part of A.
  161. bool CholeskyFactor()
  162. {
  163. if (mDBand.size() == 0 || mLBands.size() != mUBands.size())
  164. {
  165. // Invalid number of bands.
  166. return false;
  167. }
  168. int const sizeM1 = mSize - 1;
  169. int const numBands = static_cast<int>(mLBands.size());
  170. int k, kMax;
  171. for (int i = 0; i < mSize; ++i)
  172. {
  173. int jMin = i - numBands;
  174. if (jMin < 0)
  175. {
  176. jMin = 0;
  177. }
  178. int j;
  179. for (j = jMin; j < i; ++j)
  180. {
  181. kMax = j + numBands;
  182. if (kMax > sizeM1)
  183. {
  184. kMax = sizeM1;
  185. }
  186. for (k = i; k <= kMax; ++k)
  187. {
  188. operator()(k, i) -= operator()(i, j) * operator()(k, j);
  189. }
  190. }
  191. kMax = j + numBands;
  192. if (kMax > sizeM1)
  193. {
  194. kMax = sizeM1;
  195. }
  196. for (k = 0; k < i; ++k)
  197. {
  198. operator()(k, i) = operator()(i, k);
  199. }
  200. Real diagonal = operator()(i, i);
  201. if (diagonal <= (Real)0)
  202. {
  203. return false;
  204. }
  205. Real invSqrt = ((Real)1) / std::sqrt(diagonal);
  206. for (k = i; k <= kMax; ++k)
  207. {
  208. operator()(k, i) *= invSqrt;
  209. }
  210. }
  211. return true;
  212. }
  213. // Solve the linear system A*X = B, where A is an NxN banded matrix
  214. // and B is an Nx1 vector. The unknown X is also Nx1. The input to
  215. // this function is B. The output X is computed and stored in B. The
  216. // return value is 'true' iff the system has a solution. The matrix A
  217. // and the vector B are both modified by this function. If
  218. // successful, A contains the Cholesky factorization: L in the
  219. // lower-triangular part of A and L^T in the upper-triangular part
  220. // of A.
  221. bool SolveSystem(Real* bVector)
  222. {
  223. return CholeskyFactor()
  224. && SolveLower(bVector)
  225. && SolveUpper(bVector);
  226. }
  227. // Solve the linear system A*X = B, where A is an NxN banded matrix
  228. // and B is an NxM matrix. The unknown X is also NxM. The input to
  229. // this function is B. The output X is computed and stored in B. The
  230. // return value is 'true' iff the system has a solution. The matrix A
  231. // and the vector B are both modified by this function. If
  232. // successful, A contains the Cholesky factorization: L in the
  233. // lower-triangular part of A and L^T in the upper-triangular part
  234. // of A.
  235. //
  236. // 'bMatrix' must have the storage order specified by the template
  237. // parameter.
  238. template <bool RowMajor>
  239. bool SolveSystem(Real* bMatrix, int numBColumns)
  240. {
  241. return CholeskyFactor()
  242. && SolveLower<RowMajor>(bMatrix, numBColumns)
  243. && SolveUpper<RowMajor>(bMatrix, numBColumns);
  244. }
  245. // Compute the inverse of the banded matrix. The return value is
  246. // 'true' when the matrix is invertible, in which case the 'inverse'
  247. // output is valid. The return value is 'false' when the matrix is
  248. // not invertible, in which case 'inverse' is invalid and should not
  249. // be used. The input matrix 'inverse' must be the same size as
  250. // 'this'.
  251. //
  252. // 'bMatrix' must have the storage order specified by the template
  253. // parameter.
  254. template <bool RowMajor>
  255. bool ComputeInverse(Real* inverse) const
  256. {
  257. LexicoArray2<RowMajor, Real> invA(mSize, mSize, inverse);
  258. BandedMatrix<Real> tmpA = *this;
  259. for (int row = 0; row < mSize; ++row)
  260. {
  261. for (int col = 0; col < mSize; ++col)
  262. {
  263. if (row != col)
  264. {
  265. invA(row, col) = (Real)0;
  266. }
  267. else
  268. {
  269. invA(row, row) = (Real)1;
  270. }
  271. }
  272. }
  273. // Forward elimination.
  274. for (int row = 0; row < mSize; ++row)
  275. {
  276. // The pivot must be nonzero in order to proceed.
  277. Real diag = tmpA(row, row);
  278. if (diag == (Real)0)
  279. {
  280. return false;
  281. }
  282. Real invDiag = ((Real)1) / diag;
  283. tmpA(row, row) = (Real)1;
  284. // Multiply the row to be consistent with diagonal term of 1.
  285. int colMin = row + 1;
  286. int colMax = colMin + static_cast<int>(mUBands.size());
  287. if (colMax > mSize)
  288. {
  289. colMax = mSize;
  290. }
  291. int c;
  292. for (c = colMin; c < colMax; ++c)
  293. {
  294. tmpA(row, c) *= invDiag;
  295. }
  296. for (c = 0; c <= row; ++c)
  297. {
  298. invA(row, c) *= invDiag;
  299. }
  300. // Reduce the remaining rows.
  301. int rowMin = row + 1;
  302. int rowMax = rowMin + static_cast<int>(mLBands.size());
  303. if (rowMax > mSize)
  304. {
  305. rowMax = mSize;
  306. }
  307. for (int r = rowMin; r < rowMax; ++r)
  308. {
  309. Real mult = tmpA(r, row);
  310. tmpA(r, row) = (Real)0;
  311. for (c = colMin; c < colMax; ++c)
  312. {
  313. tmpA(r, c) -= mult * tmpA(row, c);
  314. }
  315. for (c = 0; c <= row; ++c)
  316. {
  317. invA(r, c) -= mult * invA(row, c);
  318. }
  319. }
  320. }
  321. // Backward elimination.
  322. for (int row = mSize - 1; row >= 1; --row)
  323. {
  324. int rowMax = row - 1;
  325. int rowMin = row - static_cast<int>(mUBands.size());
  326. if (rowMin < 0)
  327. {
  328. rowMin = 0;
  329. }
  330. for (int r = rowMax; r >= rowMin; --r)
  331. {
  332. Real mult = tmpA(r, row);
  333. tmpA(r, row) = (Real)0;
  334. for (int c = 0; c < mSize; ++c)
  335. {
  336. invA(r, c) -= mult * invA(row, c);
  337. }
  338. }
  339. }
  340. return true;
  341. }
  342. private:
  343. // The linear system is L*U*X = B, where A = L*U and U = L^T, Reduce
  344. // this to U*X = L^{-1}*B. The return value is 'true' iff the
  345. // operation is successful.
  346. bool SolveLower(Real* dataVector) const
  347. {
  348. int const size = static_cast<int>(mDBand.size());
  349. for (int r = 0; r < size; ++r)
  350. {
  351. Real lowerRR = operator()(r, r);
  352. if (lowerRR > (Real)0)
  353. {
  354. for (int c = 0; c < r; ++c)
  355. {
  356. Real lowerRC = operator()(r, c);
  357. dataVector[r] -= lowerRC * dataVector[c];
  358. }
  359. dataVector[r] /= lowerRR;
  360. }
  361. else
  362. {
  363. return false;
  364. }
  365. }
  366. return true;
  367. }
  368. // The linear system is U*X = L^{-1}*B. Reduce this to
  369. // X = U^{-1}*L^{-1}*B. The return value is 'true' iff the operation
  370. // is successful.
  371. bool SolveUpper(Real* dataVector) const
  372. {
  373. int const size = static_cast<int>(mDBand.size());
  374. for (int r = size - 1; r >= 0; --r)
  375. {
  376. Real upperRR = operator()(r, r);
  377. if (upperRR > (Real)0)
  378. {
  379. for (int c = r + 1; c < size; ++c)
  380. {
  381. Real upperRC = operator()(r, c);
  382. dataVector[r] -= upperRC * dataVector[c];
  383. }
  384. dataVector[r] /= upperRR;
  385. }
  386. else
  387. {
  388. return false;
  389. }
  390. }
  391. return true;
  392. }
  393. // The linear system is L*U*X = B, where A = L*U and U = L^T, Reduce
  394. // this to U*X = L^{-1}*B. The return value is 'true' iff the
  395. // operation is successful. See the comments for
  396. // SolveSystem(Real*,int) about the storage for dataMatrix.
  397. template <bool RowMajor>
  398. bool SolveLower(Real* dataMatrix, int numColumns) const
  399. {
  400. LexicoArray2<RowMajor, Real> data(mSize, numColumns, dataMatrix);
  401. for (int r = 0; r < mSize; ++r)
  402. {
  403. Real lowerRR = operator()(r, r);
  404. if (lowerRR > (Real)0)
  405. {
  406. for (int c = 0; c < r; ++c)
  407. {
  408. Real lowerRC = operator()(r, c);
  409. for (int bCol = 0; bCol < numColumns; ++bCol)
  410. {
  411. data(r, bCol) -= lowerRC * data(c, bCol);
  412. }
  413. }
  414. Real inverse = ((Real)1) / lowerRR;
  415. for (int bCol = 0; bCol < numColumns; ++bCol)
  416. {
  417. data(r, bCol) *= inverse;
  418. }
  419. }
  420. else
  421. {
  422. return false;
  423. }
  424. }
  425. return true;
  426. }
  427. // The linear system is U*X = L^{-1}*B. Reduce this to
  428. // X = U^{-1}*L^{-1}*B. The return value is 'true' iff the operation
  429. // is successful. See the comments for SolveSystem(Real*,int) about
  430. // the storage for dataMatrix.
  431. template <bool RowMajor>
  432. bool SolveUpper(Real* dataMatrix, int numColumns) const
  433. {
  434. LexicoArray2<RowMajor, Real> data(mSize, numColumns, dataMatrix);
  435. for (int r = mSize - 1; r >= 0; --r)
  436. {
  437. Real upperRR = operator()(r, r);
  438. if (upperRR > (Real)0)
  439. {
  440. for (int c = r + 1; c < mSize; ++c)
  441. {
  442. Real upperRC = operator()(r, c);
  443. for (int bCol = 0; bCol < numColumns; ++bCol)
  444. {
  445. data(r, bCol) -= upperRC * data(c, bCol);
  446. }
  447. }
  448. Real inverse = ((Real)1) / upperRR;
  449. for (int bCol = 0; bCol < numColumns; ++bCol)
  450. {
  451. data(r, bCol) *= inverse;
  452. }
  453. }
  454. else
  455. {
  456. return false;
  457. }
  458. }
  459. return true;
  460. }
  461. int mSize;
  462. std::vector<Real> mDBand;
  463. std::vector<std::vector<Real>> mLBands, mUBands;
  464. // For return by operator()(int,int) for valid indices not in the
  465. // bands, in which case the matrix entries are zero,
  466. mutable Real mZero;
  467. };
  468. }