From bf9448073e7c6ccdee020a6a018282b084b6739c Mon Sep 17 00:00:00 2001 From: Juan del Carmen Grados Vasquez Date: Tue, 28 Apr 2026 05:38:43 +0000 Subject: [PATCH] FIX/refactor: centralize CP MiniZinc model assembly Co-authored-by: Copilot --- claasp/cipher_modules/models/cp/mzn_model.py | 62 ++++++++++++++----- .../models/cp/mzn_model_test.py | 36 ++++++++++- 2 files changed, 81 insertions(+), 17 deletions(-) diff --git a/claasp/cipher_modules/models/cp/mzn_model.py b/claasp/cipher_modules/models/cp/mzn_model.py index 26a7cc97..9647310f 100644 --- a/claasp/cipher_modules/models/cp/mzn_model.py +++ b/claasp/cipher_modules/models/cp/mzn_model.py @@ -21,6 +21,7 @@ import subprocess import time from copy import deepcopy +from dataclasses import dataclass, field from datetime import timedelta from minizinc import Instance, Model, Solver, Status @@ -52,6 +53,18 @@ CONSTRAINT_TYPE_ERROR = "Constraint type not defined" +@dataclass +class MiniZincModelParts: + prefix: list[str] = field(default_factory=list) + variables: list[str] = field(default_factory=list) + constraints: list[str] = field(default_factory=list) + outputs: list[str] = field(default_factory=list) + carries_outputs: list[str] = field(default_factory=list) + + def lines(self): + return self.prefix + self.variables + self.constraints + self.outputs + self.carries_outputs + + class MznModel: def __init__(self, cipher, sat_or_milp='sat'): @@ -106,6 +119,20 @@ def initialise_model(self): self.component_probability_var = {} self._model_prefix = ['include "globals.mzn";', f"{usefulfunctions.MINIZINC_USEFUL_FUNCTIONS}"] + def current_model_parts(self): + return MiniZincModelParts( + variables=list(self._variables_list), + constraints=list(self._model_constraints), + outputs=list(self.mzn_output_directives), + carries_outputs=list(self.mzn_carries_output_directives), + ) + + def assemble_model_lines(self, parts=None): + return (parts or self.current_model_parts()).lines() + + def assemble_model(self, parts=None): + return "\n".join(self.assemble_model_lines(parts)) + "\n" + def add_comment(self, comment): """ Write a 'comment' at the beginning of the model. @@ -813,8 +840,6 @@ def solve( ): truncated = True - mzn_model = self._variables_list + self._model_constraints - solutions = [] if solve_external: command = self.get_command_for_solver_process( @@ -824,7 +849,7 @@ def solve( timeout_in_seconds_, intermediate_solutions=intermediate_solutions_, ) - model = "\n".join(mzn_model) + model = self.assemble_model() start = time.time() solver_process = subprocess.run(command, input=model, capture_output=True, text=True) end = time.time() @@ -832,7 +857,7 @@ def solve( if solver_process.returncode >= 0: solver_output = solver_process.stdout.splitlines() else: - mzn_model_string = "\n".join(mzn_model) + mzn_model_string = self.assemble_model() solver_name_mzn = Solver.lookup(solver_name) bit_mzn_model = Model() bit_mzn_model.add_string(mzn_model_string) @@ -972,9 +997,12 @@ def solve_for_ARX( sage: result.statistics['nSolutions'] 1 """ - constraints = self._model_constraints - variables = self._variables_list - mzn_model_string = "\n".join(constraints) + "\n".join(variables) + mzn_model_string = self.assemble_model( + MiniZincModelParts( + variables=list(self._model_constraints), + constraints=list(self._variables_list), + ) + ) solver_name_mzn = Solver.lookup(solver_name) bit_mzn_model = Model() bit_mzn_model.add_string(mzn_model_string) @@ -1077,20 +1105,22 @@ def write_minizinc_model_to_file(self, file_path, prefix=""): - ``file_path`` -- **string**; the path of the file that will contain the model - ``prefix`` -- **str** (default: ``) """ - model_string = ( - "\n".join(self.mzn_comments) - + "\n".join(self._variables_list) - + "\n".join(self._model_constraints) - + "\n".join(self.mzn_output_directives) - + "\n".join(self.mzn_carries_output_directives) - ) if prefix == "": filename = f"{file_path}/{self.cipher_id}_mzn_{self.sat_or_milp}.mzn" else: filename = f"{file_path}/{prefix}_{self.cipher_id}_mzn_{self.sat_or_milp}.mzn" with open(filename, "w") as file: - file.write(model_string) + file.write( + self.assemble_model( + MiniZincModelParts( + variables=self.mzn_comments + list(self._variables_list), + constraints=list(self._model_constraints), + outputs=list(self.mzn_output_directives), + carries_outputs=list(self.mzn_carries_output_directives), + ) + ) + ) @property def cipher(self): @@ -1152,4 +1182,4 @@ def model_variables(self): """ if not self._variables_list: raise ValueError("No model generated") - return self._variables_list \ No newline at end of file + return self._variables_list diff --git a/tests/unit/cipher_modules/models/cp/mzn_model_test.py b/tests/unit/cipher_modules/models/cp/mzn_model_test.py index e053eb92..6bfc0087 100644 --- a/tests/unit/cipher_modules/models/cp/mzn_model_test.py +++ b/tests/unit/cipher_modules/models/cp/mzn_model_test.py @@ -5,7 +5,7 @@ from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher from claasp.ciphers.block_ciphers.midori_block_cipher import MidoriBlockCipher import claasp.cipher_modules.models.cp.mzn_model as mzn_model_module -from claasp.cipher_modules.models.cp.mzn_model import MznModel +from claasp.cipher_modules.models.cp.mzn_model import MiniZincModelParts, MznModel from claasp.ciphers.block_ciphers.raiden_block_cipher import RaidenBlockCipher from claasp.cipher_modules.models.cp.mzn_models.mzn_xor_differential_model import MznXorDifferentialModel from claasp.cipher_modules.models.cp.mzn_models.mzn_xor_differential_model_arx_optimized import ( @@ -28,6 +28,40 @@ def total_seconds(self): return self._seconds +def test_assemble_model_preserves_legacy_order(): + speck = SpeckBlockCipher(number_of_rounds=1) + model = MznModel(speck) + model._variables_list = ["var int: x;"] + model._model_constraints = ["constraint x = 1;", "solve satisfy;"] + model.mzn_output_directives = ['output ["x=", show(x)];'] + + assert model.assemble_model() == ( + "var int: x;\n" + "constraint x = 1;\n" + "solve satisfy;\n" + 'output ["x=", show(x)];\n' + ) + + +def test_assemble_model_accepts_explicit_parts(): + speck = SpeckBlockCipher(number_of_rounds=1) + model = MznModel(speck) + parts = MiniZincModelParts( + prefix=['include "globals.mzn";'], + variables=["var int: x;"], + constraints=["constraint x = 1;", "solve satisfy;"], + outputs=['output ["x=", show(x)];'], + ) + + assert model.assemble_model(parts) == ( + 'include "globals.mzn";\n' + "var int: x;\n" + "constraint x = 1;\n" + "solve satisfy;\n" + 'output ["x=", show(x)];\n' + ) + + @pytest.mark.parametrize( "solver_stats,expected_time", [