MUDA
Loading...
Searching...
No Matches
bvh.h
1#pragma once
2#include <muda/ext/geo/lbvh/bvh_viewer.h>
3#include <muda/ext/geo/lbvh/aabb.h>
4#include <thrust/swap.h>
5#include <thrust/pair.h>
6#include <thrust/tuple.h>
7#include <thrust/host_vector.h>
8#include <thrust/device_vector.h>
9#include <thrust/functional.h>
10#include <thrust/scan.h>
11#include <thrust/sort.h>
12#include <thrust/fill.h>
13#include <thrust/for_each.h>
14#include <thrust/transform.h>
15#include <thrust/reduce.h>
16#include <thrust/iterator/constant_iterator.h>
17#include <thrust/iterator/counting_iterator.h>
18#include <thrust/unique.h>
19#include <thrust/execution_policy.h>
20#include <muda/cub/device/device_reduce.h>
21
22namespace muda::lbvh
23{
24namespace details
25{
26 template <typename DerivedPolicy, typename UInt>
27 void construct_internal_nodes(const thrust::detail::execution_policy_base<DerivedPolicy>& policy,
28 Node* nodes,
29 UInt const* node_code,
30 const uint32_t num_objects)
31 {
32 thrust::for_each(
33 policy,
34 thrust::make_counting_iterator<uint32_t>(0),
35 thrust::make_counting_iterator<uint32_t>(num_objects - 1),
36 [nodes, node_code, num_objects] __device__(const uint32_t idx)
37 {
38 nodes[idx].object_idx = 0xFFFFFFFF; // internal nodes
39
40 const uint2 ij = determine_range(node_code, num_objects, idx);
41 const int gamma = find_split(node_code, num_objects, ij.x, ij.y);
42
43 nodes[idx].left_idx = gamma;
44 nodes[idx].right_idx = gamma + 1;
45 if(thrust::min(ij.x, ij.y) == gamma)
46 {
47 nodes[idx].left_idx += num_objects - 1;
48 }
49 if(thrust::max(ij.x, ij.y) == gamma + 1)
50 {
51 nodes[idx].right_idx += num_objects - 1;
52 }
53 nodes[nodes[idx].left_idx].parent_idx = idx;
54 nodes[nodes[idx].right_idx].parent_idx = idx;
55 return;
56 });
57 }
58} // namespace details
59
60template <typename Real, typename Object>
62{
64 : whole(w)
65 {
66 }
73
74 __device__ __host__ inline uint32_t operator()(const Object&, const AABB<Real>& box) noexcept
75 {
76 auto p = centroid(box);
77 p.x -= whole.lower.x;
78 p.y -= whole.lower.y;
79 p.z -= whole.lower.z;
80 p.x /= (whole.upper.x - whole.lower.x);
81 p.y /= (whole.upper.y - whole.lower.y);
82 p.z /= (whole.upper.z - whole.lower.z);
83 return morton_code(p);
84 }
85 AABB<Real> whole;
86};
87
88template <typename Real, typename Object, typename AABBGetter, typename MortonCodeCalculator = DefaultMortonCodeCalculator<Real, Object>>
89class BVH
90{
91 public:
92 using real_type = Real;
93 using index_type = std::uint32_t;
94 using object_type = Object;
97 using aabb_getter_type = AABBGetter;
98 using morton_code_calculator_type = MortonCodeCalculator;
99
100 public:
101 BVH() = default;
102 ~BVH() = default;
103 BVH(const BVH&) = default;
104 BVH(BVH&&) = default;
105 BVH& operator=(const BVH&) = default;
106 BVH& operator=(BVH&&) = default;
107
108 void clear()
109 {
110 this->m_objects.clear();
111 this->m_aabbs.clear();
112 this->m_nodes.clear();
113 return;
114 }
115
116 BVHViewer<real_type, object_type> viewer() noexcept
117 {
119 static_cast<uint32_t>(m_nodes.size()),
120 static_cast<uint32_t>(m_objects.size()),
121 thrust::raw_pointer_cast(m_nodes.data()),
122 thrust::raw_pointer_cast(m_aabbs.data()),
123 thrust::raw_pointer_cast(m_objects.data())};
124 }
125
126 CBVHViewer<real_type, object_type> cviewer() const noexcept
127 {
129 static_cast<uint32_t>(m_nodes.size()),
130 static_cast<uint32_t>(m_objects.size()),
131 thrust::raw_pointer_cast(m_nodes.data()),
132 thrust::raw_pointer_cast(m_aabbs.data()),
133 thrust::raw_pointer_cast(m_objects.data())};
134 }
135
136
137 void build(cudaStream_t stream = nullptr)
138 {
139 auto policy = thrust::system::cuda::par_nosync.on(stream);
140 //auto policy = thrust::device;
141
142 if(m_objects.size() == 0u)
143 {
144 return;
145 }
146
147 m_host_dirty = true;
148
149 const uint32_t num_objects = m_objects.size();
150 const uint32_t num_internal_nodes = num_objects - 1;
151 const uint32_t num_nodes = num_objects * 2 - 1;
152
153 // --------------------------------------------------------------------
154 // calculate morton code of each points
155
156 const auto inf = std::numeric_limits<real_type>::infinity();
157 aabb_type default_aabb;
158 default_aabb.upper.x = -inf;
159 default_aabb.lower.x = inf;
160 default_aabb.upper.y = -inf;
161 default_aabb.lower.y = inf;
162 default_aabb.upper.z = -inf;
163 default_aabb.lower.z = inf;
164
165 this->m_aabbs.resize(num_nodes, default_aabb);
166 m_morton.resize(num_objects);
167 m_indices.resize(num_objects);
168 m_morton64.resize(num_objects);
169 node_type default_node;
170 default_node.parent_idx = 0xFFFFFFFF;
171 default_node.left_idx = 0xFFFFFFFF;
172 default_node.right_idx = 0xFFFFFFFF;
173 default_node.object_idx = 0xFFFFFFFF;
174 m_nodes.resize(num_nodes, default_node);
175 m_flag_container.clear();
176 m_flag_container.resize(num_internal_nodes, 0);
177
178 thrust::transform(policy,
179 this->m_objects.begin(),
180 this->m_objects.end(),
181 m_aabbs.begin() + num_internal_nodes,
182 aabb_getter_type());
183
184 //muda::DeviceReduce(stream).Reduce(
185 // m_aabbs.data() + num_internal_nodes,
186 // m_aabb_whole.data(),
187 // m_aabbs.size() - num_internal_nodes,
188 // [] CUB_RUNTIME_FUNCTION(const aabb_type& lhs, const aabb_type& rhs) -> aabb_type
189 // { return merge(lhs, rhs); },
190 // default_aabb);
191
192 const auto aabb_whole = thrust::reduce(
193 policy,
194 m_aabbs.data() + num_internal_nodes,
195 m_aabbs.data() + m_aabbs.size(),
196 default_aabb,
197 [] __device__ __host__(const aabb_type& lhs, const aabb_type& rhs) -> aabb_type
198 { return merge(lhs, rhs); });
199
200 thrust::transform(policy,
201 this->m_objects.begin(),
202 this->m_objects.end(),
203 m_aabbs.begin() + num_internal_nodes,
204 m_morton.begin(),
205 morton_code_calculator_type(aabb_whole));
206
207 // --------------------------------------------------------------------
208 // sort object-indices by morton code
209
210 // iota the indices
211 thrust::copy(policy,
212 thrust::make_counting_iterator<index_type>(0),
213 thrust::make_counting_iterator<index_type>(num_objects),
214 m_indices.begin());
215
216 // keep indices ascending order
217 thrust::stable_sort_by_key(
218 policy,
219 m_morton.begin(),
220 m_morton.end(),
221 thrust::make_zip_iterator(thrust::make_tuple(m_aabbs.begin() + num_internal_nodes,
222 m_indices.begin())));
223
224 // --------------------------------------------------------------------
225 // check morton codes are unique
226
227
228 const auto uniqued = thrust::unique_copy(
229 policy, m_morton.begin(), m_morton.end(), m_morton64.begin());
230
231 const bool morton_code_is_unique = (m_morton64.end() == uniqued);
232 if(!morton_code_is_unique)
233 {
234 thrust::transform(policy,
235 m_morton.begin(),
236 m_morton.end(),
237 m_indices.begin(),
238 m_morton64.begin(),
239 [] __device__ __host__(const uint32_t m, const uint32_t idx)
240 {
241 unsigned long long int m64 = m;
242 m64 <<= 32;
243 m64 |= idx;
244 return m64;
245 });
246 }
247
248 // --------------------------------------------------------------------
249 // construct leaf nodes and aabbs
250
251 thrust::transform(policy,
252 m_indices.begin(),
253 m_indices.end(),
254 this->m_nodes.begin() + num_internal_nodes,
255 [] __device__ __host__(const index_type idx)
256 {
257 node_type n;
258 n.parent_idx = 0xFFFFFFFF;
259 n.left_idx = 0xFFFFFFFF;
260 n.right_idx = 0xFFFFFFFF;
261 n.object_idx = idx;
262 return n;
263 });
264
265 // --------------------------------------------------------------------
266 // construct internal nodes
267
268 if(morton_code_is_unique)
269 {
270 const uint32_t* node_code = thrust::raw_pointer_cast(m_morton.data());
271 details::construct_internal_nodes(
272 policy, thrust::raw_pointer_cast(m_nodes.data()), node_code, num_objects);
273 }
274 else // 64bit version
275 {
276 const unsigned long long int* node_code =
277 thrust::raw_pointer_cast(m_morton64.data());
278 details::construct_internal_nodes(
279 policy, thrust::raw_pointer_cast(m_nodes.data()), node_code, num_objects);
280 }
281
282 // --------------------------------------------------------------------
283 // create AABB for each node by bottom-up strategy
284
285 const auto flags = thrust::raw_pointer_cast(m_flag_container.data());
286
287
288 thrust::for_each(policy,
289 thrust::make_counting_iterator<index_type>(num_internal_nodes),
290 thrust::make_counting_iterator<index_type>(num_nodes),
291 [nodes = thrust::raw_pointer_cast(m_nodes.data()),
292 aabbs = thrust::raw_pointer_cast(m_aabbs.data()),
293 flags] __device__(index_type idx)
294 {
295 uint32_t parent = nodes[idx].parent_idx;
296 while(parent != 0xFFFFFFFF) // means idx == 0
297 {
298 const int old = atomicCAS(flags + parent, 0, 1);
299 if(old == 0)
300 {
301 // this is the first thread entered here.
302 // wait the other thread from the other child node.
303 return;
304 }
305 MUDA_KERNEL_ASSERT(old == 1, "old=%d", old);
306 // here, the flag has already been 1. it means that this
307 // thread is the 2nd thread. merge AABB of both childlen.
308
309 const auto lidx = nodes[parent].left_idx;
310 const auto ridx = nodes[parent].right_idx;
311 const auto lbox = aabbs[lidx];
312 const auto rbox = aabbs[ridx];
313 aabbs[parent] = merge(lbox, rbox);
314
315 // look the next parent...
316 parent = nodes[parent].parent_idx;
317 }
318 return;
319 });
320 }
321
322 const auto& objects() const noexcept { return m_objects; }
323 auto& objects() noexcept { return m_objects; }
324 const auto& aabbs() const noexcept { return m_aabbs; }
325 const auto& nodes() const noexcept { return m_nodes; }
326
327 const auto& host_objects() const noexcept
328 {
329 download_if_dirty();
330 return m_h_objects;
331 }
332 const auto& host_aabbs() const noexcept
333 {
334 download_if_dirty();
335 return m_h_aabbs;
336 }
337 const auto& host_nodes() const noexcept
338 {
339 download_if_dirty();
340 return m_h_nodes;
341 }
342
343 private:
347 muda::DeviceVector<int> m_flag_container;
348
352
353 mutable bool m_host_dirty = true;
354 mutable thrust::host_vector<object_type> m_h_objects;
355 mutable thrust::host_vector<aabb_type> m_h_aabbs;
356 mutable thrust::host_vector<node_type> m_h_nodes;
357
358 void download_if_dirty() const
359 {
360 if(m_host_dirty)
361 {
362 m_h_objects = m_objects;
363 m_h_aabbs = m_aabbs;
364 m_h_nodes = m_nodes;
365 m_host_dirty = false;
366 }
367 }
368};
369} // namespace muda::lbvh
Definition vector.h:23
Definition bvh.h:90
Definition bvh_viewer.h:120
Definition aabb.h:11
Definition bvh_viewer.h:16