123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215 |
- #pragma once
- #include <algorithm>
- #include <numeric>
- #include <random>
- #include <vector>
- #define GTE_APPR_QUERY_VALIDATE_INDICES
- namespace WwiseGTE
- {
- template <typename Real, typename ObservationType>
- class ApprQuery
- {
- public:
-
- ApprQuery() = default;
- virtual ~ApprQuery() = default;
-
-
- virtual bool FitIndexed(
- size_t numObservations, ObservationType const* observations,
- size_t numIndices, int const* indices) = 0;
- bool ValidIndices(
- size_t numObservations, ObservationType const* observations,
- size_t numIndices, int const* indices)
- {
- #if defined(GTE_APPR_QUERY_VALIDATE_INDICES)
- if (observations && indices &&
- GetMinimumRequired() <= numIndices && numIndices <= numObservations)
- {
- int const* currentIndex = indices;
- for (size_t i = 0; i < numIndices; ++i)
- {
- if (*currentIndex++ >= static_cast<int>(numObservations))
- {
- return false;
- }
- }
- return true;
- }
- return false;
- #else
-
- (void)numObservations;
- (void)observations;
- (void)numIndices;
- (void)indices;
- return true;
- #endif
- }
-
-
- bool Fit(size_t numObservations, ObservationType const* observations)
- {
- std::vector<int> indices(numObservations);
- std::iota(indices.begin(), indices.end(), 0);
- return FitIndexed(numObservations, observations, indices.size(), indices.data());
- }
-
-
- bool Fit(std::vector<ObservationType> const& observations)
- {
- std::vector<int> indices(observations.size());
- std::iota(indices.begin(), indices.end(), 0);
- return FitIndexed(observations.size(), observations.data(), indices.size(), indices.data());
- }
-
-
- bool Fit(std::vector<ObservationType> const& observations, size_t imin, size_t imax)
- {
- if (imin <= imax)
- {
- size_t numIndices = static_cast<size_t>(imax - imin + 1);
- std::vector<int> indices(numIndices);
- std::iota(indices.begin(), indices.end(), static_cast<int>(imin));
- return FitIndexed(observations.size(), observations.data(), indices.size(), indices.data());
- }
- else
- {
- return false;
- }
- }
-
- virtual bool Fit(std::vector<ObservationType> const& observations,
- std::vector<int> const& indices)
- {
- return FitIndexed(observations.size(), observations.data(), indices.size(), indices.data());
- }
-
-
-
- bool Fit(std::vector<ObservationType> const& observations,
- std::vector<int> const& indices, size_t numIndices)
- {
- size_t imax = std::min(numIndices, indices.size());
- std::vector<int> localindices(imax);
- std::copy(indices.begin(), indices.begin() + imax, localindices.begin());
- return FitIndexed(observations.size(), observations.data(), localindices.size(), localindices.data());
- }
-
-
-
-
- virtual size_t GetMinimumRequired() const = 0;
-
-
- virtual Real Error(ObservationType const& observation) const = 0;
-
-
- virtual void CopyParameters(ApprQuery const* input) = 0;
- static bool RANSAC(ApprQuery& candidateModel, std::vector<ObservationType> const& observations,
- size_t numRequiredForGoodFit, Real maxErrorForGoodFit, size_t numIterations,
- std::vector<int>& bestConsensus, ApprQuery& bestModel)
- {
- size_t const numObservations = observations.size();
- size_t const minRequired = candidateModel.GetMinimumRequired();
- if (numObservations < minRequired)
- {
-
- return false;
- }
-
-
-
-
-
-
-
- std::vector<int> candidates(numObservations);
- std::iota(candidates.begin(), candidates.end(), 0);
- if (numObservations == minRequired)
- {
-
-
-
- bestConsensus = candidates;
- return bestModel.Fit(observations);
- }
- size_t bestNumFittedObservations = minRequired;
- for (size_t i = 0; i < numIterations; ++i)
- {
-
-
-
-
- std::shuffle(candidates.begin(), candidates.end(), std::default_random_engine());
-
- if (candidateModel.Fit(observations, candidates, minRequired))
- {
-
-
- size_t numFittedObservations = minRequired;
- for (size_t j = minRequired; j < numObservations; ++j)
- {
- Real error = candidateModel.Error(observations[candidates[j]]);
- if (error <= maxErrorForGoodFit)
- {
- std::swap(candidates[j], candidates[numFittedObservations]);
- ++numFittedObservations;
- }
- }
- if (numFittedObservations >= numRequiredForGoodFit)
- {
-
-
- candidateModel.Fit(observations, candidates, numFittedObservations);
- if (numFittedObservations > bestNumFittedObservations)
- {
-
-
- bestModel.CopyParameters(&candidateModel);
- bestConsensus.resize(numFittedObservations);
- std::copy(candidates.begin(), candidates.begin() + numFittedObservations, bestConsensus.begin());
- bestNumFittedObservations = numFittedObservations;
- }
- }
- }
- }
- return bestNumFittedObservations >= numRequiredForGoodFit;
- }
- };
- }
|