Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 47 additions & 12 deletions apex/contrib/csrc/conv_bias_relu/conv_bias_relu_rocm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ struct FusionPlanEntry {
miopenFusionOpDescriptor_t conv_op;
miopenFusionOpDescriptor_t bias_op;
miopenFusionOpDescriptor_t activ_op;
size_t workspace_size;
};

static std::unordered_map<std::string, FusionPlanEntry> plan_cache;
Expand Down Expand Up @@ -166,7 +167,10 @@ static std::vector<at::Tensor> conv_bias_forward_dispatch(const at::Tensor& x,
bool use_fusion) {
if (x.is_cuda()) {
if (use_fusion) {
return conv_bias_relu_forward_fused(x, weight, bias, padding, stride, use_relu);
auto result = conv_bias_relu_forward_fused(x, weight, bias, padding, stride, use_relu);
if (!result.empty()) {
return result;
}
}
return conv_bias_forward(x, weight, bias, padding, stride, use_relu);
}
Expand Down Expand Up @@ -236,15 +240,17 @@ static std::vector<at::Tensor> conv_bias_relu_forward_fused(const at::Tensor& x,

MIOPEN_CHECK(miopenCreateOpConvForward(plan, &conv_op, conv_desc, weight_desc));

miopenConvFwdAlgorithm_t selected_algo = miopenConvolutionFwdAlgoGEMM;
int returned_algo_count = 0;
miopenFusionPlanConvolutionGetAlgo(plan, 1, &returned_algo_count, &selected_algo);

// 2. Bias Op
miopenFusionOpDescriptor_t bias_op = nullptr;
if (bias.defined()) {
miopenTensorDescriptor_t bias_desc = nullptr;
MIOPEN_CHECK(miopenCreateTensorDescriptor(&bias_desc));
if(is_nhwc)
MIOPEN_CHECK(miopenSet4dTensorDescriptor(bias_desc, dtype, 1, (int)x.size(3), 1, 1));
else
MIOPEN_CHECK(miopenSet4dTensorDescriptor(bias_desc, dtype, 1, (int)x.size(1), 1, 1));
int64_t oc = weight.size(0);
MIOPEN_CHECK(miopenSet4dTensorDescriptor(bias_desc, dtype, 1, (int)oc, 1, 1));
MIOPEN_CHECK(miopenCreateOpBiasForward(plan, &bias_op, bias_desc));
miopenDestroyTensorDescriptor(bias_desc);
}
Expand All @@ -253,18 +259,38 @@ static std::vector<at::Tensor> conv_bias_relu_forward_fused(const at::Tensor& x,
miopenFusionOpDescriptor_t activ_op = nullptr;
if (use_relu) {
MIOPEN_CHECK(miopenCreateOpActivationForward(plan, &activ_op, miopenActivationRELU));
}else
{
MIOPEN_CHECK(miopenCreateOpActivationForward(plan, &activ_op, miopenActivationCLAMP));
} else {
MIOPEN_CHECK(miopenCreateOpActivationForward(plan, &activ_op, miopenActivationCLAMP));
}

// Compile
MIOPEN_CHECK(miopenCompileFusionPlan(handle, plan));
miopenStatus_t compile_status = miopenCompileFusionPlan(handle, plan);
if (compile_status != miopenStatusSuccess) {
miopenDestroyFusionPlan(plan);
miopenDestroyTensorDescriptor(input_desc);
miopenDestroyTensorDescriptor(weight_desc);
miopenDestroyConvolutionDescriptor(conv_desc);

return {};
}

size_t ws_size = 0;
miopenFusionPlanGetWorkSpaceSize(handle, plan, &ws_size, selected_algo);

// miopenFusionPlanGetWorkSpaceSize may not account for NCHW->NHWC
// layout transform workspace. Ensure enough space for input + output.
size_t elem_size = (dtype == miopenHalf) ? 2 : 4;
int64_t out_h = (x.size(2) + 2 * padding - weight.size(2)) / stride + 1;
int64_t out_w = (x.size(3) + 2 * padding - weight.size(3)) / stride + 1;
size_t input_bytes = static_cast<size_t>(x.numel()) * elem_size;
size_t output_bytes = static_cast<size_t>(x.size(0)) * weight.size(0) * out_h * out_w * elem_size;
ws_size = std::max(ws_size, input_bytes + output_bytes);

plan_cache[key].fusion_plan = plan;
plan_cache[key].conv_op = conv_op;
plan_cache[key].bias_op = bias_op;
plan_cache[key].activ_op = activ_op;
plan_cache[key].workspace_size = ws_size;

miopenDestroyTensorDescriptor(input_desc);
miopenDestroyTensorDescriptor(weight_desc);
Expand Down Expand Up @@ -312,21 +338,30 @@ static std::vector<at::Tensor> conv_bias_relu_forward_fused(const at::Tensor& x,
if (entry.activ_op) {
if (use_relu)
MIOPEN_CHECK(miopenSetOpArgsActivForward(args, entry.activ_op, &alpha, &beta, 0.0, 0.0, 0.0));
else{
else {
float alpha1 = -3.402823466e+38F, beta1 = 3.402823466e+38F;
MIOPEN_CHECK(miopenSetOpArgsActivForward(args, entry.activ_op, &alpha, &beta, alpha1, beta1, 0.0));
}
}

MIOPEN_CHECK(miopenExecuteFusionPlan(handle, entry.fusion_plan,
auto workspace = at::empty({static_cast<long long>(entry.workspace_size)},
x.options().dtype(at::kByte));
void* ws_ptr = entry.workspace_size > 0 ? workspace.data_ptr() : nullptr;

miopenStatus_t exec_status = miopenExecuteFusionPlan_v2(handle, entry.fusion_plan,
input_desc, x.data_ptr(),
output_desc, out.data_ptr(),
args));
args, ws_ptr, entry.workspace_size);

miopenDestroyOperatorArgs(args);
miopenDestroyTensorDescriptor(input_desc);
miopenDestroyTensorDescriptor(output_desc);

if (exec_status != miopenStatusSuccess) {
plan_cache.erase(key);
return {};
}

return {out};
}

Expand Down