ApprQuery.h 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. // David Eberly, Geometric Tools, Redmond WA 98052
  2. // Copyright (c) 1998-2020
  3. // Distributed under the Boost Software License, Version 1.0.
  4. // https://www.boost.org/LICENSE_1_0.txt
  5. // https://www.geometrictools.com/License/Boost/LICENSE_1_0.txt
  6. // Version: 4.0.2019.08.13
  7. #pragma once
  8. #include <algorithm>
  9. #include <numeric>
  10. #include <random>
  11. #include <vector>
  12. // Base class support for least-squares fitting algorithms and for RANSAC
  13. // algorithms.
  14. // Expose this define if you want the code to verify that the incoming
  15. // indices to the fitting functions are valid.
  16. #define GTE_APPR_QUERY_VALIDATE_INDICES
  17. namespace WwiseGTE
  18. {
  19. template <typename Real, typename ObservationType>
  20. class ApprQuery
  21. {
  22. public:
  23. // Construction and destruction.
  24. ApprQuery() = default;
  25. virtual ~ApprQuery() = default;
  26. // The base-class Fit* functions are generic but need to call the
  27. // indexed fitting function for the specific derived class.
  28. virtual bool FitIndexed(
  29. size_t numObservations, ObservationType const* observations,
  30. size_t numIndices, int const* indices) = 0;
  31. bool ValidIndices(
  32. size_t numObservations, ObservationType const* observations,
  33. size_t numIndices, int const* indices)
  34. {
  35. #if defined(GTE_APPR_QUERY_VALIDATE_INDICES)
  36. if (observations && indices &&
  37. GetMinimumRequired() <= numIndices && numIndices <= numObservations)
  38. {
  39. int const* currentIndex = indices;
  40. for (size_t i = 0; i < numIndices; ++i)
  41. {
  42. if (*currentIndex++ >= static_cast<int>(numObservations))
  43. {
  44. return false;
  45. }
  46. }
  47. return true;
  48. }
  49. return false;
  50. #else
  51. // The caller is responsible for passing correctly formed data.
  52. (void)numObservations;
  53. (void)observations;
  54. (void)numIndices;
  55. (void)indices;
  56. return true;
  57. #endif
  58. }
  59. // Estimate the model parameters for all observations passed in via
  60. // raw pointers.
  61. bool Fit(size_t numObservations, ObservationType const* observations)
  62. {
  63. std::vector<int> indices(numObservations);
  64. std::iota(indices.begin(), indices.end(), 0);
  65. return FitIndexed(numObservations, observations, indices.size(), indices.data());
  66. }
  67. // Estimate the model parameters for all observations passed in via
  68. // std::vector.
  69. bool Fit(std::vector<ObservationType> const& observations)
  70. {
  71. std::vector<int> indices(observations.size());
  72. std::iota(indices.begin(), indices.end(), 0);
  73. return FitIndexed(observations.size(), observations.data(), indices.size(), indices.data());
  74. }
  75. // Estimate the model parameters for a contiguous subset of
  76. // observations.
  77. bool Fit(std::vector<ObservationType> const& observations, size_t imin, size_t imax)
  78. {
  79. if (imin <= imax)
  80. {
  81. size_t numIndices = static_cast<size_t>(imax - imin + 1);
  82. std::vector<int> indices(numIndices);
  83. std::iota(indices.begin(), indices.end(), static_cast<int>(imin));
  84. return FitIndexed(observations.size(), observations.data(), indices.size(), indices.data());
  85. }
  86. else
  87. {
  88. return false;
  89. }
  90. }
  91. // Estimate the model parameters for an indexed subset of observations.
  92. virtual bool Fit(std::vector<ObservationType> const& observations,
  93. std::vector<int> const& indices)
  94. {
  95. return FitIndexed(observations.size(), observations.data(), indices.size(), indices.data());
  96. }
  97. // Estimate the model parameters for the subset of observations
  98. // specified by the indices and the number of indices that is possibly
  99. // smaller than indices.size().
  100. bool Fit(std::vector<ObservationType> const& observations,
  101. std::vector<int> const& indices, size_t numIndices)
  102. {
  103. size_t imax = std::min(numIndices, indices.size());
  104. std::vector<int> localindices(imax);
  105. std::copy(indices.begin(), indices.begin() + imax, localindices.begin());
  106. return FitIndexed(observations.size(), observations.data(), localindices.size(), localindices.data());
  107. }
  108. // Apply the RANdom SAmple Consensus algorithm for fitting a model to
  109. // observations. The algorithm requires three virtual functions to be
  110. // implemented by the derived classes.
  111. // The minimum number of observations required to fit the model.
  112. virtual size_t GetMinimumRequired() const = 0;
  113. // Compute the model error for the specified observation for the
  114. // current model parameters.
  115. virtual Real Error(ObservationType const& observation) const = 0;
  116. // Copy the parameters between two models. This is used to copy the
  117. // candidate-model parameters to the current best-fit model.
  118. virtual void CopyParameters(ApprQuery const* input) = 0;
  119. static bool RANSAC(ApprQuery& candidateModel, std::vector<ObservationType> const& observations,
  120. size_t numRequiredForGoodFit, Real maxErrorForGoodFit, size_t numIterations,
  121. std::vector<int>& bestConsensus, ApprQuery& bestModel)
  122. {
  123. size_t const numObservations = observations.size();
  124. size_t const minRequired = candidateModel.GetMinimumRequired();
  125. if (numObservations < minRequired)
  126. {
  127. // Too few observations for model fitting.
  128. return false;
  129. }
  130. // The first part of the array will store the consensus set,
  131. // initially filled with the minimum number of indices that
  132. // correspond to the candidate inliers. The last part will store
  133. // the remaining indices. These points are tested against the
  134. // model and are added to the consensus set when they fit. All
  135. // the index manipulation is done in place. Initially, the
  136. // candidates are the identity permutation.
  137. std::vector<int> candidates(numObservations);
  138. std::iota(candidates.begin(), candidates.end(), 0);
  139. if (numObservations == minRequired)
  140. {
  141. // We have the minimum number of observations to generate the
  142. // model, so RANSAC cannot be used. Compute the model with the
  143. // entire set of observations.
  144. bestConsensus = candidates;
  145. return bestModel.Fit(observations);
  146. }
  147. size_t bestNumFittedObservations = minRequired;
  148. for (size_t i = 0; i < numIterations; ++i)
  149. {
  150. // Randomly permute the previous candidates, partitioning the
  151. // array into GetMinimumRequired() indices (the candidate
  152. // inliers) followed by the remaining indices (candidates for
  153. // testing against the model).
  154. std::shuffle(candidates.begin(), candidates.end(), std::default_random_engine());
  155. // Fit the model to the inliers.
  156. if (candidateModel.Fit(observations, candidates, minRequired))
  157. {
  158. // Test each remaining observation whether it fits the
  159. // model. If it does, include it in the consensus set.
  160. size_t numFittedObservations = minRequired;
  161. for (size_t j = minRequired; j < numObservations; ++j)
  162. {
  163. Real error = candidateModel.Error(observations[candidates[j]]);
  164. if (error <= maxErrorForGoodFit)
  165. {
  166. std::swap(candidates[j], candidates[numFittedObservations]);
  167. ++numFittedObservations;
  168. }
  169. }
  170. if (numFittedObservations >= numRequiredForGoodFit)
  171. {
  172. // We have observations that fit the model. Update the
  173. // best model using the consensus set.
  174. candidateModel.Fit(observations, candidates, numFittedObservations);
  175. if (numFittedObservations > bestNumFittedObservations)
  176. {
  177. // The consensus set is larger than the previous
  178. // consensus set, so its model becomes the best one.
  179. bestModel.CopyParameters(&candidateModel);
  180. bestConsensus.resize(numFittedObservations);
  181. std::copy(candidates.begin(), candidates.begin() + numFittedObservations, bestConsensus.begin());
  182. bestNumFittedObservations = numFittedObservations;
  183. }
  184. }
  185. }
  186. }
  187. return bestNumFittedObservations >= numRequiredForGoodFit;
  188. }
  189. };
  190. }