Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 83 additions & 2 deletions libkineto/src/RocprofLogger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,29 @@ auto extract_malloc_args =
return 0;
};

// Correlation-to-stream map: fixes rocprofiler-sdk stream mapping.
// rocprofiler-sdk buffer records only expose HSA queue_id (same for all HIP
// streams on a device). This map captures the actual hipStream_t from the
// callback phase and makes it available to the buffer callback for correct
// per-stream trace rows.
std::mutex g_corrToStreamMutex;
std::unordered_map<uint64_t, uint64_t> g_corrToStream;

std::mutex g_streamIdMutex;
std::unordered_map<uint64_t, uint64_t> g_streamPtrToId;
uint64_t g_nextStreamId = 1;

uint64_t getStreamId(hipStream_t stream) {
uint64_t ptr = reinterpret_cast<uint64_t>(stream);
if (ptr == 0) return 0;
std::lock_guard<std::mutex> lk(g_streamIdMutex);
auto it = g_streamPtrToId.find(ptr);
if (it != g_streamPtrToId.end()) return it->second;
uint64_t id = g_nextStreamId++;
g_streamPtrToId[ptr] = id;
return id;
}

// copy api calls
bool isCopyApi(uint32_t id) {
switch (id) {
Expand Down Expand Up @@ -589,6 +612,12 @@ void RocprofLogger::api_callback(
args.stream);
insert_row_to_buffer(row);

{
std::lock_guard<std::mutex> lk(g_corrToStreamMutex);
g_corrToStream[record.correlation_id.internal] =
getStreamId(args.stream);
}

}
// Copy Records
else if (isCopyApi(record.operation)) {
Expand All @@ -614,6 +643,12 @@ void RocprofLogger::api_callback(
args.copyKind,
args.stream);
insert_row_to_buffer(row);

{
std::lock_guard<std::mutex> lk(g_corrToStreamMutex);
g_corrToStream[record.correlation_id.internal] =
getStreamId(args.stream);
}
}
// Malloc Records
else if (isMallocApi(record.operation)) {
Expand All @@ -639,6 +674,32 @@ void RocprofLogger::api_callback(
}
// Default Records
else {
struct { hipStream_t stream{nullptr}; } default_args;
rocprofiler_iterate_callback_tracing_kind_operation_args(
record,
[]([[maybe_unused]] rocprofiler_callback_tracing_kind_t kind,
[[maybe_unused]] rocprofiler_tracing_operation_t op,
[[maybe_unused]] uint32_t arg_num,
const void* const arg_value_addr,
[[maybe_unused]] int32_t indirection_count,
[[maybe_unused]] const char* arg_type,
const char* arg_name,
[[maybe_unused]] const char* arg_value_str,
[[maybe_unused]] int32_t dereference_count,
void* cb_data) -> int {
auto& a = *static_cast<decltype(default_args)*>(cb_data);
if (strcmp("stream", arg_name) == 0)
a.stream = *(reinterpret_cast<const hipStream_t*>(arg_value_addr));
return 0;
},
1, &default_args);

{
std::lock_guard<std::mutex> lk(g_corrToStreamMutex);
g_corrToStream[record.correlation_id.internal] =
getStreamId(default_args.stream);
}

rocprofRow* row = new rocprofRow(
record.correlation_id.internal,
record.kind,
Expand Down Expand Up @@ -695,13 +756,23 @@ void RocprofLogger::buffer_callback(
? kernel_it->second
: "<unknown kernel>";

uint64_t stream_as_queue = dispatch.queue_id.handle;
{
std::lock_guard<std::mutex> lk(g_corrToStreamMutex);
auto it = g_corrToStream.find(record.correlation_id.internal);
if (it != g_corrToStream.end()) {
stream_as_queue = it->second;
g_corrToStream.erase(it);
}
}

rocprofAsyncRow* row = new rocprofAsyncRow(
record.correlation_id.internal,
record.kind,
record.operation,
record.operation, // shared op - No longer a thing. Placeholder
device_id,
dispatch.queue_id.handle,
stream_as_queue,
record.start_timestamp,
record.end_timestamp,
kernel_name);
Expand All @@ -718,13 +789,23 @@ void RocprofLogger::buffer_callback(
? agent_it->second.logical_node_type_id
: -1;

uint64_t stream_as_queue = 0;
{
std::lock_guard<std::mutex> lk(g_corrToStreamMutex);
auto it = g_corrToStream.find(record.correlation_id.internal);
if (it != g_corrToStream.end()) {
stream_as_queue = it->second;
g_corrToStream.erase(it);
}
}

rocprofAsyncRow* row = new rocprofAsyncRow(
record.correlation_id.internal,
record.kind,
record.operation,
record.operation, // shared op - No longer a thing. Placeholder
device_id,
0,
stream_as_queue,
record.start_timestamp,
record.end_timestamp,
"");
Expand Down
Loading