diff --git a/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp b/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp index 7ae13d514..5be8b2840 100644 --- a/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp +++ b/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp @@ -4,17 +4,37 @@ void multi_tensor_fused_adam_cuda( int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, - at::Tensor per_tensor_beta1, - at::Tensor per_tensor_beta2, - at::Tensor per_tensor_bias_correction, - at::Tensor per_tensor_eps, - at::Tensor per_tensor_weight_decay, + at::Tensor grad_scale, float lr, - float grad_scale, + float beta1, + float beta2, + float eps, int step, - int mode); + int mode, + int bias_correction, + float weight_decay); + +void multi_tensor_fused_adam_with_param_remainders_cuda( + int chunk_size, + at::Tensor noop_flag, + std::vector> tensor_lists, + at::Tensor grad_scale, + float lr, + float beta1, + float beta2, + float eps, + int step, + int mode, + int bias_correction, + float weight_decay); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("multi_tensor_fused_adam", &multi_tensor_fused_adam_cuda, - "Multi tensor Adam optimized CUDA implementation."); + m.def("multi_tensor_fused_adam", + &multi_tensor_fused_adam_cuda, + "CUDA kernels for multi-tensor Adam, " + "with param copy"); + m.def("multi_tensor_fused_adam_with_param_remainders", + &multi_tensor_fused_adam_with_param_remainders_cuda, + "CUDA kernel for multi-tensor Adam, " + "with stored param remainders and param copy"); } diff --git a/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu b/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu index f89fb594e..1c7f02e64 100644 --- a/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu +++ b/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu @@ -14,151 +14,315 @@ #define ILP 4 template -__device__ __forceinline__ bool is_aligned(T* p){ +__device__ __forceinline__ bool is_aligned(const T* p){ return ((uint64_t)p) % (ILP*sizeof(T)) == 0; } template -__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){ +__device__ __forceinline__ void load_store( + T* dst, + const T* src, + int dst_offset = 0, + int src_offset = 0){ typedef typename std::aligned_storage::type LT; - ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset]; + ((LT*)dst)[dst_offset] = ((const LT*)src)[src_offset]; +} + +// (1-t)*x + t*y +__device__ __forceinline__ float lerp(float t, float x, float y) { + // See https://developer.nvidia.com/blog/lerp-faster-cuda/ + return fma(t, y, fma(-t, x, x)); } typedef enum{ - ADAM_MODE_0 =0, // eps under square root - ADAM_MODE_1 =1 // eps outside square root + ADAM_MODE_0 =0, // L2 regularization mode + ADAM_MODE_1 =1 // Decoupled weight decay mode(AdamW) } adamMode_t; -template +/* Multi-tensor Adam + * + * Updates params in-place and outputs a copy with a desired datatype. + */ +template struct DistAdamFunctor { + // Vectorized local compute + __device__ __forceinline__ static void local_step( + T p[ILP], + T m[ILP], + T v[ILP], + const GRAD_T g[ILP], + const float grad_scale, + const float beta1, + const float beta2, + const float beta1_correction, + const float beta2_correction, + const float eps, + const float lr, + adamMode_t mode, + const float weight_decay) { + if (mode == ADAM_MODE_0) { // L2 +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + float scaled_grad = (g[ii] * grad_scale) + (weight_decay * p[ii]); + float next_m = lerp(beta1, scaled_grad, m[ii]); + float next_v = lerp(beta2, scaled_grad*scaled_grad, v[ii]); + float next_m_unbiased = next_m / beta1_correction; + float next_v_unbiased = next_v / beta2_correction; + float denom = sqrtf(next_v_unbiased) + eps; + float update = next_m_unbiased / denom; + m[ii] = next_m; + v[ii] = next_v; + p[ii] -= lr * update; + } + } else { // weight decay +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + float scaled_grad = g[ii] * grad_scale; + float next_m = lerp(beta1, scaled_grad, m[ii]); + float next_v = lerp(beta2, scaled_grad*scaled_grad, v[ii]); + float next_m_unbiased = next_m / beta1_correction; + float next_v_unbiased = next_v / beta2_correction; + float denom = sqrtf(next_v_unbiased) + eps; + float update = (next_m_unbiased / denom) + (weight_decay * p[ii]); + m[ii] = next_m; + v[ii] = next_v; + p[ii] -= lr * update; + } + } + } + __device__ __forceinline__ void operator()( int chunk_size, volatile int* noop_gmem, - TensorListMetadata& tl, - const float* per_tensor_beta1, - const float* per_tensor_beta2, - const int* per_tensor_bias_correction, - const float* per_tensor_eps, - const float* per_tensor_weight_decay, + TensorListMetadata<5>& tl, + const float* grad_scale_ptr, + const float beta1, + const float beta2, + const float beta1_correction, + const float beta2_correction, + const float eps, const float lr, - const float grad_scale, - const int step, - adamMode_t mode) + adamMode_t mode, + const float weight_decay) const { int tensor_loc = tl.block_to_tensor[blockIdx.x]; - int tensor_num = tl.start_tensor_this_launch + tensor_loc; int chunk_idx = tl.block_to_chunk[blockIdx.x]; int n = tl.sizes[tensor_loc]; - float b1 = per_tensor_beta1[tensor_num]; - float b2 = per_tensor_beta2[tensor_num]; - float eps = per_tensor_eps[tensor_num]; - float decay = per_tensor_weight_decay[tensor_num]; + const float grad_scale = *grad_scale_ptr; - float beta1_correction = 1.0f, beta2_correction = 1.0f; - if (per_tensor_bias_correction[tensor_num] == 1) { - beta1_correction = 1 - std::pow(b1, step); - beta2_correction = 1 - std::pow(b2, step); - } - - T* p = (T *)tl.addresses[0][tensor_loc]; - p += chunk_idx*chunk_size; + T* p_in = (T *)tl.addresses[0][tensor_loc]; + p_in += chunk_idx*chunk_size; T* m = (T *)tl.addresses[1][tensor_loc]; m += chunk_idx*chunk_size; T* v = (T *)tl.addresses[2][tensor_loc]; v += chunk_idx*chunk_size; - GRAD_T* g = (GRAD_T *)tl.addresses[3][tensor_loc]; + const GRAD_T* g = (GRAD_T *)tl.addresses[3][tensor_loc]; g += chunk_idx*chunk_size; - GRAD_T* p_copy = NULL; - if (DEPTH == 5) { - p_copy = (GRAD_T *)tl.addresses[4][tensor_loc]; - p_copy += chunk_idx*chunk_size; - } + PARAM_OUT_T* p_out = (PARAM_OUT_T *)tl.addresses[4][tensor_loc]; + p_out += chunk_idx*chunk_size; n -= chunk_idx*chunk_size; - - T incoming_p[ILP]; - T incoming_m[ILP]; - T incoming_v[ILP]; - T incoming_g[ILP]; - - // to make things simple, we put aligned case in a different code path - if (n % ILP == 0 && - chunk_size % ILP == 0 && - is_aligned(p) && - is_aligned(m) && - is_aligned(v) && - is_aligned(g) && - is_aligned(p_copy)) { - for (int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x) { - // load - GRAD_T tmp_g[ILP]; - load_store(incoming_p, p, 0, i_start); - load_store(incoming_m, m, 0, i_start); - load_store(incoming_v, v, 0, i_start); - load_store(tmp_g, g, 0, i_start); + n = chunk_size < n ? chunk_size : n; + + const bool aligned = (n % ILP == 0 && + is_aligned(p_in) && + is_aligned(m) && + is_aligned(v) && + is_aligned(g) && + is_aligned(p_out)); + + for (int i_start = threadIdx.x*ILP; i_start < n; i_start += blockDim.x*ILP) { + T local_p[ILP]; + T local_m[ILP]; + T local_v[ILP]; + GRAD_T local_g[ILP]; + PARAM_OUT_T local_p_out[ILP]; + + // Load + if (aligned) { + load_store(local_p, p_in + i_start); + load_store(local_m, m + i_start); + load_store(local_v, v + i_start); + load_store(local_g, g + i_start); + } else { #pragma unroll - for (int ii = 0; ii < ILP; ii++) { - incoming_g[ii] = static_cast(tmp_g[ii]); - T scaled_grad = incoming_g[ii]/grad_scale; - incoming_m[ii] = b1*incoming_m[ii] + (1-b1)*scaled_grad; - incoming_v[ii] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad; - T next_m_unbiased = incoming_m[ii] / beta1_correction; - T next_v_unbiased = incoming_v[ii] / beta2_correction; - float denom; - if (mode == ADAM_MODE_0) - denom = sqrtf(next_v_unbiased + eps); - else // Mode 1 - denom = sqrtf(next_v_unbiased) + eps; - float update = (next_m_unbiased / denom) + (decay * incoming_p[ii]); - incoming_p[ii] = incoming_p[ii] - (lr * update); - if (DEPTH == 5) tmp_g[ii] = static_cast(incoming_p[ii]); + for (int ii = 0, i = i_start; ii < ILP; ii++, i++) { + if (i < n) { + local_p[ii] = p_in[i]; + local_m[ii] = m[i]; + local_v[ii] = v[i]; + local_g[ii] = g[i]; + } else { + local_p[ii] = 0; + local_m[ii] = 0; + local_v[ii] = 0; + local_g[ii] = 0; + } } - load_store(p, incoming_p, i_start, 0); - load_store(m, incoming_m, i_start, 0); - load_store(v, incoming_v, i_start, 0); - if (DEPTH == 5) load_store(p_copy, tmp_g, i_start, 0); } - } else { - for (int i_start = 0; - i_start < n && i_start < chunk_size; - i_start += blockDim.x*ILP) { + // Local compute + local_step( + local_p, local_m, local_v, local_g, grad_scale, + beta1, beta2, beta1_correction, beta2_correction, + eps, lr, mode, weight_decay); +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + local_p_out[ii] = static_cast(local_p[ii]); + } + + // Store + if (aligned) { + load_store(p_in + i_start, local_p); + load_store(m + i_start, local_m); + load_store(v + i_start, local_v); + load_store(p_out + i_start, local_p_out); + } else { +#pragma unroll + for (int ii = 0, i = i_start; ii < ILP; ii++, i++) { + if (i < n) { + p_in[i] = local_p[ii]; + m[i] = local_m[ii]; + v[i] = local_v[ii]; + p_out[i] = local_p_out[ii]; + } + } + } + } + } +}; + +/* Functor for multi-tensor Adam with implicit main params + * + * If params are BF16 and optimizer state is FP32, it is not necessary + * to store FP32 main params. Instead, store 16-bit param remainder + * and combine with BF16 param to reconstruct the FP32 main param. + */ +struct DistAdamWithParamRemaindersFunctor +{ + __device__ __forceinline__ void operator()( + int chunk_size, + volatile int* noop_gmem, + TensorListMetadata<6>& tl, + const float* grad_scale_ptr, + const float beta1, + const float beta2, + const float beta1_correction, + const float beta2_correction, + const float eps, + const float lr, + adamMode_t mode, + const float weight_decay) const + { + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + const float grad_scale = *grad_scale_ptr; + + int16_t* p_in = (int16_t *)tl.addresses[0][tensor_loc]; + p_in += chunk_idx*chunk_size; + int16_t* p_rem = (int16_t *)tl.addresses[1][tensor_loc]; + p_rem += chunk_idx*chunk_size; + float* m = (float *)tl.addresses[2][tensor_loc]; + m += chunk_idx*chunk_size; + float* v = (float *)tl.addresses[3][tensor_loc]; + v += chunk_idx*chunk_size; + float* g = (float *)tl.addresses[4][tensor_loc]; + g += chunk_idx*chunk_size; + int16_t* p_out = (int16_t *)tl.addresses[5][tensor_loc]; + p_out += chunk_idx*chunk_size; + + n -= chunk_idx*chunk_size; + n = chunk_size < n ? chunk_size : n; + + const bool aligned = (n % ILP == 0 && + is_aligned(p_in) && + is_aligned(p_rem) && + is_aligned(m) && + is_aligned(v) && + is_aligned(g) && + is_aligned(p_out)); + + for (int i_start = threadIdx.x*ILP; i_start < n; i_start += blockDim.x*ILP) { + union fp32_or_int162 { + float fp32; + int16_t int16[2]; + }; + fp32_or_int162 local_p[ILP]; + int16_t local_p_bf16[ILP]; + int16_t local_p_rem[ILP]; + float local_m[ILP]; + float local_v[ILP]; + float local_g[ILP]; + + // Load + if (aligned) { + load_store(local_p_bf16, p_in + i_start); + load_store(local_p_rem, p_rem + i_start); + load_store(local_m, m + i_start); + load_store(local_v, v + i_start); + load_store(local_g, g + i_start); + } else { #pragma unroll - for (int ii = 0; ii < ILP; ii++) { - incoming_p[ii] = 0; - incoming_m[ii] = 0; - incoming_v[ii] = 0; - incoming_g[ii] = 0; - - int i = i_start + threadIdx.x + ii*blockDim.x; - if (i < n && i < chunk_size) { - incoming_p[ii] = p[i]; - incoming_m[ii] = m[i]; - incoming_v[ii] = v[i]; - incoming_g[ii] = static_cast(g[i]); + for (int ii = 0, i = i_start; ii < ILP; ii++, i++) { + if (i < n) { + local_p_bf16[ii] = p_in[i]; + local_p_rem[ii] = p_rem[i]; + local_m[ii] = m[i]; + local_v[ii] = v[i]; + local_g[ii] = g[i]; + } else { + local_p_bf16[ii] = 0; + local_p_rem[ii] = 0; + local_m[ii] = 0; + local_v[ii] = 0; + local_g[ii] = 0; } } + } + + // Reconstruct FP32 params +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + if (local_p_rem[ii] < 0) + local_p_bf16[ii]--; // Undo rounding + local_p[ii].int16[1] = local_p_bf16[ii]; + local_p[ii].int16[0] = local_p_rem[ii]; + } + + // Local compute + using LocalFunctor = DistAdamFunctor; + LocalFunctor::local_step( + reinterpret_cast(local_p), local_m, local_v, local_g, grad_scale, + beta1, beta2, beta1_correction, beta2_correction, + eps, lr, mode, weight_decay); + // Split into BF16 params (rounded-to-nearest) and remainders #pragma unroll - for (int ii = 0; ii < ILP; ii++) { - int j = i_start + threadIdx.x + ii*blockDim.x; - - if (j < n && j < chunk_size) { - T scaled_grad = incoming_g[ii]/grad_scale; - m[j] = b1*incoming_m[ii] + (1-b1)*scaled_grad; - v[j] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad; - T next_m_unbiased = m[j] / beta1_correction; - T next_v_unbiased = v[j] / beta2_correction; - float denom; - if (mode == ADAM_MODE_0) - denom = sqrtf(next_v_unbiased + eps); - else // Mode 1 - denom = sqrtf(next_v_unbiased) + eps; - float update = (next_m_unbiased / denom) + (decay * incoming_p[ii]); - p[j] = incoming_p[ii] - (lr * update); - if (DEPTH == 5) p_copy[j] = (GRAD_T) p[j]; + for (int ii = 0; ii < ILP; ii++) { + local_p_bf16[ii] = local_p[ii].int16[1]; + local_p_rem[ii] = local_p[ii].int16[0]; + if (local_p_rem[ii] < 0) + local_p_bf16[ii]++; // Round up + } + + // Store + if (aligned) { + load_store(p_rem + i_start, local_p_rem); + load_store(m + i_start, local_m); + load_store(v + i_start, local_v); + load_store(p_out + i_start, local_p_bf16); + } else { +#pragma unroll + for (int ii = 0, i = i_start; ii < ILP; ii++, i++) { + if (i < n) { + p_rem[i] = local_p_rem[ii]; + m[i] = local_m[ii]; + v[i] = local_v[ii]; + p_out[i] = local_p_bf16[ii]; } } } @@ -169,60 +333,96 @@ struct DistAdamFunctor void multi_tensor_fused_adam_cuda( int chunk_size, at::Tensor noop_flag, - std::vector> tensor_lists, // p, m, v, g, p_copy - at::Tensor per_tensor_beta1, - at::Tensor per_tensor_beta2, - at::Tensor per_tensor_bias_correction, - at::Tensor per_tensor_eps, - at::Tensor per_tensor_weight_decay, + std::vector> tensor_lists, // p_in, m, v, g, p_out + at::Tensor grad_scale, float lr, - float grad_scale, + float beta1, + float beta2, + float eps, int step, - int mode) + int mode, + int bias_correction, + float weight_decay) { using namespace at; + // Expect p_in, m, v, g, p_out size_t tl_sz = tensor_lists.size(); - AT_ASSERTM(tl_sz == 4 || tl_sz == 5, "expected tensor lists of size 4 or 5"); + AT_ASSERTM(tl_sz == 5, "expected tensor lists of size 5"); - if (tl_sz == 5) { - DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[3][0].scalar_type(), 0, "dist_adam_cuda_kernel", // g - using accscalar_t = at::acc_type; + // Assume p_in and g have same type + auto p_in_type = tensor_lists[0][0].scalar_type(); + auto g_type = tensor_lists[3][0].scalar_type(); + auto p_out_type = tensor_lists[4][0].scalar_type(); + AT_ASSERTM(p_in_type == g_type, "expected main params and grads to have same type"); + + float beta1_correction = 1.0f, beta2_correction = 1.0f; + if (bias_correction == 1) { + beta1_correction = 1 - std::pow(beta1, step); + beta2_correction = 1 - std::pow(beta2, step); + } + + DISPATCH_FLOAT_HALF_AND_BFLOAT(p_in_type, 0, "dist_adam_cuda_kernel", + DISPATCH_FLOAT_HALF_AND_BFLOAT(p_out_type, 1, "dist_adam_cuda_kernel", multi_tensor_apply<5>( BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - DistAdamFunctor<5, accscalar_t, scalar_t_0>(), - per_tensor_beta1.DATA_PTR(), - per_tensor_beta2.DATA_PTR(), - per_tensor_bias_correction.DATA_PTR(), - per_tensor_eps.DATA_PTR(), - per_tensor_weight_decay.DATA_PTR(), - lr, - grad_scale, - step, - (adamMode_t) mode); - ); - } else { - DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[3][0].scalar_type(), 0, "dist_adam_cuda_kernel", // g - using accscalar_t = at::acc_type; - multi_tensor_apply<4>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - DistAdamFunctor<4, accscalar_t, scalar_t_0>(), - per_tensor_beta1.DATA_PTR(), - per_tensor_beta2.DATA_PTR(), - per_tensor_bias_correction.DATA_PTR(), - per_tensor_eps.DATA_PTR(), - per_tensor_weight_decay.DATA_PTR(), + DistAdamFunctor(), + grad_scale.DATA_PTR(), + beta1, + beta2, + beta1_correction, + beta2_correction, + eps, lr, - grad_scale, - step, - (adamMode_t) mode); - ); + (adamMode_t) mode, + weight_decay); + )); + C10_CUDA_CHECK(cudaGetLastError()); +} + +void multi_tensor_fused_adam_with_param_remainders_cuda( + int chunk_size, + at::Tensor noop_flag, + std::vector> tensor_lists, // p_in, p_rem, m, v, g, p_out + at::Tensor grad_scale, + float lr, + float beta1, + float beta2, + float eps, + int step, + int mode, + int bias_correction, + float weight_decay) +{ + using namespace at; + + // Expect p_in, p_rem, m, v, g, p_out + size_t tl_sz = tensor_lists.size(); + AT_ASSERTM(tl_sz == 6, "expected tensor lists of size 6"); + + float beta1_correction = 1.0f, beta2_correction = 1.0f; + if (bias_correction == 1) { + beta1_correction = 1 - std::pow(beta1, step); + beta2_correction = 1 - std::pow(beta2, step); } + + multi_tensor_apply<6>( + BLOCK_SIZE, + chunk_size, + noop_flag, + tensor_lists, + DistAdamWithParamRemaindersFunctor(), + grad_scale.DATA_PTR(), + beta1, + beta2, + beta1_correction, + beta2_correction, + eps, + lr, + (adamMode_t) mode, + weight_decay); C10_CUDA_CHECK(cudaGetLastError()); } diff --git a/apex/contrib/optimizers/distributed_fused_adam.py b/apex/contrib/optimizers/distributed_fused_adam.py index 550068022..4a49b5d8a 100644 --- a/apex/contrib/optimizers/distributed_fused_adam.py +++ b/apex/contrib/optimizers/distributed_fused_adam.py @@ -1,26 +1,83 @@ import collections import contextlib import enum -import importlib import inspect import io -import math +import itertools import threading import torch -import amp_C -from apex.multi_tensor_apply import multi_tensor_applier from torch.distributed.distributed_c10d import _get_default_group, _get_global_rank +from apex.multi_tensor_apply import multi_tensor_applier +import amp_C +import distributed_adam_cuda + +_FOUND_DEPRECATED_FUSED_ADAM = False +try: + import fused_adam_cuda + _FOUND_DEPRECATED_FUSED_ADAM = True +except ImportError: + import warnings + warnings.warn( + 'Could not find recommended CUDA kernels when importing ' + '`DistributedFusedAdam`. ' + 'For best performance, Apex should be installed with ' + '`--deprecated_fused_adam`.' + ) def _round_to_multiple(number, multiple, round_up=True): """Assumes arguments are positive integers""" return (number+multiple-1 if round_up else number) // multiple * multiple +def _multi_tensor_copy(buffers_in, buffers_out): + """Copy between corresponding buffers + + Uses fused copy kernel if possible. + """ + + # Group buffers by device and dtype + buffer_groups = collections.defaultdict(list) + for buf_in, buf_out in zip(buffers_in, buffers_out): + if buf_in.data_ptr() == buf_out.data_ptr(): + # Nothing to be done if input and output buffers are same + continue + if buf_in.dtype == buf_out.dtype: + # Just copy bytes if dtypes are same + buf_in = buf_in.view(torch.uint8) + buf_out = buf_out.view(torch.uint8) + key = (buf_in.is_cuda, buf_in.dtype, buf_out.is_cuda, buf_out.dtype) + buffer_groups[key].append((buf_in, buf_out)) + + # Copy each group of buffers + for key, buffers in buffer_groups.items(): + + # Check if buffers support fused kernel + is_cuda_in, dtype_in, is_cuda_out, dtype_out = key + supported_dtypes = (torch.float32, torch.float16) + use_fused_kernel = ( + (dtype_in in supported_dtypes and dtype_out in supported_dtypes) + or + (dtype_in == torch.uint8 and dtype_out == torch.uint8) + ) + use_fused_kernel = use_fused_kernel and is_cuda_in and is_cuda_out + + # Copy buffers + if use_fused_kernel and _FOUND_DEPRECATED_FUSED_ADAM: + dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda') + multi_tensor_applier( + fused_adam_cuda.maybe_cast_mt, + dummy_overflow_buf, + list(zip(*buffers)), + ) + else: + for buf_in, buf_out in buffers: + buf_out.copy_(buf_in) + class DistributedFusedAdam(torch.optim.Optimizer): - """AdamW optimizer with ZeRO algorithm. + """Adam optimizer with ZeRO algorithm. Currently GPU-only. Requires Apex to be installed via - ``python setup.py install --cuda_ext --cpp_ext``. + ``python setup.py install --cuda_ext --cpp_ext --distributed_adam --deprecated_fused_adam``. This implements the ZeRO-2 algorithm, which distributes the optimizer state and gradients between parallel processes. In @@ -38,11 +95,16 @@ class DistributedFusedAdam(torch.optim.Optimizer): params (iterable): iterable of parameters to optimize or dicts defining parameter groups. lr (float, optional): learning rate. (default: 1e-3) + bias_correction (bool, optional): apply correction factor to + moment estimates. (default: True) betas (Tuple[float, float], optional): coefficients used for computing running averages of gradient and its square. (default: (0.9, 0.999)) eps (float, optional): term added to the denominator to improve numerical stability. (default: 1e-8) + adam_w_mode (boolean, optional): Decouple weight decay + regularization (also known as AdamW algorithm) (default: + True) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) amsgrad (boolean, optional): whether to use the AMSGrad @@ -85,6 +147,13 @@ class DistributedFusedAdam(torch.optim.Optimizer): externally (see grad_buffer_view function). It also maximizes memory usage and may prevent overlapping communication and compute. + store_params (bool, optional): store a distributed copy of the + parameters as optimizer state (default: True). This may be + desirable if the optimizer dtype has higher precision than + the parameter dtype. + store_param_remainders (bool, optional): if model is BF16 and + optimizer is FP32, store bits required to reconstruct FP32 + params (default: False). This is an experimental feature. .. _Adam\: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 @@ -135,16 +204,37 @@ def __init__( self.shard_param_range = shard_param_range class StateBucket: - def __init__(self, shard_size, dtype, device): + def __init__( + self, + shard_size, + dtype, + device, + store_params=False, + store_param_remainders=False, + ): """Optimizer state for a bucket""" # Buffer ranges corresponding to parameter fragments self.fragments = [] # Local shard of parameters - self.params_shard = torch.zeros([shard_size], dtype=dtype, device=device) + self.params_shard = None + if store_params: + self.params_shard = torch.zeros( + [shard_size], dtype=dtype, device=device, + ) + # Local shard of parameter remainders + self.param_remainders_shard = None + if store_param_remainders: + self.param_remainders_shard = torch.zeros( + [shard_size], dtype=torch.int16, device=device, + ) # Local shard of first moment estimate - self.exp_avg_shard = torch.zeros([shard_size], dtype=dtype, device=device) + self.exp_avg_shard = torch.zeros( + [shard_size], dtype=dtype, device=device, + ) # Local shard of second moment estimate - self.exp_avg_sq_shard = torch.zeros([shard_size], dtype=dtype, device=device) + self.exp_avg_sq_shard = torch.zeros( + [shard_size], dtype=dtype, device=device, + ) class GradientStatus(enum.Enum): """Status of gradients within a bucket""" @@ -185,6 +275,7 @@ def __init__(self, bias_correction=True, betas=(0.9, 0.999), eps=1e-8, + adam_w_mode=True, weight_decay=0., amsgrad=False, dtype=torch.float32, @@ -199,12 +290,15 @@ def __init__(self, bucket_cap_mb=100, pipeline_size=2, contiguous_grad_buffer=False, + store_params=True, + store_param_remainders=False, ): defaults = dict(lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay) super(DistributedFusedAdam, self).__init__(params, defaults) # Adam options + self.adam_w_mode = adam_w_mode if amsgrad: raise RuntimeError('DistributedFusedAdam does not support the AMSGrad variant.') @@ -213,22 +307,34 @@ def __init__(self, grad_sync_dtype = dtype if param_sync_dtype is None: param_sync_dtype = dtype - supported_dtypes = [ - (torch.float32, torch.float16), - (torch.float32, torch.float32), - ] - if (dtype, grad_sync_dtype) not in supported_dtypes: + supported_dtypes = (torch.float32, torch.float16, torch.bfloat16) + if (dtype not in supported_dtypes + or grad_sync_dtype not in supported_dtypes + or param_sync_dtype not in supported_dtypes): raise RuntimeError( - 'Invalid dtypes for DistributedFusedAdam ' + 'Unsupported dtypes for DistributedFusedAdam ' f'(dtype={dtype}, ' f'grad_sync_dtype={grad_sync_dtype}, ' - f'param_sync_dtype={param_sync_dtype}))') - if device != 'cuda': - raise RuntimeError('DistributedFusedAdam only supports GPU') + f'param_sync_dtype={param_sync_dtype}))' + ) + if grad_sync_dtype != dtype: + raise RuntimeError( + 'DistributedFusedAdam requires dtype to match grad dtype ' + f'(dtype={dtype}, grad_sync_dtype={grad_sync_dtype})' + ) self.dtype = dtype self.grad_sync_dtype = grad_sync_dtype self.param_sync_dtype = param_sync_dtype - self.device = device + + # Device options + device = torch.device(device) + if (device.type != 'cuda' + or device.index not in (None, torch.cuda.current_device())): + raise RuntimeError( + 'Invalid device for DistributedFusedAdam ' + f'(device={device})' + ) + self.device = torch.device('cuda', torch.cuda.current_device()) # Process groups self.process_group = ( @@ -276,6 +382,27 @@ def __init__(self, # Allocate contiguous buffer for gradients self.contiguous_grad_buffer = contiguous_grad_buffer + # Store params or param remainders + if store_param_remainders: + if store_params: + raise RuntimeError( + 'Attempted to construct DistributedFusedAdam ' + 'with store_params=True and store_param_remainders=True' + ) + if (self.dtype != torch.float32 + or self.grad_sync_dtype != torch.float32 + or self.param_sync_dtype != torch.bfloat16): + raise RuntimeError( + 'DistributedFusedAdam requires ' + 'BF16 params and FP32 optimizer state ' + 'when storing parameter remainders ' + f'(dtype={self.dtype}, ' + f'grad_sync_dtype={self.grad_sync_dtype}, ' + f'param_sync_dtype={self.param_sync_dtype}))' + ) + self.store_params = store_params + self.store_param_remainders = store_param_remainders + # Determine bucket sizes dtype_size = torch.finfo(self.grad_sync_dtype).bits // 8 self.alignment = 128 // dtype_size @@ -287,11 +414,6 @@ def __init__(self, self.bucket_size = bucket_size self.shard_size = shard_size - # Load CUDA kernels - global fused_adam_cuda, distributed_adam_cuda - fused_adam_cuda = importlib.import_module("fused_adam_cuda") - distributed_adam_cuda = importlib.import_module("distributed_adam_cuda") - # Optimizer state self.state['buckets'] = [] self.state['step'] = 0 @@ -301,9 +423,9 @@ def __init__(self, self._grads_generated = set() self._pipeline_streams = [torch.cuda.Stream() for _ in range(self.pipeline_size)] - # Divide gradients by factor before optimizer step. Used for - # grad clipping and gradient scaler. - self._inv_grad_scale = torch.full([1], 1.0, dtype=self.dtype, device=self.device) + # Scale by factor before optimizer step. Used for grad + # clipping and gradient scaler. + self._grad_scale = torch.full([], 1.0, dtype=self.dtype, device=self.device) # Norm of parameter gradients. Used for gradient clipping and # gradient scaler. self._grad_norm = None @@ -378,6 +500,12 @@ def reduction_hook(*unused): device=self.device, ) + def parameters(self): + """Returns an iterator over optimizer parameters""" + return itertools.chain.from_iterable( + group['params'] for group in self.param_groups + ) + def init_params(self, params=None): """Initialize optimizer state for parameters @@ -388,12 +516,10 @@ def init_params(self, params=None): """ # Default cases - if isinstance(params, torch.Tensor): + if params is None: + params = self.parameters() + elif isinstance(params, torch.Tensor): params = [params] - elif params is None: - params = [] - for group in self.param_groups: - params.extend(group['params']) # Get indices corresponding to parameters id_map = dict() @@ -418,7 +544,13 @@ def _init_param_state( # Make sure there is at least one bucket if not self.state['buckets']: self.state['buckets'].append( - self.StateBucket(self.shard_size, self.dtype, self.device) + self.StateBucket( + self.shard_size, + self.dtype, + self.device, + store_params=self.store_params, + store_param_remainders=self.store_param_remainders, + ) ) # Split parameter values into fragments @@ -446,7 +578,13 @@ def _init_param_state( # Create new bucket if current one is full if fragment_size <= 0: self.state['buckets'].append( - self.StateBucket(self.shard_size, self.dtype, self.device) + self.StateBucket( + self.shard_size, + self.dtype, + self.device, + store_params=self.store_params, + store_param_remainders=self.store_param_remainders, + ) ) continue @@ -482,15 +620,16 @@ def _init_param_state( bucket.fragments.append(fragment) param_start = param_end - # Initialize master param buffer - for fragment in self.state[param]['fragments']: - if fragment.in_local_shard: - bucket = self.state['buckets'][fragment.bucket_id] - param_start, param_end = fragment.shard_param_range - shard_start, shard_end = fragment.shard_range - model_param_fragment = param.view(-1)[param_start:param_end] - master_param_fragment = bucket.params_shard[shard_start:shard_end] - master_param_fragment.copy_(model_param_fragment) + # Initialize main param buffer + if self.store_params: + for fragment in self.state[param]['fragments']: + if fragment.in_local_shard: + bucket = self.state['buckets'][fragment.bucket_id] + param_start, param_end = fragment.shard_param_range + shard_start, shard_end = fragment.shard_range + model_param_fragment = param.detach().view(-1)[param_start:param_end] + main_param_fragment = bucket.params_shard[shard_start:shard_end] + main_param_fragment.copy_(model_param_fragment) def zero_grad(self, set_to_none=True): """Clear parameter gradients""" @@ -508,16 +647,15 @@ def zero_grad(self, set_to_none=True): bucket.grads_bucket = self._grad_buffer[bucket_start:bucket_end] # Reset param grads - for group in self.param_groups: - for param in group['params']: - if param.grad is None or set_to_none: - param.grad = None - else: - param.grad.zero_() + for param in self.parameters(): + if param.grad is None or set_to_none: + param.grad = None + else: + param.grad.zero_() # Reset other state self._grads_generated = set() - self._inv_grad_scale = torch.full([1], 1.0, dtype=self.dtype, device=self.device) + self._grad_scale = torch.full([], 1.0, dtype=self.dtype, device=self.device) self._grad_norm = None def _grad_copy(self, param): @@ -686,60 +824,49 @@ def _start_bucket_grad_sync(self, buckets): else: reduce_op = torch.distributed.ReduceOp.SUM - # Reduce gradients - main_stream = torch.cuda.current_stream() - for stream in self._pipeline_streams: - stream.wait_stream(main_stream) + # Reduce-scatter over distributed process group for i, bucket in enumerate(buckets): bucket.status = self.GradientStatus.SYNCING - stream = self._pipeline_streams[i % self.pipeline_size] - with torch.cuda.stream(stream): - - # Reduce-scatter over distributed process group - bucket.sync_wait() - if self.distributed_size == 1: - bucket.sync_grads_shard = bucket.grads_bucket + bucket.sync_wait() + if self.distributed_size == 1: + bucket.sync_grads_shard = bucket.grads_bucket + else: + bucket.sync_grads_shard = torch.zeros( + [self.shard_size], + dtype=self.grad_sync_dtype, + device=self.device, + ) + grads_bucket_shards = [ + bucket.grads_bucket[i*self.shard_size:(i+1)*self.shard_size] + for i in range(self.distributed_size) + ] + if self._reduce_scatter_no_copy: + no_copy_kwarg = { 'no_copy': True } else: - with torch.cuda.stream(main_stream): - bucket.sync_grads_shard = torch.zeros( - [self.shard_size], - dtype=self.grad_sync_dtype, - device=self.device, - ) - grads_bucket_shards = [ - bucket.grads_bucket[i*self.shard_size:(i+1)*self.shard_size] - for i in range(self.distributed_size) - ] - if self._reduce_scatter_no_copy: - no_copy_kwarg = { 'no_copy': True } - else: - no_copy_kwarg = {} - bucket.sync_request = ( - torch.distributed.reduce_scatter( - bucket.sync_grads_shard, - grads_bucket_shards, - op=reduce_op, - group=self.distributed_process_group, - async_op=True, - **no_copy_kwarg, - ) + no_copy_kwarg = {} + bucket.sync_request = ( + torch.distributed.reduce_scatter( + bucket.sync_grads_shard, + grads_bucket_shards, + op=reduce_op, + group=self.distributed_process_group, + async_op=True, + **no_copy_kwarg, ) + ) - # All-reduce over redundant process group - # Note: Assuming reduce-scatters are finished in the - # order they are submitted, all-reduces should be - # submitted in a consistent order. There could be race - # conditions if wait doesn't finish in order. - if self.redundant_size > 1: - bucket.sync_wait() - bucket.sync_request = ( - torch.distributed.all_reduce( - bucket.sync_grads_shard, - op=reduce_op, - group=self.redundant_process_group, - async_op=True, - ) + # All-reduce over redundant process group + if self.redundant_size > 1: + for i, bucket in enumerate(buckets): + bucket.sync_wait() + bucket.sync_request = ( + torch.distributed.all_reduce( + bucket.sync_grads_shard, + op=reduce_op, + group=self.redundant_process_group, + async_op=True, ) + ) def _finish_bucket_grad_sync(self): """Wait for any gradient synchronizations that are in progress""" @@ -823,7 +950,7 @@ def _local_grad_norm(self, parameters=[], norm_type=2.0): if not parameters or len(parameters) == self._num_grads: # Compute norm of all local gradients - dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda') + dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device=self.device) grad_norm_sq = multi_tensor_applier( amp_C.multi_tensor_l2norm, dummy_overflow_buf, @@ -840,7 +967,7 @@ def _local_grad_norm(self, parameters=[], norm_type=2.0): shard_start, shard_end = fragment.shard_range grads.append(bucket.grads_shard[shard_start:shard_end]) if grads: - dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda') + dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device=self.device) grad_norm_sq = multi_tensor_applier( amp_C.multi_tensor_l2norm, dummy_overflow_buf, @@ -848,9 +975,12 @@ def _local_grad_norm(self, parameters=[], norm_type=2.0): False, )[0] ** 2 else: - grad_norm_sq = torch.zeros([1], dtype=torch.float32, device=self.device) + grad_norm_sq = torch.zeros([1], dtype=self.dtype, device=self.device) - return grad_norm_sq.detach().view([]) + grad_norm_sq = grad_norm_sq.detach() + grad_norm_sq = grad_norm_sq.to(dtype=self.dtype, device=self.device) + grad_norm_sq = grad_norm_sq.view([]) + return grad_norm_sq def grad_norm(self, parameters=[], norm_type=2.0, force=False): """Gradient norm of parameters in optimizer @@ -907,8 +1037,8 @@ def clip_grad_norm(self, max_norm, parameters=[], norm_type=2.0): """ assert max_norm > 0 total_norm = self.grad_norm(parameters=parameters, norm_type=norm_type) - inv_clip_coef = (total_norm + 1e-6) / max_norm - self._inv_grad_scale = torch.clamp(inv_clip_coef, min=1.0).view(1) + clip_coef = max_norm / (total_norm + 1e-6) + self._grad_scale = torch.minimum(self._grad_scale, clip_coef) return total_norm def step(self, closure=None, *, grad_scaler=None): @@ -945,8 +1075,8 @@ def step(self, closure=None, *, grad_scaler=None): return else: assert grad_scaler._scale is not None - self._inv_grad_scale *= grad_scaler._scale - inv_grad_scale = self._inv_grad_scale.item() + self._grad_scale /= grad_scaler._scale.view([]) + self._grad_scale = self._grad_scale.to(dtype=torch.float32, device=self.device) # Construct workspace buffers params_bucket_buffers = [ @@ -957,22 +1087,6 @@ def step(self, closure=None, *, grad_scaler=None): ) for _ in range(self.pipeline_size) ] - if self.grad_sync_dtype == self.param_sync_dtype: - shard_start = self.distributed_rank * self.shard_size - shard_end = shard_start + self.shard_size - params_copy_buffers = [ - params_bucket[shard_start:shard_end] - for params_bucket in params_bucket_buffers - ] - else: - params_copy_buffers = [ - torch.empty( - [self.shard_size], - dtype=self.grad_sync_dtype, - device=self.device, - ) - for _ in range(self.pipeline_size) - ] # Apply optimizer step to each bucket and synchronize params self.state['step'] += 1 @@ -981,138 +1095,61 @@ def step(self, closure=None, *, grad_scaler=None): stream.wait_stream(main_stream) for bucket_id in range(len(self.state['buckets'])): stream_id = bucket_id % self.pipeline_size - - # Bucket buffers - fragments = self.state['buckets'][bucket_id].fragments - shard_start = self.distributed_rank * self.shard_size - shard_end = shard_start + self.shard_size - params_bucket = params_bucket_buffers[stream_id] - params_bucket_shard = params_bucket[shard_start:shard_end] - params_shard = self.state['buckets'][bucket_id].params_shard - params_copy = params_copy_buffers[stream_id] - exp_avg = self.state['buckets'][bucket_id].exp_avg_shard - exp_avg_sq = self.state['buckets'][bucket_id].exp_avg_sq_shard - grads = self._grads_buckets[bucket_id].grads_shard - - # Perform compute on parallel stream stream = self._pipeline_streams[stream_id] with torch.cuda.stream(stream): - # Find param fragments in local shard - buffers = collections.defaultdict(list) # p, m, v, g, p_copy - for fragment in fragments: - if fragment.in_local_shard: - param_group_id = fragment.param_group_id - shard_start, shard_end = fragment.shard_range - buffers[param_group_id].append([ - params_shard[shard_start:shard_end], - exp_avg[shard_start:shard_end], - exp_avg_sq[shard_start:shard_end], - grads[shard_start:shard_end], - params_copy[shard_start:shard_end], - ]) - - # Fuse param fragments if possible - if len(buffers) == 1: - group_id = list(buffers.keys())[0] - buffers[group_id] = [( - params_shard, - exp_avg, - exp_avg_sq, - grads, - params_copy, - )] - - # Apply optimizer step to each param group - for group_id, group_buffers in buffers.items(): - - # Get param group configs - group = self.param_groups[group_id] - beta1, beta2 = group['betas'] - bias_correction = 1 if group['bias_correction'] else 0 - eps = group['eps'] - weight_decay = group['weight_decay'] - - # Copy param group configs to GPU - num_fragments = len(group_buffers) - beta1 = torch.full([num_fragments], beta1, dtype=self.dtype, device='cuda') - beta2 = torch.full([num_fragments], beta2, dtype=self.dtype, device='cuda') - bias_correction = torch.full([num_fragments], bias_correction, dtype=torch.int32, device='cuda') - eps = torch.full([num_fragments], eps, dtype=self.dtype, device='cuda') - weight_decay = torch.full([num_fragments], weight_decay, dtype=self.dtype, device='cuda') - - # Apply Adam step - dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda') - multi_tensor_applier( - distributed_adam_cuda.multi_tensor_fused_adam, - dummy_overflow_buf, - list(zip(*group_buffers)), - beta1, - beta2, - bias_correction, - eps, - weight_decay, - group['lr'], - inv_grad_scale, - self.state['step'], - 1, # Set to 0 to apply eps inside sqrt - ) + # Buffers for param sync + params_bucket = params_bucket_buffers[stream_id] + params_bucket_shards = [ + params_bucket[i*self.shard_size:(i+1)*self.shard_size] + for i in range(self.distributed_size) + ] - # Cast parameter dtype if needed - if params_copy.data_ptr() != params_bucket_shard.data_ptr(): - params_bucket_shard.copy_(params_copy) + # Apply optimizer step to local shard + if self.store_param_remainders: + self._local_step_with_param_remainders( + bucket_id, + params_bucket_shards[self.distributed_rank], + ) + else: + self._local_step( + bucket_id, + params_bucket_shards[self.distributed_rank], + ) - # Allgather updated parameters + # All-gather updated parameters + # Note: All-gather seems to allocate memory + # internally, which can cause significant memory pool + # overheads when called in side streams. Avoid this by + # only calling in main stream. if self.distributed_size > 1: - all_params_bucket_shards = [ - params_bucket[i*self.shard_size:(i+1)*self.shard_size] - for i in range(self.distributed_size) - ] if self._all_gather_no_copy: no_copy_kwarg = { 'no_copy': True } else: no_copy_kwarg = {} - torch.distributed.all_gather( - all_params_bucket_shards, - params_bucket_shard, - group=self.distributed_process_group, - **no_copy_kwarg, - ) + main_stream.wait_stream(stream) + with torch.cuda.stream(main_stream): + torch.distributed.all_gather( + params_bucket_shards, + params_bucket_shards[self.distributed_rank], + group=self.distributed_process_group, + **no_copy_kwarg, + ) + stream.wait_stream(main_stream) # Copy values to param buffers - buffers = collections.defaultdict(list) # param_in, param_out + params_in = [] + params_out = [] + fragments = self.state['buckets'][bucket_id].fragments for fragment in fragments: param_group_id = fragment.param_group_id param_id = fragment.param_id param = self.param_groups[param_group_id]['params'][param_id] bucket_start, bucket_end = fragment.bucket_range param_start, param_end = fragment.param_range - param_in = params_bucket[bucket_start:bucket_end] - param_out = param.detach().view(-1)[param_start:param_end] - if param_in.dtype == param_out.dtype: - # Just copy bytes if buffers have same type - param_in = param_in.view(torch.uint8) - param_out = param_out.view(torch.uint8) - buffers[(param.is_cuda, param.dtype)].append( - (param_in, param_out) - ) - for (is_cuda, dtype), dtype_buffers in buffers.items(): - fused_kernel_dtypes = ( - self.param_sync_dtype, - torch.float32, - torch.float16, - torch.uint8, - ) - if is_cuda and dtype in fused_kernel_dtypes: - dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda') - multi_tensor_applier( - fused_adam_cuda.maybe_cast_mt, - dummy_overflow_buf, - list(zip(*dtype_buffers)), - ) - else: - for param_in, param_out in dtype_buffers: - param_out.copy_(param_in) + params_in.append(params_bucket[bucket_start:bucket_end]) + params_out.append(param.detach().view(-1)[param_start:param_end]) + _multi_tensor_copy(params_in, params_out) # Synchronize pipeline streams for stream in self._pipeline_streams: @@ -1120,6 +1157,114 @@ def step(self, closure=None, *, grad_scaler=None): return loss + def _local_step(self, bucket_id, params_out): + """Apply optimizer step to local shard of parameter bucket""" + + # Optimizer state buffers for local shard + fragments = self.state['buckets'][bucket_id].fragments + exp_avg = self.state['buckets'][bucket_id].exp_avg_shard + exp_avg_sq = self.state['buckets'][bucket_id].exp_avg_sq_shard + grads = self._grads_buckets[bucket_id].grads_shard + + # Find param fragments in local shard + buffers = collections.defaultdict(list) # p_in, m, v, g, p_out + for fragment in fragments: + if fragment.in_local_shard: + param_group_id = fragment.param_group_id + shard_start, shard_end = fragment.shard_range + if self.store_params: + params_shard = self.state['buckets'][bucket_id].params_shard + param_fragment = params_shard[shard_start:shard_end] + else: + param_id = fragment.param_id + param = self.param_groups[param_group_id]['params'][param_id] + param_start, param_end = fragment.shard_param_range + param_fragment = param.detach().view(-1)[param_start:param_end] + param_fragment = param_fragment.to(dtype=self.dtype, device=self.device) + buffers[param_group_id].append([ + param_fragment, + exp_avg[shard_start:shard_end], + exp_avg_sq[shard_start:shard_end], + grads[shard_start:shard_end], + params_out[shard_start:shard_end], + ]) + + # Apply optimizer step to each param group + for group_id, group_buffers in buffers.items(): + group = self.param_groups[group_id] + beta1, beta2 = group['betas'] + dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device=self.device) + multi_tensor_applier( + distributed_adam_cuda.multi_tensor_fused_adam, + dummy_overflow_buf, + list(zip(*group_buffers)), + self._grad_scale, + group['lr'], + beta1, + beta2, + group['eps'], + self.state['step'], + 1 if self.adam_w_mode else 0, + 1 if group['bias_correction'] else 0, + group['weight_decay'], + ) + + def _local_step_with_param_remainders(self, bucket_id, params_out): + """Apply optimizer step to local shard of parameter bucket + + This is an experimental implementation that expects + store_params=False and store_param_remainders=True. The + optimizer dtype must be FP32 and the params must all be BF16 + and GPU. + """ + + # State buffers for local shard + fragments = self.state['buckets'][bucket_id].fragments + param_remainders_shard = self.state['buckets'][bucket_id].param_remainders_shard + exp_avg = self.state['buckets'][bucket_id].exp_avg_shard + exp_avg_sq = self.state['buckets'][bucket_id].exp_avg_sq_shard + grads = self._grads_buckets[bucket_id].grads_shard + + # Find param fragments in local shard + buffers = collections.defaultdict(list) # p_in, p_rem, m, v, g, p_out + for fragment in fragments: + if fragment.in_local_shard: + param_group_id = fragment.param_group_id + param_id = fragment.param_id + param_start, param_end = fragment.shard_param_range + shard_start, shard_end = fragment.shard_range + param = self.param_groups[param_group_id]['params'][param_id] + param_fragment = param.detach().view(-1)[param_start:param_end] + param_fragment = param_fragment.to(dtype=torch.bfloat16, device=self.device) + buffers[param_group_id].append([ + param_fragment, + param_remainders_shard[shard_start:shard_end], + exp_avg[shard_start:shard_end], + exp_avg_sq[shard_start:shard_end], + grads[shard_start:shard_end], + params_out[shard_start:shard_end], + ]) + + # Apply optimizer step to each param group + for group_id, group_buffers in buffers.items(): + group = self.param_groups[group_id] + beta1, beta2 = group['betas'] + dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device=self.device) + multi_tensor_applier( + distributed_adam_cuda.multi_tensor_fused_adam_with_param_remainders, + dummy_overflow_buf, + list(zip(*group_buffers)), + self._grad_scale, + group['lr'], + beta1, + beta2, + group['eps'], + self.state['step'], + 1 if self.adam_w_mode else 0, + 1 if group['bias_correction'] else 0, + group['weight_decay'], + ) + def state_dict(self, gather_on_root=True): """Get dictionary containing optimizer state @@ -1157,8 +1302,13 @@ def state_dict(self, gather_on_root=True): # Construct workspace buffers chunk_size = self.shard_size * torch.finfo(self.grad_sync_dtype).bits // 8 if self.distributed_rank == 0: - gathered_state_bytes = [state_bytes.getvalue()] - gathered_state_bytes.extend(bytearray(size) for size in state_sizes[1:]) + gathered_state_bytes = [ + torch.empty([size], dtype=torch.uint8, device='cpu') + for size in state_sizes + ] + gathered_state_bytes[0].copy_( + torch.frombuffer(state_bytes_view, dtype=torch.uint8) + ) gathered_chunks_buffers = [ torch.empty( [chunk_size * self.distributed_size], @@ -1238,17 +1388,13 @@ def state_dict(self, gather_on_root=True): # Copy back to CPU if self.distributed_rank == 0: for rank in range(1, self.distributed_size): - if offset < state_sizes[rank]: - rank_chunk_size = min(chunk_size, state_sizes[rank]-offset) - torch.frombuffer( - gathered_state_bytes[rank], - dtype=torch.uint8, - count=rank_chunk_size, - offset=offset, - ).copy_( - gathered_chunks[rank][:rank_chunk_size], - non_blocking=True, - ) + rank_chunk_start = offset + rank_chunk_end = min(offset + chunk_size, state_sizes[rank]) + rank_chunk_size = rank_chunk_end - rank_chunk_start + if rank_chunk_size > 0: + src = gathered_chunks[rank][:rank_chunk_size] + dst = gathered_state_bytes[rank][rank_chunk_start:rank_chunk_end] + dst.copy_(src, non_blocking=True) # Synchronize GPU for stream in self._pipeline_streams: @@ -1274,7 +1420,7 @@ def load_state_dict(self, state_dict): # Get state for current rank and parse byte string state_bytes = state_dict['gathered_states'][self.distributed_rank] - state_bytes = io.BytesIO(state_bytes) + state_bytes = io.BytesIO(state_bytes.numpy()) state_dict = torch.load(state_bytes) return super().load_state_dict(state_dict) diff --git a/apex/contrib/test/optimizers/test_dist_adam.py b/apex/contrib/test/optimizers/test_dist_adam.py index bd23ce2ae..875436e57 100644 --- a/apex/contrib/test/optimizers/test_dist_adam.py +++ b/apex/contrib/test/optimizers/test_dist_adam.py @@ -8,32 +8,34 @@ from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase class SimpleModel(torch.nn.Module): - def __init__(self, num_layers, size): super().__init__() - self.layers = torch.nn.ModuleList([ - torch.nn.Linear(size, size, bias=(i%3==0)) - for i in range(num_layers) + self.params = torch.nn.ParameterList([ + torch.nn.Parameter(torch.rand(1, size) + 1) + for _ in range(num_layers) ]) - def forward(self, x): y = 0 - for i, l in enumerate(self.layers): - y += (i+1) * l(x) + for i, param in enumerate(self.params): + y += (i+1) * param * x return y def make_models( num_layers, size, - dtype=torch.float32, + adam_w_mode=True, + model_dtype=torch.float32, + optim_dtype=None, param_sync_dtype=None, device='cuda', overlap_communication=True, + store_params=False, + store_param_remainders=False, ): # Construct models with same parameters - ref_model = SimpleModel(num_layers, size).to(dtype=dtype, device=device) - dist_model = SimpleModel(num_layers, size).to(dtype=dtype, device=device) + ref_model = SimpleModel(num_layers, size).to(dtype=model_dtype, device=device) + dist_model = SimpleModel(num_layers, size).to(dtype=model_dtype, device=device) with torch.no_grad(): for ref_param, dist_param in zip(dist_model.parameters(), ref_model.parameters()): @@ -48,8 +50,11 @@ def make_models( ) # Construct optimizers with same hyperparameters + if optim_dtype is None: + optim_dtype = model_dtype optim_args = dict(lr=0.1, betas=(0.1,0.2), eps=0.25, weight_decay=0.1) - ref_optim = torch.optim.AdamW( + ref_optim_class = torch.optim.AdamW if adam_w_mode else torch.optim.Adam + ref_optim = ref_optim_class( [ {'params': list(ref_model.parameters())[1::2], 'lr': 0.2}, {'params': list(ref_model.parameters())[0::2]}, @@ -61,10 +66,13 @@ def make_models( {'params': list(dist_model.parameters())[1::2], 'lr': 0.2}, {'params': list(dist_model.parameters())[0::2]}, ], + adam_w_mode=adam_w_mode, overlap_grad_sync=overlap_communication, bucket_cap_mb=71/(4*1024*1024), - dtype=torch.float32, + dtype=optim_dtype, param_sync_dtype=param_sync_dtype, + store_params=store_params, + store_param_remainders=store_param_remainders, **optim_args, ) @@ -83,18 +91,22 @@ class TestDistributedFusedAdam(NcclDistributedTestBase): def test_matches_pytorch( self, + rtol=None, + atol=None, num_layers=11, layer_size=7, batch_size=3, num_steps=3, micro_batch_steps=3, + adam_w_mode=True, overlap_communication=True, use_nosync=True, - dtype=torch.float32, + model_dtype=torch.float32, + optim_dtype=None, param_sync_dtype=None, device='cuda', - rtol=None, - atol=None, + store_params=False, + store_param_remainders=False, ): torch.manual_seed(self.seed + self.rank) @@ -103,10 +115,14 @@ def test_matches_pytorch( ref_model, ref_optim, dist_model, dist_optim = make_models( num_layers, layer_size, - dtype=dtype, + adam_w_mode=adam_w_mode, + model_dtype=model_dtype, + optim_dtype=optim_dtype, param_sync_dtype=param_sync_dtype, device=device, overlap_communication=overlap_communication, + store_params=store_params, + store_param_remainders=store_param_remainders, ) # Training loop @@ -122,8 +138,8 @@ def test_matches_pytorch( # Synthetic data x = torch.rand(batch_size, layer_size) - 0.5 dy = torch.rand_like(x) - 0.5 - x = x.to(dtype=dtype, device=device) - dy = dy.to(dtype=dtype, device=device) + x = x.to(dtype=model_dtype, device=device) + dy = dy.to(dtype=model_dtype, device=device) # Reference implementation x_ref = x.detach().clone().requires_grad_(True) @@ -155,6 +171,11 @@ def test_matches_pytorch( torch.testing.assert_close( dist_param, ref_param, rtol=rtol, atol=atol) + def test_matches_pytorch_l2_reg(self): + self.test_matches_pytorch( + adam_w_mode=False, + ) + def test_matches_pytorch_no_overlap(self): self.test_matches_pytorch( overlap_communication=False, @@ -166,24 +187,51 @@ def test_matches_pytorch_sync_every_step(self): def test_matches_pytorch_fp64(self): self.test_matches_pytorch( - dtype=torch.float64, rtol=1.3e-6, atol=1e-5, + model_dtype=torch.float64, + optim_dtype=torch.float32, ) def test_matches_pytorch_fp16(self): self.test_matches_pytorch( - dtype=torch.float16, - rtol=1e-2, - atol=1e-2, + rtol=5e-3, + atol=1e-5, + micro_batch_steps=1, + model_dtype=torch.float16, + optim_dtype=torch.float16, ) - def test_matches_pytorch_allgather_fp16(self): + def test_matches_pytorch_bf16(self): self.test_matches_pytorch( - dtype=torch.float32, + rtol=5e-2, + atol=1e-5, + micro_batch_steps=1, + model_dtype=torch.bfloat16, + optim_dtype=torch.bfloat16, + ) + + def test_matches_pytorch_fp16_params(self): + self.test_matches_pytorch( + rtol=5e-3, + atol=1e-5, + micro_batch_steps=1, + model_dtype=torch.float16, + optim_dtype=torch.float32, param_sync_dtype=torch.float16, - rtol=1e-2, - atol=1e-2, + store_params=True, + ) + + def test_matches_pytorch_bf16_param_remainders(self): + self.test_matches_pytorch( + rtol=5e-2, + atol=1e-5, + micro_batch_steps=1, + model_dtype=torch.bfloat16, + optim_dtype=torch.float32, + param_sync_dtype=torch.bfloat16, + store_params=False, + store_param_remainders=True, ) def test_raises_on_mismatch(self): @@ -200,9 +248,9 @@ def test_raises_on_mismatch(self): # Only perform training step with distributed model dist_optim.zero_grad() - x = torch.rand(3, layer_size) + 0.5 + x = torch.rand(3, layer_size) - 0.5 x = x.to(dtype=torch.float32, device='cuda') - dy = torch.rand_like(x) + 0.5 + dy = torch.rand_like(x) - 0.5 y = dist_model(x) y.backward(dy) dist_optim.step() @@ -227,8 +275,8 @@ def test_clip_grad_norm(self): xs = [3, 1, 4, 1, 5, 9] dys = [1, -1, 1, -1, 1, -1] for x, dy in zip(xs, dys): - x = torch.tensor([x], dtype=torch.float32, device='cuda') - dy = torch.tensor([dy], dtype=torch.float32, device='cuda') + x = torch.tensor([[x]], dtype=torch.float32, device='cuda') + dy = torch.tensor([[dy]], dtype=torch.float32, device='cuda') # Reference implementation ref_optim.zero_grad() @@ -269,8 +317,8 @@ def test_grad_scaler(self): xs = [3, 1, 4, 1, 5, 9] dys = [1, float('inf'), 1, 1, float('nan'), -1] for x, dy in zip(xs, dys): - x = torch.tensor([x], dtype=torch.float32, device='cuda') - dy = torch.tensor([dy], dtype=torch.float32, device='cuda') + x = torch.tensor([[x]], dtype=torch.float32, device='cuda') + dy = torch.tensor([[dy]], dtype=torch.float32, device='cuda') # Reference implementation ref_optim.zero_grad() @@ -337,7 +385,7 @@ def test_checkpoint(self): state_bytes = [None] torch.distributed.broadcast_object_list(state_bytes, src=0) state_bytes = io.BytesIO(state_bytes[0]) - state_dict = torch.load(state_bytes, map_location='cuda') + state_dict = torch.load(state_bytes) model_load.load_state_dict(state_dict['model']) optim_load.load_state_dict(state_dict['optim']) diff --git a/apex/contrib/test/run_rocm_extensions.py b/apex/contrib/test/run_rocm_extensions.py index e0f4e1f5b..1e9d29f1f 100644 --- a/apex/contrib/test/run_rocm_extensions.py +++ b/apex/contrib/test/run_rocm_extensions.py @@ -2,7 +2,7 @@ import sys -test_dirs = ["groupbn", "layer_norm", "multihead_attn", "transducer", "focal_loss", "index_mul_2d", "."] # "." for test_label_smoothing.py +test_dirs = ["groupbn", "layer_norm", "multihead_attn", "optimizers", "transducer", "focal_loss", "index_mul_2d", "."] # "." for test_label_smoothing.py ROCM_BLACKLIST = [ "layer_norm" ] diff --git a/csrc/multi_tensor_apply.cuh b/csrc/multi_tensor_apply.cuh index b6a9f17de..814b02c3b 100644 --- a/csrc/multi_tensor_apply.cuh +++ b/csrc/multi_tensor_apply.cuh @@ -13,8 +13,8 @@ // TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson) -constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30}; -constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320}; +constexpr int depth_to_max_tensors[6] = {110, 64, 48, 36, 30, 24}; +constexpr int depth_to_max_blocks[6] = {320, 320, 320, 320, 320, 320}; template struct TensorListMetadata {