diff --git a/src/Raycore.jl b/src/Raycore.jl index dde188f..2de323c 100644 --- a/src/Raycore.jl +++ b/src/Raycore.jl @@ -58,7 +58,7 @@ export @_inbounds export Ray, RayDifferentials, Triangle, TriangleMesh, AccelPrimitive, BVH, Bounds3, Normal3f # Instanced BVH types -export BLAS, TLAS, InstanceDescriptor, BVHNode2, build_blas, build_tlas, INVALID_NODE +export BLAS, BLASDescriptor, TLAS, InstanceDescriptor, BVHNode2, build_blas, build_tlas, INVALID_NODE export Instance, n_instances, n_geometries, build_triangle, is_degenerate_face # TLASBuilder (new MultiTypeSet-style API) diff --git a/src/instanced-bvh.jl b/src/instanced-bvh.jl index e0411e9..930b369 100644 --- a/src/instanced-bvh.jl +++ b/src/instanced-bvh.jl @@ -101,27 +101,53 @@ struct BLAS{ root_aabb::Bounds3 end +""" + BLASDescriptor + +Lightweight descriptor for a BLAS in flat-array layout. +Instead of storing device pointers to per-BLAS arrays (which fail on Metal when +stored in GPU buffers), this stores offsets into concatenated flat arrays. + +Fields: +- `nodes_offset`: 0-based offset into the flat all_blas_nodes array +- `primitives_offset`: 0-based offset into the flat all_blas_prims array +- `root_aabb`: Bounding box of the BLAS in local space +""" +struct BLASDescriptor + nodes_offset::UInt32 + primitives_offset::UInt32 + root_aabb::Bounds3 +end + # ============================================================================== # StaticTLAS - Immutable structure for kernel traversal # ============================================================================== """ - StaticTLAS{NodeArray, InstArray, BLASArray} + StaticTLAS{NodeArray, InstArray, BLASNodeArray, BLASPrimArray, DescArray} Immutable Top-Level Acceleration Structure for GPU kernel traversal. This is what `Adapt.adapt_structure` returns from a TLAS. +Uses flat arrays with offset-based indexing instead of per-BLAS pointer arrays. +This avoids the Metal issue where device pointers stored in GPU buffers cannot +be reliably dereferenced by kernels. + The struct is immutable and contains only the arrays needed for ray traversal. No management state (dictionaries, free lists, etc.) - those stay on CPU in TLAS. """ struct StaticTLAS{ NodeArray <: AbstractVector{BVHNode2}, InstArray <: AbstractVector{InstanceDescriptor}, - BLASArray <: AbstractVector{<:BLAS} + BLASNodeArray <: AbstractVector{BVHNode2}, + BLASPrimArray <: AbstractVector{<:Triangle}, + DescArray <: AbstractVector{BLASDescriptor} } nodes::NodeArray instances::InstArray - blas_array::BLASArray + all_blas_nodes::BLASNodeArray + all_blas_prims::BLASPrimArray + blas_descriptors::DescArray root_aabb::Bounds3 end @@ -195,6 +221,11 @@ mutable struct TLAS{Backend} # Backend arrays kept alive for GC (isbits pointers in blas_array reference these) gpu_blas_arrays::Vector{Any} + # Flat BLAS arrays for StaticTLAS traversal (built during adapt, kept alive for isbits pointers) + _flat_blas_nodes::Any # concatenated BVH nodes from all BLASes + _flat_blas_prims::Any # concatenated triangles from all BLASes + _flat_blas_descs::Any # BLASDescriptor array on backend + # Whether BVH topology needs rebuild dirty::Bool @@ -281,6 +312,9 @@ function TLAS(backend) Dict{TLASHandle, UnitRange{Int}}(), # handle_to_range Set{TLASHandle}(), # deleted_handles Any[], # gpu_blas_arrays (GC roots) + nothing, # _flat_blas_nodes + nothing, # _flat_blas_prims + nothing, # _flat_blas_descs true, # dirty UInt32(1), # next_handle_id UInt32(1) # next_instance_id @@ -309,6 +343,12 @@ function free!(tlas::TLAS) finalize(arr) end empty!(tlas.gpu_blas_arrays) + tlas._flat_blas_nodes !== nothing && finalize(tlas._flat_blas_nodes) + tlas._flat_blas_prims !== nothing && finalize(tlas._flat_blas_prims) + tlas._flat_blas_descs !== nothing && finalize(tlas._flat_blas_descs) + tlas._flat_blas_nodes = nothing + tlas._flat_blas_prims = nothing + tlas._flat_blas_descs = nothing return nothing end @@ -332,7 +372,12 @@ end """ Convert a BLAS with backend arrays to isbits BLAS with device pointers. -Stores the backend arrays in `keep_alive` vector to prevent GC. +Stores the backend arrays in `keep_alive` vector to prevent GC +(entries are stored in groups of 2: nodes, primitives). + +Note: The isbits BLAS is only used by management kernels that read root_aabb +(inline data). For traversal, StaticTLAS uses flat arrays with offset-based +indexing instead (see BLASDescriptor). """ function _to_isbits_blas(backend, blas::BLAS, keep_alive::Vector{Any}) # Store the backend arrays to keep them alive @@ -365,6 +410,72 @@ function _append_blas!(backend, blas_array, isbits_blas) end end +""" + _build_flat_blas_arrays!(tlas::TLAS) + +Build concatenated flat arrays from individual BLAS GPU arrays and store them +in `tlas._flat_blas_nodes`, `tlas._flat_blas_prims`, `tlas._flat_blas_descs`. + +This avoids storing device pointers in GPU buffers (which fails on Metal). +Instead, traversal kernels use BLASDescriptor offsets to index into the flat arrays. + +The flat arrays are MtlVector/CuVector etc., kept alive by the TLAS. +During adapt, they are converted to isbits device pointers for kernels. +""" +function _build_flat_blas_arrays!(tlas::TLAS) + n_blas = length(tlas.gpu_blas_arrays) ÷ 2 + backend = tlas.backend + + if n_blas == 0 + tlas._flat_blas_nodes = nothing + tlas._flat_blas_prims = nothing + tlas._flat_blas_descs = nothing + return + end + + # Read root_aabb from blas_array (inline data, always correct even on Metal) + cpu_blas = Array(tlas.blas_array) + + # Compute total sizes and build descriptors + descriptors = Vector{BLASDescriptor}(undef, n_blas) + total_nodes = 0 + total_prims = 0 + for i in 1:n_blas + nodes_arr = tlas.gpu_blas_arrays[2(i-1) + 1] + prims_arr = tlas.gpu_blas_arrays[2(i-1) + 2] + descriptors[i] = BLASDescriptor(UInt32(total_nodes), UInt32(total_prims), cpu_blas[i].root_aabb) + total_nodes += length(nodes_arr) + total_prims += length(prims_arr) + end + + # Allocate flat arrays on backend + first_nodes = tlas.gpu_blas_arrays[1] + first_prims = tlas.gpu_blas_arrays[2] + all_nodes = similar(first_nodes, total_nodes) + all_prims = similar(first_prims, total_prims) + + # Copy BLAS data into flat arrays + nodes_pos = 1 + prims_pos = 1 + for i in 1:n_blas + nodes_arr = tlas.gpu_blas_arrays[2(i-1) + 1] + prims_arr = tlas.gpu_blas_arrays[2(i-1) + 2] + + nn = length(nodes_arr) + copyto!(all_nodes, nodes_pos, nodes_arr, 1, nn) + nodes_pos += nn + + np = length(prims_arr) + copyto!(all_prims, prims_pos, prims_arr, 1, np) + prims_pos += np + end + + # Store on TLAS to keep alive (prevents GC of backing GPU buffers) + tlas._flat_blas_nodes = all_nodes + tlas._flat_blas_prims = all_prims + tlas._flat_blas_descs = Adapt.adapt(backend, descriptors) +end + """ is_valid(tlas::TLAS, handle::TLASHandle) -> Bool @@ -385,13 +496,6 @@ function n_instances(tlas::TLAS, handle::TLASHandle)::Int return length(tlas.handle_to_range[handle]) end -""" - n_geometries(tlas::TLAS) -> Int - -Get the total number of distinct geometries (BLASes) in the TLAS. -""" -n_geometries(tlas::TLAS) = tlas.blas_array === nothing ? 0 : length(tlas.blas_array) - """ n_total_instances(tlas::TLAS) -> Int @@ -790,13 +894,12 @@ function _rebuild_bvh!(tlas::TLAS) return end - # Build TLAS BVH structure from existing GPU arrays - # instances are already on backend, blas_array has isbits BLASes - built_tlas = build_tlas(tlas.blas_array, tlas.instances) + # Build TLAS BVH topology from existing GPU arrays + # blas_array is only used for root_aabb (inline data, safe on Metal) + nodes, root_aabb = _build_tlas_topology(tlas.blas_array, tlas.instances, tlas.backend) - # Update nodes with rebuilt structure - tlas.nodes = built_tlas.nodes - tlas.root_aabb = built_tlas.root_aabb + tlas.nodes = nodes + tlas.root_aabb = root_aabb tlas.dirty = false return @@ -901,14 +1004,32 @@ The TLAS must stay alive while the StaticTLAS is in use. """ function Adapt.adapt_structure(to, tlas::TLAS) sync!(tlas) - blas = tlas.blas_array - if blas === nothing - blas = BLAS{Vector{BVHNode2}, Vector{Triangle{UInt32}}}[] + + # Build flat BLAS arrays for traversal (avoids pointer-in-buffer on Metal) + _build_flat_blas_arrays!(tlas) + + if tlas._flat_blas_nodes === nothing + # Empty scene — need correct types for StaticTLAS type parameters + prim_type = length(tlas.gpu_blas_arrays) >= 2 ? eltype(tlas.gpu_blas_arrays[2]) : Triangle{UInt32} + empty_nodes = KA.allocate(tlas.backend, BVHNode2, 0) + empty_prims = KA.allocate(tlas.backend, prim_type, 0) + empty_descs = Adapt.adapt(tlas.backend, BLASDescriptor[]) + return StaticTLAS( + adapt(to, tlas.nodes), + adapt(to, tlas.instances), + adapt(to, empty_nodes), + adapt(to, empty_prims), + adapt(to, empty_descs), + tlas.root_aabb + ) end + return StaticTLAS( adapt(to, tlas.nodes), adapt(to, tlas.instances), - adapt(to, blas), + adapt(to, tlas._flat_blas_nodes), + adapt(to, tlas._flat_blas_prims), + adapt(to, tlas._flat_blas_descs), tlas.root_aabb ) end @@ -924,15 +1045,14 @@ already have isbits device pointers. Use TLAS(items) to create a mutable TLAS that properly manages GPU array lifetimes. """ function Adapt.adapt_structure(to, tlas::StaticTLAS) - # If already isbits, return as-is (already kernel-ready) isbitstype(typeof(tlas)) && return tlas - # BLASes should already have isbits pointers from mutable TLAS path - # Just adapt the outer arrays (CLArray → CLDeviceVector) return StaticTLAS( Adapt.adapt(to, tlas.nodes), Adapt.adapt(to, tlas.instances), - Adapt.adapt(to, tlas.blas_array), + Adapt.adapt(to, tlas.all_blas_nodes), + Adapt.adapt(to, tlas.all_blas_prims), + Adapt.adapt(to, tlas.blas_descriptors), tlas.root_aabb ) end @@ -1285,35 +1405,24 @@ end end """ - build_tlas(blas_array::AbstractVector{BLAS}, instances::AbstractVector{InstanceDescriptor}) -> StaticTLAS + _build_tlas_topology(blas_array, instances, backend) -> (nodes, root_aabb) -Build a Top-Level Acceleration Structure over instances. -Uses LBVH over transformed instance AABBs. +Internal: Build TLAS BVH topology (Morton codes, sorting, tree construction, refit). +Returns (nodes, root_aabb). Only accesses blas_array for root_aabb (inline data). -Returns a StaticTLAS suitable for ray traversal. -Uses KernelAbstractions for automatic CPU/GPU execution based on input array type. +`instances` must already be on the backend. """ -function build_tlas( - blas_array::AbstractVector{B}, - instances::AbstractVector{InstanceDescriptor} -) where {B <: BLAS} +function _build_tlas_topology(blas_array, instances, backend) n = length(instances) - n == 0 && return StaticTLAS(BVHNode2[], instances, blas_array, Bounds3()) - - # Infer backend from blas_array type (instances may be CPU Vector) - backend = KA.get_backend(blas_array) # Compute scene AABB from transformed instance bounds using GPU kernel # Allocate arrays for per-instance world AABBs aabb_mins = KA.allocate(backend, Point3f, n) aabb_maxs = KA.allocate(backend, Point3f, n) - # Adapt instances to backend for kernel execution (instances may be CPU Vector) - backend_instances = Adapt.adapt(backend, instances) - # Launch kernel to compute world AABBs in parallel aabb_kernel! = compute_instance_aabbs_kernel!(backend) - aabb_kernel!(aabb_mins, aabb_maxs, backend_instances, blas_array, ndrange=n) + aabb_kernel!(aabb_mins, aabb_maxs, instances, blas_array, ndrange=n) KA.synchronize(backend) # Copy results to CPU and compute scene AABB via reduction @@ -1341,12 +1450,10 @@ function build_tlas( max(aabb_extent[3], 1f-6) ) - # Allocate arrays on same backend as input + # Calculate Morton codes on same backend as input morton_codes = KA.allocate(backend, UInt32, n) - - # Launch kernel: Calculate Morton codes for instances calc_kernel! = calculate_tlas_morton_codes_kernel!(backend) - calc_kernel!(morton_codes, backend_instances, blas_array, scene_min, scene_extent, ndrange=n) + calc_kernel!(morton_codes, instances, blas_array, scene_min, scene_extent, ndrange=n) KA.synchronize(backend) # Sort indices by Morton codes @@ -1376,7 +1483,6 @@ function build_tlas( if n == 1 # For CPU, sorted_indices is already a CPU array; for GPU, copy to avoid scalar indexing original_idx = backend isa KA.CPU ? sorted_indices[1] : Array(sorted_indices[1:1])[1] - # Use scene_aabb computed from kernel (same as the single instance's world AABB) world_aabb = scene_aabb @@ -1390,11 +1496,10 @@ function build_tlas( cpu_nodes = [leaf_node] copyto!(nodes, Adapt.adapt(backend, cpu_nodes)) - return StaticTLAS(nodes, backend_instances, blas_array, world_aabb) + return (nodes, world_aabb) end # Multi-instance case: build proper LBVH - # Launch kernel: Emit topology (reuse BLAS topology kernel - same algorithm) topo_kernel! = emit_topology_kernel!(backend) topo_kernel!(nodes, morton_codes, Int32(n), ndrange=n-1) @@ -1403,7 +1508,7 @@ function build_tlas( parent_kernel!(nodes, Int32(n), ndrange=n-1) # Launch kernel: Create TLAS leaf nodes (different from BLAS - stores AABBs, not vertices) leaf_kernel! = create_tlas_leaf_nodes_kernel!(backend) - leaf_kernel!(nodes, sorted_indices, backend_instances, blas_array, Int32(n), ndrange=n) + leaf_kernel!(nodes, sorted_indices, instances, blas_array, Int32(n), ndrange=n) # Refit AABBs bottom-up (parallel using atomic counters) update_flags = KA.zeros(backend, UInt32, n - 1) refit_kernel! = refit_tlas_aabbs_kernel!(backend) @@ -1413,7 +1518,64 @@ function build_tlas( root_node = Array(nodes[1:1])[1] root_aabb = get_tlas_node_aabb(root_node, true) - return StaticTLAS(nodes, backend_instances, blas_array, root_aabb) + return (nodes, root_aabb) +end + +""" + build_tlas(blas_array::AbstractVector{BLAS}, instances::AbstractVector{InstanceDescriptor}) -> StaticTLAS + +Build a Top-Level Acceleration Structure over instances. +Uses LBVH over transformed instance AABBs. + +Returns a StaticTLAS with flat BLAS arrays suitable for ray traversal. +Uses KernelAbstractions for automatic CPU/GPU execution based on input array type. +""" +function build_tlas( + blas_array::AbstractVector{B}, + instances::AbstractVector{InstanceDescriptor} +) where {B <: BLAS} + n_blas = length(blas_array) + n = length(instances) + + if n == 0 + prim_type = n_blas > 0 ? eltype(blas_array[1].primitives) : Triangle{UInt32} + return StaticTLAS( + BVHNode2[], instances, + BVHNode2[], prim_type[], + BLASDescriptor[], + Bounds3() + ) + end + + backend = KA.get_backend(blas_array) + backend_instances = Adapt.adapt(backend, instances) + + nodes, root_aabb = _build_tlas_topology(blas_array, backend_instances, backend) + + # Build flat arrays from BLAS data + descriptors = Vector{BLASDescriptor}(undef, n_blas) + total_nodes = 0 + total_prims = 0 + for i in 1:n_blas + descriptors[i] = BLASDescriptor(UInt32(total_nodes), UInt32(total_prims), blas_array[i].root_aabb) + total_nodes += length(blas_array[i].nodes) + total_prims += length(blas_array[i].primitives) + end + + all_nodes = similar(blas_array[1].nodes, total_nodes) + all_prims = similar(blas_array[1].primitives, total_prims) + nodes_pos = 1 + prims_pos = 1 + for i in 1:n_blas + nn = length(blas_array[i].nodes) + copyto!(all_nodes, nodes_pos, blas_array[i].nodes, 1, nn) + nodes_pos += nn + np = length(blas_array[i].primitives) + copyto!(all_prims, prims_pos, blas_array[i].primitives, 1, np) + prims_pos += np + end + + return StaticTLAS(nodes, backend_instances, all_nodes, all_prims, descriptors, root_aabb) end @@ -1608,7 +1770,7 @@ Algorithm: 4. Transform back to world space 5. Return closest hit across all instances """ -@inline function closest_hit(tlas::TraversableTLAS, ray::R) where {R <: AbstractRay} +@inline function closest_hit(tlas::StaticTLAS, ray::R) where {R <: AbstractRay} # Initialize traversal state - matches HLSL TraceRays ray = check_direction(ray) ray_o::Point3f = ray.o @@ -1632,20 +1794,22 @@ Algorithm: # Entry point is node 1 (1-indexed in Julia) node_index::UInt32 = UInt32(1) + # Cached BLAS offset for current instance (avoids repeated descriptor lookup) + current_blas_offset::UInt32 = UInt32(0) + # Get typed references to avoid repeated field access tlas_nodes = tlas.nodes tlas_instances = tlas.instances - tlas_blas_array = tlas.blas_array + tlas_blas_nodes = tlas.all_blas_nodes + tlas_blas_prims = tlas.all_blas_prims + tlas_blas_descs = tlas.blas_descriptors @inbounds while node_index != INVALID_NODE # Fetch node based on current level node::BVHNode2 = if current_instance < Int32(0) tlas_nodes[node_index] else - # current_instance is a 0-indexed instance index - inst = tlas_instances[current_instance + Int32(1)] - blas = tlas_blas_array[inst.blas_index] - blas.nodes[node_index] + tlas_blas_nodes[current_blas_offset + node_index] end is_leaf::Bool = (node.child0 == INVALID_NODE) @@ -1676,6 +1840,8 @@ Algorithm: # Get instance and transform ray node_index = UInt32(1) # Start at root of BLAS inst = tlas_instances[current_instance + Int32(1)] + desc = tlas_blas_descs[inst.blas_index] + current_blas_offset = desc.nodes_offset ray_o = transform_point(inst.inv_transform, ray.o) ray_d = transform_direction(inst.inv_transform, ray.d) ray_inv_d = safe_invdir(ray_d) @@ -1714,14 +1880,14 @@ Algorithm: # Fill in hit output - matches HLSL @inbounds if closest_instance >= Int32(0) inst = tlas_instances[closest_instance + Int32(1)] - blas = tlas_blas_array[inst.blas_index] - tri = blas.primitives[closest_prim] + desc = tlas_blas_descs[inst.blas_index] + tri = tlas_blas_prims[desc.primitives_offset + closest_prim] w = 1.0f0 - hit_u - hit_v bary = SVector{3, Float32}(w, hit_u, hit_v) return (true, tri, ray_maxt, bary, inst.instance_id) else # No hit - return dummy values - dummy_tri = tlas_blas_array[1].primitives[1] + dummy_tri = tlas_blas_prims[1] bary = SVector{3, Float32}(0.0f0, 0.0f0, 0.0f0) return (false, dummy_tri, 0.0f0, bary, INVALID_NODE) end @@ -1735,7 +1901,7 @@ Faster than closest_hit when only occlusion testing is needed. Matches HLSL TraceRays with ANY_HIT defined. """ -@inline function any_hit(tlas::TraversableTLAS, ray::R) where {R <: AbstractRay} +@inline function any_hit(tlas::StaticTLAS, ray::R) where {R <: AbstractRay} # Initialize traversal state - matches HLSL TraceRays ray = check_direction(ray) ray_o::Point3f = ray.o @@ -1751,24 +1917,23 @@ Matches HLSL TraceRays with ANY_HIT defined. # Traversal state - use Int32 for indices to avoid UInt32 arithmetic issues current_instance::Int32 = Int32(-1) # -1 means no instance (top level) - # Entry point is node 1 (1-indexed in Julia) node_index::UInt32 = UInt32(1) + current_blas_offset::UInt32 = UInt32(0) # Get typed references to avoid repeated field access tlas_nodes = tlas.nodes tlas_instances = tlas.instances - tlas_blas_array = tlas.blas_array + tlas_blas_nodes = tlas.all_blas_nodes + tlas_blas_prims = tlas.all_blas_prims + tlas_blas_descs = tlas.blas_descriptors @inbounds while node_index != INVALID_NODE # Fetch node based on current level node::BVHNode2 = if current_instance < Int32(0) tlas_nodes[node_index] else - # current_instance is a 0-indexed instance index - inst = tlas_instances[current_instance + Int32(1)] - blas = tlas_blas_array[inst.blas_index] - blas.nodes[node_index] + tlas_blas_nodes[current_blas_offset + node_index] end is_leaf::Bool = (node.child0 == INVALID_NODE) @@ -1799,6 +1964,8 @@ Matches HLSL TraceRays with ANY_HIT defined. # Get instance and transform ray node_index = UInt32(1) # Start at root of BLAS inst = tlas_instances[current_instance + Int32(1)] + desc = tlas_blas_descs[inst.blas_index] + current_blas_offset = desc.nodes_offset ray_o = transform_point(inst.inv_transform, ray.o) ray_d = transform_direction(inst.inv_transform, ray.d) ray_inv_d = safe_invdir(ray_d) @@ -1809,8 +1976,8 @@ Matches HLSL TraceRays with ANY_HIT defined. if hit # ANY_HIT: return immediately on first hit inst = tlas_instances[current_instance + Int32(1)] - blas = tlas_blas_array[inst.blas_index] - tri = blas.primitives[node.child1] + desc = tlas_blas_descs[inst.blas_index] + tri = tlas_blas_prims[desc.primitives_offset + node.child1] w = 1.0f0 - u - v bary = SVector{3, Float32}(w, u, v) return (true, tri, t, bary, inst.instance_id) @@ -1836,7 +2003,7 @@ Matches HLSL TraceRays with ANY_HIT defined. end # No hit found - @inbounds dummy_tri = tlas_blas_array[1].primitives[1] + @inbounds dummy_tri = tlas_blas_prims[1] bary = SVector{3, Float32}(0.0f0, 0.0f0, 0.0f0) return (false, dummy_tri, 0.0f0, bary, INVALID_NODE) end @@ -1901,6 +2068,7 @@ function refit_tlas!(tlas::TLAS) backend = tlas.backend # Update leaf node AABBs from new transforms (kernel) + # blas_array is only used for root_aabb (inline data, safe on Metal) leaf_kernel! = update_tlas_leaf_aabbs_kernel!(backend) leaf_kernel!(tlas.nodes, tlas.instances, tlas.blas_array, Int32(n), ndrange=n) # Refit internal nodes bottom-up using atomic counters @@ -2058,18 +2226,18 @@ end BVH-compatible argument order for closest_hit. Returns (hit_found, triangle, distance, barycentric) - same as BVH. """ -function closest_hit(ray::AbstractRay, tlas::TraversableTLAS) +function closest_hit(ray::AbstractRay, tlas::StaticTLAS) hit, tri, t, bary, inst_id = closest_hit(tlas, ray) return (hit, tri, t, bary) end """ - any_hit(ray::AbstractRay, tlas::TLAS) + any_hit(ray::AbstractRay, tlas::StaticTLAS) BVH-compatible argument order for any_hit. Returns (hit_found, triangle, distance, barycentric) - same as BVH. """ -function any_hit(ray::AbstractRay, tlas::TraversableTLAS) +function any_hit(ray::AbstractRay, tlas::StaticTLAS) hit, tri, t, bary, inst_id = any_hit(tlas, ray) return (hit, tri, t, bary) end @@ -2091,9 +2259,8 @@ function Base.eltype(tlas::TLAS) return eltype(prims_array) end -function Base.eltype(::StaticTLAS{NA, IA, BA}) where {NA, IA, BA} - BLASType = eltype(BA) - return eltype(fieldtype(BLASType, :primitives)) +function Base.eltype(::StaticTLAS{NA, IA, BNA, BPA, DA}) where {NA, IA, BNA, BPA, DA} + return eltype(BPA) end # ============================================================================== @@ -2157,10 +2324,11 @@ n_instances(tlas::TraversableTLAS) = length(tlas.instances) Return number of unique BLAS geometries in the TLAS. """ -n_geometries(tlas::TraversableTLAS) = length(tlas.blas_array) +n_geometries(tlas::TLAS) = tlas.blas_array === nothing ? 0 : length(tlas.blas_array) +n_geometries(tlas::StaticTLAS) = length(tlas.blas_descriptors) # Export public API -export BLAS, TLAS, StaticTLAS, TraversableTLAS, InstanceDescriptor, BVHNode2 +export BLAS, BLASDescriptor, TLAS, StaticTLAS, TraversableTLAS, InstanceDescriptor, BVHNode2 export build_blas, build_tlas, closest_hit, any_hit, world_bound export update_instance_transform!, update_instance_transforms!, refit_tlas! export INVALID_NODE diff --git a/test/test_instanced_bvh.jl b/test/test_instanced_bvh.jl index 5ebb888..593a545 100644 --- a/test/test_instanced_bvh.jl +++ b/test/test_instanced_bvh.jl @@ -1212,7 +1212,9 @@ end # Verify all fields are isbits @test isbitstype(typeof(cl_tlas.nodes)) @test isbitstype(typeof(cl_tlas.instances)) - @test isbitstype(typeof(cl_tlas.blas_array)) + @test isbitstype(typeof(cl_tlas.all_blas_nodes)) + @test isbitstype(typeof(cl_tlas.all_blas_prims)) + @test isbitstype(typeof(cl_tlas.blas_descriptors)) end @testset "World bound preserved after adapt" begin