From a571543e240a5d91a9745cff505e888acd427135 Mon Sep 17 00:00:00 2001 From: skishore Date: Tue, 3 Mar 2026 10:41:10 +0000 Subject: [PATCH] add support for workspace size, fix workspace size for channel first format, destroy variables if fusion fails, if fusion fails then fallback to non-fusion approach --- .../conv_bias_relu/conv_bias_relu_rocm.cpp | 59 +++++++++++++++---- 1 file changed, 47 insertions(+), 12 deletions(-) diff --git a/apex/contrib/csrc/conv_bias_relu/conv_bias_relu_rocm.cpp b/apex/contrib/csrc/conv_bias_relu/conv_bias_relu_rocm.cpp index 7668053e2..3fe5e2d45 100644 --- a/apex/contrib/csrc/conv_bias_relu/conv_bias_relu_rocm.cpp +++ b/apex/contrib/csrc/conv_bias_relu/conv_bias_relu_rocm.cpp @@ -28,6 +28,7 @@ struct FusionPlanEntry { miopenFusionOpDescriptor_t conv_op; miopenFusionOpDescriptor_t bias_op; miopenFusionOpDescriptor_t activ_op; + size_t workspace_size; }; static std::unordered_map plan_cache; @@ -166,7 +167,10 @@ static std::vector conv_bias_forward_dispatch(const at::Tensor& x, bool use_fusion) { if (x.is_cuda()) { if (use_fusion) { - return conv_bias_relu_forward_fused(x, weight, bias, padding, stride, use_relu); + auto result = conv_bias_relu_forward_fused(x, weight, bias, padding, stride, use_relu); + if (!result.empty()) { + return result; + } } return conv_bias_forward(x, weight, bias, padding, stride, use_relu); } @@ -236,15 +240,17 @@ static std::vector conv_bias_relu_forward_fused(const at::Tensor& x, MIOPEN_CHECK(miopenCreateOpConvForward(plan, &conv_op, conv_desc, weight_desc)); + miopenConvFwdAlgorithm_t selected_algo = miopenConvolutionFwdAlgoGEMM; + int returned_algo_count = 0; + miopenFusionPlanConvolutionGetAlgo(plan, 1, &returned_algo_count, &selected_algo); + // 2. Bias Op miopenFusionOpDescriptor_t bias_op = nullptr; if (bias.defined()) { miopenTensorDescriptor_t bias_desc = nullptr; MIOPEN_CHECK(miopenCreateTensorDescriptor(&bias_desc)); - if(is_nhwc) - MIOPEN_CHECK(miopenSet4dTensorDescriptor(bias_desc, dtype, 1, (int)x.size(3), 1, 1)); - else - MIOPEN_CHECK(miopenSet4dTensorDescriptor(bias_desc, dtype, 1, (int)x.size(1), 1, 1)); + int64_t oc = weight.size(0); + MIOPEN_CHECK(miopenSet4dTensorDescriptor(bias_desc, dtype, 1, (int)oc, 1, 1)); MIOPEN_CHECK(miopenCreateOpBiasForward(plan, &bias_op, bias_desc)); miopenDestroyTensorDescriptor(bias_desc); } @@ -253,18 +259,38 @@ static std::vector conv_bias_relu_forward_fused(const at::Tensor& x, miopenFusionOpDescriptor_t activ_op = nullptr; if (use_relu) { MIOPEN_CHECK(miopenCreateOpActivationForward(plan, &activ_op, miopenActivationRELU)); - }else - { - MIOPEN_CHECK(miopenCreateOpActivationForward(plan, &activ_op, miopenActivationCLAMP)); + } else { + MIOPEN_CHECK(miopenCreateOpActivationForward(plan, &activ_op, miopenActivationCLAMP)); } // Compile - MIOPEN_CHECK(miopenCompileFusionPlan(handle, plan)); + miopenStatus_t compile_status = miopenCompileFusionPlan(handle, plan); + if (compile_status != miopenStatusSuccess) { + miopenDestroyFusionPlan(plan); + miopenDestroyTensorDescriptor(input_desc); + miopenDestroyTensorDescriptor(weight_desc); + miopenDestroyConvolutionDescriptor(conv_desc); + + return {}; + } + + size_t ws_size = 0; + miopenFusionPlanGetWorkSpaceSize(handle, plan, &ws_size, selected_algo); + + // miopenFusionPlanGetWorkSpaceSize may not account for NCHW->NHWC + // layout transform workspace. Ensure enough space for input + output. + size_t elem_size = (dtype == miopenHalf) ? 2 : 4; + int64_t out_h = (x.size(2) + 2 * padding - weight.size(2)) / stride + 1; + int64_t out_w = (x.size(3) + 2 * padding - weight.size(3)) / stride + 1; + size_t input_bytes = static_cast(x.numel()) * elem_size; + size_t output_bytes = static_cast(x.size(0)) * weight.size(0) * out_h * out_w * elem_size; + ws_size = std::max(ws_size, input_bytes + output_bytes); plan_cache[key].fusion_plan = plan; plan_cache[key].conv_op = conv_op; plan_cache[key].bias_op = bias_op; plan_cache[key].activ_op = activ_op; + plan_cache[key].workspace_size = ws_size; miopenDestroyTensorDescriptor(input_desc); miopenDestroyTensorDescriptor(weight_desc); @@ -312,21 +338,30 @@ static std::vector conv_bias_relu_forward_fused(const at::Tensor& x, if (entry.activ_op) { if (use_relu) MIOPEN_CHECK(miopenSetOpArgsActivForward(args, entry.activ_op, &alpha, &beta, 0.0, 0.0, 0.0)); - else{ + else { float alpha1 = -3.402823466e+38F, beta1 = 3.402823466e+38F; MIOPEN_CHECK(miopenSetOpArgsActivForward(args, entry.activ_op, &alpha, &beta, alpha1, beta1, 0.0)); } } - MIOPEN_CHECK(miopenExecuteFusionPlan(handle, entry.fusion_plan, + auto workspace = at::empty({static_cast(entry.workspace_size)}, + x.options().dtype(at::kByte)); + void* ws_ptr = entry.workspace_size > 0 ? workspace.data_ptr() : nullptr; + + miopenStatus_t exec_status = miopenExecuteFusionPlan_v2(handle, entry.fusion_plan, input_desc, x.data_ptr(), output_desc, out.data_ptr(), - args)); + args, ws_ptr, entry.workspace_size); miopenDestroyOperatorArgs(args); miopenDestroyTensorDescriptor(input_desc); miopenDestroyTensorDescriptor(output_desc); + if (exec_status != miopenStatusSuccess) { + plan_cache.erase(key); + return {}; + } + return {out}; }