diff --git a/tests/cpu/st/testcase/trelu/gen_data.py b/tests/cpu/st/testcase/trelu/gen_data.py index ef55aa64..370de925 100644 --- a/tests/cpu/st/testcase/trelu/gen_data.py +++ b/tests/cpu/st/testcase/trelu/gen_data.py @@ -34,7 +34,8 @@ def gen_golden_data_trelu(case_name, param): class TReluParams: - def __init__(self, dtype, src_tile_row, src_tile_col, dst_tile_row, dst_tile_col, valid_row, valid_col): + def __init__(self, output_case_name, dtype, src_tile_row, src_tile_col, dst_tile_row, dst_tile_col, valid_row, valid_col): + self.output_case_name = output_case_name self.dtype = dtype self.src_tile_row = src_tile_row self.src_tile_col = src_tile_col @@ -64,24 +65,21 @@ def substring(a, b) -> str: script_dir = os.path.dirname(os.path.abspath(__file__)) case_params_list = [ - TReluParams(np.float32, 64, 64, 64, 64, 64, 64), - TReluParams(np.int32, 64, 64, 64, 64, 64, 64), - TReluParams(np.float16, 16, 256, 16, 256, 16, 256), - TReluParams(np.int16, 64, 64, 64, 64, 64, 64), - TReluParams(np.float32, 64, 64, 64, 64, 60, 55), - TReluParams(np.int32, 64, 64, 64, 64, 60, 55), - TReluParams(np.float16, 64, 64, 96, 96, 64, 60), - TReluParams(np.int16, 64, 64, 96, 96, 64, 60), + TReluParams("case_0", np.float32, 64, 64, 64, 64, 64, 64), + TReluParams("case_1", np.int32, 64, 64, 64, 64, 64, 64), + TReluParams("case_2", np.float16, 16, 256, 16, 256, 16, 256), + TReluParams("case_3", np.int16, 64, 64, 64, 64, 64, 64), + TReluParams("case_4", np.float32, 64, 64, 64, 64, 60, 55), + TReluParams("case_5", np.int32, 64, 64, 64, 64, 60, 55), + TReluParams("case_6", np.float16, 64, 64, 96, 96, 64, 60), + TReluParams("case_7", np.int16, 64, 64, 96, 96, 64, 60), ] if os.getenv("PTO_CPU_SIM_ENABLE_BF16") == "1": - case_params_list.append(TReluParams(BF16_DTYPE, 16, 256, 16, 256, 16, 256)) + case_params_list.append(TReluParams("case_bf16_16x256_16x256_16x256", BF16_DTYPE, 16, 256, 16, 256, 16, 256)) - for i, param in enumerate(case_params_list): + for param in case_params_list: case_name = generate_case_name(param) - if i < 8: - output_dir = os.path.join(script_dir, f"TRELUTest.case_{i}") - else: - output_dir = os.path.join(script_dir, case_name) + output_dir = os.path.join(script_dir, f"TRELUTest.{param.output_case_name}") os.makedirs(output_dir, exist_ok=True) original_dir = os.getcwd() os.chdir(output_dir)