Skip to content

Commit 70a5a53

Browse files
authored
Fix contiguous memory check for SGEMM and DGEMM. (#5815)
2 parents fda55ad + f3f718b commit 70a5a53

2 files changed

Lines changed: 3 additions & 38 deletions

File tree

kernel/riscv64/dgemm_kernel_8x8_zvl256b.c

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1628,11 +1628,12 @@ static void NM_TAIL(BLASLONG K, BLASLONG M, const BLASLONG m_edge, const BLASLON
16281628
}
16291629
}
16301630
}
1631+
16311632
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, FLOAT* A, FLOAT* B, FLOAT* C, BLASLONG ldc)
16321633
{
16331634
if (K <= 0) return 0;
16341635
const BLASLONG m_edge = M & 7;
1635-
const bool S = (M == (ldc & 0x7));
1636+
const bool S = (ldc == m_edge);
16361637

16371638
// -- MAIN PASS
16381639

@@ -1689,24 +1690,6 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, FLOAT* A, FLOAT* B, F
16891690
resultEF = __riscv_vfmacc_vf_f64m2( resultEF, B7, A00, 8 );
16901691
}
16911692

1692-
// LMUL = 2 does worst here
1693-
vfloat64m1_t result0 = __riscv_vget_v_f64m2_f64m1(result01, 0);
1694-
vfloat64m1_t result1 = __riscv_vget_v_f64m2_f64m1(result01, 1);
1695-
vfloat64m1_t result2 = __riscv_vget_v_f64m2_f64m1(result23, 0);
1696-
vfloat64m1_t result3 = __riscv_vget_v_f64m2_f64m1(result23, 1);
1697-
vfloat64m1_t result4 = __riscv_vget_v_f64m2_f64m1(result45, 0);
1698-
vfloat64m1_t result5 = __riscv_vget_v_f64m2_f64m1(result45, 1);
1699-
vfloat64m1_t result6 = __riscv_vget_v_f64m2_f64m1(result67, 0);
1700-
vfloat64m1_t result7 = __riscv_vget_v_f64m2_f64m1(result67, 1);
1701-
vfloat64m1_t result8 = __riscv_vget_v_f64m2_f64m1(result89, 0);
1702-
vfloat64m1_t result9 = __riscv_vget_v_f64m2_f64m1(result89, 1);
1703-
vfloat64m1_t result10 = __riscv_vget_v_f64m2_f64m1(resultAB, 0);
1704-
vfloat64m1_t result11 = __riscv_vget_v_f64m2_f64m1(resultAB, 1);
1705-
vfloat64m1_t result12 = __riscv_vget_v_f64m2_f64m1(resultCD, 0);
1706-
vfloat64m1_t result13 = __riscv_vget_v_f64m2_f64m1(resultCD, 1);
1707-
vfloat64m1_t result14 = __riscv_vget_v_f64m2_f64m1(resultEF, 0);
1708-
vfloat64m1_t result15 = __riscv_vget_v_f64m2_f64m1(resultEF, 1);
1709-
17101693
FLOAT *C2 = C;
17111694

17121695
vfloat64m2_t c01 = __riscv_vle64_v_f64m2(C, 8); C += ldc;

kernel/riscv64/sgemm_kernel_16x8_zvl256b.c

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2137,7 +2137,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, FLOAT* A, FLOAT* B, F
21372137
{
21382138
if (K <= 0) return 0;
21392139
const BLASLONG m_edge = M & 15;
2140-
const bool S = (M == (ldc & 0xF));
2140+
const bool S = (ldc == m_edge);
21412141

21422142
// -- MAIN PASS
21432143

@@ -2194,24 +2194,6 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, FLOAT* A, FLOAT* B, F
21942194
resultEF = __riscv_vfmacc_vf_f32m2( resultEF, B7, A00, 16 );
21952195
}
21962196

2197-
// LMUL = 2 does worst here
2198-
vfloat32m1_t result0 = __riscv_vget_v_f32m2_f32m1(result01, 0);
2199-
vfloat32m1_t result1 = __riscv_vget_v_f32m2_f32m1(result01, 1);
2200-
vfloat32m1_t result2 = __riscv_vget_v_f32m2_f32m1(result23, 0);
2201-
vfloat32m1_t result3 = __riscv_vget_v_f32m2_f32m1(result23, 1);
2202-
vfloat32m1_t result4 = __riscv_vget_v_f32m2_f32m1(result45, 0);
2203-
vfloat32m1_t result5 = __riscv_vget_v_f32m2_f32m1(result45, 1);
2204-
vfloat32m1_t result6 = __riscv_vget_v_f32m2_f32m1(result67, 0);
2205-
vfloat32m1_t result7 = __riscv_vget_v_f32m2_f32m1(result67, 1);
2206-
vfloat32m1_t result8 = __riscv_vget_v_f32m2_f32m1(result89, 0);
2207-
vfloat32m1_t result9 = __riscv_vget_v_f32m2_f32m1(result89, 1);
2208-
vfloat32m1_t result10 = __riscv_vget_v_f32m2_f32m1(resultAB, 0);
2209-
vfloat32m1_t result11 = __riscv_vget_v_f32m2_f32m1(resultAB, 1);
2210-
vfloat32m1_t result12 = __riscv_vget_v_f32m2_f32m1(resultCD, 0);
2211-
vfloat32m1_t result13 = __riscv_vget_v_f32m2_f32m1(resultCD, 1);
2212-
vfloat32m1_t result14 = __riscv_vget_v_f32m2_f32m1(resultEF, 0);
2213-
vfloat32m1_t result15 = __riscv_vget_v_f32m2_f32m1(resultEF, 1);
2214-
22152197
FLOAT *C2 = C;
22162198

22172199
vfloat32m2_t c01 = __riscv_vle32_v_f32m2(C, 16); C += ldc;

0 commit comments

Comments
 (0)