LCPSolver.h 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628
  1. // David Eberly, Geometric Tools, Redmond WA 98052
  2. // Copyright (c) 1998-2020
  3. // Distributed under the Boost Software License, Version 1.0.
  4. // https://www.boost.org/LICENSE_1_0.txt
  5. // https://www.geometrictools.com/License/Boost/LICENSE_1_0.txt
  6. // Version: 4.0.2019.10.04
  7. #pragma once
  8. #include <Mathematics/Logger.h>
  9. #include <algorithm>
  10. #include <array>
  11. #include <vector>
  12. // A class for solving the Linear Complementarity Problem (LCP)
  13. // w = q + M * z, w^T * z = 0, w >= 0, z >= 0. The vectors q, w, and z are
  14. // n-tuples and the matrix M is n-by-n. The inputs to Solve(...) are q and M.
  15. // The outputs are w and z, which are valid when the returned bool is true but
  16. // are invalid when the returned bool is false.
  17. //
  18. // The comments at the end of this file explain what the preprocessor symbol
  19. // means regarding the LCP solver implementation. If the algorithm fails to
  20. // converge within the specified maximum number of iterations, consider
  21. // increasing the number and calling Solve(...) again.
  22. // Expose the following preprocessor symbol if you want the code to throw an
  23. // exception the algorithm fails to converge. You can choose to trap the
  24. // exception and handle it as you please. If you do not expose the
  25. // preprocessor symbol, you can pass a Result object and check whether the
  26. // algorithm failed to converge. Again, you can handle this as you please.
  27. //
  28. //#define GTE_THROW_ON_LCPSOLVER_ERRORS
  29. namespace WwiseGTE
  30. {
  31. // Support templates for number of dimensions known at compile time or
  32. // known only at run time.
  33. template <typename Real, int... Dimensions>
  34. class LCPSolver {};
  35. template <typename Real>
  36. class LCPSolverShared
  37. {
  38. protected:
  39. // Abstract base class construction. A virtual destructor is not
  40. // provided because there are no required side effects when destroying
  41. // objects from the derived classes. The member mMaxIterations is set
  42. // by this call to the default value n*n.
  43. LCPSolverShared(int n)
  44. :
  45. mNumIterations(0),
  46. mVarBasic(nullptr),
  47. mVarNonbasic(nullptr),
  48. mNumCols(0),
  49. mAugmented(nullptr),
  50. mQMin(nullptr),
  51. mMinRatio(nullptr),
  52. mRatio(nullptr),
  53. mPoly(nullptr),
  54. mZero((Real)0),
  55. mOne((Real)1)
  56. {
  57. if (n > 0)
  58. {
  59. mDimension = n;
  60. mMaxIterations = n * n;
  61. }
  62. else
  63. {
  64. mDimension = 0;
  65. mMaxIterations = 0;
  66. }
  67. }
  68. // Use this constructor when you need a specific representation of
  69. // zero and of one to be used when manipulating the polynomials of the
  70. // base class. In particular, this is needed to select the correct
  71. // zero and correct one for QFNumber objects.
  72. LCPSolverShared(int n, Real const& zero, Real const& one)
  73. :
  74. mNumIterations(0),
  75. mVarBasic(nullptr),
  76. mVarNonbasic(nullptr),
  77. mNumCols(0),
  78. mAugmented(nullptr),
  79. mQMin(nullptr),
  80. mMinRatio(nullptr),
  81. mRatio(nullptr),
  82. mPoly(nullptr),
  83. mZero(zero),
  84. mOne(one)
  85. {
  86. if (n > 0)
  87. {
  88. mDimension = n;
  89. mMaxIterations = n * n;
  90. }
  91. else
  92. {
  93. mDimension = 0;
  94. mMaxIterations = 0;
  95. }
  96. }
  97. public:
  98. // Theoretically, when there is a solution the algorithm must converge
  99. // in a finite number of iterations. The number of iterations depends
  100. // on the problem at hand, but we need to guard against an infinite
  101. // loop by limiting the number. The implementation uses a maximum
  102. // number of n*n (chosen arbitrarily). You can set the number
  103. // yourself, perhaps when a call to Solve fails--increase the number
  104. // of iterations and call and solve again.
  105. inline void SetMaxIterations(int maxIterations)
  106. {
  107. mMaxIterations = (maxIterations > 0 ? maxIterations : mDimension * mDimension);
  108. }
  109. inline int GetMaxIterations() const
  110. {
  111. return mMaxIterations;
  112. }
  113. // Access the actual number of iterations used in a call to Solve.
  114. inline int GetNumIterations() const
  115. {
  116. return mNumIterations;
  117. }
  118. enum Result
  119. {
  120. HAS_TRIVIAL_SOLUTION,
  121. HAS_NONTRIVIAL_SOLUTION,
  122. NO_SOLUTION,
  123. FAILED_TO_CONVERGE,
  124. INVALID_INPUT
  125. };
  126. protected:
  127. // Bookkeeping of variables during the iterations of the solver. The
  128. // name is either 'w' or 'z' and is used for human-readable debugging
  129. // help. The 'index' is that for the original variables w[index] or
  130. // z[index]. The 'complementary' index is the location of the
  131. // complementary variable in mVarBasic[] or in mVarNonbasic[]. The
  132. // 'tuple' is a pointer to &w[0] or &z[0], the choice based on name of
  133. // 'w' or 'z', and is used to fill in the solution values (the
  134. // variables are permuted during the pivoting algorithm).
  135. struct Variable
  136. {
  137. char name;
  138. int index;
  139. int complementary;
  140. Real* tuple;
  141. };
  142. // The augmented problem is w = q + M*z + z[n]*U = 0, where U is an
  143. // n-tuple of 1-values. We manipulate the augmented matrix
  144. // [M | U | p(t)] where p(t) is a column vector of polynomials of at
  145. // most degree n. If p[r](t) is the polynomial for row r, then
  146. // p[r](0) = q[r]. These are perturbations of q[r] designed so that
  147. // the algorithm avoids degeneracies (a q-term becomes zero during the
  148. // iterations). The basic variables are w[0] through w[n-1] and the
  149. // nonbasic variables are z[0] through z[n]. The returned z consists
  150. // only of z[0] through z[n-1].
  151. // The derived classes ensure that the pointers point to the correct
  152. // of elements for each array. The matrix M must be stored in
  153. // row-major order.
  154. bool Solve(Real const* q, Real const* M, Real* w, Real* z, Result* result)
  155. {
  156. // Perturb the q[r] constants to be polynomials of degree r+1
  157. // represented as an array of n+1 coefficients. The coefficient
  158. // with index r+1 is 1 and the coefficients with indices larger
  159. // than r+1 are 0.
  160. for (int r = 0; r < mDimension; ++r)
  161. {
  162. mPoly[r] = &Augmented(r, mDimension + 1);
  163. MakeZero(mPoly[r]);
  164. mPoly[r][0] = q[r];
  165. mPoly[r][r + 1] = mOne;
  166. }
  167. // Determine whether there is the trivial solution w = z = 0.
  168. Copy(mPoly[0], mQMin);
  169. int basic = 0;
  170. for (int r = 1; r < mDimension; ++r)
  171. {
  172. if (LessThan(mPoly[r], mQMin))
  173. {
  174. Copy(mPoly[r], mQMin);
  175. basic = r;
  176. }
  177. }
  178. if (!LessThanZero(mQMin))
  179. {
  180. for (int r = 0; r < mDimension; ++r)
  181. {
  182. w[r] = q[r];
  183. z[r] = mZero;
  184. }
  185. if (result)
  186. {
  187. *result = HAS_TRIVIAL_SOLUTION;
  188. }
  189. return true;
  190. }
  191. // Initialize the remainder of the augmented matrix with M and U.
  192. for (int r = 0; r < mDimension; ++r)
  193. {
  194. for (int c = 0; c < mDimension; ++c)
  195. {
  196. Augmented(r, c) = M[c + mDimension * r];
  197. }
  198. Augmented(r, mDimension) = mOne;
  199. }
  200. // Keep track of when the variables enter and exit the dictionary,
  201. // including where complementary variables are relocated.
  202. for (int i = 0; i <= mDimension; ++i)
  203. {
  204. mVarBasic[i].name = 'w';
  205. mVarBasic[i].index = i;
  206. mVarBasic[i].complementary = i;
  207. mVarBasic[i].tuple = w;
  208. mVarNonbasic[i].name = 'z';
  209. mVarNonbasic[i].index = i;
  210. mVarNonbasic[i].complementary = i;
  211. mVarNonbasic[i].tuple = z;
  212. }
  213. // The augmented variable z[n] is the initial driving variable for
  214. // pivoting. The equation 'basic' is the one to solve for z[n]
  215. // and pivoting with w[basic]. The last column of M remains all
  216. // 1-values for this initial step, so no algebraic computations
  217. // occur for M[r][n].
  218. int driving = mDimension;
  219. for (int r = 0; r < mDimension; ++r)
  220. {
  221. if (r != basic)
  222. {
  223. for (int c = 0; c < mNumCols; ++c)
  224. {
  225. if (c != mDimension)
  226. {
  227. Augmented(r, c) -= Augmented(basic, c);
  228. }
  229. }
  230. }
  231. }
  232. for (int c = 0; c < mNumCols; ++c)
  233. {
  234. if (c != mDimension)
  235. {
  236. Augmented(basic, c) = -Augmented(basic, c);
  237. }
  238. }
  239. mNumIterations = 0;
  240. for (int i = 0; i < mMaxIterations; ++i, ++mNumIterations)
  241. {
  242. // The basic variable of equation 'basic' exited the
  243. // dictionary, so/ its complementary (nonbasic) variable must
  244. // become the next driving variable in order for it to enter
  245. // the dictionary.
  246. int nextDriving = mVarBasic[basic].complementary;
  247. mVarNonbasic[nextDriving].complementary = driving;
  248. std::swap(mVarBasic[basic], mVarNonbasic[driving]);
  249. if (mVarNonbasic[driving].index == mDimension)
  250. {
  251. // The algorithm has converged.
  252. for (int r = 0; r < mDimension; ++r)
  253. {
  254. mVarBasic[r].tuple[mVarBasic[r].index] = mPoly[r][0];
  255. }
  256. for (int c = 0; c <= mDimension; ++c)
  257. {
  258. int index = mVarNonbasic[c].index;
  259. if (index < mDimension)
  260. {
  261. mVarNonbasic[c].tuple[index] = mZero;
  262. }
  263. }
  264. if (result)
  265. {
  266. *result = HAS_NONTRIVIAL_SOLUTION;
  267. }
  268. return true;
  269. }
  270. // Determine the 'basic' equation for which the ratio
  271. // -q[r]/M(r,driving) is minimized among all equations r with
  272. // M(r,driving) < 0.
  273. driving = nextDriving;
  274. basic = -1;
  275. for (int r = 0; r < mDimension; ++r)
  276. {
  277. if (Augmented(r, driving) < mZero)
  278. {
  279. Real factor = -mOne / Augmented(r, driving);
  280. Multiply(mPoly[r], factor, mRatio);
  281. if (basic == -1 || LessThan(mRatio, mMinRatio))
  282. {
  283. Copy(mRatio, mMinRatio);
  284. basic = r;
  285. }
  286. }
  287. }
  288. if (basic == -1)
  289. {
  290. // The coefficients of z[driving] in all the equations are
  291. // nonnegative, so the z[driving] variable cannot leave
  292. // the dictionary. There is no solution to the LCP.
  293. for (int r = 0; r < mDimension; ++r)
  294. {
  295. w[r] = mZero;
  296. z[r] = mZero;
  297. }
  298. if (result)
  299. {
  300. *result = NO_SOLUTION;
  301. }
  302. return false;
  303. }
  304. // Solve the basic equation so that z[driving] enters the
  305. // dictionary and w[basic] exits the dictionary.
  306. Real invDenom = mOne / Augmented(basic, driving);
  307. for (int r = 0; r < mDimension; ++r)
  308. {
  309. if (r != basic && Augmented(r, driving) != mZero)
  310. {
  311. Real multiplier = Augmented(r, driving) * invDenom;
  312. for (int c = 0; c < mNumCols; ++c)
  313. {
  314. if (c != driving)
  315. {
  316. Augmented(r, c) -= Augmented(basic, c) * multiplier;
  317. }
  318. else
  319. {
  320. Augmented(r, driving) = multiplier;
  321. }
  322. }
  323. }
  324. }
  325. for (int c = 0; c < mNumCols; ++c)
  326. {
  327. if (c != driving)
  328. {
  329. Augmented(basic, c) = -Augmented(basic, c) * invDenom;
  330. }
  331. else
  332. {
  333. Augmented(basic, driving) = invDenom;
  334. }
  335. }
  336. }
  337. // Numerical round-off errors can cause the Lemke algorithm not to
  338. // converge. In particular, the code above has a test
  339. // if (mAugmented[r][driving] < (Real)0) { ... }
  340. // to determine the 'basic' equation with which to pivot. It is
  341. // possible that theoretically mAugmented[r][driving]is zero but
  342. // rounding errors cause it to be slightly negative. If
  343. // theoretically all mAugmented[r][driving] >= 0, there is no
  344. // solution to the LCP. With the rounding errors, if the
  345. // algorithm fails to converge within the specified number of
  346. // iterations, NO_SOLUTION is returned, which is hopefully the
  347. // correct result. It is also possible that the rounding errors
  348. // lead to a NO_SOLUTION (returned from inside the loop) when in
  349. // fact there is a solution. When the LCP solver is used by
  350. // intersection testing algorithms, the hope is that
  351. // misclassifications occur only when the two objects are nearly
  352. // in tangential contact.
  353. //
  354. // To determine whether the rounding errors are the problem, you
  355. // can execute the query using exact arithmetic with the following
  356. // type used for 'Real' (replacing 'float' or 'double') of
  357. // BSRational<UIntegerAP32> Rational.
  358. //
  359. // That said, if the algorithm fails to converge and you believe
  360. // that the rounding errors are not causing this, please file a
  361. // bug report and provide the input data to the solver.
  362. #if defined(GTE_THROW_ON_LCPSOLVER_ERRORS)
  363. LogError("LCPSolverShared::Solve failed to converge.");
  364. #endif
  365. if (result)
  366. {
  367. *result = FAILED_TO_CONVERGE;
  368. }
  369. return false;
  370. }
  371. // Access mAugmented as a 2-dimensional array.
  372. inline Real const& Augmented(int row, int col) const
  373. {
  374. return mAugmented[col + mNumCols * row];
  375. }
  376. inline Real& Augmented(int row, int col)
  377. {
  378. return mAugmented[col + mNumCols * row];
  379. }
  380. // Support for polynomials with n+1 coefficients and degree no larger
  381. // than n.
  382. void MakeZero(Real* poly)
  383. {
  384. for (int i = 0; i <= mDimension; ++i)
  385. {
  386. poly[i] = mZero;
  387. }
  388. }
  389. void Copy(Real const* poly0, Real* poly1)
  390. {
  391. for (int i = 0; i <= mDimension; ++i)
  392. {
  393. poly1[i] = poly0[i];
  394. }
  395. }
  396. bool LessThan(Real const* poly0, Real const* poly1)
  397. {
  398. for (int i = 0; i <= mDimension; ++i)
  399. {
  400. if (poly0[i] < poly1[i])
  401. {
  402. return true;
  403. }
  404. if (poly0[i] > poly1[i])
  405. {
  406. return false;
  407. }
  408. }
  409. return false;
  410. }
  411. bool LessThanZero(Real const* poly)
  412. {
  413. for (int i = 0; i <= mDimension; ++i)
  414. {
  415. if (poly[i] < mZero)
  416. {
  417. return true;
  418. }
  419. if (poly[i] > mZero)
  420. {
  421. return false;
  422. }
  423. }
  424. return false;
  425. }
  426. void Multiply(Real const* poly, Real scalar, Real* product)
  427. {
  428. for (int i = 0; i <= mDimension; ++i)
  429. {
  430. product[i] = poly[i] * scalar;
  431. }
  432. }
  433. int mDimension;
  434. int mMaxIterations;
  435. int mNumIterations;
  436. // These pointers are set by the derived-class constructors to arrays
  437. // that have the correct number of elements. The arrays mVarBasic,
  438. // mVarNonbasic, mQMin, mMinRatio, and mRatio each have n+1 elements.
  439. // The mAugmented array has n rows and 2*(n+1) columns stored in
  440. // row-major order in a 1-dimensional array. The array of pointers
  441. // mPoly has n elements.
  442. Variable* mVarBasic;
  443. Variable* mVarNonbasic;
  444. int mNumCols;
  445. Real* mAugmented;
  446. Real* mQMin;
  447. Real* mMinRatio;
  448. Real* mRatio;
  449. Real** mPoly;
  450. Real mZero, mOne;
  451. };
  452. template <typename Real, int n>
  453. class LCPSolver<Real, n> : public LCPSolverShared<Real>
  454. {
  455. public:
  456. // Construction. The member mMaxIterations is set by this call to the
  457. // default value n*n.
  458. LCPSolver()
  459. :
  460. LCPSolverShared<Real>(n)
  461. {
  462. this->mVarBasic = mArrayVarBasic.data();
  463. this->mVarNonbasic = mArrayVarNonbasic.data();
  464. this->mNumCols = 2 * (n + 1);
  465. this->mAugmented = mArrayAugmented.data();
  466. this->mQMin = mArrayQMin.data();
  467. this->mMinRatio = mArrayMinRatio.data();
  468. this->mRatio = mArrayRatio.data();
  469. this->mPoly = mArrayPoly.data();
  470. }
  471. // Use this constructor when you need a specific representation of
  472. // zero and of one to be used when manipulating the polynomials of the
  473. // base class. In particular, this is needed to select the correct
  474. // zero and correct one for QFNumber objects.
  475. LCPSolver(Real const& zero, Real const& one)
  476. :
  477. LCPSolverShared<Real>(n, zero, one)
  478. {
  479. this->mVarBasic = mArrayVarBasic.data();
  480. this->mVarNonbasic = mArrayVarNonbasic.data();
  481. this->mNumCols = 2 * (n + 1);
  482. this->mAugmented = mArrayAugmented.data();
  483. this->mQMin = mArrayQMin.data();
  484. this->mMinRatio = mArrayMinRatio.data();
  485. this->mRatio = mArrayRatio.data();
  486. this->mPoly = mArrayPoly.data();
  487. }
  488. // If you want to know specifically why 'true' or 'false' was
  489. // returned, pass the address of a Result variable as the last
  490. // parameter.
  491. bool Solve(std::array<Real, n> const& q, std::array<std::array<Real, n>, n> const& M,
  492. std::array<Real, n>& w, std::array<Real, n>& z,
  493. typename LCPSolverShared<Real>::Result* result = nullptr)
  494. {
  495. return LCPSolverShared<Real>::Solve(q.data(), M.front().data(), w.data(), z.data(), result);
  496. }
  497. private:
  498. std::array<typename LCPSolverShared<Real>::Variable, n + 1> mArrayVarBasic;
  499. std::array<typename LCPSolverShared<Real>::Variable, n + 1> mArrayVarNonbasic;
  500. std::array<Real, 2 * (n + 1)* n> mArrayAugmented;
  501. std::array<Real, n + 1> mArrayQMin;
  502. std::array<Real, n + 1> mArrayMinRatio;
  503. std::array<Real, n + 1> mArrayRatio;
  504. std::array<Real*, n> mArrayPoly;
  505. };
  506. template <typename Real>
  507. class LCPSolver<Real> : public LCPSolverShared<Real>
  508. {
  509. public:
  510. // Construction. The member mMaxIterations is set by this call to the
  511. // default value n*n.
  512. LCPSolver(int n)
  513. :
  514. LCPSolverShared<Real>(n)
  515. {
  516. if (n > 0)
  517. {
  518. mVectorVarBasic.resize(n + 1);
  519. mVectorVarNonbasic.resize(n + 1);
  520. mVectorAugmented.resize(2 * (n + 1) * n);
  521. mVectorQMin.resize(n + 1);
  522. mVectorMinRatio.resize(n + 1);
  523. mVectorRatio.resize(n + 1);
  524. mVectorPoly.resize(n);
  525. this->mVarBasic = mVectorVarBasic.data();
  526. this->mVarNonbasic = mVectorVarNonbasic.data();
  527. this->mNumCols = 2 * (n + 1);
  528. this->mAugmented = mVectorAugmented.data();
  529. this->mQMin = mVectorQMin.data();
  530. this->mMinRatio = mVectorMinRatio.data();
  531. this->mRatio = mVectorRatio.data();
  532. this->mPoly = mVectorPoly.data();
  533. }
  534. }
  535. // The input q must have n elements and the input M must be an n-by-n
  536. // matrix stored in row-major order. The outputs w and z have n
  537. // elements. If you want to know specifically why 'true' or 'false'
  538. // was returned, pass the address of a Result variable as the last
  539. // parameter.
  540. bool Solve(std::vector<Real> const& q, std::vector<Real> const& M,
  541. std::vector<Real>& w, std::vector<Real>& z,
  542. typename LCPSolverShared<Real>::Result* result = nullptr)
  543. {
  544. if (this->mDimension > static_cast<int>(q.size())
  545. || this->mDimension * this->mDimension > static_cast<int>(M.size()))
  546. {
  547. if (result)
  548. {
  549. *result = this->INVALID_INPUT;
  550. }
  551. return false;
  552. }
  553. if (this->mDimension > static_cast<int>(w.size()))
  554. {
  555. w.resize(this->mDimension);
  556. }
  557. if (this->mDimension > static_cast<int>(z.size()))
  558. {
  559. z.resize(this->mDimension);
  560. }
  561. return LCPSolverShared<Real>::Solve(q.data(), M.data(), w.data(), z.data(), result);
  562. }
  563. private:
  564. std::vector<typename LCPSolverShared<Real>::Variable> mVectorVarBasic;
  565. std::vector<typename LCPSolverShared<Real>::Variable> mVectorVarNonbasic;
  566. std::vector<Real> mVectorAugmented;
  567. std::vector<Real> mVectorQMin;
  568. std::vector<Real> mVectorMinRatio;
  569. std::vector<Real> mVectorRatio;
  570. std::vector<Real*> mVectorPoly;
  571. };
  572. }