diff --git a/libkineto/src/RocprofLogger.cpp b/libkineto/src/RocprofLogger.cpp index 3c23e9786..24a8b34bf 100644 --- a/libkineto/src/RocprofLogger.cpp +++ b/libkineto/src/RocprofLogger.cpp @@ -154,6 +154,29 @@ auto extract_malloc_args = return 0; }; +// Correlation-to-stream map: fixes rocprofiler-sdk stream mapping. +// rocprofiler-sdk buffer records only expose HSA queue_id (same for all HIP +// streams on a device). This map captures the actual hipStream_t from the +// callback phase and makes it available to the buffer callback for correct +// per-stream trace rows. +std::mutex g_corrToStreamMutex; +std::unordered_map g_corrToStream; + +std::mutex g_streamIdMutex; +std::unordered_map g_streamPtrToId; +uint64_t g_nextStreamId = 1; + +uint64_t getStreamId(hipStream_t stream) { + uint64_t ptr = reinterpret_cast(stream); + if (ptr == 0) return 0; + std::lock_guard lk(g_streamIdMutex); + auto it = g_streamPtrToId.find(ptr); + if (it != g_streamPtrToId.end()) return it->second; + uint64_t id = g_nextStreamId++; + g_streamPtrToId[ptr] = id; + return id; +} + // copy api calls bool isCopyApi(uint32_t id) { switch (id) { @@ -589,6 +612,12 @@ void RocprofLogger::api_callback( args.stream); insert_row_to_buffer(row); + { + std::lock_guard lk(g_corrToStreamMutex); + g_corrToStream[record.correlation_id.internal] = + getStreamId(args.stream); + } + } // Copy Records else if (isCopyApi(record.operation)) { @@ -614,6 +643,12 @@ void RocprofLogger::api_callback( args.copyKind, args.stream); insert_row_to_buffer(row); + + { + std::lock_guard lk(g_corrToStreamMutex); + g_corrToStream[record.correlation_id.internal] = + getStreamId(args.stream); + } } // Malloc Records else if (isMallocApi(record.operation)) { @@ -639,6 +674,32 @@ void RocprofLogger::api_callback( } // Default Records else { + struct { hipStream_t stream{nullptr}; } default_args; + rocprofiler_iterate_callback_tracing_kind_operation_args( + record, + []([[maybe_unused]] rocprofiler_callback_tracing_kind_t kind, + [[maybe_unused]] rocprofiler_tracing_operation_t op, + [[maybe_unused]] uint32_t arg_num, + const void* const arg_value_addr, + [[maybe_unused]] int32_t indirection_count, + [[maybe_unused]] const char* arg_type, + const char* arg_name, + [[maybe_unused]] const char* arg_value_str, + [[maybe_unused]] int32_t dereference_count, + void* cb_data) -> int { + auto& a = *static_cast(cb_data); + if (strcmp("stream", arg_name) == 0) + a.stream = *(reinterpret_cast(arg_value_addr)); + return 0; + }, + 1, &default_args); + + { + std::lock_guard lk(g_corrToStreamMutex); + g_corrToStream[record.correlation_id.internal] = + getStreamId(default_args.stream); + } + rocprofRow* row = new rocprofRow( record.correlation_id.internal, record.kind, @@ -695,13 +756,23 @@ void RocprofLogger::buffer_callback( ? kernel_it->second : ""; + uint64_t stream_as_queue = dispatch.queue_id.handle; + { + std::lock_guard lk(g_corrToStreamMutex); + auto it = g_corrToStream.find(record.correlation_id.internal); + if (it != g_corrToStream.end()) { + stream_as_queue = it->second; + g_corrToStream.erase(it); + } + } + rocprofAsyncRow* row = new rocprofAsyncRow( record.correlation_id.internal, record.kind, record.operation, record.operation, // shared op - No longer a thing. Placeholder device_id, - dispatch.queue_id.handle, + stream_as_queue, record.start_timestamp, record.end_timestamp, kernel_name); @@ -718,13 +789,23 @@ void RocprofLogger::buffer_callback( ? agent_it->second.logical_node_type_id : -1; + uint64_t stream_as_queue = 0; + { + std::lock_guard lk(g_corrToStreamMutex); + auto it = g_corrToStream.find(record.correlation_id.internal); + if (it != g_corrToStream.end()) { + stream_as_queue = it->second; + g_corrToStream.erase(it); + } + } + rocprofAsyncRow* row = new rocprofAsyncRow( record.correlation_id.internal, record.kind, record.operation, record.operation, // shared op - No longer a thing. Placeholder device_id, - 0, + stream_as_queue, record.start_timestamp, record.end_timestamp, "");