diff --git a/src/array/cuda/gather_mm.cu b/src/array/cuda/gather_mm.cu index c40d53bb05ec..6f3c1e1c279b 100644 --- a/src/array/cuda/gather_mm.cu +++ b/src/array/cuda/gather_mm.cu @@ -113,7 +113,7 @@ __global__ void GatherMMScatterKernel( for (unsigned int outloop = 0; outloop < out_len; outloop += 32) { DType out_reg = static_cast(0.0f); // thread private const unsigned int l = laneId; - if (l < out_len) { + if (l + outloop < out_len) { // iterate over elements of a row of A for (unsigned int i = 0; i < a_tile; i++) { const DType a_val = sh_A[local_row * sh_a_tile + i]; @@ -170,7 +170,7 @@ __global__ void GatherMMScatterKernel2( for (unsigned int outloop = 0; outloop < out_len; outloop += 32) { DType out_reg = static_cast(0.0f); // thread private const unsigned int l = laneId; - if (l < out_len) { + if (l + outloop < out_len) { const DType b_val = B[row_b * out_len + (outloop + l)]; /* iterate over elements of a row of A */ for (unsigned int i = 0; i < a_tile; i++) {