NearestNeighborQuery.h 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. // David Eberly, Geometric Tools, Redmond WA 98052
  2. // Copyright (c) 1998-2020
  3. // Distributed under the Boost Software License, Version 1.0.
  4. // https://www.boost.org/LICENSE_1_0.txt
  5. // https://www.geometrictools.com/License/Boost/LICENSE_1_0.txt
  6. // Version: 4.0.2019.08.13
  7. #pragma once
  8. #include <Mathematics/Logger.h>
  9. #include <Mathematics/Vector.h>
  10. #include <vector>
  11. // TODO: This is not a KD-tree nearest neighbor query. Instead, it is an
  12. // algorithm to get "approximate" nearest neighbors. Replace this by the
  13. // actual KD-tree query.
  14. // Use a kd-tree for sorting used in a query for finding nearest neighbors of
  15. // a point in a space of the specified dimension N. The split order is always
  16. // 0,1,2,...,N-1. The number of sites at a leaf node is controlled by
  17. // 'maxLeafSize' and the maximum level of the tree is controlled by
  18. // 'maxLevels'. The points are of type Vector<N,Real>. The 'Site' is a
  19. // structure of information that minimally implements the function
  20. // 'Vector<N,Real> GetPosition () const'. The Site template parameter
  21. // allows the query to be applied even when it has more local information
  22. // than just point location.
  23. namespace WwiseGTE
  24. {
  25. // Predefined site structs for convenience.
  26. template <int N, typename T>
  27. struct PositionSite
  28. {
  29. Vector<N, T> position;
  30. PositionSite(Vector<N, T> const& p)
  31. :
  32. position(p)
  33. {
  34. }
  35. Vector<N, T> GetPosition() const
  36. {
  37. return position;
  38. }
  39. };
  40. // Predefined site structs for convenience.
  41. template <int N, typename T>
  42. struct PositionDirectionSite
  43. {
  44. Vector<N, T> position;
  45. Vector<N, T> direction;
  46. PositionDirectionSite(Vector<N, T> const& p, Vector<N, T> const& d)
  47. :
  48. position(p),
  49. direction(d)
  50. {
  51. }
  52. Vector<N, T> GetPosition() const
  53. {
  54. return position;
  55. }
  56. };
  57. template <int N, typename Real, typename Site>
  58. class NearestNeighborQuery
  59. {
  60. public:
  61. // Supporting data structures.
  62. typedef std::pair<Vector<N, Real>, int> SortedPoint;
  63. struct Node
  64. {
  65. Real split;
  66. int axis;
  67. int numSites;
  68. int siteOffset;
  69. int left;
  70. int right;
  71. };
  72. // Construction.
  73. NearestNeighborQuery(std::vector<Site> const& sites, int maxLeafSize, int maxLevel)
  74. :
  75. mMaxLeafSize(maxLeafSize),
  76. mMaxLevel(maxLevel),
  77. mSortedPoints(sites.size()),
  78. mDepth(0),
  79. mLargestNodeSize(0)
  80. {
  81. LogAssert(mMaxLevel > 0 && mMaxLevel <= 32, "Invalid max level.");
  82. int const numSites = static_cast<int>(sites.size());
  83. for (int i = 0; i < numSites; ++i)
  84. {
  85. mSortedPoints[i] = std::make_pair(sites[i].GetPosition(), i);
  86. }
  87. mNodes.push_back(Node());
  88. Build(numSites, 0, 0, 0);
  89. }
  90. // Member access.
  91. inline int GetMaxLeafSize() const
  92. {
  93. return mMaxLeafSize;
  94. }
  95. inline int GetMaxLevel() const
  96. {
  97. return mMaxLevel;
  98. }
  99. inline int GetDepth() const
  100. {
  101. return mDepth;
  102. }
  103. inline int GetLargestNodeSize() const
  104. {
  105. return mLargestNodeSize;
  106. }
  107. int GetNumNodes() const
  108. {
  109. return static_cast<int>(mNodes.size());
  110. }
  111. inline std::vector<Node> const& GetNodes() const
  112. {
  113. return mNodes;
  114. }
  115. // Compute up to MaxNeighbors nearest neighbors within the specified
  116. // radius of the point. The returned integer is the number of
  117. // neighbors found, possibly zero. The neighbors array stores indices
  118. // into the array passed to the constructor.
  119. template <int MaxNeighbors>
  120. int FindNeighbors(Vector<N, Real> const& point, Real radius, std::array<int, MaxNeighbors>& neighbors) const
  121. {
  122. Real sqrRadius = radius * radius;
  123. int numNeighbors = 0;
  124. std::array<int, MaxNeighbors + 1> localNeighbors;
  125. std::array<Real, MaxNeighbors + 1> neighborSqrLength;
  126. for (int i = 0; i <= MaxNeighbors; ++i)
  127. {
  128. localNeighbors[i] = -1;
  129. neighborSqrLength[i] = std::numeric_limits<Real>::max();
  130. }
  131. // The kd-tree construction is recursive, simulated here by using
  132. // a stack. The maximum depth is limited to 32, because the number
  133. // of sites is limited to 2^{32} (the number of 32-bit integer
  134. // indices).
  135. std::array<int, 32> stack;
  136. int top = 0;
  137. stack[0] = 0;
  138. int maxNeighbors = MaxNeighbors;
  139. if (maxNeighbors == 1)
  140. {
  141. while (top >= 0)
  142. {
  143. Node node = mNodes[stack[top--]];
  144. if (node.siteOffset != -1)
  145. {
  146. for (int i = 0, j = node.siteOffset; i < node.numSites; ++i, ++j)
  147. {
  148. auto diff = mSortedPoints[j].first - point;
  149. auto sqrLength = Dot(diff, diff);
  150. if (sqrLength <= sqrRadius)
  151. {
  152. // Maintain the nearest neighbors.
  153. if (sqrLength <= neighborSqrLength[0])
  154. {
  155. localNeighbors[0] = mSortedPoints[j].second;
  156. neighborSqrLength[0] = sqrLength;
  157. numNeighbors = 1;
  158. }
  159. }
  160. }
  161. }
  162. if (node.left != -1 && point[node.axis] - radius <= node.split)
  163. {
  164. stack[++top] = node.left;
  165. }
  166. if (node.right != -1 && point[node.axis] + radius >= node.split)
  167. {
  168. stack[++top] = node.right;
  169. }
  170. }
  171. }
  172. else
  173. {
  174. while (top >= 0)
  175. {
  176. Node node = mNodes[stack[top--]];
  177. if (node.siteOffset != -1)
  178. {
  179. for (int i = 0, j = node.siteOffset; i < node.numSites; ++i, ++j)
  180. {
  181. Vector<N, Real> diff = mSortedPoints[j].first - point;
  182. Real sqrLength = Dot(diff, diff);
  183. if (sqrLength <= sqrRadius)
  184. {
  185. // Maintain the nearest neighbors.
  186. int k;
  187. for (k = 0; k < numNeighbors; ++k)
  188. {
  189. if (sqrLength <= neighborSqrLength[k])
  190. {
  191. for (int n = numNeighbors; n > k; --n)
  192. {
  193. localNeighbors[n] = localNeighbors[n - 1];
  194. neighborSqrLength[n] = neighborSqrLength[n - 1];
  195. }
  196. break;
  197. }
  198. }
  199. if (k < MaxNeighbors)
  200. {
  201. localNeighbors[k] = mSortedPoints[j].second;
  202. neighborSqrLength[k] = sqrLength;
  203. }
  204. if (numNeighbors < MaxNeighbors)
  205. {
  206. ++numNeighbors;
  207. }
  208. }
  209. }
  210. }
  211. if (node.left != -1 && point[node.axis] - radius <= node.split)
  212. {
  213. stack[++top] = node.left;
  214. }
  215. if (node.right != -1 && point[node.axis] + radius >= node.split)
  216. {
  217. stack[++top] = node.right;
  218. }
  219. }
  220. }
  221. for (int i = 0; i < numNeighbors; ++i)
  222. {
  223. neighbors[i] = localNeighbors[i];
  224. }
  225. return numNeighbors;
  226. }
  227. inline std::vector<SortedPoint> const& GetSortedPoints() const
  228. {
  229. return mSortedPoints;
  230. }
  231. private:
  232. // Populate the node so that it contains the points split along the
  233. // coordinate axes.
  234. void Build(int numSites, int siteOffset, int nodeIndex, int level)
  235. {
  236. LogAssert(siteOffset != -1, "Invalid site offset.");
  237. LogAssert(nodeIndex != -1, "Invalid node index.");
  238. LogAssert(numSites > 0, "Empty point list.");
  239. mDepth = std::max(mDepth, level);
  240. Node& node = mNodes[nodeIndex];
  241. node.numSites = numSites;
  242. if (numSites > mMaxLeafSize && level <= mMaxLevel)
  243. {
  244. int halfNumSites = numSites / 2;
  245. // The point set is too large for a leaf node, so split it at
  246. // the median. The O(m log m) sort is not needed; rather, we
  247. // locate the median using an order statistic construction
  248. // that is expected time O(m).
  249. int const axis = level % N;
  250. auto sorter = [axis](SortedPoint const& p0, SortedPoint const& p1)
  251. {
  252. return p0.first[axis] < p1.first[axis];
  253. };
  254. auto begin = mSortedPoints.begin() + siteOffset;
  255. auto mid = mSortedPoints.begin() + siteOffset + halfNumSites;
  256. auto end = mSortedPoints.begin() + siteOffset + numSites;
  257. std::nth_element(begin, mid, end, sorter);
  258. // Get the median position.
  259. node.split = mSortedPoints[siteOffset + halfNumSites].first[axis];
  260. node.axis = axis;
  261. node.siteOffset = -1;
  262. // Apply a divide-and-conquer step.
  263. int left = (int)mNodes.size(), right = left + 1;
  264. node.left = left;
  265. node.right = right;
  266. mNodes.push_back(Node());
  267. mNodes.push_back(Node());
  268. int nextLevel = level + 1;
  269. Build(halfNumSites, siteOffset, left, nextLevel);
  270. Build(numSites - halfNumSites, siteOffset + halfNumSites, right, nextLevel);
  271. }
  272. else
  273. {
  274. // The number of points is small enough or we have run out of
  275. // depth, so make this node a leaf.
  276. node.split = std::numeric_limits<Real>::max();
  277. node.axis = -1;
  278. node.siteOffset = siteOffset;
  279. node.left = -1;
  280. node.right = -1;
  281. mLargestNodeSize = std::max(mLargestNodeSize, node.numSites);
  282. }
  283. }
  284. int mMaxLeafSize;
  285. int mMaxLevel;
  286. std::vector<SortedPoint> mSortedPoints;
  287. std::vector<Node> mNodes;
  288. int mDepth;
  289. int mLargestNodeSize;
  290. };
  291. }